use futures_util::stream::Stream;
use hickory_resolver::{
Resolver,
config::{ResolverConfig, ResolverOpts},
name_server::TokioConnectionProvider,
};
use std::collections::HashMap;
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::{MissedTickBehavior, interval};
use tower::discover::Change;
use tracing::{debug, error, trace};
type TokioResolver = Resolver<TokioConnectionProvider>;
type DiscoveryResult = Result<Change<IpAddr, String>, Box<dyn std::error::Error + Send + Sync>>;
type DiscoveryReceiver = mpsc::UnboundedReceiver<DiscoveryResult>;
type DiscoverySender = mpsc::UnboundedSender<DiscoveryResult>;
#[derive(Debug, Clone)]
pub struct DnsDiscoveryConfig {
pub hostname: String,
pub port: u16,
pub refresh_interval: Duration,
pub use_https: bool,
pub resolver_config: Option<ResolverConfig>,
pub resolver_opts: Option<ResolverOpts>,
}
impl DnsDiscoveryConfig {
pub fn new<S: Into<String>>(hostname: S, port: u16) -> Self {
Self {
hostname: hostname.into(),
port,
refresh_interval: Duration::from_secs(30),
use_https: false,
resolver_config: None,
resolver_opts: None,
}
}
pub fn with_refresh_interval(mut self, interval: Duration) -> Self {
self.refresh_interval = interval;
self
}
pub fn with_https(mut self, use_https: bool) -> Self {
self.use_https = use_https;
self
}
pub fn with_resolver_config(mut self, config: ResolverConfig) -> Self {
self.resolver_config = Some(config);
self
}
pub fn with_resolver_opts(mut self, opts: ResolverOpts) -> Self {
self.resolver_opts = Some(opts);
self
}
}
#[derive(Clone)]
pub struct DnsDiscovery {
receiver: Arc<tokio::sync::Mutex<DiscoveryReceiver>>,
_handle: Arc<tokio::task::JoinHandle<()>>,
}
impl DnsDiscovery {
pub fn new(
config: DnsDiscoveryConfig,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let resolver = if let (Some(resolver_config), Some(_opts)) =
(&config.resolver_config, &config.resolver_opts)
{
Resolver::builder_with_config(
resolver_config.clone(),
TokioConnectionProvider::default(),
)
.build()
} else {
Resolver::builder_tokio()
.map_err(|e| format!("Failed to create resolver from system config: {e}"))?
.build()
};
let (sender, receiver) = mpsc::unbounded_channel();
let handle = tokio::spawn(Self::discovery_task(config, resolver, sender));
Ok(Self {
receiver: Arc::new(tokio::sync::Mutex::new(receiver)),
_handle: Arc::new(handle),
})
}
async fn discovery_task(
config: DnsDiscoveryConfig,
resolver: TokioResolver,
sender: DiscoverySender,
) {
let mut current_services = HashMap::new();
let mut interval = interval(config.refresh_interval);
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
if let Err(e) =
Self::resolve_and_send(&config, &resolver, &mut current_services, &sender).await
{
error!("Initial DNS resolution failed: {}", e);
let _ = sender.send(Err(e));
}
loop {
interval.tick().await;
if let Err(e) =
Self::resolve_and_send(&config, &resolver, &mut current_services, &sender).await
{
error!("DNS resolution refresh failed: {}", e);
let _ = sender.send(Err(e));
}
}
}
async fn resolve_and_send(
config: &DnsDiscoveryConfig,
resolver: &TokioResolver,
current_services: &mut HashMap<IpAddr, String>,
sender: &DiscoverySender,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
trace!("Resolving DNS for hostname: {}", config.hostname);
let lookup_result = resolver
.lookup_ip(&config.hostname)
.await
.map_err(|e| format!("DNS lookup failed for {}: {}", config.hostname, e))?;
let mut new_services = HashMap::new();
let scheme = if config.use_https { "https" } else { "http" };
for ip in lookup_result.iter() {
let service_url = format!("{}://{}:{}", scheme, ip, config.port);
new_services.insert(ip, service_url);
}
for (ip, _) in current_services.iter() {
if !new_services.contains_key(ip) {
debug!("Removing service: {}", ip);
let _ = sender.send(Ok(Change::Remove(*ip)));
}
}
for (ip, service_url) in &new_services {
if !current_services.contains_key(ip) {
debug!("Adding service: {} -> {}", ip, service_url);
let _ = sender.send(Ok(Change::Insert(*ip, service_url.clone())));
}
}
*current_services = new_services;
Ok(())
}
}
impl Stream for DnsDiscovery {
type Item = DiscoveryResult;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.receiver.try_lock() {
Ok(mut receiver) => receiver.poll_recv(cx),
Err(_) => Poll::Pending, }
}
}
pub struct StaticDnsDiscovery {
receiver: DiscoveryReceiver,
_handle: tokio::task::JoinHandle<()>,
}
impl StaticDnsDiscovery {
pub fn new(
config: DnsDiscoveryConfig,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let resolver = if let (Some(resolver_config), Some(_opts)) =
(&config.resolver_config, &config.resolver_opts)
{
Resolver::builder_with_config(
resolver_config.clone(),
TokioConnectionProvider::default(),
)
.build()
} else {
Resolver::builder_tokio()
.map_err(|e| format!("Failed to create resolver from system config: {e}"))?
.build()
};
let (sender, receiver) = mpsc::unbounded_channel();
let handle = tokio::spawn(Self::static_discovery_task(config, resolver, sender));
Ok(Self {
receiver,
_handle: handle,
})
}
async fn static_discovery_task(
config: DnsDiscoveryConfig,
resolver: TokioResolver,
sender: DiscoverySender,
) {
trace!(
"Performing static DNS resolution for hostname: {}",
config.hostname
);
match resolver.lookup_ip(&config.hostname).await {
Ok(lookup_result) => {
let scheme = if config.use_https { "https" } else { "http" };
for ip in lookup_result.iter() {
let service_url = format!("{}://{}:{}", scheme, ip, config.port);
debug!(
"Discovered service: {} -> {}://{}:{}",
ip, scheme, ip, config.port
);
let _ = sender.send(Ok(Change::Insert(ip, service_url)));
}
}
Err(e) => {
error!(
"Static DNS resolution failed for {}: {}",
config.hostname, e
);
let _ = sender.send(Err(format!(
"DNS lookup failed for {}: {}",
config.hostname, e
)
.into()));
}
}
}
}
impl Stream for StaticDnsDiscovery {
type Item = DiscoveryResult;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.receiver.poll_recv(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_dns_discovery_config() {
let config = DnsDiscoveryConfig::new("example.com", 8080)
.with_refresh_interval(Duration::from_secs(60))
.with_https(true);
assert_eq!(config.hostname, "example.com");
assert_eq!(config.port, 8080);
assert_eq!(config.refresh_interval, Duration::from_secs(60));
assert!(config.use_https);
}
#[tokio::test]
async fn test_static_dns_discovery_creation() {
let config = DnsDiscoveryConfig::new("localhost", 8080);
let discovery = StaticDnsDiscovery::new(config);
assert!(discovery.is_ok());
}
#[tokio::test]
async fn test_dns_discovery_creation() {
let config = DnsDiscoveryConfig::new("localhost", 8080);
let discovery = DnsDiscovery::new(config);
assert!(discovery.is_ok());
}
}