1use std::collections::HashMap;
26use std::sync::Arc;
27
28use crate::ast::{Expr, FnBody, FnDef, Literal, MatchArm, Pattern, Spanned, Stmt, TailCallData};
29
30#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum BufferBuildKind {
53 InternalReverse,
54 ExternalReverse,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
59pub struct BufferBuildShape {
60 pub acc_param_idx: usize,
64 pub acc_param_name: String,
67 pub kind: BufferBuildKind,
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum ConsumerKind {
86 StringJoin,
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
95pub struct FusionSite {
96 pub enclosing_fn: String,
98 pub line: usize,
100 pub sink_fn: String,
102 pub consumer: ConsumerKind,
104}
105
106pub fn compute_buffer_build_sinks(fns: &[&FnDef]) -> HashMap<String, BufferBuildShape> {
110 let mut out = HashMap::new();
111 for fd in fns {
112 if let Some(shape) = match_buffer_build_shape(fd) {
113 out.insert(fd.name.clone(), shape);
114 }
115 }
116 out
117}
118
119pub fn find_fusion_sites(
125 fns: &[&FnDef],
126 sinks: &HashMap<String, BufferBuildShape>,
127) -> Vec<FusionSite> {
128 let mut out = Vec::new();
129 for fd in fns {
130 for stmt in fd.body.stmts() {
131 match stmt {
132 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
133 walk_expr_for_fusion_sites(&expr.node, expr.line, &fd.name, sinks, &mut out);
134 }
135 }
136 }
137 }
138 out
139}
140
141fn walk_expr_for_fusion_sites(
145 expr: &Expr,
146 expr_line: usize,
147 enclosing_fn: &str,
148 sinks: &HashMap<String, BufferBuildShape>,
149 out: &mut Vec<FusionSite>,
150) {
151 if let Some(inner_name) = match_string_join_fusion_site(expr, sinks) {
152 out.push(FusionSite {
153 enclosing_fn: enclosing_fn.to_string(),
154 line: expr_line,
155 sink_fn: inner_name,
156 consumer: ConsumerKind::StringJoin,
157 });
158 }
159 visit_subexprs(expr, expr_line, enclosing_fn, sinks, out);
163}
164
165fn visit_subexprs(
169 expr: &Expr,
170 fallback_line: usize,
171 enclosing_fn: &str,
172 sinks: &HashMap<String, BufferBuildShape>,
173 out: &mut Vec<FusionSite>,
174) {
175 let line_of = |s: &crate::ast::Spanned<Expr>| {
176 if s.line > 0 { s.line } else { fallback_line }
177 };
178 match expr {
179 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
180 Expr::Constructor(_, Some(inner)) | Expr::Attr(inner, _) | Expr::ErrorProp(inner) => {
181 walk_expr_for_fusion_sites(&inner.node, line_of(inner), enclosing_fn, sinks, out);
182 }
183 Expr::FnCall(callee, args) => {
184 walk_expr_for_fusion_sites(&callee.node, line_of(callee), enclosing_fn, sinks, out);
185 for a in args {
186 walk_expr_for_fusion_sites(&a.node, line_of(a), enclosing_fn, sinks, out);
187 }
188 }
189 Expr::TailCall(data) => {
190 for a in &data.args {
191 walk_expr_for_fusion_sites(&a.node, line_of(a), enclosing_fn, sinks, out);
192 }
193 }
194 Expr::BinOp(_, l, r) => {
195 walk_expr_for_fusion_sites(&l.node, line_of(l), enclosing_fn, sinks, out);
196 walk_expr_for_fusion_sites(&r.node, line_of(r), enclosing_fn, sinks, out);
197 }
198 Expr::Match { subject, arms } => {
199 walk_expr_for_fusion_sites(&subject.node, line_of(subject), enclosing_fn, sinks, out);
200 for arm in arms {
201 walk_expr_for_fusion_sites(
202 &arm.body.node,
203 line_of(&arm.body),
204 enclosing_fn,
205 sinks,
206 out,
207 );
208 }
209 }
210 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
211 for it in items {
212 walk_expr_for_fusion_sites(&it.node, line_of(it), enclosing_fn, sinks, out);
213 }
214 }
215 Expr::MapLiteral(entries) => {
216 for (k, v) in entries {
217 walk_expr_for_fusion_sites(&k.node, line_of(k), enclosing_fn, sinks, out);
218 walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
219 }
220 }
221 Expr::RecordCreate { fields, .. } => {
222 for (_, v) in fields {
223 walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
224 }
225 }
226 Expr::RecordUpdate { base, updates, .. } => {
227 walk_expr_for_fusion_sites(&base.node, line_of(base), enclosing_fn, sinks, out);
228 for (_, v) in updates {
229 walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
230 }
231 }
232 Expr::InterpolatedStr(parts) => {
233 for part in parts {
234 if let crate::ast::StrPart::Parsed(inner) = part {
235 walk_expr_for_fusion_sites(
236 &inner.node,
237 line_of(inner),
238 enclosing_fn,
239 sinks,
240 out,
241 );
242 }
243 }
244 }
245 }
246}
247
248fn match_buffer_build_shape(fd: &FnDef) -> Option<BufferBuildShape> {
250 let (acc_idx, acc_name) = fd
263 .params
264 .iter()
265 .enumerate()
266 .rfind(|(_, (_, ty))| is_list_type_str(ty))
267 .map(|(i, (name, _))| (i, name.clone()))?;
268
269 let match_expr = single_match_body(&fd.body)?;
271 let (subject_expr, arms) = match match_expr {
272 Expr::Match { subject, arms } => (subject, arms),
273 _ => return None,
274 };
275
276 if let Some((true_body, false_body)) = pair_bool_arms(arms) {
278 let _ = subject_expr;
279 if is_list_reverse_of(true_body, &acc_name)
280 && is_self_tail_with_prepend_acc(false_body, &fd.name, acc_idx, &acc_name)
281 {
282 return Some(BufferBuildShape {
283 acc_param_idx: acc_idx,
284 acc_param_name: acc_name,
285 kind: BufferBuildKind::InternalReverse,
286 });
287 }
288 }
289
290 if let Some((nil_body, cons_body)) = pair_nil_cons_arms(arms)
295 && is_ident_named(nil_body, &acc_name)
296 && is_self_tail_with_prepend_acc(cons_body, &fd.name, acc_idx, &acc_name)
297 {
298 return Some(BufferBuildShape {
299 acc_param_idx: acc_idx,
300 acc_param_name: acc_name,
301 kind: BufferBuildKind::ExternalReverse,
302 });
303 }
304
305 None
306}
307
308fn pair_nil_cons_arms(arms: &[MatchArm]) -> Option<(&Expr, &Expr)> {
313 if arms.len() != 2 {
314 return None;
315 }
316 let mut nil_body: Option<&Expr> = None;
317 let mut cons_body: Option<&Expr> = None;
318 for arm in arms {
319 match &arm.pattern {
320 Pattern::EmptyList => nil_body = Some(&arm.body.node),
321 Pattern::Cons(_, _) => cons_body = Some(&arm.body.node),
322 _ => return None,
323 }
324 }
325 match (nil_body, cons_body) {
326 (Some(n), Some(c)) => Some((n, c)),
327 _ => None,
328 }
329}
330
331fn is_ident_named(expr: &Expr, name: &str) -> bool {
333 matches!(expr, Expr::Ident(n) if n == name)
334}
335
336fn match_string_join_fusion_site(
362 expr: &Expr,
363 sinks: &HashMap<String, BufferBuildShape>,
364) -> Option<String> {
365 let Expr::FnCall(callee, args) = expr else {
366 return None;
367 };
368 if !is_dotted_ident(&callee.node, "String", "join") || args.len() != 2 {
369 return None;
370 }
371 let consumer_arg = &args[0].node;
372
373 let (inner_call_expr, saw_external_reverse) = match consumer_arg {
375 Expr::FnCall(rev_callee, rev_args)
376 if is_dotted_ident(&rev_callee.node, "List", "reverse") && rev_args.len() == 1 =>
377 {
378 (&rev_args[0].node, true)
379 }
380 other => (other, false),
381 };
382
383 let Expr::FnCall(inner_callee, inner_args) = inner_call_expr else {
384 return None;
385 };
386 let Expr::Ident(name) = &inner_callee.node else {
387 return None;
388 };
389 let shape = sinks.get(name)?;
390
391 let kinds_align = matches!(
392 (saw_external_reverse, &shape.kind),
393 (false, BufferBuildKind::InternalReverse) | (true, BufferBuildKind::ExternalReverse)
394 );
395 if !kinds_align {
396 return None;
397 }
398
399 let acc_arg = inner_args.get(shape.acc_param_idx)?;
400 if !matches!(&acc_arg.node, Expr::List(items) if items.is_empty()) {
401 return None;
402 }
403
404 Some(name.clone())
405}
406
407fn is_list_type_str(ty: &str) -> bool {
409 let t = ty.trim();
410 t.starts_with("List<") && t.ends_with('>')
411}
412
413fn single_match_body(body: &FnBody) -> Option<&Expr> {
417 let stmts = body.stmts();
418 if stmts.len() != 1 {
419 return None;
420 }
421 match &stmts[0] {
422 Stmt::Expr(spanned) => match &spanned.node {
423 Expr::Match { .. } => Some(&spanned.node),
424 _ => None,
425 },
426 Stmt::Binding(_, _, _) => None,
427 }
428}
429
430fn pair_bool_arms(arms: &[MatchArm]) -> Option<(&Expr, &Expr)> {
434 if arms.len() != 2 {
435 return None;
436 }
437 let mut t = None;
438 let mut f = None;
439 for arm in arms {
440 match &arm.pattern {
441 Pattern::Literal(Literal::Bool(true)) => {
442 if t.is_some() {
443 return None;
444 }
445 t = Some(&arm.body.node);
446 }
447 Pattern::Literal(Literal::Bool(false)) => {
448 if f.is_some() {
449 return None;
450 }
451 f = Some(&arm.body.node);
452 }
453 _ => return None,
454 }
455 }
456 Some((t?, f?))
457}
458
459fn is_list_reverse_of(expr: &Expr, acc_name: &str) -> bool {
461 let (callee, args) = match expr {
462 Expr::FnCall(c, a) => (c, a),
463 _ => return false,
464 };
465 if !is_dotted_ident(&callee.node, "List", "reverse") {
466 return false;
467 }
468 if args.len() != 1 {
469 return false;
470 }
471 matches!(&args[0].node, Expr::Ident(name) if name == acc_name)
472}
473
474fn is_self_tail_with_prepend_acc(
480 expr: &Expr,
481 self_name: &str,
482 acc_idx: usize,
483 acc_name: &str,
484) -> bool {
485 let data = match expr {
486 Expr::TailCall(data) => data,
487 _ => return false,
488 };
489 if data.target != self_name {
490 return false;
491 }
492 let acc_arg = match data.args.get(acc_idx) {
501 Some(a) => a,
502 None => return false,
503 };
504 is_list_prepend_to_acc(&acc_arg.node, acc_name)
505}
506
507fn is_list_prepend_to_acc(expr: &Expr, acc_name: &str) -> bool {
509 let (callee, args) = match expr {
510 Expr::FnCall(c, a) => (c, a),
511 _ => return false,
512 };
513 if !is_dotted_ident(&callee.node, "List", "prepend") {
514 return false;
515 }
516 if args.len() != 2 {
517 return false;
518 }
519 matches!(&args[1].node, Expr::Ident(name) if name == acc_name)
520}
521
522fn is_dotted_ident(expr: &Expr, module: &str, member: &str) -> bool {
525 let (base, attr) = match expr {
526 Expr::Attr(b, a) => (b, a),
527 _ => return false,
528 };
529 if attr != member {
530 return false;
531 }
532 matches!(&base.node, Expr::Ident(name) if name == module)
533}
534
535pub fn synthesize_buffered_variants(
573 fns: &[&FnDef],
574 sinks: &HashMap<String, BufferBuildShape>,
575) -> Vec<FnDef> {
576 let mut out = Vec::new();
577 for fd in fns {
578 if let Some(shape) = sinks.get(&fd.name)
579 && let Some(buffered) = build_buffered_variant(fd, shape)
580 {
581 out.push(buffered);
582 }
583 }
584 out
585}
586
587fn sp_at(line: usize, expr: Expr) -> Spanned<Expr> {
592 Spanned { node: expr, line }
593}
594
595fn intrinsic_call(line: usize, name: &str, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
600 let callee = sp_at(line, Expr::Ident(name.to_string()));
601 sp_at(line, Expr::FnCall(Box::new(callee), args))
602}
603
604pub fn run_buffer_build_pass(items: &mut Vec<crate::ast::TopLevel>) -> BufferBuildPassReport {
616 let fn_refs: Vec<&FnDef> = items
617 .iter()
618 .filter_map(|it| match it {
619 crate::ast::TopLevel::FnDef(fd) => Some(fd),
620 _ => None,
621 })
622 .collect();
623 let all_sinks = compute_buffer_build_sinks(&fn_refs);
624 if all_sinks.is_empty() {
625 return BufferBuildPassReport::default();
626 }
627 let sites = find_fusion_sites(&fn_refs, &all_sinks);
628
629 let mut used_sinks: HashMap<String, BufferBuildShape> = HashMap::new();
637 for site in &sites {
638 if let Some(shape) = all_sinks.get(&site.sink_fn) {
639 used_sinks.insert(site.sink_fn.clone(), shape.clone());
640 }
641 }
642 let synthesized = synthesize_buffered_variants(&fn_refs, &used_sinks);
643 let sinks = used_sinks;
644 drop(fn_refs);
645
646 let mut fn_defs_owned: Vec<&mut FnDef> = items
647 .iter_mut()
648 .filter_map(|it| match it {
649 crate::ast::TopLevel::FnDef(fd) => Some(fd),
650 _ => None,
651 })
652 .collect();
653 for fd in fn_defs_owned.iter_mut() {
657 rewrite_one_fn(fd, &sinks);
658 }
659
660 items.reserve(synthesized.len());
661 for fd in synthesized.iter() {
662 items.push(crate::ast::TopLevel::FnDef(fd.clone()));
663 }
664
665 let mut sink_fns: Vec<String> = sinks.keys().cloned().collect();
666 sink_fns.sort();
667 let synthesized_fns: Vec<String> = synthesized.iter().map(|fd| fd.name.clone()).collect();
668
669 let mut rewrites_by_sink: std::collections::BTreeMap<String, usize> =
670 std::collections::BTreeMap::new();
671 for site in &sites {
672 *rewrites_by_sink.entry(site.sink_fn.clone()).or_default() += 1;
673 }
674
675 BufferBuildPassReport {
676 rewrites: sites.len(),
677 synthesized: synthesized_fns,
678 sink_fns,
679 rewrites_by_sink,
680 }
681}
682
683#[derive(Debug, Clone, Default)]
688pub struct BufferBuildPassReport {
689 pub rewrites: usize,
691 pub synthesized: Vec<String>,
694 pub sink_fns: Vec<String>,
697 pub rewrites_by_sink: std::collections::BTreeMap<String, usize>,
699}
700
701fn rewrite_one_fn(fd: &mut FnDef, sinks: &HashMap<String, BufferBuildShape>) {
705 let body_arc = std::sync::Arc::make_mut(&mut fd.body);
706 let FnBody::Block(stmts) = body_arc;
707 for stmt in stmts.iter_mut() {
708 match stmt {
709 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
710 rewrite_expr_in_place(expr, sinks);
711 }
712 }
713 }
714}
715
716pub fn rewrite_fusion_sites(fn_defs: &mut [FnDef], sinks: &HashMap<String, BufferBuildShape>) {
728 if sinks.is_empty() {
729 return;
730 }
731 for fd in fn_defs.iter_mut() {
732 let body_arc = std::sync::Arc::make_mut(&mut fd.body);
733 let FnBody::Block(stmts) = body_arc;
734 for stmt in stmts.iter_mut() {
735 match stmt {
736 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
737 rewrite_expr_in_place(expr, sinks);
738 }
739 }
740 }
741 }
742}
743
744fn rewrite_expr_in_place(expr: &mut Spanned<Expr>, sinks: &HashMap<String, BufferBuildShape>) {
749 if let Some(replacement) = try_rewrite_fusion_site(expr, sinks) {
750 *expr = replacement;
751 descend_into_subexprs(expr, sinks);
755 return;
756 }
757 descend_into_subexprs(expr, sinks);
758}
759
760fn descend_into_subexprs(expr: &mut Spanned<Expr>, sinks: &HashMap<String, BufferBuildShape>) {
764 match &mut expr.node {
765 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
766 Expr::Constructor(_, Some(inner)) | Expr::Attr(inner, _) | Expr::ErrorProp(inner) => {
767 rewrite_expr_in_place(inner, sinks);
768 }
769 Expr::FnCall(callee, args) => {
770 rewrite_expr_in_place(callee, sinks);
771 for a in args.iter_mut() {
772 rewrite_expr_in_place(a, sinks);
773 }
774 }
775 Expr::TailCall(data) => {
776 for a in data.args.iter_mut() {
777 rewrite_expr_in_place(a, sinks);
778 }
779 }
780 Expr::BinOp(_, l, r) => {
781 rewrite_expr_in_place(l, sinks);
782 rewrite_expr_in_place(r, sinks);
783 }
784 Expr::Match { subject, arms } => {
785 rewrite_expr_in_place(subject, sinks);
786 for arm in arms.iter_mut() {
787 rewrite_expr_in_place(&mut arm.body, sinks);
788 }
789 }
790 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
791 for it in items.iter_mut() {
792 rewrite_expr_in_place(it, sinks);
793 }
794 }
795 Expr::MapLiteral(entries) => {
796 for (k, v) in entries.iter_mut() {
797 rewrite_expr_in_place(k, sinks);
798 rewrite_expr_in_place(v, sinks);
799 }
800 }
801 Expr::RecordCreate { fields, .. } => {
802 for (_, v) in fields.iter_mut() {
803 rewrite_expr_in_place(v, sinks);
804 }
805 }
806 Expr::RecordUpdate { base, updates, .. } => {
807 rewrite_expr_in_place(base, sinks);
808 for (_, v) in updates.iter_mut() {
809 rewrite_expr_in_place(v, sinks);
810 }
811 }
812 Expr::InterpolatedStr(parts) => {
813 for part in parts.iter_mut() {
814 if let crate::ast::StrPart::Parsed(inner) = part {
815 rewrite_expr_in_place(inner, sinks);
816 }
817 }
818 }
819 }
820}
821
822fn try_rewrite_fusion_site(
826 expr: &Spanned<Expr>,
827 sinks: &HashMap<String, BufferBuildShape>,
828) -> Option<Spanned<Expr>> {
829 let line = expr.line;
830
831 let sink_name = match_string_join_fusion_site(&expr.node, sinks)?;
834 let shape = sinks.get(&sink_name)?;
835
836 let outer_args = match &expr.node {
840 Expr::FnCall(_, a) => a,
841 _ => return None,
842 };
843 let consumer_arg = &outer_args[0].node;
844 let inner_call_expr = if let Expr::FnCall(rev_callee, rev_args) = consumer_arg
845 && is_dotted_ident(&rev_callee.node, "List", "reverse")
846 && rev_args.len() == 1
847 {
848 &rev_args[0].node
849 } else {
850 consumer_arg
851 };
852 let inner_args = match inner_call_expr {
853 Expr::FnCall(_, a) => a,
854 _ => return None,
855 };
856
857 let sep_expr = outer_args[1].clone();
866 let buf_new = intrinsic_call(
867 line,
868 "__buf_new",
869 vec![sp_at(line, Expr::Literal(Literal::Int(8192)))],
870 );
871 let mut buffered_args: Vec<Spanned<Expr>> = inner_args
872 .iter()
873 .enumerate()
874 .filter_map(|(i, a)| (i != shape.acc_param_idx).then_some(a).cloned())
875 .collect();
876 buffered_args.push(buf_new);
877 buffered_args.push(sep_expr);
878 let buffered_call = sp_at(
879 line,
880 Expr::FnCall(
881 Box::new(sp_at(line, Expr::Ident(format!("{}__buffered", sink_name)))),
882 buffered_args,
883 ),
884 );
885 Some(intrinsic_call(line, "__buf_finalize", vec![buffered_call]))
886}
887
888fn build_buffered_variant(fd: &FnDef, shape: &BufferBuildShape) -> Option<FnDef> {
893 let stmts = fd.body.stmts();
900 if stmts.len() != 1 {
901 return None;
902 }
903 let outer_expr = match &stmts[0] {
904 Stmt::Expr(spanned) => spanned,
905 _ => return None,
906 };
907 let (subject_orig, arms_orig) = match &outer_expr.node {
908 Expr::Match { subject, arms } => (subject, arms),
909 _ => return None,
910 };
911 let recursive_body: &Spanned<Expr> = match shape.kind {
912 BufferBuildKind::InternalReverse => arms_orig
913 .iter()
914 .find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(false))))
915 .map(|a| a.body.as_ref())?,
916 BufferBuildKind::ExternalReverse => arms_orig
917 .iter()
918 .find(|a| matches!(a.pattern, Pattern::Cons(_, _)))
919 .map(|a| a.body.as_ref())?,
920 };
921 let tail_data = match &recursive_body.node {
922 Expr::TailCall(data) => data,
923 _ => return None,
924 };
925
926 let acc_arg_orig = tail_data.args.get(shape.acc_param_idx)?;
929 let elem_expr = match &acc_arg_orig.node {
930 Expr::FnCall(callee, args) => {
931 if !is_dotted_ident(&callee.node, "List", "prepend") {
932 return None;
933 }
934 if args.len() != 2 {
935 return None;
936 }
937 match &args[1].node {
939 Expr::Ident(name) if name == &shape.acc_param_name => {}
940 _ => return None,
941 }
942 args[0].clone()
943 }
944 _ => return None,
945 };
946
947 let line = fd.line;
948 let buf_name = "__buf";
949 let sep_name = "__sep";
950 let buffered_target = format!("{}__buffered", fd.name);
951
952 let buf_ident = || sp_at(line, Expr::Ident(buf_name.to_string()));
961 let sep_ident = || sp_at(line, Expr::Ident(sep_name.to_string()));
962 let sep_then_buf = intrinsic_call(
963 line,
964 "__buf_append_sep_unless_first",
965 vec![buf_ident(), sep_ident()],
966 );
967 let final_buf = intrinsic_call(line, "__buf_append", vec![sep_then_buf, elem_expr]);
968
969 let mut new_args: Vec<Spanned<Expr>> = tail_data
972 .args
973 .iter()
974 .enumerate()
975 .map(|(i, a)| {
976 if i == shape.acc_param_idx {
977 final_buf.clone()
978 } else {
979 a.clone()
980 }
981 })
982 .collect();
983 new_args.push(sep_ident());
984
985 let new_recursive_body = sp_at(
986 line,
987 Expr::TailCall(Box::new(TailCallData {
988 target: buffered_target.clone(),
989 args: new_args,
990 })),
991 );
992
993 let new_arms = match shape.kind {
998 BufferBuildKind::InternalReverse => vec![
999 MatchArm {
1000 pattern: Pattern::Literal(Literal::Bool(true)),
1001 body: Box::new(buf_ident()),
1002 },
1003 MatchArm {
1004 pattern: Pattern::Literal(Literal::Bool(false)),
1005 body: Box::new(new_recursive_body),
1006 },
1007 ],
1008 BufferBuildKind::ExternalReverse => {
1009 let cons_pat = arms_orig
1013 .iter()
1014 .find_map(|a| match &a.pattern {
1015 Pattern::Cons(h, t) => Some(Pattern::Cons(h.clone(), t.clone())),
1016 _ => None,
1017 })
1018 .unwrap_or(Pattern::Cons("__head".to_string(), "__tail".to_string()));
1019 vec![
1020 MatchArm {
1021 pattern: Pattern::EmptyList,
1022 body: Box::new(buf_ident()),
1023 },
1024 MatchArm {
1025 pattern: cons_pat,
1026 body: Box::new(new_recursive_body),
1027 },
1028 ]
1029 }
1030 };
1031
1032 let new_match = sp_at(
1033 line,
1034 Expr::Match {
1035 subject: subject_orig.clone(),
1036 arms: new_arms,
1037 },
1038 );
1039
1040 let new_body = FnBody::Block(vec![Stmt::Expr(new_match)]);
1041
1042 let mut new_params: Vec<(String, String)> = fd
1044 .params
1045 .iter()
1046 .enumerate()
1047 .filter_map(|(i, p)| (i != shape.acc_param_idx).then_some(p).cloned())
1048 .collect();
1049 new_params.push((buf_name.to_string(), "Buffer".to_string()));
1050 new_params.push((sep_name.to_string(), "String".to_string()));
1051
1052 Some(FnDef {
1053 name: buffered_target,
1054 line,
1055 params: new_params,
1056 return_type: "Buffer".to_string(),
1057 effects: fd.effects.clone(),
1062 desc: Some(format!(
1063 "Synthesized buffered variant of `{}` for deforestation \
1064 lowering. Call sites that match `String.join({}(...), sep)` \
1065 are rewritten to alloc a buffer + call this variant + \
1066 finalize, skipping the intermediate List.",
1067 fd.name, fd.name
1068 )),
1069 body: Arc::new(new_body),
1070 resolution: None,
1071 })
1072}
1073
1074#[cfg(test)]
1075mod tests {
1076 use super::*;
1077 use crate::ast::{BinOp, FnBody, FnDef, Literal, Spanned, TailCallData};
1078 use std::sync::Arc;
1079
1080 fn sp<T>(value: T) -> Spanned<T> {
1081 Spanned {
1082 node: value,
1083 line: 1,
1084 }
1085 }
1086
1087 fn ident(name: &str) -> Spanned<Expr> {
1088 sp(Expr::Ident(name.to_string()))
1089 }
1090
1091 fn dotted(module: &str, member: &str) -> Spanned<Expr> {
1092 sp(Expr::Attr(Box::new(ident(module)), member.to_string()))
1093 }
1094
1095 fn call(callee: Spanned<Expr>, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
1096 sp(Expr::FnCall(Box::new(callee), args))
1097 }
1098
1099 fn canonical_builder(name: &str) -> FnDef {
1103 let true_body = call(dotted("List", "reverse"), vec![ident("acc")]);
1104 let prepend = call(dotted("List", "prepend"), vec![ident("col"), ident("acc")]);
1105 let false_body = sp(Expr::TailCall(Box::new(TailCallData {
1106 target: name.to_string(),
1107 args: vec![
1108 sp(Expr::BinOp(
1109 BinOp::Add,
1110 Box::new(ident("col")),
1111 Box::new(sp(Expr::Literal(Literal::Int(1)))),
1112 )),
1113 prepend,
1114 ],
1115 })));
1116 let match_expr = sp(Expr::Match {
1117 subject: Box::new(sp(Expr::BinOp(
1118 BinOp::Gte,
1119 Box::new(ident("col")),
1120 Box::new(sp(Expr::Literal(Literal::Int(10)))),
1121 ))),
1122 arms: vec![
1123 MatchArm {
1124 pattern: Pattern::Literal(Literal::Bool(true)),
1125 body: Box::new(true_body),
1126 },
1127 MatchArm {
1128 pattern: Pattern::Literal(Literal::Bool(false)),
1129 body: Box::new(false_body),
1130 },
1131 ],
1132 });
1133 FnDef {
1134 name: name.to_string(),
1135 line: 1,
1136 params: vec![
1137 ("col".to_string(), "Int".to_string()),
1138 ("acc".to_string(), "List<Int>".to_string()),
1139 ],
1140 return_type: "List<Int>".to_string(),
1141 effects: vec![],
1142 desc: None,
1143 body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
1144 resolution: None,
1145 }
1146 }
1147
1148 #[test]
1149 fn matches_canonical_buffer_build() {
1150 let fd = canonical_builder("build");
1151 let info = compute_buffer_build_sinks(&[&fd]);
1152 let shape = info.get("build").expect("expected match");
1153 assert_eq!(shape.acc_param_idx, 1);
1154 assert_eq!(shape.acc_param_name, "acc");
1155 }
1156
1157 #[test]
1158 fn rejects_fn_without_list_param() {
1159 let mut fd = canonical_builder("build");
1160 fd.params = vec![("col".to_string(), "Int".to_string())];
1162 let info = compute_buffer_build_sinks(&[&fd]);
1163 assert!(info.is_empty(), "fn without List param should not match");
1164 }
1165
1166 #[test]
1167 fn rejects_when_true_arm_isnt_reverse() {
1168 let mut fd = canonical_builder("build");
1169 if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1171 if let Stmt::Expr(spanned) = &mut stmts[0] {
1172 if let Expr::Match { arms, .. } = &mut spanned.node {
1173 arms[0].body = Box::new(ident("acc"));
1174 }
1175 }
1176 }
1177 let info = compute_buffer_build_sinks(&[&fd]);
1178 assert!(
1179 info.is_empty(),
1180 "fn returning bare acc instead of reverse should not match"
1181 );
1182 }
1183
1184 #[test]
1185 fn rejects_when_false_arm_uses_append_not_prepend() {
1186 let mut fd = canonical_builder("build");
1187 if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1189 if let Stmt::Expr(spanned) = &mut stmts[0] {
1190 if let Expr::Match { arms, .. } = &mut spanned.node {
1191 let false_body = arms[1].body.as_mut();
1192 if let Expr::TailCall(data) = &mut false_body.node {
1193 if let Expr::FnCall(callee, _) = &mut data.args[1].node {
1194 if let Expr::Attr(_, attr) = &mut callee.node {
1195 *attr = "append".to_string();
1196 }
1197 }
1198 }
1199 }
1200 }
1201 }
1202 let info = compute_buffer_build_sinks(&[&fd]);
1203 assert!(
1204 info.is_empty(),
1205 "fn using List.append instead of prepend should not match"
1206 );
1207 }
1208
1209 #[test]
1210 fn rejects_tail_call_to_different_fn() {
1211 let mut fd = canonical_builder("build");
1212 if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1213 if let Stmt::Expr(spanned) = &mut stmts[0] {
1214 if let Expr::Match { arms, .. } = &mut spanned.node {
1215 let false_body = arms[1].body.as_mut();
1216 if let Expr::TailCall(data) = &mut false_body.node {
1217 data.target = "someone_else".to_string();
1218 }
1219 }
1220 }
1221 }
1222 let info = compute_buffer_build_sinks(&[&fd]);
1223 assert!(
1224 info.is_empty(),
1225 "fn whose recursive call targets a different name should not match"
1226 );
1227 }
1228
1229 #[test]
1230 fn rejects_match_with_non_bool_arms() {
1231 let mut fd = canonical_builder("build");
1232 if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1233 if let Stmt::Expr(spanned) = &mut stmts[0] {
1234 if let Expr::Match { arms, .. } = &mut spanned.node {
1235 arms[0].pattern = Pattern::Literal(Literal::Int(0));
1236 }
1237 }
1238 }
1239 let info = compute_buffer_build_sinks(&[&fd]);
1240 assert!(
1241 info.is_empty(),
1242 "match on non-bool patterns should not be detected as buffer-build"
1243 );
1244 }
1245
1246 #[test]
1251 fn detects_via_parser_after_tco() {
1252 let src = r#"
1253fn build(n: Int, acc: List<Int>) -> List<Int>
1254 match n <= 0
1255 true -> List.reverse(acc)
1256 false -> build(n - 1, List.prepend(n, acc))
1257"#;
1258 let mut lexer = crate::lexer::Lexer::new(src);
1259 let tokens = lexer.tokenize().expect("lex");
1260 let mut parser = crate::parser::Parser::new(tokens);
1261 let mut items = parser.parse().expect("parse");
1262 crate::ir::pipeline::tco(&mut items);
1263 let fns: Vec<&FnDef> = items
1264 .iter()
1265 .filter_map(|it| match it {
1266 crate::ast::TopLevel::FnDef(fd) => Some(fd),
1267 _ => None,
1268 })
1269 .collect();
1270 let info = compute_buffer_build_sinks(&fns);
1271 let shape = info
1272 .get("build")
1273 .expect("expected end-to-end shape match for canonical builder");
1274 assert_eq!(shape.acc_param_idx, 1);
1275 assert_eq!(shape.acc_param_name, "acc");
1276 }
1277
1278 #[test]
1281 fn finds_fusion_site_via_parser() {
1282 let src = r#"
1283fn build(n: Int, acc: List<Int>) -> List<Int>
1284 match n <= 0
1285 true -> List.reverse(acc)
1286 false -> build(n - 1, List.prepend(n, acc))
1287
1288fn main() -> String
1289 String.join(build(5, []), ",")
1290"#;
1291 let mut lexer = crate::lexer::Lexer::new(src);
1292 let tokens = lexer.tokenize().expect("lex");
1293 let mut parser = crate::parser::Parser::new(tokens);
1294 let mut items = parser.parse().expect("parse");
1295 crate::ir::pipeline::tco(&mut items);
1296 let fns: Vec<&FnDef> = items
1297 .iter()
1298 .filter_map(|it| match it {
1299 crate::ast::TopLevel::FnDef(fd) => Some(fd),
1300 _ => None,
1301 })
1302 .collect();
1303 let sinks = compute_buffer_build_sinks(&fns);
1304 let sites = find_fusion_sites(&fns, &sinks);
1305 assert_eq!(sites.len(), 1, "expected one fusion site, got {sites:?}");
1306 let site = &sites[0];
1307 assert_eq!(site.enclosing_fn, "main");
1308 assert_eq!(site.sink_fn, "build");
1309 assert!(site.line > 0, "expected real line info, got 0");
1310 }
1311
1312 #[test]
1316 fn ignores_call_when_not_wrapped_in_string_join() {
1317 let src = r#"
1318fn build(n: Int, acc: List<Int>) -> List<Int>
1319 match n <= 0
1320 true -> List.reverse(acc)
1321 false -> build(n - 1, List.prepend(n, acc))
1322
1323fn main() -> List<Int>
1324 build(5, [])
1325"#;
1326 let mut lexer = crate::lexer::Lexer::new(src);
1327 let tokens = lexer.tokenize().expect("lex");
1328 let mut parser = crate::parser::Parser::new(tokens);
1329 let mut items = parser.parse().expect("parse");
1330 crate::ir::pipeline::tco(&mut items);
1331 let fns: Vec<&FnDef> = items
1332 .iter()
1333 .filter_map(|it| match it {
1334 crate::ast::TopLevel::FnDef(fd) => Some(fd),
1335 _ => None,
1336 })
1337 .collect();
1338 let sinks = compute_buffer_build_sinks(&fns);
1339 let sites = find_fusion_sites(&fns, &sinks);
1340 assert!(
1341 sites.is_empty(),
1342 "build called outside String.join must not be a fusion site, got {sites:?}"
1343 );
1344 }
1345
1346 #[test]
1352 fn rejects_via_parser_when_true_arm_returns_bare_acc() {
1353 let src = r#"
1354fn build(n: Int, acc: List<Int>) -> List<Int>
1355 match n <= 0
1356 true -> acc
1357 false -> build(n - 1, List.prepend(n, acc))
1358"#;
1359 let mut lexer = crate::lexer::Lexer::new(src);
1360 let tokens = lexer.tokenize().expect("lex");
1361 let mut parser = crate::parser::Parser::new(tokens);
1362 let mut items = parser.parse().expect("parse");
1363 crate::ir::pipeline::tco(&mut items);
1364 let fns: Vec<&FnDef> = items
1365 .iter()
1366 .filter_map(|it| match it {
1367 crate::ast::TopLevel::FnDef(fd) => Some(fd),
1368 _ => None,
1369 })
1370 .collect();
1371 let info = compute_buffer_build_sinks(&fns);
1372 assert!(
1373 info.is_empty(),
1374 "fn returning bare acc must not be detected as a deforestation candidate"
1375 );
1376 }
1377
1378 #[test]
1384 fn synthesizes_buffered_variant_from_real_builder() {
1385 let src = r#"
1386fn build(n: Int, acc: List<Int>) -> List<Int>
1387 match n <= 0
1388 true -> List.reverse(acc)
1389 false -> build(n - 1, List.prepend(n, acc))
1390"#;
1391 let mut lexer = crate::lexer::Lexer::new(src);
1392 let tokens = lexer.tokenize().expect("lex");
1393 let mut parser = crate::parser::Parser::new(tokens);
1394 let mut items = parser.parse().expect("parse");
1395 crate::ir::pipeline::tco(&mut items);
1396 let fns: Vec<&FnDef> = items
1397 .iter()
1398 .filter_map(|it| match it {
1399 crate::ast::TopLevel::FnDef(fd) => Some(fd),
1400 _ => None,
1401 })
1402 .collect();
1403 let sinks = compute_buffer_build_sinks(&fns);
1404 assert!(sinks.contains_key("build"));
1405 let synthesized = synthesize_buffered_variants(&fns, &sinks);
1406 assert_eq!(
1407 synthesized.len(),
1408 1,
1409 "expected exactly one synthesized variant"
1410 );
1411 let bf = &synthesized[0];
1412
1413 assert_eq!(bf.name, "build__buffered");
1415 assert_eq!(bf.return_type, "Buffer");
1416 let param_names: Vec<&str> = bf.params.iter().map(|(n, _)| n.as_str()).collect();
1417 let param_types: Vec<&str> = bf.params.iter().map(|(_, t)| t.as_str()).collect();
1418 assert_eq!(param_names, vec!["n", "__buf", "__sep"]);
1419 assert_eq!(param_types, vec!["Int", "Buffer", "String"]);
1420
1421 let stmts = bf.body.stmts();
1423 assert_eq!(stmts.len(), 1);
1424 let match_expr = match &stmts[0] {
1425 Stmt::Expr(s) => match &s.node {
1426 Expr::Match { subject: _, arms } => arms,
1427 _ => panic!("body root must be a match"),
1428 },
1429 _ => panic!("body root must be Stmt::Expr"),
1430 };
1431 assert_eq!(match_expr.len(), 2);
1432
1433 let true_arm = match_expr
1435 .iter()
1436 .find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(true))))
1437 .expect("true arm");
1438 match &true_arm.body.node {
1439 Expr::Ident(name) => assert_eq!(name, "__buf"),
1440 other => panic!("true arm should be Ident(__buf), got {other:?}"),
1441 }
1442
1443 let false_arm = match_expr
1445 .iter()
1446 .find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(false))))
1447 .expect("false arm");
1448 let tail_data = match &false_arm.body.node {
1449 Expr::TailCall(d) => d,
1450 other => panic!("false arm should be TailCall, got {other:?}"),
1451 };
1452 assert_eq!(tail_data.target, "build__buffered");
1453 assert_eq!(tail_data.args.len(), 3);
1457 let outer = match &tail_data.args[1].node {
1460 Expr::FnCall(callee, args) => {
1461 match &callee.node {
1462 Expr::Ident(name) => assert_eq!(name, "__buf_append"),
1463 _ => panic!("expected Ident callee"),
1464 }
1465 args
1466 }
1467 _ => panic!("expected outer __buf_append FnCall"),
1468 };
1469 assert_eq!(outer.len(), 2);
1470 match &outer[0].node {
1472 Expr::FnCall(callee, _) => match &callee.node {
1473 Expr::Ident(name) => assert_eq!(name, "__buf_append_sep_unless_first"),
1474 _ => panic!("expected Ident callee for inner intrinsic"),
1475 },
1476 _ => panic!("expected inner __buf_append_sep_unless_first FnCall"),
1477 }
1478 match &outer[1].node {
1480 Expr::Ident(name) => assert_eq!(name, "n"),
1481 _ => panic!("expected `n` ident as elem"),
1482 }
1483 match &tail_data.args[2].node {
1485 Expr::Ident(name) => assert_eq!(name, "__sep"),
1486 _ => panic!("expected __sep ident as last arg"),
1487 }
1488 }
1489
1490 #[test]
1491 fn detects_acc_param_at_arbitrary_index() {
1492 let true_body = call(dotted("List", "reverse"), vec![ident("acc")]);
1500 let prepend = call(dotted("List", "prepend"), vec![ident("col"), ident("acc")]);
1501 let false_body = sp(Expr::TailCall(Box::new(TailCallData {
1504 target: "build".to_string(),
1505 args: vec![
1506 prepend,
1507 sp(Expr::BinOp(
1508 BinOp::Add,
1509 Box::new(ident("col")),
1510 Box::new(sp(Expr::Literal(Literal::Int(1)))),
1511 )),
1512 ],
1513 })));
1514 let match_expr = sp(Expr::Match {
1515 subject: Box::new(sp(Expr::BinOp(
1516 BinOp::Gte,
1517 Box::new(ident("col")),
1518 Box::new(sp(Expr::Literal(Literal::Int(10)))),
1519 ))),
1520 arms: vec![
1521 MatchArm {
1522 pattern: Pattern::Literal(Literal::Bool(true)),
1523 body: Box::new(true_body),
1524 },
1525 MatchArm {
1526 pattern: Pattern::Literal(Literal::Bool(false)),
1527 body: Box::new(false_body),
1528 },
1529 ],
1530 });
1531 let fd = FnDef {
1532 name: "build".to_string(),
1533 line: 1,
1534 params: vec![
1535 ("acc".to_string(), "List<Int>".to_string()),
1536 ("col".to_string(), "Int".to_string()),
1537 ],
1538 return_type: "List<Int>".to_string(),
1539 effects: vec![],
1540 desc: None,
1541 body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
1542 resolution: None,
1543 };
1544 let info = compute_buffer_build_sinks(&[&fd]);
1545 let shape = info.get("build").expect("expected match");
1546 assert_eq!(shape.acc_param_idx, 0);
1547 assert_eq!(shape.acc_param_name, "acc");
1548 }
1549
1550 #[test]
1551 fn rejects_loose_prepend_in_non_acc_position() {
1552 let mut fd = canonical_builder("build");
1557 {
1562 let body = std::sync::Arc::make_mut(&mut fd.body);
1563 let FnBody::Block(stmts) = body;
1564 if let Stmt::Expr(spanned) = &mut stmts[0]
1565 && let Expr::Match { arms, .. } = &mut spanned.node
1566 {
1567 for arm in arms.iter_mut() {
1568 if matches!(arm.pattern, Pattern::Literal(Literal::Bool(false)))
1569 && let Expr::TailCall(data) = &mut arm.body.node
1570 {
1571 data.args.reverse();
1572 }
1573 }
1574 }
1575 }
1576 let info = compute_buffer_build_sinks(&[&fd]);
1577 assert!(
1578 info.get("build").is_none(),
1579 "loose-prepend (prepend not at acc-position) must not be detected"
1580 );
1581 }
1582
1583 #[test]
1584 fn skips_synth_when_no_rewriteable_call_site() {
1585 let sink = canonical_builder("build");
1592 let caller = FnDef {
1594 name: "use_build".to_string(),
1595 line: 2,
1596 params: vec![],
1597 return_type: "List<Int>".to_string(),
1598 effects: vec![],
1599 desc: None,
1600 body: Arc::new(FnBody::Block(vec![Stmt::Expr(call(
1601 ident_expr("build"),
1602 vec![sp(Expr::Literal(Literal::Int(0))), sp(Expr::List(vec![]))],
1603 ))])),
1604 resolution: None,
1605 };
1606 let mut items = vec![
1607 crate::ast::TopLevel::FnDef(sink),
1608 crate::ast::TopLevel::FnDef(caller),
1609 ];
1610 let initial_count = items.len();
1611 let report = run_buffer_build_pass(&mut items);
1612 assert_eq!(report.rewrites, 0, "no fusion sites — no rewriteable call");
1613 assert_eq!(
1614 report.synthesized.len(),
1615 0,
1616 "no synth — nothing to fuse against"
1617 );
1618 assert_eq!(items.len(), initial_count, "no buffered variant appended");
1619 }
1620
1621 #[test]
1622 fn external_reverse_pattern_round_trips() {
1623 let nil_body = ident("acc");
1627 let prepend = call(dotted("List", "prepend"), vec![ident("h"), ident("acc")]);
1628 let cons_body = sp(Expr::TailCall(Box::new(TailCallData {
1629 target: "build".to_string(),
1630 args: vec![ident("t"), prepend],
1631 })));
1632 let match_expr = sp(Expr::Match {
1633 subject: Box::new(ident("xs")),
1634 arms: vec![
1635 MatchArm {
1636 pattern: Pattern::EmptyList,
1637 body: Box::new(nil_body),
1638 },
1639 MatchArm {
1640 pattern: Pattern::Cons("h".to_string(), "t".to_string()),
1641 body: Box::new(cons_body),
1642 },
1643 ],
1644 });
1645 let sink = FnDef {
1646 name: "build".to_string(),
1647 line: 1,
1648 params: vec![
1649 ("xs".to_string(), "List<Int>".to_string()),
1650 ("acc".to_string(), "List<String>".to_string()),
1651 ],
1652 return_type: "List<String>".to_string(),
1653 effects: vec![],
1654 desc: None,
1655 body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
1656 resolution: None,
1657 };
1658 let info = compute_buffer_build_sinks(&[&sink]);
1659 let shape = info
1660 .get("build")
1661 .expect("external-reverse sink should be detected");
1662 assert_eq!(shape.kind, BufferBuildKind::ExternalReverse);
1663 assert_eq!(shape.acc_param_idx, 1);
1664
1665 let join_call = call(
1667 dotted("String", "join"),
1668 vec![
1669 call(
1670 dotted("List", "reverse"),
1671 vec![call(
1672 ident_expr("build"),
1673 vec![ident("xs"), sp(Expr::List(vec![]))],
1674 )],
1675 ),
1676 sp(Expr::Literal(Literal::Str("\n".to_string()))),
1677 ],
1678 );
1679 let caller = FnDef {
1680 name: "render".to_string(),
1681 line: 2,
1682 params: vec![("xs".to_string(), "List<Int>".to_string())],
1683 return_type: "String".to_string(),
1684 effects: vec![],
1685 desc: None,
1686 body: Arc::new(FnBody::Block(vec![Stmt::Expr(join_call)])),
1687 resolution: None,
1688 };
1689
1690 let mut items = vec![
1691 crate::ast::TopLevel::FnDef(sink),
1692 crate::ast::TopLevel::FnDef(caller),
1693 ];
1694 let report = run_buffer_build_pass(&mut items);
1695 assert_eq!(
1696 report.rewrites, 1,
1697 "external-reverse pattern should be one fusion site"
1698 );
1699 assert_eq!(
1700 report.synthesized.len(),
1701 1,
1702 "exactly one buffered variant for the used sink"
1703 );
1704
1705 let synth_present = items.iter().any(|it| match it {
1707 crate::ast::TopLevel::FnDef(fd) => fd.name == "build__buffered",
1708 _ => false,
1709 });
1710 assert!(synth_present, "build__buffered must be appended");
1711 }
1712
1713 fn ident_expr(name: &str) -> Spanned<Expr> {
1714 sp(Expr::Ident(name.to_string()))
1715 }
1716}