1use std::collections::HashMap;
26use std::sync::Arc;
27
28use crate::ast::{Expr, FnBody, FnDef, Literal, MatchArm, Pattern, Spanned, Stmt, TailCallData};
29
30#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum BufferBuildKind {
53 InternalReverse,
54 ExternalReverse,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq)]
59pub struct BufferBuildShape {
60 pub acc_param_idx: usize,
64 pub acc_param_name: String,
67 pub kind: BufferBuildKind,
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum ConsumerKind {
86 StringJoin,
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
95pub struct FusionSite {
96 pub enclosing_fn: String,
98 pub line: usize,
100 pub sink_fn: String,
102 pub consumer: ConsumerKind,
104}
105
106pub fn compute_buffer_build_sinks(fns: &[&FnDef]) -> HashMap<String, BufferBuildShape> {
110 let mut out = HashMap::new();
111 for fd in fns {
112 if let Some(shape) = match_buffer_build_shape(fd) {
113 out.insert(fd.name.clone(), shape);
114 }
115 }
116 out
117}
118
119pub fn find_fusion_sites(
125 fns: &[&FnDef],
126 sinks: &HashMap<String, BufferBuildShape>,
127) -> Vec<FusionSite> {
128 let mut out = Vec::new();
129 for fd in fns {
130 for stmt in fd.body.stmts() {
131 match stmt {
132 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
133 walk_expr_for_fusion_sites(&expr.node, expr.line, &fd.name, sinks, &mut out);
134 }
135 }
136 }
137 }
138 out
139}
140
141fn walk_expr_for_fusion_sites(
145 expr: &Expr,
146 expr_line: usize,
147 enclosing_fn: &str,
148 sinks: &HashMap<String, BufferBuildShape>,
149 out: &mut Vec<FusionSite>,
150) {
151 if let Some(inner_name) = match_string_join_fusion_site(expr, sinks) {
152 out.push(FusionSite {
153 enclosing_fn: enclosing_fn.to_string(),
154 line: expr_line,
155 sink_fn: inner_name,
156 consumer: ConsumerKind::StringJoin,
157 });
158 }
159 visit_subexprs(expr, expr_line, enclosing_fn, sinks, out);
163}
164
165fn visit_subexprs(
169 expr: &Expr,
170 fallback_line: usize,
171 enclosing_fn: &str,
172 sinks: &HashMap<String, BufferBuildShape>,
173 out: &mut Vec<FusionSite>,
174) {
175 let line_of = |s: &crate::ast::Spanned<Expr>| {
176 if s.line > 0 { s.line } else { fallback_line }
177 };
178 match expr {
179 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
180 Expr::Constructor(_, Some(inner)) | Expr::Attr(inner, _) | Expr::ErrorProp(inner) => {
181 walk_expr_for_fusion_sites(&inner.node, line_of(inner), enclosing_fn, sinks, out);
182 }
183 Expr::FnCall(callee, args) => {
184 walk_expr_for_fusion_sites(&callee.node, line_of(callee), enclosing_fn, sinks, out);
185 for a in args {
186 walk_expr_for_fusion_sites(&a.node, line_of(a), enclosing_fn, sinks, out);
187 }
188 }
189 Expr::TailCall(data) => {
190 for a in &data.args {
191 walk_expr_for_fusion_sites(&a.node, line_of(a), enclosing_fn, sinks, out);
192 }
193 }
194 Expr::BinOp(_, l, r) => {
195 walk_expr_for_fusion_sites(&l.node, line_of(l), enclosing_fn, sinks, out);
196 walk_expr_for_fusion_sites(&r.node, line_of(r), enclosing_fn, sinks, out);
197 }
198 Expr::Match { subject, arms } => {
199 walk_expr_for_fusion_sites(&subject.node, line_of(subject), enclosing_fn, sinks, out);
200 for arm in arms {
201 walk_expr_for_fusion_sites(
202 &arm.body.node,
203 line_of(&arm.body),
204 enclosing_fn,
205 sinks,
206 out,
207 );
208 }
209 }
210 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
211 for it in items {
212 walk_expr_for_fusion_sites(&it.node, line_of(it), enclosing_fn, sinks, out);
213 }
214 }
215 Expr::MapLiteral(entries) => {
216 for (k, v) in entries {
217 walk_expr_for_fusion_sites(&k.node, line_of(k), enclosing_fn, sinks, out);
218 walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
219 }
220 }
221 Expr::RecordCreate { fields, .. } => {
222 for (_, v) in fields {
223 walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
224 }
225 }
226 Expr::RecordUpdate { base, updates, .. } => {
227 walk_expr_for_fusion_sites(&base.node, line_of(base), enclosing_fn, sinks, out);
228 for (_, v) in updates {
229 walk_expr_for_fusion_sites(&v.node, line_of(v), enclosing_fn, sinks, out);
230 }
231 }
232 Expr::InterpolatedStr(parts) => {
233 for part in parts {
234 if let crate::ast::StrPart::Parsed(inner) = part {
235 walk_expr_for_fusion_sites(
236 &inner.node,
237 line_of(inner),
238 enclosing_fn,
239 sinks,
240 out,
241 );
242 }
243 }
244 }
245 }
246}
247
248fn match_buffer_build_shape(fd: &FnDef) -> Option<BufferBuildShape> {
250 let (acc_idx, acc_name) = fd
263 .params
264 .iter()
265 .enumerate()
266 .rfind(|(_, (_, ty))| is_list_type_str(ty))
267 .map(|(i, (name, _))| (i, name.clone()))?;
268
269 let match_expr = single_match_body(&fd.body)?;
271 let (subject_expr, arms) = match match_expr {
272 Expr::Match { subject, arms } => (subject, arms),
273 _ => return None,
274 };
275
276 if let Some((true_body, false_body)) = pair_bool_arms(arms) {
278 let _ = subject_expr;
279 if is_list_reverse_of(true_body, &acc_name)
280 && is_self_tail_with_prepend_acc(false_body, &fd.name, acc_idx, &acc_name)
281 {
282 return Some(BufferBuildShape {
283 acc_param_idx: acc_idx,
284 acc_param_name: acc_name,
285 kind: BufferBuildKind::InternalReverse,
286 });
287 }
288 }
289
290 if let Some((nil_body, cons_body)) = pair_nil_cons_arms(arms)
295 && is_ident_named(nil_body, &acc_name)
296 && is_self_tail_with_prepend_acc(cons_body, &fd.name, acc_idx, &acc_name)
297 {
298 return Some(BufferBuildShape {
299 acc_param_idx: acc_idx,
300 acc_param_name: acc_name,
301 kind: BufferBuildKind::ExternalReverse,
302 });
303 }
304
305 None
306}
307
308fn pair_nil_cons_arms(arms: &[MatchArm]) -> Option<(&Expr, &Expr)> {
313 if arms.len() != 2 {
314 return None;
315 }
316 let mut nil_body: Option<&Expr> = None;
317 let mut cons_body: Option<&Expr> = None;
318 for arm in arms {
319 match &arm.pattern {
320 Pattern::EmptyList => nil_body = Some(&arm.body.node),
321 Pattern::Cons(_, _) => cons_body = Some(&arm.body.node),
322 _ => return None,
323 }
324 }
325 match (nil_body, cons_body) {
326 (Some(n), Some(c)) => Some((n, c)),
327 _ => None,
328 }
329}
330
331fn is_ident_named(expr: &Expr, name: &str) -> bool {
333 matches!(expr, Expr::Ident(n) if n == name)
334}
335
336fn match_string_join_fusion_site(
362 expr: &Expr,
363 sinks: &HashMap<String, BufferBuildShape>,
364) -> Option<String> {
365 let Expr::FnCall(callee, args) = expr else {
366 return None;
367 };
368 if !is_dotted_ident(&callee.node, "String", "join") || args.len() != 2 {
369 return None;
370 }
371 let consumer_arg = &args[0].node;
372
373 let (inner_call_expr, saw_external_reverse) = match consumer_arg {
375 Expr::FnCall(rev_callee, rev_args)
376 if is_dotted_ident(&rev_callee.node, "List", "reverse") && rev_args.len() == 1 =>
377 {
378 (&rev_args[0].node, true)
379 }
380 other => (other, false),
381 };
382
383 let Expr::FnCall(inner_callee, inner_args) = inner_call_expr else {
384 return None;
385 };
386 let Expr::Ident(name) = &inner_callee.node else {
387 return None;
388 };
389 let shape = sinks.get(name)?;
390
391 let kinds_align = matches!(
392 (saw_external_reverse, &shape.kind),
393 (false, BufferBuildKind::InternalReverse) | (true, BufferBuildKind::ExternalReverse)
394 );
395 if !kinds_align {
396 return None;
397 }
398
399 let acc_arg = inner_args.get(shape.acc_param_idx)?;
400 if !matches!(&acc_arg.node, Expr::List(items) if items.is_empty()) {
401 return None;
402 }
403
404 Some(name.clone())
405}
406
407fn is_list_type_str(ty: &str) -> bool {
409 let t = ty.trim();
410 t.starts_with("List<") && t.ends_with('>')
411}
412
413fn single_match_body(body: &FnBody) -> Option<&Expr> {
417 let stmts = body.stmts();
418 if stmts.len() != 1 {
419 return None;
420 }
421 match &stmts[0] {
422 Stmt::Expr(spanned) => match &spanned.node {
423 Expr::Match { .. } => Some(&spanned.node),
424 _ => None,
425 },
426 Stmt::Binding(_, _, _) => None,
427 }
428}
429
430fn pair_bool_arms(arms: &[MatchArm]) -> Option<(&Expr, &Expr)> {
434 if arms.len() != 2 {
435 return None;
436 }
437 let mut t = None;
438 let mut f = None;
439 for arm in arms {
440 match &arm.pattern {
441 Pattern::Literal(Literal::Bool(true)) => {
442 if t.is_some() {
443 return None;
444 }
445 t = Some(&arm.body.node);
446 }
447 Pattern::Literal(Literal::Bool(false)) => {
448 if f.is_some() {
449 return None;
450 }
451 f = Some(&arm.body.node);
452 }
453 _ => return None,
454 }
455 }
456 Some((t?, f?))
457}
458
459fn is_list_reverse_of(expr: &Expr, acc_name: &str) -> bool {
461 let (callee, args) = match expr {
462 Expr::FnCall(c, a) => (c, a),
463 _ => return false,
464 };
465 if !is_dotted_ident(&callee.node, "List", "reverse") {
466 return false;
467 }
468 if args.len() != 1 {
469 return false;
470 }
471 matches!(&args[0].node, Expr::Ident(name) if name == acc_name)
472}
473
474fn is_self_tail_with_prepend_acc(
480 expr: &Expr,
481 self_name: &str,
482 acc_idx: usize,
483 acc_name: &str,
484) -> bool {
485 let data = match expr {
486 Expr::TailCall(data) => data,
487 _ => return false,
488 };
489 if data.target != self_name {
490 return false;
491 }
492 let acc_arg = match data.args.get(acc_idx) {
501 Some(a) => a,
502 None => return false,
503 };
504 is_list_prepend_to_acc(&acc_arg.node, acc_name)
505}
506
507fn is_list_prepend_to_acc(expr: &Expr, acc_name: &str) -> bool {
509 let (callee, args) = match expr {
510 Expr::FnCall(c, a) => (c, a),
511 _ => return false,
512 };
513 if !is_dotted_ident(&callee.node, "List", "prepend") {
514 return false;
515 }
516 if args.len() != 2 {
517 return false;
518 }
519 matches!(&args[1].node, Expr::Ident(name) if name == acc_name)
520}
521
522fn is_dotted_ident(expr: &Expr, module: &str, member: &str) -> bool {
525 let (base, attr) = match expr {
526 Expr::Attr(b, a) => (b, a),
527 _ => return false,
528 };
529 if attr != member {
530 return false;
531 }
532 matches!(&base.node, Expr::Ident(name) if name == module)
533}
534
535pub fn synthesize_buffered_variants(
573 fns: &[&FnDef],
574 sinks: &HashMap<String, BufferBuildShape>,
575) -> Vec<FnDef> {
576 let mut out = Vec::new();
577 for fd in fns {
578 if let Some(shape) = sinks.get(&fd.name)
579 && let Some(buffered) = build_buffered_variant(fd, shape)
580 {
581 out.push(buffered);
582 }
583 }
584 out
585}
586
587fn sp_at(line: usize, expr: Expr) -> Spanned<Expr> {
592 Spanned { node: expr, line }
593}
594
595fn intrinsic_call(line: usize, name: &str, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
600 let callee = sp_at(line, Expr::Ident(name.to_string()));
601 sp_at(line, Expr::FnCall(Box::new(callee), args))
602}
603
604pub fn run_buffer_build_pass(items: &mut Vec<crate::ast::TopLevel>) -> (usize, usize) {
616 let fn_refs: Vec<&FnDef> = items
617 .iter()
618 .filter_map(|it| match it {
619 crate::ast::TopLevel::FnDef(fd) => Some(fd),
620 _ => None,
621 })
622 .collect();
623 let all_sinks = compute_buffer_build_sinks(&fn_refs);
624 if all_sinks.is_empty() {
625 return (0, 0);
626 }
627 let sites = find_fusion_sites(&fn_refs, &all_sinks);
628
629 let mut used_sinks: HashMap<String, BufferBuildShape> = HashMap::new();
637 for site in &sites {
638 if let Some(shape) = all_sinks.get(&site.sink_fn) {
639 used_sinks.insert(site.sink_fn.clone(), shape.clone());
640 }
641 }
642 let synthesized = synthesize_buffered_variants(&fn_refs, &used_sinks);
643 let sinks = used_sinks;
644 drop(fn_refs);
645
646 let mut fn_defs_owned: Vec<&mut FnDef> = items
647 .iter_mut()
648 .filter_map(|it| match it {
649 crate::ast::TopLevel::FnDef(fd) => Some(fd),
650 _ => None,
651 })
652 .collect();
653 for fd in fn_defs_owned.iter_mut() {
657 rewrite_one_fn(fd, &sinks);
658 }
659
660 items.reserve(synthesized.len());
661 for fd in synthesized.iter() {
662 items.push(crate::ast::TopLevel::FnDef(fd.clone()));
663 }
664
665 (sites.len(), synthesized.len())
666}
667
668fn rewrite_one_fn(fd: &mut FnDef, sinks: &HashMap<String, BufferBuildShape>) {
672 let body_arc = std::sync::Arc::make_mut(&mut fd.body);
673 let FnBody::Block(stmts) = body_arc;
674 for stmt in stmts.iter_mut() {
675 match stmt {
676 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
677 rewrite_expr_in_place(expr, sinks);
678 }
679 }
680 }
681}
682
683pub fn rewrite_fusion_sites(fn_defs: &mut [FnDef], sinks: &HashMap<String, BufferBuildShape>) {
695 if sinks.is_empty() {
696 return;
697 }
698 for fd in fn_defs.iter_mut() {
699 let body_arc = std::sync::Arc::make_mut(&mut fd.body);
700 let FnBody::Block(stmts) = body_arc;
701 for stmt in stmts.iter_mut() {
702 match stmt {
703 Stmt::Binding(_, _, expr) | Stmt::Expr(expr) => {
704 rewrite_expr_in_place(expr, sinks);
705 }
706 }
707 }
708 }
709}
710
711fn rewrite_expr_in_place(expr: &mut Spanned<Expr>, sinks: &HashMap<String, BufferBuildShape>) {
716 if let Some(replacement) = try_rewrite_fusion_site(expr, sinks) {
717 *expr = replacement;
718 descend_into_subexprs(expr, sinks);
722 return;
723 }
724 descend_into_subexprs(expr, sinks);
725}
726
727fn descend_into_subexprs(expr: &mut Spanned<Expr>, sinks: &HashMap<String, BufferBuildShape>) {
731 match &mut expr.node {
732 Expr::Literal(_) | Expr::Ident(_) | Expr::Resolved { .. } | Expr::Constructor(_, None) => {}
733 Expr::Constructor(_, Some(inner)) | Expr::Attr(inner, _) | Expr::ErrorProp(inner) => {
734 rewrite_expr_in_place(inner, sinks);
735 }
736 Expr::FnCall(callee, args) => {
737 rewrite_expr_in_place(callee, sinks);
738 for a in args.iter_mut() {
739 rewrite_expr_in_place(a, sinks);
740 }
741 }
742 Expr::TailCall(data) => {
743 for a in data.args.iter_mut() {
744 rewrite_expr_in_place(a, sinks);
745 }
746 }
747 Expr::BinOp(_, l, r) => {
748 rewrite_expr_in_place(l, sinks);
749 rewrite_expr_in_place(r, sinks);
750 }
751 Expr::Match { subject, arms } => {
752 rewrite_expr_in_place(subject, sinks);
753 for arm in arms.iter_mut() {
754 rewrite_expr_in_place(&mut arm.body, sinks);
755 }
756 }
757 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
758 for it in items.iter_mut() {
759 rewrite_expr_in_place(it, sinks);
760 }
761 }
762 Expr::MapLiteral(entries) => {
763 for (k, v) in entries.iter_mut() {
764 rewrite_expr_in_place(k, sinks);
765 rewrite_expr_in_place(v, sinks);
766 }
767 }
768 Expr::RecordCreate { fields, .. } => {
769 for (_, v) in fields.iter_mut() {
770 rewrite_expr_in_place(v, sinks);
771 }
772 }
773 Expr::RecordUpdate { base, updates, .. } => {
774 rewrite_expr_in_place(base, sinks);
775 for (_, v) in updates.iter_mut() {
776 rewrite_expr_in_place(v, sinks);
777 }
778 }
779 Expr::InterpolatedStr(parts) => {
780 for part in parts.iter_mut() {
781 if let crate::ast::StrPart::Parsed(inner) = part {
782 rewrite_expr_in_place(inner, sinks);
783 }
784 }
785 }
786 }
787}
788
789fn try_rewrite_fusion_site(
793 expr: &Spanned<Expr>,
794 sinks: &HashMap<String, BufferBuildShape>,
795) -> Option<Spanned<Expr>> {
796 let line = expr.line;
797
798 let sink_name = match_string_join_fusion_site(&expr.node, sinks)?;
801 let shape = sinks.get(&sink_name)?;
802
803 let outer_args = match &expr.node {
807 Expr::FnCall(_, a) => a,
808 _ => return None,
809 };
810 let consumer_arg = &outer_args[0].node;
811 let inner_call_expr = if let Expr::FnCall(rev_callee, rev_args) = consumer_arg
812 && is_dotted_ident(&rev_callee.node, "List", "reverse")
813 && rev_args.len() == 1
814 {
815 &rev_args[0].node
816 } else {
817 consumer_arg
818 };
819 let inner_args = match inner_call_expr {
820 Expr::FnCall(_, a) => a,
821 _ => return None,
822 };
823
824 let sep_expr = outer_args[1].clone();
833 let buf_new = intrinsic_call(
834 line,
835 "__buf_new",
836 vec![sp_at(line, Expr::Literal(Literal::Int(8192)))],
837 );
838 let mut buffered_args: Vec<Spanned<Expr>> = inner_args
839 .iter()
840 .enumerate()
841 .filter_map(|(i, a)| (i != shape.acc_param_idx).then_some(a).cloned())
842 .collect();
843 buffered_args.push(buf_new);
844 buffered_args.push(sep_expr);
845 let buffered_call = sp_at(
846 line,
847 Expr::FnCall(
848 Box::new(sp_at(line, Expr::Ident(format!("{}__buffered", sink_name)))),
849 buffered_args,
850 ),
851 );
852 Some(intrinsic_call(line, "__buf_finalize", vec![buffered_call]))
853}
854
855fn build_buffered_variant(fd: &FnDef, shape: &BufferBuildShape) -> Option<FnDef> {
860 let stmts = fd.body.stmts();
867 if stmts.len() != 1 {
868 return None;
869 }
870 let outer_expr = match &stmts[0] {
871 Stmt::Expr(spanned) => spanned,
872 _ => return None,
873 };
874 let (subject_orig, arms_orig) = match &outer_expr.node {
875 Expr::Match { subject, arms } => (subject, arms),
876 _ => return None,
877 };
878 let recursive_body: &Spanned<Expr> = match shape.kind {
879 BufferBuildKind::InternalReverse => arms_orig
880 .iter()
881 .find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(false))))
882 .map(|a| a.body.as_ref())?,
883 BufferBuildKind::ExternalReverse => arms_orig
884 .iter()
885 .find(|a| matches!(a.pattern, Pattern::Cons(_, _)))
886 .map(|a| a.body.as_ref())?,
887 };
888 let tail_data = match &recursive_body.node {
889 Expr::TailCall(data) => data,
890 _ => return None,
891 };
892
893 let acc_arg_orig = tail_data.args.get(shape.acc_param_idx)?;
896 let elem_expr = match &acc_arg_orig.node {
897 Expr::FnCall(callee, args) => {
898 if !is_dotted_ident(&callee.node, "List", "prepend") {
899 return None;
900 }
901 if args.len() != 2 {
902 return None;
903 }
904 match &args[1].node {
906 Expr::Ident(name) if name == &shape.acc_param_name => {}
907 _ => return None,
908 }
909 args[0].clone()
910 }
911 _ => return None,
912 };
913
914 let line = fd.line;
915 let buf_name = "__buf";
916 let sep_name = "__sep";
917 let buffered_target = format!("{}__buffered", fd.name);
918
919 let buf_ident = || sp_at(line, Expr::Ident(buf_name.to_string()));
928 let sep_ident = || sp_at(line, Expr::Ident(sep_name.to_string()));
929 let sep_then_buf = intrinsic_call(
930 line,
931 "__buf_append_sep_unless_first",
932 vec![buf_ident(), sep_ident()],
933 );
934 let final_buf = intrinsic_call(line, "__buf_append", vec![sep_then_buf, elem_expr]);
935
936 let mut new_args: Vec<Spanned<Expr>> = tail_data
939 .args
940 .iter()
941 .enumerate()
942 .map(|(i, a)| {
943 if i == shape.acc_param_idx {
944 final_buf.clone()
945 } else {
946 a.clone()
947 }
948 })
949 .collect();
950 new_args.push(sep_ident());
951
952 let new_recursive_body = sp_at(
953 line,
954 Expr::TailCall(Box::new(TailCallData {
955 target: buffered_target.clone(),
956 args: new_args,
957 })),
958 );
959
960 let new_arms = match shape.kind {
965 BufferBuildKind::InternalReverse => vec![
966 MatchArm {
967 pattern: Pattern::Literal(Literal::Bool(true)),
968 body: Box::new(buf_ident()),
969 },
970 MatchArm {
971 pattern: Pattern::Literal(Literal::Bool(false)),
972 body: Box::new(new_recursive_body),
973 },
974 ],
975 BufferBuildKind::ExternalReverse => {
976 let cons_pat = arms_orig
980 .iter()
981 .find_map(|a| match &a.pattern {
982 Pattern::Cons(h, t) => Some(Pattern::Cons(h.clone(), t.clone())),
983 _ => None,
984 })
985 .unwrap_or(Pattern::Cons("__head".to_string(), "__tail".to_string()));
986 vec![
987 MatchArm {
988 pattern: Pattern::EmptyList,
989 body: Box::new(buf_ident()),
990 },
991 MatchArm {
992 pattern: cons_pat,
993 body: Box::new(new_recursive_body),
994 },
995 ]
996 }
997 };
998
999 let new_match = sp_at(
1000 line,
1001 Expr::Match {
1002 subject: subject_orig.clone(),
1003 arms: new_arms,
1004 },
1005 );
1006
1007 let new_body = FnBody::Block(vec![Stmt::Expr(new_match)]);
1008
1009 let mut new_params: Vec<(String, String)> = fd
1011 .params
1012 .iter()
1013 .enumerate()
1014 .filter_map(|(i, p)| (i != shape.acc_param_idx).then_some(p).cloned())
1015 .collect();
1016 new_params.push((buf_name.to_string(), "Buffer".to_string()));
1017 new_params.push((sep_name.to_string(), "String".to_string()));
1018
1019 Some(FnDef {
1020 name: buffered_target,
1021 line,
1022 params: new_params,
1023 return_type: "Buffer".to_string(),
1024 effects: fd.effects.clone(),
1029 desc: Some(format!(
1030 "Synthesized buffered variant of `{}` for deforestation \
1031 lowering. Call sites that match `String.join({}(...), sep)` \
1032 are rewritten to alloc a buffer + call this variant + \
1033 finalize, skipping the intermediate List.",
1034 fd.name, fd.name
1035 )),
1036 body: Arc::new(new_body),
1037 resolution: None,
1038 })
1039}
1040
1041#[cfg(test)]
1042mod tests {
1043 use super::*;
1044 use crate::ast::{BinOp, FnBody, FnDef, Literal, Spanned, TailCallData};
1045 use std::sync::Arc;
1046
1047 fn sp<T>(value: T) -> Spanned<T> {
1048 Spanned {
1049 node: value,
1050 line: 1,
1051 }
1052 }
1053
1054 fn ident(name: &str) -> Spanned<Expr> {
1055 sp(Expr::Ident(name.to_string()))
1056 }
1057
1058 fn dotted(module: &str, member: &str) -> Spanned<Expr> {
1059 sp(Expr::Attr(Box::new(ident(module)), member.to_string()))
1060 }
1061
1062 fn call(callee: Spanned<Expr>, args: Vec<Spanned<Expr>>) -> Spanned<Expr> {
1063 sp(Expr::FnCall(Box::new(callee), args))
1064 }
1065
1066 fn canonical_builder(name: &str) -> FnDef {
1070 let true_body = call(dotted("List", "reverse"), vec![ident("acc")]);
1071 let prepend = call(dotted("List", "prepend"), vec![ident("col"), ident("acc")]);
1072 let false_body = sp(Expr::TailCall(Box::new(TailCallData {
1073 target: name.to_string(),
1074 args: vec![
1075 sp(Expr::BinOp(
1076 BinOp::Add,
1077 Box::new(ident("col")),
1078 Box::new(sp(Expr::Literal(Literal::Int(1)))),
1079 )),
1080 prepend,
1081 ],
1082 })));
1083 let match_expr = sp(Expr::Match {
1084 subject: Box::new(sp(Expr::BinOp(
1085 BinOp::Gte,
1086 Box::new(ident("col")),
1087 Box::new(sp(Expr::Literal(Literal::Int(10)))),
1088 ))),
1089 arms: vec![
1090 MatchArm {
1091 pattern: Pattern::Literal(Literal::Bool(true)),
1092 body: Box::new(true_body),
1093 },
1094 MatchArm {
1095 pattern: Pattern::Literal(Literal::Bool(false)),
1096 body: Box::new(false_body),
1097 },
1098 ],
1099 });
1100 FnDef {
1101 name: name.to_string(),
1102 line: 1,
1103 params: vec![
1104 ("col".to_string(), "Int".to_string()),
1105 ("acc".to_string(), "List<Int>".to_string()),
1106 ],
1107 return_type: "List<Int>".to_string(),
1108 effects: vec![],
1109 desc: None,
1110 body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
1111 resolution: None,
1112 }
1113 }
1114
1115 #[test]
1116 fn matches_canonical_buffer_build() {
1117 let fd = canonical_builder("build");
1118 let info = compute_buffer_build_sinks(&[&fd]);
1119 let shape = info.get("build").expect("expected match");
1120 assert_eq!(shape.acc_param_idx, 1);
1121 assert_eq!(shape.acc_param_name, "acc");
1122 }
1123
1124 #[test]
1125 fn rejects_fn_without_list_param() {
1126 let mut fd = canonical_builder("build");
1127 fd.params = vec![("col".to_string(), "Int".to_string())];
1129 let info = compute_buffer_build_sinks(&[&fd]);
1130 assert!(info.is_empty(), "fn without List param should not match");
1131 }
1132
1133 #[test]
1134 fn rejects_when_true_arm_isnt_reverse() {
1135 let mut fd = canonical_builder("build");
1136 if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1138 if let Stmt::Expr(spanned) = &mut stmts[0] {
1139 if let Expr::Match { arms, .. } = &mut spanned.node {
1140 arms[0].body = Box::new(ident("acc"));
1141 }
1142 }
1143 }
1144 let info = compute_buffer_build_sinks(&[&fd]);
1145 assert!(
1146 info.is_empty(),
1147 "fn returning bare acc instead of reverse should not match"
1148 );
1149 }
1150
1151 #[test]
1152 fn rejects_when_false_arm_uses_append_not_prepend() {
1153 let mut fd = canonical_builder("build");
1154 if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1156 if let Stmt::Expr(spanned) = &mut stmts[0] {
1157 if let Expr::Match { arms, .. } = &mut spanned.node {
1158 let false_body = arms[1].body.as_mut();
1159 if let Expr::TailCall(data) = &mut false_body.node {
1160 if let Expr::FnCall(callee, _) = &mut data.args[1].node {
1161 if let Expr::Attr(_, attr) = &mut callee.node {
1162 *attr = "append".to_string();
1163 }
1164 }
1165 }
1166 }
1167 }
1168 }
1169 let info = compute_buffer_build_sinks(&[&fd]);
1170 assert!(
1171 info.is_empty(),
1172 "fn using List.append instead of prepend should not match"
1173 );
1174 }
1175
1176 #[test]
1177 fn rejects_tail_call_to_different_fn() {
1178 let mut fd = canonical_builder("build");
1179 if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1180 if let Stmt::Expr(spanned) = &mut stmts[0] {
1181 if let Expr::Match { arms, .. } = &mut spanned.node {
1182 let false_body = arms[1].body.as_mut();
1183 if let Expr::TailCall(data) = &mut false_body.node {
1184 data.target = "someone_else".to_string();
1185 }
1186 }
1187 }
1188 }
1189 let info = compute_buffer_build_sinks(&[&fd]);
1190 assert!(
1191 info.is_empty(),
1192 "fn whose recursive call targets a different name should not match"
1193 );
1194 }
1195
1196 #[test]
1197 fn rejects_match_with_non_bool_arms() {
1198 let mut fd = canonical_builder("build");
1199 if let FnBody::Block(stmts) = Arc::make_mut(&mut fd.body) {
1200 if let Stmt::Expr(spanned) = &mut stmts[0] {
1201 if let Expr::Match { arms, .. } = &mut spanned.node {
1202 arms[0].pattern = Pattern::Literal(Literal::Int(0));
1203 }
1204 }
1205 }
1206 let info = compute_buffer_build_sinks(&[&fd]);
1207 assert!(
1208 info.is_empty(),
1209 "match on non-bool patterns should not be detected as buffer-build"
1210 );
1211 }
1212
1213 #[test]
1218 fn detects_via_parser_after_tco() {
1219 let src = r#"
1220fn build(n: Int, acc: List<Int>) -> List<Int>
1221 match n <= 0
1222 true -> List.reverse(acc)
1223 false -> build(n - 1, List.prepend(n, acc))
1224"#;
1225 let mut lexer = crate::lexer::Lexer::new(src);
1226 let tokens = lexer.tokenize().expect("lex");
1227 let mut parser = crate::parser::Parser::new(tokens);
1228 let mut items = parser.parse().expect("parse");
1229 crate::ir::pipeline::tco(&mut items);
1230 let fns: Vec<&FnDef> = items
1231 .iter()
1232 .filter_map(|it| match it {
1233 crate::ast::TopLevel::FnDef(fd) => Some(fd),
1234 _ => None,
1235 })
1236 .collect();
1237 let info = compute_buffer_build_sinks(&fns);
1238 let shape = info
1239 .get("build")
1240 .expect("expected end-to-end shape match for canonical builder");
1241 assert_eq!(shape.acc_param_idx, 1);
1242 assert_eq!(shape.acc_param_name, "acc");
1243 }
1244
1245 #[test]
1248 fn finds_fusion_site_via_parser() {
1249 let src = r#"
1250fn build(n: Int, acc: List<Int>) -> List<Int>
1251 match n <= 0
1252 true -> List.reverse(acc)
1253 false -> build(n - 1, List.prepend(n, acc))
1254
1255fn main() -> String
1256 String.join(build(5, []), ",")
1257"#;
1258 let mut lexer = crate::lexer::Lexer::new(src);
1259 let tokens = lexer.tokenize().expect("lex");
1260 let mut parser = crate::parser::Parser::new(tokens);
1261 let mut items = parser.parse().expect("parse");
1262 crate::ir::pipeline::tco(&mut items);
1263 let fns: Vec<&FnDef> = items
1264 .iter()
1265 .filter_map(|it| match it {
1266 crate::ast::TopLevel::FnDef(fd) => Some(fd),
1267 _ => None,
1268 })
1269 .collect();
1270 let sinks = compute_buffer_build_sinks(&fns);
1271 let sites = find_fusion_sites(&fns, &sinks);
1272 assert_eq!(sites.len(), 1, "expected one fusion site, got {sites:?}");
1273 let site = &sites[0];
1274 assert_eq!(site.enclosing_fn, "main");
1275 assert_eq!(site.sink_fn, "build");
1276 assert!(site.line > 0, "expected real line info, got 0");
1277 }
1278
1279 #[test]
1283 fn ignores_call_when_not_wrapped_in_string_join() {
1284 let src = r#"
1285fn build(n: Int, acc: List<Int>) -> List<Int>
1286 match n <= 0
1287 true -> List.reverse(acc)
1288 false -> build(n - 1, List.prepend(n, acc))
1289
1290fn main() -> List<Int>
1291 build(5, [])
1292"#;
1293 let mut lexer = crate::lexer::Lexer::new(src);
1294 let tokens = lexer.tokenize().expect("lex");
1295 let mut parser = crate::parser::Parser::new(tokens);
1296 let mut items = parser.parse().expect("parse");
1297 crate::ir::pipeline::tco(&mut items);
1298 let fns: Vec<&FnDef> = items
1299 .iter()
1300 .filter_map(|it| match it {
1301 crate::ast::TopLevel::FnDef(fd) => Some(fd),
1302 _ => None,
1303 })
1304 .collect();
1305 let sinks = compute_buffer_build_sinks(&fns);
1306 let sites = find_fusion_sites(&fns, &sinks);
1307 assert!(
1308 sites.is_empty(),
1309 "build called outside String.join must not be a fusion site, got {sites:?}"
1310 );
1311 }
1312
1313 #[test]
1319 fn rejects_via_parser_when_true_arm_returns_bare_acc() {
1320 let src = r#"
1321fn build(n: Int, acc: List<Int>) -> List<Int>
1322 match n <= 0
1323 true -> acc
1324 false -> build(n - 1, List.prepend(n, acc))
1325"#;
1326 let mut lexer = crate::lexer::Lexer::new(src);
1327 let tokens = lexer.tokenize().expect("lex");
1328 let mut parser = crate::parser::Parser::new(tokens);
1329 let mut items = parser.parse().expect("parse");
1330 crate::ir::pipeline::tco(&mut items);
1331 let fns: Vec<&FnDef> = items
1332 .iter()
1333 .filter_map(|it| match it {
1334 crate::ast::TopLevel::FnDef(fd) => Some(fd),
1335 _ => None,
1336 })
1337 .collect();
1338 let info = compute_buffer_build_sinks(&fns);
1339 assert!(
1340 info.is_empty(),
1341 "fn returning bare acc must not be detected as a deforestation candidate"
1342 );
1343 }
1344
1345 #[test]
1351 fn synthesizes_buffered_variant_from_real_builder() {
1352 let src = r#"
1353fn build(n: Int, acc: List<Int>) -> List<Int>
1354 match n <= 0
1355 true -> List.reverse(acc)
1356 false -> build(n - 1, List.prepend(n, acc))
1357"#;
1358 let mut lexer = crate::lexer::Lexer::new(src);
1359 let tokens = lexer.tokenize().expect("lex");
1360 let mut parser = crate::parser::Parser::new(tokens);
1361 let mut items = parser.parse().expect("parse");
1362 crate::ir::pipeline::tco(&mut items);
1363 let fns: Vec<&FnDef> = items
1364 .iter()
1365 .filter_map(|it| match it {
1366 crate::ast::TopLevel::FnDef(fd) => Some(fd),
1367 _ => None,
1368 })
1369 .collect();
1370 let sinks = compute_buffer_build_sinks(&fns);
1371 assert!(sinks.contains_key("build"));
1372 let synthesized = synthesize_buffered_variants(&fns, &sinks);
1373 assert_eq!(
1374 synthesized.len(),
1375 1,
1376 "expected exactly one synthesized variant"
1377 );
1378 let bf = &synthesized[0];
1379
1380 assert_eq!(bf.name, "build__buffered");
1382 assert_eq!(bf.return_type, "Buffer");
1383 let param_names: Vec<&str> = bf.params.iter().map(|(n, _)| n.as_str()).collect();
1384 let param_types: Vec<&str> = bf.params.iter().map(|(_, t)| t.as_str()).collect();
1385 assert_eq!(param_names, vec!["n", "__buf", "__sep"]);
1386 assert_eq!(param_types, vec!["Int", "Buffer", "String"]);
1387
1388 let stmts = bf.body.stmts();
1390 assert_eq!(stmts.len(), 1);
1391 let match_expr = match &stmts[0] {
1392 Stmt::Expr(s) => match &s.node {
1393 Expr::Match { subject: _, arms } => arms,
1394 _ => panic!("body root must be a match"),
1395 },
1396 _ => panic!("body root must be Stmt::Expr"),
1397 };
1398 assert_eq!(match_expr.len(), 2);
1399
1400 let true_arm = match_expr
1402 .iter()
1403 .find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(true))))
1404 .expect("true arm");
1405 match &true_arm.body.node {
1406 Expr::Ident(name) => assert_eq!(name, "__buf"),
1407 other => panic!("true arm should be Ident(__buf), got {other:?}"),
1408 }
1409
1410 let false_arm = match_expr
1412 .iter()
1413 .find(|a| matches!(a.pattern, Pattern::Literal(Literal::Bool(false))))
1414 .expect("false arm");
1415 let tail_data = match &false_arm.body.node {
1416 Expr::TailCall(d) => d,
1417 other => panic!("false arm should be TailCall, got {other:?}"),
1418 };
1419 assert_eq!(tail_data.target, "build__buffered");
1420 assert_eq!(tail_data.args.len(), 3);
1424 let outer = match &tail_data.args[1].node {
1427 Expr::FnCall(callee, args) => {
1428 match &callee.node {
1429 Expr::Ident(name) => assert_eq!(name, "__buf_append"),
1430 _ => panic!("expected Ident callee"),
1431 }
1432 args
1433 }
1434 _ => panic!("expected outer __buf_append FnCall"),
1435 };
1436 assert_eq!(outer.len(), 2);
1437 match &outer[0].node {
1439 Expr::FnCall(callee, _) => match &callee.node {
1440 Expr::Ident(name) => assert_eq!(name, "__buf_append_sep_unless_first"),
1441 _ => panic!("expected Ident callee for inner intrinsic"),
1442 },
1443 _ => panic!("expected inner __buf_append_sep_unless_first FnCall"),
1444 }
1445 match &outer[1].node {
1447 Expr::Ident(name) => assert_eq!(name, "n"),
1448 _ => panic!("expected `n` ident as elem"),
1449 }
1450 match &tail_data.args[2].node {
1452 Expr::Ident(name) => assert_eq!(name, "__sep"),
1453 _ => panic!("expected __sep ident as last arg"),
1454 }
1455 }
1456
1457 #[test]
1458 fn detects_acc_param_at_arbitrary_index() {
1459 let true_body = call(dotted("List", "reverse"), vec![ident("acc")]);
1467 let prepend = call(dotted("List", "prepend"), vec![ident("col"), ident("acc")]);
1468 let false_body = sp(Expr::TailCall(Box::new(TailCallData {
1471 target: "build".to_string(),
1472 args: vec![
1473 prepend,
1474 sp(Expr::BinOp(
1475 BinOp::Add,
1476 Box::new(ident("col")),
1477 Box::new(sp(Expr::Literal(Literal::Int(1)))),
1478 )),
1479 ],
1480 })));
1481 let match_expr = sp(Expr::Match {
1482 subject: Box::new(sp(Expr::BinOp(
1483 BinOp::Gte,
1484 Box::new(ident("col")),
1485 Box::new(sp(Expr::Literal(Literal::Int(10)))),
1486 ))),
1487 arms: vec![
1488 MatchArm {
1489 pattern: Pattern::Literal(Literal::Bool(true)),
1490 body: Box::new(true_body),
1491 },
1492 MatchArm {
1493 pattern: Pattern::Literal(Literal::Bool(false)),
1494 body: Box::new(false_body),
1495 },
1496 ],
1497 });
1498 let fd = FnDef {
1499 name: "build".to_string(),
1500 line: 1,
1501 params: vec![
1502 ("acc".to_string(), "List<Int>".to_string()),
1503 ("col".to_string(), "Int".to_string()),
1504 ],
1505 return_type: "List<Int>".to_string(),
1506 effects: vec![],
1507 desc: None,
1508 body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
1509 resolution: None,
1510 };
1511 let info = compute_buffer_build_sinks(&[&fd]);
1512 let shape = info.get("build").expect("expected match");
1513 assert_eq!(shape.acc_param_idx, 0);
1514 assert_eq!(shape.acc_param_name, "acc");
1515 }
1516
1517 #[test]
1518 fn rejects_loose_prepend_in_non_acc_position() {
1519 let mut fd = canonical_builder("build");
1524 {
1529 let body = std::sync::Arc::make_mut(&mut fd.body);
1530 let FnBody::Block(stmts) = body;
1531 if let Stmt::Expr(spanned) = &mut stmts[0]
1532 && let Expr::Match { arms, .. } = &mut spanned.node
1533 {
1534 for arm in arms.iter_mut() {
1535 if matches!(arm.pattern, Pattern::Literal(Literal::Bool(false)))
1536 && let Expr::TailCall(data) = &mut arm.body.node
1537 {
1538 data.args.reverse();
1539 }
1540 }
1541 }
1542 }
1543 let info = compute_buffer_build_sinks(&[&fd]);
1544 assert!(
1545 info.get("build").is_none(),
1546 "loose-prepend (prepend not at acc-position) must not be detected"
1547 );
1548 }
1549
1550 #[test]
1551 fn skips_synth_when_no_rewriteable_call_site() {
1552 let sink = canonical_builder("build");
1559 let caller = FnDef {
1561 name: "use_build".to_string(),
1562 line: 2,
1563 params: vec![],
1564 return_type: "List<Int>".to_string(),
1565 effects: vec![],
1566 desc: None,
1567 body: Arc::new(FnBody::Block(vec![Stmt::Expr(call(
1568 ident_expr("build"),
1569 vec![sp(Expr::Literal(Literal::Int(0))), sp(Expr::List(vec![]))],
1570 ))])),
1571 resolution: None,
1572 };
1573 let mut items = vec![
1574 crate::ast::TopLevel::FnDef(sink),
1575 crate::ast::TopLevel::FnDef(caller),
1576 ];
1577 let initial_count = items.len();
1578 let (sites, synth) = run_buffer_build_pass(&mut items);
1579 assert_eq!(sites, 0, "no fusion sites — no rewriteable call");
1580 assert_eq!(synth, 0, "no synth — nothing to fuse against");
1581 assert_eq!(items.len(), initial_count, "no buffered variant appended");
1582 }
1583
1584 #[test]
1585 fn external_reverse_pattern_round_trips() {
1586 let nil_body = ident("acc");
1590 let prepend = call(dotted("List", "prepend"), vec![ident("h"), ident("acc")]);
1591 let cons_body = sp(Expr::TailCall(Box::new(TailCallData {
1592 target: "build".to_string(),
1593 args: vec![ident("t"), prepend],
1594 })));
1595 let match_expr = sp(Expr::Match {
1596 subject: Box::new(ident("xs")),
1597 arms: vec![
1598 MatchArm {
1599 pattern: Pattern::EmptyList,
1600 body: Box::new(nil_body),
1601 },
1602 MatchArm {
1603 pattern: Pattern::Cons("h".to_string(), "t".to_string()),
1604 body: Box::new(cons_body),
1605 },
1606 ],
1607 });
1608 let sink = FnDef {
1609 name: "build".to_string(),
1610 line: 1,
1611 params: vec![
1612 ("xs".to_string(), "List<Int>".to_string()),
1613 ("acc".to_string(), "List<String>".to_string()),
1614 ],
1615 return_type: "List<String>".to_string(),
1616 effects: vec![],
1617 desc: None,
1618 body: Arc::new(FnBody::Block(vec![Stmt::Expr(match_expr)])),
1619 resolution: None,
1620 };
1621 let info = compute_buffer_build_sinks(&[&sink]);
1622 let shape = info
1623 .get("build")
1624 .expect("external-reverse sink should be detected");
1625 assert_eq!(shape.kind, BufferBuildKind::ExternalReverse);
1626 assert_eq!(shape.acc_param_idx, 1);
1627
1628 let join_call = call(
1630 dotted("String", "join"),
1631 vec![
1632 call(
1633 dotted("List", "reverse"),
1634 vec![call(
1635 ident_expr("build"),
1636 vec![ident("xs"), sp(Expr::List(vec![]))],
1637 )],
1638 ),
1639 sp(Expr::Literal(Literal::Str("\n".to_string()))),
1640 ],
1641 );
1642 let caller = FnDef {
1643 name: "render".to_string(),
1644 line: 2,
1645 params: vec![("xs".to_string(), "List<Int>".to_string())],
1646 return_type: "String".to_string(),
1647 effects: vec![],
1648 desc: None,
1649 body: Arc::new(FnBody::Block(vec![Stmt::Expr(join_call)])),
1650 resolution: None,
1651 };
1652
1653 let mut items = vec![
1654 crate::ast::TopLevel::FnDef(sink),
1655 crate::ast::TopLevel::FnDef(caller),
1656 ];
1657 let (sites, synth) = run_buffer_build_pass(&mut items);
1658 assert_eq!(
1659 sites, 1,
1660 "external-reverse pattern should be one fusion site"
1661 );
1662 assert_eq!(synth, 1, "exactly one buffered variant for the used sink");
1663
1664 let synth_present = items.iter().any(|it| match it {
1666 crate::ast::TopLevel::FnDef(fd) => fd.name == "build__buffered",
1667 _ => false,
1668 });
1669 assert!(synth_present, "build__buffered must be appended");
1670 }
1671
1672 fn ident_expr(name: &str) -> Spanned<Expr> {
1673 sp(Expr::Ident(name.to_string()))
1674 }
1675}