detsys_srv/client/
mod.rs

1//! Clients based on SRV lookups.
2
3use crate::{resolver::SrvResolver, SrvRecord};
4use arc_swap::ArcSwap;
5use http::uri::Scheme;
6use std::{fmt::Debug, future::Future, sync::Arc, time::Instant};
7use url::Url;
8
9mod cache;
10pub use cache::Cache;
11
12/// SRV target selection policies.
13pub mod policy;
14
15/// Errors encountered by a [`SrvClient`].
16#[derive(Debug, thiserror::Error)]
17pub enum Error<Lookup: Debug> {
18    /// SRV lookup errors
19    #[error("SRV lookup error")]
20    Lookup(Lookup),
21    /// SRV record parsing errors
22    #[error("building url from SRV record: {0}")]
23    RecordParsing(#[from] url::ParseError),
24    /// Produced when there are no SRV targets for a client to use
25    #[error("no SRV targets to use")]
26    NoTargets,
27}
28
29/// Client for intelligently performing operations on a service located by SRV records.
30///
31/// # Usage
32///
33/// After being created by [`SrvClient::new`] or [`SrvClient::new_with_resolver`],
34/// operations can be performed on the service pointed to by a [`SrvClient`] with
35/// the [`execute`] and [`execute_stream`] methods.
36///
37/// ## DNS Resolvers
38///
39/// The resolver used to lookup SRV records is determined by a client's
40/// [`SrvResolver`], and can be set with [`SrvClient::resolver`].
41///
42/// ## SRV Target Selection Policies
43///
44/// SRV target selection order is determined by a client's [`Policy`],
45/// and can be set with [`SrvClient::policy`].
46///
47/// [`execute`]: SrvClient::execute()
48/// [`execute_stream`]: SrvClient::execute_stream()
49/// [`Policy`]: policy::Policy
50#[derive(Debug)]
51pub struct SrvClient<Resolver, Policy: policy::Policy = policy::Affinity> {
52    srv: String,
53    fallback: url::Url,
54    allowed_suffixes: Option<Vec<url::Host>>,
55    resolver: Resolver,
56    http_scheme: Scheme,
57    path_prefix: String,
58    policy: Policy,
59    cache: ArcSwap<Cache<Policy::CacheItem>>,
60}
61
62impl<Resolver: Default, Policy: policy::Policy + Default> SrvClient<Resolver, Policy> {
63    /// Creates a new client for communicating with services located by `srv_name`.
64    ///
65    pub fn new(
66        srv_name: impl ToString,
67        fallback: url::Url,
68        allowed_suffixes: Option<Vec<url::Host>>,
69    ) -> Self {
70        Self::new_with_resolver(srv_name, fallback, allowed_suffixes, Resolver::default())
71    }
72}
73
74impl<Resolver, Policy: policy::Policy + Default> SrvClient<Resolver, Policy> {
75    /// Creates a new client for communicating with services located by `srv_name`.
76    pub fn new_with_resolver(
77        srv_name: impl ToString,
78        fallback: url::Url,
79        allowed_suffixes: Option<Vec<url::Host>>,
80        resolver: Resolver,
81    ) -> Self {
82        Self {
83            srv: srv_name.to_string(),
84            fallback,
85            allowed_suffixes,
86            resolver,
87            http_scheme: Scheme::HTTPS,
88            path_prefix: String::from("/"),
89            policy: Default::default(),
90            cache: Default::default(),
91        }
92    }
93}
94
95impl<Resolver: SrvResolver, Policy: policy::Policy> SrvClient<Resolver, Policy> {
96    /// Gets a fresh set of SRV records from a client's DNS resolver, returning
97    /// them along with the time they're valid until.
98    async fn get_srv_records(
99        &self,
100    ) -> Result<(Vec<Resolver::Record>, Instant), Error<Resolver::Error>> {
101        self.resolver
102            .get_srv_records(&self.srv)
103            .await
104            .map_err(Error::Lookup)
105    }
106
107    /// Gets a fresh set of SRV records from a client's DNS resolver and parses
108    /// their target/port pairs into URIs, which are returned along with the
109    /// time they're valid until--i.e., the time a cache containing these URIs
110    /// should expire.
111    pub async fn get_fresh_uri_candidates(
112        &self,
113    ) -> Result<(Vec<Url>, Instant), Error<Resolver::Error>> {
114        // Query DNS for the SRV record
115        let (records, valid_until) = self.get_srv_records().await?;
116
117        // Create URIs from SRV records
118        let uri_iter = records
119            .iter()
120            .map(|record| self.parse_record(record))
121            .filter_map(|parsed| match parsed {
122                Ok(record) => Some(record),
123                Err(e) => {
124                    tracing::trace!(%e, "Failed to parse an SRV record");
125                    None
126                }
127            });
128
129        let uris = if let Some(allowed_suffixes) = &self.allowed_suffixes {
130            use url::Host;
131
132            let mut allowed_ipv4 = Vec::<&std::net::Ipv4Addr>::new();
133            let mut allowed_ipv6 = Vec::<&std::net::Ipv6Addr>::new();
134            let mut allowed_domains = Vec::<&str>::new();
135
136            for suffix in allowed_suffixes {
137                match suffix {
138                    Host::Ipv4(ip) => {
139                        allowed_ipv4.push(ip);
140                    }
141                    Host::Ipv6(ip) => {
142                        allowed_ipv6.push(ip);
143                    }
144                    Host::Domain(d) => {
145                        allowed_domains.push(d);
146                    }
147                }
148            }
149
150            uri_iter
151                .filter(|record| {
152                    let allow = match record.host() {
153                        None => false,
154                        Some(Host::Ipv4(ip)) => allowed_ipv4.contains(&&ip),
155                        Some(Host::Ipv6(ip)) => allowed_ipv6.contains(&&ip),
156                        Some(Host::Domain(candidate)) => allowed_domains
157                            .iter()
158                            .any(|allowed| candidate.ends_with(allowed)),
159                    };
160
161                    if !allow {
162                        tracing::trace!(
163                            %record,
164                            "Rejecting SRV record because it is not allowed by the allowed suffixes"
165                        );
166                    }
167
168                    allow
169                })
170                .collect::<Vec<Url>>()
171        } else {
172            uri_iter.collect::<Vec<Url>>()
173        };
174
175        Ok((uris, valid_until))
176    }
177
178    async fn refresh_cache(&self) -> Result<Arc<Cache<Policy::CacheItem>>, Error<Resolver::Error>> {
179        let new_cache = Arc::new(self.policy.refresh_cache(self).await?);
180        self.cache.store(new_cache.clone());
181        Ok(new_cache)
182    }
183
184    /// Gets a client's cached items, refreshing the existing cache if it is invalid.
185    async fn get_valid_cache(
186        &self,
187    ) -> Result<Arc<Cache<Policy::CacheItem>>, Error<Resolver::Error>> {
188        match self.cache.load_full() {
189            cache if cache.valid() => Ok(cache),
190            _ => self.refresh_cache().await,
191        }
192    }
193
194    /// Performs an operation on a client's SRV targets, producing the first
195    /// successful result or the last error encountered if every execution of
196    /// the operation was unsuccessful.
197    ///
198    pub async fn execute<T, E, Fut>(&self, func: impl FnMut(Url) -> Fut) -> Result<T, E>
199    where
200        E: std::error::Error,
201        Fut: Future<Output = Result<T, E>>,
202    {
203        let mut func = func;
204        let cache = match self.get_valid_cache().await {
205            Ok(c) => c,
206            Err(e) => {
207                tracing::trace!(%e, "No valid cache");
208                return func(self.fallback.clone()).await;
209            }
210        };
211
212        let order = self.policy.order(cache.items());
213        let cache_items = order.map(|idx| &cache.items()[idx]);
214
215        for cache_item in cache_items.into_iter() {
216            let candidate = Policy::cache_item_to_uri(cache_item);
217
218            match func(candidate.to_owned()).await {
219                Ok(res) => {
220                    tracing::trace!(URI = %candidate, "execution attempt succeeded");
221                    self.policy.note_success(candidate);
222                    return Ok(res);
223                }
224                Err(err) => {
225                    tracing::trace!(URI = %candidate, error = %err, "execution attempt failed");
226                    self.policy.note_failure(candidate);
227                }
228            }
229        }
230
231        func(self.fallback.clone()).await
232    }
233
234    fn parse_record(&self, record: &Resolver::Record) -> Result<Url, url::ParseError> {
235        record.parse(self.http_scheme.clone())
236    }
237}
238
239impl<Resolver, Policy: policy::Policy> SrvClient<Resolver, Policy> {
240    /// Sets the SRV name of the client.
241    pub fn srv_name(self, srv_name: impl ToString) -> Self {
242        Self {
243            srv: srv_name.to_string(),
244            ..self
245        }
246    }
247
248    /// Sets the resolver of the client.
249    pub fn resolver<R>(self, resolver: R) -> SrvClient<R, Policy> {
250        SrvClient {
251            resolver,
252            cache: Default::default(),
253            policy: self.policy,
254            srv: self.srv,
255            fallback: self.fallback,
256            allowed_suffixes: self.allowed_suffixes,
257            http_scheme: self.http_scheme,
258            path_prefix: self.path_prefix,
259        }
260    }
261
262    /// Sets the policy of the client.
263    pub fn policy<P: policy::Policy>(self, policy: P) -> SrvClient<Resolver, P> {
264        SrvClient {
265            policy,
266            cache: Default::default(),
267            resolver: self.resolver,
268            srv: self.srv,
269            fallback: self.fallback,
270            allowed_suffixes: self.allowed_suffixes,
271            http_scheme: self.http_scheme,
272            path_prefix: self.path_prefix,
273        }
274    }
275
276    /// Sets the http scheme of the client.
277    pub fn http_scheme(self, http_scheme: Scheme) -> Self {
278        Self {
279            http_scheme,
280            ..self
281        }
282    }
283
284    /// Sets the path prefix of the client.
285    pub fn path_prefix(self, path_prefix: impl ToString) -> Self {
286        Self {
287            path_prefix: path_prefix.to_string(),
288            ..self
289        }
290    }
291}