1use std::collections::HashSet;
2
3use crate::ast::{
4 Expr, FnBody, FnDef, Spanned, Stmt, StrPart, TailCallData, TopLevel, TypeDef, TypeVariant,
5 VerifyBlock, VerifyGivenDomain, VerifyKind,
6};
7use crate::codegen::CodegenContext;
8use crate::types::Type;
9
10pub fn is_pure_fn(fd: &FnDef) -> bool {
16 fd.effects.is_empty() && fd.name != "main"
17}
18
19pub fn is_recursive_type_def(td: &TypeDef) -> bool {
22 match td {
23 TypeDef::Sum { name, variants, .. } => is_recursive_sum(name, variants),
24 TypeDef::Product { name, fields, .. } => is_recursive_product(name, fields),
25 }
26}
27
28pub fn type_def_name(td: &TypeDef) -> &str {
30 match td {
31 TypeDef::Sum { name, .. } | TypeDef::Product { name, .. } => name,
32 }
33}
34
35pub fn is_recursive_sum(name: &str, variants: &[TypeVariant]) -> bool {
39 variants
40 .iter()
41 .any(|v| v.fields.iter().any(|f| type_ref_contains(f, name)))
42}
43
44pub fn is_recursive_product(name: &str, fields: &[(String, String)]) -> bool {
46 fields.iter().any(|(_, ty)| type_ref_contains(ty, name))
47}
48
49fn type_ref_contains(annotation: &str, type_name: &str) -> bool {
50 annotation == type_name
53 || annotation.contains(&format!("<{}", type_name))
54 || annotation.contains(&format!("{}>", type_name))
55 || annotation.contains(&format!(", {}", type_name))
56 || annotation.contains(&format!("{},", type_name))
57}
58
59pub(crate) fn is_user_type(name: &str, ctx: &CodegenContext) -> bool {
61 let check_td = |td: &TypeDef| match td {
62 TypeDef::Sum { name: n, .. } => n == name,
63 TypeDef::Product { name: n, .. } => n == name,
64 };
65 ctx.type_defs.iter().any(check_td)
66 || ctx.modules.iter().any(|m| m.type_defs.iter().any(check_td))
67}
68
69pub(crate) fn resolve_module_call<'a>(
72 dotted_name: &'a str,
73 ctx: &'a CodegenContext,
74) -> Option<(&'a str, &'a str)> {
75 let mut best: Option<&str> = None;
76 for prefix in &ctx.module_prefixes {
77 let dotted_prefix = format!("{}.", prefix);
78 if dotted_name.starts_with(&dotted_prefix) && best.is_none_or(|b| prefix.len() > b.len()) {
79 best = Some(prefix.as_str());
80 }
81 }
82 best.map(|prefix| (prefix, &dotted_name[prefix.len() + 1..]))
83}
84
85pub(crate) fn module_prefix_to_rust_segments(prefix: &str) -> Vec<String> {
86 prefix.split('.').map(module_segment_to_rust).collect()
87}
88
89pub(crate) fn module_prefix_to_filename(prefix: &str) -> String {
94 prefix.replace('.', "/")
95}
96
97pub(crate) struct DeclaredEffects {
109 pub bare_namespaces: HashSet<String>,
110 pub methods: HashSet<String>,
111}
112
113impl DeclaredEffects {
114 pub fn includes(&self, c_method: &str) -> bool {
117 if self.methods.contains(c_method) {
118 return true;
119 }
120 if let Some((ns, _)) = c_method.split_once('.') {
121 return self.bare_namespaces.contains(ns);
122 }
123 false
124 }
125}
126
127pub(crate) fn collect_declared_effects(ctx: &CodegenContext) -> DeclaredEffects {
131 let mut bare_namespaces: HashSet<String> = HashSet::new();
132 let mut methods: HashSet<String> = HashSet::new();
133 let mut record = |effect: &str| {
134 if effect.contains('.') {
135 methods.insert(effect.to_string());
136 } else {
137 bare_namespaces.insert(effect.to_string());
138 }
139 };
140 for item in &ctx.items {
141 if let TopLevel::FnDef(fd) = item {
142 for eff in &fd.effects {
143 record(&eff.node);
144 }
145 }
146 }
147 for module in &ctx.modules {
148 for fd in &module.fn_defs {
149 for eff in &fd.effects {
150 record(&eff.node);
151 }
152 }
153 }
154 DeclaredEffects {
155 bare_namespaces,
156 methods,
157 }
158}
159
160pub fn entry_basename(ctx: &CodegenContext) -> String {
168 ctx.items
169 .iter()
170 .find_map(|item| match item {
171 TopLevel::Module(m) => Some(m.name.clone()),
172 _ => None,
173 })
174 .unwrap_or_else(|| {
175 let mut chars = ctx.project_name.chars();
176 match chars.next() {
177 None => String::new(),
178 Some(c) => c.to_uppercase().chain(chars).collect(),
179 }
180 })
181}
182
183pub(crate) fn fn_owning_scope(ctx: &CodegenContext) -> std::collections::HashMap<String, String> {
188 let mut scope = std::collections::HashMap::new();
189 for m in &ctx.modules {
190 for fd in &m.fn_defs {
191 scope.insert(fd.name.clone(), m.prefix.clone());
192 }
193 }
194 for fd in &ctx.fn_defs {
195 scope.insert(fd.name.clone(), String::new());
196 }
197 scope
198}
199
200pub(crate) fn module_prefix_to_rust_path(prefix: &str) -> String {
201 format!(
202 "crate::aver_generated::{}",
203 module_prefix_to_rust_segments(prefix).join("::")
204 )
205}
206
207fn module_segment_to_rust(segment: &str) -> String {
208 let chars = segment.chars().collect::<Vec<_>>();
209 let mut out = String::new();
210
211 for (idx, ch) in chars.iter().enumerate() {
212 if ch.is_ascii_alphanumeric() {
213 if ch.is_ascii_uppercase() {
214 let prev_is_lower_or_digit = idx > 0
215 && (chars[idx - 1].is_ascii_lowercase() || chars[idx - 1].is_ascii_digit());
216 let next_is_lower = chars
217 .get(idx + 1)
218 .is_some_and(|next| next.is_ascii_lowercase());
219 if idx > 0 && (prev_is_lower_or_digit || next_is_lower) && !out.ends_with('_') {
220 out.push('_');
221 }
222 out.push(ch.to_ascii_lowercase());
223 } else {
224 out.push(ch.to_ascii_lowercase());
225 }
226 } else if !out.ends_with('_') {
227 out.push('_');
228 }
229 }
230
231 let trimmed = out.trim_matches('_');
232 let mut normalized = if trimmed.is_empty() {
233 "module".to_string()
234 } else {
235 trimmed.to_string()
236 };
237
238 if matches!(
239 normalized.as_str(),
240 "as" | "break"
241 | "const"
242 | "continue"
243 | "crate"
244 | "else"
245 | "enum"
246 | "extern"
247 | "false"
248 | "fn"
249 | "for"
250 | "if"
251 | "impl"
252 | "in"
253 | "let"
254 | "loop"
255 | "match"
256 | "mod"
257 | "move"
258 | "mut"
259 | "pub"
260 | "ref"
261 | "return"
262 | "self"
263 | "Self"
264 | "static"
265 | "struct"
266 | "super"
267 | "trait"
268 | "true"
269 | "type"
270 | "unsafe"
271 | "use"
272 | "where"
273 | "while"
274 ) {
275 normalized.push_str("_mod");
276 }
277
278 normalized
279}
280
281pub(crate) fn split_type_params(s: &str, delim: char) -> Vec<String> {
286 let mut parts = Vec::new();
287 let mut depth = 0usize;
288 let mut current = String::new();
289 for ch in s.chars() {
290 match ch {
291 '<' | '(' => {
292 depth += 1;
293 current.push(ch);
294 }
295 '>' | ')' => {
296 depth = depth.saturating_sub(1);
297 current.push(ch);
298 }
299 _ if ch == delim && depth == 0 => {
300 parts.push(current.trim().to_string());
301 current.clear();
302 }
303 _ => current.push(ch),
304 }
305 }
306 let rest = current.trim().to_string();
307 if !rest.is_empty() {
308 parts.push(rest);
309 }
310 parts
311}
312
313pub(crate) fn escape_string_literal_ext(s: &str, unicode_escapes: bool) -> String {
320 let mut out = String::with_capacity(s.len());
321 for ch in s.chars() {
322 match ch {
323 '\\' => out.push_str("\\\\"),
324 '"' => out.push_str("\\\""),
325 '\n' => out.push_str("\\n"),
326 '\r' => out.push_str("\\r"),
327 '\t' => out.push_str("\\t"),
328 '\0' => out.push_str("\\0"),
329 c if c.is_control() => {
330 if unicode_escapes {
331 out.push_str(&format!("\\U{{{:06x}}}", c as u32));
333 } else {
334 out.push_str(&format!("\\x{:02x}", c as u32));
335 }
336 }
337 c => out.push(c),
338 }
339 }
340 out
341}
342
343pub(crate) fn escape_string_literal(s: &str) -> String {
345 escape_string_literal_ext(s, false)
346}
347
348pub(crate) fn escape_string_literal_unicode(s: &str) -> String {
350 escape_string_literal_ext(s, true)
351}
352
353pub(crate) fn parse_type_annotation(ann: &str) -> Type {
357 crate::types::parse_type_str(ann)
358}
359
360pub(crate) fn is_set_type(ty: &Type) -> bool {
366 matches!(ty, Type::Map(_, v) if matches!(v.as_ref(), Type::Unit))
367}
368
369pub(crate) fn is_set_annotation(ann: &str) -> bool {
371 is_set_type(&parse_type_annotation(ann))
372}
373
374pub(crate) fn is_unit_expr(expr: &crate::ast::Expr) -> bool {
376 matches!(expr, crate::ast::Expr::Literal(crate::ast::Literal::Unit))
377}
378
379pub(crate) fn is_unit_expr_spanned(expr: &crate::ast::Spanned<crate::ast::Expr>) -> bool {
381 is_unit_expr(&expr.node)
382}
383
384pub(crate) fn escape_reserved_word(name: &str, reserved: &[&str], suffix: &str) -> String {
389 if reserved.contains(&name) {
390 format!("{}{}", name, suffix)
391 } else {
392 name.to_string()
393 }
394}
395
396pub(crate) fn escape_reserved_word_prefix(name: &str, reserved: &[&str], prefix: &str) -> String {
399 if reserved.contains(&name) {
400 format!("{}{}", prefix, name)
401 } else {
402 name.to_string()
403 }
404}
405
406pub(crate) fn to_lower_first(s: &str) -> String {
410 let mut chars = s.chars();
411 match chars.next() {
412 None => String::new(),
413 Some(c) => c.to_lowercase().to_string() + chars.as_str(),
414 }
415}
416
417pub(crate) fn expr_to_dotted_name(expr: &Expr) -> Option<String> {
420 crate::ir::expr_to_dotted_name(expr)
421}
422
423#[derive(Debug, Clone)]
437pub(crate) enum OracleInjectionMode<'a> {
438 LemmaBinding,
439 LemmaBindingProjected,
449 #[allow(dead_code)]
450 SampleValue,
451 SampleCaseBinding(&'a [(String, crate::ast::Spanned<Expr>)]),
452}
453
454pub(crate) fn rewrite_effectful_calls_in_law(
463 expr: &crate::ast::Spanned<Expr>,
464 law: &crate::ast::VerifyLaw,
465 ctx: &CodegenContext,
466 mode: OracleInjectionMode,
467) -> crate::ast::Spanned<Expr> {
468 use crate::ast::{Spanned, VerifyGivenDomain};
469
470 let injection_by_effect: std::collections::HashMap<String, Spanned<Expr>> = law
471 .givens
472 .iter()
473 .filter_map(|g| {
474 let arg_expr = match &mode {
475 OracleInjectionMode::LemmaBinding => {
476 Spanned::new(Expr::Ident(g.name.clone()), expr.line)
477 }
478 OracleInjectionMode::LemmaBindingProjected => {
479 Spanned::new(Expr::Ident(g.name.clone()), expr.line)
487 }
488 OracleInjectionMode::SampleValue => match &g.domain {
489 VerifyGivenDomain::Explicit(vals) => vals.first().cloned()?,
490 _ => return None,
491 },
492 OracleInjectionMode::SampleCaseBinding(case_bindings) => case_bindings
493 .iter()
494 .find(|(name, _)| name == &g.name)
495 .map(|(_, v)| v.clone())?,
496 };
497 Some((g.type_name.clone(), arg_expr))
498 })
499 .collect();
500 let rewritten = rewrite_effectful_call(expr, &injection_by_effect, ctx);
501
502 if matches!(mode, OracleInjectionMode::LemmaBindingProjected) {
511 let oracle_names: std::collections::HashSet<String> = law
512 .givens
513 .iter()
514 .filter(|g| {
515 matches!(
516 crate::types::checker::effect_classification::classify(&g.type_name)
517 .map(|c| c.dimension),
518 Some(crate::types::checker::effect_classification::EffectDimension::Generative)
519 | Some(
520 crate::types::checker::effect_classification::EffectDimension::GenerativeOutput
521 )
522 )
523 })
524 .map(|g| g.name.clone())
525 .collect();
526 if !oracle_names.is_empty() {
527 return project_oracle_direct_calls(&rewritten, &oracle_names);
528 }
529 }
530 rewritten
531}
532
533fn project_oracle_direct_calls(
546 expr: &crate::ast::Spanned<Expr>,
547 oracle_names: &std::collections::HashSet<String>,
548) -> crate::ast::Spanned<Expr> {
549 use crate::ast::Spanned;
550 let line = expr.line;
551 let project_ident = |name: &str, line: usize| -> Spanned<Expr> {
552 Spanned::new(
553 Expr::Attr(
554 Box::new(Spanned::new(Expr::Ident(name.to_string()), line)),
555 "val".to_string(),
556 ),
557 line,
558 )
559 };
560 let new_node = match &expr.node {
561 Expr::Ident(name) if oracle_names.contains(name) => {
565 return project_ident(name, line);
566 }
567 Expr::FnCall(callee, args) => {
568 let new_args: Vec<Spanned<Expr>> = args
569 .iter()
570 .map(|a| project_oracle_direct_calls(a, oracle_names))
571 .collect();
572 let new_callee = if let Expr::Ident(name) = &callee.node
574 && oracle_names.contains(name)
575 {
576 project_ident(name, callee.line)
577 } else {
578 project_oracle_direct_calls(callee, oracle_names)
579 };
580 Expr::FnCall(Box::new(new_callee), new_args)
581 }
582 Expr::Constructor(name, Some(arg)) => Expr::Constructor(
583 name.clone(),
584 Some(Box::new(project_oracle_direct_calls(arg, oracle_names))),
585 ),
586 Expr::Attr(obj, field) => Expr::Attr(
587 Box::new(project_oracle_direct_calls(obj, oracle_names)),
588 field.clone(),
589 ),
590 Expr::BinOp(op, l, r) => Expr::BinOp(
591 *op,
592 Box::new(project_oracle_direct_calls(l, oracle_names)),
593 Box::new(project_oracle_direct_calls(r, oracle_names)),
594 ),
595 other => other.clone(),
596 };
597 Spanned::new(new_node, line)
598}
599
600fn rewrite_effectful_call(
601 expr: &crate::ast::Spanned<Expr>,
602 injection_by_effect: &std::collections::HashMap<String, crate::ast::Spanned<Expr>>,
603 ctx: &CodegenContext,
604) -> crate::ast::Spanned<Expr> {
605 use crate::ast::Spanned;
606 use crate::types::checker::effect_classification::{EffectDimension, classify};
607
608 match &expr.node {
609 Expr::FnCall(callee, args) => {
610 let rewritten_args: Vec<Spanned<Expr>> = args
611 .iter()
612 .map(|a| rewrite_effectful_call(a, injection_by_effect, ctx))
613 .collect();
614 let rewritten_callee =
615 Box::new(rewrite_effectful_call(callee, injection_by_effect, ctx));
616
617 let callee_name = match &callee.node {
618 Expr::Ident(name) => Some(name.clone()),
619 Expr::Resolved { name, .. } => Some(name.clone()),
620 _ => None,
621 };
622
623 if let Some(name) = callee_name
624 && let Some(fd) = ctx.fn_defs.iter().find(|fd| fd.name == name)
625 && !fd.effects.is_empty()
626 && fd
627 .effects
628 .iter()
629 .all(|e| crate::types::checker::effect_classification::is_classified(&e.node))
630 {
631 let mut injected: Vec<Spanned<Expr>> = Vec::new();
632 let needs_path = fd.effects.iter().any(|e| {
633 matches!(
634 classify(&e.node).map(|c| c.dimension),
635 Some(EffectDimension::Generative | EffectDimension::GenerativeOutput)
636 )
637 });
638 if needs_path {
639 injected.push(Spanned::new(
640 Expr::Attr(
644 Box::new(Spanned::new(
645 Expr::Ident("BranchPath".to_string()),
646 expr.line,
647 )),
648 "Root".to_string(),
649 ),
650 expr.line,
651 ));
652 }
653 let mut seen = std::collections::HashSet::new();
654 for e in &fd.effects {
655 if !seen.insert(e.node.clone()) {
656 continue;
657 }
658 let Some(c) = classify(&e.node) else { continue };
659 if matches!(c.dimension, EffectDimension::Output) {
660 continue;
661 }
662 if let Some(inj) = injection_by_effect.get(&e.node) {
663 injected.push(inj.clone());
664 }
665 }
666 injected.extend(rewritten_args);
667 return Spanned::new(Expr::FnCall(rewritten_callee, injected), expr.line);
668 }
669
670 Spanned::new(Expr::FnCall(rewritten_callee, rewritten_args), expr.line)
671 }
672 Expr::BinOp(op, l, r) => Spanned::new(
673 Expr::BinOp(
674 *op,
675 Box::new(rewrite_effectful_call(l, injection_by_effect, ctx)),
676 Box::new(rewrite_effectful_call(r, injection_by_effect, ctx)),
677 ),
678 expr.line,
679 ),
680 Expr::Tuple(items) => Spanned::new(
681 Expr::Tuple(
682 items
683 .iter()
684 .map(|i| rewrite_effectful_call(i, injection_by_effect, ctx))
685 .collect(),
686 ),
687 expr.line,
688 ),
689 _ => expr.clone(),
690 }
691}
692
693pub(crate) fn verify_reachable_fn_names(items: &[TopLevel]) -> HashSet<String> {
703 let mut reachable: HashSet<String> = HashSet::new();
704 for item in items {
705 if let TopLevel::Verify(vb) = item {
706 collect_verify_block_refs(vb, &mut reachable);
707 }
708 }
709 loop {
711 let mut changed = false;
712 for item in items {
713 if let TopLevel::FnDef(fd) = item
714 && reachable.contains(&fd.name)
715 {
716 let mut called = HashSet::new();
717 collect_called_idents_in_body(&fd.body, &mut called);
718 for name in called {
719 if reachable.insert(name) {
720 changed = true;
721 }
722 }
723 }
724 }
725 if !changed {
726 break;
727 }
728 }
729 reachable
730}
731
732fn collect_verify_block_refs(vb: &VerifyBlock, out: &mut HashSet<String>) {
733 out.insert(vb.fn_name.clone());
734 for (lhs, rhs) in &vb.cases {
735 collect_called_idents(lhs, out);
736 collect_called_idents(rhs, out);
737 }
738 if let VerifyKind::Law(law) = &vb.kind {
739 collect_called_idents(&law.lhs, out);
740 collect_called_idents(&law.rhs, out);
741 if let Some(when) = &law.when {
742 collect_called_idents(when, out);
743 }
744 for given in &law.givens {
745 if let VerifyGivenDomain::Explicit(values) = &given.domain {
746 for v in values {
747 collect_called_idents(v, out);
748 }
749 }
750 }
751 }
752 for given in &vb.cases_givens {
753 if let VerifyGivenDomain::Explicit(values) = &given.domain {
754 for v in values {
755 collect_called_idents(v, out);
756 }
757 }
758 }
759}
760
761fn collect_called_idents_in_body(body: &FnBody, out: &mut HashSet<String>) {
762 for stmt in body.stmts() {
763 match stmt {
764 Stmt::Binding(_, _, e) | Stmt::Expr(e) => collect_called_idents(e, out),
765 }
766 }
767}
768
769fn collect_called_idents(expr: &Spanned<Expr>, out: &mut HashSet<String>) {
770 match &expr.node {
771 Expr::FnCall(callee, args) => {
772 if let Expr::Ident(name) | Expr::Resolved { name, .. } = &callee.node {
773 out.insert(name.clone());
774 } else {
775 collect_called_idents(callee, out);
776 }
777 for a in args {
778 collect_called_idents(a, out);
779 }
780 }
781 Expr::TailCall(boxed) => {
782 let TailCallData { target, args, .. } = boxed.as_ref();
783 out.insert(target.clone());
784 for a in args {
785 collect_called_idents(a, out);
786 }
787 }
788 Expr::Ident(name) | Expr::Resolved { name, .. } => {
789 out.insert(name.clone());
790 }
791 Expr::BinOp(_, l, r) => {
792 collect_called_idents(l, out);
793 collect_called_idents(r, out);
794 }
795 Expr::Match { subject, arms, .. } => {
796 collect_called_idents(subject, out);
797 for arm in arms {
798 collect_called_idents(&arm.body, out);
799 }
800 }
801 Expr::ErrorProp(inner) | Expr::Attr(inner, _) => {
802 collect_called_idents(inner, out);
803 }
804 Expr::Constructor(_, Some(inner)) => {
805 collect_called_idents(inner, out);
806 }
807 Expr::InterpolatedStr(parts) => {
808 for part in parts {
809 if let StrPart::Parsed(inner) = part {
810 collect_called_idents(inner, out);
811 }
812 }
813 }
814 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
815 for i in items {
816 collect_called_idents(i, out);
817 }
818 }
819 Expr::MapLiteral(entries) => {
820 for (k, v) in entries {
821 collect_called_idents(k, out);
822 collect_called_idents(v, out);
823 }
824 }
825 Expr::RecordCreate { fields, .. } => {
826 for (_, v) in fields {
827 collect_called_idents(v, out);
828 }
829 }
830 Expr::RecordUpdate { base, updates, .. } => {
831 collect_called_idents(base, out);
832 for (_, v) in updates {
833 collect_called_idents(v, out);
834 }
835 }
836 Expr::Literal(_) | Expr::Constructor(_, None) => {}
837 }
838}
839
840pub(crate) struct PerScopeSections {
844 pub by_scope: std::collections::HashMap<String, Vec<String>>,
845}
846
847impl PerScopeSections {
848 pub(crate) fn take(&mut self, scope: &str) -> Vec<String> {
849 self.by_scope.remove(scope).unwrap_or_default()
850 }
851}
852
853pub(crate) fn route_pure_components_per_scope<F, G>(
862 ctx: &CodegenContext,
863 is_pure: F,
864 mut emit: G,
865) -> PerScopeSections
866where
867 F: Fn(&FnDef) -> bool,
868 G: FnMut(&[&FnDef]) -> Vec<String>,
869{
870 let mut by_scope: std::collections::HashMap<String, Vec<String>> =
871 std::collections::HashMap::new();
872
873 let mut process =
874 |fns: Vec<&FnDef>,
875 scope: String,
876 by_scope: &mut std::collections::HashMap<String, Vec<String>>| {
877 let comps = crate::call_graph::ordered_fn_components(&fns, &ctx.module_prefixes);
878 let bucket = by_scope.entry(scope).or_default();
879 for comp in comps {
880 bucket.extend(emit(&comp));
881 }
882 };
883
884 for module in &ctx.modules {
885 let pure: Vec<&FnDef> = module.fn_defs.iter().filter(|fd| is_pure(fd)).collect();
886 process(pure, module.prefix.clone(), &mut by_scope);
887 }
888 let entry_pure: Vec<&FnDef> = ctx.fn_defs.iter().filter(|fd| is_pure(fd)).collect();
889 process(entry_pure, String::new(), &mut by_scope);
890
891 PerScopeSections { by_scope }
892}