1use crate::Error::Decrypt;
2use byteorder::{ByteOrder, NetworkEndian, WriteBytesExt};
3use pqc_kyber::{
4 Keypair, KyberError, PublicKey, SecretKey, KYBER_CIPHERTEXTBYTES, KYBER_SECRETKEYBYTES,
5};
6
7use pqc_kyber::indcpa::{indcpa_dec, indcpa_enc, indcpa_keypair};
8pub use pqc_kyber::{decapsulate, encapsulate};
9
10const KYBER_BLOCK_SIZE: usize = 32;
11const LENGTH_FIELD: usize = 8;
12
13pub fn encrypt<T: AsRef<[u8]>, R: AsRef<[u8]>, V: AsRef<[u8]>>(
14 public_key: T,
15 plaintext: R,
16 nonce: V,
17) -> Result<Vec<u8>, Error> {
18 let full_ciphertext_len = ct_len(plaintext.as_ref().len());
19 let mut out = vec![0u8; full_ciphertext_len];
20 encrypt_into(public_key, plaintext, nonce, out.as_mut_slice())?;
21 Ok(out)
22}
23
24pub fn ct_len(plaintext_len: usize) -> usize {
26 std::cmp::max(
27 KYBER_CIPHERTEXTBYTES,
28 div_ceil(plaintext_len as f32, KYBER_BLOCK_SIZE as f32) * KYBER_CIPHERTEXTBYTES,
29 ) + LENGTH_FIELD
30}
31
32pub fn plaintext_len(ciphertext: &[u8]) -> Option<usize> {
33 let split_pt = ciphertext.len().saturating_sub(8);
35 if split_pt > ciphertext.len() || split_pt == 0 {
36 return None;
37 }
38
39 let (_, field_length_be) = ciphertext.split_at(split_pt);
40 let plaintext_length = byteorder::NetworkEndian::read_u64(field_length_be) as usize;
41 Some(plaintext_length)
42}
43
44pub fn encrypt_into<T: AsRef<[u8]>, R: AsRef<[u8]>, V: AsRef<[u8]>, O: AsMut<[u8]>>(
45 public_key: T,
46 plaintext: R,
47 nonce: V,
48 mut ret: O,
49) -> Result<(), Error> {
50 let public_key = public_key.as_ref();
51 let nonce = nonce.as_ref();
52 let plaintext = plaintext.as_ref();
53 let plaintext_length = plaintext.len();
54 let ret = ret.as_mut();
55
56 if nonce.len() != 32 {
57 return Err(Error::Encrypt(format!(
58 "Nonce must be 32 bytes, got {}",
59 nonce.len()
60 )));
61 }
62
63 if ret.len() < ct_len(plaintext.len()) {
64 return Err(Error::Encrypt(format!(
65 "Bad output buffer len {}",
66 ret.len()
67 )));
68 }
69
70 if plaintext_length != 0 {
71 let chunks = plaintext.chunks(KYBER_BLOCK_SIZE);
72
73 for (chunk, output) in chunks.zip(ret.chunks_mut(KYBER_CIPHERTEXTBYTES)) {
74 if chunk.len() < KYBER_BLOCK_SIZE {
75 let mut buf = [0u8; KYBER_BLOCK_SIZE];
77 let slice = &mut buf[..chunk.len()];
78 slice.copy_from_slice(chunk);
79 indcpa_enc(output, &buf, public_key, nonce);
80 } else {
81 indcpa_enc(output, chunk, public_key, nonce);
82 }
83 }
84 } else {
85 let zeroes = [0u8; KYBER_BLOCK_SIZE];
87 indcpa_enc(ret, &zeroes, public_key, nonce);
88 }
89
90 let length_pos = ret.len() - 8;
92 (&mut ret[length_pos..])
93 .write_u64::<NetworkEndian>(plaintext_length as u64)
94 .unwrap();
95
96 Ok(())
97}
98
99pub fn decrypt<T: AsRef<[u8]>, R: AsRef<[u8]>>(
100 secret_key: T,
101 ciphertext: R,
102) -> Result<Vec<u8>, Error> {
103 let ciphertext = ciphertext.as_ref();
104 let secret_key = secret_key.as_ref();
105 const CIPHERTEXT_BLOCK_LEN: usize = pqc_kyber::KYBER_CIPHERTEXTBYTES;
107
108 if ciphertext.len() < CIPHERTEXT_BLOCK_LEN {
109 return Err(Decrypt("The input ciphertext is too short".to_string()));
110 }
111
112 let plaintext_length = plaintext_len(ciphertext)
113 .ok_or_else(|| Error::Decrypt("Invalid ciphertext input length".to_string()))?;
114 let split_pt = ciphertext.len().saturating_sub(8);
115 let (concatenated_ciphertexts, _) = ciphertext.split_at(split_pt);
116 let buffer_len = div_ceil(plaintext_length as f32, KYBER_BLOCK_SIZE as f32) * KYBER_BLOCK_SIZE;
120 let mut ret = vec![0u8; buffer_len];
121 for (chunk, output) in concatenated_ciphertexts
123 .chunks(CIPHERTEXT_BLOCK_LEN)
124 .zip(ret.chunks_mut(KYBER_BLOCK_SIZE))
125 {
126 indcpa_dec(output, chunk, secret_key);
127 }
128
129 ret.truncate(plaintext_length);
132
133 Ok(ret)
134}
135
136pub fn pke_keypair() -> Result<(PublicKey, SecretKey), KyberError> {
137 let mut rng = rand::rngs::OsRng;
138 let mut public = [0u8; pqc_kyber::KYBER_PUBLICKEYBYTES];
139 let mut secret = [0u8; KYBER_SECRETKEYBYTES];
140 indcpa_keypair(&mut public, &mut secret, None, &mut rng)?;
141 Ok((public, secret))
142}
143
144pub fn kem_keypair() -> Result<Keypair, KyberError> {
145 let mut rng = rand::rngs::OsRng;
146 pqc_kyber::keypair(&mut rng)
147}
148
149#[derive(Debug, Clone)]
150pub enum Error {
151 Encrypt(String),
152 Decrypt(String),
153}
154
155fn div_ceil(a: f32, b: f32) -> usize {
156 ((a + b - 1.0) / b) as _
157}
158
159#[cfg(test)]
160mod tests {
161 use crate::pke_keypair;
162
163 #[test]
164 fn test_pke() {
165 let (pk, sk) = pke_keypair().unwrap();
166 let nonce = (0..32).collect::<Vec<u8>>();
167 let mut message = vec![];
168 for x in 0..1000 {
169 if x != 0 {
171 message.push(x as u8);
172 }
173
174 let ciphertext = crate::encrypt(pk, &message, &nonce).unwrap();
175 assert_ne!(ciphertext, message);
176 let plaintext = crate::decrypt(sk, &ciphertext).unwrap();
177 assert_eq!(plaintext, message);
178 }
179 }
180
181 #[test]
182 fn test_pke_large() {
183 let (pk, sk) = pke_keypair().unwrap();
184 let nonce = (0..32).collect::<Vec<u8>>();
185 let message = (0..10000).map(|r| (r % 256) as u8).collect::<Vec<u8>>();
186 let ciphertext = crate::encrypt(pk, &message, nonce).unwrap();
187 assert_ne!(ciphertext, message);
188 let plaintext = crate::decrypt(sk, &ciphertext).unwrap();
189 assert_eq!(plaintext, message);
190 }
191}