use std::collections::{HashMap, HashSet};
use std::fmt::Write;
use crate::ast::logic::LogicExpr;
use crate::ast::stmt::Stmt;
use crate::intern::{Interner, Symbol};
use super::{codegen_assertion, codegen_expr};
pub struct RefinementContext<'a> {
scopes: Vec<HashMap<Symbol, (Symbol, &'a LogicExpr<'a>)>>,
variable_types: HashMap<Symbol, String>,
boxed_binding_scopes: Vec<HashSet<Symbol>>,
string_vars: HashSet<Symbol>,
live_vars_after: Option<HashSet<Symbol>>,
}
impl<'a> RefinementContext<'a> {
pub fn new() -> Self {
Self {
scopes: vec![HashMap::new()],
variable_types: HashMap::new(),
boxed_binding_scopes: vec![HashSet::new()],
string_vars: HashSet::new(),
live_vars_after: None,
}
}
pub fn from_type_env(type_env: &crate::analysis::types::TypeEnv) -> Self {
Self {
scopes: vec![HashMap::new()],
variable_types: type_env.to_legacy_variable_types(),
boxed_binding_scopes: vec![HashSet::new()],
string_vars: type_env.to_legacy_string_vars(),
live_vars_after: None,
}
}
pub fn set_live_vars_after(&mut self, live: HashSet<Symbol>) {
self.live_vars_after = Some(live);
}
pub fn take_live_vars_after(&mut self) -> Option<HashSet<Symbol>> {
self.live_vars_after.take()
}
pub(super) fn push_scope(&mut self) {
self.scopes.push(HashMap::new());
self.boxed_binding_scopes.push(HashSet::new());
}
pub(super) fn pop_scope(&mut self) {
self.scopes.pop();
self.boxed_binding_scopes.pop();
}
pub(super) fn register_boxed_binding(&mut self, var: Symbol) {
if let Some(scope) = self.boxed_binding_scopes.last_mut() {
scope.insert(var);
}
}
pub(super) fn is_boxed_binding(&self, var: Symbol) -> bool {
for scope in self.boxed_binding_scopes.iter().rev() {
if scope.contains(&var) {
return true;
}
}
false
}
pub(super) fn register_string_var(&mut self, var: Symbol) {
self.string_vars.insert(var);
}
pub(super) fn is_string_var(&self, var: Symbol) -> bool {
self.string_vars.contains(&var)
}
pub(super) fn get_string_vars(&self) -> &HashSet<Symbol> {
&self.string_vars
}
pub(super) fn register(&mut self, var: Symbol, bound_var: Symbol, predicate: &'a LogicExpr<'a>) {
if let Some(scope) = self.scopes.last_mut() {
scope.insert(var, (bound_var, predicate));
}
}
pub(super) fn get_constraint(&self, var: Symbol) -> Option<(Symbol, &'a LogicExpr<'a>)> {
for scope in self.scopes.iter().rev() {
if let Some(entry) = scope.get(&var) {
return Some(*entry);
}
}
None
}
pub(super) fn register_variable_type(&mut self, var: Symbol, type_name: String) {
self.variable_types.insert(var, type_name);
}
pub(super) fn get_variable_types(&self) -> &HashMap<Symbol, String> {
&self.variable_types
}
pub(super) fn get_variable_types_mut(&mut self) -> &mut HashMap<Symbol, String> {
&mut self.variable_types
}
pub(super) fn find_variable_by_type(&self, type_name: &str, interner: &Interner) -> Option<String> {
let type_lower = type_name.to_lowercase();
for (var_sym, var_type) in &self.variable_types {
if var_type.to_lowercase() == type_lower {
return Some(interner.resolve(*var_sym).to_string());
}
}
None
}
}
pub(super) fn emit_refinement_check(
var_name: &str,
bound_var: Symbol,
predicate: &LogicExpr,
interner: &Interner,
indent_str: &str,
output: &mut String,
) {
let assertion = codegen_assertion(predicate, interner);
let bound = interner.resolve(bound_var);
let check = if bound == var_name {
assertion
} else {
replace_word(&assertion, bound, var_name)
};
writeln!(output, "{}debug_assert!({});", indent_str, check).unwrap();
}
pub(super) fn replace_word(text: &str, from: &str, to: &str) -> String {
let mut result = String::with_capacity(text.len());
let mut word = String::new();
for c in text.chars() {
if c.is_alphanumeric() || c == '_' {
word.push(c);
} else {
if !word.is_empty() {
result.push_str(if word == from { to } else { &word });
word.clear();
}
result.push(c);
}
}
if !word.is_empty() {
result.push_str(if word == from { to } else { &word });
}
result
}
#[derive(Debug, Default)]
pub struct VariableCapabilities {
pub(super) mounted: bool,
pub(super) synced: bool,
pub(super) mount_path: Option<String>,
pub(super) sync_topic: Option<String>,
}
pub fn empty_var_caps() -> HashMap<Symbol, VariableCapabilities> {
HashMap::new()
}
pub(super) fn analyze_variable_capabilities<'a>(
stmts: &[Stmt<'a>],
interner: &Interner,
) -> HashMap<Symbol, VariableCapabilities> {
let mut caps: HashMap<Symbol, VariableCapabilities> = HashMap::new();
let empty_synced = HashSet::new();
for stmt in stmts {
match stmt {
Stmt::Mount { var, path } => {
let entry = caps.entry(*var).or_default();
entry.mounted = true;
entry.mount_path = Some(codegen_expr(path, interner, &empty_synced));
}
Stmt::Sync { var, topic } => {
let entry = caps.entry(*var).or_default();
entry.synced = true;
entry.sync_topic = Some(codegen_expr(topic, interner, &empty_synced));
}
Stmt::If { then_block, else_block, .. } => {
let nested = analyze_variable_capabilities(then_block, interner);
for (var, cap) in nested {
let entry = caps.entry(var).or_default();
if cap.mounted { entry.mounted = true; entry.mount_path = cap.mount_path; }
if cap.synced { entry.synced = true; entry.sync_topic = cap.sync_topic; }
}
if let Some(else_b) = else_block {
let nested = analyze_variable_capabilities(else_b, interner);
for (var, cap) in nested {
let entry = caps.entry(var).or_default();
if cap.mounted { entry.mounted = true; entry.mount_path = cap.mount_path; }
if cap.synced { entry.synced = true; entry.sync_topic = cap.sync_topic; }
}
}
}
Stmt::While { body, .. } | Stmt::Repeat { body, .. } => {
let nested = analyze_variable_capabilities(body, interner);
for (var, cap) in nested {
let entry = caps.entry(var).or_default();
if cap.mounted { entry.mounted = true; entry.mount_path = cap.mount_path; }
if cap.synced { entry.synced = true; entry.sync_topic = cap.sync_topic; }
}
}
_ => {}
}
}
caps
}