ns_inscriber/wallet/
encrypt.rs1use aes_gcm::{
2 aead::{AeadCore, KeyInit},
3 AeadInPlace, Aes256Gcm, Key, Nonce,
4};
5use coset::{iana, CborSerializable, CoseEncrypt0, CoseEncrypt0Builder, HeaderBuilder};
6use rand_core::OsRng;
7
8use ns_protocol::{ns::Value, state::to_bytes};
9
10use super::{skip_tag, with_tag, ENCRYPT0_TAG};
11
12pub struct Encrypt0 {
13 kid: Option<Value>,
14 cipher: Aes256Gcm,
15}
16
17impl Encrypt0 {
18 pub fn new(key: [u8; 32], kid: Option<Value>) -> Self {
19 let key = Key::<Aes256Gcm>::from_slice(&key);
20 let cipher = Aes256Gcm::new(key);
21 Self { kid, cipher }
22 }
23
24 pub fn encrypt(
25 &self,
26 plaintext: &[u8],
27 aad: &[u8],
28 cid: Option<Value>,
29 ) -> anyhow::Result<Vec<u8>> {
30 let protected = HeaderBuilder::new()
31 .algorithm(iana::Algorithm::A256GCM)
32 .build();
33 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
34 let mut unprotected = HeaderBuilder::new()
35 .key_id(to_bytes(&self.kid)?)
36 .iv(nonce.to_vec());
37 if let Some(kid) = self.kid.as_ref() {
38 unprotected = unprotected.key_id(to_bytes(kid)?);
39 }
40 if let Some(cid) = cid {
41 unprotected = unprotected.text_value("cid".to_string(), cid);
42 }
43
44 let e0 = CoseEncrypt0Builder::new()
45 .protected(protected)
46 .unprotected(unprotected.build())
47 .create_ciphertext(plaintext, aad, |plain, enc| {
48 let mut buf: Vec<u8> = Vec::with_capacity(plain.len() + 16);
49 buf.extend_from_slice(plain);
50 self.cipher.encrypt_in_place(&nonce, enc, &mut buf).unwrap();
51 buf
52 })
53 .build();
54 Ok(with_tag(
55 &ENCRYPT0_TAG,
56 e0.to_vec().map_err(anyhow::Error::msg)?.as_slice(),
57 ))
58 }
59
60 pub fn decrypt(&self, encrypt0_data: &[u8], aad: &[u8]) -> anyhow::Result<Vec<u8>> {
61 let e0 = CoseEncrypt0::from_slice(skip_tag(&ENCRYPT0_TAG, encrypt0_data))
62 .map_err(anyhow::Error::msg)?;
63 if e0.unprotected.iv.len() != 12 {
64 return Err(anyhow::Error::msg("invalid iv length"));
65 }
66 let nonce = Nonce::from_slice(&e0.unprotected.iv);
67 e0.decrypt(aad, |cipher, enc| {
68 let mut buf: Vec<u8> = Vec::with_capacity(cipher.len() + 16);
69 buf.extend_from_slice(cipher);
70 self.cipher
71 .decrypt_in_place(nonce, enc, &mut buf)
72 .map_err(anyhow::Error::msg)?;
73 Ok(buf)
74 })
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 use rand_core::RngCore;
83
84 #[test]
85 fn encrypt0_works() {
86 let mut key = [0u8; 32];
87
88 OsRng.fill_bytes(&mut key);
89 let encrypt0 = Encrypt0::new(key, None);
90
91 let plaintext = b"hello world";
92 let data = encrypt0
93 .encrypt(plaintext, b"Name & Service Protocol", None)
94 .unwrap();
95 let res = encrypt0.decrypt(&data, b"Name & Service Protocol").unwrap();
97 assert_eq!(res, plaintext);
98 assert!(encrypt0
99 .decrypt(&data[2..], b"Name & Service Protocol")
100 .is_err());
101 assert!(encrypt0.decrypt(&data, b"NS").is_err());
102 }
103}