1use std::collections::HashMap;
7use std::sync::{Arc, RwLock};
8
9use yrs::{Doc, ReadTxn, Transact, Update, updates::decoder::Decode};
10
11use crate::crdt_storage::{CrdtStorage, CrdtUpdate, StorageResult, UpdateOrigin};
12
13const AUTO_COMPACT_THRESHOLD: usize = 1000;
15
16const AUTO_COMPACT_KEEP: usize = 500;
18
19#[derive(Debug, Default)]
28pub struct MemoryStorage {
29 docs: Arc<RwLock<HashMap<String, Vec<u8>>>>,
31
32 updates: Arc<RwLock<HashMap<String, Vec<StoredUpdate>>>>,
34
35 next_id: Arc<RwLock<i64>>,
37}
38
39#[derive(Debug, Clone)]
40struct StoredUpdate {
41 id: i64,
42 data: Vec<u8>,
43 timestamp: i64,
44 origin: UpdateOrigin,
45 device_id: Option<String>,
46 device_name: Option<String>,
47}
48
49impl MemoryStorage {
50 pub fn new() -> Self {
52 Self::default()
53 }
54
55 fn next_update_id(&self) -> i64 {
56 let mut id = self.next_id.write().unwrap();
57 *id += 1;
58 *id
59 }
60}
61
62impl CrdtStorage for MemoryStorage {
63 fn load_doc(&self, name: &str) -> StorageResult<Option<Vec<u8>>> {
64 let docs = self.docs.read().unwrap();
65 Ok(docs.get(name).cloned())
66 }
67
68 fn save_doc(&self, name: &str, state: &[u8]) -> StorageResult<()> {
69 let mut docs = self.docs.write().unwrap();
70 docs.insert(name.to_string(), state.to_vec());
71 Ok(())
72 }
73
74 fn delete_doc(&self, name: &str) -> StorageResult<()> {
75 let mut docs = self.docs.write().unwrap();
76 let mut updates = self.updates.write().unwrap();
77 docs.remove(name);
78 updates.remove(name);
79 Ok(())
80 }
81
82 fn list_docs(&self) -> StorageResult<Vec<String>> {
83 let docs = self.docs.read().unwrap();
84 Ok(docs.keys().cloned().collect())
85 }
86
87 fn append_update_with_device(
88 &self,
89 name: &str,
90 update: &[u8],
91 origin: UpdateOrigin,
92 device_id: Option<&str>,
93 device_name: Option<&str>,
94 ) -> StorageResult<i64> {
95 let id = self.next_update_id();
96 let stored = StoredUpdate {
97 id,
98 data: update.to_vec(),
99 timestamp: crate::time::now_timestamp_millis(),
100 origin,
101 device_id: device_id.map(String::from),
102 device_name: device_name.map(String::from),
103 };
104
105 let mut updates = self.updates.write().unwrap();
106 let doc_updates = updates.entry(name.to_string()).or_default();
107 doc_updates.push(stored);
108
109 if doc_updates.len() > AUTO_COMPACT_THRESHOLD {
111 let drain_count = doc_updates.len() - AUTO_COMPACT_KEEP;
112 doc_updates.drain(0..drain_count);
113 }
114
115 Ok(id)
116 }
117
118 fn get_updates_since(&self, name: &str, since_id: i64) -> StorageResult<Vec<CrdtUpdate>> {
119 let updates = self.updates.read().unwrap();
120 let doc_updates = updates.get(name).map(|u| u.as_slice()).unwrap_or(&[]);
121
122 Ok(doc_updates
123 .iter()
124 .filter(|u| u.id > since_id)
125 .map(|u| CrdtUpdate {
126 update_id: u.id,
127 doc_name: name.to_string(),
128 data: u.data.clone(),
129 timestamp: u.timestamp,
130 origin: u.origin,
131 device_id: u.device_id.clone(),
132 device_name: u.device_name.clone(),
133 })
134 .collect())
135 }
136
137 fn get_all_updates(&self, name: &str) -> StorageResult<Vec<CrdtUpdate>> {
138 self.get_updates_since(name, 0)
139 }
140
141 fn get_state_at(&self, name: &str, update_id: i64) -> StorageResult<Option<Vec<u8>>> {
142 let base_state = self.load_doc(name)?;
144
145 let updates_lock = self.updates.read().unwrap();
147 let doc_updates: Vec<Vec<u8>> = updates_lock
148 .get(name)
149 .map(|updates| {
150 updates
151 .iter()
152 .filter(|u| u.id <= update_id)
153 .map(|u| u.data.clone())
154 .collect()
155 })
156 .unwrap_or_default();
157
158 if base_state.is_none() && doc_updates.is_empty() {
160 return Ok(None);
161 }
162
163 let doc = Doc::new();
165 {
166 let mut txn = doc.transact_mut();
167
168 if let Some(state) = &base_state
170 && let Ok(update) = Update::decode_v1(state)
171 && let Err(e) = txn.apply_update(update)
172 {
173 log::warn!("Failed to apply base state for {}: {}", name, e);
174 }
175
176 for update_data in doc_updates {
178 if let Ok(update) = Update::decode_v1(&update_data)
179 && let Err(e) = txn.apply_update(update)
180 {
181 log::warn!("Failed to apply incremental update for {}: {}", name, e);
182 }
183 }
184 }
185
186 let txn = doc.transact();
188 Ok(Some(txn.encode_state_as_update_v1(&Default::default())))
189 }
190
191 fn compact(&self, name: &str, keep_updates: usize) -> StorageResult<()> {
192 let mut updates = self.updates.write().unwrap();
193
194 if let Some(doc_updates) = updates.get_mut(name)
195 && doc_updates.len() > keep_updates
196 {
197 let drain_count = doc_updates.len() - keep_updates;
199 doc_updates.drain(0..drain_count);
200 }
201
202 Ok(())
203 }
204
205 fn get_latest_update_id(&self, name: &str) -> StorageResult<i64> {
206 let updates = self.updates.read().unwrap();
207 Ok(updates
208 .get(name)
209 .and_then(|u| u.last())
210 .map(|u| u.id)
211 .unwrap_or(0))
212 }
213
214 fn rename_doc(&self, old_name: &str, new_name: &str) -> StorageResult<()> {
215 {
217 let mut docs = self.docs.write().unwrap();
218 if let Some(state) = docs.remove(old_name) {
219 docs.insert(new_name.to_string(), state);
220 }
221 }
222
223 {
225 let mut updates = self.updates.write().unwrap();
226 if let Some(old_updates) = updates.remove(old_name) {
227 let new_updates: Vec<StoredUpdate> = old_updates
228 .into_iter()
229 .map(|u| StoredUpdate {
230 id: u.id,
231 data: u.data,
232 timestamp: u.timestamp,
233 origin: u.origin,
234 device_id: u.device_id,
235 device_name: u.device_name,
236 })
237 .collect();
238 updates.insert(new_name.to_string(), new_updates);
239 }
240 }
241
242 Ok(())
243 }
244
245 fn clear_updates(&self, name: &str) -> StorageResult<()> {
246 let mut updates = self.updates.write().unwrap();
247 if let Some(doc_updates) = updates.get_mut(name) {
248 doc_updates.clear();
249 }
250 Ok(())
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn test_save_and_load_doc() {
260 let storage = MemoryStorage::new();
261 let data = b"test document state";
262
263 storage.save_doc("test", data).unwrap();
264 let loaded = storage.load_doc("test").unwrap();
265
266 assert_eq!(loaded, Some(data.to_vec()));
267 }
268
269 #[test]
270 fn test_load_nonexistent_doc() {
271 let storage = MemoryStorage::new();
272 let loaded = storage.load_doc("nonexistent").unwrap();
273 assert!(loaded.is_none());
274 }
275
276 #[test]
277 fn test_delete_doc() {
278 let storage = MemoryStorage::new();
279 storage.save_doc("test", b"data").unwrap();
280 storage
281 .append_update("test", b"update", UpdateOrigin::Local)
282 .unwrap();
283
284 storage.delete_doc("test").unwrap();
285
286 assert!(storage.load_doc("test").unwrap().is_none());
287 assert!(storage.get_all_updates("test").unwrap().is_empty());
288 }
289
290 #[test]
291 fn test_list_docs() {
292 let storage = MemoryStorage::new();
293 storage.save_doc("doc1", b"data1").unwrap();
294 storage.save_doc("doc2", b"data2").unwrap();
295
296 let mut docs = storage.list_docs().unwrap();
297 docs.sort();
298
299 assert_eq!(docs, vec!["doc1", "doc2"]);
300 }
301
302 #[test]
303 fn test_append_and_get_updates() {
304 let storage = MemoryStorage::new();
305
306 let id1 = storage
307 .append_update("test", b"update1", UpdateOrigin::Local)
308 .unwrap();
309 let id2 = storage
310 .append_update("test", b"update2", UpdateOrigin::Remote)
311 .unwrap();
312 let id3 = storage
313 .append_update("test", b"update3", UpdateOrigin::Sync)
314 .unwrap();
315
316 assert!(id1 < id2);
317 assert!(id2 < id3);
318
319 let all = storage.get_all_updates("test").unwrap();
320 assert_eq!(all.len(), 3);
321 assert_eq!(all[0].origin, UpdateOrigin::Local);
322 assert_eq!(all[1].origin, UpdateOrigin::Remote);
323
324 let since_id1 = storage.get_updates_since("test", id1).unwrap();
325 assert_eq!(since_id1.len(), 2);
326 assert_eq!(since_id1[0].update_id, id2);
327 }
328
329 #[test]
330 fn test_compact() {
331 let storage = MemoryStorage::new();
332
333 for i in 0..10 {
334 storage
335 .append_update(
336 "test",
337 format!("update{}", i).as_bytes(),
338 UpdateOrigin::Local,
339 )
340 .unwrap();
341 }
342
343 assert_eq!(storage.get_all_updates("test").unwrap().len(), 10);
344
345 storage.compact("test", 3).unwrap();
346
347 let remaining = storage.get_all_updates("test").unwrap();
348 assert_eq!(remaining.len(), 3);
349 }
350
351 #[test]
352 fn test_get_latest_update_id() {
353 let storage = MemoryStorage::new();
354
355 assert_eq!(storage.get_latest_update_id("test").unwrap(), 0);
356
357 let id1 = storage
358 .append_update("test", b"update1", UpdateOrigin::Local)
359 .unwrap();
360 assert_eq!(storage.get_latest_update_id("test").unwrap(), id1);
361
362 let id2 = storage
363 .append_update("test", b"update2", UpdateOrigin::Local)
364 .unwrap();
365 assert_eq!(storage.get_latest_update_id("test").unwrap(), id2);
366 }
367
368 #[test]
369 fn test_get_state_at_reconstructs_history() {
370 use yrs::{GetString, Text, Transact};
371
372 let storage = MemoryStorage::new();
373
374 let doc = Doc::new();
376 let text = doc.get_or_insert_text("content");
377
378 let update1 = {
380 let mut txn = doc.transact_mut();
381 text.insert(&mut txn, 0, "Hello");
382 txn.encode_update_v1()
383 };
384 let id1 = storage
385 .append_update("test", &update1, UpdateOrigin::Local)
386 .unwrap();
387
388 let update2 = {
390 let mut txn = doc.transact_mut();
391 text.insert(&mut txn, 5, " World");
392 txn.encode_update_v1()
393 };
394 let id2 = storage
395 .append_update("test", &update2, UpdateOrigin::Local)
396 .unwrap();
397
398 let update3 = {
400 let mut txn = doc.transact_mut();
401 text.insert(&mut txn, 11, "!");
402 txn.encode_update_v1()
403 };
404 let _id3 = storage
405 .append_update("test", &update3, UpdateOrigin::Local)
406 .unwrap();
407
408 {
410 let txn = doc.transact();
411 assert_eq!(text.get_string(&txn), "Hello World!");
412 }
413
414 let state_at_1 = storage.get_state_at("test", id1).unwrap().unwrap();
416 let doc_at_1 = Doc::new();
417 {
418 let mut txn = doc_at_1.transact_mut();
419 let update = Update::decode_v1(&state_at_1).unwrap();
420 txn.apply_update(update).unwrap();
421 }
422 let text_at_1 = doc_at_1.get_or_insert_text("content");
423 {
424 let txn = doc_at_1.transact();
425 assert_eq!(text_at_1.get_string(&txn), "Hello");
426 }
427
428 let state_at_2 = storage.get_state_at("test", id2).unwrap().unwrap();
430 let doc_at_2 = Doc::new();
431 {
432 let mut txn = doc_at_2.transact_mut();
433 let update = Update::decode_v1(&state_at_2).unwrap();
434 txn.apply_update(update).unwrap();
435 }
436 let text_at_2 = doc_at_2.get_or_insert_text("content");
437 {
438 let txn = doc_at_2.transact();
439 assert_eq!(text_at_2.get_string(&txn), "Hello World");
440 }
441 }
442
443 #[test]
444 fn test_get_state_at_nonexistent() {
445 let storage = MemoryStorage::new();
446
447 let result = storage.get_state_at("nonexistent", 1).unwrap();
449 assert!(result.is_none());
450 }
451
452 #[test]
453 fn test_rename_doc() {
454 let storage = MemoryStorage::new();
455
456 storage.save_doc("old_name", b"test state").unwrap();
458 storage
459 .append_update("old_name", b"update1", UpdateOrigin::Local)
460 .unwrap();
461 storage
462 .append_update("old_name", b"update2", UpdateOrigin::Remote)
463 .unwrap();
464
465 assert!(storage.load_doc("old_name").unwrap().is_some());
467 assert_eq!(storage.get_all_updates("old_name").unwrap().len(), 2);
468
469 storage.rename_doc("old_name", "new_name").unwrap();
471
472 assert!(storage.load_doc("old_name").unwrap().is_none());
474 assert!(storage.get_all_updates("old_name").unwrap().is_empty());
475
476 assert_eq!(
478 storage.load_doc("new_name").unwrap(),
479 Some(b"test state".to_vec())
480 );
481 let updates = storage.get_all_updates("new_name").unwrap();
482 assert_eq!(updates.len(), 2);
483 assert_eq!(updates[0].origin, UpdateOrigin::Local);
484 assert_eq!(updates[1].origin, UpdateOrigin::Remote);
485 }
486
487 #[test]
488 fn test_rename_doc_nonexistent() {
489 let storage = MemoryStorage::new();
490
491 let result = storage.rename_doc("nonexistent", "new_name");
493 assert!(result.is_ok());
494
495 assert!(storage.load_doc("nonexistent").unwrap().is_none());
497 assert!(storage.load_doc("new_name").unwrap().is_none());
498 }
499
500 #[test]
501 fn test_clear_updates() {
502 let storage = MemoryStorage::new();
503
504 storage.save_doc("test", b"snapshot").unwrap();
506 storage
507 .append_update("test", b"update1", UpdateOrigin::Local)
508 .unwrap();
509 storage
510 .append_update("test", b"update2", UpdateOrigin::Remote)
511 .unwrap();
512
513 assert_eq!(storage.get_all_updates("test").unwrap().len(), 2);
515
516 storage.clear_updates("test").unwrap();
518
519 assert!(storage.get_all_updates("test").unwrap().is_empty());
521 assert_eq!(
522 storage.load_doc("test").unwrap(),
523 Some(b"snapshot".to_vec())
524 );
525 }
526
527 #[test]
528 fn test_clear_updates_nonexistent() {
529 let storage = MemoryStorage::new();
530
531 let result = storage.clear_updates("nonexistent");
533 assert!(result.is_ok());
534 }
535}