srv-rs 0.2.0

Rust client for communicating with SRV-located services
Documentation
//! Clients based on SRV lookups.

use crate::{resolver::SrvResolver, SrvRecord};
use arc_swap::ArcSwap;
use futures_util::{
    pin_mut,
    stream::{self, Stream, StreamExt},
    FutureExt,
};
use http::uri::{Scheme, Uri};
use std::{fmt::Debug, future::Future, iter::FromIterator, sync::Arc, time::Instant};

mod cache;
pub use cache::Cache;

/// SRV target selection policies.
pub mod policy;

/// Errors encountered by a [`SrvClient`].
#[derive(Debug, thiserror::Error)]
pub enum Error<Lookup: Debug> {
    /// SRV lookup errors
    #[error("SRV lookup error")]
    Lookup(Lookup),
    /// SRV record parsing errors
    #[error("building uri from SRV record: {0}")]
    RecordParsing(#[from] http::Error),
    /// Produced when there are no SRV targets for a client to use
    #[error("no SRV targets to use")]
    NoTargets,
}

/// Client for intelligently performing operations on a service located by SRV records.
///
/// # Usage
///
/// After being created by [`SrvClient::new`] or [`SrvClient::new_with_resolver`],
/// operations can be performed on the service pointed to by a [`SrvClient`] with
/// the [`execute`] and [`execute_stream`] methods.
///
/// ## DNS Resolvers
///
/// The resolver used to lookup SRV records is determined by a client's
/// [`SrvResolver`], and can be set with [`SrvClient::resolver`].
///
/// ## SRV Target Selection Policies
///
/// SRV target selection order is determined by a client's [`Policy`],
/// and can be set with [`SrvClient::policy`].
///
/// [`execute`]: SrvClient::execute()
/// [`execute_stream`]: SrvClient::execute_stream()
/// [`Policy`]: policy::Policy
#[derive(Debug)]
pub struct SrvClient<Resolver, Policy: policy::Policy = policy::Affinity> {
    srv: String,
    resolver: Resolver,
    http_scheme: Scheme,
    path_prefix: String,
    policy: Policy,
    cache: ArcSwap<Cache<Policy::CacheItem>>,
}

/// Execution mode to use when performing an operation on SRV targets.
pub enum Execution {
    /// Operations are performed *serially* (i.e. one after the other).
    Serial,
    /// Operations are performed *concurrently* (i.e. all at once).
    /// Note that this does not imply parallelism--no additional tasks are spawned.
    Concurrent,
}

impl Default for Execution {
    fn default() -> Self {
        Self::Serial
    }
}

impl<Resolver: Default, Policy: policy::Policy + Default> SrvClient<Resolver, Policy> {
    /// Creates a new client for communicating with services located by `srv_name`.
    ///
    /// # Examples
    /// ```
    /// use srv_rs::{SrvClient, resolver::libresolv::LibResolv};
    /// let client = SrvClient::<LibResolv>::new("_http._tcp.example.com");
    /// ```
    pub fn new(srv_name: impl ToString) -> Self {
        Self::new_with_resolver(srv_name, Resolver::default())
    }
}

impl<Resolver, Policy: policy::Policy + Default> SrvClient<Resolver, Policy> {
    /// Creates a new client for communicating with services located by `srv_name`.
    pub fn new_with_resolver(srv_name: impl ToString, resolver: Resolver) -> Self {
        Self {
            srv: srv_name.to_string(),
            resolver,
            http_scheme: Scheme::HTTPS,
            path_prefix: String::from("/"),
            policy: Default::default(),
            cache: Default::default(),
        }
    }
}

impl<Resolver: SrvResolver, Policy: policy::Policy> SrvClient<Resolver, Policy> {
    /// Gets a fresh set of SRV records from a client's DNS resolver, returning
    /// them along with the time they're valid until.
    pub async fn get_srv_records(
        &self,
    ) -> Result<(Vec<Resolver::Record>, Instant), Error<Resolver::Error>> {
        self.resolver
            .get_srv_records(&self.srv)
            .await
            .map_err(Error::Lookup)
    }

    /// Gets a fresh set of SRV records from a client's DNS resolver and parses
    /// their target/port pairs into URIs, which are returned along with the
    /// time they're valid until--i.e., the time a cache containing these URIs
    /// should expire.
    pub async fn get_fresh_uri_candidates(
        &self,
    ) -> Result<(Vec<Uri>, Instant), Error<Resolver::Error>> {
        // Query DNS for the SRV record
        let (records, valid_until) = self.get_srv_records().await?;

        // Create URIs from SRV records
        let uris = records
            .iter()
            .map(|record| self.parse_record(record))
            .collect::<Result<Vec<Uri>, _>>()?;

        Ok((uris, valid_until))
    }

    async fn refresh_cache(&self) -> Result<Arc<Cache<Policy::CacheItem>>, Error<Resolver::Error>> {
        let new_cache = Arc::new(self.policy.refresh_cache(self).await?);
        self.cache.store(new_cache.clone());
        Ok(new_cache)
    }

    /// Gets a client's cached items, refreshing the existing cache if it is invalid.
    async fn get_valid_cache(
        &self,
    ) -> Result<Arc<Cache<Policy::CacheItem>>, Error<Resolver::Error>> {
        match self.cache.load_full() {
            cache if cache.valid() => Ok(cache),
            _ => self.refresh_cache().await,
        }
    }

    /// Performs an operation on all of a client's SRV targets, producing a
    /// stream of results (one for each target). If the serial execution mode is
    /// specified, the operation will be performed on each target in the order
    /// determined by the current [`Policy`], and the results will be returned
    /// in the same order. If the concurrent execution mode is specified, the
    /// operation will be performed on all targets concurrently, and results
    /// will be returned in the order they become available.
    ///
    /// # Examples
    ///
    /// ```
    /// # use srv_rs::EXAMPLE_SRV;
    /// use srv_rs::{SrvClient, Error, Execution};
    /// use srv_rs::resolver::libresolv::{LibResolv, LibResolvError};
    ///
    /// # #[tokio::main]
    /// # async fn main() -> Result<(), Error<LibResolvError>> {
    /// # let client = SrvClient::<LibResolv>::new(EXAMPLE_SRV);
    /// let results_stream = client.execute_stream(Execution::Serial, |address| async move {
    ///     Ok::<_, std::convert::Infallible>(address.to_string())
    /// })
    /// .await?;
    /// // Do something with the stream, for example collect all results into a `Vec`:
    /// use futures::stream::StreamExt;
    /// let results: Vec<Result<_, _>> = results_stream.collect().await;
    /// for result in results {
    ///     assert!(result.is_ok());
    /// }
    /// # Ok(())
    /// # }
    /// ```
    ///
    /// [`Policy`]: policy::Policy
    pub async fn execute_stream<'a, T, E, Fut>(
        &'a self,
        execution_mode: Execution,
        func: impl FnMut(Uri) -> Fut + 'a,
    ) -> Result<impl Stream<Item = Result<T, E>> + 'a, Error<Resolver::Error>>
    where
        E: std::error::Error,
        Fut: Future<Output = Result<T, E>> + 'a,
    {
        let mut func = func;
        let cache = self.get_valid_cache().await?;
        let order = self.policy.order(cache.items());
        let func = {
            let cache = cache.clone();
            move |idx| {
                let candidate = Policy::cache_item_to_uri(&cache.items()[idx]);
                func(candidate.to_owned()).map(move |res| (idx, res))
            }
        };
        let results = match execution_mode {
            Execution::Serial => stream::iter(order).then(func).left_stream(),
            #[allow(clippy::from_iter_instead_of_collect)]
            Execution::Concurrent => {
                stream::FuturesUnordered::from_iter(order.map(func)).right_stream()
            }
        };
        let results = results.map(move |(candidate_idx, result)| {
            let candidate = Policy::cache_item_to_uri(&cache.items()[candidate_idx]);
            match result {
                Ok(res) => {
                    #[cfg(feature = "log")]
                    tracing::info!(URI = %candidate, "execution attempt succeeded");
                    self.policy.note_success(candidate);
                    Ok(res)
                }
                Err(err) => {
                    #[cfg(feature = "log")]
                    tracing::info!(URI = %candidate, error = %err, "execution attempt failed");
                    self.policy.note_failure(candidate);
                    Err(err)
                }
            }
        });
        Ok(results)
    }

    /// Performs an operation on a client's SRV targets, producing the first
    /// successful result or the last error encountered if every execution of
    /// the operation was unsuccessful.
    ///
    /// # Examples
    ///
    /// ```
    /// # use srv_rs::EXAMPLE_SRV;
    /// use srv_rs::{SrvClient, Error, Execution};
    /// use srv_rs::resolver::libresolv::{LibResolv, LibResolvError};
    ///
    /// # #[tokio::main]
    /// # async fn main() -> Result<(), Error<LibResolvError>> {
    /// let client = SrvClient::<LibResolv>::new(EXAMPLE_SRV);
    ///
    /// let res = client.execute(Execution::Serial, |address| async move {
    ///     Ok::<_, std::convert::Infallible>(address.to_string())
    /// })
    /// .await?;
    /// assert!(res.is_ok());
    ///
    /// let res = client.execute(Execution::Concurrent, |address| async move {
    ///     address.to_string().parse::<usize>()
    /// })
    /// .await?;
    /// assert!(res.is_err());
    /// # Ok(())
    /// # }
    /// ```
    pub async fn execute<T, E, Fut>(
        &self,
        execution_mode: Execution,
        func: impl FnMut(Uri) -> Fut,
    ) -> Result<Result<T, E>, Error<Resolver::Error>>
    where
        E: std::error::Error,
        Fut: Future<Output = Result<T, E>>,
    {
        let results = self.execute_stream(execution_mode, func).await?;
        pin_mut!(results);

        let mut last_error = None;
        while let Some(result) = results.next().await {
            match result {
                Ok(res) => return Ok(Ok(res)),
                Err(err) => last_error = Some(err),
            }
        }

        if let Some(err) = last_error {
            Ok(Err(err))
        } else {
            Err(Error::NoTargets)
        }
    }

    fn parse_record(&self, record: &Resolver::Record) -> Result<Uri, http::Error> {
        record.parse(self.http_scheme.clone(), self.path_prefix.as_str())
    }
}

impl<Resolver, Policy: policy::Policy> SrvClient<Resolver, Policy> {
    /// Sets the SRV name of the client.
    pub fn srv_name(self, srv_name: impl ToString) -> Self {
        Self {
            srv: srv_name.to_string(),
            ..self
        }
    }

    /// Sets the resolver of the client.
    pub fn resolver<R>(self, resolver: R) -> SrvClient<R, Policy> {
        SrvClient {
            resolver,
            cache: Default::default(),
            policy: self.policy,
            srv: self.srv,
            http_scheme: self.http_scheme,
            path_prefix: self.path_prefix,
        }
    }

    /// Sets the policy of the client.
    ///
    /// # Examples
    ///
    /// ```
    /// # use srv_rs::EXAMPLE_SRV;
    /// use srv_rs::{SrvClient, policy::Rfc2782, resolver::libresolv::LibResolv};
    /// let client = SrvClient::<LibResolv>::new(EXAMPLE_SRV).policy(Rfc2782);
    /// ```
    pub fn policy<P: policy::Policy>(self, policy: P) -> SrvClient<Resolver, P> {
        SrvClient {
            policy,
            cache: Default::default(),
            resolver: self.resolver,
            srv: self.srv,
            http_scheme: self.http_scheme,
            path_prefix: self.path_prefix,
        }
    }

    /// Sets the http scheme of the client.
    pub fn http_scheme(self, http_scheme: Scheme) -> Self {
        Self {
            http_scheme,
            ..self
        }
    }

    /// Sets the path prefix of the client.
    pub fn path_prefix(self, path_prefix: impl ToString) -> Self {
        Self {
            path_prefix: path_prefix.to_string(),
            ..self
        }
    }
}