loro_internal/
subscription.rs1use super::{
2 arena::SharedArena,
3 event::{DiffEvent, DocDiff},
4};
5use crate::{
6 container::idx::ContainerIdx, utils::subscription::SubscriberSet, ContainerDiff, LoroDoc,
7 Subscription,
8};
9use loro_common::{ContainerID, ID};
10use rustc_hash::FxHashMap;
11use smallvec::SmallVec;
12use std::{collections::VecDeque, sync::Arc};
13
14use crate::sync::Mutex;
15pub type LocalUpdateCallback = Box<dyn Fn(&Vec<u8>) -> bool + Send + Sync + 'static>;
17pub type PeerIdUpdateCallback = Box<dyn Fn(&ID) -> bool + Send + Sync + 'static>;
19#[allow(clippy::unused_unit)]
20pub type Subscriber = Arc<dyn for<'a> Fn(DiffEvent<'a>) + Send + Sync>;
21
22impl LoroDoc {
23 pub fn subscribe_peer_id_change(&self, callback: PeerIdUpdateCallback) -> Subscription {
25 let (s, enable) = self.peer_id_change_subs.inner().insert((), callback);
26 enable();
27 s
28 }
29}
30
31struct ObserverInner {
32 subscriber_set: SubscriberSet<Option<ContainerIdx>, Subscriber>,
33 queue: Arc<Mutex<VecDeque<DocDiff>>>,
34}
35
36impl Default for ObserverInner {
37 fn default() -> Self {
38 Self {
39 subscriber_set: SubscriberSet::new(),
40 queue: Arc::new(Mutex::new(VecDeque::new())),
41 }
42 }
43}
44
45pub struct Observer {
46 inner: ObserverInner,
47 arena: SharedArena,
48}
49
50impl std::fmt::Debug for Observer {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("Observer").finish()
53 }
54}
55
56impl Observer {
57 pub fn new(arena: SharedArena) -> Self {
58 Self {
59 arena,
60 inner: ObserverInner::default(),
61 }
62 }
63
64 pub fn subscribe(&self, id: &ContainerID, callback: Subscriber) -> Subscription {
65 let idx = self.arena.register_container(id);
66 let inner = &self.inner;
67 let (sub, enable) = inner.subscriber_set.insert(Some(idx), callback);
68 enable();
69 sub
70 }
71
72 pub fn subscribe_root(&self, callback: Subscriber) -> Subscription {
73 let inner = &self.inner;
74 let (sub, enable) = inner.subscriber_set.insert(None, callback);
75 enable();
76 sub
77 }
78
79 pub(crate) fn emit(&self, doc_diff: DocDiff) {
80 let success = self.emit_inner(doc_diff);
81 if success {
82 let mut e = self.inner.queue.lock().unwrap().pop_front();
83 while let Some(event) = e {
84 self.emit_inner(event);
85 e = self.inner.queue.lock().unwrap().pop_front();
86 }
87 }
88 }
89
90 fn emit_inner(&self, doc_diff: DocDiff) -> bool {
92 let inner = &self.inner;
93 let mut container_events_map: FxHashMap<ContainerIdx, SmallVec<[&ContainerDiff; 1]>> =
94 Default::default();
95 for container_diff in doc_diff.diff.iter() {
96 self.arena
97 .with_ancestors(container_diff.idx, |ancestor, _| {
98 if inner.subscriber_set.may_include(&Some(ancestor)) {
99 container_events_map
100 .entry(ancestor)
101 .or_default()
102 .push(container_diff);
103 }
104 });
105 }
106
107 {
108 if inner.subscriber_set.is_recursive_calling(&None)
111 || container_events_map
112 .keys()
113 .any(|x| inner.subscriber_set.is_recursive_calling(&Some(*x)))
114 {
115 drop(container_events_map);
116 inner.queue.lock().unwrap().push_back(doc_diff);
117 return false;
118 }
119 }
120
121 for (container_idx, container_diffs) in container_events_map {
122 inner
123 .subscriber_set
124 .retain(&Some(container_idx), &mut |callback| {
125 (callback)(DiffEvent {
126 current_target: Some(self.arena.get_container_id(container_idx).unwrap()),
127 events: &container_diffs,
128 event_meta: &doc_diff,
129 });
130 true
131 })
132 .unwrap();
133 }
134
135 let events: Vec<_> = doc_diff.diff.iter().collect();
136 inner
137 .subscriber_set
138 .retain(&None, &mut |callback| {
139 (callback)(DiffEvent {
140 current_target: None,
141 events: &events,
142 event_meta: &doc_diff,
143 });
144 true
145 })
146 .unwrap();
147
148 true
149 }
150}
151
152#[cfg(test)]
153mod test {
154 use std::sync::atomic::{AtomicUsize, Ordering};
155
156 use super::*;
157 use crate::{cursor::PosType, handler::HandlerTrait, LoroDoc};
158
159 #[test]
160 fn test_recursive_events() {
161 let loro = Arc::new(LoroDoc::new());
162 let loro_cp = loro.clone();
163 let count = Arc::new(AtomicUsize::new(0));
164 let count_cp = Arc::clone(&count);
165 let _g = loro_cp.subscribe_root(Arc::new(move |_| {
166 count_cp.fetch_add(1, Ordering::SeqCst);
167 let mut txn = loro.txn().unwrap();
168 let text = loro.get_text("id");
169 if text.get_value().as_string().unwrap().len() > 10 {
170 return;
171 }
172 text.insert_with_txn(&mut txn, 0, "123", PosType::Unicode)
173 .unwrap();
174 txn.commit().unwrap();
175 }));
176
177 let loro = loro_cp;
178 let mut txn = loro.txn().unwrap();
179 let text = loro.get_text("id");
180 text.insert_with_txn(&mut txn, 0, "123", PosType::Unicode)
181 .unwrap();
182 txn.commit().unwrap();
183 let count = count.load(Ordering::SeqCst);
184 assert!(count > 2, "{}", count);
185 }
186
187 #[test]
188 fn unsubscribe() {
189 let loro = Arc::new(LoroDoc::new());
190 let count = Arc::new(AtomicUsize::new(0));
191 let count_cp = Arc::clone(&count);
192 let sub = loro.subscribe_root(Arc::new(move |_| {
193 count_cp.fetch_add(1, Ordering::SeqCst);
194 }));
195
196 let text = loro.get_text("id");
197
198 assert_eq!(count.load(Ordering::SeqCst), 0);
199 {
200 let mut txn = loro.txn().unwrap();
201 text.insert_with_txn(&mut txn, 0, "123", PosType::Unicode)
202 .unwrap();
203 txn.commit().unwrap();
204 }
205 assert_eq!(count.load(Ordering::SeqCst), 1);
206 {
207 let mut txn = loro.txn().unwrap();
208 text.insert_with_txn(&mut txn, 0, "123", PosType::Unicode)
209 .unwrap();
210 txn.commit().unwrap();
211 }
212 assert_eq!(count.load(Ordering::SeqCst), 2);
213 sub.unsubscribe();
214 {
215 let mut txn = loro.txn().unwrap();
216 text.insert_with_txn(&mut txn, 0, "123", PosType::Unicode)
217 .unwrap();
218 txn.commit().unwrap();
219 }
220 assert_eq!(count.load(Ordering::SeqCst), 2);
221 }
222}