use crate::{resolver::SrvResolver, SrvRecord};
use arc_swap::ArcSwap;
use futures_util::{
pin_mut,
stream::{self, Stream, StreamExt},
FutureExt,
};
use http::uri::{Scheme, Uri};
use std::{fmt::Debug, future::Future, iter::FromIterator, sync::Arc, time::Instant};
mod cache;
pub use cache::Cache;
pub mod policy;
#[derive(Debug, thiserror::Error)]
pub enum Error<Lookup: Debug> {
#[error("SRV lookup error")]
Lookup(Lookup),
#[error("building uri from SRV record: {0}")]
RecordParsing(#[from] http::Error),
#[error("no SRV targets to use")]
NoTargets,
}
#[derive(Debug)]
pub struct SrvClient<Resolver, Policy: policy::Policy = policy::Affinity> {
srv: String,
resolver: Resolver,
http_scheme: Scheme,
path_prefix: String,
policy: Policy,
cache: ArcSwap<Cache<Policy::CacheItem>>,
}
pub enum Execution {
Serial,
Concurrent,
}
impl Default for Execution {
fn default() -> Self {
Self::Serial
}
}
impl<Resolver: Default, Policy: policy::Policy + Default> SrvClient<Resolver, Policy> {
pub fn new(srv_name: impl ToString) -> Self {
Self::new_with_resolver(srv_name, Resolver::default())
}
}
impl<Resolver, Policy: policy::Policy + Default> SrvClient<Resolver, Policy> {
pub fn new_with_resolver(srv_name: impl ToString, resolver: Resolver) -> Self {
Self {
srv: srv_name.to_string(),
resolver,
http_scheme: Scheme::HTTPS,
path_prefix: String::from("/"),
policy: Default::default(),
cache: Default::default(),
}
}
}
impl<Resolver: SrvResolver, Policy: policy::Policy> SrvClient<Resolver, Policy> {
pub async fn get_srv_records(
&self,
) -> Result<(Vec<Resolver::Record>, Instant), Error<Resolver::Error>> {
self.resolver
.get_srv_records(&self.srv)
.await
.map_err(Error::Lookup)
}
pub async fn get_fresh_uri_candidates(
&self,
) -> Result<(Vec<Uri>, Instant), Error<Resolver::Error>> {
let (records, valid_until) = self.get_srv_records().await?;
let uris = records
.iter()
.map(|record| self.parse_record(record))
.collect::<Result<Vec<Uri>, _>>()?;
Ok((uris, valid_until))
}
async fn refresh_cache(&self) -> Result<Arc<Cache<Policy::CacheItem>>, Error<Resolver::Error>> {
let new_cache = Arc::new(self.policy.refresh_cache(self).await?);
self.cache.store(new_cache.clone());
Ok(new_cache)
}
async fn get_valid_cache(
&self,
) -> Result<Arc<Cache<Policy::CacheItem>>, Error<Resolver::Error>> {
match self.cache.load_full() {
cache if cache.valid() => Ok(cache),
_ => self.refresh_cache().await,
}
}
pub async fn execute_stream<'a, T, E, Fut>(
&'a self,
execution_mode: Execution,
func: impl FnMut(Uri) -> Fut + 'a,
) -> Result<impl Stream<Item = Result<T, E>> + 'a, Error<Resolver::Error>>
where
E: std::error::Error,
Fut: Future<Output = Result<T, E>> + 'a,
{
let mut func = func;
let cache = self.get_valid_cache().await?;
let order = self.policy.order(cache.items());
let func = {
let cache = cache.clone();
move |idx| {
let candidate = Policy::cache_item_to_uri(&cache.items()[idx]);
func(candidate.to_owned()).map(move |res| (idx, res))
}
};
let results = match execution_mode {
Execution::Serial => stream::iter(order).then(func).left_stream(),
#[allow(clippy::from_iter_instead_of_collect)]
Execution::Concurrent => {
stream::FuturesUnordered::from_iter(order.map(func)).right_stream()
}
};
let results = results.map(move |(candidate_idx, result)| {
let candidate = Policy::cache_item_to_uri(&cache.items()[candidate_idx]);
match result {
Ok(res) => {
#[cfg(feature = "log")]
tracing::info!(URI = %candidate, "execution attempt succeeded");
self.policy.note_success(candidate);
Ok(res)
}
Err(err) => {
#[cfg(feature = "log")]
tracing::info!(URI = %candidate, error = %err, "execution attempt failed");
self.policy.note_failure(candidate);
Err(err)
}
}
});
Ok(results)
}
pub async fn execute<T, E, Fut>(
&self,
execution_mode: Execution,
func: impl FnMut(Uri) -> Fut,
) -> Result<Result<T, E>, Error<Resolver::Error>>
where
E: std::error::Error,
Fut: Future<Output = Result<T, E>>,
{
let results = self.execute_stream(execution_mode, func).await?;
pin_mut!(results);
let mut last_error = None;
while let Some(result) = results.next().await {
match result {
Ok(res) => return Ok(Ok(res)),
Err(err) => last_error = Some(err),
}
}
if let Some(err) = last_error {
Ok(Err(err))
} else {
Err(Error::NoTargets)
}
}
fn parse_record(&self, record: &Resolver::Record) -> Result<Uri, http::Error> {
record.parse(self.http_scheme.clone(), self.path_prefix.as_str())
}
}
impl<Resolver, Policy: policy::Policy> SrvClient<Resolver, Policy> {
pub fn srv_name(self, srv_name: impl ToString) -> Self {
Self {
srv: srv_name.to_string(),
..self
}
}
pub fn resolver<R>(self, resolver: R) -> SrvClient<R, Policy> {
SrvClient {
resolver,
cache: Default::default(),
policy: self.policy,
srv: self.srv,
http_scheme: self.http_scheme,
path_prefix: self.path_prefix,
}
}
pub fn policy<P: policy::Policy>(self, policy: P) -> SrvClient<Resolver, P> {
SrvClient {
policy,
cache: Default::default(),
resolver: self.resolver,
srv: self.srv,
http_scheme: self.http_scheme,
path_prefix: self.path_prefix,
}
}
pub fn http_scheme(self, http_scheme: Scheme) -> Self {
Self {
http_scheme,
..self
}
}
pub fn path_prefix(self, path_prefix: impl ToString) -> Self {
Self {
path_prefix: path_prefix.to_string(),
..self
}
}
}