use std::sync::Arc;
use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use zeroize::Zeroizing;
use crate::crypto::encryption::Encryptor;
use crate::crypto::keymanager::{KeyManager, WrappedDek};
use crate::error::{AppError, AppResult};
pub const ENC_VERSION: i16 = 1;
pub const ENC_ALG: &str = "AES-256-GCM";
#[derive(Clone, Debug)]
pub struct EnvelopeRecord {
pub ciphertext: Vec<u8>,
pub wrapped: WrappedDek,
pub enc_alg: String,
pub enc_version: i16,
}
#[derive(Serialize, Deserialize)]
struct StoredEnvelope {
v: i16,
alg: String,
ct: String,
dek: StoredDek,
}
#[derive(Serialize, Deserialize)]
struct StoredDek {
p: String,
kid: String,
kv: String,
ct: String,
}
impl EnvelopeRecord {
pub fn to_storage_string(&self) -> String {
let s = StoredEnvelope {
v: self.enc_version,
alg: self.enc_alg.clone(),
ct: B64.encode(&self.ciphertext),
dek: StoredDek {
p: self.wrapped.provider.clone(),
kid: self.wrapped.key_id.clone(),
kv: self.wrapped.key_version.clone(),
ct: B64.encode(&self.wrapped.ciphertext),
},
};
serde_json::to_string(&s).expect("StoredEnvelope serialises")
}
pub fn from_storage_str(raw: &str) -> AppResult<EnvelopeRecord> {
let s: StoredEnvelope = serde_json::from_str(raw.trim())
.map_err(|e| AppError::Encryption(format!("not a wallet envelope record: {e}")))?;
let ciphertext = B64
.decode(s.ct.as_bytes())
.map_err(|e| AppError::Encryption(format!("envelope ct base64: {e}")))?;
let dek_ct = B64
.decode(s.dek.ct.as_bytes())
.map_err(|e| AppError::Encryption(format!("envelope dek base64: {e}")))?;
Ok(EnvelopeRecord {
ciphertext,
wrapped: WrappedDek {
provider: s.dek.p,
key_id: s.dek.kid,
key_version: s.dek.kv,
ciphertext: dek_ct,
},
enc_alg: s.alg,
enc_version: s.v,
})
}
}
#[derive(Clone)]
pub struct EnvelopeCipher {
km: Arc<dyn KeyManager>,
}
impl EnvelopeCipher {
pub fn new(km: Arc<dyn KeyManager>) -> Self {
Self { km }
}
pub fn provider(&self) -> &str {
self.km.provider()
}
pub async fn seal(&self, plaintext: &[u8]) -> AppResult<EnvelopeRecord> {
let mut dek = Zeroizing::new(vec![0u8; 32]);
rand::thread_rng().fill_bytes(dek.as_mut_slice());
let cipher = Encryptor::from_bytes(&dek)?;
let ciphertext = cipher.encrypt(plaintext)?;
let wrapped = self.km.wrap_dek(&dek).await?;
Ok(EnvelopeRecord {
ciphertext,
wrapped,
enc_alg: ENC_ALG.to_string(),
enc_version: ENC_VERSION,
})
}
pub async fn open(&self, record: &EnvelopeRecord) -> AppResult<Vec<u8>> {
let dek = self.km.unwrap_dek(&record.wrapped).await?;
let cipher = Encryptor::from_bytes(&dek)?;
cipher.decrypt(&record.ciphertext)
}
pub async fn seal_json(&self, value: &serde_json::Value) -> AppResult<EnvelopeRecord> {
let bytes = serde_json::to_vec(value)
.map_err(|e| crate::error::AppError::Encryption(format!("serialize: {e}")))?;
self.seal(&bytes).await
}
pub async fn open_json(&self, record: &EnvelopeRecord) -> AppResult<serde_json::Value> {
let bytes = self.open(record).await?;
serde_json::from_slice(&bytes)
.map_err(|e| crate::error::AppError::Encryption(format!("deserialize: {e}")))
}
pub async fn seal_json_to_storage(&self, value: &serde_json::Value) -> AppResult<String> {
Ok(self.seal_json(value).await?.to_storage_string())
}
pub async fn open_storage_json(&self, raw: &str) -> AppResult<serde_json::Value> {
let record = EnvelopeRecord::from_storage_str(raw)?;
self.open_json(&record).await
}
pub async fn rewrap_storage_string(&self, raw: &str) -> AppResult<RewrapOutcome> {
let record = EnvelopeRecord::from_storage_str(raw)?;
let current = self.km.current_key_version();
if record.wrapped.key_version == current {
return Ok(RewrapOutcome::Skipped {
key_version: current.to_string(),
});
}
let dek = self.km.unwrap_dek(&record.wrapped).await?;
let new_wrapped = self.km.wrap_dek(&dek).await?;
let new_record = EnvelopeRecord {
ciphertext: record.ciphertext,
wrapped: new_wrapped,
enc_alg: record.enc_alg,
enc_version: record.enc_version,
};
Ok(RewrapOutcome::Rewrapped {
old_key_version: record.wrapped.key_version,
new_key_version: current.to_string(),
new_storage_string: new_record.to_storage_string(),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RewrapOutcome {
Skipped { key_version: String },
Rewrapped {
old_key_version: String,
new_key_version: String,
new_storage_string: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::keymanager::LocalDevKms;
fn cipher() -> EnvelopeCipher {
let key = Encryptor::generate_key_base64();
let km = LocalDevKms::from_master_key_base64(&key).unwrap();
EnvelopeCipher::new(Arc::new(km))
}
#[tokio::test]
async fn seal_open_round_trips() {
let c = cipher();
let pt = b"super-secret-password";
let rec = c.seal(pt).await.unwrap();
assert_eq!(rec.enc_version, ENC_VERSION);
assert_eq!(rec.enc_alg, ENC_ALG);
assert_ne!(rec.ciphertext, pt);
assert!(!rec.wrapped.ciphertext.is_empty());
let out = c.open(&rec).await.unwrap();
assert_eq!(out, pt);
}
#[tokio::test]
async fn two_seals_use_distinct_deks() {
let c = cipher();
let a = c.seal(b"x").await.unwrap();
let b = c.seal(b"x").await.unwrap();
assert_ne!(a.wrapped.ciphertext, b.wrapped.ciphertext);
assert_ne!(a.ciphertext, b.ciphertext);
}
#[tokio::test]
async fn seal_open_json() {
let c = cipher();
let v = serde_json::json!({"db_host": "pg", "db_password": "p@ss"});
let rec = c.seal_json(&v).await.unwrap();
let out = c.open_json(&rec).await.unwrap();
assert_eq!(out, v);
}
#[tokio::test]
async fn record_from_another_kek_does_not_open() {
let c1 = cipher();
let c2 = cipher();
let rec = c1.seal(b"secret").await.unwrap();
assert!(c2.open(&rec).await.is_err());
}
#[tokio::test]
async fn storage_string_round_trips() {
let c = cipher();
let v = serde_json::json!({"db_host": "pg", "db_password": "p@ss;'"});
let stored = c.seal_json_to_storage(&v).await.unwrap();
assert!(stored.contains("\"v\":1"));
assert!(stored.contains("\"local\""));
let out = c.open_storage_json(&stored).await.unwrap();
assert_eq!(out, v);
}
#[tokio::test]
async fn non_envelope_storage_value_errors() {
let c = cipher();
assert!(c.open_storage_json("AAAAlegacy-base64==").await.is_err());
}
#[tokio::test]
async fn rewrap_skips_records_already_on_current_version() {
let c = cipher();
let v = serde_json::json!({"db_password": "p@ss"});
let stored = c.seal_json_to_storage(&v).await.unwrap();
match c.rewrap_storage_string(&stored).await.unwrap() {
RewrapOutcome::Skipped { key_version } => {
assert_eq!(key_version, "v1");
}
other => panic!("expected Skipped, got {other:?}"),
}
}
#[tokio::test]
async fn rewrap_emits_new_envelope_under_current_version_when_older() {
let key = Encryptor::generate_key_base64();
let km_v1 = LocalDevKms::from_master_key_base64(&key).unwrap();
let c_v1 = EnvelopeCipher::new(Arc::new(km_v1));
let stored_v1 = c_v1
.seal_json_to_storage(&serde_json::json!({"a": 1}))
.await
.unwrap();
let km_v2 = LocalDevKms::from_master_key_base64_with_version(&key, "v2").unwrap();
let c_v2 = EnvelopeCipher::new(Arc::new(km_v2));
match c_v2.rewrap_storage_string(&stored_v1).await.unwrap() {
RewrapOutcome::Rewrapped {
old_key_version,
new_key_version,
new_storage_string,
} => {
assert_eq!(old_key_version, "v1");
assert_eq!(new_key_version, "v2");
let opened = c_v2.open_storage_json(&new_storage_string).await.unwrap();
assert_eq!(opened, serde_json::json!({"a": 1}));
assert!(new_storage_string.contains("\"kv\":\"v2\""));
}
other => panic!("expected Rewrapped, got {other:?}"),
}
}
#[tokio::test]
async fn rewrap_rejects_non_envelope_storage_value() {
let c = cipher();
let err = c
.rewrap_storage_string("not-a-wallet-envelope")
.await
.unwrap_err();
assert!(format!("{err:?}").contains("not a wallet envelope"));
}
#[test]
fn local_kms_reports_its_key_version() {
let key = Encryptor::generate_key_base64();
let km = LocalDevKms::from_master_key_base64(&key).unwrap();
assert_eq!(km.current_key_version(), "v1");
}
}