kaspa_notify/subscription/
compounded.rs

1use crate::{
2    address::{error::Result, tracker::Counters},
3    events::EventType,
4    scope::{Scope, UtxosChangedScope, VirtualChainChangedScope},
5    subscription::{context::SubscriptionContext, Command, Compounded, Mutation, Subscription},
6};
7use itertools::Itertools;
8use kaspa_addresses::{Address, Prefix};
9
10#[derive(Clone, Debug, PartialEq, Eq)]
11pub struct OverallSubscription {
12    event_type: EventType,
13    active: usize,
14}
15
16impl OverallSubscription {
17    pub fn new(event_type: EventType) -> Self {
18        Self { event_type, active: 0 }
19    }
20}
21
22impl Compounded for OverallSubscription {
23    fn compound(&mut self, mutation: Mutation, _context: &SubscriptionContext) -> Option<Mutation> {
24        assert_eq!(self.event_type(), mutation.event_type());
25        match mutation.command {
26            Command::Start => {
27                self.active += 1;
28                if self.active == 1 {
29                    return Some(mutation);
30                }
31            }
32            Command::Stop => {
33                assert!(self.active > 0);
34                self.active -= 1;
35                if self.active == 0 {
36                    return Some(mutation);
37                }
38            }
39        }
40        None
41    }
42}
43
44impl Subscription for OverallSubscription {
45    #[inline(always)]
46    fn event_type(&self) -> EventType {
47        self.event_type
48    }
49
50    fn active(&self) -> bool {
51        self.active > 0
52    }
53
54    fn scope(&self, _context: &SubscriptionContext) -> Scope {
55        self.event_type.into()
56    }
57}
58
59#[derive(Clone, Default, Debug, PartialEq, Eq)]
60pub struct VirtualChainChangedSubscription {
61    include_accepted_transaction_ids: [usize; 2],
62}
63
64impl VirtualChainChangedSubscription {
65    #[inline(always)]
66    fn all(&self) -> usize {
67        self.include_accepted_transaction_ids[true as usize]
68    }
69
70    #[inline(always)]
71    fn all_mut(&mut self) -> &mut usize {
72        &mut self.include_accepted_transaction_ids[true as usize]
73    }
74
75    #[inline(always)]
76    fn reduced(&self) -> usize {
77        self.include_accepted_transaction_ids[false as usize]
78    }
79
80    #[inline(always)]
81    fn reduced_mut(&mut self) -> &mut usize {
82        &mut self.include_accepted_transaction_ids[false as usize]
83    }
84}
85
86impl Compounded for VirtualChainChangedSubscription {
87    fn compound(&mut self, mutation: Mutation, _context: &SubscriptionContext) -> Option<Mutation> {
88        assert_eq!(self.event_type(), mutation.event_type());
89        if let Scope::VirtualChainChanged(ref scope) = mutation.scope {
90            let all = scope.include_accepted_transaction_ids;
91            match mutation.command {
92                Command::Start => {
93                    if all {
94                        // Add All
95                        *self.all_mut() += 1;
96                        if self.all() == 1 {
97                            return Some(mutation);
98                        }
99                    } else {
100                        // Add Reduced
101                        *self.reduced_mut() += 1;
102                        if self.reduced() == 1 && self.all() == 0 {
103                            return Some(mutation);
104                        }
105                    }
106                }
107                Command::Stop => {
108                    if !all {
109                        // Remove Reduced
110                        assert!(self.reduced() > 0);
111                        *self.reduced_mut() -= 1;
112                        if self.reduced() == 0 && self.all() == 0 {
113                            return Some(mutation);
114                        }
115                    } else {
116                        // Remove All
117                        assert!(self.all() > 0);
118                        *self.all_mut() -= 1;
119                        if self.all() == 0 {
120                            if self.reduced() > 0 {
121                                return Some(Mutation::new(
122                                    Command::Start,
123                                    Scope::VirtualChainChanged(VirtualChainChangedScope::new(false)),
124                                ));
125                            } else {
126                                return Some(mutation);
127                            }
128                        }
129                    }
130                }
131            }
132        }
133        None
134    }
135}
136
137impl Subscription for VirtualChainChangedSubscription {
138    #[inline(always)]
139    fn event_type(&self) -> EventType {
140        EventType::VirtualChainChanged
141    }
142
143    fn active(&self) -> bool {
144        self.include_accepted_transaction_ids.iter().sum::<usize>() > 0
145    }
146
147    fn scope(&self, _context: &SubscriptionContext) -> Scope {
148        Scope::VirtualChainChanged(VirtualChainChangedScope::new(self.all() > 0))
149    }
150}
151
152#[derive(Clone, Default, Debug, PartialEq, Eq)]
153pub struct UtxosChangedSubscription {
154    all: usize,
155    indexes: Counters,
156}
157
158impl UtxosChangedSubscription {
159    pub fn new() -> Self {
160        Self { all: 0, indexes: Counters::new() }
161    }
162
163    pub fn with_capacity(capacity: usize) -> Self {
164        Self { all: 0, indexes: Counters::with_capacity(capacity) }
165    }
166
167    pub fn to_addresses(&self, prefix: Prefix, context: &SubscriptionContext) -> Vec<Address> {
168        self.indexes
169            .iter()
170            .filter_map(|(&index, &count)| {
171                (count > 0).then_some(()).and_then(|_| context.address_tracker.get_address_at_index(index, prefix))
172            })
173            .collect_vec()
174    }
175
176    pub fn register(&mut self, addresses: Vec<Address>, context: &SubscriptionContext) -> Result<Vec<Address>> {
177        context.address_tracker.register(&mut self.indexes, addresses)
178    }
179
180    pub fn unregister(&mut self, addresses: Vec<Address>, context: &SubscriptionContext) -> Vec<Address> {
181        context.address_tracker.unregister(&mut self.indexes, addresses)
182    }
183}
184
185impl Compounded for UtxosChangedSubscription {
186    fn compound(&mut self, mutation: Mutation, context: &SubscriptionContext) -> Option<Mutation> {
187        assert_eq!(self.event_type(), mutation.event_type());
188        if let Scope::UtxosChanged(scope) = mutation.scope {
189            match mutation.command {
190                Command::Start => {
191                    if scope.addresses.is_empty() {
192                        // Add All
193                        self.all += 1;
194                        if self.all == 1 {
195                            return Some(Mutation::new(Command::Start, UtxosChangedScope::default().into()));
196                        }
197                    } else {
198                        // Add(A)
199                        let added = self.register(scope.addresses, context).expect("compounded always registers");
200                        if !added.is_empty() && self.all == 0 {
201                            return Some(Mutation::new(Command::Start, UtxosChangedScope::new(added).into()));
202                        }
203                    }
204                }
205                Command::Stop => {
206                    if !scope.addresses.is_empty() {
207                        // Remove(R)
208                        let removed = self.unregister(scope.addresses, context);
209                        if !removed.is_empty() && self.all == 0 {
210                            return Some(Mutation::new(Command::Stop, UtxosChangedScope::new(removed).into()));
211                        }
212                    } else {
213                        // Remove All
214                        assert!(self.all > 0);
215                        self.all -= 1;
216                        if self.all == 0 {
217                            let addresses = self.to_addresses(Prefix::Mainnet, context);
218                            if !addresses.is_empty() {
219                                return Some(Mutation::new(Command::Start, UtxosChangedScope::new(addresses).into()));
220                            } else {
221                                return Some(Mutation::new(Command::Stop, UtxosChangedScope::default().into()));
222                            }
223                        }
224                    }
225                }
226            }
227        }
228        None
229    }
230}
231
232impl Subscription for UtxosChangedSubscription {
233    #[inline(always)]
234    fn event_type(&self) -> EventType {
235        EventType::UtxosChanged
236    }
237
238    fn active(&self) -> bool {
239        self.all > 0 || !self.indexes.is_empty()
240    }
241
242    fn scope(&self, context: &SubscriptionContext) -> Scope {
243        let addresses = if self.all > 0 { vec![] } else { self.to_addresses(Prefix::Mainnet, context) };
244        Scope::UtxosChanged(UtxosChangedScope::new(addresses))
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use kaspa_core::trace;
251
252    use super::super::*;
253    use super::*;
254    use crate::{
255        address::{test_helpers::get_3_addresses, tracker::Counter},
256        scope::BlockAddedScope,
257    };
258    use std::panic::AssertUnwindSafe;
259
260    struct Step {
261        name: &'static str,
262        mutation: Mutation,
263        result: Option<Mutation>,
264    }
265
266    struct Test {
267        name: &'static str,
268        context: SubscriptionContext,
269        initial_state: CompoundedSubscription,
270        steps: Vec<Step>,
271        final_state: CompoundedSubscription,
272    }
273
274    impl Test {
275        fn run(&self) -> CompoundedSubscription {
276            let mut state = self.initial_state.clone_box();
277            for (idx, step) in self.steps.iter().enumerate() {
278                trace!("{}: {}", idx, step.name);
279                let result = state.compound(step.mutation.clone(), &self.context);
280                assert_eq!(step.result, result, "{} - {}: wrong compound result", self.name, step.name);
281                trace!("{}: state = {:?}", idx, state);
282            }
283            assert_eq!(*self.final_state, *state, "{}: wrong final state", self.name);
284            state
285        }
286    }
287
288    #[test]
289    #[allow(clippy::redundant_clone)]
290    fn test_overall_compounding() {
291        let none = || Box::new(OverallSubscription::new(EventType::BlockAdded));
292        let add = || Mutation::new(Command::Start, Scope::BlockAdded(BlockAddedScope {}));
293        let remove = || Mutation::new(Command::Stop, Scope::BlockAdded(BlockAddedScope {}));
294        let test = Test {
295            name: "OverallSubscription 0 to 2 to 0",
296            context: SubscriptionContext::new(),
297            initial_state: none(),
298            steps: vec![
299                Step { name: "add 1", mutation: add(), result: Some(add()) },
300                Step { name: "add 2", mutation: add(), result: None },
301                Step { name: "remove 2", mutation: remove(), result: None },
302                Step { name: "remove 1", mutation: remove(), result: Some(remove()) },
303            ],
304            final_state: none(),
305        };
306        let mut state = test.run();
307
308        // Removing once more must panic
309        let result = std::panic::catch_unwind(AssertUnwindSafe(|| state.compound(remove(), &test.context)));
310        assert!(result.is_err(), "{}: trying to remove when counter is zero must panic", test.name);
311    }
312
313    #[test]
314    #[allow(clippy::redundant_clone)]
315    fn test_virtual_chain_changed_compounding() {
316        fn m(command: Command, include_accepted_transaction_ids: bool) -> Mutation {
317            Mutation { command, scope: Scope::VirtualChainChanged(VirtualChainChangedScope { include_accepted_transaction_ids }) }
318        }
319        let none = Box::<VirtualChainChangedSubscription>::default;
320        let add_all = || m(Command::Start, true);
321        let add_reduced = || m(Command::Start, false);
322        let remove_reduced = || m(Command::Stop, false);
323        let remove_all = || m(Command::Stop, true);
324        let test = Test {
325            name: "VirtualChainChanged",
326            context: SubscriptionContext::new(),
327            initial_state: none(),
328            steps: vec![
329                Step { name: "add all 1", mutation: add_all(), result: Some(add_all()) },
330                Step { name: "add all 2", mutation: add_all(), result: None },
331                Step { name: "remove all 2", mutation: remove_all(), result: None },
332                Step { name: "remove all 1", mutation: remove_all(), result: Some(remove_all()) },
333                Step { name: "add reduced 1", mutation: add_reduced(), result: Some(add_reduced()) },
334                Step { name: "add reduced 2", mutation: add_reduced(), result: None },
335                Step { name: "remove reduced 2", mutation: remove_reduced(), result: None },
336                Step { name: "remove reduced 1", mutation: remove_reduced(), result: Some(remove_reduced()) },
337                // Interleaved all and reduced
338                Step { name: "add all 1", mutation: add_all(), result: Some(add_all()) },
339                Step { name: "add reduced 1, masked by all", mutation: add_reduced(), result: None },
340                Step { name: "remove all 1, revealing reduced", mutation: remove_all(), result: Some(add_reduced()) },
341                Step { name: "add all 1, masking reduced", mutation: add_all(), result: Some(add_all()) },
342                Step { name: "remove reduced 1, masked by all", mutation: remove_reduced(), result: None },
343                Step { name: "remove all 1", mutation: remove_all(), result: Some(remove_all()) },
344            ],
345            final_state: none(),
346        };
347        let mut state = test.run();
348
349        // Removing once more must panic
350        let result = std::panic::catch_unwind(AssertUnwindSafe(|| state.compound(remove_all(), &test.context)));
351        assert!(result.is_err(), "{}: trying to remove all when counter is zero must panic", test.name);
352        let result = std::panic::catch_unwind(AssertUnwindSafe(|| state.compound(remove_reduced(), &test.context)));
353        assert!(result.is_err(), "{}: trying to remove reduced when counter is zero must panic", test.name);
354    }
355
356    #[test]
357    #[allow(clippy::redundant_clone)]
358    fn test_utxos_changed_compounding() {
359        kaspa_core::log::try_init_logger("trace,kaspa_notify=trace");
360        let a_stock = get_3_addresses(true);
361
362        let a = |indexes: &[usize]| indexes.iter().map(|idx| (a_stock[*idx]).clone()).collect::<Vec<_>>();
363        let m = |command: Command, indexes: &[usize]| -> Mutation {
364            Mutation { command, scope: Scope::UtxosChanged(UtxosChangedScope::new(a(indexes))) }
365        };
366        let none = Box::<UtxosChangedSubscription>::default;
367
368        let add_all = || m(Command::Start, &[]);
369        let remove_all = || m(Command::Stop, &[]);
370        let add_0 = || m(Command::Start, &[0]);
371        let add_1 = || m(Command::Start, &[1]);
372        let add_01 = || m(Command::Start, &[0, 1]);
373        let remove_0 = || m(Command::Stop, &[0]);
374        let remove_1 = || m(Command::Stop, &[1]);
375
376        let test = Test {
377            name: "UtxosChanged",
378            context: SubscriptionContext::new(),
379            initial_state: none(),
380            steps: vec![
381                Step { name: "add all 1", mutation: add_all(), result: Some(add_all()) },
382                Step { name: "add all 2", mutation: add_all(), result: None },
383                Step { name: "remove all 2", mutation: remove_all(), result: None },
384                Step { name: "remove all 1", mutation: remove_all(), result: Some(remove_all()) },
385                Step { name: "add a0 1", mutation: add_0(), result: Some(add_0()) },
386                Step { name: "add a0 2", mutation: add_0(), result: None },
387                Step { name: "add a1 1", mutation: add_1(), result: Some(add_1()) },
388                Step { name: "remove a0 2", mutation: remove_0(), result: None },
389                Step { name: "remove a1 1", mutation: remove_1(), result: Some(remove_1()) },
390                Step { name: "remove a0 1", mutation: remove_0(), result: Some(remove_0()) },
391                // Interleaved all and address set
392                Step { name: "add all 1", mutation: add_all(), result: Some(add_all()) },
393                Step { name: "add a0a1, masked by all", mutation: add_01(), result: None },
394                Step { name: "remove all 1, revealing a0a1", mutation: remove_all(), result: Some(add_01()) },
395                Step { name: "add all 1, masking a0a1", mutation: add_all(), result: Some(add_all()) },
396                Step { name: "remove a1, masked by all", mutation: remove_1(), result: None },
397                Step { name: "remove all 1, revealing a0", mutation: remove_all(), result: Some(add_0()) },
398                Step { name: "remove a0", mutation: remove_0(), result: Some(remove_0()) },
399            ],
400            final_state: Box::new(UtxosChangedSubscription {
401                all: 0,
402                indexes: Counters::with_counters(vec![
403                    Counter { index: 0, count: 0, locked: true },
404                    Counter { index: 1, count: 0, locked: false },
405                ]),
406            }),
407        };
408        let mut state = test.run();
409
410        // Removing once more must panic
411        let result = std::panic::catch_unwind(AssertUnwindSafe(|| state.compound(remove_all(), &test.context)));
412        assert!(result.is_err(), "{}: trying to remove all when counter is zero must panic", test.name);
413        // let result = std::panic::catch_unwind(AssertUnwindSafe(|| state.compound(remove_0(), &test.context)));
414        // assert!(result.is_err(), "{}: trying to remove an address when its counter is zero must panic", test.name);
415    }
416}