use chacha20poly1305::{
ChaCha20Poly1305, KeyInit as _,
aead::{Aead as _, Payload},
};
use rand::{RngCore, rng};
use std::{collections::HashMap, fmt::Debug};
use super::{Error, Key, KeyId, StrongBox};
#[derive(Clone, Debug)]
pub struct StaticStrongBox {
encryption_key: Key,
encryption_key_id: KeyId,
decryption_keys: HashMap<KeyId, Key>,
}
impl StaticStrongBox {
#[tracing::instrument(level = "debug", skip(enc_key, dec_keys))]
pub fn new(
enc_key: impl Into<Key>,
dec_keys: impl IntoIterator<Item = impl Into<Key>>,
) -> Self {
let mut key_map: HashMap<KeyId, Key> = HashMap::default();
for key in dec_keys.into_iter() {
let key = key.into();
let key_id = super::key_id(&key);
tracing::debug!(%key_id, "Including decryption key");
key_map.insert(key_id, key);
}
let enc_key = enc_key.into();
let enc_key_id = super::key_id(&enc_key);
tracing::debug!("Encryption key is {enc_key_id}");
Self {
encryption_key_id: enc_key_id,
encryption_key: enc_key,
decryption_keys: key_map,
}
}
pub(crate) fn decrypt_ciphertext(
&self,
ciphertext: &Ciphertext,
ctx: &[u8],
) -> Result<Vec<u8>, Error> {
if let Some(key) = self.decryption_keys.get(&ciphertext.key_id) {
tracing::debug!(key_id=%ciphertext.key_id, "Decrypting");
let mut aad = Vec::<u8>::new();
aad.extend_from_slice(ctx.as_ref());
aad.extend_from_slice(ciphertext.key_id.as_bytes());
aad.extend_from_slice(&ciphertext.nonce);
let cipher = ChaCha20Poly1305::new(key.expose_secret().into());
let payload = Payload {
msg: &ciphertext.ciphertext,
aad: &aad,
};
cipher
.decrypt((&ciphertext.nonce[..]).into(), payload)
.map_err(|_| Error::Decryption)
} else {
tracing::debug!(key_id=%ciphertext.key_id, "Decryption key not found");
Err(Error::Decryption)
}
}
}
impl StrongBox for StaticStrongBox {
#[tracing::instrument(level = "debug", skip(plaintext))]
fn encrypt(
&self,
plaintext: impl AsRef<[u8]>,
ctx: impl AsRef<[u8]> + Debug,
) -> Result<Vec<u8>, Error> {
let cipher = ChaCha20Poly1305::new((self.encryption_key.expose_secret()).into());
let mut rng = rng();
let mut nonce = [0u8; 12];
rng.fill_bytes(&mut nonce);
let mut aad = Vec::<u8>::new();
aad.extend_from_slice(ctx.as_ref());
aad.extend_from_slice(self.encryption_key_id.as_bytes());
aad.extend_from_slice(&nonce);
let ciphertext = cipher
.encrypt(
(&nonce).into(),
Payload {
msg: plaintext.as_ref(),
aad: &aad,
},
)
.map_err(|_| Error::Encryption)?;
tracing::debug!(key_id=%self.encryption_key_id, "Encrypting");
Ciphertext::new(self.encryption_key_id, nonce, ciphertext).to_bytes()
}
#[tracing::instrument(level = "debug", skip(ciphertext))]
fn decrypt(
&self,
ciphertext: impl AsRef<[u8]>,
ctx: impl AsRef<[u8]> + Debug,
) -> Result<Vec<u8>, Error> {
let ciphertext = Ciphertext::try_from(ciphertext.as_ref())?;
self.decrypt_ciphertext(&ciphertext, ctx.as_ref())
}
}
const CIPHERTEXT_MAGIC: [u8; 3] = [0xb1, 0xb8, 0xf5];
#[derive(Clone, Debug)]
pub(crate) struct Ciphertext {
pub(crate) key_id: KeyId,
pub(crate) nonce: [u8; 12],
pub(crate) ciphertext: Vec<u8>,
}
impl Ciphertext {
pub(crate) fn new(key_id: KeyId, nonce: [u8; 12], ciphertext: Vec<u8>) -> Self {
Self {
key_id,
nonce,
ciphertext,
}
}
pub(crate) fn to_bytes(&self) -> Result<Vec<u8>, Error> {
use ciborium_ll::{Encoder, Header};
let mut v: Vec<u8> = Vec::new();
v.extend_from_slice(&CIPHERTEXT_MAGIC);
let mut enc = Encoder::from(&mut v);
enc.push(Header::Array(Some(3)))
.map_err(|e| Error::ciphertext_encoding("key_id", e))?;
self.key_id.encode(&mut enc)?;
enc.bytes(&self.nonce, None)
.map_err(|e| Error::ciphertext_encoding("nonce", e))?;
enc.bytes(&self.ciphertext, None)
.map_err(|e| Error::ciphertext_encoding("ciphertext", e))?;
tracing::debug!(
nonce = self
.nonce
.iter()
.map(|i| format!("{i:02x}"))
.collect::<Vec<_>>()
.join(""),
ct = self
.ciphertext
.iter()
.map(|i| format!("{i:02x}"))
.collect::<Vec<_>>()
.join(""),
"{}",
v.iter()
.map(|i| format!("{i:02x}"))
.collect::<Vec<_>>()
.join("")
);
Ok(v)
}
}
impl TryFrom<&[u8]> for Ciphertext {
type Error = Error;
fn try_from(b: &[u8]) -> Result<Self, Self::Error> {
use ciborium_ll::{Decoder, Header};
if b.len() < 21 {
return Err(Error::invalid_ciphertext("too short"));
}
if b[0..3] != CIPHERTEXT_MAGIC {
tracing::debug!(magic=?CIPHERTEXT_MAGIC, actual=?b[0..3]);
return Err(Error::invalid_ciphertext("incorrect magic"));
}
let mut dec = Decoder::from(&b[3..]);
let Header::Array(Some(3)) = dec
.pull()
.map_err(|e| Error::ciphertext_decoding("array", e))?
else {
return Err(Error::invalid_ciphertext("expected array"));
};
let key_id = KeyId::decode(&mut dec)?;
let Header::Bytes(len) = dec
.pull()
.map_err(|e| Error::ciphertext_decoding("nonce header", e))?
else {
return Err(Error::invalid_ciphertext("expected nonce"));
};
let mut segments = dec.bytes(len);
let Ok(Some(mut segment)) = segments.pull() else {
return Err(Error::invalid_ciphertext("bad nonce"));
};
let mut buf = [0u8; 1024];
let mut nonce = [0u8; 12];
if let Some(chunk) = segment
.pull(&mut buf[..])
.map_err(|e| Error::ciphertext_decoding("nonce", e))?
{
nonce[..].copy_from_slice(chunk);
} else {
return Err(Error::invalid_ciphertext("short nonce"));
}
let Header::Bytes(len) = dec
.pull()
.map_err(|e| Error::ciphertext_decoding("ciphertext header", e))?
else {
return Err(Error::invalid_ciphertext("expected ciphertext"));
};
let mut segments = dec.bytes(len);
let Ok(Some(mut segment)) = segments.pull() else {
return Err(Error::invalid_ciphertext("bad ciphertext"));
};
let mut ciphertext: Vec<u8> = Vec::new();
while let Some(chunk) = segment
.pull(&mut buf[..])
.map_err(|e| Error::ciphertext_decoding("ciphertext", e))?
{
ciphertext.extend_from_slice(chunk);
}
Ok(Self {
key_id,
nonce,
ciphertext,
})
}
}