1use crate::{
4 DBData, Timestamp, ZWeight,
5 algebra::{IndexedZSet, OrdIndexedZSet},
6 circuit::{Circuit, Stream},
7 dynamic::{ClonableTrait, DataTrait, Erase},
8 operator::dynamic::{
9 aggregate::{
10 IncAggregateLinearFactories, StreamLinearAggregateFactories, WeightedCountOutFunc,
11 },
12 distinct::DistinctFactories,
13 },
14 trace::{BatchReaderFactories, Deserializable},
15};
16
17pub struct DistinctCountFactories<Z, O, T>
18where
19 Z: IndexedZSet,
20 O: IndexedZSet<Key = Z::Key>,
21 O::Val: DataTrait,
22 T: Timestamp,
23{
24 distinct_factories: DistinctFactories<Z, T>,
25 aggregate_factories: IncAggregateLinearFactories<Z, Z::R, O, T>,
26}
27
28impl<Z, O, T> DistinctCountFactories<Z, O, T>
29where
30 Z: IndexedZSet,
31 O: IndexedZSet<Key = Z::Key>,
32 T: Timestamp,
33{
34 pub fn new<KType, VType, OType>() -> Self
35 where
36 KType: DBData + Erase<Z::Key>,
37 <KType as Deserializable>::ArchivedDeser: Ord,
38 VType: DBData + Erase<Z::Val>,
39 OType: DBData + Erase<O::Val>,
40 {
41 Self {
42 distinct_factories: DistinctFactories::new::<KType, VType>(),
43 aggregate_factories: IncAggregateLinearFactories::new::<KType, ZWeight, OType>(),
44 }
45 }
46}
47
48pub struct StreamDistinctCountFactories<Z, O>
49where
50 Z: IndexedZSet,
51 O: IndexedZSet<Key = Z::Key>,
52{
53 input_factories: Z::Factories,
54 aggregate_factories: StreamLinearAggregateFactories<Z, Z::R, O>,
55}
56
57impl<Z, O> StreamDistinctCountFactories<Z, O>
58where
59 Z: IndexedZSet,
60 O: IndexedZSet<Key = Z::Key>,
61{
62 pub fn new<KType, VType, OType>() -> Self
63 where
64 KType: DBData + Erase<Z::Key>,
65 <KType as Deserializable>::ArchivedDeser: Ord,
66 VType: DBData + Erase<Z::Val>,
67 OType: DBData + Erase<O::Val>,
68 {
69 Self {
70 input_factories: BatchReaderFactories::new::<KType, VType, ZWeight>(),
71 aggregate_factories: StreamLinearAggregateFactories::new::<KType, VType, ZWeight, OType>(
72 ),
73 }
74 }
75}
76
77impl<C, Z> Stream<C, Z>
78where
79 C: Circuit,
80 Z: IndexedZSet,
81{
82 #[allow(clippy::type_complexity)]
84 pub fn dyn_weighted_count(
85 &self,
86 persistent_id: Option<&str>,
87 factories: &IncAggregateLinearFactories<Z, Z::R, OrdIndexedZSet<Z::Key, Z::R>, C::Time>,
88 ) -> Stream<C, OrdIndexedZSet<Z::Key, Z::R>> {
89 self.dyn_weighted_count_generic(persistent_id, factories, Box::new(|w, out| w.move_to(out)))
90 }
91
92 pub fn dyn_weighted_count_generic<A, O>(
94 &self,
95 persistent_id: Option<&str>,
96 factories: &IncAggregateLinearFactories<Z, Z::R, O, C::Time>,
97 out_func: Box<dyn WeightedCountOutFunc<Z::R, A>>,
98 ) -> Stream<C, O>
99 where
100 O: IndexedZSet<Key = Z::Key, Val = A>,
101 A: DataTrait + ?Sized,
102 {
103 self.dyn_aggregate_linear_generic(
104 persistent_id,
105 factories,
106 Box::new(|_k, _v, w, res| w.clone_to(res)),
107 out_func,
108 )
109 }
110
111 #[allow(clippy::type_complexity)]
113 pub fn dyn_distinct_count(
114 &self,
115 persistent_id: Option<&str>,
116 factories: &DistinctCountFactories<Z, OrdIndexedZSet<Z::Key, Z::R>, C::Time>,
117 ) -> Stream<C, OrdIndexedZSet<Z::Key, Z::R>>
118 where
119 Z: Send,
120 {
121 self.dyn_distinct_count_generic(persistent_id, factories, Box::new(|w, out| w.move_to(out)))
122 }
123
124 pub fn dyn_distinct_count_generic<A, O>(
126 &self,
127 persistent_id: Option<&str>,
128 factories: &DistinctCountFactories<Z, O, C::Time>,
129 out_func: Box<dyn WeightedCountOutFunc<Z::R, A>>,
130 ) -> Stream<C, O>
131 where
132 A: DataTrait + ?Sized,
133 O: IndexedZSet<Key = Z::Key, Val = A>,
134 Z: Send,
135 {
136 self.dyn_distinct(&factories.distinct_factories)
137 .dyn_weighted_count_generic(persistent_id, &factories.aggregate_factories, out_func)
138 }
139
140 #[allow(clippy::type_complexity)]
142 pub fn dyn_stream_weighted_count(
143 &self,
144 factories: &StreamLinearAggregateFactories<Z, Z::R, OrdIndexedZSet<Z::Key, Z::R>>,
145 ) -> Stream<C, OrdIndexedZSet<Z::Key, Z::R>> {
146 self.dyn_stream_weighted_count_generic(factories, Box::new(|w, out| w.move_to(out)))
147 }
148
149 pub fn dyn_stream_weighted_count_generic<A, O>(
151 &self,
152 factories: &StreamLinearAggregateFactories<Z, Z::R, O>,
153 out_func: Box<dyn WeightedCountOutFunc<Z::R, A>>,
154 ) -> Stream<C, O>
155 where
156 A: DataTrait + ?Sized,
157 O: IndexedZSet<Key = Z::Key, Val = A>,
158 {
159 self.dyn_stream_aggregate_linear_generic(
160 factories,
161 Box::new(|_k, _v, w, res| w.clone_to(res)),
162 out_func,
163 )
164 }
165
166 #[allow(clippy::type_complexity)]
168 pub fn dyn_stream_distinct_count(
169 &self,
170 factories: &StreamDistinctCountFactories<Z, OrdIndexedZSet<Z::Key, Z::R>>,
171 ) -> Stream<C, OrdIndexedZSet<Z::Key, Z::R>>
172 where
173 Z: Send,
174 {
175 self.dyn_stream_distinct_count_generic(factories, Box::new(|w, out| w.move_to(out)))
176 }
177
178 pub fn dyn_stream_distinct_count_generic<A, O>(
180 &self,
181 factories: &StreamDistinctCountFactories<Z, O>,
182 out_func: Box<dyn WeightedCountOutFunc<Z::R, A>>,
183 ) -> Stream<C, O>
184 where
185 A: DataTrait + ?Sized,
186 O: IndexedZSet<Key = Z::Key, Val = A>,
187 Z: Send,
188 {
189 self.dyn_stream_distinct(&factories.input_factories)
190 .dyn_stream_weighted_count_generic(&factories.aggregate_factories, out_func)
191 }
192}
193
194#[cfg(test)]
195mod test {
196 use crate::{
197 Runtime, indexed_zset,
198 typed_batch::{IndexedZSetReader, OrdIndexedZSet, SpineSnapshot},
199 utils::Tup2,
200 };
201 use core::ops::Range;
202 use rand::{Rng, SeedableRng, rngs::StdRng, seq::SliceRandom};
203
204 #[test]
205 fn weighted_count_test() {
206 let (mut circuit, (input_handle, counts, stream_counts)) =
207 Runtime::init_circuit(1, move |circuit| {
208 let (inputs, input_handle) = circuit.add_input_zset::<i64>();
209
210 let counts = inputs.weighted_count().accumulate_integrate();
211 let stream_counts = circuit
212 .non_incremental(&inputs, |_child, inputs| {
213 Ok(inputs.integrate().stream_weighted_count())
214 })
215 .unwrap();
216
217 Ok((
218 input_handle,
219 counts.accumulate_output(),
220 stream_counts.accumulate_output(),
221 ))
222 })
223 .unwrap();
224
225 fn a077925(n: i64) -> i64 {
228 let mut x = 2 << n;
229 if (n & 1) == 0 {
230 x = -x;
231 }
232 (1 - x) / 3
233 }
234
235 let mut next = 0;
236 let mut term = 0;
237 let mut ones_count = 0;
238
239 for _ in 0..10 {
240 input_handle.push(2, next);
243 next = if next == 0 { 1 } else { next * (-2) };
244 input_handle.push(1, next);
245
246 circuit.transaction().unwrap();
247 let counts = counts.concat().consolidate();
248 let stream_counts = stream_counts.concat().consolidate();
249 term += 1;
254
255 let twos_count = ones_count;
256 ones_count = a077925(term - 1);
257
258 let expected_counts = if twos_count == 0 {
259 indexed_zset! { 1 => {ones_count => 1 } }
260 } else {
261 indexed_zset! { 1 => {ones_count => 1 }, 2 => {twos_count => 1} }
262 };
263
264 assert_eq!(counts, expected_counts);
265 assert_eq!(stream_counts, expected_counts);
266 }
267 }
268
269 #[test]
270 fn distinct_count_test() {
271 const N: usize = 50;
273
274 const K: Range<u64> = 0..10; const V: Range<u64> = 0..10; const W: Range<i64> = -10..10; let mut rng = StdRng::seed_from_u64(0); let mut input: Vec<Vec<Tup2<u64, Tup2<i64, i64>>>> = Vec::new();
283 let mut expected: Vec<Vec<(u64, i64, i64)>> = Vec::new();
284 for _ in 0..N {
285 let mut input_tuples = Vec::new();
286 let mut expected_tuples = Vec::new();
287 for k in K {
288 let mut v: Vec<u64> = V.collect();
289 let n = rng.gen_range(V);
290 v.partial_shuffle(&mut rng, n as usize);
291
292 let mut distinct_count = 0;
293 for &v in &v[0..n as usize] {
294 let w = rng.gen_range(W);
295 input_tuples.push(Tup2(k, Tup2(v as i64, w)));
296 if w > 0 {
297 distinct_count += 1;
298 }
299 }
300 if distinct_count > 0 {
301 expected_tuples.push((k, distinct_count, 1i64));
302 }
303 }
304 input.push(input_tuples);
305 expected.push(expected_tuples);
306 }
307 let input_copy = input.clone();
308
309 let (mut circuit, (source_handle, counts, _stream_counts)) =
310 Runtime::init_circuit(1, move |circuit| {
311 let (source, source_handle) = circuit.add_input_indexed_zset::<u64, i64>();
312 let counts = source
313 .accumulate_differentiate()
314 .distinct_count()
315 .accumulate_integrate();
316 let stream_counts = source.stream_distinct_count();
317 Ok((
318 source_handle,
319 counts.accumulate_output(),
320 stream_counts.accumulate_output(),
321 ))
322 })
323 .unwrap();
324
325 for (mut input, expected_counts) in input_copy.into_iter().zip(expected.into_iter()) {
326 println!("step");
327 source_handle.append(&mut input);
328 circuit.transaction().unwrap();
329
330 let counts = SpineSnapshot::<OrdIndexedZSet<u64, i64>>::concat(&counts.take_from_all())
331 .iter()
332 .collect::<Vec<_>>();
333
334 assert_eq!(counts, expected_counts.to_vec());
341
342 }
345 }
346}