use std::fmt;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use async_trait::async_trait;
use smallvec::SmallVec;
use thiserror::Error;
use crate::audit::{BoundedString, BoundedStringTooLong};
use crate::authority::CapabilityKind;
use crate::identity::TraceId;
use crate::proto::{AtUri, Did, Nsid, RecordKey};
use crate::target::SensitiveRepresentation;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct AuditEncryptionKeyId([u8; 32]);
impl AuditEncryptionKeyId {
#[must_use]
pub const fn from_bytes(bytes: [u8; 32]) -> Self {
AuditEncryptionKeyId(bytes)
}
#[must_use]
pub const fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AuditEncryptionAlgorithm {}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct EncryptionContext {
pub capability: CapabilityKind,
pub trace_id: TraceId,
pub operator_context: SmallVec<[(String, Vec<u8>); 2]>,
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum EncryptionError {
#[error("encryption key not found: {key_id:?}")]
KeyNotFound {
key_id: AuditEncryptionKeyId,
},
#[error("encryption algorithm not supported: {0:?}")]
AlgorithmNotSupported(AuditEncryptionAlgorithm),
#[error("encryption payload malformed")]
Malformed,
#[error("encryption access denied: {reason}")]
AccessDenied {
reason: &'static str,
},
#[error("encryption deadline exceeded after {elapsed:?}")]
DeadlineExceeded {
elapsed: Duration,
},
#[error("encryption upstream error: {0}")]
UpstreamError(String),
}
#[async_trait]
pub trait AuditEncryptionResolver: Send + Sync {
async fn encrypt(
&self,
plaintext: &[u8],
context: &EncryptionContext,
deadline: Instant,
) -> Result<SensitiveRepresentation, EncryptionError>;
async fn decrypt(
&self,
sensitive: &SensitiveRepresentation,
context: &EncryptionContext,
deadline: Instant,
) -> Result<Vec<u8>, EncryptionError>;
fn active_key_id(&self) -> AuditEncryptionKeyId;
}
pub async fn produce_sensitive_representation(
plaintext: &[u8],
context: &EncryptionContext,
deadline: Instant,
resolver: Option<&dyn AuditEncryptionResolver>,
) -> Result<Option<SensitiveRepresentation>, EncryptionError> {
match resolver {
None => Ok(None),
Some(r) => r.encrypt(plaintext, context, deadline).await.map(Some),
}
}
pub const MAX_CODEC_ID_LEN: usize = 128;
pub const MAX_ROTATION_GENERATION_MARK_LEN: usize = 128;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CodecId(BoundedString<MAX_CODEC_ID_LEN>);
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum CodecIdError {
#[error("codec id too long: {len} bytes exceeds max {max}")]
TooLong {
len: usize,
max: usize,
},
#[error("codec id contains disallowed character at byte {index}")]
InvalidCharset {
index: usize,
},
#[error("codec id is empty")]
Empty,
}
impl CodecId {
pub fn new(s: impl Into<String>) -> Result<Self, CodecIdError> {
let s = s.into();
if s.is_empty() {
return Err(CodecIdError::Empty);
}
for (index, b) in s.bytes().enumerate() {
if !(b.is_ascii_alphanumeric() || matches!(b, b'/' | b'.' | b'-' | b'_')) {
return Err(CodecIdError::InvalidCharset { index });
}
}
let inner = BoundedString::new(s).map_err(|BoundedStringTooLong { len, bound }| {
CodecIdError::TooLong { len, max: bound }
})?;
Ok(CodecId(inner))
}
#[must_use]
pub fn as_str(&self) -> &str {
self.0.as_str()
}
}
impl fmt::Display for CodecId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.0.as_str())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct RotationGenerationMark(BoundedString<MAX_ROTATION_GENERATION_MARK_LEN>);
impl RotationGenerationMark {
pub fn new(s: impl Into<String>) -> Result<Self, BoundedStringTooLong> {
Ok(RotationGenerationMark(BoundedString::new(s)?))
}
#[must_use]
pub fn as_str(&self) -> &str {
self.0.as_str()
}
}
impl fmt::Display for RotationGenerationMark {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.0.as_str())
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum CodecError {
#[error("content malformed for codec {codec}")]
Malformed {
codec: CodecId,
},
#[error("unknown or wrong codec: stored {stored} != installed {installed}")]
UnknownOrWrongCodec {
stored: CodecId,
installed: CodecId,
},
#[error("rotation state unavailable for codec {codec}")]
RotationStateUnavailable {
codec: CodecId,
},
#[error("codec deadline exceeded after {elapsed:?}")]
DeadlineExceeded {
elapsed: Duration,
},
#[error("codec backend unavailable: {detail}")]
BackendUnavailable {
detail: String,
},
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CodecErrorClass {
Malformed,
UnknownOrWrongCodec,
RotationStateUnavailable,
DeadlineExceeded,
BackendUnavailable,
}
impl CodecError {
#[must_use]
pub fn class(&self) -> CodecErrorClass {
match self {
CodecError::Malformed { .. } => CodecErrorClass::Malformed,
CodecError::UnknownOrWrongCodec { .. } => CodecErrorClass::UnknownOrWrongCodec,
CodecError::RotationStateUnavailable { .. } => {
CodecErrorClass::RotationStateUnavailable
}
CodecError::DeadlineExceeded { .. } => CodecErrorClass::DeadlineExceeded,
CodecError::BackendUnavailable { .. } => CodecErrorClass::BackendUnavailable,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct EncodeContext {
pub nsid: Nsid,
pub rkey: RecordKey,
pub originator: Did,
pub audience_list: Option<AtUri>,
pub current_generation_hint: Option<RotationGenerationMark>,
pub trace_id: TraceId,
pub operator_context: SmallVec<[(String, Vec<u8>); 2]>,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct DecodeContext {
pub nsid: Nsid,
pub rkey: RecordKey,
pub originator: Did,
pub audience_list: Option<AtUri>,
pub trace_id: TraceId,
pub operator_context: SmallVec<[(String, Vec<u8>); 2]>,
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EncodedRecord {
pub codec: CodecId,
pub content: Vec<u8>,
pub generation: Option<RotationGenerationMark>,
}
#[async_trait]
pub trait ContentCodec: Send + Sync {
fn codec_id(&self) -> CodecId;
async fn encode(
&self,
plaintext: &[u8],
context: &EncodeContext,
deadline: Instant,
) -> Result<Vec<u8>, CodecError>;
async fn decode(
&self,
encoded: &EncodedRecord,
context: &DecodeContext,
deadline: Instant,
) -> Result<Vec<u8>, CodecError>;
fn requires_rotation(&self) -> bool {
false
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct RotationContext {
pub originator: Did,
pub nsid: Nsid,
pub audience_list: Option<AtUri>,
}
impl RotationContext {
#[must_use]
pub fn for_install_probe() -> Self {
RotationContext {
originator: Did::new("did:plc:rotationinstallprobe")
.expect("constant probe DID is well-formed"),
nsid: Nsid::new("tools.kryphocron.rotation.installProbe")
.expect("constant probe NSID is well-formed"),
audience_list: None,
}
}
}
pub trait RotationOracle: Send + Sync {
fn current_generation(&self, ctx: &RotationContext) -> Option<RotationGenerationMark>;
fn last_synced_at(&self) -> SystemTime;
fn data_freshness_bound(&self) -> Duration;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoRotationOracle;
impl RotationOracle for NoRotationOracle {
fn current_generation(&self, _ctx: &RotationContext) -> Option<RotationGenerationMark> {
None
}
fn last_synced_at(&self) -> SystemTime {
SystemTime::UNIX_EPOCH
}
fn data_freshness_bound(&self) -> Duration {
Duration::MAX
}
}
pub trait AtRestHooks: Send + Sync {
fn audit(&self) -> Option<Arc<dyn AuditEncryptionResolver>>;
fn content_codec(&self) -> Arc<dyn ContentCodec>;
fn rotation_oracle(&self) -> Option<Arc<dyn RotationOracle>>;
}
#[derive(Clone)]
pub struct DefaultAtRestHooks {
codec: Arc<dyn ContentCodec>,
rotation_oracle: Arc<dyn RotationOracle>,
audit: Option<Arc<dyn AuditEncryptionResolver>>,
}
impl DefaultAtRestHooks {
pub fn for_data_dir(
data_dir: PathBuf,
) -> Result<Self, crate::codec::laquna::RotationOracleConstructionError> {
Ok(Self {
codec: Arc::new(crate::codec::laquna::Codec::default()),
rotation_oracle: Arc::new(
crate::codec::laquna::DefaultRotationOracle::for_data_dir(data_dir)?,
),
audit: None,
})
}
#[must_use]
pub fn builder(data_dir: PathBuf) -> DefaultAtRestHooksBuilder {
DefaultAtRestHooksBuilder {
data_dir,
codec: None,
rotation_oracle: None,
audit: None,
}
}
}
pub struct DefaultAtRestHooksBuilder {
data_dir: PathBuf,
codec: Option<Arc<dyn ContentCodec>>,
rotation_oracle: Option<Arc<dyn RotationOracle>>,
audit: Option<Arc<dyn AuditEncryptionResolver>>,
}
impl DefaultAtRestHooksBuilder {
#[must_use]
pub fn with_codec(mut self, codec: Arc<dyn ContentCodec>) -> Self {
self.codec = Some(codec);
self
}
#[must_use]
pub fn with_rotation_oracle(mut self, oracle: Arc<dyn RotationOracle>) -> Self {
self.rotation_oracle = Some(oracle);
self
}
#[must_use]
pub fn with_audit(mut self, audit: Arc<dyn AuditEncryptionResolver>) -> Self {
self.audit = Some(audit);
self
}
pub fn build(
self,
) -> Result<DefaultAtRestHooks, crate::codec::laquna::RotationOracleConstructionError> {
let codec: Arc<dyn ContentCodec> = self
.codec
.unwrap_or_else(|| Arc::new(crate::codec::laquna::Codec::default()));
let rotation_oracle: Arc<dyn RotationOracle> = match self.rotation_oracle {
Some(o) => o,
None => Arc::new(crate::codec::laquna::DefaultRotationOracle::for_data_dir(
self.data_dir,
)?),
};
Ok(DefaultAtRestHooks {
codec,
rotation_oracle,
audit: self.audit,
})
}
}
impl AtRestHooks for DefaultAtRestHooks {
fn audit(&self) -> Option<Arc<dyn AuditEncryptionResolver>> {
self.audit.clone()
}
fn content_codec(&self) -> Arc<dyn ContentCodec> {
self.codec.clone()
}
fn rotation_oracle(&self) -> Option<Arc<dyn RotationOracle>> {
Some(self.rotation_oracle.clone())
}
}
pub fn resolve_rotation_generation(
oracle: Option<&dyn RotationOracle>,
codec: &CodecId,
ctx: &RotationContext,
now: SystemTime,
) -> Result<Option<RotationGenerationMark>, CodecError> {
match oracle {
None => Ok(None),
Some(o) => {
let stale = match now.duration_since(o.last_synced_at()) {
Ok(age) => age > o.data_freshness_bound(),
Err(_) => true,
};
if stale {
return Err(CodecError::RotationStateUnavailable {
codec: codec.clone(),
});
}
Ok(o.current_generation(ctx))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn audit_algorithm_enum_has_zero_v1_variants() {
fn _assert_audit_alg_zero_variants(a: AuditEncryptionAlgorithm) -> ! {
match a {}
}
}
#[test]
fn audit_key_id_bytes_round_trip() {
let bytes = [0xCC; 32];
assert_eq!(AuditEncryptionKeyId::from_bytes(bytes).as_bytes(), &bytes);
}
#[tokio::test]
async fn produce_sensitive_returns_none_when_resolver_absent() {
let context = EncryptionContext {
capability: CapabilityKind::ViewPrivate,
trace_id: TraceId::from_bytes([0; 16]),
operator_context: SmallVec::new(),
};
let deadline = Instant::now() + Duration::from_secs(30);
let result = produce_sensitive_representation(b"plaintext", &context, deadline, None)
.await
.unwrap();
assert!(result.is_none());
}
struct AlwaysAccessDenied;
#[async_trait]
impl AuditEncryptionResolver for AlwaysAccessDenied {
async fn encrypt(
&self,
_plaintext: &[u8],
_context: &EncryptionContext,
_deadline: Instant,
) -> Result<SensitiveRepresentation, EncryptionError> {
Err(EncryptionError::AccessDenied {
reason: "mock resolver: always denies",
})
}
async fn decrypt(
&self,
_sensitive: &SensitiveRepresentation,
_context: &EncryptionContext,
_deadline: Instant,
) -> Result<Vec<u8>, EncryptionError> {
Err(EncryptionError::AccessDenied {
reason: "mock resolver: always denies",
})
}
fn active_key_id(&self) -> AuditEncryptionKeyId {
AuditEncryptionKeyId::from_bytes([0xFF; 32])
}
}
#[tokio::test]
async fn produce_sensitive_propagates_resolver_error() {
let context = EncryptionContext {
capability: CapabilityKind::ViewPrivate,
trace_id: TraceId::from_bytes([0; 16]),
operator_context: SmallVec::new(),
};
let deadline = Instant::now() + Duration::from_secs(30);
let resolver = AlwaysAccessDenied;
let err = produce_sensitive_representation(
b"plaintext",
&context,
deadline,
Some(&resolver as &dyn AuditEncryptionResolver),
)
.await
.unwrap_err();
assert!(matches!(
err,
EncryptionError::AccessDenied {
reason: "mock resolver: always denies",
}
));
}
#[test]
fn codec_id_new_validates() {
assert_eq!(CodecId::new("laquna/0.2").unwrap().as_str(), "laquna/0.2");
assert!(matches!(CodecId::new(""), Err(CodecIdError::Empty)));
assert!(matches!(
CodecId::new("bad space"),
Err(CodecIdError::InvalidCharset { index: 3 })
));
let over = "a".repeat(MAX_CODEC_ID_LEN + 1);
assert!(matches!(
CodecId::new(over),
Err(CodecIdError::TooLong {
len,
max: MAX_CODEC_ID_LEN
}) if len == MAX_CODEC_ID_LEN + 1
));
}
#[test]
fn rotation_generation_mark_round_trips_and_bounds() {
assert_eq!(RotationGenerationMark::new("000042").unwrap().as_str(), "000042");
let over = "a".repeat(MAX_ROTATION_GENERATION_MARK_LEN + 1);
assert!(RotationGenerationMark::new(over).is_err());
}
#[test]
fn codec_error_class_maps_each_variant() {
let c = CodecId::new("laquna/0.2").unwrap();
assert_eq!(
CodecError::Malformed { codec: c.clone() }.class(),
CodecErrorClass::Malformed
);
assert_eq!(
CodecError::RotationStateUnavailable { codec: c }.class(),
CodecErrorClass::RotationStateUnavailable
);
assert_eq!(
CodecError::DeadlineExceeded {
elapsed: Duration::from_secs(1)
}
.class(),
CodecErrorClass::DeadlineExceeded
);
}
#[test]
fn default_at_rest_hooks_installs_real_codec_and_oracle() {
let dir = std::env::temp_dir().join(format!(
"kryphocron-hooks-{}-{}",
std::process::id(),
"default"
));
let hooks = DefaultAtRestHooks::for_data_dir(dir.clone()).expect("construct");
assert_eq!(hooks.content_codec().codec_id().as_str(), "laquna/0.2");
assert!(hooks.rotation_oracle().is_some());
assert!(hooks.audit().is_none());
let _ = std::fs::remove_dir_all(&dir);
}
struct StubOracle {
generation: Option<RotationGenerationMark>,
synced: SystemTime,
bound: Duration,
}
impl RotationOracle for StubOracle {
fn current_generation(&self, _ctx: &RotationContext) -> Option<RotationGenerationMark> {
self.generation.clone()
}
fn last_synced_at(&self) -> SystemTime {
self.synced
}
fn data_freshness_bound(&self) -> Duration {
self.bound
}
}
fn rotation_ctx() -> RotationContext {
RotationContext {
originator: Did::new("did:plc:exampleexampleexample").unwrap(),
nsid: Nsid::new("tools.kryphocron.feed.postPrivate").unwrap(),
audience_list: None,
}
}
#[test]
fn resolve_rotation_generation_no_oracle_is_none() {
let codec = CodecId::new("laquna/0.2").unwrap();
let got = resolve_rotation_generation(None, &codec, &rotation_ctx(), SystemTime::now())
.unwrap();
assert!(got.is_none());
}
#[test]
fn resolve_rotation_generation_fresh_returns_value() {
let codec = CodecId::new("laquna/0.2").unwrap();
let now = SystemTime::now();
let oracle = StubOracle {
generation: Some(RotationGenerationMark::new("000042").unwrap()),
synced: now,
bound: Duration::from_secs(3600),
};
let got = resolve_rotation_generation(Some(&oracle), &codec, &rotation_ctx(), now)
.unwrap()
.unwrap();
assert_eq!(got.as_str(), "000042");
}
#[test]
fn resolve_rotation_generation_stale_fails_closed() {
let codec = CodecId::new("laquna/0.2").unwrap();
let now = SystemTime::now();
let oracle = StubOracle {
generation: Some(RotationGenerationMark::new("000042").unwrap()),
synced: now - Duration::from_secs(7200),
bound: Duration::from_secs(3600),
};
let err = resolve_rotation_generation(Some(&oracle), &codec, &rotation_ctx(), now)
.unwrap_err();
assert_eq!(err.class(), CodecErrorClass::RotationStateUnavailable);
}
#[test]
fn no_rotation_oracle_never_stale_and_none() {
let codec = CodecId::new("laquna/0.2").unwrap();
let oracle = NoRotationOracle;
let got = resolve_rotation_generation(Some(&oracle), &codec, &rotation_ctx(), SystemTime::now())
.unwrap();
assert!(got.is_none());
}
}