Skip to main content

dbsp/operator/dynamic/
count.rs

1//! Count operators.
2
3use crate::{
4    DBData, Timestamp, ZWeight,
5    algebra::{IndexedZSet, OrdIndexedZSet},
6    circuit::{Circuit, Stream},
7    dynamic::{ClonableTrait, DataTrait, Erase},
8    operator::dynamic::{
9        aggregate::{
10            IncAggregateLinearFactories, StreamLinearAggregateFactories, WeightedCountOutFunc,
11        },
12        distinct::DistinctFactories,
13    },
14    trace::{BatchReaderFactories, Deserializable},
15};
16
17pub struct DistinctCountFactories<Z, O, T>
18where
19    Z: IndexedZSet,
20    O: IndexedZSet<Key = Z::Key>,
21    O::Val: DataTrait,
22    T: Timestamp,
23{
24    distinct_factories: DistinctFactories<Z, T>,
25    aggregate_factories: IncAggregateLinearFactories<Z, Z::R, O, T>,
26}
27
28impl<Z, O, T> DistinctCountFactories<Z, O, T>
29where
30    Z: IndexedZSet,
31    O: IndexedZSet<Key = Z::Key>,
32    T: Timestamp,
33{
34    pub fn new<KType, VType, OType>() -> Self
35    where
36        KType: DBData + Erase<Z::Key>,
37        <KType as Deserializable>::ArchivedDeser: Ord,
38        VType: DBData + Erase<Z::Val>,
39        OType: DBData + Erase<O::Val>,
40    {
41        Self {
42            distinct_factories: DistinctFactories::new::<KType, VType>(),
43            aggregate_factories: IncAggregateLinearFactories::new::<KType, ZWeight, OType>(),
44        }
45    }
46}
47
48pub struct StreamDistinctCountFactories<Z, O>
49where
50    Z: IndexedZSet,
51    O: IndexedZSet<Key = Z::Key>,
52{
53    input_factories: Z::Factories,
54    aggregate_factories: StreamLinearAggregateFactories<Z, Z::R, O>,
55}
56
57impl<Z, O> StreamDistinctCountFactories<Z, O>
58where
59    Z: IndexedZSet,
60    O: IndexedZSet<Key = Z::Key>,
61{
62    pub fn new<KType, VType, OType>() -> Self
63    where
64        KType: DBData + Erase<Z::Key>,
65        <KType as Deserializable>::ArchivedDeser: Ord,
66        VType: DBData + Erase<Z::Val>,
67        OType: DBData + Erase<O::Val>,
68    {
69        Self {
70            input_factories: BatchReaderFactories::new::<KType, VType, ZWeight>(),
71            aggregate_factories: StreamLinearAggregateFactories::new::<KType, VType, ZWeight, OType>(
72            ),
73        }
74    }
75}
76
77impl<C, Z> Stream<C, Z>
78where
79    C: Circuit,
80    Z: IndexedZSet,
81{
82    /// See [`Stream::weighted_count`].
83    #[allow(clippy::type_complexity)]
84    pub fn dyn_weighted_count(
85        &self,
86        persistent_id: Option<&str>,
87        factories: &IncAggregateLinearFactories<Z, Z::R, OrdIndexedZSet<Z::Key, Z::R>, C::Time>,
88    ) -> Stream<C, OrdIndexedZSet<Z::Key, Z::R>> {
89        self.dyn_weighted_count_generic(persistent_id, factories, Box::new(|w, out| w.move_to(out)))
90    }
91
92    /// Like [`Self::dyn_weighted_count`], but can return any batch type.
93    pub fn dyn_weighted_count_generic<A, O>(
94        &self,
95        persistent_id: Option<&str>,
96        factories: &IncAggregateLinearFactories<Z, Z::R, O, C::Time>,
97        out_func: Box<dyn WeightedCountOutFunc<Z::R, A>>,
98    ) -> Stream<C, O>
99    where
100        O: IndexedZSet<Key = Z::Key, Val = A>,
101        A: DataTrait + ?Sized,
102    {
103        self.dyn_aggregate_linear_generic(
104            persistent_id,
105            factories,
106            Box::new(|_k, _v, w, res| w.clone_to(res)),
107            out_func,
108        )
109    }
110
111    /// See [`Stream::distinct_count`].
112    #[allow(clippy::type_complexity)]
113    pub fn dyn_distinct_count(
114        &self,
115        persistent_id: Option<&str>,
116        factories: &DistinctCountFactories<Z, OrdIndexedZSet<Z::Key, Z::R>, C::Time>,
117    ) -> Stream<C, OrdIndexedZSet<Z::Key, Z::R>>
118    where
119        Z: Send,
120    {
121        self.dyn_distinct_count_generic(persistent_id, factories, Box::new(|w, out| w.move_to(out)))
122    }
123
124    /// Like [`Self::dyn_distinct_count`], but can return any batch type.
125    pub fn dyn_distinct_count_generic<A, O>(
126        &self,
127        persistent_id: Option<&str>,
128        factories: &DistinctCountFactories<Z, O, C::Time>,
129        out_func: Box<dyn WeightedCountOutFunc<Z::R, A>>,
130    ) -> Stream<C, O>
131    where
132        A: DataTrait + ?Sized,
133        O: IndexedZSet<Key = Z::Key, Val = A>,
134        Z: Send,
135    {
136        self.dyn_distinct(&factories.distinct_factories)
137            .dyn_weighted_count_generic(persistent_id, &factories.aggregate_factories, out_func)
138    }
139
140    /// See [`Stream::stream_weighted_count`].
141    #[allow(clippy::type_complexity)]
142    pub fn dyn_stream_weighted_count(
143        &self,
144        factories: &StreamLinearAggregateFactories<Z, Z::R, OrdIndexedZSet<Z::Key, Z::R>>,
145    ) -> Stream<C, OrdIndexedZSet<Z::Key, Z::R>> {
146        self.dyn_stream_weighted_count_generic(factories, Box::new(|w, out| w.move_to(out)))
147    }
148
149    /// Like [`Self::dyn_stream_weighted_count`], but can return any batch type.
150    pub fn dyn_stream_weighted_count_generic<A, O>(
151        &self,
152        factories: &StreamLinearAggregateFactories<Z, Z::R, O>,
153        out_func: Box<dyn WeightedCountOutFunc<Z::R, A>>,
154    ) -> Stream<C, O>
155    where
156        A: DataTrait + ?Sized,
157        O: IndexedZSet<Key = Z::Key, Val = A>,
158    {
159        self.dyn_stream_aggregate_linear_generic(
160            factories,
161            Box::new(|_k, _v, w, res| w.clone_to(res)),
162            out_func,
163        )
164    }
165
166    /// See [`Stream::stream_distinct_count`].
167    #[allow(clippy::type_complexity)]
168    pub fn dyn_stream_distinct_count(
169        &self,
170        factories: &StreamDistinctCountFactories<Z, OrdIndexedZSet<Z::Key, Z::R>>,
171    ) -> Stream<C, OrdIndexedZSet<Z::Key, Z::R>>
172    where
173        Z: Send,
174    {
175        self.dyn_stream_distinct_count_generic(factories, Box::new(|w, out| w.move_to(out)))
176    }
177
178    /// Like [`Self::dyn_distinct_count`], but can return any batch type.
179    pub fn dyn_stream_distinct_count_generic<A, O>(
180        &self,
181        factories: &StreamDistinctCountFactories<Z, O>,
182        out_func: Box<dyn WeightedCountOutFunc<Z::R, A>>,
183    ) -> Stream<C, O>
184    where
185        A: DataTrait + ?Sized,
186        O: IndexedZSet<Key = Z::Key, Val = A>,
187        Z: Send,
188    {
189        self.dyn_stream_distinct(&factories.input_factories)
190            .dyn_stream_weighted_count_generic(&factories.aggregate_factories, out_func)
191    }
192}
193
194#[cfg(test)]
195mod test {
196    use crate::{
197        Runtime, indexed_zset,
198        typed_batch::{IndexedZSetReader, OrdIndexedZSet, SpineSnapshot},
199        utils::Tup2,
200    };
201    use core::ops::Range;
202    use rand::{Rng, SeedableRng, rngs::StdRng, seq::SliceRandom};
203
204    #[test]
205    fn weighted_count_test() {
206        let (mut circuit, (input_handle, counts, stream_counts)) =
207            Runtime::init_circuit(1, move |circuit| {
208                let (inputs, input_handle) = circuit.add_input_zset::<i64>();
209
210                let counts = inputs.weighted_count().accumulate_integrate();
211                let stream_counts = circuit
212                    .non_incremental(&inputs, |_child, inputs| {
213                        Ok(inputs.integrate().stream_weighted_count())
214                    })
215                    .unwrap();
216
217                Ok((
218                    input_handle,
219                    counts.accumulate_output(),
220                    stream_counts.accumulate_output(),
221                ))
222            })
223            .unwrap();
224
225        // Generate expected values in `counts` by another means, using the formula for
226        // A077925 (https://oeis.org/A077925).
227        fn a077925(n: i64) -> i64 {
228            let mut x = 2 << n;
229            if (n & 1) == 0 {
230                x = -x;
231            }
232            (1 - x) / 3
233        }
234
235        let mut next = 0;
236        let mut term = 0;
237        let mut ones_count = 0;
238
239        for _ in 0..10 {
240            // Generate sequence with key 1 and weights 1, -2, 4, -8, 16, -32, ...
241            // Generate sequence with key 2 and delayed weights.
242            input_handle.push(2, next);
243            next = if next == 0 { 1 } else { next * (-2) };
244            input_handle.push(1, next);
245
246            circuit.transaction().unwrap();
247            let counts = counts.concat().consolidate();
248            let stream_counts = stream_counts.concat().consolidate();
249            // println!("counts={}", counts);
250            // println!("stream_counts={}", stream_counts);
251            // println!("expected={}", expected_counts);
252
253            term += 1;
254
255            let twos_count = ones_count;
256            ones_count = a077925(term - 1);
257
258            let expected_counts = if twos_count == 0 {
259                indexed_zset! { 1 => {ones_count => 1 } }
260            } else {
261                indexed_zset! { 1 => {ones_count => 1 }, 2 => {twos_count => 1} }
262            };
263
264            assert_eq!(counts, expected_counts);
265            assert_eq!(stream_counts, expected_counts);
266        }
267    }
268
269    #[test]
270    fn distinct_count_test() {
271        // Number of steps to test.
272        const N: usize = 50;
273
274        // Generate `input` as a vector of `N` Z-sets with keys in range `K`, values in
275        // range `V`, and weights in range `W`, and `expected` as a vector that
276        // for each element in `input` contains a Z-set that maps from each key
277        // to the number of values with positive weight.
278        const K: Range<u64> = 0..10; // Range of keys in Z-set.
279        const V: Range<u64> = 0..10; // Range of values in Z-set.
280        const W: Range<i64> = -10..10; // Range of weights in Z-set.
281        let mut rng = StdRng::seed_from_u64(0); // Make the test reproducible.
282        let mut input: Vec<Vec<Tup2<u64, Tup2<i64, i64>>>> = Vec::new();
283        let mut expected: Vec<Vec<(u64, i64, i64)>> = Vec::new();
284        for _ in 0..N {
285            let mut input_tuples = Vec::new();
286            let mut expected_tuples = Vec::new();
287            for k in K {
288                let mut v: Vec<u64> = V.collect();
289                let n = rng.gen_range(V);
290                v.partial_shuffle(&mut rng, n as usize);
291
292                let mut distinct_count = 0;
293                for &v in &v[0..n as usize] {
294                    let w = rng.gen_range(W);
295                    input_tuples.push(Tup2(k, Tup2(v as i64, w)));
296                    if w > 0 {
297                        distinct_count += 1;
298                    }
299                }
300                if distinct_count > 0 {
301                    expected_tuples.push((k, distinct_count, 1i64));
302                }
303            }
304            input.push(input_tuples);
305            expected.push(expected_tuples);
306        }
307        let input_copy = input.clone();
308
309        let (mut circuit, (source_handle, counts, _stream_counts)) =
310            Runtime::init_circuit(1, move |circuit| {
311                let (source, source_handle) = circuit.add_input_indexed_zset::<u64, i64>();
312                let counts = source
313                    .accumulate_differentiate()
314                    .distinct_count()
315                    .accumulate_integrate();
316                let stream_counts = source.stream_distinct_count();
317                Ok((
318                    source_handle,
319                    counts.accumulate_output(),
320                    stream_counts.accumulate_output(),
321                ))
322            })
323            .unwrap();
324
325        for (mut input, expected_counts) in input_copy.into_iter().zip(expected.into_iter()) {
326            println!("step");
327            source_handle.append(&mut input);
328            circuit.transaction().unwrap();
329
330            let counts = SpineSnapshot::<OrdIndexedZSet<u64, i64>>::concat(&counts.take_from_all())
331                .iter()
332                .collect::<Vec<_>>();
333
334            // let stream_counts = stream_counts.consolidate();
335            // println!("input={}", _input);
336            // println!("counts={}", counts);
337            // println!("stream_counts={}", stream_counts);
338            // println!("expected={}", expected_counts);
339
340            assert_eq!(counts, expected_counts.to_vec());
341
342            // TODO
343            //assert_eq!(stream_counts, expected_counts);
344        }
345    }
346}