use super::host::lookup_host;
use crate::base::iana::{Class, Rtype};
use crate::base::message::Message;
use crate::base::name::{Name, ToName, ToRelativeName};
use crate::base::wire::ParseError;
use crate::rdata::{Aaaa, Srv, A};
use crate::resolv::resolver::Resolver;
use core::fmt;
use futures_util::stream::{self, Stream, StreamExt};
use octseq::octets::Octets;
use rand::distr::{Distribution, Uniform};
use std::net::{IpAddr, SocketAddr};
use std::vec::Vec;
use std::{io, mem, ops};
#[cfg(feature = "smallvec")]
type OctetsVec = octseq::octets::SmallOctets;
#[cfg(not(feature = "smallvec"))]
type OctetsVec = Vec<u8>;
pub async fn lookup_srv(
resolver: &impl Resolver,
service: impl ToRelativeName,
name: impl ToName,
fallback_port: u16,
) -> Result<Option<FoundSrvs>, SrvError> {
let full_name = match (&service).chain(&name) {
Ok(name) => name,
Err(_) => return Err(SrvError::LongName),
};
let answer = resolver.query((full_name, Rtype::SRV)).await?;
FoundSrvs::new(answer.as_ref().for_slice(), name, fallback_port)
}
#[derive(Clone, Debug)]
pub struct FoundSrvs {
items: Result<Vec<SrvItem>, SrvItem>,
}
impl FoundSrvs {
pub fn into_stream<R: Resolver>(
self,
resolver: &R,
) -> impl Stream<Item = Result<ResolvedSrvItem, io::Error>> + '_
where
R::Octets: Octets,
{
let iter = match self.items {
Ok(vec) => {
Some(vec.into_iter()).into_iter().flatten().chain(None)
}
Err(one) => None.into_iter().flatten().chain(Some(one)),
};
stream::iter(iter).then(move |item| item.resolve(resolver))
}
pub fn into_srvs(self) -> impl Iterator<Item = Srv<Name<OctetsVec>>> {
let (left, right) = match self.items {
Ok(ok) => (Some(ok.into_iter()), None),
Err(err) => (None, Some(std::iter::once(err))),
};
left.into_iter()
.flatten()
.chain(right.into_iter().flatten())
.map(|item| item.srv)
}
pub fn merge(&mut self, other: &Self) {
let mut items = match mem::replace(&mut self.items, Ok(Vec::new())) {
Ok(items) => items,
Err(one) => vec![one],
};
match other.items {
Ok(ref vec) => items.extend_from_slice(vec),
Err(ref one) => items.push(one.clone()),
}
Self::reorder_items(&mut items);
self.items = Ok(items);
}
}
impl FoundSrvs {
fn new(
answer: &Message<[u8]>,
fallback_name: impl ToName,
fallback_port: u16,
) -> Result<Option<Self>, SrvError> {
let name =
answer.canonical_name().ok_or(SrvError::MalformedAnswer)?;
let mut items = Self::process_records(answer, &name)?;
if items.is_empty() {
return Ok(Some(FoundSrvs {
items: Err(SrvItem::fallback(fallback_name, fallback_port)),
}));
}
if items.len() == 1 && items[0].target().is_root() {
return Ok(None);
}
Self::process_additional(&mut items, answer)?;
Self::reorder_items(&mut items);
Ok(Some(FoundSrvs { items: Ok(items) }))
}
fn process_records(
answer: &Message<[u8]>,
name: &impl ToName,
) -> Result<Vec<SrvItem>, SrvError> {
let mut res = Vec::new();
for record in answer.answer()?.limit_to_in::<Srv<_>>().flatten() {
if record.owner() == name {
res.push(SrvItem::from_rdata(record.data()))
}
}
Ok(res)
}
fn process_additional(
items: &mut [SrvItem],
answer: &Message<[u8]>,
) -> Result<(), SrvError> {
let additional = answer.additional()?;
for item in items {
let mut addrs = Vec::new();
for record in additional {
let record = match record {
Ok(record) => record,
Err(_) => continue,
};
if record.class() != Class::IN
|| record.owner() != item.target()
{
continue;
}
if let Ok(Some(record)) = record.to_record::<A>() {
addrs.push(record.data().addr().into())
}
if let Ok(Some(record)) = record.to_record::<Aaaa>() {
addrs.push(record.data().addr().into())
}
}
if !addrs.is_empty() {
item.resolved = Some(addrs)
}
}
Ok(())
}
fn reorder_items(items: &mut [SrvItem]) {
items.sort_by_key(|k| (k.priority(), k.weight()));
let mut current_prio = 0;
let mut weight_sum = 0;
let mut first_index = 0;
for i in 0..items.len() {
if current_prio != items[i].priority() {
current_prio = items[i].priority();
Self::reorder_by_weight(
&mut items[first_index..i],
weight_sum,
);
weight_sum = 0;
first_index = i;
}
weight_sum += u32::from(items[i].weight());
}
Self::reorder_by_weight(&mut items[first_index..], weight_sum);
}
fn reorder_by_weight(items: &mut [SrvItem], weight_sum: u32) {
let mut rng = rand::rng();
let mut weight_sum = weight_sum;
for i in 0..items.len() {
#[allow(clippy::unwrap_used)]
let range = Uniform::new(0, weight_sum + 1).unwrap();
let mut sum: u32 = 0;
let pick = range.sample(&mut rng);
for j in 0..items.len() {
sum += u32::from(items[j].weight());
if sum >= pick {
weight_sum -= u32::from(items[j].weight());
items.swap(i, j);
break;
}
}
}
}
}
#[derive(Clone, Debug)]
pub struct SrvItem {
srv: Srv<Name<OctetsVec>>,
#[allow(dead_code)] fallback: bool,
resolved: Option<Vec<IpAddr>>,
}
impl SrvItem {
fn from_rdata(srv: &Srv<impl ToName>) -> Self {
SrvItem {
srv: Srv::new(
srv.priority(),
srv.weight(),
srv.port(),
srv.target().to_name(),
),
fallback: false,
resolved: None,
}
}
fn fallback(name: impl ToName, fallback_port: u16) -> Self {
SrvItem {
srv: Srv::new(0, 0, fallback_port, name.to_name()),
fallback: true,
resolved: None,
}
}
pub async fn resolve<R: Resolver>(
self,
resolver: &R,
) -> Result<ResolvedSrvItem, io::Error>
where
R::Octets: Octets,
{
let port = self.port();
if let Some(resolved) = self.resolved {
return Ok(ResolvedSrvItem {
srv: self.srv,
resolved: {
resolved
.into_iter()
.map(|addr| SocketAddr::new(addr, port))
.collect()
},
});
}
let resolved = lookup_host(resolver, self.target()).await?;
Ok(ResolvedSrvItem {
srv: self.srv,
resolved: {
resolved
.iter()
.map(|addr| SocketAddr::new(addr, port))
.collect()
},
})
}
}
impl AsRef<Srv<Name<OctetsVec>>> for SrvItem {
fn as_ref(&self) -> &Srv<Name<OctetsVec>> {
&self.srv
}
}
impl ops::Deref for SrvItem {
type Target = Srv<Name<OctetsVec>>;
fn deref(&self) -> &Self::Target {
self.as_ref()
}
}
#[derive(Clone, Debug)]
pub struct ResolvedSrvItem {
srv: Srv<Name<OctetsVec>>,
resolved: Vec<SocketAddr>,
}
impl ResolvedSrvItem {
pub fn resolved(&self) -> &[SocketAddr] {
&self.resolved
}
}
impl AsRef<Srv<Name<OctetsVec>>> for ResolvedSrvItem {
fn as_ref(&self) -> &Srv<Name<OctetsVec>> {
&self.srv
}
}
impl ops::Deref for ResolvedSrvItem {
type Target = Srv<Name<OctetsVec>>;
fn deref(&self) -> &Self::Target {
self.as_ref()
}
}
#[derive(Debug)]
pub enum SrvError {
LongName,
MalformedAnswer,
Query(io::Error),
}
impl fmt::Display for SrvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SrvError::LongName => write!(f, "name too long"),
SrvError::MalformedAnswer => write!(f, "malformed answer"),
SrvError::Query(e) => write!(f, "error executing query {}", e),
}
}
}
impl std::error::Error for SrvError {}
impl From<io::Error> for SrvError {
fn from(err: io::Error) -> SrvError {
SrvError::Query(err)
}
}
impl From<ParseError> for SrvError {
fn from(_: ParseError) -> SrvError {
SrvError::MalformedAnswer
}
}