clavis 0.1.1-rc7

High-performance async network encryption library featuring XChaCha20-Poly1305, type-safe protocol DSL, and zero-copy packet streaming for secure communication channels
Documentation
use chacha20poly1305::{
    aead::{Aead, AeadCore, KeyInit, OsRng},
    XChaCha20Poly1305, XNonce,
};
use hkdf::Hkdf;
use hmac::{Hmac, Mac};
use rand::RngCore;
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use x25519_dalek::{EphemeralSecret, PublicKey};

use crate::{
    error::{ClavisError, ClavisResult, CryptoError, CryptoOperation, MessageError, StreamError},
    stream::EncryptedStreamOptions,
};

type HmacSha256 = Hmac<Sha256>;

struct CryptoContext {
    cipher: XChaCha20Poly1305,
    buffer: Vec<u8>,
    max_message_size: usize,
}

impl CryptoContext {
    fn new(key: &[u8], max_message_size: usize) -> ClavisResult<Self> {
        Ok(Self {
            cipher: XChaCha20Poly1305::new_from_slice(key).map_err(|e| {
                ClavisError::Crypto(CryptoError::InvalidKeyMaterial(format!(
                    "Invalid key material: {}",
                    e
                )))
            })?,
            buffer: Vec::with_capacity(4096),
            max_message_size,
        })
    }

    #[inline]
    fn validate_message_size(&self, size: usize) -> ClavisResult<()> {
        if size > self.max_message_size {
            Err(ClavisError::Message(MessageError::MessageTooLarge {
                size,
                max_size: self.max_message_size,
            }))
        } else {
            Ok(())
        }
    }

    async fn read_message<R: AsyncRead + Unpin>(
        &mut self,
        stream: &mut R,
    ) -> ClavisResult<Vec<u8>> {
        let length = stream.read_u32_le().await.map_err(|e| match e.kind() {
            std::io::ErrorKind::UnexpectedEof => ClavisError::Stream(StreamError::UnexpectedClose),
            _ => ClavisError::Stream(StreamError::Io(e)),
        })? as usize;

        self.validate_message_size(length)?;

        let mut nonce = [0u8; 24];
        stream
            .read_exact(&mut nonce)
            .await
            .map_err(|e| match e.kind() {
                std::io::ErrorKind::UnexpectedEof => {
                    ClavisError::Stream(StreamError::UnexpectedClose)
                }
                _ => ClavisError::Stream(StreamError::Io(e)),
            })?;

        let mut ciphertext = vec![0u8; length];
        stream
            .read_exact(&mut ciphertext)
            .await
            .map_err(|e| match e.kind() {
                std::io::ErrorKind::UnexpectedEof => {
                    ClavisError::Stream(StreamError::UnexpectedClose)
                }
                _ => ClavisError::Stream(StreamError::Io(e)),
            })?;

        let plaintext = self
            .cipher
            .decrypt(XNonce::from_slice(&nonce), ciphertext.as_ref())
            .map_err(|e| ClavisError::crypto_failure(CryptoOperation::Decryption, e.to_string()))?;

        Ok(plaintext)
    }

    async fn write_message<W: AsyncWrite + Unpin>(
        &mut self,
        stream: &mut W,
        message: &[u8],
    ) -> ClavisResult<()> {
        self.validate_message_size(message.len())?;

        let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);

        let ciphertext = self
            .cipher
            .encrypt(&nonce, message)
            .map_err(|e| ClavisError::crypto_failure(CryptoOperation::Encryption, e.to_string()))?;

        self.buffer.clear();
        self.buffer
            .extend_from_slice(&(ciphertext.len() as u32).to_le_bytes());
        self.buffer.extend_from_slice(&nonce);
        self.buffer.extend_from_slice(&ciphertext);

        stream
            .write_all(&self.buffer)
            .await
            .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;
        stream
            .flush()
            .await
            .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;

        Ok(())
    }
}

struct HandshakeContext {
    transcript: Sha256,
    mac: Option<HmacSha256>,
}

impl HandshakeContext {
    fn new(psk: Option<&[u8]>) -> ClavisResult<Self> {
        let mac = psk
            .map(<HmacSha256 as KeyInit>::new_from_slice)
            .transpose()
            .map_err(|e| {
                ClavisError::crypto_failure(
                    CryptoOperation::Authentication,
                    format!("Failed to create HMAC: {}", e),
                )
            })?;

        Ok(Self {
            transcript: Sha256::new(),
            mac,
        })
    }

    fn append(&mut self, data: &[u8]) {
        self.transcript.update(data);
        if let Some(mac) = &mut self.mac {
            mac.update(data);
        }
    }

    fn finalize(self) -> ([u8; 32], Option<[u8; 32]>) {
        let transcript_hash = self.transcript.finalize().into();
        let mac = self.mac.map(|mac| mac.finalize().into_bytes().into());
        (transcript_hash, mac)
    }
}

pub struct CryptoReader {
    context: CryptoContext,
}

impl CryptoReader {
    pub async fn read<R: AsyncRead + Unpin + Send>(
        &mut self,
        stream: &mut R,
    ) -> ClavisResult<Vec<u8>> {
        self.context.read_message(stream).await
    }
}

pub struct CryptoWriter {
    context: CryptoContext,
}

impl CryptoWriter {
    pub async fn write<W: AsyncWrite + Unpin + Send>(
        &mut self,
        stream: &mut W,
        message: &[u8],
    ) -> ClavisResult<()> {
        self.context.write_message(stream, message).await
    }
}

pub struct CryptoCore {
    pub(crate) reader: CryptoReader,
    pub(crate) writer: CryptoWriter,
}

impl CryptoCore {
    pub async fn establish<S: AsyncRead + AsyncWrite + Unpin + Send>(
        stream: &mut S,
        options: EncryptedStreamOptions,
    ) -> ClavisResult<Self> {
        let is_initiator = Self::determine_role(stream).await?;
        Self::validate_psk(&options)?;

        let (shared_secret, transcript_hash, mac) =
            Self::handshake(stream, options.psk.as_deref(), is_initiator).await?;
        Self::verify_mac(stream, mac, is_initiator).await?;
        let (enc_key, dec_key) = Self::derive_keys(&shared_secret, &transcript_hash, is_initiator)?;

        Ok(Self {
            reader: CryptoReader {
                context: CryptoContext::new(&dec_key, options.max_packet_size)?,
            },
            writer: CryptoWriter {
                context: CryptoContext::new(&enc_key, options.max_packet_size)?,
            },
        })
    }

    async fn determine_role<S: AsyncRead + AsyncWrite + Unpin>(
        stream: &mut S,
    ) -> ClavisResult<bool> {
        let mut local_nonce = [0u8; 32];
        OsRng.fill_bytes(&mut local_nonce);

        stream
            .write_all(&local_nonce)
            .await
            .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;
        stream
            .flush()
            .await
            .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;

        let mut peer_nonce = [0u8; 32];
        stream
            .read_exact(&mut peer_nonce)
            .await
            .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;

        Ok(local_nonce > peer_nonce)
    }

    fn validate_psk(options: &EncryptedStreamOptions) -> ClavisResult<()> {
        if let Some(psk) = &options.psk {
            if psk.len() < 16 {
                return Err(ClavisError::Crypto(CryptoError::InvalidKeyMaterial(
                    "Pre-shared key must be at least 16 bytes".into(),
                )));
            }
        }
        Ok(())
    }

    async fn verify_mac<S: AsyncRead + AsyncWrite + Unpin>(
        stream: &mut S,
        mac: Option<[u8; 32]>,
        is_initiator: bool,
    ) -> ClavisResult<()> {
        if let Some(local_mac) = mac {
            let mut peer_mac = [0u8; 32];

            if is_initiator {
                stream
                    .write_all(&local_mac)
                    .await
                    .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;
                stream
                    .flush()
                    .await
                    .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;
                stream
                    .read_exact(&mut peer_mac)
                    .await
                    .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;
            } else {
                stream
                    .read_exact(&mut peer_mac)
                    .await
                    .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;
                stream
                    .write_all(&local_mac)
                    .await
                    .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;
                stream
                    .flush()
                    .await
                    .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;
            }

            if local_mac.ct_eq(&peer_mac).unwrap_u8() == 0 {
                return Err(ClavisError::Crypto(CryptoError::AuthenticationFailure(
                    "MAC verification failed".into(),
                )));
            }
        }
        Ok(())
    }

    async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
        stream: &mut S,
        psk: Option<&[u8]>,
        is_initiator: bool,
    ) -> ClavisResult<([u8; 32], [u8; 32], Option<[u8; 32]>)> {
        let mut context = HandshakeContext::new(psk)?;
        let (secret, peer_key) = Self::exchange_keys(stream, &mut context, is_initiator).await?;

        let shared_secret = secret.diffie_hellman(&peer_key);
        let shared_secret_bytes = shared_secret.as_bytes();
        let (transcript_hash, mac) = context.finalize();

        Ok((*shared_secret_bytes, transcript_hash, mac))
    }

    async fn exchange_keys<S: AsyncRead + AsyncWrite + Unpin + Send>(
        stream: &mut S,
        context: &mut HandshakeContext,
        is_initiator: bool,
    ) -> ClavisResult<(EphemeralSecret, PublicKey)> {
        let secret = EphemeralSecret::random_from_rng(OsRng);
        let public = PublicKey::from(&secret);
        let mut peer_bytes = [0u8; 32];

        if is_initiator {
            stream
                .write_all(public.as_bytes())
                .await
                .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;
            stream
                .flush()
                .await
                .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;
            stream
                .read_exact(&mut peer_bytes)
                .await
                .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;
        } else {
            stream
                .read_exact(&mut peer_bytes)
                .await
                .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;
            stream
                .write_all(public.as_bytes())
                .await
                .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;
            stream
                .flush()
                .await
                .map_err(|e| ClavisError::Stream(StreamError::Io(e)))?;
        }

        if is_initiator {
            context.append(public.as_bytes());
            context.append(&peer_bytes);
        } else {
            context.append(&peer_bytes);
            context.append(public.as_bytes());
        }

        let peer_key = PublicKey::from(peer_bytes);
        Ok((secret, peer_key))
    }

    fn derive_keys(
        shared_secret: &[u8; 32],
        transcript_hash: &[u8; 32],
        is_initiator: bool,
    ) -> ClavisResult<([u8; 32], [u8; 32])> {
        let hkdf = Hkdf::<Sha256>::new(Some(transcript_hash), shared_secret);
        let mut enc_key = [0u8; 32];
        let mut dec_key = [0u8; 32];

        if is_initiator {
            hkdf.expand(b"enc", &mut enc_key).map_err(|_| {
                ClavisError::Crypto(CryptoError::KeyDerivationFailure(
                    "Failed to derive encryption key".into(),
                ))
            })?;
            hkdf.expand(b"dec", &mut dec_key).map_err(|_| {
                ClavisError::Crypto(CryptoError::KeyDerivationFailure(
                    "Failed to derive decryption key".into(),
                ))
            })?;
        } else {
            hkdf.expand(b"dec", &mut enc_key).map_err(|_| {
                ClavisError::Crypto(CryptoError::KeyDerivationFailure(
                    "Failed to derive encryption key".into(),
                ))
            })?;
            hkdf.expand(b"enc", &mut dec_key).map_err(|_| {
                ClavisError::Crypto(CryptoError::KeyDerivationFailure(
                    "Failed to derive decryption key".into(),
                ))
            })?;
        }

        Ok((enc_key, dec_key))
    }
}