cortexai_encryption/
stores.rs1use std::sync::Arc;
7
8use crate::envelope::EnvelopeEncryptor;
9use crate::error::{CryptoError, CryptoResult};
10
11pub struct EncryptedSessionStore<S> {
29 inner: S,
30 encryptor: Arc<EnvelopeEncryptor>,
31}
32
33impl<S> EncryptedSessionStore<S> {
34 pub fn new(inner: S, encryptor: Arc<EnvelopeEncryptor>) -> Self {
36 Self { inner, encryptor }
37 }
38
39 pub fn inner(&self) -> &S {
41 &self.inner
42 }
43
44 pub fn encryptor(&self) -> &EnvelopeEncryptor {
46 &self.encryptor
47 }
48
49 #[cfg(feature = "aes")]
51 pub fn encrypt_session<T: serde::Serialize>(&self, session: &T) -> CryptoResult<Vec<u8>> {
52 let json = serde_json::to_vec(session)?;
53 self.encryptor.encrypt(&json, None)
54 }
55
56 #[cfg(feature = "aes")]
58 pub fn decrypt_session<T: serde::de::DeserializeOwned>(
59 &self,
60 ciphertext: &[u8],
61 ) -> CryptoResult<T> {
62 let plaintext = self.encryptor.decrypt(ciphertext, None)?;
63 let session = serde_json::from_slice(&plaintext)?;
64 Ok(session)
65 }
66}
67
68impl<S: Clone> Clone for EncryptedSessionStore<S> {
69 fn clone(&self) -> Self {
70 Self {
71 inner: self.inner.clone(),
72 encryptor: self.encryptor.clone(),
73 }
74 }
75}
76
77impl<S: std::fmt::Debug> std::fmt::Debug for EncryptedSessionStore<S> {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("EncryptedSessionStore")
80 .field("inner", &self.inner)
81 .field("encryptor", &"[EnvelopeEncryptor]")
82 .finish()
83 }
84}
85
86pub struct EncryptedCheckpointStore<S> {
104 inner: S,
105 encryptor: Arc<EnvelopeEncryptor>,
106}
107
108impl<S> EncryptedCheckpointStore<S> {
109 pub fn new(inner: S, encryptor: Arc<EnvelopeEncryptor>) -> Self {
111 Self { inner, encryptor }
112 }
113
114 pub fn inner(&self) -> &S {
116 &self.inner
117 }
118
119 pub fn encryptor(&self) -> &EnvelopeEncryptor {
121 &self.encryptor
122 }
123
124 #[cfg(feature = "aes")]
128 pub fn encrypt_state(&self, state: &[u8], thread_id: &str) -> CryptoResult<Vec<u8>> {
129 self.encryptor.encrypt(state, Some(thread_id.as_bytes()))
130 }
131
132 #[cfg(feature = "aes")]
136 pub fn decrypt_state(&self, ciphertext: &[u8], thread_id: &str) -> CryptoResult<Vec<u8>> {
137 self.encryptor
138 .decrypt(ciphertext, Some(thread_id.as_bytes()))
139 }
140}
141
142impl<S: Clone> Clone for EncryptedCheckpointStore<S> {
143 fn clone(&self) -> Self {
144 Self {
145 inner: self.inner.clone(),
146 encryptor: self.encryptor.clone(),
147 }
148 }
149}
150
151impl<S: std::fmt::Debug> std::fmt::Debug for EncryptedCheckpointStore<S> {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 f.debug_struct("EncryptedCheckpointStore")
154 .field("inner", &self.inner)
155 .field("encryptor", &"[EnvelopeEncryptor]")
156 .finish()
157 }
158}
159
160#[derive(Debug, Clone)]
162pub struct EncryptionConfig {
163 pub enabled: bool,
165 pub master_key: Option<String>,
167 pub salt: Option<String>,
169 pub algorithm: String,
171}
172
173impl Default for EncryptionConfig {
174 fn default() -> Self {
175 Self {
176 enabled: false,
177 master_key: None,
178 salt: None,
179 algorithm: "aes-256-gcm".to_string(),
180 }
181 }
182}
183
184impl EncryptionConfig {
185 #[cfg(feature = "aes")]
187 pub fn with_generated_key() -> Self {
188 use crate::key::EncryptionKey;
189 let key = EncryptionKey::generate(32);
190 Self {
191 enabled: true,
192 master_key: Some(key.to_base64()),
193 salt: None,
194 algorithm: "aes-256-gcm".to_string(),
195 }
196 }
197
198 #[cfg(feature = "aes")]
200 pub fn create_encryptor(&self) -> CryptoResult<Option<EnvelopeEncryptor>> {
201 use crate::key::EncryptionKey;
202
203 if !self.enabled {
204 return Ok(None);
205 }
206
207 let key = match &self.master_key {
208 Some(k) => EncryptionKey::from_base64(k)?,
209 None => {
210 return Err(CryptoError::KeyDerivationFailed(
211 "master_key required when encryption is enabled".to_string(),
212 ))
213 }
214 };
215
216 Ok(Some(EnvelopeEncryptor::new(key)))
217 }
218}
219
220#[cfg(all(test, feature = "aes"))]
221mod tests {
222 use super::*;
223 use crate::key::EncryptionKey;
224 use serde::{Deserialize, Serialize};
225
226 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
227 struct MockSession {
228 id: String,
229 user_id: String,
230 messages: Vec<String>,
231 }
232
233 #[test]
234 fn test_encrypted_session_store_roundtrip() {
235 let key = EncryptionKey::generate(32);
236 let encryptor = Arc::new(EnvelopeEncryptor::new(key));
237
238 let store: EncryptedSessionStore<()> = EncryptedSessionStore::new((), encryptor);
240
241 let session = MockSession {
242 id: "sess-123".to_string(),
243 user_id: "user-456".to_string(),
244 messages: vec!["Hello".to_string(), "World".to_string()],
245 };
246
247 let encrypted = store.encrypt_session(&session).unwrap();
248 let decrypted: MockSession = store.decrypt_session(&encrypted).unwrap();
249
250 assert_eq!(session, decrypted);
251 }
252
253 #[test]
254 fn test_encrypted_checkpoint_store_with_aad() {
255 let key = EncryptionKey::generate(32);
256 let encryptor = Arc::new(EnvelopeEncryptor::new(key));
257
258 let store: EncryptedCheckpointStore<()> = EncryptedCheckpointStore::new((), encryptor);
259
260 let state = b"checkpoint state data";
261 let thread_id = "thread-123";
262
263 let encrypted = store.encrypt_state(state, thread_id).unwrap();
264 let decrypted = store.decrypt_state(&encrypted, thread_id).unwrap();
265
266 assert_eq!(state.as_slice(), decrypted.as_slice());
267 }
268
269 #[test]
270 fn test_checkpoint_wrong_thread_id_fails() {
271 let key = EncryptionKey::generate(32);
272 let encryptor = Arc::new(EnvelopeEncryptor::new(key));
273
274 let store: EncryptedCheckpointStore<()> = EncryptedCheckpointStore::new((), encryptor);
275
276 let state = b"checkpoint state data";
277 let encrypted = store.encrypt_state(state, "thread-123").unwrap();
278
279 let result = store.decrypt_state(&encrypted, "thread-456");
281 assert!(result.is_err());
282 }
283
284 #[test]
285 fn test_encryption_config_create_encryptor() {
286 let config = EncryptionConfig::with_generated_key();
287 assert!(config.enabled);
288 assert!(config.master_key.is_some());
289
290 let encryptor = config.create_encryptor().unwrap();
291 assert!(encryptor.is_some());
292 }
293
294 #[test]
295 fn test_disabled_encryption_config() {
296 let config = EncryptionConfig::default();
297 assert!(!config.enabled);
298
299 let encryptor = config.create_encryptor().unwrap();
300 assert!(encryptor.is_none());
301 }
302}