spf-milter 0.6.0

Milter for SPF verification
Documentation
// SPF Milter – milter for SPF verification
// Copyright © 2020–2023 David Bürgin <dbuergin@gluet.ch>
//
// This program is free software: you can redistribute it and/or modify it under
// the terms of the GNU General Public License as published by the Free Software
// Foundation, either version 3 of the License, or (at your option) any later
// version.
//
// This program is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
// FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
// details.
//
// You should have received a copy of the GNU General Public License along with
// this program. If not, see <https://www.gnu.org/licenses/>.

use async_trait::async_trait;
use domain::{
    base::{Dname, Rtype},
    rdata::{Aaaa, Mx, Txt, A},
    resolv::{
        stub::conf::{ResolvConf, ResolvOptions},
        StubResolver,
    },
};
use std::{
    error::Error,
    io::{self, ErrorKind},
    net::{IpAddr, Ipv4Addr, Ipv6Addr},
    str::FromStr,
    sync::Arc,
    time::Duration,
};
use viaspf::lookup::{Lookup, LookupError, LookupResult, Name};

pub enum Resolver {
    Live(DomainResolver),
    Mock(Arc<Box<dyn Lookup>>),
}

pub struct DomainResolver {
    resolver: StubResolver,
}

impl DomainResolver {
    pub fn new(timeout: Duration) -> Self {
        let options = ResolvOptions {
            timeout,
            ..Default::default()
        };

        let mut conf = ResolvConf {
            options,
            ..Default::default()
        };

        conf.finalize();

        let resolver = StubResolver::from_conf(conf);

        Self { resolver }
    }
}

#[async_trait]
impl Lookup for DomainResolver {
    async fn lookup_a(&self, name: &Name) -> LookupResult<Vec<Ipv4Addr>> {
        let name = to_dname(name)?;
        self.resolver
            .query((name, Rtype::A))
            .await
            .map_err(to_lookup_error)?
            .answer()
            .map_err(wrap_error)?
            .limit_to::<A>()
            .map(|record| record.map(|r| r.data().addr()))
            .collect::<Result<Vec<_>, _>>()
            .map_err(wrap_error)
    }

    async fn lookup_aaaa(&self, name: &Name) -> LookupResult<Vec<Ipv6Addr>> {
        let name = to_dname(name)?;
        self.resolver
            .query((name, Rtype::Aaaa))
            .await
            .map_err(to_lookup_error)?
            .answer()
            .map_err(wrap_error)?
            .limit_to::<Aaaa>()
            .map(|record| record.map(|r| r.data().addr()))
            .collect::<Result<Vec<_>, _>>()
            .map_err(wrap_error)
    }

    async fn lookup_mx(&self, name: &Name) -> LookupResult<Vec<Name>> {
        let name = to_dname(name)?;
        let answer = self
            .resolver
            .query((name, Rtype::Mx))
            .await
            .map_err(to_lookup_error)?;
        let mut mxs = answer
            .answer()
            .map_err(wrap_error)?
            .limit_to::<Mx<_>>()
            .map(|record| record.map(|r| r.into_data()))
            .collect::<Result<Vec<_>, _>>()
            .map_err(wrap_error)?;
        mxs.sort_by_key(|mx| mx.preference());
        mxs.into_iter()
            .map(|mx| Name::new(&mx.exchange().to_string()))
            .collect::<Result<Vec<_>, _>>()
            .map_err(wrap_error)
    }

    async fn lookup_txt(&self, name: &Name) -> LookupResult<Vec<String>> {
        let name = to_dname(name)?;
        self.resolver
            .query((name, Rtype::Txt))
            .await
            .map_err(to_lookup_error)?
            .answer()
            .map_err(wrap_error)?
            .limit_to::<Txt<_>>()
            .map(|record| {
                record.map(|r| {
                    let text = r.into_data().text::<Vec<_>>();
                    String::from_utf8_lossy(&text).into_owned()
                })
            })
            .collect::<Result<Vec<_>, _>>()
            .map_err(wrap_error)
    }

    async fn lookup_ptr(&self, ip: IpAddr) -> LookupResult<Vec<Name>> {
        self.resolver
            .lookup_addr(ip)
            .await
            .map_err(to_lookup_error)?
            .into_iter()
            .map(|ptr| Name::new(&ptr.to_string()))
            .collect::<Result<Vec<_>, _>>()
            .map_err(wrap_error)
    }
}

fn to_dname(name: &Name) -> LookupResult<Dname<Vec<u8>>> {
    Dname::from_str(name.as_str()).map_err(wrap_error)
}

fn to_lookup_error(error: io::Error) -> LookupError {
    match error.kind() {
        ErrorKind::NotFound => LookupError::NoRecords,
        ErrorKind::TimedOut => LookupError::Timeout,
        _ => wrap_error(error),
    }
}

fn wrap_error(error: impl Error + Send + Sync + 'static) -> LookupError {
    LookupError::Dns(Some(error.into()))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    #[ignore = "depends on live DNS records"]
    async fn domain_resolver_lookup_ok() {
        let resolver = DomainResolver::new(Duration::from_secs(30));

        let domain = Name::new("gluet.ch").unwrap();

        let ips = resolver.lookup_a(&domain).await;
        assert!(ips.is_ok());
        let ip = ips.unwrap().into_iter().next().unwrap();
        assert!(resolver.lookup_ptr(ip.into()).await.is_ok());

        let ips = resolver.lookup_aaaa(&domain).await;
        assert!(ips.is_ok());
        let ip = ips.unwrap().into_iter().next().unwrap();
        assert!(resolver.lookup_ptr(ip.into()).await.is_ok());
    }
}