loro_internal/
subscription.rs

1use super::{
2    arena::SharedArena,
3    event::{DiffEvent, DocDiff},
4};
5use crate::{
6    container::idx::ContainerIdx, utils::subscription::SubscriberSet, ContainerDiff, LoroDoc,
7    Subscription,
8};
9use fxhash::FxHashMap;
10use loro_common::{ContainerID, ID};
11use smallvec::SmallVec;
12use std::{
13    collections::VecDeque,
14    sync::{Arc, Mutex},
15};
16
17/// The callback of the local update.
18pub type LocalUpdateCallback = Box<dyn Fn(&Vec<u8>) -> bool + Send + Sync + 'static>;
19/// The callback of the peer id change. The second argument is the next counter for the peer.
20pub type PeerIdUpdateCallback = Box<dyn Fn(&ID) -> bool + Send + Sync + 'static>;
21pub type Subscriber = Arc<dyn (for<'a> Fn(DiffEvent<'a>)) + Send + Sync>;
22
23impl LoroDoc {
24    /// Subscribe to the changes of the peer id.
25    pub fn subscribe_peer_id_change(&self, callback: PeerIdUpdateCallback) -> Subscription {
26        let (s, enable) = self.peer_id_change_subs.inner().insert((), callback);
27        enable();
28        s
29    }
30}
31
32struct ObserverInner {
33    subscriber_set: SubscriberSet<Option<ContainerIdx>, Subscriber>,
34    queue: Arc<Mutex<VecDeque<DocDiff>>>,
35}
36
37impl Default for ObserverInner {
38    fn default() -> Self {
39        Self {
40            subscriber_set: SubscriberSet::new(),
41            queue: Arc::new(Mutex::new(VecDeque::new())),
42        }
43    }
44}
45
46pub struct Observer {
47    inner: ObserverInner,
48    arena: SharedArena,
49}
50
51impl std::fmt::Debug for Observer {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("Observer").finish()
54    }
55}
56
57impl Observer {
58    pub fn new(arena: SharedArena) -> Self {
59        Self {
60            arena,
61            inner: ObserverInner::default(),
62        }
63    }
64
65    pub fn subscribe(&self, id: &ContainerID, callback: Subscriber) -> Subscription {
66        let idx = self.arena.register_container(id);
67        let inner = &self.inner;
68        let (sub, enable) = inner.subscriber_set.insert(Some(idx), callback);
69        enable();
70        sub
71    }
72
73    pub fn subscribe_root(&self, callback: Subscriber) -> Subscription {
74        let inner = &self.inner;
75        let (sub, enable) = inner.subscriber_set.insert(None, callback);
76        enable();
77        sub
78    }
79
80    pub(crate) fn emit(&self, doc_diff: DocDiff) {
81        let success = self.emit_inner(doc_diff);
82        if success {
83            let mut e = self.inner.queue.try_lock().unwrap().pop_front();
84            while let Some(event) = e {
85                self.emit_inner(event);
86                e = self.inner.queue.try_lock().unwrap().pop_front();
87            }
88        }
89    }
90
91    // When emitting changes, we need to make sure that the observer is not locked.
92    fn emit_inner(&self, doc_diff: DocDiff) -> bool {
93        let inner = &self.inner;
94        let mut container_events_map: FxHashMap<ContainerIdx, SmallVec<[&ContainerDiff; 1]>> =
95            Default::default();
96        for container_diff in doc_diff.diff.iter() {
97            self.arena
98                .with_ancestors(container_diff.idx, |ancestor, _| {
99                    if inner.subscriber_set.may_include(&Some(ancestor)) {
100                        container_events_map
101                            .entry(ancestor)
102                            .or_default()
103                            .push(container_diff);
104                    }
105                });
106        }
107
108        {
109            // Check whether we are calling events recursively.
110            // If so, push the event to the queue
111            if inner.subscriber_set.is_recursive_calling(&None)
112                || container_events_map
113                    .keys()
114                    .any(|x| inner.subscriber_set.is_recursive_calling(&Some(*x)))
115            {
116                drop(container_events_map);
117                inner.queue.try_lock().unwrap().push_back(doc_diff);
118                return false;
119            }
120        }
121
122        for (container_idx, container_diffs) in container_events_map {
123            inner
124                .subscriber_set
125                .retain(&Some(container_idx), &mut |callback| {
126                    (callback)(DiffEvent {
127                        current_target: Some(self.arena.get_container_id(container_idx).unwrap()),
128                        events: &container_diffs,
129                        event_meta: &doc_diff,
130                    });
131                    true
132                })
133                .unwrap();
134        }
135
136        let events: Vec<_> = doc_diff.diff.iter().collect();
137        inner
138            .subscriber_set
139            .retain(&None, &mut |callback| {
140                (callback)(DiffEvent {
141                    current_target: None,
142                    events: &events,
143                    event_meta: &doc_diff,
144                });
145                true
146            })
147            .unwrap();
148
149        true
150    }
151}
152
153#[cfg(test)]
154mod test {
155    use std::sync::atomic::{AtomicUsize, Ordering};
156
157    use tracing::trace;
158
159    use super::*;
160    use crate::{handler::HandlerTrait, LoroDoc};
161
162    #[test]
163    fn test_recursive_events() {
164        let loro = Arc::new(LoroDoc::new());
165        let loro_cp = loro.clone();
166        let count = Arc::new(AtomicUsize::new(0));
167        let count_cp = Arc::clone(&count);
168        let _g = loro_cp.subscribe_root(Arc::new(move |_| {
169            count_cp.fetch_add(1, Ordering::SeqCst);
170            let mut txn = loro.txn().unwrap();
171            let text = loro.get_text("id");
172            if text.get_value().as_string().unwrap().len() > 10 {
173                trace!("Skip");
174                return;
175            }
176            text.insert_with_txn(&mut txn, 0, "123").unwrap();
177            trace!("PRE Another commit");
178            txn.commit().unwrap();
179            trace!("AFTER Another commit");
180        }));
181
182        let loro = loro_cp;
183        let mut txn = loro.txn().unwrap();
184        let text = loro.get_text("id");
185        text.insert_with_txn(&mut txn, 0, "123").unwrap();
186        txn.commit().unwrap();
187        let count = count.load(Ordering::SeqCst);
188        assert!(count > 2, "{}", count);
189    }
190
191    #[test]
192    fn unsubscribe() {
193        let loro = Arc::new(LoroDoc::new());
194        let count = Arc::new(AtomicUsize::new(0));
195        let count_cp = Arc::clone(&count);
196        let sub = loro.subscribe_root(Arc::new(move |_| {
197            count_cp.fetch_add(1, Ordering::SeqCst);
198        }));
199
200        let text = loro.get_text("id");
201
202        assert_eq!(count.load(Ordering::SeqCst), 0);
203        {
204            let mut txn = loro.txn().unwrap();
205            text.insert_with_txn(&mut txn, 0, "123").unwrap();
206            txn.commit().unwrap();
207        }
208        assert_eq!(count.load(Ordering::SeqCst), 1);
209        {
210            let mut txn = loro.txn().unwrap();
211            text.insert_with_txn(&mut txn, 0, "123").unwrap();
212            txn.commit().unwrap();
213        }
214        assert_eq!(count.load(Ordering::SeqCst), 2);
215        sub.unsubscribe();
216        {
217            let mut txn = loro.txn().unwrap();
218            text.insert_with_txn(&mut txn, 0, "123").unwrap();
219            txn.commit().unwrap();
220        }
221        assert_eq!(count.load(Ordering::SeqCst), 2);
222    }
223}