1use std::collections::BTreeSet;
2
3use proc_macro2::Span;
4use quote::quote_spanned;
5use syn::spanned::Spanned;
6use syn::token::Colon;
7use syn::{parse_quote_spanned, Expr, Ident, LitInt, LitStr, Pat, PatType};
8
9use super::{
10 OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput, PortIndexValue,
11 PortListSpec, WriteContextArgs, RANGE_0, RANGE_1,
12};
13use crate::diagnostic::{Diagnostic, Level};
14use crate::pretty_span::PrettySpan;
15
16pub const PARTITION: OperatorConstraints = OperatorConstraints {
57 name: "partition",
58 categories: &[OperatorCategory::MultiOut],
59 hard_range_inn: RANGE_1,
60 soft_range_inn: RANGE_1,
61 hard_range_out: &(2..),
62 soft_range_out: &(2..),
63 num_args: 1,
64 persistence_args: RANGE_0,
65 type_args: RANGE_0,
66 is_external_input: false,
67 has_singleton_output: false,
68 ports_inn: None,
69 ports_out: Some(|| PortListSpec::Variadic),
70 input_delaytype_fn: |_| None,
71 write_fn: |wc @ &WriteContextArgs {
72 root,
73 op_span,
74 ident,
75 outputs,
76 is_pull,
77 op_name,
78 op_inst: OperatorInstance { output_ports, .. },
79 arguments,
80 ..
81 },
82 diagnostics| {
83 assert!(!is_pull);
84
85 let mut func = arguments[0].clone();
87
88 let idx_ints = (0..output_ports.len())
89 .map(|i| LitInt::new(&format!("{}_usize", i), op_span))
90 .collect::<Vec<_>>();
91
92 let mut output_sort_permutation: Vec<_> = (0..outputs.len()).collect();
93 let (output_idents, arg2_val) = if let Some(port_idents) =
94 determine_indices_or_idents(output_ports, op_span, op_name, diagnostics)?
95 {
96 let (closure_idents, arg2_span) =
98 extract_closure_idents(&mut func, op_name).map_err(|err| diagnostics.push(err))?;
99 check_closure_ports_match(
100 &closure_idents,
101 &port_idents,
102 op_name,
103 arg2_span,
104 diagnostics,
105 )?;
106 output_sort_permutation.sort_by_key(|&i| {
107 closure_idents
108 .iter()
109 .position(|ident| ident == &port_idents[i])
110 .expect(
111 "Missing port, this should've been caught in the check above, this is a Hydroflow bug.",
112 )
113 });
114 let arg2_val = quote_spanned! {arg2_span.span()=> [ #( #idx_ints ),* ] };
115
116 (closure_idents, arg2_val)
117 } else {
118 let numeric_idents = (0..output_ports.len())
120 .map(|i| wc.make_ident(format!("{}_push", i)))
121 .collect();
122 let len_lit = LitInt::new(&format!("{}_usize", output_ports.len()), op_span);
123 let arg2_val = quote_spanned! {op_span=> #len_lit };
124 (numeric_idents, arg2_val)
125 };
126
127 let err_str = LitStr::new(
128 &format!(
129 "Index `{{}}` returned by `{}(..)` closure is out-of-bounds.",
130 op_name
131 ),
132 op_span,
133 );
134 let ident_item = wc.make_ident("item");
135 let ident_index = wc.make_ident("index");
136 let ident_unknown = wc.make_ident("match_unknown");
137
138 let sorted_outputs = output_sort_permutation.into_iter().map(|i| &outputs[i]);
139
140 let write_iterator = quote_spanned! {op_span=>
141 let #ident = {
142 #root::pusherator::demux::Demux::new(
143 |#ident_item, #root::var_args!( #( #output_idents ),* )| {
144 #[allow(unused_imports)]
145 use #root::pusherator::Pusherator;
146
147 let #ident_index = {
148 #[allow(clippy::redundant_closure_call)]
149 (#func)(&#ident_item, #arg2_val)
150 };
151 match #ident_index {
152 #(
153 #idx_ints => #output_idents.give(#ident_item),
154 )*
155 #ident_unknown => panic!(#err_str, #ident_unknown),
156 };
157 },
158 #root::var_expr!( #( #sorted_outputs ),* ),
159 )
160 };
161 };
162
163 Ok(OperatorWriteOutput {
164 write_iterator,
165 ..Default::default()
166 })
167 },
168};
169
170fn determine_indices_or_idents(
173 output_ports: &[PortIndexValue],
174 op_span: Span,
175 op_name: &'static str,
176 diagnostics: &mut Vec<Diagnostic>,
177) -> Result<Option<Vec<Ident>>, ()> {
178 let mut ports_numeric = BTreeSet::new();
183 let mut ports_idents = Vec::new();
184 let mut err_elided = false;
186 for output_port in output_ports {
187 match output_port {
188 PortIndexValue::Elided(port_span) => {
189 err_elided = true;
190 diagnostics.push(Diagnostic::spanned(
191 port_span.unwrap_or(op_span),
192 Level::Error,
193 format!(
194 "Output ports from `{}` cannot be blank, must be named or indexed.",
195 op_name
196 ),
197 ));
198 }
199 PortIndexValue::Int(port_idx) => {
200 ports_numeric.insert(port_idx);
201
202 if port_idx.value < 0 {
203 diagnostics.push(Diagnostic::spanned(
204 port_idx.span,
205 Level::Error,
206 format!("Output ports from `{}` must be non-nonegative indices starting from zero.", op_name),
207 ));
208 }
209 }
210 PortIndexValue::Path(port_path) => {
211 let port_ident = syn::parse2::<Ident>(quote_spanned!(op_span=> #port_path))
212 .map_err(|err| diagnostics.push(err.into()))?;
213 ports_idents.push(port_ident);
214 }
215 }
216 }
217 if err_elided {
218 return Err(());
219 }
220
221 match (!ports_numeric.is_empty(), !ports_idents.is_empty()) {
222 (false, false) => {
223 assert!(diagnostics.iter().any(Diagnostic::is_error), "Empty input ports, expected an error diagnostic but none were emitted, this is a Hydroflow bug.");
225 Err(())
226 }
227 (true, true) => {
228 let msg = &*format!(
230 "Output ports from `{}` must either be all integer indices or all identifiers.",
231 op_name
232 );
233 diagnostics.extend(
234 output_ports
235 .iter()
236 .map(|output_port| Diagnostic::spanned(output_port.span(), Level::Error, msg)),
237 );
238 Err(())
239 }
240 (true, false) => {
241 let max_port_idx = ports_numeric.last().unwrap().value;
242 if usize::try_from(max_port_idx).unwrap() >= ports_numeric.len() {
243 let mut expected = 0;
244 for port_numeric in ports_numeric {
245 if expected != port_numeric.value {
246 diagnostics.push(Diagnostic::spanned(
247 port_numeric.span,
248 Level::Error,
249 format!(
250 "Output port indices from `{}` must be consecutive from zero, missing {}.",
251 op_name, expected
252 ),
253 ));
254 }
255 expected = port_numeric.value + 1;
256 }
257 }
260 Ok(None)
261 }
262 (false, true) => Ok(Some(ports_idents)),
263 }
264}
265
266fn extract_closure_idents(
268 func: &mut Expr,
269 op_name: &'static str,
270) -> Result<(Vec<Ident>, Span), Diagnostic> {
271 let Expr::Closure(func) = func else {
272 return Err(Diagnostic::spanned(
273 func.span(),
274 Level::Error,
275 "Argument must be a two-argument closure expression",
276 ));
277 };
278 if 2 != func.inputs.len() {
279 return Err(Diagnostic::spanned(
280 func.inputs.span(),
281 Level::Error,
282 &*format!(
283 "Closure provided to `{}(..)` must have two arguments: \
284 the first argument is the item, and for named ports the second argument must contain a Rust 'slice pattern' to determine the port names and order. \
285 For example, the second argument could be `[foo, bar, baz]` for ports `foo`, `bar`, and `baz`.",
286 op_name
287 ),
288 ));
289 }
290
291 let mut arg2 = &mut func.inputs[1];
293 let mut already_has_type = false;
294 if let Pat::Type(pat_type) = arg2 {
295 arg2 = &mut *pat_type.pat;
296 already_has_type = true;
297 }
298
299 let arg2_span = arg2.span();
300 if let Pat::Ident(pat_ident) = arg2 {
301 arg2 = &mut *pat_ident
302 .subpat
303 .as_mut()
304 .ok_or_else(|| Diagnostic::spanned(
305 arg2_span,
306 Level::Error,
307 format!(
308 "Second argument for the `{}` closure must contain a Rust 'slice pattern' to determine the port names and order. \
309 For example: `arr @ [foo, bar, baz]` for ports `foo`, `bar`, and `baz`.",
310 op_name
311 )
312 ))?
313 .1;
314 }
315 let Pat::Slice(pat_slice) = arg2 else {
316 return Err(Diagnostic::spanned(
317 arg2_span,
318 Level::Error,
319 format!(
320 "Second argument for the `{}` closure must have a Rust 'slice pattern' to determine the port names and order. \
321 For example: `[foo, bar, baz]` for ports `foo`, `bar`, and `baz`.",
322 op_name
323 )
324 ));
325 };
326
327 let idents = pat_slice
328 .elems
329 .iter()
330 .map(|pat| {
331 let Pat::Ident(pat_ident) = pat else {
332 panic!("TODO(mingwei) expected ident pat");
333 };
334 pat_ident.ident.clone()
335 })
336 .collect();
337
338 if !already_has_type {
340 let len = LitInt::new(&pat_slice.elems.len().to_string(), arg2_span);
341 *arg2 = Pat::Type(PatType {
342 attrs: vec![],
343 pat: Box::new(arg2.clone()),
344 colon_token: Colon { spans: [arg2_span] },
345 ty: parse_quote_spanned! {arg2_span=> [usize; #len] },
346 });
347 }
348
349 Ok((idents, arg2_span))
350}
351
352fn check_closure_ports_match(
354 closure_idents: &[Ident],
355 port_idents: &[Ident],
356 op_name: &'static str,
357 arg2_span: Span,
358 diagnostics: &mut Vec<Diagnostic>,
359) -> Result<(), ()> {
360 let mut err = false;
361 for port_ident in port_idents {
362 if !closure_idents.contains(port_ident) {
363 err = true;
365 diagnostics.push(Diagnostic::spanned(
366 arg2_span,
367 Level::Error,
368 format!(
369 "Argument specifying the output ports in `{0}(..)` does not contain extra port `{1}`: ({2}) (1/2).",
370 op_name, port_ident, PrettySpan(port_ident.span()),
371 ),
372 ));
373 diagnostics.push(Diagnostic::spanned(
374 port_ident.span(),
375 Level::Error,
376 format!(
377 "Port `{1}` not found in the arguments specified in `{0}(..)`'s closure: ({2}) (2/2).",
378 op_name, port_ident, PrettySpan(arg2_span),
379 ),
380 ));
381 }
382 }
383 for closure_ident in closure_idents {
384 if !port_idents.contains(closure_ident) {
385 err = true;
387 diagnostics.push(Diagnostic::spanned(
388 closure_ident.span(),
389 Level::Error,
390 format!(
391 "`{}(..)` closure argument `{}` missing corresponding output port.",
392 op_name, closure_ident,
393 ),
394 ));
395 }
396 }
397 (!err).then_some(()).ok_or(())
398}