loro_internal/
subscription.rs

1use super::{
2    arena::SharedArena,
3    event::{DiffEvent, DocDiff},
4};
5use crate::{
6    container::idx::ContainerIdx, utils::subscription::SubscriberSet, ContainerDiff, LoroDoc,
7    Subscription,
8};
9use rustc_hash::FxHashMap;
10use loro_common::{ContainerID, ID};
11use smallvec::SmallVec;
12use std::{collections::VecDeque, sync::Arc};
13
14use crate::sync::Mutex;
15/// The callback of the local update.
16pub type LocalUpdateCallback = Box<dyn Fn(&Vec<u8>) -> bool + Send + Sync + 'static>;
17/// The callback of the peer id change. The second argument is the next counter for the peer.
18pub type PeerIdUpdateCallback = Box<dyn Fn(&ID) -> bool + Send + Sync + 'static>;
19#[allow(clippy::unused_unit)]
20pub type Subscriber = Arc<dyn (for<'a> Fn(DiffEvent<'a>)) + Send + Sync>;
21
22impl LoroDoc {
23    /// Subscribe to the changes of the peer id.
24    pub fn subscribe_peer_id_change(&self, callback: PeerIdUpdateCallback) -> Subscription {
25        let (s, enable) = self.peer_id_change_subs.inner().insert((), callback);
26        enable();
27        s
28    }
29}
30
31struct ObserverInner {
32    subscriber_set: SubscriberSet<Option<ContainerIdx>, Subscriber>,
33    queue: Arc<Mutex<VecDeque<DocDiff>>>,
34}
35
36impl Default for ObserverInner {
37    fn default() -> Self {
38        Self {
39            subscriber_set: SubscriberSet::new(),
40            queue: Arc::new(Mutex::new(VecDeque::new())),
41        }
42    }
43}
44
45pub struct Observer {
46    inner: ObserverInner,
47    arena: SharedArena,
48}
49
50impl std::fmt::Debug for Observer {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.debug_struct("Observer").finish()
53    }
54}
55
56impl Observer {
57    pub fn new(arena: SharedArena) -> Self {
58        Self {
59            arena,
60            inner: ObserverInner::default(),
61        }
62    }
63
64    pub fn subscribe(&self, id: &ContainerID, callback: Subscriber) -> Subscription {
65        let idx = self.arena.register_container(id);
66        let inner = &self.inner;
67        let (sub, enable) = inner.subscriber_set.insert(Some(idx), callback);
68        enable();
69        sub
70    }
71
72    pub fn subscribe_root(&self, callback: Subscriber) -> Subscription {
73        let inner = &self.inner;
74        let (sub, enable) = inner.subscriber_set.insert(None, callback);
75        enable();
76        sub
77    }
78
79    pub(crate) fn emit(&self, doc_diff: DocDiff) {
80        let success = self.emit_inner(doc_diff);
81        if success {
82            let mut e = self.inner.queue.lock().unwrap().pop_front();
83            while let Some(event) = e {
84                self.emit_inner(event);
85                e = self.inner.queue.lock().unwrap().pop_front();
86            }
87        }
88    }
89
90    // When emitting changes, we need to make sure that the observer is not locked.
91    fn emit_inner(&self, doc_diff: DocDiff) -> bool {
92        let inner = &self.inner;
93        let mut container_events_map: FxHashMap<ContainerIdx, SmallVec<[&ContainerDiff; 1]>> =
94            Default::default();
95        for container_diff in doc_diff.diff.iter() {
96            self.arena
97                .with_ancestors(container_diff.idx, |ancestor, _| {
98                    if inner.subscriber_set.may_include(&Some(ancestor)) {
99                        container_events_map
100                            .entry(ancestor)
101                            .or_default()
102                            .push(container_diff);
103                    }
104                });
105        }
106
107        {
108            // Check whether we are calling events recursively.
109            // If so, push the event to the queue
110            if inner.subscriber_set.is_recursive_calling(&None)
111                || container_events_map
112                    .keys()
113                    .any(|x| inner.subscriber_set.is_recursive_calling(&Some(*x)))
114            {
115                drop(container_events_map);
116                inner.queue.lock().unwrap().push_back(doc_diff);
117                return false;
118            }
119        }
120
121        for (container_idx, container_diffs) in container_events_map {
122            inner
123                .subscriber_set
124                .retain(&Some(container_idx), &mut |callback| {
125                    (callback)(DiffEvent {
126                        current_target: Some(self.arena.get_container_id(container_idx).unwrap()),
127                        events: &container_diffs,
128                        event_meta: &doc_diff,
129                    });
130                    true
131                })
132                .unwrap();
133        }
134
135        let events: Vec<_> = doc_diff.diff.iter().collect();
136        inner
137            .subscriber_set
138            .retain(&None, &mut |callback| {
139                (callback)(DiffEvent {
140                    current_target: None,
141                    events: &events,
142                    event_meta: &doc_diff,
143                });
144                true
145            })
146            .unwrap();
147
148        true
149    }
150}
151
152#[cfg(test)]
153mod test {
154    use std::sync::atomic::{AtomicUsize, Ordering};
155
156    use super::*;
157    use crate::{handler::HandlerTrait, LoroDoc};
158
159    #[test]
160    fn test_recursive_events() {
161        let loro = Arc::new(LoroDoc::new());
162        let loro_cp = loro.clone();
163        let count = Arc::new(AtomicUsize::new(0));
164        let count_cp = Arc::clone(&count);
165        let _g = loro_cp.subscribe_root(Arc::new(move |_| {
166            count_cp.fetch_add(1, Ordering::SeqCst);
167            let mut txn = loro.txn().unwrap();
168            let text = loro.get_text("id");
169            if text.get_value().as_string().unwrap().len() > 10 {
170                return;
171            }
172            text.insert_with_txn(&mut txn, 0, "123").unwrap();
173            txn.commit().unwrap();
174        }));
175
176        let loro = loro_cp;
177        let mut txn = loro.txn().unwrap();
178        let text = loro.get_text("id");
179        text.insert_with_txn(&mut txn, 0, "123").unwrap();
180        txn.commit().unwrap();
181        let count = count.load(Ordering::SeqCst);
182        assert!(count > 2, "{}", count);
183    }
184
185    #[test]
186    fn unsubscribe() {
187        let loro = Arc::new(LoroDoc::new());
188        let count = Arc::new(AtomicUsize::new(0));
189        let count_cp = Arc::clone(&count);
190        let sub = loro.subscribe_root(Arc::new(move |_| {
191            count_cp.fetch_add(1, Ordering::SeqCst);
192        }));
193
194        let text = loro.get_text("id");
195
196        assert_eq!(count.load(Ordering::SeqCst), 0);
197        {
198            let mut txn = loro.txn().unwrap();
199            text.insert_with_txn(&mut txn, 0, "123").unwrap();
200            txn.commit().unwrap();
201        }
202        assert_eq!(count.load(Ordering::SeqCst), 1);
203        {
204            let mut txn = loro.txn().unwrap();
205            text.insert_with_txn(&mut txn, 0, "123").unwrap();
206            txn.commit().unwrap();
207        }
208        assert_eq!(count.load(Ordering::SeqCst), 2);
209        sub.unsubscribe();
210        {
211            let mut txn = loro.txn().unwrap();
212            text.insert_with_txn(&mut txn, 0, "123").unwrap();
213            txn.commit().unwrap();
214        }
215        assert_eq!(count.load(Ordering::SeqCst), 2);
216    }
217}