ns_inscriber/wallet/
encrypt.rs

1use aes_gcm::{
2    aead::{AeadCore, KeyInit},
3    AeadInPlace, Aes256Gcm, Key, Nonce,
4};
5use coset::{iana, CborSerializable, CoseEncrypt0, CoseEncrypt0Builder, HeaderBuilder};
6use rand_core::OsRng;
7
8use ns_protocol::{ns::Value, state::to_bytes};
9
10use super::{skip_tag, with_tag, ENCRYPT0_TAG};
11
12pub struct Encrypt0 {
13    kid: Option<Value>,
14    cipher: Aes256Gcm,
15}
16
17impl Encrypt0 {
18    pub fn new(key: [u8; 32], kid: Option<Value>) -> Self {
19        let key = Key::<Aes256Gcm>::from_slice(&key);
20        let cipher = Aes256Gcm::new(key);
21        Self { kid, cipher }
22    }
23
24    pub fn encrypt(
25        &self,
26        plaintext: &[u8],
27        aad: &[u8],
28        cid: Option<Value>,
29    ) -> anyhow::Result<Vec<u8>> {
30        let protected = HeaderBuilder::new()
31            .algorithm(iana::Algorithm::A256GCM)
32            .build();
33        let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
34        let mut unprotected = HeaderBuilder::new()
35            .key_id(to_bytes(&self.kid)?)
36            .iv(nonce.to_vec());
37        if let Some(kid) = self.kid.as_ref() {
38            unprotected = unprotected.key_id(to_bytes(kid)?);
39        }
40        if let Some(cid) = cid {
41            unprotected = unprotected.text_value("cid".to_string(), cid);
42        }
43
44        let e0 = CoseEncrypt0Builder::new()
45            .protected(protected)
46            .unprotected(unprotected.build())
47            .create_ciphertext(plaintext, aad, |plain, enc| {
48                let mut buf: Vec<u8> = Vec::with_capacity(plain.len() + 16);
49                buf.extend_from_slice(plain);
50                self.cipher.encrypt_in_place(&nonce, enc, &mut buf).unwrap();
51                buf
52            })
53            .build();
54        Ok(with_tag(
55            &ENCRYPT0_TAG,
56            e0.to_vec().map_err(anyhow::Error::msg)?.as_slice(),
57        ))
58    }
59
60    pub fn decrypt(&self, encrypt0_data: &[u8], aad: &[u8]) -> anyhow::Result<Vec<u8>> {
61        let e0 = CoseEncrypt0::from_slice(skip_tag(&ENCRYPT0_TAG, encrypt0_data))
62            .map_err(anyhow::Error::msg)?;
63        if e0.unprotected.iv.len() != 12 {
64            return Err(anyhow::Error::msg("invalid iv length"));
65        }
66        let nonce = Nonce::from_slice(&e0.unprotected.iv);
67        e0.decrypt(aad, |cipher, enc| {
68            let mut buf: Vec<u8> = Vec::with_capacity(cipher.len() + 16);
69            buf.extend_from_slice(cipher);
70            self.cipher
71                .decrypt_in_place(nonce, enc, &mut buf)
72                .map_err(anyhow::Error::msg)?;
73            Ok(buf)
74        })
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    use rand_core::RngCore;
83
84    #[test]
85    fn encrypt0_works() {
86        let mut key = [0u8; 32];
87
88        OsRng.fill_bytes(&mut key);
89        let encrypt0 = Encrypt0::new(key, None);
90
91        let plaintext = b"hello world";
92        let data = encrypt0
93            .encrypt(plaintext, b"Name & Service Protocol", None)
94            .unwrap();
95        // println!("{}", hex_string(&data));
96        let res = encrypt0.decrypt(&data, b"Name & Service Protocol").unwrap();
97        assert_eq!(res, plaintext);
98        assert!(encrypt0
99            .decrypt(&data[2..], b"Name & Service Protocol")
100            .is_err());
101        assert!(encrypt0.decrypt(&data, b"NS").is_err());
102    }
103}