Skip to main content

diaryx_sync/
memory_storage.rs

1//! In-memory storage implementation for testing and WASM.
2//!
3//! This provides a simple in-memory implementation of [`CrdtStorage`]
4//! for use in unit tests, development, and WASM environments.
5
6use std::collections::HashMap;
7use std::sync::{Arc, RwLock};
8
9use yrs::{Doc, ReadTxn, Transact, Update, updates::decoder::Decode};
10
11use crate::crdt_storage::{CrdtStorage, CrdtUpdate, StorageResult, UpdateOrigin};
12
13/// Threshold for triggering auto-compaction (number of updates)
14const AUTO_COMPACT_THRESHOLD: usize = 1000;
15
16/// Number of updates to keep after auto-compaction
17const AUTO_COMPACT_KEEP: usize = 500;
18
19/// In-memory CRDT storage for testing.
20///
21/// This implementation stores all data in memory using `HashMap` and `Vec`.
22/// It's thread-safe via `RwLock` but data is lost when dropped.
23///
24/// Auto-compaction is triggered when the number of updates for a document
25/// exceeds [`AUTO_COMPACT_THRESHOLD`], keeping the most recent
26/// [`AUTO_COMPACT_KEEP`] updates.
27#[derive(Debug, Default)]
28pub struct MemoryStorage {
29    /// Document snapshots (name -> binary state)
30    docs: Arc<RwLock<HashMap<String, Vec<u8>>>>,
31
32    /// Update logs (name -> list of updates)
33    updates: Arc<RwLock<HashMap<String, Vec<StoredUpdate>>>>,
34
35    /// Counter for generating update IDs
36    next_id: Arc<RwLock<i64>>,
37}
38
39#[derive(Debug, Clone)]
40struct StoredUpdate {
41    id: i64,
42    data: Vec<u8>,
43    timestamp: i64,
44    origin: UpdateOrigin,
45    device_id: Option<String>,
46    device_name: Option<String>,
47}
48
49impl MemoryStorage {
50    /// Create a new empty in-memory storage.
51    pub fn new() -> Self {
52        Self::default()
53    }
54
55    fn next_update_id(&self) -> i64 {
56        let mut id = self.next_id.write().unwrap();
57        *id += 1;
58        *id
59    }
60}
61
62impl CrdtStorage for MemoryStorage {
63    fn load_doc(&self, name: &str) -> StorageResult<Option<Vec<u8>>> {
64        let docs = self.docs.read().unwrap();
65        Ok(docs.get(name).cloned())
66    }
67
68    fn save_doc(&self, name: &str, state: &[u8]) -> StorageResult<()> {
69        let mut docs = self.docs.write().unwrap();
70        docs.insert(name.to_string(), state.to_vec());
71        Ok(())
72    }
73
74    fn delete_doc(&self, name: &str) -> StorageResult<()> {
75        let mut docs = self.docs.write().unwrap();
76        let mut updates = self.updates.write().unwrap();
77        docs.remove(name);
78        updates.remove(name);
79        Ok(())
80    }
81
82    fn list_docs(&self) -> StorageResult<Vec<String>> {
83        let docs = self.docs.read().unwrap();
84        Ok(docs.keys().cloned().collect())
85    }
86
87    fn append_update_with_device(
88        &self,
89        name: &str,
90        update: &[u8],
91        origin: UpdateOrigin,
92        device_id: Option<&str>,
93        device_name: Option<&str>,
94    ) -> StorageResult<i64> {
95        let id = self.next_update_id();
96        let stored = StoredUpdate {
97            id,
98            data: update.to_vec(),
99            timestamp: crate::time::now_timestamp_millis(),
100            origin,
101            device_id: device_id.map(String::from),
102            device_name: device_name.map(String::from),
103        };
104
105        let mut updates = self.updates.write().unwrap();
106        let doc_updates = updates.entry(name.to_string()).or_default();
107        doc_updates.push(stored);
108
109        // Auto-compact if we've exceeded the threshold
110        if doc_updates.len() > AUTO_COMPACT_THRESHOLD {
111            let drain_count = doc_updates.len() - AUTO_COMPACT_KEEP;
112            doc_updates.drain(0..drain_count);
113        }
114
115        Ok(id)
116    }
117
118    fn get_updates_since(&self, name: &str, since_id: i64) -> StorageResult<Vec<CrdtUpdate>> {
119        let updates = self.updates.read().unwrap();
120        let doc_updates = updates.get(name).map(|u| u.as_slice()).unwrap_or(&[]);
121
122        Ok(doc_updates
123            .iter()
124            .filter(|u| u.id > since_id)
125            .map(|u| CrdtUpdate {
126                update_id: u.id,
127                doc_name: name.to_string(),
128                data: u.data.clone(),
129                timestamp: u.timestamp,
130                origin: u.origin,
131                device_id: u.device_id.clone(),
132                device_name: u.device_name.clone(),
133            })
134            .collect())
135    }
136
137    fn get_all_updates(&self, name: &str) -> StorageResult<Vec<CrdtUpdate>> {
138        self.get_updates_since(name, 0)
139    }
140
141    fn get_state_at(&self, name: &str, update_id: i64) -> StorageResult<Option<Vec<u8>>> {
142        // Load base document snapshot
143        let base_state = self.load_doc(name)?;
144
145        // Get updates up to the specified ID
146        let updates_lock = self.updates.read().unwrap();
147        let doc_updates: Vec<Vec<u8>> = updates_lock
148            .get(name)
149            .map(|updates| {
150                updates
151                    .iter()
152                    .filter(|u| u.id <= update_id)
153                    .map(|u| u.data.clone())
154                    .collect()
155            })
156            .unwrap_or_default();
157
158        // If no base state and no updates, return None
159        if base_state.is_none() && doc_updates.is_empty() {
160            return Ok(None);
161        }
162
163        // Create a new doc and apply all state
164        let doc = Doc::new();
165        {
166            let mut txn = doc.transact_mut();
167
168            // Apply base state if it exists
169            if let Some(state) = &base_state
170                && let Ok(update) = Update::decode_v1(state)
171                && let Err(e) = txn.apply_update(update)
172            {
173                log::warn!("Failed to apply base state for {}: {}", name, e);
174            }
175
176            // Apply incremental updates up to the specified ID
177            for update_data in doc_updates {
178                if let Ok(update) = Update::decode_v1(&update_data)
179                    && let Err(e) = txn.apply_update(update)
180                {
181                    log::warn!("Failed to apply incremental update for {}: {}", name, e);
182                }
183            }
184        }
185
186        // Encode final state
187        let txn = doc.transact();
188        Ok(Some(txn.encode_state_as_update_v1(&Default::default())))
189    }
190
191    fn compact(&self, name: &str, keep_updates: usize) -> StorageResult<()> {
192        let mut updates = self.updates.write().unwrap();
193
194        if let Some(doc_updates) = updates.get_mut(name)
195            && doc_updates.len() > keep_updates
196        {
197            // Keep only the last `keep_updates` entries
198            let drain_count = doc_updates.len() - keep_updates;
199            doc_updates.drain(0..drain_count);
200        }
201
202        Ok(())
203    }
204
205    fn get_latest_update_id(&self, name: &str) -> StorageResult<i64> {
206        let updates = self.updates.read().unwrap();
207        Ok(updates
208            .get(name)
209            .and_then(|u| u.last())
210            .map(|u| u.id)
211            .unwrap_or(0))
212    }
213
214    fn rename_doc(&self, old_name: &str, new_name: &str) -> StorageResult<()> {
215        // Copy document snapshot
216        {
217            let mut docs = self.docs.write().unwrap();
218            if let Some(state) = docs.remove(old_name) {
219                docs.insert(new_name.to_string(), state);
220            }
221        }
222
223        // Copy updates with new doc_name
224        {
225            let mut updates = self.updates.write().unwrap();
226            if let Some(old_updates) = updates.remove(old_name) {
227                let new_updates: Vec<StoredUpdate> = old_updates
228                    .into_iter()
229                    .map(|u| StoredUpdate {
230                        id: u.id,
231                        data: u.data,
232                        timestamp: u.timestamp,
233                        origin: u.origin,
234                        device_id: u.device_id,
235                        device_name: u.device_name,
236                    })
237                    .collect();
238                updates.insert(new_name.to_string(), new_updates);
239            }
240        }
241
242        Ok(())
243    }
244
245    fn clear_updates(&self, name: &str) -> StorageResult<()> {
246        let mut updates = self.updates.write().unwrap();
247        if let Some(doc_updates) = updates.get_mut(name) {
248            doc_updates.clear();
249        }
250        Ok(())
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_save_and_load_doc() {
260        let storage = MemoryStorage::new();
261        let data = b"test document state";
262
263        storage.save_doc("test", data).unwrap();
264        let loaded = storage.load_doc("test").unwrap();
265
266        assert_eq!(loaded, Some(data.to_vec()));
267    }
268
269    #[test]
270    fn test_load_nonexistent_doc() {
271        let storage = MemoryStorage::new();
272        let loaded = storage.load_doc("nonexistent").unwrap();
273        assert!(loaded.is_none());
274    }
275
276    #[test]
277    fn test_delete_doc() {
278        let storage = MemoryStorage::new();
279        storage.save_doc("test", b"data").unwrap();
280        storage
281            .append_update("test", b"update", UpdateOrigin::Local)
282            .unwrap();
283
284        storage.delete_doc("test").unwrap();
285
286        assert!(storage.load_doc("test").unwrap().is_none());
287        assert!(storage.get_all_updates("test").unwrap().is_empty());
288    }
289
290    #[test]
291    fn test_list_docs() {
292        let storage = MemoryStorage::new();
293        storage.save_doc("doc1", b"data1").unwrap();
294        storage.save_doc("doc2", b"data2").unwrap();
295
296        let mut docs = storage.list_docs().unwrap();
297        docs.sort();
298
299        assert_eq!(docs, vec!["doc1", "doc2"]);
300    }
301
302    #[test]
303    fn test_append_and_get_updates() {
304        let storage = MemoryStorage::new();
305
306        let id1 = storage
307            .append_update("test", b"update1", UpdateOrigin::Local)
308            .unwrap();
309        let id2 = storage
310            .append_update("test", b"update2", UpdateOrigin::Remote)
311            .unwrap();
312        let id3 = storage
313            .append_update("test", b"update3", UpdateOrigin::Sync)
314            .unwrap();
315
316        assert!(id1 < id2);
317        assert!(id2 < id3);
318
319        let all = storage.get_all_updates("test").unwrap();
320        assert_eq!(all.len(), 3);
321        assert_eq!(all[0].origin, UpdateOrigin::Local);
322        assert_eq!(all[1].origin, UpdateOrigin::Remote);
323
324        let since_id1 = storage.get_updates_since("test", id1).unwrap();
325        assert_eq!(since_id1.len(), 2);
326        assert_eq!(since_id1[0].update_id, id2);
327    }
328
329    #[test]
330    fn test_compact() {
331        let storage = MemoryStorage::new();
332
333        for i in 0..10 {
334            storage
335                .append_update(
336                    "test",
337                    format!("update{}", i).as_bytes(),
338                    UpdateOrigin::Local,
339                )
340                .unwrap();
341        }
342
343        assert_eq!(storage.get_all_updates("test").unwrap().len(), 10);
344
345        storage.compact("test", 3).unwrap();
346
347        let remaining = storage.get_all_updates("test").unwrap();
348        assert_eq!(remaining.len(), 3);
349    }
350
351    #[test]
352    fn test_get_latest_update_id() {
353        let storage = MemoryStorage::new();
354
355        assert_eq!(storage.get_latest_update_id("test").unwrap(), 0);
356
357        let id1 = storage
358            .append_update("test", b"update1", UpdateOrigin::Local)
359            .unwrap();
360        assert_eq!(storage.get_latest_update_id("test").unwrap(), id1);
361
362        let id2 = storage
363            .append_update("test", b"update2", UpdateOrigin::Local)
364            .unwrap();
365        assert_eq!(storage.get_latest_update_id("test").unwrap(), id2);
366    }
367
368    #[test]
369    fn test_get_state_at_reconstructs_history() {
370        use yrs::{GetString, Text, Transact};
371
372        let storage = MemoryStorage::new();
373
374        // Create a Y.Doc and make some changes, storing updates
375        let doc = Doc::new();
376        let text = doc.get_or_insert_text("content");
377
378        // First update: add "Hello"
379        let update1 = {
380            let mut txn = doc.transact_mut();
381            text.insert(&mut txn, 0, "Hello");
382            txn.encode_update_v1()
383        };
384        let id1 = storage
385            .append_update("test", &update1, UpdateOrigin::Local)
386            .unwrap();
387
388        // Second update: add " World"
389        let update2 = {
390            let mut txn = doc.transact_mut();
391            text.insert(&mut txn, 5, " World");
392            txn.encode_update_v1()
393        };
394        let id2 = storage
395            .append_update("test", &update2, UpdateOrigin::Local)
396            .unwrap();
397
398        // Third update: add "!"
399        let update3 = {
400            let mut txn = doc.transact_mut();
401            text.insert(&mut txn, 11, "!");
402            txn.encode_update_v1()
403        };
404        let _id3 = storage
405            .append_update("test", &update3, UpdateOrigin::Local)
406            .unwrap();
407
408        // Verify current state is "Hello World!"
409        {
410            let txn = doc.transact();
411            assert_eq!(text.get_string(&txn), "Hello World!");
412        }
413
414        // Get state at id1 - should only have "Hello"
415        let state_at_1 = storage.get_state_at("test", id1).unwrap().unwrap();
416        let doc_at_1 = Doc::new();
417        {
418            let mut txn = doc_at_1.transact_mut();
419            let update = Update::decode_v1(&state_at_1).unwrap();
420            txn.apply_update(update).unwrap();
421        }
422        let text_at_1 = doc_at_1.get_or_insert_text("content");
423        {
424            let txn = doc_at_1.transact();
425            assert_eq!(text_at_1.get_string(&txn), "Hello");
426        }
427
428        // Get state at id2 - should have "Hello World"
429        let state_at_2 = storage.get_state_at("test", id2).unwrap().unwrap();
430        let doc_at_2 = Doc::new();
431        {
432            let mut txn = doc_at_2.transact_mut();
433            let update = Update::decode_v1(&state_at_2).unwrap();
434            txn.apply_update(update).unwrap();
435        }
436        let text_at_2 = doc_at_2.get_or_insert_text("content");
437        {
438            let txn = doc_at_2.transact();
439            assert_eq!(text_at_2.get_string(&txn), "Hello World");
440        }
441    }
442
443    #[test]
444    fn test_get_state_at_nonexistent() {
445        let storage = MemoryStorage::new();
446
447        // No doc, no updates - should return None
448        let result = storage.get_state_at("nonexistent", 1).unwrap();
449        assert!(result.is_none());
450    }
451
452    #[test]
453    fn test_rename_doc() {
454        let storage = MemoryStorage::new();
455
456        // Create a doc with state and updates
457        storage.save_doc("old_name", b"test state").unwrap();
458        storage
459            .append_update("old_name", b"update1", UpdateOrigin::Local)
460            .unwrap();
461        storage
462            .append_update("old_name", b"update2", UpdateOrigin::Remote)
463            .unwrap();
464
465        // Verify old name exists
466        assert!(storage.load_doc("old_name").unwrap().is_some());
467        assert_eq!(storage.get_all_updates("old_name").unwrap().len(), 2);
468
469        // Rename
470        storage.rename_doc("old_name", "new_name").unwrap();
471
472        // Old name should be gone
473        assert!(storage.load_doc("old_name").unwrap().is_none());
474        assert!(storage.get_all_updates("old_name").unwrap().is_empty());
475
476        // New name should have the content
477        assert_eq!(
478            storage.load_doc("new_name").unwrap(),
479            Some(b"test state".to_vec())
480        );
481        let updates = storage.get_all_updates("new_name").unwrap();
482        assert_eq!(updates.len(), 2);
483        assert_eq!(updates[0].origin, UpdateOrigin::Local);
484        assert_eq!(updates[1].origin, UpdateOrigin::Remote);
485    }
486
487    #[test]
488    fn test_rename_doc_nonexistent() {
489        let storage = MemoryStorage::new();
490
491        // Renaming a nonexistent doc should not error
492        let result = storage.rename_doc("nonexistent", "new_name");
493        assert!(result.is_ok());
494
495        // Both should be empty
496        assert!(storage.load_doc("nonexistent").unwrap().is_none());
497        assert!(storage.load_doc("new_name").unwrap().is_none());
498    }
499
500    #[test]
501    fn test_clear_updates() {
502        let storage = MemoryStorage::new();
503
504        // Add some updates and a doc snapshot
505        storage.save_doc("test", b"snapshot").unwrap();
506        storage
507            .append_update("test", b"update1", UpdateOrigin::Local)
508            .unwrap();
509        storage
510            .append_update("test", b"update2", UpdateOrigin::Remote)
511            .unwrap();
512
513        // Verify updates exist
514        assert_eq!(storage.get_all_updates("test").unwrap().len(), 2);
515
516        // Clear updates
517        storage.clear_updates("test").unwrap();
518
519        // Updates should be gone but snapshot should remain
520        assert!(storage.get_all_updates("test").unwrap().is_empty());
521        assert_eq!(
522            storage.load_doc("test").unwrap(),
523            Some(b"snapshot".to_vec())
524        );
525    }
526
527    #[test]
528    fn test_clear_updates_nonexistent() {
529        let storage = MemoryStorage::new();
530
531        // Clearing updates for nonexistent doc should not error
532        let result = storage.clear_updates("nonexistent");
533        assert!(result.is_ok());
534    }
535}