oqs-safe 0.5.0

Post-Quantum Cryptography (PQC) toolkit in Rust with ML-KEM, ML-DSA, hybrid cryptography (X25519 + ML-KEM), and secure session primitives.
Documentation
use crate::{
    hybrid::derive_hybrid_secret,
    kem::{Kem, KemAlgorithm, KemInstance, SecretKey},
    session::SecureSession,
    OqsError,
};

use rand_core::OsRng;
#[cfg(not(feature = "liboqs"))]
use sha2::{Digest, Sha256};
use x25519_dalek::{PublicKey as X25519PublicKey, StaticSecret};

const HANDSHAKE_CONTEXT: &[u8] = b"oqs-safe-v0.5.0-hybrid-handshake";

#[derive(Clone, Debug)]
pub struct ClientHello {
    pub client_x25519_public: Vec<u8>,
    pub client_kem_public: Vec<u8>,
}

#[derive(Clone, Debug)]
pub struct ServerHello {
    pub server_x25519_public: Vec<u8>,
    pub kem_ciphertext: Vec<u8>,
}

#[derive(Debug)]
pub enum HandshakeError {
    MissingClientState,
    MissingServerState,
    InvalidHandshakeState,
    CryptoError(OqsError),
}

impl core::fmt::Display for HandshakeError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        match self {
            HandshakeError::MissingClientState => write!(f, "missing client handshake state"),
            HandshakeError::MissingServerState => write!(f, "missing server handshake state"),
            HandshakeError::InvalidHandshakeState => write!(f, "invalid handshake state"),
            HandshakeError::CryptoError(err) => write!(f, "cryptographic operation failed: {err}"),
        }
    }
}

impl std::error::Error for HandshakeError {}

impl From<OqsError> for HandshakeError {
    fn from(value: OqsError) -> Self {
        HandshakeError::CryptoError(value)
    }
}

pub struct HybridClient {
    kem: KemInstance,
    state: Option<ClientHandshakeState>,
}

struct ClientHandshakeState {
    x25519_secret: StaticSecret,
    kem_secret: SecretKey,
}

impl HybridClient {
    pub fn new() -> Self {
        Self {
            kem: KemInstance::new(KemAlgorithm::MlKem768),
            state: None,
        }
    }

    pub fn with_algorithm(algorithm: KemAlgorithm) -> Self {
        Self {
            kem: KemInstance::new(algorithm),
            state: None,
        }
    }

    pub fn start_handshake(&mut self) -> Result<ClientHello, HandshakeError> {
        let client_x25519_secret = StaticSecret::random_from_rng(OsRng);
        let client_x25519_public = X25519PublicKey::from(&client_x25519_secret);

        let (client_kem_public, client_kem_secret) = self.kem.keypair()?;

        self.state = Some(ClientHandshakeState {
            x25519_secret: client_x25519_secret,
            kem_secret: client_kem_secret,
        });

        Ok(ClientHello {
            client_x25519_public: client_x25519_public.as_bytes().to_vec(),
            client_kem_public: client_kem_public.as_bytes().to_vec(),
        })
    }

    pub fn finish(&mut self, server_hello: ServerHello) -> Result<SecureSession, HandshakeError> {
        let state = self
            .state
            .take()
            .ok_or(HandshakeError::MissingClientState)?;

        if server_hello.server_x25519_public.len() != 32 || server_hello.kem_ciphertext.is_empty() {
            return Err(HandshakeError::InvalidHandshakeState);
        }

        let server_public_bytes: [u8; 32] = server_hello
            .server_x25519_public
            .as_slice()
            .try_into()
            .map_err(|_| HandshakeError::InvalidHandshakeState)?;

        let server_x25519_public = X25519PublicKey::from(server_public_bytes);
        let classical_secret = state.x25519_secret.diffie_hellman(&server_x25519_public);

        let pqc_secret = client_pqc_secret(
            self.kem.algorithm(),
            &server_hello.kem_ciphertext,
            &state.kem_secret,
        )?;

        let hybrid_secret = derive_hybrid_secret(
            pqc_secret.as_slice(),
            classical_secret.as_bytes(),
            HANDSHAKE_CONTEXT,
        );

        Ok(SecureSession::new(hybrid_secret.as_bytes().to_vec()))
    }
}

impl Default for HybridClient {
    fn default() -> Self {
        Self::new()
    }
}

pub struct HybridServer {
    kem: KemInstance,
    session: Option<SecureSession>,
}

impl HybridServer {
    pub fn new() -> Self {
        Self {
            kem: KemInstance::new(KemAlgorithm::MlKem768),
            session: None,
        }
    }

    pub fn with_algorithm(algorithm: KemAlgorithm) -> Self {
        Self {
            kem: KemInstance::new(algorithm),
            session: None,
        }
    }

    pub fn respond(&mut self, client_hello: ClientHello) -> Result<ServerHello, HandshakeError> {
        if client_hello.client_x25519_public.len() != 32
            || client_hello.client_kem_public.is_empty()
        {
            return Err(HandshakeError::InvalidHandshakeState);
        }

        let client_public_bytes: [u8; 32] = client_hello
            .client_x25519_public
            .as_slice()
            .try_into()
            .map_err(|_| HandshakeError::InvalidHandshakeState)?;

        let client_x25519_public = X25519PublicKey::from(client_public_bytes);

        let server_x25519_secret = StaticSecret::random_from_rng(OsRng);
        let server_x25519_public = X25519PublicKey::from(&server_x25519_secret);

        let classical_secret = server_x25519_secret.diffie_hellman(&client_x25519_public);

        let (kem_ciphertext, pqc_secret) =
            server_pqc_secret(self.kem.algorithm(), &client_hello.client_kem_public)?;

        let hybrid_secret = derive_hybrid_secret(
            pqc_secret.as_slice(),
            classical_secret.as_bytes(),
            HANDSHAKE_CONTEXT,
        );

        self.session = Some(SecureSession::new(hybrid_secret.as_bytes().to_vec()));

        Ok(ServerHello {
            server_x25519_public: server_x25519_public.as_bytes().to_vec(),
            kem_ciphertext,
        })
    }

    pub fn session(&self) -> Result<&SecureSession, HandshakeError> {
        self.session
            .as_ref()
            .ok_or(HandshakeError::MissingServerState)
    }
}

impl Default for HybridServer {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(feature = "liboqs")]
fn server_pqc_secret(
    algorithm: KemAlgorithm,
    client_kem_public: &[u8],
) -> Result<(Vec<u8>, Vec<u8>), HandshakeError> {
    use crate::kem::PublicKey;

    let kem = KemInstance::new(algorithm);
    let client_public_key = PublicKey::new(algorithm, client_kem_public.to_vec());

    let (ciphertext, shared_secret) = kem.encapsulate(&client_public_key)?;

    Ok((
        ciphertext.as_bytes().to_vec(),
        shared_secret.as_bytes().to_vec(),
    ))
}

#[cfg(feature = "liboqs")]
fn client_pqc_secret(
    algorithm: KemAlgorithm,
    kem_ciphertext: &[u8],
    kem_secret: &SecretKey,
) -> Result<Vec<u8>, HandshakeError> {
    use crate::kem::Ciphertext;

    let kem = KemInstance::new(algorithm);
    let ciphertext = Ciphertext::new(algorithm, kem_ciphertext.to_vec());

    let shared_secret = kem.decapsulate(&ciphertext, kem_secret)?;

    Ok(shared_secret.as_bytes().to_vec())
}

#[cfg(not(feature = "liboqs"))]
fn server_pqc_secret(
    algorithm: KemAlgorithm,
    client_kem_public: &[u8],
) -> Result<(Vec<u8>, Vec<u8>), HandshakeError> {
    let ciphertext = mock_ciphertext(algorithm, client_kem_public);
    let shared_secret = mock_shared_secret(algorithm, &ciphertext);

    Ok((ciphertext, shared_secret))
}

#[cfg(not(feature = "liboqs"))]
fn client_pqc_secret(
    algorithm: KemAlgorithm,
    kem_ciphertext: &[u8],
    kem_secret: &SecretKey,
) -> Result<Vec<u8>, HandshakeError> {
    let _ = kem_secret;

    Ok(mock_shared_secret(algorithm, kem_ciphertext))
}

#[cfg(not(feature = "liboqs"))]
fn mock_ciphertext(algorithm: KemAlgorithm, client_kem_public: &[u8]) -> Vec<u8> {
    let mut ciphertext = vec![0u8; algorithm.ciphertext_len()];
    let mut counter = 0u64;
    let mut offset = 0usize;

    while offset < ciphertext.len() {
        let mut hasher = Sha256::new();
        hasher.update(b"oqs-safe-v0.5.0-mock-ciphertext");
        hasher.update(client_kem_public);
        hasher.update(counter.to_le_bytes());

        let block = hasher.finalize();
        let take = core::cmp::min(block.len(), ciphertext.len() - offset);

        ciphertext[offset..offset + take].copy_from_slice(&block[..take]);
        offset += take;
        counter += 1;
    }

    ciphertext
}

#[cfg(not(feature = "liboqs"))]
fn mock_shared_secret(algorithm: KemAlgorithm, kem_ciphertext: &[u8]) -> Vec<u8> {
    let mut hasher = Sha256::new();

    hasher.update(b"oqs-safe-v0.5.0-mock-pqc-secret");
    hasher.update(format!("{algorithm:?}").as_bytes());
    hasher.update(kem_ciphertext);

    hasher.finalize().to_vec()
}