use std::sync::Arc;
use crate::conf::ConfDynSeed;
use crate::seeds::{SeedsError, SeedsProvider};
#[derive(Debug, Clone)]
pub enum ResolvedSeeds {
Txt(Vec<String>),
A {
ips: Vec<String>,
port: u16,
rack: String,
dc: String,
tokens: String,
},
}
pub trait Resolver: Send + Sync {
fn resolve(&self, name: &str) -> Result<ResolvedSeeds, SeedsError>;
}
impl<T: Resolver + ?Sized> Resolver for Arc<T> {
fn resolve(&self, name: &str) -> Result<ResolvedSeeds, SeedsError> {
(**self).resolve(name)
}
}
pub struct DnsSeedsProvider {
name: String,
resolver: Box<dyn Resolver>,
}
impl DnsSeedsProvider {
#[must_use]
pub fn new(name: String, resolver: Box<dyn Resolver>) -> Self {
Self { name, resolver }
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
}
impl std::fmt::Debug for DnsSeedsProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DnsSeedsProvider")
.field("name", &self.name)
.finish_non_exhaustive()
}
}
impl SeedsProvider for DnsSeedsProvider {
fn get_seeds(&self) -> Result<Vec<ConfDynSeed>, SeedsError> {
let resolved = self.resolver.resolve(&self.name)?;
match resolved {
ResolvedSeeds::Txt(entries) => {
let mut out = Vec::with_capacity(entries.len());
for raw in entries {
let seed =
ConfDynSeed::parse(&raw).map_err(|e| SeedsError::Parse(e.to_string()))?;
out.push(seed);
}
Ok(out)
}
ResolvedSeeds::A {
ips,
port,
rack,
dc,
tokens,
} => {
let mut out = Vec::with_capacity(ips.len());
for ip in ips {
let raw = format!("{ip}:{port}:{rack}:{dc}:{tokens}");
let seed =
ConfDynSeed::parse(&raw).map_err(|e| SeedsError::Parse(e.to_string()))?;
out.push(seed);
}
Ok(out)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct StaticResolver(ResolvedSeeds);
impl Resolver for StaticResolver {
fn resolve(&self, _: &str) -> Result<ResolvedSeeds, SeedsError> {
Ok(self.0.clone())
}
}
#[test]
fn txt_branch() {
let r = StaticResolver(ResolvedSeeds::Txt(vec![
"127.0.0.1:8101:rA:dc1:1".into(),
"127.0.0.2:8101:rA:dc1:2".into(),
]));
let p = DnsSeedsProvider::new("n".into(), Box::new(r));
let v = p.get_seeds().unwrap();
assert_eq!(v.len(), 2);
assert_eq!(v[0].host(), "127.0.0.1");
}
#[test]
fn a_branch_synthesises_seed_format() {
let r = StaticResolver(ResolvedSeeds::A {
ips: vec!["10.0.0.1".into(), "10.0.0.2".into()],
port: 8101,
rack: "rA".into(),
dc: "dc1".into(),
tokens: "1".into(),
});
let p = DnsSeedsProvider::new("n".into(), Box::new(r));
let v = p.get_seeds().unwrap();
assert_eq!(v.len(), 2);
assert_eq!(v[0].port(), 8101);
assert_eq!(v[0].dc(), "dc1");
}
#[test]
fn parse_error_propagates() {
let r = StaticResolver(ResolvedSeeds::Txt(vec!["invalid-seed".into()]));
let p = DnsSeedsProvider::new("n".into(), Box::new(r));
assert!(matches!(p.get_seeds(), Err(SeedsError::Parse(_))));
}
}