use crate::encryption_key::EncryptionKey;
use crate::service::{
AppendEventRequest, CreateRequest, DeleteRequest, GetRequest, ListRequest, SessionService,
};
use crate::session::Session;
use crate::state::State;
use crate::{Event, Events};
use adk_core::{AdkError, Result};
use aes_gcm::aead::Aead;
use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
use async_trait::async_trait;
use base64::Engine;
use chrono::{DateTime, Utc};
use rand::RngCore;
use serde_json::Value;
use std::collections::HashMap;
const ENCRYPTED_STATE_KEY: &str = "__encrypted_state";
pub struct EncryptedSession<S: SessionService> {
inner: S,
current_key: EncryptionKey,
previous_keys: Vec<EncryptionKey>,
}
impl<S: SessionService> EncryptedSession<S> {
pub fn new(inner: S, current_key: EncryptionKey, previous_keys: Vec<EncryptionKey>) -> Self {
Self { inner, current_key, previous_keys }
}
fn encrypt_state(&self, state: &HashMap<String, Value>) -> Result<HashMap<String, Value>> {
let plaintext = serde_json::to_vec(state)
.map_err(|e| AdkError::session(format!("failed to serialize state: {e}")))?;
let encrypted = encrypt_bytes(self.current_key.as_bytes(), &plaintext)?;
let encoded = base64::engine::general_purpose::STANDARD.encode(&encrypted);
let mut wrapped = HashMap::new();
wrapped.insert(ENCRYPTED_STATE_KEY.to_string(), Value::String(encoded));
Ok(wrapped)
}
fn decrypt_state(&self, state: &HashMap<String, Value>) -> Result<HashMap<String, Value>> {
let encoded = match state.get(ENCRYPTED_STATE_KEY) {
Some(Value::String(s)) => s,
_ => {
return Ok(state.clone());
}
};
let encrypted = base64::engine::general_purpose::STANDARD
.decode(encoded)
.map_err(|e| AdkError::session(format!("invalid base64 in encrypted state: {e}")))?;
if let Ok(plaintext) = decrypt_bytes(self.current_key.as_bytes(), &encrypted) {
return parse_state(&plaintext);
}
for prev_key in &self.previous_keys {
if let Ok(plaintext) = decrypt_bytes(prev_key.as_bytes(), &encrypted) {
return parse_state(&plaintext);
}
}
Err(AdkError::session("decryption failed: no matching key"))
}
fn decrypt_state_with_rotation(
&self,
state: &HashMap<String, Value>,
) -> Result<(HashMap<String, Value>, bool)> {
let encoded = match state.get(ENCRYPTED_STATE_KEY) {
Some(Value::String(s)) => s,
_ => return Ok((state.clone(), false)),
};
let encrypted = base64::engine::general_purpose::STANDARD
.decode(encoded)
.map_err(|e| AdkError::session(format!("invalid base64 in encrypted state: {e}")))?;
if let Ok(plaintext) = decrypt_bytes(self.current_key.as_bytes(), &encrypted) {
return Ok((parse_state(&plaintext)?, false));
}
for prev_key in &self.previous_keys {
if let Ok(plaintext) = decrypt_bytes(prev_key.as_bytes(), &encrypted) {
return Ok((parse_state(&plaintext)?, true));
}
}
Err(AdkError::session("decryption failed: no matching key"))
}
}
fn encrypt_bytes(key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>> {
let cipher = Aes256Gcm::new_from_slice(key)
.map_err(|e| AdkError::session(format!("failed to create cipher: {e}")))?;
let mut nonce_bytes = [0u8; 12];
rand::rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|e| AdkError::session(format!("encryption failed: {e}")))?;
let mut result = Vec::with_capacity(12 + ciphertext.len());
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
Ok(result)
}
fn decrypt_bytes(key: &[u8; 32], data: &[u8]) -> Result<Vec<u8>> {
if data.len() < 12 {
return Err(AdkError::session("encrypted data too short: missing nonce"));
}
let (nonce_bytes, ciphertext) = data.split_at(12);
let cipher = Aes256Gcm::new_from_slice(key)
.map_err(|e| AdkError::session(format!("failed to create cipher: {e}")))?;
let nonce = Nonce::from_slice(nonce_bytes);
cipher
.decrypt(nonce, ciphertext)
.map_err(|e| AdkError::session(format!("decryption failed: {e}")))
}
fn parse_state(plaintext: &[u8]) -> Result<HashMap<String, Value>> {
serde_json::from_slice(plaintext)
.map_err(|e| AdkError::session(format!("failed to deserialize decrypted state: {e}")))
}
#[async_trait]
impl<S: SessionService> SessionService for EncryptedSession<S> {
async fn create(&self, mut req: CreateRequest) -> Result<Box<dyn Session>> {
if !req.state.is_empty() {
req.state = self.encrypt_state(&req.state)?;
}
let session = self.inner.create(req).await?;
let decrypted = self.decrypt_state(&session.state().all())?;
Ok(Box::new(DecryptedSession::new(session, decrypted)))
}
async fn get(&self, req: GetRequest) -> Result<Box<dyn Session>> {
let session = self.inner.get(req).await?;
let raw_state = session.state().all();
let (decrypted, needs_reencrypt) = self.decrypt_state_with_rotation(&raw_state)?;
if needs_reencrypt {
let re_encrypted = self.encrypt_state(&decrypted)?;
let update_req = CreateRequest {
app_name: session.app_name().to_string(),
user_id: session.user_id().to_string(),
session_id: Some(session.id().to_string()),
state: re_encrypted,
};
let _ = self.inner.create(update_req).await;
}
Ok(Box::new(DecryptedSession::new(session, decrypted)))
}
async fn list(&self, req: ListRequest) -> Result<Vec<Box<dyn Session>>> {
self.inner.list(req).await
}
async fn delete(&self, req: DeleteRequest) -> Result<()> {
self.inner.delete(req).await
}
async fn append_event(&self, session_id: &str, event: Event) -> Result<()> {
self.inner.append_event(session_id, event).await
}
async fn append_event_for_identity(&self, req: AppendEventRequest) -> Result<()> {
self.inner.append_event_for_identity(req).await
}
async fn delete_all_sessions(&self, app_name: &str, user_id: &str) -> Result<()> {
self.inner.delete_all_sessions(app_name, user_id).await
}
async fn health_check(&self) -> Result<()> {
self.inner.health_check().await
}
}
struct DecryptedSession {
inner: Box<dyn Session>,
decrypted_state: HashMap<String, Value>,
}
impl DecryptedSession {
fn new(inner: Box<dyn Session>, decrypted_state: HashMap<String, Value>) -> Self {
Self { inner, decrypted_state }
}
}
impl Session for DecryptedSession {
fn id(&self) -> &str {
self.inner.id()
}
fn app_name(&self) -> &str {
self.inner.app_name()
}
fn user_id(&self) -> &str {
self.inner.user_id()
}
fn state(&self) -> &dyn State {
self
}
fn events(&self) -> &dyn Events {
self.inner.events()
}
fn last_update_time(&self) -> DateTime<Utc> {
self.inner.last_update_time()
}
}
impl State for DecryptedSession {
fn get(&self, key: &str) -> Option<Value> {
self.decrypted_state.get(key).cloned()
}
fn set(&mut self, key: String, value: Value) {
self.decrypted_state.insert(key, value);
}
fn all(&self) -> HashMap<String, Value> {
self.decrypted_state.clone()
}
}
impl Events for DecryptedSession {
fn all(&self) -> Vec<Event> {
self.inner.events().all()
}
fn len(&self) -> usize {
self.inner.events().len()
}
fn at(&self, index: usize) -> Option<&Event> {
self.inner.events().at(index)
}
}