use super::context::ProtectionContext;
use super::processor::{OutputProcessor, ProcessedEvent};
use crate::streaming::StreamEvent;
use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
use async_trait::async_trait;
use rand::RngCore;
pub type EncryptionKey = [u8; 32];
pub struct EncryptionProcessor {
enabled: bool,
key: EncryptionKey,
}
impl EncryptionProcessor {
pub fn new() -> Self {
Self {
enabled: false,
key: [0u8; 32],
}
}
pub fn with_key(mut self, key: EncryptionKey) -> Self {
self.key = key;
self
}
pub fn enabled(mut self) -> Self {
if std::env::var("ENACT_PRODUCTION")
.map(|v| v.to_lowercase() == "true")
.unwrap_or(false)
{
panic!(
"EncryptionProcessor is DEVELOPMENT-ONLY and cannot be used in production. \
Set up a production-grade encryption solution with KMS integration. \
See documentation for details."
);
}
self.enabled = true;
self
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
fn encrypt_text(&self, text: &str) -> anyhow::Result<String> {
let key = aes_gcm::Key::<Aes256Gcm>::from_slice(&self.key);
let cipher = Aes256Gcm::new(key);
let mut nonce_bytes = [0u8; 12];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, text.as_bytes())
.map_err(|e| anyhow::anyhow!("encryption failed: {:?}", e))?;
let mut payload = Vec::with_capacity(nonce_bytes.len() + ciphertext.len());
payload.extend_from_slice(&nonce_bytes);
payload.extend_from_slice(&ciphertext);
Ok(format!("ENC:{}", hex::encode(payload)))
}
fn should_encrypt_event(&self, event: &StreamEvent, ctx: &ProtectionContext) -> bool {
if !ctx.destination.requires_encryption() {
return false;
}
matches!(
event,
StreamEvent::TextDelta { .. }
| StreamEvent::StepEnd {
output: Some(_),
..
}
| StreamEvent::ExecutionEnd {
final_output: Some(_),
..
}
)
}
fn encrypt_event(&self, event: StreamEvent) -> anyhow::Result<(StreamEvent, Option<String>)> {
match event {
StreamEvent::TextDelta { id, delta } => {
let encrypted = self.encrypt_text(&delta)?;
Ok((
StreamEvent::TextDelta {
id,
delta: "[ENCRYPTED]".to_string(),
},
Some(encrypted),
))
}
StreamEvent::StepEnd {
execution_id,
step_id,
output: Some(text),
duration_ms,
timestamp,
} => {
let encrypted = self.encrypt_text(&text)?;
Ok((
StreamEvent::StepEnd {
execution_id,
step_id,
output: Some("[ENCRYPTED]".to_string()),
duration_ms,
timestamp,
},
Some(encrypted),
))
}
StreamEvent::ExecutionEnd {
execution_id,
final_output: Some(text),
duration_ms,
timestamp,
} => {
let encrypted = self.encrypt_text(&text)?;
Ok((
StreamEvent::ExecutionEnd {
execution_id,
final_output: Some("[ENCRYPTED]".to_string()),
duration_ms,
timestamp,
},
Some(encrypted),
))
}
_ => Ok((event, None)),
}
}
}
impl Default for EncryptionProcessor {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl OutputProcessor for EncryptionProcessor {
fn name(&self) -> &str {
"encryption"
}
async fn process(
&self,
event: StreamEvent,
ctx: &ProtectionContext,
) -> anyhow::Result<ProcessedEvent> {
if !self.enabled {
return Ok(ProcessedEvent::unchanged(event));
}
tracing::warn!(
target: "enact_core::encryption",
"DEVELOPMENT-ONLY encryption in use. Do not use in production. Implement KMS-backed encryption instead."
);
if !self.should_encrypt_event(&event, ctx) {
return Ok(ProcessedEvent::unchanged(event));
}
let (encrypted_event, encrypted_payload) = self.encrypt_event(event)?;
Ok(ProcessedEvent {
event: encrypted_event,
was_modified: true,
encrypted_payload,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::ExecutionId;
use std::sync::Mutex;
static ENACT_PRODUCTION_TEST_LOCK: std::sync::OnceLock<Mutex<()>> = std::sync::OnceLock::new();
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
ENACT_PRODUCTION_TEST_LOCK
.get_or_init(|| Mutex::new(()))
.lock()
.unwrap()
}
#[tokio::test]
async fn test_encryption_processor_name() {
let processor = EncryptionProcessor::new();
assert_eq!(processor.name(), "encryption");
}
#[tokio::test]
async fn test_encryption_processor_disabled_by_default() {
let processor = EncryptionProcessor::new();
assert!(!processor.is_enabled());
}
#[tokio::test]
async fn test_encryption_processor_can_enable() {
let _guard = env_lock();
std::env::set_var("ENACT_PRODUCTION", "false");
let processor = EncryptionProcessor::new().enabled();
std::env::remove_var("ENACT_PRODUCTION");
assert!(processor.is_enabled());
}
#[tokio::test]
async fn test_encryption_skips_when_disabled() {
let processor = EncryptionProcessor::new();
let ctx = ProtectionContext::for_storage();
let event = StreamEvent::text_delta("id", "secret data");
let result = processor.process(event, &ctx).await.unwrap();
assert!(!result.was_modified);
}
#[tokio::test]
async fn test_encryption_skips_streaming_destination() {
let processor = {
let _guard = env_lock();
std::env::set_var("ENACT_PRODUCTION", "false");
let p = EncryptionProcessor::new().enabled();
std::env::remove_var("ENACT_PRODUCTION");
p
};
let ctx = ProtectionContext::for_stream();
let event = StreamEvent::text_delta("id", "secret data");
let result = processor.process(event, &ctx).await.unwrap();
assert!(!result.was_modified);
}
#[tokio::test]
async fn test_encryption_encrypts_for_storage() {
let processor = {
let _guard = env_lock();
std::env::set_var("ENACT_PRODUCTION", "false");
let p = EncryptionProcessor::new().with_key([1u8; 32]).enabled();
std::env::remove_var("ENACT_PRODUCTION");
p
};
let ctx = ProtectionContext::for_storage();
let event = StreamEvent::text_delta("id", "secret data");
let result = processor.process(event, &ctx).await.unwrap();
assert!(result.was_modified);
if let StreamEvent::TextDelta { delta, .. } = result.event {
assert_eq!(delta, "[ENCRYPTED]");
} else {
panic!("Expected TextDelta");
}
let payload = result
.encrypted_payload
.expect("expected encrypted payload");
assert!(payload.starts_with("ENC:"));
assert!(
!payload.contains("secret data"),
"ciphertext should not contain plaintext"
);
}
#[tokio::test]
async fn test_encryption_control_events_pass_through() {
let processor = {
let _guard = env_lock();
std::env::set_var("ENACT_PRODUCTION", "false");
let p = EncryptionProcessor::new().enabled();
std::env::remove_var("ENACT_PRODUCTION");
p
};
let ctx = ProtectionContext::for_storage();
let exec_id = ExecutionId::new();
let event = StreamEvent::execution_start(&exec_id);
let result = processor.process(event, &ctx).await.unwrap();
assert!(!result.was_modified);
}
#[test]
fn test_encryption_panics_in_production_mode() {
let _guard = env_lock();
std::env::set_var("ENACT_PRODUCTION", "true");
let result = std::panic::catch_unwind(|| {
let _ = EncryptionProcessor::new().enabled();
});
std::env::remove_var("ENACT_PRODUCTION");
assert!(result.is_err(), "Should panic when ENACT_PRODUCTION=true");
}
#[tokio::test]
async fn test_encryption_works_when_production_false() {
let _guard = env_lock();
std::env::set_var("ENACT_PRODUCTION", "false");
let processor = EncryptionProcessor::new().enabled();
std::env::remove_var("ENACT_PRODUCTION");
assert!(processor.is_enabled());
}
}