1use std::collections::HashSet;
2
3use crate::ast::{
4 Expr, FnBody, FnDef, Spanned, Stmt, StrPart, TailCallData, TopLevel, TypeDef, TypeVariant,
5 VerifyBlock, VerifyGivenDomain, VerifyKind,
6};
7use crate::codegen::CodegenContext;
8use crate::types::Type;
9
10pub fn is_pure_fn(fd: &FnDef) -> bool {
16 fd.effects.is_empty() && fd.name != "main"
17}
18
19pub fn is_recursive_type_def(td: &TypeDef) -> bool {
22 match td {
23 TypeDef::Sum { name, variants, .. } => is_recursive_sum(name, variants),
24 TypeDef::Product { name, fields, .. } => is_recursive_product(name, fields),
25 }
26}
27
28pub fn type_def_name(td: &TypeDef) -> &str {
30 match td {
31 TypeDef::Sum { name, .. } | TypeDef::Product { name, .. } => name,
32 }
33}
34
35pub fn is_recursive_sum(name: &str, variants: &[TypeVariant]) -> bool {
39 variants
40 .iter()
41 .any(|v| v.fields.iter().any(|f| type_ref_contains(f, name)))
42}
43
44pub fn is_recursive_product(name: &str, fields: &[(String, String)]) -> bool {
46 fields.iter().any(|(_, ty)| type_ref_contains(ty, name))
47}
48
49fn type_ref_contains(annotation: &str, type_name: &str) -> bool {
50 annotation == type_name
53 || annotation.contains(&format!("<{}", type_name))
54 || annotation.contains(&format!("{}>", type_name))
55 || annotation.contains(&format!(", {}", type_name))
56 || annotation.contains(&format!("{},", type_name))
57}
58
59pub(crate) fn is_user_type(name: &str, ctx: &CodegenContext) -> bool {
61 let check_td = |td: &TypeDef| match td {
62 TypeDef::Sum { name: n, .. } => n == name,
63 TypeDef::Product { name: n, .. } => n == name,
64 };
65 ctx.type_defs.iter().any(check_td)
66 || ctx.modules.iter().any(|m| m.type_defs.iter().any(check_td))
67}
68
69pub(crate) fn resolve_module_call<'a>(
72 dotted_name: &'a str,
73 ctx: &'a CodegenContext,
74) -> Option<(&'a str, &'a str)> {
75 let mut best: Option<&str> = None;
76 for prefix in &ctx.module_prefixes {
77 let dotted_prefix = format!("{}.", prefix);
78 if dotted_name.starts_with(&dotted_prefix) && best.is_none_or(|b| prefix.len() > b.len()) {
79 best = Some(prefix.as_str());
80 }
81 }
82 best.map(|prefix| (prefix, &dotted_name[prefix.len() + 1..]))
83}
84
85pub(crate) fn module_prefix_to_rust_segments(prefix: &str) -> Vec<String> {
86 prefix.split('.').map(module_segment_to_rust).collect()
87}
88
89pub(crate) fn module_prefix_to_filename(prefix: &str) -> String {
94 prefix.replace('.', "/")
95}
96
97pub(crate) struct DeclaredEffects {
109 pub bare_namespaces: HashSet<String>,
110 pub methods: HashSet<String>,
111}
112
113impl DeclaredEffects {
114 pub fn includes(&self, c_method: &str) -> bool {
117 if self.methods.contains(c_method) {
118 return true;
119 }
120 if let Some((ns, _)) = c_method.split_once('.') {
121 return self.bare_namespaces.contains(ns);
122 }
123 false
124 }
125}
126
127pub(crate) fn collect_declared_effects(ctx: &CodegenContext) -> DeclaredEffects {
131 let mut bare_namespaces: HashSet<String> = HashSet::new();
132 let mut methods: HashSet<String> = HashSet::new();
133 let mut record = |effect: &str| {
134 if effect.contains('.') {
135 methods.insert(effect.to_string());
136 } else {
137 bare_namespaces.insert(effect.to_string());
138 }
139 };
140 for item in &ctx.items {
141 if let TopLevel::FnDef(fd) = item {
142 for eff in &fd.effects {
143 record(&eff.node);
144 }
145 }
146 }
147 for module in &ctx.modules {
148 for fd in &module.fn_defs {
149 for eff in &fd.effects {
150 record(&eff.node);
151 }
152 }
153 }
154 DeclaredEffects {
155 bare_namespaces,
156 methods,
157 }
158}
159
160pub fn entry_basename(ctx: &CodegenContext) -> String {
168 ctx.items
169 .iter()
170 .find_map(|item| match item {
171 TopLevel::Module(m) => Some(m.name.clone()),
172 _ => None,
173 })
174 .unwrap_or_else(|| {
175 let mut chars = ctx.project_name.chars();
176 match chars.next() {
177 None => String::new(),
178 Some(c) => c.to_uppercase().chain(chars).collect(),
179 }
180 })
181}
182
183pub(crate) fn fn_owning_scope(ctx: &CodegenContext) -> std::collections::HashMap<String, String> {
188 let mut scope = std::collections::HashMap::new();
189 for m in &ctx.modules {
190 for fd in &m.fn_defs {
191 scope.insert(fd.name.clone(), m.prefix.clone());
192 }
193 }
194 for fd in &ctx.fn_defs {
195 scope.insert(fd.name.clone(), String::new());
196 }
197 scope
198}
199
200pub(crate) fn module_prefix_to_rust_path(prefix: &str) -> String {
201 format!(
202 "crate::aver_generated::{}",
203 module_prefix_to_rust_segments(prefix).join("::")
204 )
205}
206
207fn module_segment_to_rust(segment: &str) -> String {
208 let chars = segment.chars().collect::<Vec<_>>();
209 let mut out = String::new();
210
211 for (idx, ch) in chars.iter().enumerate() {
212 if ch.is_ascii_alphanumeric() {
213 if ch.is_ascii_uppercase() {
214 let prev_is_lower_or_digit = idx > 0
215 && (chars[idx - 1].is_ascii_lowercase() || chars[idx - 1].is_ascii_digit());
216 let next_is_lower = chars
217 .get(idx + 1)
218 .is_some_and(|next| next.is_ascii_lowercase());
219 if idx > 0 && (prev_is_lower_or_digit || next_is_lower) && !out.ends_with('_') {
220 out.push('_');
221 }
222 out.push(ch.to_ascii_lowercase());
223 } else {
224 out.push(ch.to_ascii_lowercase());
225 }
226 } else if !out.ends_with('_') {
227 out.push('_');
228 }
229 }
230
231 let trimmed = out.trim_matches('_');
232 let mut normalized = if trimmed.is_empty() {
233 "module".to_string()
234 } else {
235 trimmed.to_string()
236 };
237
238 if matches!(
239 normalized.as_str(),
240 "as" | "break"
241 | "const"
242 | "continue"
243 | "crate"
244 | "else"
245 | "enum"
246 | "extern"
247 | "false"
248 | "fn"
249 | "for"
250 | "if"
251 | "impl"
252 | "in"
253 | "let"
254 | "loop"
255 | "match"
256 | "mod"
257 | "move"
258 | "mut"
259 | "pub"
260 | "ref"
261 | "return"
262 | "self"
263 | "Self"
264 | "static"
265 | "struct"
266 | "super"
267 | "trait"
268 | "true"
269 | "type"
270 | "unsafe"
271 | "use"
272 | "where"
273 | "while"
274 ) {
275 normalized.push_str("_mod");
276 }
277
278 normalized
279}
280
281pub(crate) fn split_type_params(s: &str, delim: char) -> Vec<String> {
286 let mut parts = Vec::new();
287 let mut depth = 0usize;
288 let mut current = String::new();
289 for ch in s.chars() {
290 match ch {
291 '<' | '(' => {
292 depth += 1;
293 current.push(ch);
294 }
295 '>' | ')' => {
296 depth = depth.saturating_sub(1);
297 current.push(ch);
298 }
299 _ if ch == delim && depth == 0 => {
300 parts.push(current.trim().to_string());
301 current.clear();
302 }
303 _ => current.push(ch),
304 }
305 }
306 let rest = current.trim().to_string();
307 if !rest.is_empty() {
308 parts.push(rest);
309 }
310 parts
311}
312
313pub(crate) fn escape_string_literal_ext(s: &str, unicode_escapes: bool) -> String {
320 let mut out = String::with_capacity(s.len());
321 for ch in s.chars() {
322 match ch {
323 '\\' => out.push_str("\\\\"),
324 '"' => out.push_str("\\\""),
325 '\n' => out.push_str("\\n"),
326 '\r' => out.push_str("\\r"),
327 '\t' => out.push_str("\\t"),
328 '\0' => out.push_str("\\0"),
329 c if c.is_control() => {
330 if unicode_escapes {
331 out.push_str(&format!("\\U{{{:06x}}}", c as u32));
333 } else {
334 out.push_str(&format!("\\x{:02x}", c as u32));
335 }
336 }
337 c => out.push(c),
338 }
339 }
340 out
341}
342
343pub(crate) fn escape_string_literal(s: &str) -> String {
345 escape_string_literal_ext(s, false)
346}
347
348pub(crate) fn escape_string_literal_unicode(s: &str) -> String {
350 escape_string_literal_ext(s, true)
351}
352
353pub(crate) fn parse_type_annotation(ann: &str) -> Type {
357 crate::types::parse_type_str(ann)
358}
359
360pub(crate) fn is_set_type(ty: &Type) -> bool {
366 matches!(ty, Type::Map(_, v) if matches!(v.as_ref(), Type::Unit))
367}
368
369pub(crate) fn is_set_annotation(ann: &str) -> bool {
371 is_set_type(&parse_type_annotation(ann))
372}
373
374pub(crate) fn is_unit_expr(expr: &crate::ast::Expr) -> bool {
376 matches!(expr, crate::ast::Expr::Literal(crate::ast::Literal::Unit))
377}
378
379pub(crate) fn is_unit_expr_spanned(expr: &crate::ast::Spanned<crate::ast::Expr>) -> bool {
381 is_unit_expr(&expr.node)
382}
383
384pub(crate) fn escape_reserved_word(name: &str, reserved: &[&str], suffix: &str) -> String {
389 if reserved.contains(&name) {
390 format!("{}{}", name, suffix)
391 } else {
392 name.to_string()
393 }
394}
395
396pub(crate) fn escape_reserved_word_prefix(name: &str, reserved: &[&str], prefix: &str) -> String {
399 if reserved.contains(&name) {
400 format!("{}{}", prefix, name)
401 } else {
402 name.to_string()
403 }
404}
405
406pub(crate) fn to_lower_first(s: &str) -> String {
410 let mut chars = s.chars();
411 match chars.next() {
412 None => String::new(),
413 Some(c) => c.to_lowercase().to_string() + chars.as_str(),
414 }
415}
416
417pub(crate) fn expr_to_dotted_name(expr: &Expr) -> Option<String> {
420 crate::ir::expr_to_dotted_name(expr)
421}
422
423#[derive(Debug, Clone)]
437pub(crate) enum OracleInjectionMode<'a> {
438 LemmaBinding,
439 LemmaBindingProjected,
449 #[allow(dead_code)]
450 SampleValue,
451 SampleCaseBinding(&'a [(String, crate::ast::Spanned<Expr>)]),
452}
453
454pub(crate) fn rewrite_effectful_calls_in_law(
463 expr: &crate::ast::Spanned<Expr>,
464 law: &crate::ast::VerifyLaw,
465 ctx: &CodegenContext,
466 mode: OracleInjectionMode,
467) -> crate::ast::Spanned<Expr> {
468 use crate::ast::{Spanned, VerifyGivenDomain};
469
470 let injection_by_effect: std::collections::HashMap<String, Spanned<Expr>> = law
471 .givens
472 .iter()
473 .filter_map(|g| {
474 let arg_expr = match &mode {
475 OracleInjectionMode::LemmaBinding => Spanned {
476 node: Expr::Ident(g.name.clone()),
477 line: expr.line,
478 },
479 OracleInjectionMode::LemmaBindingProjected => {
480 Spanned {
488 node: Expr::Ident(g.name.clone()),
489 line: expr.line,
490 }
491 }
492 OracleInjectionMode::SampleValue => match &g.domain {
493 VerifyGivenDomain::Explicit(vals) => vals.first().cloned()?,
494 _ => return None,
495 },
496 OracleInjectionMode::SampleCaseBinding(case_bindings) => case_bindings
497 .iter()
498 .find(|(name, _)| name == &g.name)
499 .map(|(_, v)| v.clone())?,
500 };
501 Some((g.type_name.clone(), arg_expr))
502 })
503 .collect();
504 let rewritten = rewrite_effectful_call(expr, &injection_by_effect, ctx);
505
506 if matches!(mode, OracleInjectionMode::LemmaBindingProjected) {
515 let oracle_names: std::collections::HashSet<String> = law
516 .givens
517 .iter()
518 .filter(|g| {
519 matches!(
520 crate::types::checker::effect_classification::classify(&g.type_name)
521 .map(|c| c.dimension),
522 Some(crate::types::checker::effect_classification::EffectDimension::Generative)
523 | Some(
524 crate::types::checker::effect_classification::EffectDimension::GenerativeOutput
525 )
526 )
527 })
528 .map(|g| g.name.clone())
529 .collect();
530 if !oracle_names.is_empty() {
531 return project_oracle_direct_calls(&rewritten, &oracle_names);
532 }
533 }
534 rewritten
535}
536
537fn project_oracle_direct_calls(
550 expr: &crate::ast::Spanned<Expr>,
551 oracle_names: &std::collections::HashSet<String>,
552) -> crate::ast::Spanned<Expr> {
553 use crate::ast::Spanned;
554 let line = expr.line;
555 let project_ident = |name: &str, line: usize| -> Spanned<Expr> {
556 Spanned {
557 node: Expr::Attr(
558 Box::new(Spanned {
559 node: Expr::Ident(name.to_string()),
560 line,
561 }),
562 "val".to_string(),
563 ),
564 line,
565 }
566 };
567 let new_node = match &expr.node {
568 Expr::Ident(name) if oracle_names.contains(name) => {
572 return project_ident(name, line);
573 }
574 Expr::FnCall(callee, args) => {
575 let new_args: Vec<Spanned<Expr>> = args
576 .iter()
577 .map(|a| project_oracle_direct_calls(a, oracle_names))
578 .collect();
579 let new_callee = if let Expr::Ident(name) = &callee.node
581 && oracle_names.contains(name)
582 {
583 project_ident(name, callee.line)
584 } else {
585 project_oracle_direct_calls(callee, oracle_names)
586 };
587 Expr::FnCall(Box::new(new_callee), new_args)
588 }
589 Expr::Constructor(name, Some(arg)) => Expr::Constructor(
590 name.clone(),
591 Some(Box::new(project_oracle_direct_calls(arg, oracle_names))),
592 ),
593 Expr::Attr(obj, field) => Expr::Attr(
594 Box::new(project_oracle_direct_calls(obj, oracle_names)),
595 field.clone(),
596 ),
597 Expr::BinOp(op, l, r) => Expr::BinOp(
598 *op,
599 Box::new(project_oracle_direct_calls(l, oracle_names)),
600 Box::new(project_oracle_direct_calls(r, oracle_names)),
601 ),
602 other => other.clone(),
603 };
604 Spanned {
605 node: new_node,
606 line,
607 }
608}
609
610fn rewrite_effectful_call(
611 expr: &crate::ast::Spanned<Expr>,
612 injection_by_effect: &std::collections::HashMap<String, crate::ast::Spanned<Expr>>,
613 ctx: &CodegenContext,
614) -> crate::ast::Spanned<Expr> {
615 use crate::ast::Spanned;
616 use crate::types::checker::effect_classification::{EffectDimension, classify};
617
618 match &expr.node {
619 Expr::FnCall(callee, args) => {
620 let rewritten_args: Vec<Spanned<Expr>> = args
621 .iter()
622 .map(|a| rewrite_effectful_call(a, injection_by_effect, ctx))
623 .collect();
624 let rewritten_callee =
625 Box::new(rewrite_effectful_call(callee, injection_by_effect, ctx));
626
627 let callee_name = match &callee.node {
628 Expr::Ident(name) => Some(name.clone()),
629 Expr::Resolved { name, .. } => Some(name.clone()),
630 _ => None,
631 };
632
633 if let Some(name) = callee_name
634 && let Some(fd) = ctx.fn_defs.iter().find(|fd| fd.name == name)
635 && !fd.effects.is_empty()
636 && fd
637 .effects
638 .iter()
639 .all(|e| crate::types::checker::effect_classification::is_classified(&e.node))
640 {
641 let mut injected: Vec<Spanned<Expr>> = Vec::new();
642 let needs_path = fd.effects.iter().any(|e| {
643 matches!(
644 classify(&e.node).map(|c| c.dimension),
645 Some(EffectDimension::Generative | EffectDimension::GenerativeOutput)
646 )
647 });
648 if needs_path {
649 injected.push(Spanned {
650 node: Expr::Attr(
654 Box::new(Spanned {
655 node: Expr::Ident("BranchPath".to_string()),
656 line: expr.line,
657 }),
658 "Root".to_string(),
659 ),
660 line: expr.line,
661 });
662 }
663 let mut seen = std::collections::HashSet::new();
664 for e in &fd.effects {
665 if !seen.insert(e.node.clone()) {
666 continue;
667 }
668 let Some(c) = classify(&e.node) else { continue };
669 if matches!(c.dimension, EffectDimension::Output) {
670 continue;
671 }
672 if let Some(inj) = injection_by_effect.get(&e.node) {
673 injected.push(inj.clone());
674 }
675 }
676 injected.extend(rewritten_args);
677 return Spanned {
678 node: Expr::FnCall(rewritten_callee, injected),
679 line: expr.line,
680 };
681 }
682
683 Spanned {
684 node: Expr::FnCall(rewritten_callee, rewritten_args),
685 line: expr.line,
686 }
687 }
688 Expr::BinOp(op, l, r) => Spanned {
689 node: Expr::BinOp(
690 *op,
691 Box::new(rewrite_effectful_call(l, injection_by_effect, ctx)),
692 Box::new(rewrite_effectful_call(r, injection_by_effect, ctx)),
693 ),
694 line: expr.line,
695 },
696 Expr::Tuple(items) => Spanned {
697 node: Expr::Tuple(
698 items
699 .iter()
700 .map(|i| rewrite_effectful_call(i, injection_by_effect, ctx))
701 .collect(),
702 ),
703 line: expr.line,
704 },
705 _ => expr.clone(),
706 }
707}
708
709pub(crate) fn verify_reachable_fn_names(items: &[TopLevel]) -> HashSet<String> {
719 let mut reachable: HashSet<String> = HashSet::new();
720 for item in items {
721 if let TopLevel::Verify(vb) = item {
722 collect_verify_block_refs(vb, &mut reachable);
723 }
724 }
725 loop {
727 let mut changed = false;
728 for item in items {
729 if let TopLevel::FnDef(fd) = item
730 && reachable.contains(&fd.name)
731 {
732 let mut called = HashSet::new();
733 collect_called_idents_in_body(&fd.body, &mut called);
734 for name in called {
735 if reachable.insert(name) {
736 changed = true;
737 }
738 }
739 }
740 }
741 if !changed {
742 break;
743 }
744 }
745 reachable
746}
747
748fn collect_verify_block_refs(vb: &VerifyBlock, out: &mut HashSet<String>) {
749 out.insert(vb.fn_name.clone());
750 for (lhs, rhs) in &vb.cases {
751 collect_called_idents(lhs, out);
752 collect_called_idents(rhs, out);
753 }
754 if let VerifyKind::Law(law) = &vb.kind {
755 collect_called_idents(&law.lhs, out);
756 collect_called_idents(&law.rhs, out);
757 if let Some(when) = &law.when {
758 collect_called_idents(when, out);
759 }
760 for given in &law.givens {
761 if let VerifyGivenDomain::Explicit(values) = &given.domain {
762 for v in values {
763 collect_called_idents(v, out);
764 }
765 }
766 }
767 }
768 for given in &vb.cases_givens {
769 if let VerifyGivenDomain::Explicit(values) = &given.domain {
770 for v in values {
771 collect_called_idents(v, out);
772 }
773 }
774 }
775}
776
777fn collect_called_idents_in_body(body: &FnBody, out: &mut HashSet<String>) {
778 for stmt in body.stmts() {
779 match stmt {
780 Stmt::Binding(_, _, e) | Stmt::Expr(e) => collect_called_idents(e, out),
781 }
782 }
783}
784
785fn collect_called_idents(expr: &Spanned<Expr>, out: &mut HashSet<String>) {
786 match &expr.node {
787 Expr::FnCall(callee, args) => {
788 if let Expr::Ident(name) | Expr::Resolved { name, .. } = &callee.node {
789 out.insert(name.clone());
790 } else {
791 collect_called_idents(callee, out);
792 }
793 for a in args {
794 collect_called_idents(a, out);
795 }
796 }
797 Expr::TailCall(boxed) => {
798 let TailCallData { target, args, .. } = boxed.as_ref();
799 out.insert(target.clone());
800 for a in args {
801 collect_called_idents(a, out);
802 }
803 }
804 Expr::Ident(name) | Expr::Resolved { name, .. } => {
805 out.insert(name.clone());
806 }
807 Expr::BinOp(_, l, r) => {
808 collect_called_idents(l, out);
809 collect_called_idents(r, out);
810 }
811 Expr::Match { subject, arms, .. } => {
812 collect_called_idents(subject, out);
813 for arm in arms {
814 collect_called_idents(&arm.body, out);
815 }
816 }
817 Expr::ErrorProp(inner) | Expr::Attr(inner, _) => {
818 collect_called_idents(inner, out);
819 }
820 Expr::Constructor(_, Some(inner)) => {
821 collect_called_idents(inner, out);
822 }
823 Expr::InterpolatedStr(parts) => {
824 for part in parts {
825 if let StrPart::Parsed(inner) = part {
826 collect_called_idents(inner, out);
827 }
828 }
829 }
830 Expr::List(items) | Expr::Tuple(items) | Expr::IndependentProduct(items, _) => {
831 for i in items {
832 collect_called_idents(i, out);
833 }
834 }
835 Expr::MapLiteral(entries) => {
836 for (k, v) in entries {
837 collect_called_idents(k, out);
838 collect_called_idents(v, out);
839 }
840 }
841 Expr::RecordCreate { fields, .. } => {
842 for (_, v) in fields {
843 collect_called_idents(v, out);
844 }
845 }
846 Expr::RecordUpdate { base, updates, .. } => {
847 collect_called_idents(base, out);
848 for (_, v) in updates {
849 collect_called_idents(v, out);
850 }
851 }
852 Expr::Literal(_) | Expr::Constructor(_, None) => {}
853 }
854}
855
856pub(crate) struct PerScopeSections {
860 pub by_scope: std::collections::HashMap<String, Vec<String>>,
861}
862
863impl PerScopeSections {
864 pub(crate) fn take(&mut self, scope: &str) -> Vec<String> {
865 self.by_scope.remove(scope).unwrap_or_default()
866 }
867}
868
869pub(crate) fn route_pure_components_per_scope<F, G>(
878 ctx: &CodegenContext,
879 is_pure: F,
880 mut emit: G,
881) -> PerScopeSections
882where
883 F: Fn(&FnDef) -> bool,
884 G: FnMut(&[&FnDef]) -> Vec<String>,
885{
886 let mut by_scope: std::collections::HashMap<String, Vec<String>> =
887 std::collections::HashMap::new();
888
889 let mut process =
890 |fns: Vec<&FnDef>,
891 scope: String,
892 by_scope: &mut std::collections::HashMap<String, Vec<String>>| {
893 let comps = crate::call_graph::ordered_fn_components(&fns, &ctx.module_prefixes);
894 let bucket = by_scope.entry(scope).or_default();
895 for comp in comps {
896 bucket.extend(emit(&comp));
897 }
898 };
899
900 for module in &ctx.modules {
901 let pure: Vec<&FnDef> = module.fn_defs.iter().filter(|fd| is_pure(fd)).collect();
902 process(pure, module.prefix.clone(), &mut by_scope);
903 }
904 let entry_pure: Vec<&FnDef> = ctx.fn_defs.iter().filter(|fd| is_pure(fd)).collect();
905 process(entry_pure, String::new(), &mut by_scope);
906
907 PerScopeSections { by_scope }
908}