1use std::{
10 borrow::Cow,
11 collections::{BinaryHeap, HashMap},
12 fmt::Debug,
13};
14
15use rustc_hash::{FxHashMap, FxHashSet};
16use loro_common::IdSpanVector;
17use rle::{HasLength, Sliceable};
18use smallvec::SmallVec;
19mod iter;
20mod mermaid;
21#[cfg(feature = "test_utils")]
22mod test;
23#[cfg(feature = "test_utils")]
24pub use test::{fuzz_alloc_tree, Interaction};
25
26use crate::{
27 change::Lamport,
28 diff_calc::DiffMode,
29 id::{Counter, PeerID, ID},
30 span::{CounterSpan, HasId, HasIdSpan, HasLamport, HasLamportSpan, IdSpan},
31 version::{Frontiers, VersionVector, VersionVectorDiff},
32};
33
34use self::{
35 iter::{iter_dag, iter_dag_with_vv, DagCausalIter, DagIterator, DagIteratorVV},
36 mermaid::dag_to_mermaid,
37};
38
39pub(crate) trait DagNode: HasLamport + HasId + HasLength + Debug + Sliceable {
40 fn deps(&self) -> &Frontiers;
41
42 #[allow(unused)]
43 #[inline]
44 fn get_lamport_from_counter(&self, c: Counter) -> Lamport {
45 self.lamport() + c as Lamport - self.id_start().counter as Lamport
46 }
47}
48
49pub(crate) trait Dag: Debug {
54 type Node: DagNode;
55
56 fn get(&self, id: ID) -> Option<Self::Node>;
57 #[allow(unused)]
58 fn frontier(&self) -> &Frontiers;
59 fn vv(&self) -> &VersionVector;
60 fn contains(&self, id: ID) -> bool;
61}
62
63pub(crate) trait DagUtils: Dag {
64 fn find_common_ancestor(&self, a_id: &Frontiers, b_id: &Frontiers) -> (Frontiers, DiffMode);
65 #[allow(unused)]
67 fn get_vv(&self, id: ID) -> VersionVector;
68 #[allow(unused)]
69 fn find_path(&self, from: &Frontiers, to: &Frontiers) -> VersionVectorDiff;
70 fn iter_causal(&self, from: Frontiers, target: IdSpanVector) -> DagCausalIter<'_, Self>
71 where
72 Self: Sized;
73 #[allow(unused)]
74 fn iter(&self) -> DagIterator<'_, Self::Node>
75 where
76 Self: Sized;
77 #[allow(unused)]
78 fn iter_with_vv(&self) -> DagIteratorVV<'_, Self::Node>
79 where
80 Self: Sized;
81 #[allow(unused)]
82 fn mermaid(&self) -> String
83 where
84 Self: Sized;
85}
86
87impl<T: Dag + ?Sized> DagUtils for T {
88 #[inline]
89 fn find_common_ancestor(&self, a_id: &Frontiers, b_id: &Frontiers) -> (Frontiers, DiffMode) {
90 find_common_ancestor(&|id| self.get(id), a_id, b_id)
92 }
93
94 #[inline]
95 fn get_vv(&self, id: ID) -> VersionVector {
96 get_version_vector(&|id| self.get(id), id)
97 }
98
99 fn find_path(&self, from: &Frontiers, to: &Frontiers) -> VersionVectorDiff {
100 let mut ans = VersionVectorDiff::default();
101 if from == to {
102 return ans;
103 }
104
105 if from.len() == 1 && to.len() == 1 {
106 let from = from.as_single().unwrap();
107 let to = to.as_single().unwrap();
108 if from.peer == to.peer {
109 let from_span = self.get(from).unwrap();
110 let to_span = self.get(to).unwrap();
111 if from_span.id_start() == to_span.id_start() {
112 if from.counter < to.counter {
113 ans.forward.insert(
114 from.peer,
115 CounterSpan::new(from.counter + 1, to.counter + 1),
116 );
117 } else {
118 ans.retreat.insert(
119 from.peer,
120 CounterSpan::new(to.counter + 1, from.counter + 1),
121 );
122 }
123 return ans;
124 }
125
126 if from_span.deps().len() == 1
127 && to_span.contains_id(from_span.deps().as_single().unwrap())
128 {
129 ans.retreat.insert(
130 from.peer,
131 CounterSpan::new(to.counter + 1, from.counter + 1),
132 );
133 return ans;
134 }
135
136 if to_span.deps().len() == 1
137 && from_span.contains_id(to_span.deps().as_single().unwrap())
138 {
139 ans.forward.insert(
140 from.peer,
141 CounterSpan::new(from.counter + 1, to.counter + 1),
142 );
143 return ans;
144 }
145 }
146 }
147
148 _find_common_ancestor(
149 &|v| self.get(v),
150 from,
151 to,
152 &mut |span, node_type| match node_type {
153 NodeType::A => ans.merge_left(span),
154 NodeType::B => ans.merge_right(span),
155 NodeType::Shared => {
156 ans.subtract_start_left(span);
157 ans.subtract_start_right(span);
158 }
159 },
160 true,
161 );
162
163 ans
164 }
165
166 #[inline(always)]
167 fn iter_with_vv(&self) -> DagIteratorVV<'_, Self::Node>
168 where
169 Self: Sized,
170 {
171 iter_dag_with_vv(self)
172 }
173
174 #[inline(always)]
175 fn iter_causal(&self, from: Frontiers, target: IdSpanVector) -> DagCausalIter<'_, Self>
176 where
177 Self: Sized,
178 {
179 DagCausalIter::new(self, from, target)
180 }
181
182 #[inline(always)]
183 fn iter(&self) -> DagIterator<'_, Self::Node>
184 where
185 Self: Sized,
186 {
187 iter_dag(self)
188 }
189
190 #[inline]
192 fn mermaid(&self) -> String
193 where
194 Self: Sized,
195 {
196 dag_to_mermaid(self)
197 }
198}
199
200#[allow(dead_code)]
201fn get_version_vector<'a, Get, D>(get: &'a Get, id: ID) -> VersionVector
202where
203 Get: Fn(ID) -> Option<D>,
204 D: DagNode + 'a,
205{
206 let mut vv = VersionVector::new();
207 let mut visited: FxHashSet<ID> = FxHashSet::default();
208 vv.insert(id.peer, id.counter + 1);
209 let node = get(id).unwrap();
210
211 if node.deps().is_empty() {
212 return vv;
213 }
214
215 let mut stack = Vec::with_capacity(node.deps().len());
216 for dep in node.deps().iter() {
217 stack.push(dep);
218 }
219
220 while let Some(node_id) = stack.pop() {
221 let node = get(node_id).unwrap();
222 let node_id_start = node.id_start();
223 if !visited.contains(&node_id_start) {
224 vv.try_update_last(node_id);
225 for dep in node.deps().iter() {
226 if !visited.contains(&dep) {
227 stack.push(dep);
228 }
229 }
230
231 visited.insert(node_id_start);
232 }
233 }
234
235 vv
236}
237
238#[derive(Debug, PartialEq, Eq)]
239struct OrdIdSpan<'a> {
240 id: ID,
241 lamport: Lamport,
242 len: usize,
243 deps: Cow<'a, Frontiers>,
244}
245
246impl HasLength for OrdIdSpan<'_> {
247 fn content_len(&self) -> usize {
248 self.len
249 }
250}
251
252impl HasId for OrdIdSpan<'_> {
253 fn id_start(&self) -> ID {
254 self.id
255 }
256}
257
258impl HasLamport for OrdIdSpan<'_> {
259 fn lamport(&self) -> Lamport {
260 self.lamport
261 }
262}
263
264impl PartialOrd for OrdIdSpan<'_> {
265 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
266 Some(self.cmp(other))
267 }
268}
269
270impl Ord for OrdIdSpan<'_> {
271 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
272 self.lamport_last()
273 .cmp(&other.lamport_last())
274 .then(self.id.peer.cmp(&other.id.peer))
275 .then(other.len.cmp(&self.len))
280 }
281}
282
283#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
284enum NodeType {
285 A,
286 B,
287 Shared,
288}
289
290impl<'a> OrdIdSpan<'a> {
291 #[inline]
292 fn from_dag_node<D, F>(id: ID, get: &'a F) -> Option<OrdIdSpan<'a>>
293 where
294 D: DagNode + 'a,
295 F: Fn(ID) -> Option<D>,
296 {
297 let span = get(id)?;
298 let span_id = span.id_start();
299 Some(OrdIdSpan {
300 id: span_id,
301 lamport: span.lamport(),
302 deps: Cow::Owned(span.deps().clone()),
303 len: (id.counter - span_id.counter) as usize + 1,
304 })
305 }
306
307 #[inline]
308 fn get_min(&self) -> OrdIdSpan<'a> {
309 OrdIdSpan {
310 id: self.id,
311 lamport: self.lamport,
312 deps: Cow::Owned(Default::default()),
313 len: 1,
314 }
315 }
316}
317
318#[inline(always)]
319fn find_common_ancestor<'a, F, D>(
320 get: &'a F,
321 a_id: &Frontiers,
322 b_id: &Frontiers,
323) -> (Frontiers, DiffMode)
324where
325 D: DagNode + 'a,
326 F: Fn(ID) -> Option<D>,
327{
328 if b_id.is_empty() {
329 return (Default::default(), DiffMode::Checkout);
330 }
331
332 _find_common_ancestor_new(get, a_id, b_id)
333}
334
335fn _find_common_ancestor<'a, F, D, G>(
337 get: &'a F,
338 a_ids: &Frontiers,
339 b_ids: &Frontiers,
340 notify: &mut G,
341 find_path: bool,
342) -> FxHashMap<PeerID, Counter>
343where
344 D: DagNode + 'a,
345 F: Fn(ID) -> Option<D>,
346 G: FnMut(IdSpan, NodeType),
347{
348 let mut ans: FxHashMap<PeerID, Counter> = Default::default();
349 let mut queue: BinaryHeap<(OrdIdSpan, NodeType)> = BinaryHeap::new();
350 for id in a_ids.iter() {
351 queue.push((OrdIdSpan::from_dag_node(id, get).unwrap(), NodeType::A));
352 }
353 for id in b_ids.iter() {
354 queue.push((OrdIdSpan::from_dag_node(id, get).unwrap(), NodeType::B));
355 }
356 let mut visited: HashMap<PeerID, (Counter, NodeType), _> = FxHashMap::default();
357 let mut a_count = a_ids.len();
367 let mut b_count = b_ids.len();
368 let mut min = None;
369 while let Some((node, mut node_type)) = queue.pop() {
370 match node_type {
371 NodeType::A => a_count -= 1,
372 NodeType::B => b_count -= 1,
373 NodeType::Shared => {}
374 }
375
376 if node_type != NodeType::Shared {
377 if let Some(min) = &mut min {
378 let node_start = node.get_min();
379 if node_start < *min {
380 *min = node_start;
381 }
382 } else {
383 min = Some(node.get_min())
384 }
385 }
386
387 while let Some((other_node, other_type)) = queue.peek() {
389 if node.id_span() == other_node.id_span() {
390 if node_type == *other_type {
391 match node_type {
392 NodeType::A => a_count -= 1,
393 NodeType::B => b_count -= 1,
394 NodeType::Shared => {}
395 }
396 } else {
397 if node_type != NodeType::Shared {
398 if visited.get(&node.id.peer).map(|(_, t)| *t) != Some(NodeType::Shared) {
399 ans.insert(node.id.peer, other_node.id_last().counter);
400 }
401 node_type = NodeType::Shared;
402 }
403 match other_type {
404 NodeType::A => a_count -= 1,
405 NodeType::B => b_count -= 1,
406 NodeType::Shared => {}
407 }
408 }
409
410 queue.pop();
411 } else {
412 break;
413 }
414 }
415
416 if let Some((ctr, visited_type)) = visited.get_mut(&node.id.peer) {
418 debug_assert!(*ctr >= node.id_last().counter);
419 if *visited_type == NodeType::Shared {
420 node_type = NodeType::Shared;
421 } else if *visited_type != node_type {
422 if node_type != NodeType::Shared {
424 ans.insert(node.id.peer, node.id_last().counter);
425 }
426 *visited_type = NodeType::Shared;
427 node_type = NodeType::Shared;
428 }
429 } else {
430 visited.insert(node.id.peer, (node.id_last().counter, node_type));
431 }
432
433 notify(node.id_span(), node_type);
436
437 match node_type {
438 NodeType::A => a_count += node.deps.len(),
439 NodeType::B => b_count += node.deps.len(),
440 NodeType::Shared => {}
441 }
442
443 if a_count == 0 && b_count == 0 && (min.is_none() || &node <= min.as_ref().unwrap()) {
444 if node_type != NodeType::Shared {
445 ans.clear();
446 }
447
448 break;
449 }
450
451 for dep_id in node.deps.as_ref().iter() {
452 queue.push((OrdIdSpan::from_dag_node(dep_id, get).unwrap(), node_type));
453 }
454
455 if node_type != NodeType::Shared {
456 if queue.is_empty() {
457 ans.clear();
458 break;
459 }
460 if node.deps.is_empty() && !find_path {
461 if node.len == 1 {
462 ans.clear();
463 break;
464 }
465
466 match node_type {
467 NodeType::A => a_count += 1,
468 NodeType::B => b_count += 1,
469 NodeType::Shared => {}
470 }
471
472 queue.push((
473 OrdIdSpan {
474 deps: Cow::Owned(Default::default()),
475 id: node.id,
476 len: 1,
477 lamport: node.lamport,
478 },
479 node_type,
480 ));
481 }
482 }
483 }
484
485 ans
486}
487
488fn _find_common_ancestor_new<'a, F, D>(
489 get: &'a F,
490 left: &Frontiers,
491 right: &Frontiers,
492) -> (Frontiers, DiffMode)
493where
494 D: DagNode + 'a,
495 F: Fn(ID) -> Option<D>,
496{
497 if right.is_empty() {
498 return (Default::default(), DiffMode::Checkout);
499 }
500
501 if left.is_empty() {
502 if right.len() == 1 {
503 let mut node_id = right.as_single().unwrap();
504 let mut node = get(node_id).unwrap();
505 while node.deps().len() == 1 {
506 node_id = node.deps().as_single().unwrap();
507 node = get(node_id).unwrap();
508 }
509
510 if node.deps().is_empty() {
511 return (Default::default(), DiffMode::Linear);
512 }
513 }
514
515 return (Default::default(), DiffMode::ImportGreaterUpdates);
516 }
517
518 if left.len() == 1 && right.len() == 1 {
519 let left = left.as_single().unwrap();
520 let right = right.as_single().unwrap();
521 if left.peer == right.peer {
522 let left_span = get(left).unwrap();
523 let right_span = get(right).unwrap();
524 if left_span.id_start() == right_span.id_start() {
525 if left.counter < right.counter {
526 return (left.into(), DiffMode::Linear);
527 } else {
528 return (right.into(), DiffMode::Checkout);
529 }
530 }
531
532 if left_span.deps().len() == 1
533 && right_span.contains_id(left_span.deps().as_single().unwrap())
534 {
535 return (right.into(), DiffMode::Checkout);
536 }
537
538 if right_span.deps().len() == 1
539 && left_span.contains_id(right_span.deps().as_single().unwrap())
540 {
541 return (left.into(), DiffMode::Linear);
542 }
543 }
544 }
545
546 let mut is_linear = left.len() <= 1 && right.len() == 1;
547 let mut is_right_greater = true;
548 let mut ans: Frontiers = Default::default();
549 let mut queue: BinaryHeap<(SmallVec<[OrdIdSpan; 1]>, NodeType)> = BinaryHeap::new();
550
551 fn ids_to_ord_id_spans<'a, D: DagNode + 'a, F: Fn(ID) -> Option<D>>(
552 ids: &Frontiers,
553 get: &'a F,
554 ) -> Option<SmallVec<[OrdIdSpan<'a>; 1]>> {
555 let mut ans: SmallVec<[OrdIdSpan<'a>; 1]> = SmallVec::with_capacity(ids.len());
556 for id in ids.iter() {
557 if let Some(node) = OrdIdSpan::from_dag_node(id, get) {
558 ans.push(node);
559 } else {
560 return None;
561 }
562 }
563
564 if ans.len() > 1 {
565 ans.sort_unstable_by(|a, b| b.cmp(a));
566 }
567
568 Some(ans)
569 }
570
571 queue.push((ids_to_ord_id_spans(left, get).unwrap(), NodeType::A));
572 queue.push((ids_to_ord_id_spans(right, get).unwrap(), NodeType::B));
573 while let Some((mut node, mut node_type)) = queue.pop() {
574 while let Some((other_node, other_type)) = queue.peek() {
575 if node == *other_node
576 || (node.len() == 1
577 && other_node.len() == 1
578 && node[0].id_last() == other_node[0].id_last())
579 {
580 if node_type != *other_type {
581 node_type = NodeType::Shared;
582 }
583
584 queue.pop();
585 } else {
586 break;
587 }
588 }
589
590 if queue.is_empty() {
591 if node_type == NodeType::Shared {
592 ans = node.into_iter().map(|x| x.id_last()).collect();
593 }
594
595 is_right_greater = false;
598 break;
599 }
600
601 if node_type == NodeType::A {
603 is_right_greater = false;
604 }
605
606 if node.len() > 1 {
607 for node in node.drain(1..node.len()) {
608 queue.push((smallvec::smallvec![node], node_type));
609 }
610 }
611
612 if let Some(other) = queue.peek() {
613 if other.0.len() == 1
614 && node[0].contains_id(other.0[0].id_last())
615 && node_type != other.1
616 {
617 node[0].len = (other.0[0].id_last().counter - node[0].id.counter + 1) as usize;
618 queue.push((node, node_type));
619 continue;
620 }
621
622 if node[0].len > 1 {
623 if other.0[0].lamport_last() > node[0].lamport {
624 node[0].len = (other.0[0].lamport_last() - node[0].lamport)
625 .min(node[0].len as u32 - 1) as usize;
626 queue.push((node, node_type));
627 continue;
628 } else {
629 node[0].len = 1;
630 queue.push((node, node_type));
631 continue;
632 }
633 }
634 }
635
636 if !node[0].deps.is_empty() {
637 if let Some(deps) = ids_to_ord_id_spans(node[0].deps.as_ref(), get) {
638 queue.push((deps, node_type));
639 } else {
640 panic!("deps on trimmed history");
642 }
643
644 is_linear = false;
645 } else {
646 is_right_greater = false;
647 break;
648 }
649 }
650
651 let mode = if is_right_greater {
652 if ans.len() <= 1 {
653 debug_assert_eq!(&ans, left);
654 }
655
656 if is_linear {
657 debug_assert!(ans.len() <= 1);
658 DiffMode::Linear
659 } else {
660 DiffMode::ImportGreaterUpdates
661 }
662 } else {
663 DiffMode::Checkout
664 };
665
666 (ans, mode)
667}
668
669pub fn remove_included_frontiers(frontiers: &mut VersionVector, new_change_deps: &[ID]) {
670 for dep in new_change_deps.iter() {
671 if let Some(last) = frontiers.get_last(dep.peer) {
672 if last <= dep.counter {
673 frontiers.remove(&dep.peer);
674 }
675 }
676 }
677}