use zeroize::Zeroize;
use aes_gcm::{Aes256Gcm, KeyInit, aead::Aead};
use generic_array::GenericArray;
use rand_core::{OsRng, RngCore};
use std::time::{Instant};
#[derive(Clone)]
pub struct ScatteredParts {
pub(crate) l1_part: [u8; 8],
pub(crate) l2_part: [u8; 16],
pub(crate) ram_part_encrypted: [u8; 32],
pub(crate) ram_part_iv: [u8; 12],
pub(crate) register_seed: [u8; 16],
}
impl ScatteredParts {
pub fn new() -> Self {
Self {
l1_part: [0; 8],
l2_part: [0; 16],
ram_part_encrypted: [0; 32],
ram_part_iv: [0; 12],
register_seed: [0; 16],
}
}
}
impl Zeroize for ScatteredParts {
fn zeroize(&mut self) {
self.l1_part.zeroize();
self.l2_part.zeroize();
self.ram_part_encrypted.zeroize();
self.ram_part_iv.zeroize();
self.register_seed.zeroize();
}
}
impl Drop for ScatteredParts {
fn drop(&mut self) {
self.zeroize();
}
}
pub struct MemoryScatterer {
encryption_key: [u8; 32],
}
impl MemoryScatterer {
pub fn new() -> Self {
let mut encryption_key = [0u8; 32];
OsRng.fill_bytes(&mut encryption_key);
Self { encryption_key }
}
pub fn scatter(&self, master_key: &[u8; 32]) -> ScatteredParts {
let start = Instant::now();
let mut rng = OsRng;
let mut l1_part = [0u8; 8];
let mut l2_part = [0u8; 16];
let mut register_seed = [0u8; 16];
let mut ram_part_iv = [0u8; 12];
rng.fill_bytes(&mut l1_part);
rng.fill_bytes(&mut l2_part);
rng.fill_bytes(&mut register_seed);
rng.fill_bytes(&mut ram_part_iv);
let mut ram_part = [0u8; 32];
for i in 0..32 {
let mut value = master_key[i];
if i < 8 {
value ^= l1_part[i];
}
if i < 16 {
value ^= l2_part[i % 16];
}
if i < 16 {
value ^= register_seed[i % 16];
}
ram_part[i] = value;
}
let cipher = Aes256Gcm::new_from_slice(&self.encryption_key)
.expect("Valid encryption key");
let nonce = GenericArray::from_slice(&ram_part_iv);
let ram_part_encrypted_bytes = cipher.encrypt(nonce, &ram_part[..])
.expect("Encryption failed");
let ram_part_encrypted: [u8; 32] = ram_part_encrypted_bytes.try_into()
.unwrap_or_else(|_| [0; 32]);
ram_part.zeroize();
let elapsed = start.elapsed();
#[cfg(feature = "metrics")]
metrics::histogram!("phantom.scatterer.scatter_time", elapsed.as_nanos() as f64);
ScatteredParts {
l1_part,
l2_part,
ram_part_encrypted,
ram_part_iv,
register_seed,
}
}
pub fn rescatter(&self, parts: &mut ScatteredParts, master_key: &[u8; 32]) {
parts.zeroize();
*parts = self.scatter(master_key);
}
}
impl Default for MemoryScatterer {
fn default() -> Self {
Self::new()
}
}