1use crate::device::Device;
6use crate::error::{check_error, Result};
7use crate::ffi;
8use crate::types::KemAlgorithm;
9
10use zeroize::ZeroizeOnDrop;
11
12#[derive(Clone, ZeroizeOnDrop)]
17pub struct KeyPair {
18 #[zeroize(skip)]
20 public_key: Vec<u8>,
21 secret_key: Vec<u8>,
23 #[zeroize(skip)]
25 algorithm: KemAlgorithm,
26}
27
28impl KeyPair {
29 fn new(public_key: Vec<u8>, secret_key: Vec<u8>, algorithm: KemAlgorithm) -> Self {
31 Self {
32 public_key,
33 secret_key,
34 algorithm,
35 }
36 }
37
38 pub fn public_key(&self) -> &[u8] {
40 &self.public_key
41 }
42
43 pub fn secret_key(&self) -> &[u8] {
45 &self.secret_key
46 }
47
48 pub fn algorithm(&self) -> KemAlgorithm {
50 self.algorithm
51 }
52
53 pub fn into_bytes(mut self) -> (Vec<u8>, Vec<u8>) {
57 let pk = std::mem::take(&mut self.public_key);
58 let sk = std::mem::take(&mut self.secret_key);
59 std::mem::forget(self); (pk, sk)
61 }
62}
63
64impl std::fmt::Debug for KeyPair {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("KeyPair")
67 .field("algorithm", &self.algorithm)
68 .field("public_key_len", &self.public_key.len())
69 .field("secret_key_len", &self.secret_key.len())
70 .finish()
71 }
72}
73
74#[derive(Clone, ZeroizeOnDrop)]
79pub struct EncapsulationResult {
80 #[zeroize(skip)]
82 ciphertext: Vec<u8>,
83 shared_secret: Vec<u8>,
85}
86
87impl EncapsulationResult {
88 fn new(ciphertext: Vec<u8>, shared_secret: Vec<u8>) -> Self {
90 Self {
91 ciphertext,
92 shared_secret,
93 }
94 }
95
96 pub fn ciphertext(&self) -> &[u8] {
98 &self.ciphertext
99 }
100
101 pub fn shared_secret(&self) -> &[u8] {
103 &self.shared_secret
104 }
105
106 pub fn into_parts(mut self) -> (Vec<u8>, Vec<u8>) {
108 let ct = std::mem::take(&mut self.ciphertext);
109 let ss = std::mem::take(&mut self.shared_secret);
110 std::mem::forget(self);
111 (ct, ss)
112 }
113}
114
115impl std::fmt::Debug for EncapsulationResult {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 f.debug_struct("EncapsulationResult")
118 .field("ciphertext_len", &self.ciphertext.len())
119 .field("shared_secret_len", &self.shared_secret.len())
120 .finish()
121 }
122}
123
124#[derive(Clone)]
157pub struct Kem {
158 device: Device,
159}
160
161impl Kem {
162 pub(crate) fn new(device: Device) -> Self {
164 Self { device }
165 }
166
167 pub fn generate_keypair(&self, algorithm: KemAlgorithm) -> Result<KeyPair> {
177 let pk_size = algorithm.public_key_size();
178 let sk_size = algorithm.secret_key_size();
179
180 let mut public_key = vec![0u8; pk_size];
181 let mut secret_key = vec![0u8; sk_size];
182 let mut pk_len = pk_size;
183 let mut sk_len = sk_size;
184
185 let result = unsafe {
186 ffi::quac_kem_keygen(
187 self.device.handle(),
188 algorithm.to_raw(),
189 public_key.as_mut_ptr(),
190 &mut pk_len,
191 secret_key.as_mut_ptr(),
192 &mut sk_len,
193 )
194 };
195
196 check_error(result)?;
197
198 public_key.truncate(pk_len);
199 secret_key.truncate(sk_len);
200
201 Ok(KeyPair::new(public_key, secret_key, algorithm))
202 }
203
204 pub fn generate_keypair_512(&self) -> Result<KeyPair> {
206 self.generate_keypair(KemAlgorithm::MlKem512)
207 }
208
209 pub fn generate_keypair_768(&self) -> Result<KeyPair> {
211 self.generate_keypair(KemAlgorithm::MlKem768)
212 }
213
214 pub fn generate_keypair_1024(&self) -> Result<KeyPair> {
216 self.generate_keypair(KemAlgorithm::MlKem1024)
217 }
218
219 pub fn encapsulate(
233 &self,
234 public_key: &[u8],
235 algorithm: KemAlgorithm,
236 ) -> Result<(Vec<u8>, Vec<u8>)> {
237 let ct_size = algorithm.ciphertext_size();
238 let ss_size = algorithm.shared_secret_size();
239
240 let mut ciphertext = vec![0u8; ct_size];
241 let mut shared_secret = vec![0u8; ss_size];
242 let mut ct_len = ct_size;
243 let mut ss_len = ss_size;
244
245 let result = unsafe {
246 ffi::quac_kem_encapsulate(
247 self.device.handle(),
248 algorithm.to_raw(),
249 public_key.as_ptr(),
250 public_key.len(),
251 ciphertext.as_mut_ptr(),
252 &mut ct_len,
253 shared_secret.as_mut_ptr(),
254 &mut ss_len,
255 )
256 };
257
258 check_error(result)?;
259
260 ciphertext.truncate(ct_len);
261 shared_secret.truncate(ss_len);
262
263 Ok((ciphertext, shared_secret))
264 }
265
266 pub fn encapsulate_result(
268 &self,
269 public_key: &[u8],
270 algorithm: KemAlgorithm,
271 ) -> Result<EncapsulationResult> {
272 let (ct, ss) = self.encapsulate(public_key, algorithm)?;
273 Ok(EncapsulationResult::new(ct, ss))
274 }
275
276 pub fn encapsulate_512(&self, public_key: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
278 self.encapsulate(public_key, KemAlgorithm::MlKem512)
279 }
280
281 pub fn encapsulate_768(&self, public_key: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
283 self.encapsulate(public_key, KemAlgorithm::MlKem768)
284 }
285
286 pub fn encapsulate_1024(&self, public_key: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
288 self.encapsulate(public_key, KemAlgorithm::MlKem1024)
289 }
290
291 pub fn decapsulate(
306 &self,
307 secret_key: &[u8],
308 ciphertext: &[u8],
309 algorithm: KemAlgorithm,
310 ) -> Result<Vec<u8>> {
311 let ss_size = algorithm.shared_secret_size();
312 let mut shared_secret = vec![0u8; ss_size];
313 let mut ss_len = ss_size;
314
315 let result = unsafe {
316 ffi::quac_kem_decapsulate(
317 self.device.handle(),
318 algorithm.to_raw(),
319 secret_key.as_ptr(),
320 secret_key.len(),
321 ciphertext.as_ptr(),
322 ciphertext.len(),
323 shared_secret.as_mut_ptr(),
324 &mut ss_len,
325 )
326 };
327
328 check_error(result)?;
329 shared_secret.truncate(ss_len);
330
331 Ok(shared_secret)
332 }
333
334 pub fn decapsulate_512(&self, secret_key: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>> {
336 self.decapsulate(secret_key, ciphertext, KemAlgorithm::MlKem512)
337 }
338
339 pub fn decapsulate_768(&self, secret_key: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>> {
341 self.decapsulate(secret_key, ciphertext, KemAlgorithm::MlKem768)
342 }
343
344 pub fn decapsulate_1024(&self, secret_key: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>> {
346 self.decapsulate(secret_key, ciphertext, KemAlgorithm::MlKem1024)
347 }
348}
349
350impl std::fmt::Debug for Kem {
351 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352 f.debug_struct("Kem").finish()
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_keypair_debug_hides_secret() {
362 let kp = KeyPair::new(
363 vec![1, 2, 3],
364 vec![4, 5, 6],
365 KemAlgorithm::MlKem768,
366 );
367 let debug = format!("{:?}", kp);
368 assert!(!debug.contains("[4, 5, 6]"));
369 assert!(debug.contains("secret_key_len"));
370 }
371
372 #[test]
373 fn test_encapsulation_result_debug() {
374 let result = EncapsulationResult::new(vec![1, 2, 3], vec![4, 5, 6]);
375 let debug = format!("{:?}", result);
376 assert!(debug.contains("ciphertext_len"));
377 }
378}