use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use smallvec::SmallVec;
use thiserror::Error;
use crate::authority::CapabilityKind;
use crate::identity::TraceId;
use crate::proto::{AtUri, Did, Nsid};
use crate::target::SensitiveRepresentation;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct AuditEncryptionKeyId([u8; 32]);
impl AuditEncryptionKeyId {
#[must_use]
pub const fn from_bytes(bytes: [u8; 32]) -> Self {
AuditEncryptionKeyId(bytes)
}
#[must_use]
pub const fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RecordEncryptionKeyId([u8; 32]);
impl RecordEncryptionKeyId {
#[must_use]
pub const fn from_bytes(bytes: [u8; 32]) -> Self {
RecordEncryptionKeyId(bytes)
}
#[must_use]
pub const fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AuditEncryptionAlgorithm {}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RecordEncryptionAlgorithm {}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct EncryptionContext {
pub capability: CapabilityKind,
pub trace_id: TraceId,
pub operator_context: SmallVec<[(String, Vec<u8>); 2]>,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct RecordEncryptionContext {
pub nsid: Nsid,
pub originator: Did,
pub audience_list: Option<AtUri>,
pub trace_id: TraceId,
pub operator_context: SmallVec<[(String, Vec<u8>); 2]>,
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EncryptedRecord {
pub key_id: RecordEncryptionKeyId,
pub algorithm: RecordEncryptionAlgorithm,
pub ciphertext: Vec<u8>,
pub aad: Vec<u8>,
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum EncryptionError {
#[error("encryption key not found: {key_id:?}")]
KeyNotFound {
key_id: AuditEncryptionKeyId,
},
#[error("encryption algorithm not supported: {0:?}")]
AlgorithmNotSupported(AuditEncryptionAlgorithm),
#[error("encryption payload malformed")]
Malformed,
#[error("encryption access denied: {reason}")]
AccessDenied {
reason: &'static str,
},
#[error("encryption deadline exceeded after {elapsed:?}")]
DeadlineExceeded {
elapsed: Duration,
},
#[error("encryption upstream error: {0}")]
UpstreamError(String),
}
#[async_trait]
pub trait AuditEncryptionResolver: Send + Sync {
async fn encrypt(
&self,
plaintext: &[u8],
context: &EncryptionContext,
deadline: Instant,
) -> Result<SensitiveRepresentation, EncryptionError>;
async fn decrypt(
&self,
sensitive: &SensitiveRepresentation,
context: &EncryptionContext,
deadline: Instant,
) -> Result<Vec<u8>, EncryptionError>;
fn active_key_id(&self) -> AuditEncryptionKeyId;
}
#[async_trait]
pub trait RecordEncryptionResolver: Send + Sync {
async fn encrypt_record(
&self,
plaintext: &[u8],
context: &RecordEncryptionContext,
deadline: Instant,
) -> Result<EncryptedRecord, EncryptionError>;
async fn decrypt_record(
&self,
encrypted: &EncryptedRecord,
reader: &Did,
context: &RecordEncryptionContext,
deadline: Instant,
) -> Result<Vec<u8>, EncryptionError>;
}
pub trait EncryptionResolverSet: Send + Sync {
fn audit(&self) -> Option<Arc<dyn AuditEncryptionResolver>>;
fn record(&self) -> Option<Arc<dyn RecordEncryptionResolver>>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoEncryption;
impl EncryptionResolverSet for NoEncryption {
fn audit(&self) -> Option<Arc<dyn AuditEncryptionResolver>> {
None
}
fn record(&self) -> Option<Arc<dyn RecordEncryptionResolver>> {
None
}
}
pub async fn produce_sensitive_representation(
plaintext: &[u8],
context: &EncryptionContext,
deadline: Instant,
resolver: Option<&dyn AuditEncryptionResolver>,
) -> Result<Option<SensitiveRepresentation>, EncryptionError> {
match resolver {
None => Ok(None),
Some(r) => r.encrypt(plaintext, context, deadline).await.map(Some),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn algorithm_enums_have_zero_v1_variants() {
fn _assert_audit_alg_zero_variants(a: AuditEncryptionAlgorithm) -> ! {
match a {}
}
fn _assert_record_alg_zero_variants(a: RecordEncryptionAlgorithm) -> ! {
match a {}
}
}
#[test]
fn key_id_bytes_round_trip() {
let bytes = [0xCC; 32];
assert_eq!(AuditEncryptionKeyId::from_bytes(bytes).as_bytes(), &bytes);
assert_eq!(RecordEncryptionKeyId::from_bytes(bytes).as_bytes(), &bytes);
}
#[test]
fn no_encryption_returns_none_from_both_methods() {
let set = NoEncryption;
assert!(set.audit().is_none());
assert!(set.record().is_none());
}
#[test]
fn no_encryption_is_zero_sized() {
assert_eq!(std::mem::size_of::<NoEncryption>(), 0);
}
#[test]
fn encryption_error_constructible_variants_round_trip() {
let _v1 = EncryptionError::KeyNotFound {
key_id: AuditEncryptionKeyId::from_bytes([0; 32]),
};
let _v2 = EncryptionError::Malformed;
let _v3 = EncryptionError::AccessDenied {
reason: "test-only",
};
let _v4 = EncryptionError::DeadlineExceeded {
elapsed: Duration::from_secs(1),
};
let _v5 = EncryptionError::UpstreamError("kms unreachable".into());
}
#[tokio::test]
async fn produce_sensitive_returns_none_when_resolver_absent() {
let context = EncryptionContext {
capability: CapabilityKind::ViewPrivate,
trace_id: TraceId::from_bytes([0; 16]),
operator_context: SmallVec::new(),
};
let deadline = Instant::now() + Duration::from_secs(30);
let result = produce_sensitive_representation(
b"plaintext",
&context,
deadline,
None,
)
.await
.unwrap();
assert!(result.is_none());
}
struct AlwaysAccessDenied;
#[async_trait]
impl AuditEncryptionResolver for AlwaysAccessDenied {
async fn encrypt(
&self,
_plaintext: &[u8],
_context: &EncryptionContext,
_deadline: Instant,
) -> Result<SensitiveRepresentation, EncryptionError> {
Err(EncryptionError::AccessDenied {
reason: "mock resolver: always denies",
})
}
async fn decrypt(
&self,
_sensitive: &SensitiveRepresentation,
_context: &EncryptionContext,
_deadline: Instant,
) -> Result<Vec<u8>, EncryptionError> {
Err(EncryptionError::AccessDenied {
reason: "mock resolver: always denies",
})
}
fn active_key_id(&self) -> AuditEncryptionKeyId {
AuditEncryptionKeyId::from_bytes([0xFF; 32])
}
}
#[tokio::test]
async fn produce_sensitive_propagates_resolver_error() {
let context = EncryptionContext {
capability: CapabilityKind::ViewPrivate,
trace_id: TraceId::from_bytes([0; 16]),
operator_context: SmallVec::new(),
};
let deadline = Instant::now() + Duration::from_secs(30);
let resolver = AlwaysAccessDenied;
let err = produce_sensitive_representation(
b"plaintext",
&context,
deadline,
Some(&resolver as &dyn AuditEncryptionResolver),
)
.await
.unwrap_err();
assert!(matches!(
err,
EncryptionError::AccessDenied {
reason: "mock resolver: always denies",
}
));
}
#[test]
fn mock_audit_resolver_active_key_id_round_trips() {
let resolver = AlwaysAccessDenied;
assert_eq!(resolver.active_key_id().as_bytes(), &[0xFF; 32]);
}
struct AlwaysMalformedRecord;
#[async_trait]
impl RecordEncryptionResolver for AlwaysMalformedRecord {
async fn encrypt_record(
&self,
_plaintext: &[u8],
_context: &RecordEncryptionContext,
_deadline: Instant,
) -> Result<EncryptedRecord, EncryptionError> {
Err(EncryptionError::Malformed)
}
async fn decrypt_record(
&self,
_encrypted: &EncryptedRecord,
_reader: &Did,
_context: &RecordEncryptionContext,
_deadline: Instant,
) -> Result<Vec<u8>, EncryptionError> {
Err(EncryptionError::Malformed)
}
}
#[tokio::test]
async fn mock_record_resolver_returns_malformed() {
let nsid = Nsid::new("tools.kryphocron.feed.postPrivate").unwrap();
let did = Did::new("did:plc:exampleexampleexample").unwrap();
let context = RecordEncryptionContext {
nsid,
originator: did.clone(),
audience_list: None,
trace_id: TraceId::from_bytes([0; 16]),
operator_context: SmallVec::new(),
};
let deadline = Instant::now() + Duration::from_secs(30);
let resolver = AlwaysMalformedRecord;
let err = resolver
.encrypt_record(b"plaintext", &context, deadline)
.await
.unwrap_err();
assert!(matches!(err, EncryptionError::Malformed));
}
}