libdoh 0.9.7

DoH and Oblivious DoH library for the rust-doh app
Documentation
use std::fmt;
use std::sync::Arc;
use std::time::Duration;

use arc_swap::ArcSwap;
use odoh_rs::{
    Deserialize, ObliviousDoHConfig, ObliviousDoHConfigs, ObliviousDoHKeyPair, ObliviousDoHMessage,
    ObliviousDoHMessagePlaintext, OdohSecret, ResponseNonce, Serialize,
};
use rand::Rng;
use tokio::runtime;

use crate::constants::ODOH_KEY_ROTATION_SECS;
use crate::errors::DoHError;

#[derive(Clone)]
pub struct ODoHPublicKey {
    key_pair: ObliviousDoHKeyPair,
    serialized_configs: Vec<u8>,
}

impl fmt::Debug for ODoHPublicKey {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("ODoHPublicKey").finish()
    }
}

#[derive(Clone, Debug)]
pub struct ODoHQueryContext {
    query: ObliviousDoHMessagePlaintext,
    server_secret: OdohSecret,
}

impl ODoHPublicKey {
    pub fn new() -> Result<ODoHPublicKey, DoHError> {
        let key_pair = ObliviousDoHKeyPair::new(&mut rand::thread_rng());
        let config = ObliviousDoHConfig::from(key_pair.public().clone());
        let mut serialized_configs = Vec::new();
        ObliviousDoHConfigs::from(vec![config])
            .serialize(&mut serialized_configs)
            .map_err(|e| DoHError::ODoHConfigError(e.into()))?;
        Ok(ODoHPublicKey {
            key_pair,
            serialized_configs,
        })
    }

    pub fn into_config(self) -> Vec<u8> {
        self.serialized_configs
    }

    pub fn decrypt_query(
        self,
        encrypted_query: Vec<u8>,
    ) -> Result<(Vec<u8>, ODoHQueryContext), DoHError> {
        let odoh_query = ObliviousDoHMessage::deserialize(&mut bytes::Bytes::from(encrypted_query))
            .map_err(|_| DoHError::InvalidData)?;
        match self.key_pair.public().identifier() {
            Ok(key_id) => {
                if !key_id.eq(&odoh_query.key_id()) {
                    return Err(DoHError::StaleKey);
                }
            }
            Err(_) => return Err(DoHError::InvalidData),
        };
        let (query, server_secret) = match odoh_rs::decrypt_query(&odoh_query, &self.key_pair) {
            Ok((pq, ss)) => (pq, ss),
            Err(_) => return Err(DoHError::InvalidData),
        };
        let context = ODoHQueryContext {
            query: query.clone(),
            server_secret,
        };
        Ok((query.into_msg().to_vec(), context))
    }
}

impl ODoHQueryContext {
    pub fn encrypt_response(self, response_body: Vec<u8>) -> Result<Vec<u8>, DoHError> {
        let response_nonce = rand::thread_rng().gen::<ResponseNonce>();
        let response_body_ = ObliviousDoHMessagePlaintext::new(response_body, 0);
        let encrypted_response = odoh_rs::encrypt_response(
            &self.query,
            &response_body_,
            self.server_secret,
            response_nonce,
        )
        .map_err(|_| DoHError::InvalidData)?;
        let mut encrypted_response_bytes = Vec::new();
        encrypted_response
            .serialize(&mut encrypted_response_bytes)
            .map_err(|_| DoHError::InvalidData)?;
        Ok(encrypted_response_bytes)
    }
}

#[derive(Clone, Debug)]
pub struct ODoHRotator {
    key: Arc<ArcSwap<ODoHPublicKey>>,
}

impl ODoHRotator {
    pub fn new(runtime_handle: runtime::Handle) -> Result<ODoHRotator, DoHError> {
        let public_key = match ODoHPublicKey::new() {
            Ok(key) => Arc::new(ArcSwap::from_pointee(key)),
            Err(e) => panic!("ODoH key rotation error: {}", e),
        };

        let current_key = Arc::clone(&public_key);

        runtime_handle.spawn(async move {
            loop {
                tokio::time::sleep(Duration::from_secs(ODOH_KEY_ROTATION_SECS.into())).await;
                match ODoHPublicKey::new() {
                    Ok(key) => {
                        current_key.store(Arc::new(key));
                    }
                    Err(e) => eprintln!("ODoH key rotation error: {e}"),
                };
            }
        });

        Ok(ODoHRotator {
            key: Arc::clone(&public_key),
        })
    }

    pub fn current_public_key(&self) -> Arc<ODoHPublicKey> {
        let key = Arc::clone(&self.key);
        Arc::clone(&key.load())
    }
}