1use libloading::{Library, Symbol};
27use std::os::raw::c_int;
28use std::sync::Once;
29
30pub const CRYPTO_PUBLICKEYBYTES: usize = 1184;
31pub const CRYPTO_SECRETKEYBYTES: usize = 2400;
32pub const CRYPTO_CIPHERTEXTBYTES: usize = 1088;
33pub const CRYPTO_BYTES: usize = 32;
34
35static INIT: Once = Once::new();
36static mut LIBRARY: Option<Library> = None;
37
38fn load_library() -> &'static Library {
39 INIT.call_once(|| unsafe {
40 LIBRARY = Some(Library::new("kyber.dll").expect("Failed to load kyber.dll"));
41 });
42 unsafe { LIBRARY.as_ref().unwrap() }
43}
44
45pub fn generate_keypair(
49) -> Result<([u8; CRYPTO_PUBLICKEYBYTES], [u8; CRYPTO_SECRETKEYBYTES]), String> {
50 let mut pk = [0u8; CRYPTO_PUBLICKEYBYTES];
51 let mut sk = [0u8; CRYPTO_SECRETKEYBYTES];
52
53 let result = crypto_kem_keypair(&mut pk, &mut sk);
54 if result != 0 {
55 return Err(format!(
56 "Keypair generation failed with error code: {}",
57 result
58 ));
59 }
60
61 Ok((pk, sk))
62}
63
64pub fn encapsulate(
68 pk: &[u8; CRYPTO_PUBLICKEYBYTES],
69) -> Result<([u8; CRYPTO_CIPHERTEXTBYTES], [u8; CRYPTO_BYTES]), String> {
70 let mut ct = [0u8; CRYPTO_CIPHERTEXTBYTES];
71 let mut ss = [0u8; CRYPTO_BYTES];
72
73 let result = crypto_kem_enc(&mut ct, &mut ss, pk);
74 if result != 0 {
75 return Err(format!("Encapsulation failed with error code: {}", result));
76 }
77
78 Ok((ct, ss))
79}
80
81pub fn decapsulate(
85 ct: &[u8; CRYPTO_CIPHERTEXTBYTES],
86 sk: &[u8; CRYPTO_SECRETKEYBYTES],
87) -> Result<[u8; CRYPTO_BYTES], String> {
88 let mut ss = [0u8; CRYPTO_BYTES];
89
90 let result = crypto_kem_dec(&mut ss, ct, sk);
91 if result != 0 {
92 return Err(format!("Decapsulation failed with error code: {}", result));
93 }
94
95 Ok(ss)
96}
97
98fn crypto_kem_keypair(
99 pk: &mut [u8; CRYPTO_PUBLICKEYBYTES],
100 sk: &mut [u8; CRYPTO_SECRETKEYBYTES],
101) -> i32 {
102 let lib = load_library();
103 unsafe {
104 let func: Symbol<unsafe extern "C" fn(*mut u8, *mut u8) -> c_int> = lib
105 .get(b"pqcrystals_kyber768_ref_keypair")
106 .expect("Failed to load keypair function");
107 func(pk.as_mut_ptr(), sk.as_mut_ptr())
108 }
109}
110
111fn crypto_kem_enc(
112 ct: &mut [u8; CRYPTO_CIPHERTEXTBYTES],
113 ss: &mut [u8; CRYPTO_BYTES],
114 pk: &[u8; CRYPTO_PUBLICKEYBYTES],
115) -> i32 {
116 let lib = load_library();
117 unsafe {
118 let func: Symbol<unsafe extern "C" fn(*mut u8, *mut u8, *const u8) -> c_int> = lib
119 .get(b"pqcrystals_kyber768_ref_enc")
120 .expect("Failed to load enc function");
121 func(ct.as_mut_ptr(), ss.as_mut_ptr(), pk.as_ptr())
122 }
123}
124
125fn crypto_kem_dec(
126 ss: &mut [u8; CRYPTO_BYTES],
127 ct: &[u8; CRYPTO_CIPHERTEXTBYTES],
128 sk: &[u8; CRYPTO_SECRETKEYBYTES],
129) -> i32 {
130 let lib = load_library();
131 unsafe {
132 let func: Symbol<unsafe extern "C" fn(*mut u8, *const u8, *const u8) -> c_int> = lib
133 .get(b"pqcrystals_kyber768_ref_dec")
134 .expect("Failed to load dec function");
135 func(ss.as_mut_ptr(), ct.as_ptr(), sk.as_ptr())
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn test_generate_keypair() {
145 let (pk, sk) = generate_keypair().unwrap();
146 assert_eq!(pk.len(), CRYPTO_PUBLICKEYBYTES);
147 assert_eq!(sk.len(), CRYPTO_SECRETKEYBYTES);
148 }
149
150 #[test]
151 fn test_encapsulate() {
152 let (pk, _) = generate_keypair().unwrap();
153 let (ct, ss) = encapsulate(&pk).unwrap();
154 assert_eq!(ct.len(), CRYPTO_CIPHERTEXTBYTES);
155 assert_eq!(ss.len(), CRYPTO_BYTES);
156 }
157
158 #[test]
159 fn test_decapsulate() {
160 let (pk, sk) = generate_keypair().unwrap();
161 let (ct, ss_enc) = encapsulate(&pk).unwrap();
162 let ss_dec = decapsulate(&ct, &sk).unwrap();
163 assert_eq!(ss_enc, ss_dec);
164 }
165
166 #[test]
167 fn test_invalid_decapsulation() {
168 let (pk1, _sk1) = generate_keypair().unwrap();
169 let (_, sk2) = generate_keypair().unwrap();
170 let (ct, ss_enc) = encapsulate(&pk1).unwrap();
171
172 let ss_dec = decapsulate(&ct, &sk2).unwrap();
174
175 assert_ne!(ss_enc, ss_dec);
177 }
178
179 #[test]
180 fn test_multiple_encapsulations() {
181 let (pk, sk) = generate_keypair().unwrap();
182
183 for _ in 0..10 {
184 let (ct, ss_enc) = encapsulate(&pk).unwrap();
185 let ss_dec = decapsulate(&ct, &sk).unwrap();
186 assert_eq!(ss_enc, ss_dec);
187 }
188 }
189}