use std::io;
use domain_core::bits::name::{
Dname, ParsedDname, ParsedDnameError, ToRelativeDname, ToDname
};
use domain_core::iana::Rtype;
use domain_core::rdata::parsed::{A, Aaaa, Srv};
use rand;
use rand::distributions::{Distribution, Uniform};
use tokio::prelude::{Async, Future, Poll, Stream};
use crate::resolver::Resolver;
use super::host::{FoundHosts, FoundHostsSocketIter, LookupHost, lookup_host};
pub fn lookup_srv<R, S, N>(
resolver: R,
service: S,
name: N,
fallback_port: u16
) -> LookupSrv<R, S, N>
where
R: Resolver,
S: ToRelativeDname + Clone + Send + 'static,
N: ToDname + Send + 'static
{
let query = {
let full_name = match (&service).chain(&name) {
Ok(name) => name,
Err(_) => {
return LookupSrv {
data: None,
query: Err(Some(SrvError::LongName))
}
}
};
resolver.query((full_name, Rtype::Srv))
};
LookupSrv {
data: Some(LookupData {
resolver,
host: name,
service,
fallback_port
}),
query: Ok(query)
}
}
#[derive(Debug)]
struct LookupData<R, S, N> {
resolver: R,
host: N,
service: S,
fallback_port: u16,
}
pub struct LookupSrv<R: Resolver, S, N> {
data: Option<LookupData<R, S, N>>,
query: Result<R::Query, Option<SrvError>>,
}
impl<R, S, N> Future for LookupSrv<R, S, N>
where
R: Resolver,
S: ToRelativeDname + Clone + Send + 'static,
N: ToDname + Send + 'static
{
type Item = Option<FoundSrvs<R, S>>;
type Error = SrvError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.query {
Ok(ref mut query) => match query.poll() {
Ok(Async::NotReady) => Ok(Async::NotReady),
Ok(Async::Ready(answer)) => {
Ok(Async::Ready(
FoundSrvs::new(
answer,
self.data.take().expect("polled resolved future")
)?
))
}
Err(_) => {
Ok(Async::Ready(Some(
FoundSrvs::new_dummy(
self.data.take().expect("polled resolved future"))
)))
}
}
Err(ref mut err) => {
Err(err.take().expect("polled resolved future"))
}
}
}
}
pub struct LookupSrvStream<R: Resolver, S> {
resolver: R,
items: Vec<SrvItem<S>>,
lookup: Option<LookupHost<R>>
}
impl<R: Resolver, S> LookupSrvStream<R, S> {
fn new(found: FoundSrvs<R, S>) -> Self {
LookupSrvStream {
resolver: found.resolver,
items: found.items.into_iter().rev().collect(),
lookup: None,
}
}
}
impl<R, S> Stream for LookupSrvStream<R, S>
where R: Resolver, S: ToRelativeDname + Clone + Send + 'static {
type Item = ResolvedSrvItem<S>;
type Error = SrvError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let res = if let Some(ref mut query) = self.lookup {
match query.poll() {
Ok(Async::NotReady) => return Ok(Async::NotReady),
Ok(Async::Ready(found)) => {
Some(ResolvedSrvItem::from_item_and_hosts(
self.items.pop().unwrap(),
found
))
}
Err(_) => None
}
}
else {
None
};
if let Some(res) = res {
self.lookup = None;
return Ok(Async::Ready(Some(res)))
}
match self.items.last() {
Some(item) => match item.state {
SrvItemState::Unresolved(ref host) => {
self.lookup = Some(lookup_host(&self.resolver, host));
}
_ => { }
}
None => return Ok(Async::Ready(None))
}
if self.lookup.is_some() {
self.poll()
}
else {
Ok(Async::Ready(Some(
ResolvedSrvItem::from_item(self.items.pop().unwrap()).unwrap()
)))
}
}
}
#[derive(Clone, Debug)]
pub struct FoundSrvs<R, S> {
resolver: R,
items: Vec<SrvItem<S>>,
}
impl<R, S> FoundSrvs<R, S> {
pub fn into_stream(self) -> LookupSrvStream<R, S>
where R: Resolver {
LookupSrvStream::new(self)
}
pub fn merge(&mut self, other : &mut Self) {
self.items.append(&mut other.items);
Self::reorder_items(&mut self.items);
}
}
impl<R: Resolver, S: Clone> FoundSrvs<R, S> {
fn new<N: ToDname>(
answer: R::Answer,
data: LookupData<R, S, N>
) -> Result<Option<Self>, SrvError> {
let name = answer.as_ref().canonical_name().unwrap();
let mut rrs = Vec::new();
Self::process_records(&mut rrs, &answer, &name)?;
if rrs.len() == 0 {
return Ok(Some(Self::new_dummy(data)))
}
if rrs.len() == 1 && rrs[0].target().is_root() {
return Ok(None)
}
let mut items = Vec::with_capacity(rrs.len());
Self::items_from_rrs(&rrs, &answer, &mut items, &data)?;
Self::reorder_items(&mut items);
Ok(Some(FoundSrvs {
resolver: data.resolver,
items
}))
}
fn new_dummy<N: ToDname>(data: LookupData<R, S, N>) -> Self {
FoundSrvs {
resolver: data.resolver,
items: vec![
SrvItem {
priority: 0,
weight: 0,
port: data.fallback_port,
service: None,
state: SrvItemState::Unresolved(data.host.to_name())
}
]
}
}
fn process_records(
rrs: &mut Vec<Srv>,
answer: &R::Answer,
name: &ParsedDname
) -> Result<(), SrvError> {
for record in answer.as_ref().answer()?.limit_to::<Srv>() {
if let Ok(record) = record {
if record.owner() == name {
rrs.push(record.data().clone())
}
}
}
Ok(())
}
fn items_from_rrs<N>(
rrs: &[Srv],
answer: &R::Answer,
result: &mut Vec<SrvItem<S>>,
data: &LookupData<R, S, N>,
) -> Result<(), SrvError> {
for rr in rrs {
let mut addrs = Vec::new();
let name = rr.target().to_name();
for record in answer.as_ref().additional()?.limit_to::<A>() {
if let Ok(record) = record {
if record.owner() == &name {
addrs.push(record.data().addr().into())
}
}
}
for record in answer.as_ref().additional()?.limit_to::<Aaaa>() {
if let Ok(record) = record {
if record.owner() == &name {
addrs.push(record.data().addr().into())
}
}
}
let state = if addrs.is_empty() {
SrvItemState::Unresolved(name)
}
else {
SrvItemState::Resolved(FoundHosts::new(name, addrs))
};
result.push(SrvItem {
priority: rr.priority(),
weight: rr.weight(),
state: state,
port: rr.port(),
service: Some(data.service.clone())
})
}
Ok(())
}
}
impl<R, S> FoundSrvs<R, S> {
fn reorder_items(items: &mut [SrvItem<S>]) {
items.sort_by_key(|k| (k.priority, k.weight));
let mut current_prio = 0;
let mut weight_sum = 0;
let mut first_index = 0;
for i in 0 .. items.len() {
if current_prio != items[i].priority {
current_prio = items[i].priority;
Self::reorder_by_weight(&mut items[first_index..i], weight_sum);
weight_sum = 0;
first_index = i;
}
weight_sum += items[i].weight as u32;
}
Self::reorder_by_weight(&mut items[first_index..], weight_sum);
}
fn reorder_by_weight(items: &mut [SrvItem<S>], weight_sum : u32) {
let mut rng = rand::thread_rng();
let mut weight_sum = weight_sum;
for i in 0 .. items.len() {
let range = Uniform::new(0, weight_sum + 1);
let mut sum : u32 = 0;
let pick = range.sample(&mut rng);
for j in 0 .. items.len() {
sum += items[j].weight as u32;
if sum >= pick {
weight_sum -= items[j].weight as u32;
items.swap(i, j);
break;
}
}
}
}
}
#[derive(Clone, Debug)]
pub struct SrvItem<S> {
priority: u16,
weight: u16,
port: u16,
service: Option<S>,
state: SrvItemState
}
#[derive(Clone, Debug)]
pub enum SrvItemState {
Unresolved(Dname),
Resolved(FoundHosts)
}
impl<S> SrvItem<S> {
pub fn txt_service(&self) -> Option<&S> {
self.service.as_ref()
}
pub fn target(&self) -> &Dname {
match self.state {
SrvItemState::Unresolved(ref target) => target,
SrvItemState::Resolved(ref found_hosts) => found_hosts.canonical_name()
}
}
}
#[derive(Clone, Debug)]
pub struct ResolvedSrvItem<S> {
priority: u16,
weight: u16,
port: u16,
service: Option<S>,
hosts: FoundHosts,
}
impl<S> ResolvedSrvItem<S> {
pub fn to_socket_addrs(&self) -> FoundHostsSocketIter {
self.hosts.port_iter(self.port)
}
fn from_item(item: SrvItem<S>) -> Option<Self> {
if let SrvItemState::Resolved(hosts) = item.state {
Some(ResolvedSrvItem {
priority: item.priority,
weight: item.weight,
port: item.port,
service: item.service,
hosts: hosts
})
}
else {
None
}
}
fn from_item_and_hosts(item: SrvItem<S>, hosts: FoundHosts) -> Self {
ResolvedSrvItem {
priority: item.priority,
weight: item.weight,
port: item.port,
service: item.service,
hosts: hosts
}
}
}
#[derive(Debug)]
pub enum SrvError {
LongName,
MalformedAnswer,
Query(io::Error),
}
impl From<io::Error> for SrvError {
fn from(err: io::Error) -> SrvError {
SrvError::Query(err)
}
}
impl From<ParsedDnameError> for SrvError {
fn from(_: ParsedDnameError) -> SrvError {
SrvError::MalformedAnswer
}
}