oqs-safe 0.5.0

Post-Quantum Cryptography (PQC) toolkit in Rust with ML-KEM, ML-DSA, hybrid cryptography (X25519 + ML-KEM), and secure session primitives.
Documentation
// Minimal FFI surface for libOQS when `--features liboqs` is enabled.
// Generic algorithm support with ML-* names and legacy fallbacks.

use crate::kem::KemAlgorithm;
use crate::sig::SigAlgorithm;
use crate::OqsError;
use core::ffi::{c_char, c_int, c_uint};
use std::ffi::CString;

#[link(name = "oqs")]
extern "C" {
    fn OQS_KEM_new(method_name: *const c_char) -> *mut OQS_KEM;
    fn OQS_KEM_free(kem: *mut OQS_KEM);
    fn OQS_KEM_keypair(kem: *const OQS_KEM, pub_key: *mut u8, sec_key: *mut u8) -> c_int;
    fn OQS_KEM_encaps(kem: *const OQS_KEM, ct: *mut u8, ss: *mut u8, pub_key: *const u8) -> c_int;
    fn OQS_KEM_decaps(kem: *const OQS_KEM, ss: *mut u8, ct: *const u8, sec_key: *const u8)
        -> c_int;

    fn OQS_SIG_new(method_name: *const c_char) -> *mut OQS_SIG;
    fn OQS_SIG_free(sig: *mut OQS_SIG);
    fn OQS_SIG_keypair(sig: *const OQS_SIG, pub_key: *mut u8, sec_key: *mut u8) -> c_int;
    fn OQS_SIG_sign(
        sig: *const OQS_SIG,
        sig_out: *mut u8,
        sig_len: *mut usize,
        msg: *const u8,
        msg_len: usize,
        sec_key: *const u8,
    ) -> c_int;
    fn OQS_SIG_verify(
        sig: *const OQS_SIG,
        msg: *const u8,
        msg_len: usize,
        sig_in: *const u8,
        sig_len: usize,
        pub_key: *const u8,
    ) -> c_int;
}

#[repr(C)]
struct OQS_KEM {
    method_name: *const c_char,
    alg_version: *const c_char,
    claimed_nist_level: c_uint,
    ind_cca: c_uint,

    length_public_key: usize,
    length_secret_key: usize,
    length_ciphertext: usize,
    length_shared_secret: usize,

    keypair_fn: *const core::ffi::c_void,
    encaps_fn: *const core::ffi::c_void,
    decaps_fn: *const core::ffi::c_void,
}

#[repr(C)]
struct OQS_SIG {
    method_name: *const c_char,
    alg_version: *const c_char,
    claimed_nist_level: c_uint,
    euf_cma: c_uint,

    length_public_key: usize,
    length_secret_key: usize,
    length_signature: usize,

    keypair_fn: *const core::ffi::c_void,
    sign_fn: *const core::ffi::c_void,
    verify_fn: *const core::ffi::c_void,
}

unsafe fn kem_new(alg: KemAlgorithm) -> *mut OQS_KEM {
    let (nist_name, legacy_name) = alg.names();

    for name in [nist_name, legacy_name] {
        let cname = CString::new(name).expect("CString::new failed");

        let ptr = unsafe { OQS_KEM_new(cname.as_ptr()) };

        if !ptr.is_null() {
            return ptr;
        }
    }

    core::ptr::null_mut()
}

unsafe fn sig_new(alg: SigAlgorithm) -> *mut OQS_SIG {
    for name in alg.names() {
        let cname = CString::new(*name).expect("CString::new failed");

        let ptr = unsafe { OQS_SIG_new(cname.as_ptr()) };

        if !ptr.is_null() {
            return ptr;
        }
    }

    core::ptr::null_mut()
}

pub fn kem_keypair(alg: KemAlgorithm) -> Result<(Vec<u8>, Vec<u8>), OqsError> {
    unsafe {
        let kem = kem_new(alg);

        if kem.is_null() {
            return Err(OqsError::Internal("kem new"));
        }

        let lengths = &*kem;

        let mut pk = vec![0u8; lengths.length_public_key];
        let mut sk = vec![0u8; lengths.length_secret_key];

        let rc = OQS_KEM_keypair(kem, pk.as_mut_ptr(), sk.as_mut_ptr());

        OQS_KEM_free(kem);

        if rc != 0 {
            return Err(OqsError::Internal("kem keypair"));
        }

        Ok((pk, sk))
    }
}

pub fn kem_encapsulate(alg: KemAlgorithm, pk: &[u8]) -> Result<(Vec<u8>, Vec<u8>), OqsError> {
    unsafe {
        let kem = kem_new(alg);

        if kem.is_null() {
            return Err(OqsError::Internal("kem new"));
        }

        let lengths = &*kem;

        if pk.len() != lengths.length_public_key {
            OQS_KEM_free(kem);
            return Err(OqsError::InvalidLength);
        }

        let mut ct = vec![0u8; lengths.length_ciphertext];
        let mut ss = vec![0u8; lengths.length_shared_secret];

        let rc = OQS_KEM_encaps(kem, ct.as_mut_ptr(), ss.as_mut_ptr(), pk.as_ptr());

        OQS_KEM_free(kem);

        if rc != 0 {
            return Err(OqsError::Internal("kem encaps"));
        }

        Ok((ct, ss))
    }
}

pub fn kem_decapsulate(alg: KemAlgorithm, ct: &[u8], sk: &[u8]) -> Result<Vec<u8>, OqsError> {
    unsafe {
        let kem = kem_new(alg);

        if kem.is_null() {
            return Err(OqsError::Internal("kem new"));
        }

        let lengths = &*kem;

        if ct.len() != lengths.length_ciphertext || sk.len() != lengths.length_secret_key {
            OQS_KEM_free(kem);
            return Err(OqsError::InvalidLength);
        }

        let mut ss = vec![0u8; lengths.length_shared_secret];

        let rc = OQS_KEM_decaps(kem, ss.as_mut_ptr(), ct.as_ptr(), sk.as_ptr());

        OQS_KEM_free(kem);

        if rc != 0 {
            return Err(OqsError::Internal("kem decaps"));
        }

        Ok(ss)
    }
}

pub fn sig_keypair(alg: SigAlgorithm) -> Result<(Vec<u8>, Vec<u8>), OqsError> {
    unsafe {
        let sig = sig_new(alg);

        if sig.is_null() {
            return Err(OqsError::Internal("sig new"));
        }

        let lengths = &*sig;

        let mut pk = vec![0u8; lengths.length_public_key];
        let mut sk = vec![0u8; lengths.length_secret_key];

        let rc = OQS_SIG_keypair(sig, pk.as_mut_ptr(), sk.as_mut_ptr());

        OQS_SIG_free(sig);

        if rc != 0 {
            return Err(OqsError::Internal("sig keypair"));
        }

        Ok((pk, sk))
    }
}

pub fn sig_sign(alg: SigAlgorithm, sk: &[u8], msg: &[u8]) -> Result<Vec<u8>, OqsError> {
    unsafe {
        let sig = sig_new(alg);

        if sig.is_null() {
            return Err(OqsError::Internal("sig new"));
        }

        let lengths = &*sig;

        if sk.len() != lengths.length_secret_key {
            OQS_SIG_free(sig);
            return Err(OqsError::InvalidLength);
        }

        let mut out = vec![0u8; lengths.length_signature];
        let mut out_len: usize = 0;

        let rc = OQS_SIG_sign(
            sig,
            out.as_mut_ptr(),
            &mut out_len,
            msg.as_ptr(),
            msg.len(),
            sk.as_ptr(),
        );

        OQS_SIG_free(sig);

        if rc != 0 {
            return Err(OqsError::Internal("sig sign"));
        }

        out.truncate(out_len);

        Ok(out)
    }
}

pub fn sig_verify(alg: SigAlgorithm, pk: &[u8], msg: &[u8], sig_in: &[u8]) -> Result<(), OqsError> {
    unsafe {
        let sig = sig_new(alg);

        if sig.is_null() {
            return Err(OqsError::Internal("sig new"));
        }

        let lengths = &*sig;

        if pk.len() != lengths.length_public_key {
            OQS_SIG_free(sig);
            return Err(OqsError::InvalidLength);
        }

        let rc = OQS_SIG_verify(
            sig,
            msg.as_ptr(),
            msg.len(),
            sig_in.as_ptr(),
            sig_in.len(),
            pk.as_ptr(),
        );

        OQS_SIG_free(sig);

        if rc != 0 {
            return Err(OqsError::VerifyFail);
        }

        Ok(())
    }
}