use anyhow::{bail, Result};
use std::collections::HashMap;
use crate::config::CompilationConfig;
pub use tensorlogic_adapters::DomainInfo;
#[derive(Debug, Clone)]
pub struct CompilerContext {
pub domains: HashMap<String, DomainInfo>,
pub var_to_domain: HashMap<String, String>,
pub var_to_axis: HashMap<String, char>,
next_axis: char,
temp_counter: usize,
pub config: CompilationConfig,
symbol_table_ref: Option<String>, pub let_bindings: HashMap<String, usize>,
}
impl CompilerContext {
pub fn new() -> Self {
CompilerContext {
domains: HashMap::new(),
var_to_domain: HashMap::new(),
var_to_axis: HashMap::new(),
next_axis: 'a',
temp_counter: 0,
config: CompilationConfig::default(),
symbol_table_ref: None,
let_bindings: HashMap::new(),
}
}
pub fn with_config(config: CompilationConfig) -> Self {
CompilerContext {
domains: HashMap::new(),
var_to_domain: HashMap::new(),
var_to_axis: HashMap::new(),
next_axis: 'a',
temp_counter: 0,
config,
symbol_table_ref: None,
let_bindings: HashMap::new(),
}
}
pub fn from_symbol_table(table: &tensorlogic_adapters::SymbolTable) -> Self {
let mut ctx = Self::new();
for domain in table.domains.values() {
ctx.domains.insert(domain.name.clone(), domain.clone());
}
for (var, domain) in &table.variables {
ctx.var_to_domain.insert(var.clone(), domain.clone());
}
ctx.symbol_table_ref = Some("imported".to_string());
ctx
}
pub fn from_symbol_table_with_config(
table: &tensorlogic_adapters::SymbolTable,
config: CompilationConfig,
) -> Self {
let mut ctx = Self::from_symbol_table(table);
ctx.config = config;
ctx
}
pub fn add_domain(&mut self, name: impl Into<String>, cardinality: usize) {
let name = name.into();
self.domains
.insert(name.clone(), DomainInfo::new(name, cardinality));
}
pub fn add_domain_info(&mut self, domain: DomainInfo) {
self.domains.insert(domain.name.clone(), domain);
}
pub fn bind_var(&mut self, var: &str, domain: &str) -> Result<()> {
if !self.domains.contains_key(domain) {
bail!("Domain '{}' not found", domain);
}
self.var_to_domain
.insert(var.to_string(), domain.to_string());
Ok(())
}
pub fn assign_axis(&mut self, var: &str) -> char {
if let Some(&axis) = self.var_to_axis.get(var) {
return axis;
}
let axis = self.next_axis;
self.var_to_axis.insert(var.to_string(), axis);
self.next_axis = ((axis as u8) + 1) as char;
axis
}
pub fn fresh_temp(&mut self) -> String {
let name = format!("temp_{}", self.temp_counter);
self.temp_counter += 1;
name
}
pub fn get_axes(&self, terms: &[tensorlogic_ir::Term]) -> Result<String> {
use anyhow::anyhow;
use tensorlogic_ir::Term;
let mut axes = String::new();
for term in terms {
if let Term::Var(v) = term {
let axis = self
.var_to_axis
.get(v)
.ok_or_else(|| anyhow!("Variable '{}' not assigned an axis", v))?;
axes.push(*axis);
}
}
Ok(axes)
}
}
impl Default for CompilerContext {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub(crate) struct CompileState {
pub tensor_idx: usize,
pub axes: String,
}