loro_internal/
subscription.rs1use super::{
2 arena::SharedArena,
3 event::{DiffEvent, DocDiff},
4};
5use crate::{
6 container::idx::ContainerIdx, utils::subscription::SubscriberSet, ContainerDiff, LoroDoc,
7 Subscription,
8};
9use fxhash::FxHashMap;
10use loro_common::{ContainerID, ID};
11use smallvec::SmallVec;
12use std::{
13 collections::VecDeque,
14 sync::{Arc, Mutex},
15};
16
17pub type LocalUpdateCallback = Box<dyn Fn(&Vec<u8>) -> bool + Send + Sync + 'static>;
19pub type PeerIdUpdateCallback = Box<dyn Fn(&ID) -> bool + Send + Sync + 'static>;
21pub type Subscriber = Arc<dyn (for<'a> Fn(DiffEvent<'a>)) + Send + Sync>;
22
23impl LoroDoc {
24 pub fn subscribe_peer_id_change(&self, callback: PeerIdUpdateCallback) -> Subscription {
26 let (s, enable) = self.peer_id_change_subs.inner().insert((), callback);
27 enable();
28 s
29 }
30}
31
32struct ObserverInner {
33 subscriber_set: SubscriberSet<Option<ContainerIdx>, Subscriber>,
34 queue: Arc<Mutex<VecDeque<DocDiff>>>,
35}
36
37impl Default for ObserverInner {
38 fn default() -> Self {
39 Self {
40 subscriber_set: SubscriberSet::new(),
41 queue: Arc::new(Mutex::new(VecDeque::new())),
42 }
43 }
44}
45
46pub struct Observer {
47 inner: ObserverInner,
48 arena: SharedArena,
49}
50
51impl std::fmt::Debug for Observer {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("Observer").finish()
54 }
55}
56
57impl Observer {
58 pub fn new(arena: SharedArena) -> Self {
59 Self {
60 arena,
61 inner: ObserverInner::default(),
62 }
63 }
64
65 pub fn subscribe(&self, id: &ContainerID, callback: Subscriber) -> Subscription {
66 let idx = self.arena.register_container(id);
67 let inner = &self.inner;
68 let (sub, enable) = inner.subscriber_set.insert(Some(idx), callback);
69 enable();
70 sub
71 }
72
73 pub fn subscribe_root(&self, callback: Subscriber) -> Subscription {
74 let inner = &self.inner;
75 let (sub, enable) = inner.subscriber_set.insert(None, callback);
76 enable();
77 sub
78 }
79
80 pub(crate) fn emit(&self, doc_diff: DocDiff) {
81 let success = self.emit_inner(doc_diff);
82 if success {
83 let mut e = self.inner.queue.try_lock().unwrap().pop_front();
84 while let Some(event) = e {
85 self.emit_inner(event);
86 e = self.inner.queue.try_lock().unwrap().pop_front();
87 }
88 }
89 }
90
91 fn emit_inner(&self, doc_diff: DocDiff) -> bool {
93 let inner = &self.inner;
94 let mut container_events_map: FxHashMap<ContainerIdx, SmallVec<[&ContainerDiff; 1]>> =
95 Default::default();
96 for container_diff in doc_diff.diff.iter() {
97 self.arena
98 .with_ancestors(container_diff.idx, |ancestor, _| {
99 if inner.subscriber_set.may_include(&Some(ancestor)) {
100 container_events_map
101 .entry(ancestor)
102 .or_default()
103 .push(container_diff);
104 }
105 });
106 }
107
108 {
109 if inner.subscriber_set.is_recursive_calling(&None)
112 || container_events_map
113 .keys()
114 .any(|x| inner.subscriber_set.is_recursive_calling(&Some(*x)))
115 {
116 drop(container_events_map);
117 inner.queue.try_lock().unwrap().push_back(doc_diff);
118 return false;
119 }
120 }
121
122 for (container_idx, container_diffs) in container_events_map {
123 inner
124 .subscriber_set
125 .retain(&Some(container_idx), &mut |callback| {
126 (callback)(DiffEvent {
127 current_target: Some(self.arena.get_container_id(container_idx).unwrap()),
128 events: &container_diffs,
129 event_meta: &doc_diff,
130 });
131 true
132 })
133 .unwrap();
134 }
135
136 let events: Vec<_> = doc_diff.diff.iter().collect();
137 inner
138 .subscriber_set
139 .retain(&None, &mut |callback| {
140 (callback)(DiffEvent {
141 current_target: None,
142 events: &events,
143 event_meta: &doc_diff,
144 });
145 true
146 })
147 .unwrap();
148
149 true
150 }
151}
152
153#[cfg(test)]
154mod test {
155 use std::sync::atomic::{AtomicUsize, Ordering};
156
157 use tracing::trace;
158
159 use super::*;
160 use crate::{handler::HandlerTrait, LoroDoc};
161
162 #[test]
163 fn test_recursive_events() {
164 let loro = Arc::new(LoroDoc::new());
165 let loro_cp = loro.clone();
166 let count = Arc::new(AtomicUsize::new(0));
167 let count_cp = Arc::clone(&count);
168 let _g = loro_cp.subscribe_root(Arc::new(move |_| {
169 count_cp.fetch_add(1, Ordering::SeqCst);
170 let mut txn = loro.txn().unwrap();
171 let text = loro.get_text("id");
172 if text.get_value().as_string().unwrap().len() > 10 {
173 trace!("Skip");
174 return;
175 }
176 text.insert_with_txn(&mut txn, 0, "123").unwrap();
177 trace!("PRE Another commit");
178 txn.commit().unwrap();
179 trace!("AFTER Another commit");
180 }));
181
182 let loro = loro_cp;
183 let mut txn = loro.txn().unwrap();
184 let text = loro.get_text("id");
185 text.insert_with_txn(&mut txn, 0, "123").unwrap();
186 txn.commit().unwrap();
187 let count = count.load(Ordering::SeqCst);
188 assert!(count > 2, "{}", count);
189 }
190
191 #[test]
192 fn unsubscribe() {
193 let loro = Arc::new(LoroDoc::new());
194 let count = Arc::new(AtomicUsize::new(0));
195 let count_cp = Arc::clone(&count);
196 let sub = loro.subscribe_root(Arc::new(move |_| {
197 count_cp.fetch_add(1, Ordering::SeqCst);
198 }));
199
200 let text = loro.get_text("id");
201
202 assert_eq!(count.load(Ordering::SeqCst), 0);
203 {
204 let mut txn = loro.txn().unwrap();
205 text.insert_with_txn(&mut txn, 0, "123").unwrap();
206 txn.commit().unwrap();
207 }
208 assert_eq!(count.load(Ordering::SeqCst), 1);
209 {
210 let mut txn = loro.txn().unwrap();
211 text.insert_with_txn(&mut txn, 0, "123").unwrap();
212 txn.commit().unwrap();
213 }
214 assert_eq!(count.load(Ordering::SeqCst), 2);
215 sub.unsubscribe();
216 {
217 let mut txn = loro.txn().unwrap();
218 text.insert_with_txn(&mut txn, 0, "123").unwrap();
219 txn.commit().unwrap();
220 }
221 assert_eq!(count.load(Ordering::SeqCst), 2);
222 }
223}