pq_msg/exchange/
pair.rs

1use pqcrypto_mlkem::{
2    ffi::{
3        PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES, PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES,
4        PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES,
5    },
6    mlkem1024::{self, SharedSecret},
7    mlkem1024_decapsulate, mlkem1024_encapsulate, mlkem1024_keypair,
8};
9use pqcrypto_traits::kem::{PublicKey, SecretKey};
10
11use crate::errors::CryptoError;
12
13/// A Key Encapsulation Mechanism (KEM) pair using ML-KEM (formerly Kyber)
14///
15/// This struct represents a post-quantum cryptography key pair used for
16/// key encapsulation and decapsulation operations. It utilizes ML-KEM1024,
17/// which provides 256-bit equivalent security strength.
18pub struct KEMPair {
19    pub_key: mlkem1024::PublicKey,
20    sec_key: mlkem1024::SecretKey,
21}
22
23impl KEMPair {
24    /// Creates a new random KEM pair
25    ///
26    /// # Returns
27    /// A new KEMPair with generated public and secret keys
28    pub fn create() -> Self {
29        let (pk, sk) = mlkem1024_keypair();
30        Self {
31            pub_key: pk,
32            sec_key: sk,
33        }
34    }
35
36    /// Creates a KEM pair from separate public and secret key bytes
37    ///
38    /// # Arguments
39    /// * `pub_key` - The public key bytes
40    /// * `sec_key` - The secret key bytes
41    ///
42    /// # Returns
43    /// - `Result<KEMPair, CryptoError>`: The constructed KEMPair or an error
44    pub fn from_bytes(pub_key: &[u8], sec_key: &[u8]) -> Result<Self, CryptoError> {
45        let pub_key = mlkem1024::PublicKey::from_bytes(pub_key)?;
46        let sec_key = mlkem1024::SecretKey::from_bytes(sec_key)?;
47        Ok(Self { pub_key, sec_key })
48    }
49
50    /// Converts the key pair to raw byte arrays
51    ///
52    /// # Returns
53    /// - `Result<([u8; PUBLICKEYBYTES], [u8; SECRETKEYBYTES]), CryptoError>`:
54    ///   A tuple containing the public and secret keys as byte arrays
55    pub fn to_bytes(
56        &self,
57    ) -> Result<
58        (
59            [u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES],
60            [u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES],
61        ),
62        CryptoError,
63    > {
64        Ok((
65            self.pub_key.as_bytes().try_into()?,
66            self.sec_key.as_bytes().try_into()?,
67        ))
68    }
69
70    /// Converts the key pair to a single byte vector with public key followed by secret key
71    ///
72    /// # Returns
73    /// A vector containing the concatenated public and secret key bytes
74    pub fn to_bytes_uniform(&self) -> Vec<u8> {
75        let mut bytes = Vec::new();
76        bytes.extend_from_slice(self.pub_key.as_bytes());
77        bytes.extend_from_slice(self.sec_key.as_bytes());
78        bytes
79    }
80
81    /// Creates a KEM pair from a single byte slice containing both public and secret keys
82    ///
83    /// # Arguments
84    /// * `bytes` - The concatenated public and secret key bytes
85    ///
86    /// # Returns
87    /// - `Result<KEMPair, CryptoError>`: The constructed KEMPair or an error
88    pub fn from_bytes_uniform(bytes: &[u8]) -> Result<Self, CryptoError> {
89        if bytes.len()
90            != PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES
91                + PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES
92        {
93            return Err(CryptoError::IncongruentLength(
94                PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES
95                    + PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES,
96                bytes.len(),
97            ));
98        }
99        let pub_key = mlkem1024::PublicKey::from_bytes(
100            &bytes[..PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES],
101        )?;
102        let sec_key = mlkem1024::SecretKey::from_bytes(
103            &bytes[PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES..],
104        )?;
105        Ok(Self { pub_key, sec_key })
106    }
107
108    /// Encapsulates a shared secret using the provided public key
109    ///
110    /// # Arguments
111    /// * `receiver_pubkey` - The receiver's public key
112    ///
113    /// # Returns
114    /// A tuple containing the shared secret and the ciphertext to send to the receiver
115    pub fn encapsulate(
116        &self,
117        receiver_pubkey: &mlkem1024::PublicKey,
118    ) -> (SharedSecret, mlkem1024::Ciphertext) {
119        mlkem1024_encapsulate(receiver_pubkey)
120    }
121
122    /// Decapsulates a shared secret from the provided ciphertext using this pair's secret key
123    ///
124    /// # Arguments
125    /// * `ciphertext` - The ciphertext received from the sender
126    ///
127    /// # Returns
128    /// - `Result<SharedSecret, CryptoError>`: The decapsulated shared secret or an error
129    pub fn decapsulate(
130        &self,
131        ciphertext: &mlkem1024::Ciphertext,
132    ) -> Result<SharedSecret, CryptoError> {
133        let shared_secret = mlkem1024_decapsulate(ciphertext, &self.sec_key);
134        Ok(shared_secret)
135    }
136}
137
138/// Converts a SharedSecret to a byte array
139///
140/// # Arguments
141/// * `ss` - The SharedSecret to convert
142///
143/// # Returns
144/// A byte array representation of the shared secret
145pub fn ss2b(ss: &SharedSecret) -> [u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES] {
146    unsafe { *(ss as *const SharedSecret as *const [u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES]) }
147}
148
149/// Converts a byte array to a SharedSecret
150///
151/// # Arguments
152/// * `bytes` - The byte array to convert
153///
154/// # Returns
155/// A SharedSecret created from the provided bytes
156pub fn b2ss(bytes: &[u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES]) -> SharedSecret {
157    unsafe {
158        std::ptr::read(
159            bytes as *const [u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES] as *const SharedSecret,
160        )
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn test_keypair() {
170        let keypair = KEMPair::create();
171        let (pub_key, sec_key) = keypair.to_bytes().unwrap();
172        let new_keypair = KEMPair::from_bytes(&pub_key, &sec_key).unwrap();
173        assert_eq!(keypair.pub_key.as_bytes(), new_keypair.pub_key.as_bytes());
174        assert_eq!(keypair.sec_key.as_bytes(), new_keypair.sec_key.as_bytes());
175    }
176
177    #[test]
178    fn test_encapsulate_decapsulate() {
179        let sender = KEMPair::create();
180        let receiver = KEMPair::create();
181
182        let (shared_secret, ciphertext) = sender.encapsulate(&receiver.pub_key);
183        let dec_shared_secret = receiver.decapsulate(&ciphertext).unwrap();
184
185        let ss1 = ss2b(&shared_secret);
186        let ss2 = ss2b(&dec_shared_secret);
187
188        assert_eq!(ss1, ss2, "Difference in shared secrets!");
189    }
190}