hydroflow_lang/graph/ops/
partition.rs

1use std::collections::BTreeSet;
2
3use proc_macro2::Span;
4use quote::quote_spanned;
5use syn::spanned::Spanned;
6use syn::token::Colon;
7use syn::{parse_quote_spanned, Expr, Ident, LitInt, LitStr, Pat, PatType};
8
9use super::{
10    OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput, PortIndexValue,
11    PortListSpec, WriteContextArgs, RANGE_0, RANGE_1,
12};
13use crate::diagnostic::{Diagnostic, Level};
14use crate::pretty_span::PrettySpan;
15
16/// This operator takes the input pipeline and allows the user to determine which singular output
17/// pipeline each item should be delivered to.
18///
19/// > Arguments: A Rust closure, the first argument is a reference to the item and the second
20/// > argument corresponds to one of two modes, either named or indexed.
21///
22/// > Note: The closure has access to the [`context` object](surface_flows.mdx#the-context-object).
23///
24/// # Named mode
25/// With named ports, the closure's second argument must be a Rust 'slice pattern' of names, such as
26/// `[port_a, port_b, port_c]`, where each name is an output port. The closure should return the
27/// name of the desired output port.
28///
29/// ```hydroflow
30/// my_partition = source_iter(1..=100) -> partition(|val: &usize, [fzbz, fizz, buzz, rest]|
31///     match (val % 3, val % 5) {
32///         (0, 0) => fzbz,
33///         (0, _) => fizz,
34///         (_, 0) => buzz,
35///         (_, _) => rest,
36///     }
37/// );
38/// my_partition[fzbz] -> for_each(|v| println!("{}: fizzbuzz", v));
39/// my_partition[fizz] -> for_each(|v| println!("{}: fizz", v));
40/// my_partition[buzz] -> for_each(|v| println!("{}: buzz", v));
41/// my_partition[rest] -> for_each(|v| println!("{}", v));
42/// ```
43///
44/// # Indexed mode
45/// With indexed mode, the closure's second argument is a the number of output ports. This is a
46/// single usize value, useful for e.g. round robin partitioning. Each output pipeline port must be
47/// numbered with an index, starting from zero and with no gaps. The closure returns the index of
48/// the desired output port.
49///
50/// ```hydroflow
51/// my_partition = source_iter(1..=100) -> partition(|val, num_outputs| val % num_outputs);
52/// my_partition[0] -> for_each(|v| println!("0: {}", v));
53/// my_partition[1] -> for_each(|v| println!("1: {}", v));
54/// my_partition[2] -> for_each(|v| println!("2: {}", v));
55/// ```
56pub const PARTITION: OperatorConstraints = OperatorConstraints {
57    name: "partition",
58    categories: &[OperatorCategory::MultiOut],
59    hard_range_inn: RANGE_1,
60    soft_range_inn: RANGE_1,
61    hard_range_out: &(2..),
62    soft_range_out: &(2..),
63    num_args: 1,
64    persistence_args: RANGE_0,
65    type_args: RANGE_0,
66    is_external_input: false,
67    has_singleton_output: false,
68    ports_inn: None,
69    ports_out: Some(|| PortListSpec::Variadic),
70    input_delaytype_fn: |_| None,
71    write_fn: |wc @ &WriteContextArgs {
72                   root,
73                   op_span,
74                   ident,
75                   outputs,
76                   is_pull,
77                   op_name,
78                   op_inst: OperatorInstance { output_ports, .. },
79                   arguments,
80                   ..
81               },
82               diagnostics| {
83        assert!(!is_pull);
84
85        // Clone because we may modify the closure's arg2 to inject the type.
86        let mut func = arguments[0].clone();
87
88        let idx_ints = (0..output_ports.len())
89            .map(|i| LitInt::new(&format!("{}_usize", i), op_span))
90            .collect::<Vec<_>>();
91
92        let mut output_sort_permutation: Vec<_> = (0..outputs.len()).collect();
93        let (output_idents, arg2_val) = if let Some(port_idents) =
94            determine_indices_or_idents(output_ports, op_span, op_name, diagnostics)?
95        {
96            // All idents.
97            let (closure_idents, arg2_span) =
98                extract_closure_idents(&mut func, op_name).map_err(|err| diagnostics.push(err))?;
99            check_closure_ports_match(
100                &closure_idents,
101                &port_idents,
102                op_name,
103                arg2_span,
104                diagnostics,
105            )?;
106            output_sort_permutation.sort_by_key(|&i| {
107                closure_idents
108                    .iter()
109                    .position(|ident| ident == &port_idents[i])
110                    .expect(
111                        "Missing port, this should've been caught in the check above, this is a Hydroflow bug.",
112                    )
113            });
114            let arg2_val = quote_spanned! {arg2_span.span()=> [ #( #idx_ints ),* ] };
115
116            (closure_idents, arg2_val)
117        } else {
118            // All indices.
119            let numeric_idents = (0..output_ports.len())
120                .map(|i| wc.make_ident(format!("{}_push", i)))
121                .collect();
122            let len_lit = LitInt::new(&format!("{}_usize", output_ports.len()), op_span);
123            let arg2_val = quote_spanned! {op_span=> #len_lit };
124            (numeric_idents, arg2_val)
125        };
126
127        let err_str = LitStr::new(
128            &format!(
129                "Index `{{}}` returned by `{}(..)` closure is out-of-bounds.",
130                op_name
131            ),
132            op_span,
133        );
134        let ident_item = wc.make_ident("item");
135        let ident_index = wc.make_ident("index");
136        let ident_unknown = wc.make_ident("match_unknown");
137
138        let sorted_outputs = output_sort_permutation.into_iter().map(|i| &outputs[i]);
139
140        let write_iterator = quote_spanned! {op_span=>
141            let #ident = {
142                #root::pusherator::demux::Demux::new(
143                    |#ident_item, #root::var_args!( #( #output_idents ),* )| {
144                        #[allow(unused_imports)]
145                        use #root::pusherator::Pusherator;
146
147                        let #ident_index = {
148                            #[allow(clippy::redundant_closure_call)]
149                            (#func)(&#ident_item, #arg2_val)
150                        };
151                        match #ident_index {
152                            #(
153                                #idx_ints => #output_idents.give(#ident_item),
154                            )*
155                            #ident_unknown => panic!(#err_str, #ident_unknown),
156                        };
157                    },
158                    #root::var_expr!( #( #sorted_outputs ),* ),
159                )
160            };
161        };
162
163        Ok(OperatorWriteOutput {
164            write_iterator,
165            ..Default::default()
166        })
167    },
168};
169
170/// Returns `Ok(Some(idents))` if ports are idents, or `Ok(None)` if ports are indices.
171/// Returns `Err(())` if there are any errors (pushed to `diagnostics`).
172fn determine_indices_or_idents(
173    output_ports: &[PortIndexValue],
174    op_span: Span,
175    op_name: &'static str,
176    diagnostics: &mut Vec<Diagnostic>,
177) -> Result<Option<Vec<Ident>>, ()> {
178    // Port idents supplied via port connections in the surface syntax.
179    // Two modes, either all numeric `0, 1, 2, 3, ...` or all `Ident`s.
180    // If ports are `Idents` then the closure's 2nd argument, input array must have named
181    // values corresponding to the port idents.
182    let mut ports_numeric = BTreeSet::new();
183    let mut ports_idents = Vec::new();
184    // If any ports are elided we return `Err(())` early.
185    let mut err_elided = false;
186    for output_port in output_ports {
187        match output_port {
188            PortIndexValue::Elided(port_span) => {
189                err_elided = true;
190                diagnostics.push(Diagnostic::spanned(
191                    port_span.unwrap_or(op_span),
192                    Level::Error,
193                    format!(
194                        "Output ports from `{}` cannot be blank, must be named or indexed.",
195                        op_name
196                    ),
197                ));
198            }
199            PortIndexValue::Int(port_idx) => {
200                ports_numeric.insert(port_idx);
201
202                if port_idx.value < 0 {
203                    diagnostics.push(Diagnostic::spanned(
204                        port_idx.span,
205                        Level::Error,
206                        format!("Output ports from `{}` must be non-nonegative indices starting from zero.", op_name),
207                    ));
208                }
209            }
210            PortIndexValue::Path(port_path) => {
211                let port_ident = syn::parse2::<Ident>(quote_spanned!(op_span=> #port_path))
212                    .map_err(|err| diagnostics.push(err.into()))?;
213                ports_idents.push(port_ident);
214            }
215        }
216    }
217    if err_elided {
218        return Err(());
219    }
220
221    match (!ports_numeric.is_empty(), !ports_idents.is_empty()) {
222        (false, false) => {
223            // Had no ports or only elided ports.
224            assert!(diagnostics.iter().any(Diagnostic::is_error), "Empty input ports, expected an error diagnostic but none were emitted, this is a Hydroflow bug.");
225            Err(())
226        }
227        (true, true) => {
228            // Conflict.
229            let msg = &*format!(
230                "Output ports from `{}` must either be all integer indices or all identifiers.",
231                op_name
232            );
233            diagnostics.extend(
234                output_ports
235                    .iter()
236                    .map(|output_port| Diagnostic::spanned(output_port.span(), Level::Error, msg)),
237            );
238            Err(())
239        }
240        (true, false) => {
241            let max_port_idx = ports_numeric.last().unwrap().value;
242            if usize::try_from(max_port_idx).unwrap() >= ports_numeric.len() {
243                let mut expected = 0;
244                for port_numeric in ports_numeric {
245                    if expected != port_numeric.value {
246                        diagnostics.push(Diagnostic::spanned(
247                            port_numeric.span,
248                            Level::Error,
249                            format!(
250                                "Output port indices from `{}` must be consecutive from zero, missing {}.",
251                                op_name, expected
252                            ),
253                        ));
254                    }
255                    expected = port_numeric.value + 1;
256                }
257                // Can continue with code gen, port numbers will be treated as if they're
258                // consecutive from their ascending order.
259            }
260            Ok(None)
261        }
262        (false, true) => Ok(Some(ports_idents)),
263    }
264}
265
266// Returns a vec of closure idents and the arg2 span.
267fn extract_closure_idents(
268    func: &mut Expr,
269    op_name: &'static str,
270) -> Result<(Vec<Ident>, Span), Diagnostic> {
271    let Expr::Closure(func) = func else {
272        return Err(Diagnostic::spanned(
273            func.span(),
274            Level::Error,
275            "Argument must be a two-argument closure expression",
276        ));
277    };
278    if 2 != func.inputs.len() {
279        return Err(Diagnostic::spanned(
280            func.inputs.span(),
281            Level::Error,
282            &*format!(
283                "Closure provided to `{}(..)` must have two arguments: \
284                the first argument is the item, and for named ports the second argument must contain a Rust 'slice pattern' to determine the port names and order. \
285                For example, the second argument could be `[foo, bar, baz]` for ports `foo`, `bar`, and `baz`.",
286                op_name
287            ),
288        ));
289    }
290
291    // Port idents specified in the closure's second argument.
292    let mut arg2 = &mut func.inputs[1];
293    let mut already_has_type = false;
294    if let Pat::Type(pat_type) = arg2 {
295        arg2 = &mut *pat_type.pat;
296        already_has_type = true;
297    }
298
299    let arg2_span = arg2.span();
300    if let Pat::Ident(pat_ident) = arg2 {
301        arg2 = &mut *pat_ident
302            .subpat
303            .as_mut()
304            .ok_or_else(|| Diagnostic::spanned(
305                arg2_span,
306                Level::Error,
307                format!(
308                    "Second argument for the `{}` closure must contain a Rust 'slice pattern' to determine the port names and order. \
309                    For example: `arr @ [foo, bar, baz]` for ports `foo`, `bar`, and `baz`.",
310                    op_name
311                )
312            ))?
313            .1;
314    }
315    let Pat::Slice(pat_slice) = arg2 else {
316        return Err(Diagnostic::spanned(
317            arg2_span,
318            Level::Error,
319            format!(
320                "Second argument for the `{}` closure must have a Rust 'slice pattern' to determine the port names and order. \
321                For example: `[foo, bar, baz]` for ports `foo`, `bar`, and `baz`.",
322                op_name
323            )
324        ));
325    };
326
327    let idents = pat_slice
328        .elems
329        .iter()
330        .map(|pat| {
331            let Pat::Ident(pat_ident) = pat else {
332                panic!("TODO(mingwei) expected ident pat");
333            };
334            pat_ident.ident.clone()
335        })
336        .collect();
337
338    // Last step: set the type `[a, b, c]: [usize; 3]` if it is not already specified.
339    if !already_has_type {
340        let len = LitInt::new(&pat_slice.elems.len().to_string(), arg2_span);
341        *arg2 = Pat::Type(PatType {
342            attrs: vec![],
343            pat: Box::new(arg2.clone()),
344            colon_token: Colon { spans: [arg2_span] },
345            ty: parse_quote_spanned! {arg2_span=> [usize; #len] },
346        });
347    }
348
349    Ok((idents, arg2_span))
350}
351
352// Checks that the closure names and output port names match.
353fn check_closure_ports_match(
354    closure_idents: &[Ident],
355    port_idents: &[Ident],
356    op_name: &'static str,
357    arg2_span: Span,
358    diagnostics: &mut Vec<Diagnostic>,
359) -> Result<(), ()> {
360    let mut err = false;
361    for port_ident in port_idents {
362        if !closure_idents.contains(port_ident) {
363            // An output port is missing from the closure args.
364            err = true;
365            diagnostics.push(Diagnostic::spanned(
366                arg2_span,
367                Level::Error,
368                format!(
369                    "Argument specifying the output ports in `{0}(..)` does not contain extra port `{1}`: ({2}) (1/2).",
370                    op_name, port_ident, PrettySpan(port_ident.span()),
371                ),
372            ));
373            diagnostics.push(Diagnostic::spanned(
374                port_ident.span(),
375                Level::Error,
376                format!(
377                    "Port `{1}` not found in the arguments specified in `{0}(..)`'s closure: ({2}) (2/2).",
378                    op_name, port_ident, PrettySpan(arg2_span),
379                ),
380            ));
381        }
382    }
383    for closure_ident in closure_idents {
384        if !port_idents.contains(closure_ident) {
385            // A closure arg is missing from the output ports.
386            err = true;
387            diagnostics.push(Diagnostic::spanned(
388                closure_ident.span(),
389                Level::Error,
390                format!(
391                    "`{}(..)` closure argument `{}` missing corresponding output port.",
392                    op_name, closure_ident,
393                ),
394            ));
395        }
396    }
397    (!err).then_some(()).ok_or(())
398}