Skip to main content

cortexai_encryption/
stores.rs

1//! Encrypted store wrappers for sessions and checkpoints.
2//!
3//! These wrappers provide transparent encryption for storage backends.
4//! Data is encrypted before saving and decrypted after loading.
5
6use std::sync::Arc;
7
8use crate::envelope::EnvelopeEncryptor;
9use crate::error::{CryptoError, CryptoResult};
10
11/// Encrypted session store wrapper.
12///
13/// Wraps any session store implementation to provide transparent encryption.
14/// Sessions are serialized to JSON, encrypted, then stored as bytes.
15///
16/// # Example
17///
18/// ```rust,ignore
19/// use cortexai_encryption::{EncryptionKey, EnvelopeEncryptor, EncryptedSessionStore};
20/// use cortexai_agents::session::MemorySessionStore;
21///
22/// let key = EncryptionKey::generate(32);
23/// let encryptor = Arc::new(EnvelopeEncryptor::new(key));
24/// let inner_store = MemorySessionStore::new();
25///
26/// let encrypted_store = EncryptedSessionStore::new(inner_store, encryptor);
27/// ```
28pub struct EncryptedSessionStore<S> {
29    inner: S,
30    encryptor: Arc<EnvelopeEncryptor>,
31}
32
33impl<S> EncryptedSessionStore<S> {
34    /// Create a new encrypted session store.
35    pub fn new(inner: S, encryptor: Arc<EnvelopeEncryptor>) -> Self {
36        Self { inner, encryptor }
37    }
38
39    /// Get a reference to the inner store.
40    pub fn inner(&self) -> &S {
41        &self.inner
42    }
43
44    /// Get a reference to the encryptor.
45    pub fn encryptor(&self) -> &EnvelopeEncryptor {
46        &self.encryptor
47    }
48
49    /// Encrypt session data.
50    #[cfg(feature = "aes")]
51    pub fn encrypt_session<T: serde::Serialize>(&self, session: &T) -> CryptoResult<Vec<u8>> {
52        let json = serde_json::to_vec(session)?;
53        self.encryptor.encrypt(&json, None)
54    }
55
56    /// Decrypt session data.
57    #[cfg(feature = "aes")]
58    pub fn decrypt_session<T: serde::de::DeserializeOwned>(
59        &self,
60        ciphertext: &[u8],
61    ) -> CryptoResult<T> {
62        let plaintext = self.encryptor.decrypt(ciphertext, None)?;
63        let session = serde_json::from_slice(&plaintext)?;
64        Ok(session)
65    }
66}
67
68impl<S: Clone> Clone for EncryptedSessionStore<S> {
69    fn clone(&self) -> Self {
70        Self {
71            inner: self.inner.clone(),
72            encryptor: self.encryptor.clone(),
73        }
74    }
75}
76
77impl<S: std::fmt::Debug> std::fmt::Debug for EncryptedSessionStore<S> {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("EncryptedSessionStore")
80            .field("inner", &self.inner)
81            .field("encryptor", &"[EnvelopeEncryptor]")
82            .finish()
83    }
84}
85
86/// Encrypted checkpoint store wrapper.
87///
88/// Wraps any checkpoint store implementation to provide transparent encryption.
89/// Checkpoint state bytes are encrypted before storage.
90///
91/// # Example
92///
93/// ```rust,ignore
94/// use cortexai_encryption::{EncryptionKey, EnvelopeEncryptor, EncryptedCheckpointStore};
95/// use cortexai_agents::checkpoint::MemoryCheckpointStore;
96///
97/// let key = EncryptionKey::generate(32);
98/// let encryptor = Arc::new(EnvelopeEncryptor::new(key));
99/// let inner_store = MemoryCheckpointStore::new();
100///
101/// let encrypted_store = EncryptedCheckpointStore::new(inner_store, encryptor);
102/// ```
103pub struct EncryptedCheckpointStore<S> {
104    inner: S,
105    encryptor: Arc<EnvelopeEncryptor>,
106}
107
108impl<S> EncryptedCheckpointStore<S> {
109    /// Create a new encrypted checkpoint store.
110    pub fn new(inner: S, encryptor: Arc<EnvelopeEncryptor>) -> Self {
111        Self { inner, encryptor }
112    }
113
114    /// Get a reference to the inner store.
115    pub fn inner(&self) -> &S {
116        &self.inner
117    }
118
119    /// Get a reference to the encryptor.
120    pub fn encryptor(&self) -> &EnvelopeEncryptor {
121        &self.encryptor
122    }
123
124    /// Encrypt checkpoint state.
125    ///
126    /// Uses the thread_id as associated data for additional authentication.
127    #[cfg(feature = "aes")]
128    pub fn encrypt_state(&self, state: &[u8], thread_id: &str) -> CryptoResult<Vec<u8>> {
129        self.encryptor.encrypt(state, Some(thread_id.as_bytes()))
130    }
131
132    /// Decrypt checkpoint state.
133    ///
134    /// Uses the thread_id as associated data for verification.
135    #[cfg(feature = "aes")]
136    pub fn decrypt_state(&self, ciphertext: &[u8], thread_id: &str) -> CryptoResult<Vec<u8>> {
137        self.encryptor
138            .decrypt(ciphertext, Some(thread_id.as_bytes()))
139    }
140}
141
142impl<S: Clone> Clone for EncryptedCheckpointStore<S> {
143    fn clone(&self) -> Self {
144        Self {
145            inner: self.inner.clone(),
146            encryptor: self.encryptor.clone(),
147        }
148    }
149}
150
151impl<S: std::fmt::Debug> std::fmt::Debug for EncryptedCheckpointStore<S> {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        f.debug_struct("EncryptedCheckpointStore")
154            .field("inner", &self.inner)
155            .field("encryptor", &"[EnvelopeEncryptor]")
156            .finish()
157    }
158}
159
160/// Configuration for encrypted storage.
161#[derive(Debug, Clone)]
162pub struct EncryptionConfig {
163    /// Whether encryption is enabled
164    pub enabled: bool,
165    /// Master key (base64 encoded)
166    pub master_key: Option<String>,
167    /// Key derivation salt (base64 encoded)
168    pub salt: Option<String>,
169    /// Algorithm to use (aes-256-gcm or chacha20-poly1305)
170    pub algorithm: String,
171}
172
173impl Default for EncryptionConfig {
174    fn default() -> Self {
175        Self {
176            enabled: false,
177            master_key: None,
178            salt: None,
179            algorithm: "aes-256-gcm".to_string(),
180        }
181    }
182}
183
184impl EncryptionConfig {
185    /// Create a new encryption config with a generated key.
186    #[cfg(feature = "aes")]
187    pub fn with_generated_key() -> Self {
188        use crate::key::EncryptionKey;
189        let key = EncryptionKey::generate(32);
190        Self {
191            enabled: true,
192            master_key: Some(key.to_base64()),
193            salt: None,
194            algorithm: "aes-256-gcm".to_string(),
195        }
196    }
197
198    /// Create an encryptor from this config.
199    #[cfg(feature = "aes")]
200    pub fn create_encryptor(&self) -> CryptoResult<Option<EnvelopeEncryptor>> {
201        use crate::key::EncryptionKey;
202
203        if !self.enabled {
204            return Ok(None);
205        }
206
207        let key = match &self.master_key {
208            Some(k) => EncryptionKey::from_base64(k)?,
209            None => {
210                return Err(CryptoError::KeyDerivationFailed(
211                    "master_key required when encryption is enabled".to_string(),
212                ))
213            }
214        };
215
216        Ok(Some(EnvelopeEncryptor::new(key)))
217    }
218}
219
220#[cfg(all(test, feature = "aes"))]
221mod tests {
222    use super::*;
223    use crate::key::EncryptionKey;
224    use serde::{Deserialize, Serialize};
225
226    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
227    struct MockSession {
228        id: String,
229        user_id: String,
230        messages: Vec<String>,
231    }
232
233    #[test]
234    fn test_encrypted_session_store_roundtrip() {
235        let key = EncryptionKey::generate(32);
236        let encryptor = Arc::new(EnvelopeEncryptor::new(key));
237
238        // Using unit as a mock inner store for testing encryption only
239        let store: EncryptedSessionStore<()> = EncryptedSessionStore::new((), encryptor);
240
241        let session = MockSession {
242            id: "sess-123".to_string(),
243            user_id: "user-456".to_string(),
244            messages: vec!["Hello".to_string(), "World".to_string()],
245        };
246
247        let encrypted = store.encrypt_session(&session).unwrap();
248        let decrypted: MockSession = store.decrypt_session(&encrypted).unwrap();
249
250        assert_eq!(session, decrypted);
251    }
252
253    #[test]
254    fn test_encrypted_checkpoint_store_with_aad() {
255        let key = EncryptionKey::generate(32);
256        let encryptor = Arc::new(EnvelopeEncryptor::new(key));
257
258        let store: EncryptedCheckpointStore<()> = EncryptedCheckpointStore::new((), encryptor);
259
260        let state = b"checkpoint state data";
261        let thread_id = "thread-123";
262
263        let encrypted = store.encrypt_state(state, thread_id).unwrap();
264        let decrypted = store.decrypt_state(&encrypted, thread_id).unwrap();
265
266        assert_eq!(state.as_slice(), decrypted.as_slice());
267    }
268
269    #[test]
270    fn test_checkpoint_wrong_thread_id_fails() {
271        let key = EncryptionKey::generate(32);
272        let encryptor = Arc::new(EnvelopeEncryptor::new(key));
273
274        let store: EncryptedCheckpointStore<()> = EncryptedCheckpointStore::new((), encryptor);
275
276        let state = b"checkpoint state data";
277        let encrypted = store.encrypt_state(state, "thread-123").unwrap();
278
279        // Try to decrypt with wrong thread_id
280        let result = store.decrypt_state(&encrypted, "thread-456");
281        assert!(result.is_err());
282    }
283
284    #[test]
285    fn test_encryption_config_create_encryptor() {
286        let config = EncryptionConfig::with_generated_key();
287        assert!(config.enabled);
288        assert!(config.master_key.is_some());
289
290        let encryptor = config.create_encryptor().unwrap();
291        assert!(encryptor.is_some());
292    }
293
294    #[test]
295    fn test_disabled_encryption_config() {
296        let config = EncryptionConfig::default();
297        assert!(!config.enabled);
298
299        let encryptor = config.create_encryptor().unwrap();
300        assert!(encryptor.is_none());
301    }
302}