kyber_pke/
lib.rs

1use crate::Error::Decrypt;
2use byteorder::{ByteOrder, NetworkEndian, WriteBytesExt};
3use pqc_kyber::{
4    Keypair, KyberError, PublicKey, SecretKey, KYBER_CIPHERTEXTBYTES, KYBER_SECRETKEYBYTES,
5};
6
7use pqc_kyber::indcpa::{indcpa_dec, indcpa_enc, indcpa_keypair};
8pub use pqc_kyber::{decapsulate, encapsulate};
9
10const KYBER_BLOCK_SIZE: usize = 32;
11const LENGTH_FIELD: usize = 8;
12
13pub fn encrypt<T: AsRef<[u8]>, R: AsRef<[u8]>, V: AsRef<[u8]>>(
14    public_key: T,
15    plaintext: R,
16    nonce: V,
17) -> Result<Vec<u8>, Error> {
18    let full_ciphertext_len = ct_len(plaintext.as_ref().len());
19    let mut out = vec![0u8; full_ciphertext_len];
20    encrypt_into(public_key, plaintext, nonce, out.as_mut_slice())?;
21    Ok(out)
22}
23
24/// returns the ciphertext expected length given an input plaintext length
25pub fn ct_len(plaintext_len: usize) -> usize {
26    std::cmp::max(
27        KYBER_CIPHERTEXTBYTES,
28        div_ceil(plaintext_len as f32, KYBER_BLOCK_SIZE as f32) * KYBER_CIPHERTEXTBYTES,
29    ) + LENGTH_FIELD
30}
31
32pub fn plaintext_len(ciphertext: &[u8]) -> Option<usize> {
33    // The final 8 bytes are for the original length of the plaintext
34    let split_pt = ciphertext.len().saturating_sub(8);
35    if split_pt > ciphertext.len() || split_pt == 0 {
36        return None;
37    }
38
39    let (_, field_length_be) = ciphertext.split_at(split_pt);
40    let plaintext_length = byteorder::NetworkEndian::read_u64(field_length_be) as usize;
41    Some(plaintext_length)
42}
43
44pub fn encrypt_into<T: AsRef<[u8]>, R: AsRef<[u8]>, V: AsRef<[u8]>, O: AsMut<[u8]>>(
45    public_key: T,
46    plaintext: R,
47    nonce: V,
48    mut ret: O,
49) -> Result<(), Error> {
50    let public_key = public_key.as_ref();
51    let nonce = nonce.as_ref();
52    let plaintext = plaintext.as_ref();
53    let plaintext_length = plaintext.len();
54    let ret = ret.as_mut();
55
56    if nonce.len() != 32 {
57        return Err(Error::Encrypt(format!(
58            "Nonce must be 32 bytes, got {}",
59            nonce.len()
60        )));
61    }
62
63    if ret.len() < ct_len(plaintext.len()) {
64        return Err(Error::Encrypt(format!(
65            "Bad output buffer len {}",
66            ret.len()
67        )));
68    }
69
70    if plaintext_length != 0 {
71        let chunks = plaintext.chunks(KYBER_BLOCK_SIZE);
72
73        for (chunk, output) in chunks.zip(ret.chunks_mut(KYBER_CIPHERTEXTBYTES)) {
74            if chunk.len() < KYBER_BLOCK_SIZE {
75                // fit the buffer to KYBER_BLOCK_SIZE
76                let mut buf = [0u8; KYBER_BLOCK_SIZE];
77                let slice = &mut buf[..chunk.len()];
78                slice.copy_from_slice(chunk);
79                indcpa_enc(output, &buf, public_key, nonce);
80            } else {
81                indcpa_enc(output, chunk, public_key, nonce);
82            }
83        }
84    } else {
85        // fill with zeroes
86        let zeroes = [0u8; KYBER_BLOCK_SIZE];
87        indcpa_enc(ret, &zeroes, public_key, nonce);
88    }
89
90    // append the plaintext len
91    let length_pos = ret.len() - 8;
92    (&mut ret[length_pos..])
93        .write_u64::<NetworkEndian>(plaintext_length as u64)
94        .unwrap();
95
96    Ok(())
97}
98
99pub fn decrypt<T: AsRef<[u8]>, R: AsRef<[u8]>>(
100    secret_key: T,
101    ciphertext: R,
102) -> Result<Vec<u8>, Error> {
103    let ciphertext = ciphertext.as_ref();
104    let secret_key = secret_key.as_ref();
105    // calculate the length of each block
106    const CIPHERTEXT_BLOCK_LEN: usize = pqc_kyber::KYBER_CIPHERTEXTBYTES;
107
108    if ciphertext.len() < CIPHERTEXT_BLOCK_LEN {
109        return Err(Decrypt("The input ciphertext is too short".to_string()));
110    }
111
112    let plaintext_length = plaintext_len(ciphertext)
113        .ok_or_else(|| Error::Decrypt("Invalid ciphertext input length".to_string()))?;
114    let split_pt = ciphertext.len().saturating_sub(8);
115    let (concatenated_ciphertexts, _) = ciphertext.split_at(split_pt);
116    // pt len < 32: size must be 32
117    // pt len = 32: size must be 32
118    // pt len > 32: size must be div.ceil(pt.len()/32)*32
119    let buffer_len = div_ceil(plaintext_length as f32, KYBER_BLOCK_SIZE as f32) * KYBER_BLOCK_SIZE;
120    let mut ret = vec![0u8; buffer_len];
121    // split the concatenated ciphertexts
122    for (chunk, output) in concatenated_ciphertexts
123        .chunks(CIPHERTEXT_BLOCK_LEN)
124        .zip(ret.chunks_mut(KYBER_BLOCK_SIZE))
125    {
126        indcpa_dec(output, chunk, secret_key);
127    }
128
129    // finally, truncate the vec, as the final block is 32 in length, and may be more
130    // than what the plaintext requires
131    ret.truncate(plaintext_length);
132
133    Ok(ret)
134}
135
136pub fn pke_keypair() -> Result<(PublicKey, SecretKey), KyberError> {
137    let mut rng = rand::rngs::OsRng;
138    let mut public = [0u8; pqc_kyber::KYBER_PUBLICKEYBYTES];
139    let mut secret = [0u8; KYBER_SECRETKEYBYTES];
140    indcpa_keypair(&mut public, &mut secret, None, &mut rng)?;
141    Ok((public, secret))
142}
143
144pub fn kem_keypair() -> Result<Keypair, KyberError> {
145    let mut rng = rand::rngs::OsRng;
146    pqc_kyber::keypair(&mut rng)
147}
148
149#[derive(Debug, Clone)]
150pub enum Error {
151    Encrypt(String),
152    Decrypt(String),
153}
154
155fn div_ceil(a: f32, b: f32) -> usize {
156    ((a + b - 1.0) / b) as _
157}
158
159#[cfg(test)]
160mod tests {
161    use crate::pke_keypair;
162
163    #[test]
164    fn test_pke() {
165        let (pk, sk) = pke_keypair().unwrap();
166        let nonce = (0..32).collect::<Vec<u8>>();
167        let mut message = vec![];
168        for x in 0..1000 {
169            // test encryption of zero-sized inputs when x=0
170            if x != 0 {
171                message.push(x as u8);
172            }
173
174            let ciphertext = crate::encrypt(pk, &message, &nonce).unwrap();
175            assert_ne!(ciphertext, message);
176            let plaintext = crate::decrypt(sk, &ciphertext).unwrap();
177            assert_eq!(plaintext, message);
178        }
179    }
180
181    #[test]
182    fn test_pke_large() {
183        let (pk, sk) = pke_keypair().unwrap();
184        let nonce = (0..32).collect::<Vec<u8>>();
185        let message = (0..10000).map(|r| (r % 256) as u8).collect::<Vec<u8>>();
186        let ciphertext = crate::encrypt(pk, &message, nonce).unwrap();
187        assert_ne!(ciphertext, message);
188        let plaintext = crate::decrypt(sk, &ciphertext).unwrap();
189        assert_eq!(plaintext, message);
190    }
191}