1use crate::{
4 circuit::{
5 Circuit, OwnershipPreference, Scope, Stream,
6 circuit_builder::StreamId,
7 operator_traits::{Operator, UnaryOperator},
8 },
9 circuit_cache_key,
10 dynamic::{ClonableTrait, DataTrait, DynPair, DynUnit},
11 trace::{
12 Batch, BatchFactories, BatchReader, BatchReaderFactories, Builder, Cursor, OrdIndexedWSet,
13 },
14};
15use std::{borrow::Cow, marker::PhantomData, ops::DerefMut};
16
17circuit_cache_key!(IndexId<C, D>(StreamId => Stream<C, D>));
18
19impl<C, CI> Stream<C, CI>
20where
21 CI: Clone + 'static,
22 C: Circuit,
23{
24 pub fn index<K, V>(
31 &self,
32 output_factories: &<OrdIndexedWSet<K, V, CI::R> as BatchReader>::Factories,
33 ) -> Stream<C, OrdIndexedWSet<K, V, CI::R>>
34 where
35 K: DataTrait + ?Sized,
36 V: DataTrait + ?Sized,
37 CI: BatchReader<Key = DynPair<K, V>, Val = DynUnit, Time = ()>,
38 {
39 self.index_generic(output_factories)
40 }
41
42 pub fn index_generic<CO>(&self, output_factories: &CO::Factories) -> Stream<C, CO>
45 where
46 CI: BatchReader<Key = DynPair<CO::Key, CO::Val>, Val = DynUnit, Time = (), R = CO::R>,
47 CO: Batch<Time = ()>,
48 {
49 self.circuit()
50 .cache_get_or_insert_with(IndexId::new(self.stream_id()), || {
51 self.circuit()
52 .add_unary_operator(Index::new(output_factories), self)
53 })
54 .clone()
55 }
56
57 pub fn index_with<K, V, F>(
67 &self,
68 output_factories: &<OrdIndexedWSet<K, V, CI::R> as BatchReader>::Factories,
69 index_func: F,
70 ) -> Stream<C, OrdIndexedWSet<K, V, CI::R>>
71 where
72 CI: BatchReader<Time = (), Val = DynUnit>,
73 F: Fn(&CI::Key, &mut DynPair<K, V>) + Clone + 'static,
74 K: DataTrait + ?Sized,
75 V: DataTrait + ?Sized,
76 {
77 self.index_with_generic(index_func, output_factories)
78 }
79
80 pub fn index_with_generic<CO, F>(
83 &self,
84 index_func: F,
85 output_factories: &CO::Factories,
86 ) -> Stream<C, CO>
87 where
88 CI: BatchReader<Time = (), Val = DynUnit>,
89 CO: Batch<Time = (), R = CI::R>,
90 F: Fn(&CI::Key, &mut DynPair<CO::Key, CO::Val>) + Clone + 'static,
91 {
92 self.circuit()
93 .add_unary_operator(IndexWith::new(index_func, output_factories), self)
94 }
95}
96
97pub struct Index<CI, CO: BatchReader> {
108 factories: CO::Factories,
109 _type: PhantomData<(CI, CO)>,
110}
111
112impl<CI, CO: BatchReader> Index<CI, CO> {
113 pub fn new(factories: &CO::Factories) -> Self {
114 Self {
115 factories: factories.clone(),
116 _type: PhantomData,
117 }
118 }
119}
120
121impl<CI, CO> Operator for Index<CI, CO>
122where
123 CI: 'static,
124 CO: BatchReader + 'static,
125{
126 fn name(&self) -> Cow<'static, str> {
127 Cow::from("Index")
128 }
129 fn fixedpoint(&self, _scope: Scope) -> bool {
130 true
131 }
132}
133
134impl<CI, CO> UnaryOperator<CI, CO> for Index<CI, CO>
135where
136 CO: Batch<Time = ()>,
137 CI: BatchReader<Key = DynPair<CO::Key, CO::Val>, Val = DynUnit, Time = (), R = CO::R>,
138{
139 async fn eval(&mut self, input: &CI) -> CO {
140 let mut builder =
141 <CO as Batch>::Builder::with_capacity(&self.factories, input.len(), input.len());
142
143 let mut cursor = input.cursor();
144 let mut prev_key = self.factories.key_factory().default_box();
145 let mut has_prev_key = false;
146 while cursor.key_valid() {
147 builder.push_diff(cursor.weight());
148 let (k, v) = cursor.key().split();
149 if has_prev_key {
150 if k != &*prev_key {
151 builder.push_key_mut(&mut prev_key);
152 k.clone_to(&mut prev_key);
153 }
154 } else {
155 k.clone_to(&mut prev_key);
156 has_prev_key = true;
157 }
158 builder.push_val(v);
159
160 cursor.step_key();
161 }
162 if has_prev_key {
163 builder.push_key_mut(&mut prev_key);
164 }
165
166 builder.done()
167 }
168
169 fn input_preference(&self) -> OwnershipPreference {
187 OwnershipPreference::PREFER_OWNED
188 }
189}
190
191pub struct IndexWith<CI, CO: BatchReader, F> {
206 factories: CO::Factories,
207 index_func: F,
208 _type: PhantomData<(CI, CO)>,
209}
210
211impl<CI, CO: BatchReader, F> IndexWith<CI, CO, F> {
212 pub fn new(index_func: F, factories: &CO::Factories) -> Self {
213 Self {
214 factories: factories.clone(),
215 index_func,
216 _type: PhantomData,
217 }
218 }
219}
220
221impl<CI, CO, F> Operator for IndexWith<CI, CO, F>
222where
223 CI: 'static,
224 CO: BatchReader + 'static,
225 F: 'static,
226{
227 fn name(&self) -> Cow<'static, str> {
228 Cow::from("IndexWith")
229 }
230
231 fn fixedpoint(&self, _scope: Scope) -> bool {
232 true
233 }
234}
235
236impl<CI, CO, F> UnaryOperator<CI, CO> for IndexWith<CI, CO, F>
237where
238 CO: Batch<Time = ()>,
239 CI: BatchReader<Val = DynUnit, Time = (), R = CO::R>,
240 F: Fn(&CI::Key, &mut DynPair<CO::Key, CO::Val>) + 'static,
241{
242 async fn eval(&mut self, i: &CI) -> CO {
243 let mut tuples = self.factories.weighted_items_factory().default_box();
244 tuples.reserve(i.len());
245
246 let mut item = self.factories.weighted_item_factory().default_box();
247
248 let mut cursor = i.cursor();
249 while cursor.key_valid() {
250 let (kv, weight) = item.split_mut();
251 (self.index_func)(cursor.key(), kv);
252 cursor.weight().clone_to(weight);
253 tuples.push_val(item.deref_mut());
254 cursor.step_key();
255 }
256
257 CO::dyn_from_tuples(&self.factories, (), &mut tuples)
258 }
259
260 async fn eval_owned(&mut self, i: CI) -> CO {
261 self.eval(&i).await
263 }
264}
265
266#[cfg(test)]
267mod test {
268 use crate::{
269 Circuit, RootCircuit, ZWeight,
270 dynamic::{ClonableTrait, DynData, DynPair, Erase, LeanVec},
271 indexed_zset,
272 operator::Generator,
273 trace::{BatchReaderFactories, Batcher},
274 typed_batch::{DynBatch, DynOrdZSet, OrdIndexedZSet},
275 utils::Tup2,
276 };
277
278 #[test]
279 fn index_test() {
280 let circuit = RootCircuit::build(move |circuit| {
281 let mut inputs = vec![
282 vec![
283 (Tup2(1, 'a'), 1i64),
284 (Tup2(1, 'b'), 1),
285 (Tup2(2, 'a'), 1),
286 (Tup2(2, 'c'), 1),
287 (Tup2(1, 'a'), 2),
288 (Tup2(1, 'b'), -1),
289 ],
290 vec![
291 (Tup2(1, 'd'), 1),
292 (Tup2(1, 'e'), 1),
293 (Tup2(2, 'a'), -1),
294 (Tup2(3, 'a'), 2),
295 ],
296 ]
297 .into_iter()
298 .map(|tuples| {
299 let tuples = tuples
300 .into_iter()
301 .map(|(k, v)| Tup2(Tup2(k, ()), v))
302 .collect::<Vec<_>>();
303 let mut batcher =
304 <DynOrdZSet<DynPair<DynData, DynData>> as DynBatch>::Batcher::new_batcher(
305 &BatchReaderFactories::new::<Tup2<i32, char>, (), ZWeight>(),
306 (),
307 );
308 batcher.push_batch(&mut Box::new(LeanVec::from(tuples)).erase_box());
309 batcher.seal()
310 });
311 let mut outputs = vec![
312 indexed_zset! { 1 => {'a' => 3}, 2 => {'a' => 1, 'c' => 1}},
313 indexed_zset! { 1 => {'e' => 1, 'd' => 1}, 2 => {'a' => -1}, 3 => {'a' => 2}},
314 ]
315 .into_iter();
316 circuit
317 .add_source(Generator::new(move || inputs.next().unwrap()))
318 .index(&BatchReaderFactories::new::<i32, char, ZWeight>())
319 .typed()
320 .inspect(move |fm: &OrdIndexedZSet<_, _>| assert_eq!(fm, &outputs.next().unwrap()));
322 Ok(())
323 })
324 .unwrap()
325 .0;
326
327 for _ in 0..2 {
328 circuit.transaction().unwrap();
329 }
330 }
331
332 #[test]
333 fn index_with_test() {
334 let circuit = RootCircuit::build(move |circuit| {
335 let mut inputs = vec![
336 vec![
337 (Tup2(1, 'a'), 1),
338 (Tup2(1, 'b'), 1),
339 (Tup2(2, 'a'), 1),
340 (Tup2(2, 'c'), 1),
341 (Tup2(1, 'a'), 2),
342 (Tup2(1, 'b'), -1),
343 ],
344 vec![
345 (Tup2(1, 'd'), 1),
346 (Tup2(1, 'e'), 1),
347 (Tup2(2, 'a'), -1),
348 (Tup2(3, 'a'), 2),
349 ],
350 ]
351 .into_iter()
352 .map(|tuples| {
353 let tuples = tuples
354 .into_iter()
355 .map(|(k, v)| Tup2(Tup2(k, ()), v))
356 .collect::<Vec<_>>();
357 let mut batcher =
358 <DynOrdZSet<DynPair<DynData, DynData>> as DynBatch>::Batcher::new_batcher(
359 &BatchReaderFactories::new::<Tup2<i32, char>, (), ZWeight>(),
360 (),
361 );
362 batcher.push_batch(&mut Box::new(LeanVec::from(tuples)).erase_box());
363 batcher.seal()
364 });
365
366 let mut outputs = vec![
367 indexed_zset! { 1 => {'a' => 3}, 2 => {'a' => 1, 'c' => 1}},
368 indexed_zset! { 1 => {'e' => 1, 'd' => 1}, 2 => {'a' => -1}, 3 => {'a' => 2}},
369 ]
370 .into_iter();
371
372 circuit
373 .add_source(Generator::new(move || inputs.next().unwrap()))
374 .index_with(
375 &BatchReaderFactories::new::<i32, char, ZWeight>(),
376 |kv: &DynPair<DynData , DynData >, result| {
377 kv.clone_to(result)
378 },
379 )
380 .typed()
381 .inspect(move |fm: &OrdIndexedZSet<_, _>| assert_eq!(fm, &outputs.next().unwrap()));
382 Ok(())
383 })
384 .unwrap()
385 .0;
386
387 for _ in 0..2 {
388 circuit.transaction().unwrap();
389 }
390 }
391}