loro_internal/
subscription.rs

1use super::{
2    arena::SharedArena,
3    event::{DiffEvent, DocDiff},
4};
5use crate::{
6    container::idx::ContainerIdx, utils::subscription::SubscriberSet, ContainerDiff, LoroDoc,
7    Subscription,
8};
9use fxhash::FxHashMap;
10use loro_common::{ContainerID, ID};
11use smallvec::SmallVec;
12use std::{collections::VecDeque, sync::Arc};
13
14use crate::sync::Mutex;
15/// The callback of the local update.
16pub type LocalUpdateCallback = Box<dyn Fn(&Vec<u8>) -> bool + Send + Sync + 'static>;
17/// The callback of the peer id change. The second argument is the next counter for the peer.
18pub type PeerIdUpdateCallback = Box<dyn Fn(&ID) -> bool + Send + Sync + 'static>;
19pub type Subscriber = Arc<dyn (for<'a> Fn(DiffEvent<'a>)) + Send + Sync>;
20
21impl LoroDoc {
22    /// Subscribe to the changes of the peer id.
23    pub fn subscribe_peer_id_change(&self, callback: PeerIdUpdateCallback) -> Subscription {
24        let (s, enable) = self.peer_id_change_subs.inner().insert((), callback);
25        enable();
26        s
27    }
28}
29
30struct ObserverInner {
31    subscriber_set: SubscriberSet<Option<ContainerIdx>, Subscriber>,
32    queue: Arc<Mutex<VecDeque<DocDiff>>>,
33}
34
35impl Default for ObserverInner {
36    fn default() -> Self {
37        Self {
38            subscriber_set: SubscriberSet::new(),
39            queue: Arc::new(Mutex::new(VecDeque::new())),
40        }
41    }
42}
43
44pub struct Observer {
45    inner: ObserverInner,
46    arena: SharedArena,
47}
48
49impl std::fmt::Debug for Observer {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.debug_struct("Observer").finish()
52    }
53}
54
55impl Observer {
56    pub fn new(arena: SharedArena) -> Self {
57        Self {
58            arena,
59            inner: ObserverInner::default(),
60        }
61    }
62
63    pub fn subscribe(&self, id: &ContainerID, callback: Subscriber) -> Subscription {
64        let idx = self.arena.register_container(id);
65        let inner = &self.inner;
66        let (sub, enable) = inner.subscriber_set.insert(Some(idx), callback);
67        enable();
68        sub
69    }
70
71    pub fn subscribe_root(&self, callback: Subscriber) -> Subscription {
72        let inner = &self.inner;
73        let (sub, enable) = inner.subscriber_set.insert(None, callback);
74        enable();
75        sub
76    }
77
78    pub(crate) fn emit(&self, doc_diff: DocDiff) {
79        let success = self.emit_inner(doc_diff);
80        if success {
81            let mut e = self.inner.queue.lock().unwrap().pop_front();
82            while let Some(event) = e {
83                self.emit_inner(event);
84                e = self.inner.queue.lock().unwrap().pop_front();
85            }
86        }
87    }
88
89    // When emitting changes, we need to make sure that the observer is not locked.
90    fn emit_inner(&self, doc_diff: DocDiff) -> bool {
91        let inner = &self.inner;
92        let mut container_events_map: FxHashMap<ContainerIdx, SmallVec<[&ContainerDiff; 1]>> =
93            Default::default();
94        for container_diff in doc_diff.diff.iter() {
95            self.arena
96                .with_ancestors(container_diff.idx, |ancestor, _| {
97                    if inner.subscriber_set.may_include(&Some(ancestor)) {
98                        container_events_map
99                            .entry(ancestor)
100                            .or_default()
101                            .push(container_diff);
102                    }
103                });
104        }
105
106        {
107            // Check whether we are calling events recursively.
108            // If so, push the event to the queue
109            if inner.subscriber_set.is_recursive_calling(&None)
110                || container_events_map
111                    .keys()
112                    .any(|x| inner.subscriber_set.is_recursive_calling(&Some(*x)))
113            {
114                drop(container_events_map);
115                inner.queue.lock().unwrap().push_back(doc_diff);
116                return false;
117            }
118        }
119
120        for (container_idx, container_diffs) in container_events_map {
121            inner
122                .subscriber_set
123                .retain(&Some(container_idx), &mut |callback| {
124                    (callback)(DiffEvent {
125                        current_target: Some(self.arena.get_container_id(container_idx).unwrap()),
126                        events: &container_diffs,
127                        event_meta: &doc_diff,
128                    });
129                    true
130                })
131                .unwrap();
132        }
133
134        let events: Vec<_> = doc_diff.diff.iter().collect();
135        inner
136            .subscriber_set
137            .retain(&None, &mut |callback| {
138                (callback)(DiffEvent {
139                    current_target: None,
140                    events: &events,
141                    event_meta: &doc_diff,
142                });
143                true
144            })
145            .unwrap();
146
147        true
148    }
149}
150
151#[cfg(test)]
152mod test {
153    use std::sync::atomic::{AtomicUsize, Ordering};
154
155    use tracing::trace;
156
157    use super::*;
158    use crate::{handler::HandlerTrait, LoroDoc};
159
160    #[test]
161    fn test_recursive_events() {
162        let loro = Arc::new(LoroDoc::new());
163        let loro_cp = loro.clone();
164        let count = Arc::new(AtomicUsize::new(0));
165        let count_cp = Arc::clone(&count);
166        let _g = loro_cp.subscribe_root(Arc::new(move |_| {
167            count_cp.fetch_add(1, Ordering::SeqCst);
168            let mut txn = loro.txn().unwrap();
169            let text = loro.get_text("id");
170            if text.get_value().as_string().unwrap().len() > 10 {
171                trace!("Skip");
172                return;
173            }
174            text.insert_with_txn(&mut txn, 0, "123").unwrap();
175            trace!("PRE Another commit");
176            txn.commit().unwrap();
177            trace!("AFTER Another commit");
178        }));
179
180        let loro = loro_cp;
181        let mut txn = loro.txn().unwrap();
182        let text = loro.get_text("id");
183        text.insert_with_txn(&mut txn, 0, "123").unwrap();
184        txn.commit().unwrap();
185        let count = count.load(Ordering::SeqCst);
186        assert!(count > 2, "{}", count);
187    }
188
189    #[test]
190    fn unsubscribe() {
191        let loro = Arc::new(LoroDoc::new());
192        let count = Arc::new(AtomicUsize::new(0));
193        let count_cp = Arc::clone(&count);
194        let sub = loro.subscribe_root(Arc::new(move |_| {
195            count_cp.fetch_add(1, Ordering::SeqCst);
196        }));
197
198        let text = loro.get_text("id");
199
200        assert_eq!(count.load(Ordering::SeqCst), 0);
201        {
202            let mut txn = loro.txn().unwrap();
203            text.insert_with_txn(&mut txn, 0, "123").unwrap();
204            txn.commit().unwrap();
205        }
206        assert_eq!(count.load(Ordering::SeqCst), 1);
207        {
208            let mut txn = loro.txn().unwrap();
209            text.insert_with_txn(&mut txn, 0, "123").unwrap();
210            txn.commit().unwrap();
211        }
212        assert_eq!(count.load(Ordering::SeqCst), 2);
213        sub.unsubscribe();
214        {
215            let mut txn = loro.txn().unwrap();
216            text.insert_with_txn(&mut txn, 0, "123").unwrap();
217            txn.commit().unwrap();
218        }
219        assert_eq!(count.load(Ordering::SeqCst), 2);
220    }
221}