1use crate::{
4 circuit::{
5 circuit_builder::StreamId,
6 operator_traits::{Operator, UnaryOperator},
7 Circuit, OwnershipPreference, Scope, Stream,
8 },
9 circuit_cache_key,
10 dynamic::{ClonableTrait, DataTrait, DynPair, DynUnit},
11 trace::{
12 Batch, BatchFactories, BatchReader, BatchReaderFactories, Builder, Cursor, OrdIndexedWSet,
13 },
14};
15use minitrace::trace;
16use std::{borrow::Cow, marker::PhantomData, ops::DerefMut};
17
18circuit_cache_key!(IndexId<C, D>(StreamId => Stream<C, D>));
19
20impl<C, CI> Stream<C, CI>
21where
22 CI: Clone + 'static,
23 C: Circuit,
24{
25 pub fn index<K, V>(
32 &self,
33 output_factories: &<OrdIndexedWSet<K, V, CI::R> as BatchReader>::Factories,
34 ) -> Stream<C, OrdIndexedWSet<K, V, CI::R>>
35 where
36 K: DataTrait + ?Sized,
37 V: DataTrait + ?Sized,
38 CI: BatchReader<Key = DynPair<K, V>, Val = DynUnit, Time = ()>,
39 {
40 self.index_generic(output_factories)
41 }
42
43 pub fn index_generic<CO>(&self, output_factories: &CO::Factories) -> Stream<C, CO>
46 where
47 CI: BatchReader<Key = DynPair<CO::Key, CO::Val>, Val = DynUnit, Time = (), R = CO::R>,
48 CO: Batch<Time = ()>,
49 {
50 self.circuit()
51 .cache_get_or_insert_with(IndexId::new(self.stream_id()), || {
52 self.circuit()
53 .add_unary_operator(Index::new(output_factories), self)
54 })
55 .clone()
56 }
57
58 pub fn index_with<K, V, F>(
68 &self,
69 output_factories: &<OrdIndexedWSet<K, V, CI::R> as BatchReader>::Factories,
70 index_func: F,
71 ) -> Stream<C, OrdIndexedWSet<K, V, CI::R>>
72 where
73 CI: BatchReader<Time = (), Val = DynUnit>,
74 F: Fn(&CI::Key, &mut DynPair<K, V>) + Clone + 'static,
75 K: DataTrait + ?Sized,
76 V: DataTrait + ?Sized,
77 {
78 self.index_with_generic(index_func, output_factories)
79 }
80
81 pub fn index_with_generic<CO, F>(
84 &self,
85 index_func: F,
86 output_factories: &CO::Factories,
87 ) -> Stream<C, CO>
88 where
89 CI: BatchReader<Time = (), Val = DynUnit>,
90 CO: Batch<Time = (), R = CI::R>,
91 F: Fn(&CI::Key, &mut DynPair<CO::Key, CO::Val>) + Clone + 'static,
92 {
93 self.circuit()
94 .add_unary_operator(IndexWith::new(index_func, output_factories), self)
95 }
96}
97
98pub struct Index<CI, CO: BatchReader> {
109 factories: CO::Factories,
110 _type: PhantomData<(CI, CO)>,
111}
112
113impl<CI, CO: BatchReader> Index<CI, CO> {
114 pub fn new(factories: &CO::Factories) -> Self {
115 Self {
116 factories: factories.clone(),
117 _type: PhantomData,
118 }
119 }
120}
121
122impl<CI, CO> Operator for Index<CI, CO>
123where
124 CI: 'static,
125 CO: BatchReader + 'static,
126{
127 fn name(&self) -> Cow<'static, str> {
128 Cow::from("Index")
129 }
130 fn fixedpoint(&self, _scope: Scope) -> bool {
131 true
132 }
133}
134
135impl<CI, CO> UnaryOperator<CI, CO> for Index<CI, CO>
136where
137 CO: Batch<Time = ()>,
138 CI: BatchReader<Key = DynPair<CO::Key, CO::Val>, Val = DynUnit, Time = (), R = CO::R>,
139{
140 #[trace]
141 async fn eval(&mut self, input: &CI) -> CO {
142 let mut builder = <CO as Batch>::Builder::with_capacity(&self.factories, input.len());
143
144 let mut cursor = input.cursor();
145 let mut prev_key = self.factories.key_factory().default_box();
146 let mut has_prev_key = false;
147 while cursor.key_valid() {
148 builder.push_diff(cursor.weight());
149 let (k, v) = cursor.key().split();
150 if has_prev_key {
151 if k != &*prev_key {
152 builder.push_key_mut(&mut prev_key);
153 k.clone_to(&mut prev_key);
154 }
155 } else {
156 k.clone_to(&mut prev_key);
157 has_prev_key = true;
158 }
159 builder.push_val(v);
160
161 cursor.step_key();
162 }
163 if has_prev_key {
164 builder.push_key_mut(&mut prev_key);
165 }
166
167 builder.done()
168 }
169
170 fn input_preference(&self) -> OwnershipPreference {
188 OwnershipPreference::PREFER_OWNED
189 }
190}
191
192pub struct IndexWith<CI, CO: BatchReader, F> {
207 factories: CO::Factories,
208 index_func: F,
209 _type: PhantomData<(CI, CO)>,
210}
211
212impl<CI, CO: BatchReader, F> IndexWith<CI, CO, F> {
213 pub fn new(index_func: F, factories: &CO::Factories) -> Self {
214 Self {
215 factories: factories.clone(),
216 index_func,
217 _type: PhantomData,
218 }
219 }
220}
221
222impl<CI, CO, F> Operator for IndexWith<CI, CO, F>
223where
224 CI: 'static,
225 CO: BatchReader + 'static,
226 F: 'static,
227{
228 fn name(&self) -> Cow<'static, str> {
229 Cow::from("IndexWith")
230 }
231
232 fn fixedpoint(&self, _scope: Scope) -> bool {
233 true
234 }
235}
236
237impl<CI, CO, F> UnaryOperator<CI, CO> for IndexWith<CI, CO, F>
238where
239 CO: Batch<Time = ()>,
240 CI: BatchReader<Val = DynUnit, Time = (), R = CO::R>,
241 F: Fn(&CI::Key, &mut DynPair<CO::Key, CO::Val>) + 'static,
242{
243 #[trace]
244 async fn eval(&mut self, i: &CI) -> CO {
245 let mut tuples = self.factories.weighted_items_factory().default_box();
246 tuples.reserve(i.len());
247
248 let mut item = self.factories.weighted_item_factory().default_box();
249
250 let mut cursor = i.cursor();
251 while cursor.key_valid() {
252 let (kv, weight) = item.split_mut();
253 (self.index_func)(cursor.key(), kv);
254 cursor.weight().clone_to(weight);
255 tuples.push_val(item.deref_mut());
256 cursor.step_key();
257 }
258
259 CO::dyn_from_tuples(&self.factories, (), &mut tuples)
260 }
261
262 async fn eval_owned(&mut self, i: CI) -> CO {
263 self.eval(&i).await
265 }
266}
267
268#[cfg(test)]
269mod test {
270 use crate::{
271 dynamic::{ClonableTrait, DynData, DynPair, Erase, LeanVec},
272 indexed_zset,
273 operator::Generator,
274 trace::{BatchReaderFactories, Batcher},
275 typed_batch::{DynBatch, DynOrdZSet, OrdIndexedZSet},
276 utils::Tup2,
277 Circuit, RootCircuit, ZWeight,
278 };
279
280 #[test]
281 fn index_test() {
282 let circuit = RootCircuit::build(move |circuit| {
283 let mut inputs = vec![
284 vec![
285 (Tup2(1, 'a'), 1i64),
286 (Tup2(1, 'b'), 1),
287 (Tup2(2, 'a'), 1),
288 (Tup2(2, 'c'), 1),
289 (Tup2(1, 'a'), 2),
290 (Tup2(1, 'b'), -1),
291 ],
292 vec![
293 (Tup2(1, 'd'), 1),
294 (Tup2(1, 'e'), 1),
295 (Tup2(2, 'a'), -1),
296 (Tup2(3, 'a'), 2),
297 ],
298 ]
299 .into_iter()
300 .map(|tuples| {
301 let tuples = tuples
302 .into_iter()
303 .map(|(k, v)| Tup2(Tup2(k, ()), v))
304 .collect::<Vec<_>>();
305 let mut batcher =
306 <DynOrdZSet<DynPair<DynData, DynData>> as DynBatch>::Batcher::new_batcher(
307 &BatchReaderFactories::new::<Tup2<i32, char>, (), ZWeight>(),
308 (),
309 );
310 batcher.push_batch(&mut Box::new(LeanVec::from(tuples)).erase_box());
311 batcher.seal()
312 });
313 let mut outputs = vec![
314 indexed_zset! { 1 => {'a' => 3}, 2 => {'a' => 1, 'c' => 1}},
315 indexed_zset! { 1 => {'e' => 1, 'd' => 1}, 2 => {'a' => -1}, 3 => {'a' => 2}},
316 ]
317 .into_iter();
318 circuit
319 .add_source(Generator::new(move || inputs.next().unwrap()))
320 .index(&BatchReaderFactories::new::<i32, char, ZWeight>())
321 .typed()
322 .inspect(move |fm: &OrdIndexedZSet<_, _>| assert_eq!(fm, &outputs.next().unwrap()));
324 Ok(())
325 })
326 .unwrap()
327 .0;
328
329 for _ in 0..2 {
330 circuit.step().unwrap();
331 }
332 }
333
334 #[test]
335 fn index_with_test() {
336 let circuit = RootCircuit::build(move |circuit| {
337 let mut inputs = vec![
338 vec![
339 (Tup2(1, 'a'), 1),
340 (Tup2(1, 'b'), 1),
341 (Tup2(2, 'a'), 1),
342 (Tup2(2, 'c'), 1),
343 (Tup2(1, 'a'), 2),
344 (Tup2(1, 'b'), -1),
345 ],
346 vec![
347 (Tup2(1, 'd'), 1),
348 (Tup2(1, 'e'), 1),
349 (Tup2(2, 'a'), -1),
350 (Tup2(3, 'a'), 2),
351 ],
352 ]
353 .into_iter()
354 .map(|tuples| {
355 let tuples = tuples
356 .into_iter()
357 .map(|(k, v)| Tup2(Tup2(k, ()), v))
358 .collect::<Vec<_>>();
359 let mut batcher =
360 <DynOrdZSet<DynPair<DynData, DynData>> as DynBatch>::Batcher::new_batcher(
361 &BatchReaderFactories::new::<Tup2<i32, char>, (), ZWeight>(),
362 (),
363 );
364 batcher.push_batch(&mut Box::new(LeanVec::from(tuples)).erase_box());
365 batcher.seal()
366 });
367
368 let mut outputs = vec![
369 indexed_zset! { 1 => {'a' => 3}, 2 => {'a' => 1, 'c' => 1}},
370 indexed_zset! { 1 => {'e' => 1, 'd' => 1}, 2 => {'a' => -1}, 3 => {'a' => 2}},
371 ]
372 .into_iter();
373
374 circuit
375 .add_source(Generator::new(move || inputs.next().unwrap()))
376 .index_with(
377 &BatchReaderFactories::new::<i32, char, ZWeight>(),
378 |kv: &DynPair<DynData , DynData >, result| {
379 kv.clone_to(result)
380 },
381 )
382 .typed()
383 .inspect(move |fm: &OrdIndexedZSet<_, _>| assert_eq!(fm, &outputs.next().unwrap()));
384 Ok(())
385 })
386 .unwrap()
387 .0;
388
389 for _ in 0..2 {
390 circuit.step().unwrap();
391 }
392 }
393}