1use crate::{resolver::SrvResolver, SrvRecord};
4use arc_swap::ArcSwap;
5use http::uri::Scheme;
6use std::{fmt::Debug, future::Future, sync::Arc, time::Instant};
7use url::Url;
8
9mod cache;
10pub use cache::Cache;
11
12pub mod policy;
14
15#[derive(Debug, thiserror::Error)]
17pub enum Error<Lookup: Debug> {
18 #[error("SRV lookup error")]
20 Lookup(Lookup),
21 #[error("building url from SRV record: {0}")]
23 RecordParsing(#[from] url::ParseError),
24 #[error("no SRV targets to use")]
26 NoTargets,
27}
28
29#[derive(Debug)]
51pub struct SrvClient<Resolver, Policy: policy::Policy = policy::Affinity> {
52 srv: String,
53 fallback: url::Url,
54 allowed_suffixes: Option<Vec<url::Host>>,
55 resolver: Resolver,
56 http_scheme: Scheme,
57 path_prefix: String,
58 policy: Policy,
59 cache: ArcSwap<Cache<Policy::CacheItem>>,
60}
61
62impl<Resolver: Default, Policy: policy::Policy + Default> SrvClient<Resolver, Policy> {
63 pub fn new(
66 srv_name: impl ToString,
67 fallback: url::Url,
68 allowed_suffixes: Option<Vec<url::Host>>,
69 ) -> Self {
70 Self::new_with_resolver(srv_name, fallback, allowed_suffixes, Resolver::default())
71 }
72}
73
74impl<Resolver, Policy: policy::Policy + Default> SrvClient<Resolver, Policy> {
75 pub fn new_with_resolver(
77 srv_name: impl ToString,
78 fallback: url::Url,
79 allowed_suffixes: Option<Vec<url::Host>>,
80 resolver: Resolver,
81 ) -> Self {
82 Self {
83 srv: srv_name.to_string(),
84 fallback,
85 allowed_suffixes,
86 resolver,
87 http_scheme: Scheme::HTTPS,
88 path_prefix: String::from("/"),
89 policy: Default::default(),
90 cache: Default::default(),
91 }
92 }
93}
94
95impl<Resolver: SrvResolver, Policy: policy::Policy> SrvClient<Resolver, Policy> {
96 async fn get_srv_records(
99 &self,
100 ) -> Result<(Vec<Resolver::Record>, Instant), Error<Resolver::Error>> {
101 self.resolver
102 .get_srv_records(&self.srv)
103 .await
104 .map_err(Error::Lookup)
105 }
106
107 pub async fn get_fresh_uri_candidates(
112 &self,
113 ) -> Result<(Vec<Url>, Instant), Error<Resolver::Error>> {
114 let (records, valid_until) = self.get_srv_records().await?;
116
117 let uri_iter = records
119 .iter()
120 .map(|record| self.parse_record(record))
121 .filter_map(|parsed| match parsed {
122 Ok(record) => Some(record),
123 Err(e) => {
124 tracing::trace!(%e, "Failed to parse an SRV record");
125 None
126 }
127 });
128
129 let uris = if let Some(allowed_suffixes) = &self.allowed_suffixes {
130 use url::Host;
131
132 let mut allowed_ipv4 = Vec::<&std::net::Ipv4Addr>::new();
133 let mut allowed_ipv6 = Vec::<&std::net::Ipv6Addr>::new();
134 let mut allowed_domains = Vec::<&str>::new();
135
136 for suffix in allowed_suffixes {
137 match suffix {
138 Host::Ipv4(ip) => {
139 allowed_ipv4.push(ip);
140 }
141 Host::Ipv6(ip) => {
142 allowed_ipv6.push(ip);
143 }
144 Host::Domain(d) => {
145 allowed_domains.push(d);
146 }
147 }
148 }
149
150 uri_iter
151 .filter(|record| {
152 let allow = match record.host() {
153 None => false,
154 Some(Host::Ipv4(ip)) => allowed_ipv4.contains(&&ip),
155 Some(Host::Ipv6(ip)) => allowed_ipv6.contains(&&ip),
156 Some(Host::Domain(candidate)) => allowed_domains
157 .iter()
158 .any(|allowed| candidate.ends_with(allowed)),
159 };
160
161 if !allow {
162 tracing::trace!(
163 %record,
164 "Rejecting SRV record because it is not allowed by the allowed suffixes"
165 );
166 }
167
168 allow
169 })
170 .collect::<Vec<Url>>()
171 } else {
172 uri_iter.collect::<Vec<Url>>()
173 };
174
175 Ok((uris, valid_until))
176 }
177
178 async fn refresh_cache(&self) -> Result<Arc<Cache<Policy::CacheItem>>, Error<Resolver::Error>> {
179 let new_cache = Arc::new(self.policy.refresh_cache(self).await?);
180 self.cache.store(new_cache.clone());
181 Ok(new_cache)
182 }
183
184 async fn get_valid_cache(
186 &self,
187 ) -> Result<Arc<Cache<Policy::CacheItem>>, Error<Resolver::Error>> {
188 match self.cache.load_full() {
189 cache if cache.valid() => Ok(cache),
190 _ => self.refresh_cache().await,
191 }
192 }
193
194 pub async fn execute<T, E, Fut>(&self, func: impl FnMut(Url) -> Fut) -> Result<T, E>
199 where
200 E: std::error::Error,
201 Fut: Future<Output = Result<T, E>>,
202 {
203 let mut func = func;
204 let cache = match self.get_valid_cache().await {
205 Ok(c) => c,
206 Err(e) => {
207 tracing::trace!(%e, "No valid cache");
208 return func(self.fallback.clone()).await;
209 }
210 };
211
212 let order = self.policy.order(cache.items());
213 let cache_items = order.map(|idx| &cache.items()[idx]);
214
215 for cache_item in cache_items.into_iter() {
216 let candidate = Policy::cache_item_to_uri(cache_item);
217
218 match func(candidate.to_owned()).await {
219 Ok(res) => {
220 tracing::trace!(URI = %candidate, "execution attempt succeeded");
221 self.policy.note_success(candidate);
222 return Ok(res);
223 }
224 Err(err) => {
225 tracing::trace!(URI = %candidate, error = %err, "execution attempt failed");
226 self.policy.note_failure(candidate);
227 }
228 }
229 }
230
231 func(self.fallback.clone()).await
232 }
233
234 fn parse_record(&self, record: &Resolver::Record) -> Result<Url, url::ParseError> {
235 record.parse(self.http_scheme.clone())
236 }
237}
238
239impl<Resolver, Policy: policy::Policy> SrvClient<Resolver, Policy> {
240 pub fn srv_name(self, srv_name: impl ToString) -> Self {
242 Self {
243 srv: srv_name.to_string(),
244 ..self
245 }
246 }
247
248 pub fn resolver<R>(self, resolver: R) -> SrvClient<R, Policy> {
250 SrvClient {
251 resolver,
252 cache: Default::default(),
253 policy: self.policy,
254 srv: self.srv,
255 fallback: self.fallback,
256 allowed_suffixes: self.allowed_suffixes,
257 http_scheme: self.http_scheme,
258 path_prefix: self.path_prefix,
259 }
260 }
261
262 pub fn policy<P: policy::Policy>(self, policy: P) -> SrvClient<Resolver, P> {
264 SrvClient {
265 policy,
266 cache: Default::default(),
267 resolver: self.resolver,
268 srv: self.srv,
269 fallback: self.fallback,
270 allowed_suffixes: self.allowed_suffixes,
271 http_scheme: self.http_scheme,
272 path_prefix: self.path_prefix,
273 }
274 }
275
276 pub fn http_scheme(self, http_scheme: Scheme) -> Self {
278 Self {
279 http_scheme,
280 ..self
281 }
282 }
283
284 pub fn path_prefix(self, path_prefix: impl ToString) -> Self {
286 Self {
287 path_prefix: path_prefix.to_string(),
288 ..self
289 }
290 }
291}