1use loro::{LoroDoc, LoroMap, LoroValue, ValueOrContainer};
8
9use crate::error::{CrdtError, Result};
10
11pub struct CrdtState {
13 doc: LoroDoc,
14 peer_id: u64,
15}
16
17impl CrdtState {
18 pub fn new(peer_id: u64) -> Result<Self> {
20 let doc = LoroDoc::new();
21 doc.set_peer_id(peer_id)
22 .map_err(|e| CrdtError::Loro(format!("failed to set peer_id {peer_id}: {e}")))?;
23 Ok(Self { doc, peer_id })
24 }
25
26 pub fn upsert(
28 &self,
29 collection: &str,
30 row_id: &str,
31 fields: &[(&str, LoroValue)],
32 ) -> Result<()> {
33 let coll = self.doc.get_map(collection);
34 let row_container = coll
35 .insert_container(row_id, LoroMap::new())
36 .map_err(|e| CrdtError::Loro(e.to_string()))?;
37 for (field, value) in fields {
38 row_container
39 .insert(field, value.clone())
40 .map_err(|e| CrdtError::Loro(e.to_string()))?;
41 }
42 Ok(())
43 }
44
45 pub fn delete(&self, collection: &str, row_id: &str) -> Result<()> {
47 let coll = self.doc.get_map(collection);
48 coll.delete(row_id)
49 .map_err(|e| CrdtError::Loro(e.to_string()))?;
50 Ok(())
51 }
52
53 pub fn clear_collection(&self, collection: &str) -> Result<usize> {
55 let coll = self.doc.get_map(collection);
56 let keys: Vec<String> = coll.keys().map(|k| k.to_string()).collect();
57 let count = keys.len();
58 for key in &keys {
59 coll.delete(key)
60 .map_err(|e| CrdtError::Loro(e.to_string()))?;
61 }
62 Ok(count)
63 }
64
65 pub fn read_row(&self, collection: &str, row_id: &str) -> Option<LoroValue> {
70 let coll = self.doc.get_map(collection);
71 match coll.get(row_id)? {
72 ValueOrContainer::Container(loro::Container::Map(m)) => Some(m.get_value()),
73 ValueOrContainer::Container(loro::Container::List(l)) => Some(l.get_value()),
74 ValueOrContainer::Container(_) => Some(LoroValue::Null),
75 ValueOrContainer::Value(v) => Some(v),
76 }
77 }
78
79 pub fn read_field(&self, collection: &str, row_id: &str, field: &str) -> Option<LoroValue> {
88 let coll = self.doc.get_map(collection);
89 let row_map = match coll.get(row_id)? {
90 ValueOrContainer::Container(loro::Container::Map(m)) => m,
91 ValueOrContainer::Value(v) => return Some(v),
92 _ => return None,
93 };
94 match row_map.get(field)? {
95 ValueOrContainer::Value(v) => Some(v),
96 ValueOrContainer::Container(loro::Container::Map(m)) => Some(m.get_value()),
97 ValueOrContainer::Container(loro::Container::List(l)) => Some(l.get_value()),
98 ValueOrContainer::Container(_) => Some(LoroValue::Null),
99 }
100 }
101
102 pub fn row_exists(&self, collection: &str, row_id: &str) -> bool {
104 let coll = self.doc.get_map(collection);
105 coll.get(row_id).is_some()
106 }
107
108 pub fn collection_names(&self) -> Vec<String> {
110 let root = self.doc.get_deep_value();
111 match root {
112 LoroValue::Map(map) => map.keys().map(|k| k.to_string()).collect(),
113 _ => Vec::new(),
114 }
115 }
116
117 pub fn row_ids(&self, collection: &str) -> Vec<String> {
119 let coll = self.doc.get_map(collection);
120 coll.keys().map(|k| k.to_string()).collect()
121 }
122
123 pub fn field_value_exists(&self, collection: &str, field: &str, value: &LoroValue) -> bool {
126 let coll = self.doc.get_map(collection);
127 for key in coll.keys() {
128 let path = format!("{collection}/{key}/{field}");
129 if let Some(voc) = self.doc.get_by_str_path(&path) {
130 let field_val = match voc {
131 ValueOrContainer::Value(v) => v,
132 ValueOrContainer::Container(_) => {
133 continue;
134 }
135 };
136 if &field_val == value {
137 return true;
138 }
139 }
140 }
141 false
142 }
143
144 pub fn export_snapshot(&self) -> Result<Vec<u8>> {
146 self.doc
147 .export(loro::ExportMode::Snapshot)
148 .map_err(|e| CrdtError::Loro(format!("snapshot export failed: {e}")))
149 }
150
151 pub fn import(&self, data: &[u8]) -> Result<()> {
153 self.doc
154 .import(data)
155 .map_err(|e| CrdtError::DeltaApplyFailed(e.to_string()))?;
156 Ok(())
157 }
158
159 pub fn doc(&self) -> &LoroDoc {
161 &self.doc
162 }
163
164 pub fn peer_id(&self) -> u64 {
166 self.peer_id
167 }
168
169 pub fn compact_history(&mut self) -> Result<()> {
186 let frontiers = self.doc.oplog_frontiers();
188 let snapshot = self
189 .doc
190 .export(loro::ExportMode::shallow_snapshot(&frontiers))
191 .map_err(|e| CrdtError::Loro(format!("shallow snapshot export: {e}")))?;
192
193 let new_doc = LoroDoc::new();
195 new_doc
196 .set_peer_id(self.peer_id)
197 .map_err(|e| CrdtError::Loro(format!("failed to set peer_id on compacted doc: {e}")))?;
198 new_doc
199 .import(&snapshot)
200 .map_err(|e| CrdtError::Loro(format!("shallow snapshot import: {e}")))?;
201
202 self.doc = new_doc;
203 Ok(())
204 }
205
206 pub fn oplog_version_vector(&self) -> loro::VersionVector {
210 self.doc.oplog_vv()
211 }
212
213 pub fn read_at_version(
220 &self,
221 collection: &str,
222 row_id: &str,
223 version: &loro::VersionVector,
224 ) -> Result<Option<LoroValue>> {
225 let frontiers = self.doc.vv_to_frontiers(version);
226 let forked = self.doc.fork_at(&frontiers);
227
228 let coll = forked.get_map(collection);
229 match coll.get(row_id) {
230 Some(ValueOrContainer::Container(loro::Container::Map(m))) => Ok(Some(m.get_value())),
231 Some(ValueOrContainer::Container(loro::Container::List(l))) => Ok(Some(l.get_value())),
232 Some(ValueOrContainer::Value(v)) => Ok(Some(v)),
233 Some(ValueOrContainer::Container(_)) => Ok(Some(LoroValue::Null)),
234 None => Ok(None),
235 }
236 }
237
238 pub fn export_updates_since(&self, from_version: &loro::VersionVector) -> Result<Vec<u8>> {
243 self.doc
244 .export(loro::ExportMode::updates(from_version))
245 .map_err(|e| CrdtError::Loro(format!("delta export: {e}")))
246 }
247
248 pub fn compact_at_version(&mut self, version: &loro::VersionVector) -> Result<()> {
253 let frontiers = self.doc.vv_to_frontiers(version);
254 let snapshot = self
255 .doc
256 .export(loro::ExportMode::shallow_snapshot(&frontiers))
257 .map_err(|e| CrdtError::Loro(format!("shallow snapshot export: {e}")))?;
258
259 let new_doc = LoroDoc::new();
260 new_doc
261 .set_peer_id(self.peer_id)
262 .map_err(|e| CrdtError::Loro(format!("set peer_id on compacted doc: {e}")))?;
263 new_doc
264 .import(&snapshot)
265 .map_err(|e| CrdtError::Loro(format!("shallow snapshot import: {e}")))?;
266
267 self.doc = new_doc;
268 Ok(())
269 }
270
271 pub fn restore_to_version(
279 &self,
280 collection: &str,
281 row_id: &str,
282 version: &loro::VersionVector,
283 ) -> Result<Vec<u8>> {
284 let historical = self
285 .read_at_version(collection, row_id, version)?
286 .ok_or_else(|| CrdtError::Loro("document did not exist at target version".into()))?;
287
288 let vv_before = self.doc.oplog_vv();
289
290 let fields: Vec<(&str, LoroValue)> = match &historical {
291 LoroValue::Map(map) => map.iter().map(|(k, v)| (k.as_ref(), v.clone())).collect(),
292 _ => return Err(CrdtError::Loro("historical state is not a map".into())),
293 };
294 self.upsert(collection, row_id, &fields)?;
295
296 self.doc
297 .export(loro::ExportMode::updates(&vv_before))
298 .map_err(|e| CrdtError::Loro(format!("restore delta export: {e}")))
299 }
300
301 pub fn estimated_memory_bytes(&self) -> usize {
306 self.doc
310 .export(loro::ExportMode::Snapshot)
311 .map(|s| s.len())
312 .unwrap_or(0)
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn upsert_and_check_existence() {
322 let state = CrdtState::new(1).unwrap();
323 state
324 .upsert(
325 "users",
326 "user-1",
327 &[
328 ("name", LoroValue::String("Alice".into())),
329 ("email", LoroValue::String("alice@example.com".into())),
330 ],
331 )
332 .unwrap();
333
334 assert!(state.row_exists("users", "user-1"));
335 assert!(!state.row_exists("users", "user-2"));
336 }
337
338 #[test]
339 fn delete_row() {
340 let state = CrdtState::new(1).unwrap();
341 state
342 .upsert(
343 "users",
344 "user-1",
345 &[("name", LoroValue::String("Alice".into()))],
346 )
347 .unwrap();
348
349 assert!(state.row_exists("users", "user-1"));
350 state.delete("users", "user-1").unwrap();
351 assert!(!state.row_exists("users", "user-1"));
352 }
353
354 #[test]
355 fn row_ids_listing() {
356 let state = CrdtState::new(1).unwrap();
357 state
358 .upsert("users", "a", &[("x", LoroValue::I64(1))])
359 .unwrap();
360 state
361 .upsert("users", "b", &[("x", LoroValue::I64(2))])
362 .unwrap();
363
364 let mut ids = state.row_ids("users");
365 ids.sort();
366 assert_eq!(ids, vec!["a", "b"]);
367 }
368
369 #[test]
370 fn field_value_uniqueness_check() {
371 let state = CrdtState::new(1).unwrap();
372 state
373 .upsert(
374 "users",
375 "u1",
376 &[("email", LoroValue::String("alice@example.com".into()))],
377 )
378 .unwrap();
379
380 assert!(state.field_value_exists(
381 "users",
382 "email",
383 &LoroValue::String("alice@example.com".into())
384 ));
385 assert!(!state.field_value_exists(
386 "users",
387 "email",
388 &LoroValue::String("bob@example.com".into())
389 ));
390 }
391
392 #[test]
393 fn compact_history_preserves_state() {
394 let mut state = CrdtState::new(1).unwrap();
395 state
397 .upsert(
398 "users",
399 "u1",
400 &[("name", LoroValue::String("Alice".into()))],
401 )
402 .unwrap();
403 state
404 .upsert("users", "u2", &[("name", LoroValue::String("Bob".into()))])
405 .unwrap();
406 state
408 .upsert(
409 "users",
410 "u1",
411 &[("name", LoroValue::String("Alice Updated".into()))],
412 )
413 .unwrap();
414
415 state.compact_history().unwrap();
417
418 assert!(state.row_exists("users", "u1"));
420 assert!(state.row_exists("users", "u2"));
421
422 state
424 .upsert(
425 "users",
426 "u3",
427 &[("name", LoroValue::String("Carol".into()))],
428 )
429 .unwrap();
430 assert!(state.row_exists("users", "u3"));
431 }
432
433 #[test]
434 fn estimated_memory_grows_with_data() {
435 let state = CrdtState::new(1).unwrap();
436 let before = state.estimated_memory_bytes();
437
438 for i in 0..100 {
439 state
440 .upsert(
441 "items",
442 &format!("item-{i}"),
443 &[("value", LoroValue::I64(i))],
444 )
445 .unwrap();
446 }
447
448 let after = state.estimated_memory_bytes();
449 assert!(
450 after > before,
451 "memory should grow: before={before}, after={after}"
452 );
453 }
454
455 #[test]
456 fn snapshot_roundtrip() {
457 let state1 = CrdtState::new(1).unwrap();
458 state1
459 .upsert("users", "u1", &[("name", LoroValue::String("Bob".into()))])
460 .unwrap();
461
462 let snapshot = state1.export_snapshot().unwrap();
463
464 let state2 = CrdtState::new(2).unwrap();
465 state2.import(&snapshot).unwrap();
466
467 assert!(state2.row_exists("users", "u1"));
468 }
469}