Skip to main content

iris_chat_protocol/
storage.rs

1use super::SharedConnection;
2use std::collections::HashMap;
3use std::fs;
4use std::path::PathBuf;
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct StorageError {
10    message: String,
11}
12
13impl StorageError {
14    pub fn new(message: impl Into<String>) -> Self {
15        Self {
16            message: message.into(),
17        }
18    }
19}
20
21impl std::fmt::Display for StorageError {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        self.message.fmt(f)
24    }
25}
26
27impl std::error::Error for StorageError {}
28
29pub type StorageResult<T> = Result<T, StorageError>;
30
31pub trait StorageAdapter: Send + Sync {
32    fn get(&self, key: &str) -> StorageResult<Option<String>>;
33    fn put(&self, key: &str, value: String) -> StorageResult<()>;
34    fn del(&self, key: &str) -> StorageResult<()>;
35    fn list(&self, prefix: &str) -> StorageResult<Vec<String>>;
36}
37
38#[derive(Clone)]
39pub struct InMemoryStorage {
40    store: Arc<Mutex<HashMap<String, String>>>,
41}
42
43impl InMemoryStorage {
44    pub fn new() -> Self {
45        Self {
46            store: Arc::new(Mutex::new(HashMap::new())),
47        }
48    }
49}
50
51impl Default for InMemoryStorage {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl StorageAdapter for InMemoryStorage {
58    fn get(&self, key: &str) -> StorageResult<Option<String>> {
59        let store = self
60            .store
61            .lock()
62            .map_err(|_| StorageError::new("storage mutex poisoned"))?;
63        Ok(store.get(key).cloned())
64    }
65
66    fn put(&self, key: &str, value: String) -> StorageResult<()> {
67        let mut store = self
68            .store
69            .lock()
70            .map_err(|_| StorageError::new("storage mutex poisoned"))?;
71        store.insert(key.to_string(), value);
72        Ok(())
73    }
74
75    fn del(&self, key: &str) -> StorageResult<()> {
76        let mut store = self
77            .store
78            .lock()
79            .map_err(|_| StorageError::new("storage mutex poisoned"))?;
80        store.remove(key);
81        Ok(())
82    }
83
84    fn list(&self, prefix: &str) -> StorageResult<Vec<String>> {
85        let store = self
86            .store
87            .lock()
88            .map_err(|_| StorageError::new("storage mutex poisoned"))?;
89        Ok(store
90            .keys()
91            .filter(|key| key.starts_with(prefix))
92            .cloned()
93            .collect())
94    }
95}
96
97pub struct FileStorageAdapter {
98    base_path: PathBuf,
99}
100
101impl FileStorageAdapter {
102    pub fn new(base_path: PathBuf) -> StorageResult<Self> {
103        fs::create_dir_all(&base_path)
104            .map_err(|err| storage_io_error("failed to create storage directory", err))?;
105        Ok(Self { base_path })
106    }
107
108    fn sanitize_key(key: &str) -> String {
109        key.replace(['/', '\\', ':'], "_")
110    }
111
112    fn key_to_path(&self, key: &str) -> PathBuf {
113        let sanitized = Self::sanitize_key(key);
114        self.base_path.join(format!("{}.json", sanitized))
115    }
116}
117
118impl StorageAdapter for FileStorageAdapter {
119    fn get(&self, key: &str) -> StorageResult<Option<String>> {
120        let path = self.key_to_path(key);
121        match fs::read_to_string(&path) {
122            Ok(contents) => Ok(Some(contents)),
123            Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None),
124            Err(err) => Err(storage_io_error("failed to read storage file", err)),
125        }
126    }
127
128    fn put(&self, key: &str, value: String) -> StorageResult<()> {
129        let path = self.key_to_path(key);
130
131        if let Some(parent) = path.parent() {
132            fs::create_dir_all(parent).map_err(|err| {
133                storage_io_error("failed to create storage parent directory", err)
134            })?;
135        }
136
137        let tmp_path = path.with_extension(format!("json.{}.tmp", rand::random::<u128>()));
138        fs::write(&tmp_path, value)
139            .map_err(|err| storage_io_error("failed to write storage temp file", err))?;
140
141        #[cfg(windows)]
142        {
143            if path.exists() {
144                fs::remove_file(&path).map_err(|err| {
145                    storage_io_error("failed to replace existing storage file", err)
146                })?;
147            }
148        }
149
150        fs::rename(&tmp_path, &path)
151            .map_err(|err| storage_io_error("failed to commit storage file", err))?;
152
153        Ok(())
154    }
155
156    fn del(&self, key: &str) -> StorageResult<()> {
157        let path = self.key_to_path(key);
158        match fs::remove_file(&path) {
159            Ok(()) => Ok(()),
160            Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(()),
161            Err(err) => Err(storage_io_error("failed to delete storage file", err)),
162        }
163    }
164
165    fn list(&self, prefix: &str) -> StorageResult<Vec<String>> {
166        let mut keys = Vec::new();
167        let sanitized_prefix = Self::sanitize_key(prefix);
168        let entries = fs::read_dir(&self.base_path)
169            .map_err(|err| storage_io_error("failed to read storage directory", err))?;
170
171        for entry in entries {
172            let entry =
173                entry.map_err(|err| storage_io_error("failed to read storage entry", err))?;
174            let file_name = entry.file_name();
175            let file_name_str = file_name.to_string_lossy();
176
177            if !file_name_str.ends_with(".json") {
178                continue;
179            }
180
181            let key = file_name_str
182                .strip_suffix(".json")
183                .unwrap_or(&file_name_str)
184                .to_string();
185
186            if prefix.is_empty() {
187                keys.push(key);
188                continue;
189            }
190
191            if key.starts_with(&sanitized_prefix) {
192                let remainder = key.strip_prefix(&sanitized_prefix).unwrap_or("");
193                keys.push(format!("{}{}", prefix, remainder));
194            }
195        }
196
197        Ok(keys)
198    }
199}
200
201pub struct DebouncedFileStorage {
202    adapter: FileStorageAdapter,
203    pending_writes: Mutex<HashMap<String, String>>,
204    last_flush: Mutex<Instant>,
205    flush_interval: Duration,
206}
207
208impl DebouncedFileStorage {
209    pub fn new(base_path: PathBuf, flush_interval_ms: u64) -> StorageResult<Self> {
210        Ok(Self {
211            adapter: FileStorageAdapter::new(base_path)?,
212            pending_writes: Mutex::new(HashMap::new()),
213            last_flush: Mutex::new(Instant::now()),
214            flush_interval: Duration::from_millis(flush_interval_ms),
215        })
216    }
217
218    pub fn flush(&self) -> StorageResult<()> {
219        let mut pending = self
220            .pending_writes
221            .lock()
222            .map_err(|_| StorageError::new("pending file storage mutex poisoned"))?;
223        for (key, value) in pending.drain() {
224            self.adapter.put(&key, value)?;
225        }
226        *self
227            .last_flush
228            .lock()
229            .map_err(|_| StorageError::new("file storage flush mutex poisoned"))? = Instant::now();
230        Ok(())
231    }
232
233    fn maybe_flush(&self) -> StorageResult<()> {
234        let last_flush = *self
235            .last_flush
236            .lock()
237            .map_err(|_| StorageError::new("file storage flush mutex poisoned"))?;
238        let pending_count = self
239            .pending_writes
240            .lock()
241            .map_err(|_| StorageError::new("pending file storage mutex poisoned"))?
242            .len();
243
244        if last_flush.elapsed() >= self.flush_interval && pending_count > 0 {
245            self.flush()?;
246        }
247        Ok(())
248    }
249}
250
251impl StorageAdapter for DebouncedFileStorage {
252    fn get(&self, key: &str) -> StorageResult<Option<String>> {
253        let pending = self
254            .pending_writes
255            .lock()
256            .map_err(|_| StorageError::new("pending file storage mutex poisoned"))?;
257        if let Some(value) = pending.get(key) {
258            return Ok(Some(value.clone()));
259        }
260        drop(pending);
261        self.adapter.get(key)
262    }
263
264    fn put(&self, key: &str, value: String) -> StorageResult<()> {
265        self.pending_writes
266            .lock()
267            .map_err(|_| StorageError::new("pending file storage mutex poisoned"))?
268            .insert(key.to_string(), value);
269        self.maybe_flush()
270    }
271
272    fn del(&self, key: &str) -> StorageResult<()> {
273        self.pending_writes
274            .lock()
275            .map_err(|_| StorageError::new("pending file storage mutex poisoned"))?
276            .remove(key);
277        self.adapter.del(key)
278    }
279
280    fn list(&self, prefix: &str) -> StorageResult<Vec<String>> {
281        let mut keys = self.adapter.list(prefix)?;
282        let pending = self
283            .pending_writes
284            .lock()
285            .map_err(|_| StorageError::new("pending file storage mutex poisoned"))?;
286
287        for key in pending.keys() {
288            if key.starts_with(prefix) && !keys.contains(key) {
289                keys.push(key.clone());
290            }
291        }
292
293        Ok(keys)
294    }
295}
296
297/// SQLite-backed implementation of `iris_chat_protocol::StorageAdapter`.
298/// Keys are namespaced by (owner_pubkey_hex, device_pubkey_hex) so a
299/// single database serves multiple owner accounts and devices without
300/// keyspace collisions, matching the per-(owner, device) directory
301/// scoping the previous file-backed adapter used.
302pub struct SqliteStorageAdapter {
303    conn: SharedConnection,
304    owner_pubkey_hex: String,
305    device_pubkey_hex: String,
306}
307
308impl SqliteStorageAdapter {
309    pub fn new(
310        conn: SharedConnection,
311        owner_pubkey_hex: String,
312        device_pubkey_hex: String,
313    ) -> Self {
314        Self {
315            conn,
316            owner_pubkey_hex,
317            device_pubkey_hex,
318        }
319    }
320
321    fn map_err<E: std::fmt::Display>(error: E) -> StorageError {
322        StorageError::new(error.to_string())
323    }
324}
325
326impl StorageAdapter for SqliteStorageAdapter {
327    fn get(&self, key: &str) -> StorageResult<Option<String>> {
328        let conn = self
329            .conn
330            .lock()
331            .map_err(|_| StorageError::new("ndr_kv connection mutex poisoned"))?;
332        conn.query_row(
333            "SELECT value FROM ndr_kv WHERE owner_pubkey_hex = ?1 AND device_pubkey_hex = ?2 AND key = ?3",
334            (&self.owner_pubkey_hex, &self.device_pubkey_hex, key),
335            |row| row.get::<_, String>(0),
336        )
337        .map(Some)
338        .or_else(|err| match err {
339            rusqlite::Error::QueryReturnedNoRows => Ok(None),
340            other => Err(Self::map_err(other)),
341        })
342    }
343
344    fn put(&self, key: &str, value: String) -> StorageResult<()> {
345        let conn = self
346            .conn
347            .lock()
348            .map_err(|_| StorageError::new("ndr_kv connection mutex poisoned"))?;
349        conn.execute(
350            "INSERT INTO ndr_kv (owner_pubkey_hex, device_pubkey_hex, key, value)
351             VALUES (?1, ?2, ?3, ?4)
352             ON CONFLICT(owner_pubkey_hex, device_pubkey_hex, key) DO UPDATE SET value = excluded.value",
353            (&self.owner_pubkey_hex, &self.device_pubkey_hex, key, &value),
354        )
355        .map_err(Self::map_err)?;
356        Ok(())
357    }
358
359    fn del(&self, key: &str) -> StorageResult<()> {
360        let conn = self
361            .conn
362            .lock()
363            .map_err(|_| StorageError::new("ndr_kv connection mutex poisoned"))?;
364        conn.execute(
365            "DELETE FROM ndr_kv WHERE owner_pubkey_hex = ?1 AND device_pubkey_hex = ?2 AND key = ?3",
366            (&self.owner_pubkey_hex, &self.device_pubkey_hex, key),
367        )
368        .map_err(Self::map_err)?;
369        Ok(())
370    }
371
372    fn list(&self, prefix: &str) -> StorageResult<Vec<String>> {
373        let conn = self
374            .conn
375            .lock()
376            .map_err(|_| StorageError::new("ndr_kv connection mutex poisoned"))?;
377        let mut stmt = conn
378            .prepare(
379                "SELECT key FROM ndr_kv
380                 WHERE owner_pubkey_hex = ?1 AND device_pubkey_hex = ?2 AND key LIKE ?3 ESCAPE '\\'",
381            )
382            .map_err(Self::map_err)?;
383        let pattern = format!("{}%", escape_like(prefix));
384        let rows = stmt
385            .query_map(
386                (&self.owner_pubkey_hex, &self.device_pubkey_hex, &pattern),
387                |row| row.get::<_, String>(0),
388            )
389            .map_err(Self::map_err)?;
390        let mut keys = Vec::new();
391        for row in rows {
392            keys.push(row.map_err(Self::map_err)?);
393        }
394        Ok(keys)
395    }
396}
397
398fn escape_like(input: &str) -> String {
399    let mut out = String::with_capacity(input.len());
400    for ch in input.chars() {
401        match ch {
402            '\\' | '%' | '_' => {
403                out.push('\\');
404                out.push(ch);
405            }
406            other => out.push(other),
407        }
408    }
409    out
410}
411
412fn storage_io_error(context: &str, error: std::io::Error) -> StorageError {
413    StorageError::new(format!("{}: {}", context, error))
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use std::sync::{Arc, Mutex};
420    use tempfile::TempDir;
421
422    fn fresh_connection() -> SharedConnection {
423        let conn = rusqlite::Connection::open_in_memory().unwrap();
424        conn.execute_batch(
425            "CREATE TABLE ndr_kv (
426                owner_pubkey_hex TEXT NOT NULL,
427                device_pubkey_hex TEXT NOT NULL,
428                key TEXT NOT NULL,
429                value TEXT NOT NULL,
430                PRIMARY KEY (owner_pubkey_hex, device_pubkey_hex, key)
431            );",
432        )
433        .unwrap();
434        Arc::new(Mutex::new(conn))
435    }
436
437    fn fresh_adapter() -> SqliteStorageAdapter {
438        SqliteStorageAdapter::new(
439            fresh_connection(),
440            "owner".to_string(),
441            "device".to_string(),
442        )
443    }
444
445    #[test]
446    fn put_get_del_round_trip() {
447        let adapter = fresh_adapter();
448        assert!(adapter.get("k").unwrap().is_none());
449        adapter.put("k", "v".to_string()).unwrap();
450        assert_eq!(adapter.get("k").unwrap(), Some("v".to_string()));
451        adapter.put("k", "v2".to_string()).unwrap();
452        assert_eq!(adapter.get("k").unwrap(), Some("v2".to_string()));
453        adapter.del("k").unwrap();
454        assert!(adapter.get("k").unwrap().is_none());
455    }
456
457    #[test]
458    fn list_returns_only_matching_prefix() {
459        let adapter = fresh_adapter();
460        adapter.put("user/alice", "1".to_string()).unwrap();
461        adapter.put("user/bob", "2".to_string()).unwrap();
462        adapter.put("invite/charlie", "3".to_string()).unwrap();
463        let mut keys = adapter.list("user/").unwrap();
464        keys.sort();
465        assert_eq!(keys, vec!["user/alice".to_string(), "user/bob".to_string()]);
466    }
467
468    #[test]
469    fn keys_are_isolated_per_owner_device() {
470        let conn = fresh_connection();
471        let alice = SqliteStorageAdapter::new(conn.clone(), "owner_a".into(), "device_a".into());
472        let bob = SqliteStorageAdapter::new(conn, "owner_b".into(), "device_b".into());
473        alice.put("shared-key", "alice".to_string()).unwrap();
474        bob.put("shared-key", "bob".to_string()).unwrap();
475        assert_eq!(alice.get("shared-key").unwrap(), Some("alice".to_string()));
476        assert_eq!(bob.get("shared-key").unwrap(), Some("bob".to_string()));
477    }
478
479    #[test]
480    fn file_storage_round_trips_values() {
481        let temp_dir = TempDir::new().unwrap();
482        let adapter = FileStorageAdapter::new(temp_dir.path().to_path_buf()).unwrap();
483
484        assert!(adapter.get("test-key").unwrap().is_none());
485
486        adapter.put("test-key", "test-value".to_string()).unwrap();
487        assert_eq!(
488            adapter.get("test-key").unwrap(),
489            Some("test-value".to_string())
490        );
491
492        adapter.del("test-key").unwrap();
493        assert!(adapter.get("test-key").unwrap().is_none());
494    }
495
496    #[test]
497    fn file_storage_lists_sanitized_runtime_keys() {
498        let temp_dir = TempDir::new().unwrap();
499        let adapter = FileStorageAdapter::new(temp_dir.path().to_path_buf()).unwrap();
500
501        adapter.put("user/alice", "1".to_string()).unwrap();
502        adapter.put("user/bob", "2".to_string()).unwrap();
503        adapter.put("invite/charlie", "3".to_string()).unwrap();
504
505        let mut user_keys = adapter.list("user/").unwrap();
506        user_keys.sort();
507        assert_eq!(
508            user_keys,
509            vec!["user/alice".to_string(), "user/bob".to_string()]
510        );
511
512        let mut all_keys = adapter.list("").unwrap();
513        all_keys.sort();
514        assert_eq!(
515            all_keys,
516            vec![
517                "invite_charlie".to_string(),
518                "user_alice".to_string(),
519                "user_bob".to_string()
520            ]
521        );
522    }
523
524    #[test]
525    fn debounced_file_storage_reads_pending_writes_and_flushes() {
526        let temp_dir = TempDir::new().unwrap();
527        let storage = DebouncedFileStorage::new(temp_dir.path().to_path_buf(), 1000).unwrap();
528
529        storage.put("key1", "value1".to_string()).unwrap();
530
531        assert_eq!(storage.get("key1").unwrap(), Some("value1".to_string()));
532        assert!(storage.pending_writes.lock().unwrap().contains_key("key1"));
533
534        storage.flush().unwrap();
535
536        assert!(storage.pending_writes.lock().unwrap().is_empty());
537        assert_eq!(
538            storage.adapter.get("key1").unwrap(),
539            Some("value1".to_string())
540        );
541    }
542}