quantacore/
kem.rs

1//! Key Encapsulation Mechanism (KEM) operations.
2//!
3//! This module provides post-quantum key encapsulation using ML-KEM (Kyber).
4
5use crate::device::Device;
6use crate::error::{check_error, Result};
7use crate::ffi;
8use crate::types::KemAlgorithm;
9
10use zeroize::ZeroizeOnDrop;
11
12/// KEM key pair.
13///
14/// Contains the public and secret keys generated by KEM key generation.
15/// The secret key is automatically zeroed when dropped.
16#[derive(Clone, ZeroizeOnDrop)]
17pub struct KeyPair {
18    /// Public key bytes
19    #[zeroize(skip)]
20    public_key: Vec<u8>,
21    /// Secret key bytes (zeroized on drop)
22    secret_key: Vec<u8>,
23    /// Algorithm used
24    #[zeroize(skip)]
25    algorithm: KemAlgorithm,
26}
27
28impl KeyPair {
29    /// Create a new key pair.
30    fn new(public_key: Vec<u8>, secret_key: Vec<u8>, algorithm: KemAlgorithm) -> Self {
31        Self {
32            public_key,
33            secret_key,
34            algorithm,
35        }
36    }
37
38    /// Get the public key.
39    pub fn public_key(&self) -> &[u8] {
40        &self.public_key
41    }
42
43    /// Get the secret key.
44    pub fn secret_key(&self) -> &[u8] {
45        &self.secret_key
46    }
47
48    /// Get the algorithm.
49    pub fn algorithm(&self) -> KemAlgorithm {
50        self.algorithm
51    }
52
53    /// Consume the key pair and return the raw bytes.
54    ///
55    /// The caller is responsible for zeroing the secret key.
56    pub fn into_bytes(mut self) -> (Vec<u8>, Vec<u8>) {
57        let pk = std::mem::take(&mut self.public_key);
58        let sk = std::mem::take(&mut self.secret_key);
59        std::mem::forget(self); // Don't run destructor
60        (pk, sk)
61    }
62}
63
64impl std::fmt::Debug for KeyPair {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("KeyPair")
67            .field("algorithm", &self.algorithm)
68            .field("public_key_len", &self.public_key.len())
69            .field("secret_key_len", &self.secret_key.len())
70            .finish()
71    }
72}
73
74/// Result of KEM encapsulation.
75///
76/// Contains the ciphertext and shared secret. The shared secret
77/// is automatically zeroed when dropped.
78#[derive(Clone, ZeroizeOnDrop)]
79pub struct EncapsulationResult {
80    /// Ciphertext bytes
81    #[zeroize(skip)]
82    ciphertext: Vec<u8>,
83    /// Shared secret bytes (zeroized on drop)
84    shared_secret: Vec<u8>,
85}
86
87impl EncapsulationResult {
88    /// Create a new encapsulation result.
89    fn new(ciphertext: Vec<u8>, shared_secret: Vec<u8>) -> Self {
90        Self {
91            ciphertext,
92            shared_secret,
93        }
94    }
95
96    /// Get the ciphertext.
97    pub fn ciphertext(&self) -> &[u8] {
98        &self.ciphertext
99    }
100
101    /// Get the shared secret.
102    pub fn shared_secret(&self) -> &[u8] {
103        &self.shared_secret
104    }
105
106    /// Consume and return as tuple.
107    pub fn into_parts(mut self) -> (Vec<u8>, Vec<u8>) {
108        let ct = std::mem::take(&mut self.ciphertext);
109        let ss = std::mem::take(&mut self.shared_secret);
110        std::mem::forget(self);
111        (ct, ss)
112    }
113}
114
115impl std::fmt::Debug for EncapsulationResult {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        f.debug_struct("EncapsulationResult")
118            .field("ciphertext_len", &self.ciphertext.len())
119            .field("shared_secret_len", &self.shared_secret.len())
120            .finish()
121    }
122}
123
124/// KEM (Key Encapsulation Mechanism) subsystem.
125///
126/// Provides access to post-quantum key encapsulation operations
127/// using ML-KEM (formerly Kyber).
128///
129/// # Example
130///
131/// ```no_run
132/// use quantacore::{initialize, open_first_device, KemAlgorithm};
133///
134/// initialize().unwrap();
135/// let device = open_first_device().unwrap();
136/// let kem = device.kem();
137///
138/// // Generate key pair
139/// let keypair = kem.generate_keypair(KemAlgorithm::MlKem768).unwrap();
140///
141/// // Encapsulate (sender side)
142/// let (ciphertext, sender_secret) = kem.encapsulate(
143///     keypair.public_key(),
144///     KemAlgorithm::MlKem768
145/// ).unwrap();
146///
147/// // Decapsulate (receiver side)
148/// let receiver_secret = kem.decapsulate(
149///     keypair.secret_key(),
150///     &ciphertext,
151///     KemAlgorithm::MlKem768
152/// ).unwrap();
153///
154/// assert_eq!(sender_secret, receiver_secret);
155/// ```
156#[derive(Clone)]
157pub struct Kem {
158    device: Device,
159}
160
161impl Kem {
162    /// Create a new KEM subsystem handle.
163    pub(crate) fn new(device: Device) -> Self {
164        Self { device }
165    }
166
167    /// Generate a KEM key pair.
168    ///
169    /// # Arguments
170    ///
171    /// * `algorithm` - The KEM algorithm to use
172    ///
173    /// # Returns
174    ///
175    /// A `KeyPair` containing the public and secret keys.
176    pub fn generate_keypair(&self, algorithm: KemAlgorithm) -> Result<KeyPair> {
177        let pk_size = algorithm.public_key_size();
178        let sk_size = algorithm.secret_key_size();
179
180        let mut public_key = vec![0u8; pk_size];
181        let mut secret_key = vec![0u8; sk_size];
182        let mut pk_len = pk_size;
183        let mut sk_len = sk_size;
184
185        let result = unsafe {
186            ffi::quac_kem_keygen(
187                self.device.handle(),
188                algorithm.to_raw(),
189                public_key.as_mut_ptr(),
190                &mut pk_len,
191                secret_key.as_mut_ptr(),
192                &mut sk_len,
193            )
194        };
195
196        check_error(result)?;
197
198        public_key.truncate(pk_len);
199        secret_key.truncate(sk_len);
200
201        Ok(KeyPair::new(public_key, secret_key, algorithm))
202    }
203
204    /// Generate ML-KEM-512 key pair.
205    pub fn generate_keypair_512(&self) -> Result<KeyPair> {
206        self.generate_keypair(KemAlgorithm::MlKem512)
207    }
208
209    /// Generate ML-KEM-768 key pair.
210    pub fn generate_keypair_768(&self) -> Result<KeyPair> {
211        self.generate_keypair(KemAlgorithm::MlKem768)
212    }
213
214    /// Generate ML-KEM-1024 key pair.
215    pub fn generate_keypair_1024(&self) -> Result<KeyPair> {
216        self.generate_keypair(KemAlgorithm::MlKem1024)
217    }
218
219    /// Encapsulate to generate a shared secret and ciphertext.
220    ///
221    /// This is the sender's operation. The ciphertext should be sent
222    /// to the recipient who can decapsulate using their secret key.
223    ///
224    /// # Arguments
225    ///
226    /// * `public_key` - The recipient's public key
227    /// * `algorithm` - The KEM algorithm to use
228    ///
229    /// # Returns
230    ///
231    /// A tuple of (ciphertext, shared_secret).
232    pub fn encapsulate(
233        &self,
234        public_key: &[u8],
235        algorithm: KemAlgorithm,
236    ) -> Result<(Vec<u8>, Vec<u8>)> {
237        let ct_size = algorithm.ciphertext_size();
238        let ss_size = algorithm.shared_secret_size();
239
240        let mut ciphertext = vec![0u8; ct_size];
241        let mut shared_secret = vec![0u8; ss_size];
242        let mut ct_len = ct_size;
243        let mut ss_len = ss_size;
244
245        let result = unsafe {
246            ffi::quac_kem_encapsulate(
247                self.device.handle(),
248                algorithm.to_raw(),
249                public_key.as_ptr(),
250                public_key.len(),
251                ciphertext.as_mut_ptr(),
252                &mut ct_len,
253                shared_secret.as_mut_ptr(),
254                &mut ss_len,
255            )
256        };
257
258        check_error(result)?;
259
260        ciphertext.truncate(ct_len);
261        shared_secret.truncate(ss_len);
262
263        Ok((ciphertext, shared_secret))
264    }
265
266    /// Encapsulate returning an `EncapsulationResult`.
267    pub fn encapsulate_result(
268        &self,
269        public_key: &[u8],
270        algorithm: KemAlgorithm,
271    ) -> Result<EncapsulationResult> {
272        let (ct, ss) = self.encapsulate(public_key, algorithm)?;
273        Ok(EncapsulationResult::new(ct, ss))
274    }
275
276    /// Encapsulate using ML-KEM-512.
277    pub fn encapsulate_512(&self, public_key: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
278        self.encapsulate(public_key, KemAlgorithm::MlKem512)
279    }
280
281    /// Encapsulate using ML-KEM-768.
282    pub fn encapsulate_768(&self, public_key: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
283        self.encapsulate(public_key, KemAlgorithm::MlKem768)
284    }
285
286    /// Encapsulate using ML-KEM-1024.
287    pub fn encapsulate_1024(&self, public_key: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
288        self.encapsulate(public_key, KemAlgorithm::MlKem1024)
289    }
290
291    /// Decapsulate to recover the shared secret.
292    ///
293    /// This is the recipient's operation. Use the secret key and
294    /// received ciphertext to recover the shared secret.
295    ///
296    /// # Arguments
297    ///
298    /// * `secret_key` - The recipient's secret key
299    /// * `ciphertext` - The ciphertext received from the sender
300    /// * `algorithm` - The KEM algorithm to use
301    ///
302    /// # Returns
303    ///
304    /// The shared secret (same as the sender's).
305    pub fn decapsulate(
306        &self,
307        secret_key: &[u8],
308        ciphertext: &[u8],
309        algorithm: KemAlgorithm,
310    ) -> Result<Vec<u8>> {
311        let ss_size = algorithm.shared_secret_size();
312        let mut shared_secret = vec![0u8; ss_size];
313        let mut ss_len = ss_size;
314
315        let result = unsafe {
316            ffi::quac_kem_decapsulate(
317                self.device.handle(),
318                algorithm.to_raw(),
319                secret_key.as_ptr(),
320                secret_key.len(),
321                ciphertext.as_ptr(),
322                ciphertext.len(),
323                shared_secret.as_mut_ptr(),
324                &mut ss_len,
325            )
326        };
327
328        check_error(result)?;
329        shared_secret.truncate(ss_len);
330
331        Ok(shared_secret)
332    }
333
334    /// Decapsulate using ML-KEM-512.
335    pub fn decapsulate_512(&self, secret_key: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>> {
336        self.decapsulate(secret_key, ciphertext, KemAlgorithm::MlKem512)
337    }
338
339    /// Decapsulate using ML-KEM-768.
340    pub fn decapsulate_768(&self, secret_key: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>> {
341        self.decapsulate(secret_key, ciphertext, KemAlgorithm::MlKem768)
342    }
343
344    /// Decapsulate using ML-KEM-1024.
345    pub fn decapsulate_1024(&self, secret_key: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>> {
346        self.decapsulate(secret_key, ciphertext, KemAlgorithm::MlKem1024)
347    }
348}
349
350impl std::fmt::Debug for Kem {
351    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352        f.debug_struct("Kem").finish()
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_keypair_debug_hides_secret() {
362        let kp = KeyPair::new(
363            vec![1, 2, 3],
364            vec![4, 5, 6],
365            KemAlgorithm::MlKem768,
366        );
367        let debug = format!("{:?}", kp);
368        assert!(!debug.contains("[4, 5, 6]"));
369        assert!(debug.contains("secret_key_len"));
370    }
371
372    #[test]
373    fn test_encapsulation_result_debug() {
374        let result = EncapsulationResult::new(vec![1, 2, 3], vec![4, 5, 6]);
375        let debug = format!("{:?}", result);
376        assert!(debug.contains("ciphertext_len"));
377    }
378}