doh-client 3.1.2

DNS over HTTPS client
Documentation
use crate::context::Context;
use crate::{Cache, DohError, DohResult};
use bytes::Bytes;
use dns_message_parser::question::Question;
use dns_message_parser::Dns;
use futures::lock::Mutex;
use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::UdpSocket;
use tokio::time::timeout as create_timeout;

async fn send_response(
    dns_response: &mut Dns,
    id: u16,
    addr: SocketAddr,
    sender: Arc<UdpSocket>,
) -> DohResult<()> {
    dns_response.id = id;
    let bytes = dns_response.encode()?;
    sender.send_to(bytes.as_ref(), addr).await?;
    Ok(())
}

enum CacheReturn<'a> {
    Found(DohResult<()>),
    NotFound(Option<(&'a Mutex<Cache<Question, Dns>>, Question)>),
}

#[allow(clippy::needless_lifetimes)]
async fn get_response_from_cache<'a>(
    context: &'a Context,
    dns_request: &Dns,
    addr: &SocketAddr,
) -> CacheReturn<'a> {
    if let Some(cache) = &context.cache {
        let questions = &dns_request.questions;
        if questions.len() == 1 {
            let question = &questions[0];
            let mut guard_cache = cache.lock().await;
            let entry = if context.cache_fallback {
                guard_cache.get_expired(question)
            } else {
                guard_cache.get(question)
            };

            if let Some(dns_response) = entry {
                let id = dns_request.id;
                let sender = context.sender.clone();
                let addr = *addr;
                debug!("Question is found in cache");
                let result = send_response(dns_response, id, addr, sender).await;
                CacheReturn::Found(result)
            } else {
                debug!("Question is not found in cache");
                CacheReturn::NotFound(Some((cache, question.clone())))
            }
        } else {
            debug!("The amount of questions is not equal 1");
            CacheReturn::NotFound(None)
        }
    } else {
        debug!("Cache is disable");
        CacheReturn::NotFound(None)
    }
}

async fn get_response(
    context: &Context,
    cache_question: &Option<(&Mutex<Cache<Question, Dns>>, Question)>,
    response: (
        impl Future<Output = DohResult<(Dns, Option<Duration>)>>,
        u32,
    ),
    id: u16,
    addr: &SocketAddr,
) -> Option<DohResult<()>> {
    let (response_future, connection_id) = response;
    let timeout = context.timeout;
    match create_timeout(timeout, response_future).await {
        Ok(Ok((mut dns_response, duration))) => {
            let addr = *addr;
            let sender = context.sender.clone();
            let result = send_response(&mut dns_response, id, addr, sender).await;
            if let Some(duration) = duration {
                if let Some((cache, question)) = cache_question {
                    let mut guard_cache = cache.lock().await;
                    debug!(
                        "Add records in cache: {}, {}, {:?}",
                        question, dns_response, duration
                    );
                    guard_cache.put(question.clone(), dns_response, duration);
                }
            }
            return Some(result);
        }
        Ok(Err(e)) => {
            error!("Could not retrieve DNS response from server: {}", e);
        }
        Err(e) => {
            error!("Timeout: {}", e);
        }
    }
    let mut guard_remote_session = context.remote_session.lock().await;
    guard_remote_session.disconnect(connection_id);
    None
}

async fn get_response_from_remote(
    context: &Context,
    cache_question: &Option<(&Mutex<Cache<Question, Dns>>, Question)>,
    dns_request: &mut Dns,
    addr: &SocketAddr,
) -> Option<DohResult<()>> {
    let mut guard_remote_session = context.remote_session.lock().await;
    let result = guard_remote_session.start_request(dns_request).await;
    drop(guard_remote_session);
    match result {
        Ok(response) => {
            let id = dns_request.id;
            get_response(context, cache_question, response, id, addr).await
        }
        Err(e) => {
            info!("Could not contact DNS server: {}", e);
            None
        }
    }
}

#[allow(clippy::needless_lifetimes)]
async fn get_response_from_cache_fallback<'a>(
    context: &'a Context,
    cache_question: Option<(&Mutex<Cache<Question, Dns>>, Question)>,
    dns_request: &Dns,
    addr: SocketAddr,
) -> Option<DohResult<()>> {
    if context.cache_fallback {
        if let Some((cache, question)) = &cache_question {
            let mut guard_cache = cache.lock().await;
            if let Some(dns_response) = guard_cache.get_expired_fallback(question) {
                let id = dns_request.id;
                let sender = context.sender.clone();
                debug!("Question is found in cache fallback");
                let result = send_response(dns_response, id, addr, sender).await;
                Some(result)
            } else {
                debug!("Question is not found in cache fallback");
                None
            }
        } else {
            debug!("Question cannot be cached");
            None
        }
    } else {
        debug!("Cache fallback is disable");
        None
    }
}

pub async fn request_handler(
    msg: Bytes,
    addr: SocketAddr,
    context: &'static Context,
) -> DohResult<()> {
    let mut dns_request = Dns::decode(msg)?;
    if dns_request.is_response() {
        return Err(DohError::DnsNotRequest(dns_request));
    }

    let cache = get_response_from_cache(context, &dns_request, &addr).await;
    let cache_question = match cache {
        CacheReturn::Found(result) => return result,
        CacheReturn::NotFound(cache_question) => cache_question,
    };

    let remote = get_response_from_remote(context, &cache_question, &mut dns_request, &addr).await;
    if let Some(result) = remote {
        return result;
    }

    let fallback =
        get_response_from_cache_fallback(context, cache_question, &dns_request, addr).await;
    if let Some(result) = fallback {
        return result;
    }

    Err(DohError::CouldNotGetResponse(dns_request))
}