soda-pool 0.0.4

Connection pool for tonic's gRPC channels
Documentation
use std::{io::Result, net::IpAddr};

#[cfg(not(any(test, feature = "_mock-dns")))]
pub use std::net::ToSocketAddrs;

#[cfg(any(test, feature = "_mock-dns"))]
pub use mock_net::ToSocketAddrs;

pub fn resolve_domain(domain: &str) -> Result<impl Iterator<Item = IpAddr>> {
    Ok((domain, 0).to_socket_addrs()?.map(|addr| addr.ip()))
}

#[cfg(any(test, feature = "_mock-dns"))]
#[cfg_attr(coverage_nightly, coverage(off))]
pub mod mock_net {
    use std::{io, net::SocketAddr, vec};

    use std::sync::{LazyLock, RwLock};

    type ToSocketAddrsFn = dyn Fn(&str, u16) -> io::Result<Vec<SocketAddr>> + Send + Sync;

    static DNS_RESULT: LazyLock<RwLock<Box<ToSocketAddrsFn>>> =
        LazyLock::new(|| RwLock::new(Box::new(|_, _| Ok(vec![]))));

    pub trait ToSocketAddrs {
        type Iter: Iterator<Item = SocketAddr>;

        fn to_socket_addrs(&self) -> io::Result<Self::Iter>;
    }

    impl ToSocketAddrs for (&str, u16) {
        type Iter = vec::IntoIter<SocketAddr>;
        fn to_socket_addrs(&self) -> io::Result<vec::IntoIter<SocketAddr>> {
            (*DNS_RESULT
                .read()
                .expect("failed to acquire read lock on DNS_RESULT"))(self.0, self.1)
            .map(IntoIterator::into_iter)
        }
    }

    #[allow(dead_code)]
    pub fn set_socket_addrs(func: Box<ToSocketAddrsFn>) {
        *DNS_RESULT
            .write()
            .expect("failed to acquire write lock on DNS_RESULT") = func;
    }
}

#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
    use serial_test::serial;

    use super::*;
    use std::{io, net::SocketAddr, str::FromStr};

    #[test]
    #[serial]
    fn can_mock_address_resolution() {
        let addresses = vec![
            IpAddr::from_str("128.0.0.1").unwrap(),
            IpAddr::from_str("129.0.0.1").unwrap(),
            IpAddr::from_str("::2").unwrap(),
            IpAddr::from_str("::3").unwrap(),
        ];

        {
            let sockets = addresses
                .iter()
                .map(|ip| SocketAddr::new(*ip, 0))
                .collect::<Vec<_>>();
            mock_net::set_socket_addrs(Box::new(move |_, _| Ok(sockets.clone())));
        }

        assert_eq!(
            resolve_domain("localhost").unwrap().collect::<Vec<_>>(),
            addresses
        );
    }

    #[test]
    #[serial]
    fn forwards_errors() {
        mock_net::set_socket_addrs(Box::new(|_, _| Err(io::Error::other("mock error"))));

        let result = resolve_domain("localhost");
        assert!(result.is_err());
        let Err(error) = result else { unreachable!() };
        assert_eq!(error.kind(), io::ErrorKind::Other);
        assert_eq!(error.to_string(), "mock error");
    }
}