Skip to main content

dbsp/operator/
apply3.rs

1//! Ternary operator that applies an arbitrary ternary function to its inputs.
2
3use crate::circuit::{
4    Circuit, OwnershipPreference, Scope, Stream,
5    metadata::OperatorLocation,
6    operator_traits::{Operator, TernaryOperator},
7};
8use std::{borrow::Cow, panic::Location};
9
10impl<C, T1> Stream<C, T1>
11where
12    C: Circuit,
13    T1: Clone + 'static,
14{
15    /// Apply a user-provided ternary function to inputs from three streams at
16    /// each timestamp.
17    #[track_caller]
18    pub fn apply3<F, T2, T3, T4>(
19        &self,
20        other1: &Stream<C, T2>,
21        other2: &Stream<C, T3>,
22        func: F,
23    ) -> Stream<C, T4>
24    where
25        T2: Clone + 'static,
26        T3: Clone + 'static,
27        T4: Clone + 'static,
28        F: Fn(Cow<'_, T1>, Cow<'_, T2>, Cow<'_, T3>) -> T4 + 'static,
29    {
30        self.apply3_with_preference(
31            OwnershipPreference::INDIFFERENT,
32            (other1, OwnershipPreference::INDIFFERENT),
33            (other2, OwnershipPreference::INDIFFERENT),
34            func,
35        )
36    }
37
38    /// Apply a user-provided ternary function to inputs at each
39    /// timestamp.
40    ///
41    /// Allows the caller to specify the ownership preference for
42    /// each input stream.
43    #[track_caller]
44    pub fn apply3_with_preference<F, T2, T3, T4>(
45        &self,
46        self_preference: OwnershipPreference,
47        other1: (&Stream<C, T2>, OwnershipPreference),
48        other2: (&Stream<C, T3>, OwnershipPreference),
49        func: F,
50    ) -> Stream<C, T4>
51    where
52        T2: Clone + 'static,
53        T3: Clone + 'static,
54        T4: Clone + 'static,
55        F: Fn(Cow<'_, T1>, Cow<'_, T2>, Cow<'_, T3>) -> T4 + 'static,
56    {
57        self.circuit().add_ternary_operator_with_preference(
58            Apply3::new(func, Location::caller()),
59            (self, self_preference),
60            other1,
61            other2,
62        )
63    }
64}
65
66/// Applies a user-provided ternary function to its inputs at each timestamp.
67pub struct Apply3<F> {
68    func: F,
69    location: &'static Location<'static>,
70}
71
72impl<F> Apply3<F> {
73    pub const fn new(func: F, location: &'static Location<'static>) -> Self
74    where
75        F: 'static,
76    {
77        Self { func, location }
78    }
79}
80
81impl<F> Operator for Apply3<F>
82where
83    F: 'static,
84{
85    fn name(&self) -> Cow<'static, str> {
86        Cow::Borrowed("Apply3")
87    }
88
89    fn location(&self) -> OperatorLocation {
90        Some(self.location)
91    }
92
93    fn fixedpoint(&self, _scope: Scope) -> bool {
94        // TODO: either change `F` type to `Fn` from `FnMut` or
95        // parameterize the operator with custom fixed point check.
96        unimplemented!();
97    }
98}
99
100impl<T1, T2, T3, T4, F> TernaryOperator<T1, T2, T3, T4> for Apply3<F>
101where
102    F: Fn(Cow<'_, T1>, Cow<'_, T2>, Cow<'_, T3>) -> T4 + 'static,
103    T1: Clone + 'static,
104    T2: Clone + 'static,
105    T3: Clone + 'static,
106{
107    async fn eval(&mut self, i1: Cow<'_, T1>, i2: Cow<'_, T2>, i3: Cow<'_, T3>) -> T4 {
108        (self.func)(i1, i2, i3)
109    }
110}
111
112#[cfg(test)]
113mod test {
114    use crate::{Circuit, RootCircuit, operator::Generator};
115    use std::vec;
116
117    #[test]
118    fn apply3_test() {
119        let circuit: crate::CircuitHandle = RootCircuit::build(move |circuit| {
120            let mut inputs1 = vec![2, 4, 6].into_iter();
121            let mut inputs2 = vec![-1, -2, -3].into_iter();
122            let mut inputs3 = vec![-1, -2, -3].into_iter();
123
124            let source1 = circuit.add_source(Generator::new(move || inputs1.next().unwrap()));
125            let source2 = circuit.add_source(Generator::new(move || inputs2.next().unwrap()));
126            let source3 = circuit.add_source(Generator::new(move || inputs3.next().unwrap()));
127
128            source1
129                .apply3(&source2, &source3, |x, y, z| *x + *y + *z)
130                .inspect(|z| assert_eq!(*z, 0));
131            Ok(())
132        })
133        .unwrap()
134        .0;
135
136        for _ in 0..3 {
137            circuit.transaction().unwrap();
138        }
139    }
140}