use super::{KeyScope, KeyScopeGranularity, KeyStore, PayloadKey, KEY_LEN};
use crate::store::file_classification::KEYSET_FILENAME;
use crate::store::platform::fs::{read as fs_read, write_file_atomically_with_fs, RealFs, StoreFs};
use crate::store::StoreError;
use std::collections::BTreeMap;
use std::path::Path;
use zeroize::{Zeroize, Zeroizing};
pub(crate) const KEYSET_MAGIC: &[u8; 6] = b"FBATKS";
pub(crate) const KEYSET_VERSION: u16 = 1;
const HEADER_LEN: usize = 6 + 2 + 4;
const DISC_PER_ENTITY: u8 = 1;
const DISC_PER_CATEGORY: u8 = 2;
const DISC_PER_TYPE_ID: u8 = 3;
const DISC_PER_EVENT: u8 = 4;
fn granularity_to_disc(granularity: KeyScopeGranularity) -> u8 {
match granularity {
KeyScopeGranularity::PerEntity => DISC_PER_ENTITY,
KeyScopeGranularity::PerCategory => DISC_PER_CATEGORY,
KeyScopeGranularity::PerTypeId => DISC_PER_TYPE_ID,
KeyScopeGranularity::PerEvent => DISC_PER_EVENT,
}
}
fn granularity_from_disc(disc: u8) -> Option<KeyScopeGranularity> {
match disc {
DISC_PER_ENTITY => Some(KeyScopeGranularity::PerEntity),
DISC_PER_CATEGORY => Some(KeyScopeGranularity::PerCategory),
DISC_PER_TYPE_ID => Some(KeyScopeGranularity::PerTypeId),
DISC_PER_EVENT => Some(KeyScopeGranularity::PerEvent),
_ => None,
}
}
#[derive(serde::Serialize, serde::Deserialize)]
struct KeysetWire {
granularity: u8,
entries: Vec<KeysetEntryWire>,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct KeysetEntryWire {
scope: Vec<u8>,
key: [u8; KEY_LEN],
}
fn corrupt(reason: String) -> StoreError {
StoreError::KeysetCorrupt { reason }
}
impl KeyStore {
pub fn flush(&mut self, dir: &Path) -> Result<(), StoreError> {
self.flush_with_fs(dir, &RealFs)
}
pub(crate) fn flush_with_fs(&mut self, dir: &Path, fs: &dyn StoreFs) -> Result<(), StoreError> {
let mut wire = KeysetWire {
granularity: granularity_to_disc(self.granularity),
entries: Vec::with_capacity(self.keys.len()),
};
for (scope, key) in &self.keys {
wire.entries.push(KeysetEntryWire {
scope: scope.0.to_vec(),
key: *key.0,
});
}
let body = Zeroizing::new(
crate::encoding::to_bytes(&wire)
.map_err(|error| StoreError::ser_msg(&format!("encode keyset: {error}")))?,
);
for entry in &mut wire.entries {
entry.key.zeroize();
}
let crc = crc32fast::hash(&body);
let final_path = dir.join(KEYSET_FILENAME);
write_file_atomically_with_fs(
dir,
&final_path,
"crypto-shred-keyset",
|file| {
use std::io::Write;
file.write_all(KEYSET_MAGIC).map_err(StoreError::Io)?;
file.write_all(&KEYSET_VERSION.to_le_bytes())
.map_err(StoreError::Io)?;
file.write_all(&crc.to_le_bytes()).map_err(StoreError::Io)?;
file.write_all(&body).map_err(StoreError::Io)?;
Ok(())
},
fs,
)?;
self.dirty = false;
tracing::debug!(
target: "batpak::keyscope",
count = self.keys.len(),
"flushed crypto-shred keyset"
);
Ok(())
}
pub fn load(dir: &Path, granularity: KeyScopeGranularity) -> Result<Self, StoreError> {
Self::load_with_fs(dir, &RealFs, granularity)
}
pub(crate) fn load_with_fs(
dir: &Path,
fs: &dyn StoreFs,
granularity: KeyScopeGranularity,
) -> Result<Self, StoreError> {
let path = dir.join(KEYSET_FILENAME);
fs.reject_symlink_leaf(&path, "crypto-shred-keyset")?;
let raw = match fs_read(&path) {
Ok(bytes) => Zeroizing::new(bytes),
Err(error) if error.kind() == std::io::ErrorKind::NotFound => {
return Ok(Self::new_absent(granularity));
}
Err(error) => return Err(StoreError::Io(error)),
};
decode_keyset(&raw, granularity)
}
}
fn validate_header_and_body(raw: &[u8]) -> Result<&[u8], StoreError> {
if raw.len() < HEADER_LEN {
return Err(corrupt(format!("file too short: {} bytes", raw.len())));
}
if &raw[..6] != KEYSET_MAGIC.as_ref() {
return Err(corrupt("wrong magic bytes".to_owned()));
}
let version = u16::from_le_bytes([raw[6], raw[7]]);
if version != KEYSET_VERSION {
return Err(corrupt(format!(
"unsupported keyset version {version}; this binary reads and writes version \
{KEYSET_VERSION}"
)));
}
let stored_crc = u32::from_le_bytes([raw[8], raw[9], raw[10], raw[11]]);
let body = &raw[HEADER_LEN..];
let computed_crc = crc32fast::hash(body);
if stored_crc != computed_crc {
return Err(corrupt(format!(
"crc mismatch: stored {stored_crc:#010x}, computed {computed_crc:#010x}"
)));
}
Ok(body)
}
fn decode_keyset(raw: &[u8], configured: KeyScopeGranularity) -> Result<KeyStore, StoreError> {
let body = validate_header_and_body(raw)?;
let mut wire: KeysetWire = crate::encoding::from_bytes(body)
.map_err(|error| corrupt(format!("decode keyset body: {error}")))?;
let result = rehydrate(&wire, configured);
for entry in &mut wire.entries {
entry.key.zeroize();
}
result
}
fn rehydrate(wire: &KeysetWire, configured: KeyScopeGranularity) -> Result<KeyStore, StoreError> {
let persisted = granularity_from_disc(wire.granularity).ok_or_else(|| {
corrupt(format!(
"unknown key-scope granularity discriminant {}",
wire.granularity
))
})?;
if persisted != configured {
return Err(corrupt(format!(
"configured key-scope granularity {configured:?} does not match persisted keyset \
granularity {persisted:?}"
)));
}
let mut keys = BTreeMap::new();
for entry in &wire.entries {
let scope = KeyScope(entry.scope.clone().into_boxed_slice());
let key = PayloadKey(Zeroizing::new(entry.key));
keys.insert(scope, key);
}
Ok(KeyStore {
keys,
granularity: configured,
dirty: false,
absent_on_load: false,
})
}
#[cfg(test)]
mod tests;
#[cfg(all(test, feature = "dangerous-test-hooks"))]
mod crash_tests;