celerix_store/engine/
memstore.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3use async_trait::async_trait;
4use crate::{Result, Error, KVReader, KVWriter, AppEnumeration, BatchExporter, GlobalSearcher, Orchestrator, CelerixStore, AppScope, VaultScope};
5use crate::engine::{Persistence, vault};
6
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9type StoreData = HashMap<String, HashMap<String, HashMap<String, serde_json::Value>>>;
10
11/// A thread-safe, in-memory implementation of the [`CelerixStore`] trait.
12/// 
13/// `MemStore` maintains all data in memory for high-performance concurrent access
14/// and supports asynchronous persistence to JSON files.
15pub struct MemStore {
16    data: RwLock<StoreData>,
17    persistence: Option<Arc<Persistence>>,
18    pending_tasks: Arc<AtomicUsize>,
19}
20
21impl MemStore {
22    /// Initializes a new `MemStore`.
23    /// 
24    /// - `initial_data`: Existing data to load into the store.
25    /// - `persistence`: Optional persistence handler for background saves.
26    pub fn new(initial_data: StoreData, persistence: Option<Arc<Persistence>>) -> Self {
27        Self {
28            data: RwLock::new(initial_data),
29            persistence,
30            pending_tasks: Arc::new(AtomicUsize::new(0)),
31        }
32    }
33
34    /// Waits for all background persistence tasks to complete.
35    /// 
36    /// This is useful during graceful shutdown to ensure no data is lost.
37    pub async fn wait(&self) {
38        while self.pending_tasks.load(Ordering::SeqCst) > 0 {
39            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
40        }
41    }
42
43    fn copy_persona_data(&self, persona_id: &str) -> Option<HashMap<String, HashMap<String, serde_json::Value>>> {
44        let data = self.data.read().unwrap();
45        data.get(persona_id).cloned()
46    }
47
48    async fn persist(&self, persona_id: String) {
49        if let Some(p) = &self.persistence {
50            if let Some(persona_data) = self.copy_persona_data(&persona_id) {
51                let p = p.clone();
52                let pending = self.pending_tasks.clone();
53                pending.fetch_add(1, Ordering::SeqCst);
54                tokio::task::spawn_blocking(move || {
55                    if let Err(e) = p.save_persona(&persona_id, &persona_data) {
56                        log::error!("Failed to persist persona {}: {}", persona_id, e);
57                    }
58                    pending.fetch_sub(1, Ordering::SeqCst);
59                });
60            }
61        }
62    }
63}
64
65#[async_trait]
66impl KVReader for MemStore {
67    async fn get(&self, persona_id: &str, app_id: &str, key: &str) -> Result<serde_json::Value> {
68        let data = self.data.read().unwrap();
69        let persona = data.get(persona_id);
70        
71        if persona.is_none() {
72            // Log for debugging if needed
73            // log::debug!("Persona {} not found, checking legacy or empty", persona_id);
74            return Err(Error::PersonaNotFound);
75        }
76
77        persona.unwrap()
78            .get(app_id)
79            .ok_or(Error::AppNotFound)?
80            .get(key)
81            .cloned()
82            .ok_or(Error::KeyNotFound)
83    }
84}
85
86#[async_trait]
87impl KVWriter for MemStore {
88    async fn set(&self, persona_id: &str, app_id: &str, key: &str, value: serde_json::Value) -> Result<()> {
89        {
90            let mut data = self.data.write().unwrap();
91            let persona = data.entry(persona_id.to_string()).or_default();
92            let app = persona.entry(app_id.to_string()).or_default();
93            app.insert(key.to_string(), value);
94        }
95        self.persist(persona_id.to_string()).await;
96        Ok(())
97    }
98
99    async fn delete(&self, persona_id: &str, app_id: &str, key: &str) -> Result<()> {
100        {
101            let mut data = self.data.write().unwrap();
102            if let Some(persona) = data.get_mut(persona_id) {
103                if let Some(app) = persona.get_mut(app_id) {
104                    app.remove(key);
105                }
106            }
107        }
108        self.persist(persona_id.to_string()).await;
109        Ok(())
110    }
111}
112
113#[async_trait]
114impl AppEnumeration for MemStore {
115    async fn get_personas(&self) -> Result<Vec<String>> {
116        let data = self.data.read().unwrap();
117        Ok(data.keys().cloned().collect())
118    }
119
120    async fn get_apps(&self, persona_id: &str) -> Result<Vec<String>> {
121        let data = self.data.read().unwrap();
122        Ok(data.get(persona_id)
123            .map(|p| p.keys().cloned().collect())
124            .unwrap_or_default())
125    }
126}
127
128#[async_trait]
129impl BatchExporter for MemStore {
130    async fn get_app_store(&self, persona_id: &str, app_id: &str) -> Result<HashMap<String, serde_json::Value>> {
131        let data = self.data.read().unwrap();
132        data.get(persona_id)
133            .ok_or(Error::PersonaNotFound)?
134            .get(app_id)
135            .cloned()
136            .ok_or(Error::AppNotFound)
137    }
138
139    async fn dump_app(&self, app_id: &str) -> Result<HashMap<String, HashMap<String, serde_json::Value>>> {
140        let data = self.data.read().unwrap();
141        let mut result = HashMap::new();
142        for (persona_id, apps) in data.iter() {
143            if let Some(app_data) = apps.get(app_id) {
144                result.insert(persona_id.clone(), app_data.clone());
145            }
146        }
147        Ok(result)
148    }
149}
150
151#[async_trait]
152impl GlobalSearcher for MemStore {
153    async fn get_global(&self, app_id: &str, key: &str) -> Result<(serde_json::Value, String)> {
154        let data = self.data.read().unwrap();
155        for (persona_id, apps) in data.iter() {
156            if let Some(app_data) = apps.get(app_id) {
157                if let Some(val) = app_data.get(key) {
158                    return Ok((val.clone(), persona_id.clone()));
159                }
160            }
161        }
162        Err(Error::KeyNotFound)
163    }
164}
165
166#[async_trait]
167impl Orchestrator for MemStore {
168    async fn move_key(&self, src_persona: &str, dst_persona: &str, app_id: &str, key: &str) -> Result<()> {
169        let val = {
170            let mut data = self.data.write().unwrap();
171            let src_persona_data = data.get_mut(src_persona).ok_or(Error::PersonaNotFound)?;
172            let src_app_data = src_persona_data.get_mut(app_id).ok_or(Error::AppNotFound)?;
173            src_app_data.remove(key).ok_or(Error::KeyNotFound)?
174        };
175
176        self.set(dst_persona, app_id, key, val).await?;
177        self.persist(src_persona.to_string()).await;
178        
179        Ok(())
180    }
181}
182
183impl CelerixStore for MemStore {
184    fn app(&self, persona_id: &str, app_id: &str) -> Box<dyn AppScope + '_> {
185        Box::new(MemAppScope {
186            store: self,
187            persona_id: persona_id.to_string(),
188            app_id: app_id.to_string(),
189        })
190    }
191}
192
193pub struct MemAppScope<'a> {
194    store: &'a MemStore,
195    persona_id: String,
196    app_id: String,
197}
198
199#[async_trait]
200impl<'a> AppScope for MemAppScope<'a> {
201    async fn get(&self, key: &str) -> Result<serde_json::Value> {
202        self.store.get(&self.persona_id, &self.app_id, key).await
203    }
204
205    async fn set(&self, key: &str, value: serde_json::Value) -> Result<()> {
206        self.store.set(&self.persona_id, &self.app_id, key, value).await
207    }
208
209    async fn delete(&self, key: &str) -> Result<()> {
210        self.store.delete(&self.persona_id, &self.app_id, key).await
211    }
212
213    fn vault(&self, master_key: &[u8]) -> Box<dyn VaultScope + '_> {
214        Box::new(MemVaultScope {
215            app: self,
216            master_key: master_key.to_vec(),
217        })
218    }
219}
220
221pub struct MemVaultScope<'a> {
222    app: &'a MemAppScope<'a>,
223    master_key: Vec<u8>,
224}
225
226#[async_trait]
227impl<'a> VaultScope for MemVaultScope<'a> {
228    async fn get(&self, key: &str) -> Result<String> {
229        let val = self.app.get(key).await?;
230        let cipher_hex = val.as_str().ok_or_else(|| Error::Internal("Vault data is not a string".to_string()))?;
231        vault::decrypt(cipher_hex, &self.master_key)
232    }
233
234    async fn set(&self, key: &str, plaintext: &str) -> Result<()> {
235        let cipher_hex = vault::encrypt(plaintext, &self.master_key)?;
236        self.app.set(key, serde_json::Value::String(cipher_hex)).await
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use serde_json::json;
244
245    #[tokio::test]
246    async fn test_memstore_get_set() {
247        let store = MemStore::new(HashMap::new(), None);
248        store.set("p1", "app1", "k1", json!("v1")).await.unwrap();
249        
250        let val = store.get("p1", "app1", "k1").await.unwrap();
251        assert_eq!(val, json!("v1"));
252    }
253
254    #[tokio::test]
255    async fn test_memstore_delete() {
256        let store = MemStore::new(HashMap::new(), None);
257        store.set("p1", "app1", "k1", json!("v1")).await.unwrap();
258        store.delete("p1", "app1", "k1").await.unwrap();
259        
260        let res = store.get("p1", "app1", "k1").await;
261        assert!(matches!(res, Err(Error::KeyNotFound)));
262    }
263
264    #[tokio::test]
265    async fn test_move_key() {
266        let store = MemStore::new(HashMap::new(), None);
267        store.set("p1", "app1", "k1", json!("v1")).await.unwrap();
268        store.move_key("p1", "p2", "app1", "k1").await.unwrap();
269        
270        assert!(matches!(store.get("p1", "app1", "k1").await, Err(Error::KeyNotFound)));
271        assert_eq!(store.get("p2", "app1", "k1").await.unwrap(), json!("v1"));
272    }
273
274    #[tokio::test]
275    async fn test_app_scope_and_vault() {
276        let store = MemStore::new(HashMap::new(), None);
277        let master_key = b"thisis32byteslongsecretkey123456";
278
279        let scope = store.app("p1", "a1");
280        scope.set("secret", json!("hidden")).await.unwrap();
281
282        let val = scope.get("secret").await.unwrap();
283        assert_eq!(val, json!("hidden"));
284
285        let v = scope.vault(master_key);
286        v.set("password", "topsecret").await.unwrap();
287
288        let pass = v.get("password").await.unwrap();
289        assert_eq!(pass, "topsecret");
290
291        // Check that it's encrypted in the underlying store
292        let raw = scope.get("password").await.unwrap();
293        assert_ne!(raw, json!("topsecret"));
294        assert!(raw.is_string());
295    }
296}