use anyhow::{anyhow, Result};
use tensorlogic_ir::{EinsumGraph, MatchPattern, TLExpr};
use crate::context::{CompileState, CompilerContext};
use super::compile_expr;
pub(crate) fn compile_match(
scrutinee: &TLExpr,
arms: &[(MatchPattern, Box<TLExpr>)],
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
if arms.is_empty() {
return Err(anyhow!("Match expression has no arms"));
}
let last_pat = &arms[arms.len() - 1].0;
if !matches!(last_pat, MatchPattern::Wildcard) {
return Err(anyhow!(
"Last arm of Match must be Wildcard — validate before compiling"
));
}
let lowered = lower_match_to_if_chain(scrutinee, arms)?;
compile_expr(&lowered, ctx, graph)
}
fn lower_match_to_if_chain(
scrutinee: &TLExpr,
arms: &[(MatchPattern, Box<TLExpr>)],
) -> Result<TLExpr> {
let wildcard_body = arms
.last()
.ok_or_else(|| anyhow!("Empty arms in Match"))?
.1
.as_ref()
.clone();
let non_wildcard = &arms[..arms.len() - 1];
let mut chain = wildcard_body;
for (pat, body) in non_wildcard.iter().rev() {
let condition = pattern_condition(scrutinee, pat)?;
chain = TLExpr::IfThenElse {
condition: Box::new(condition),
then_branch: Box::new(body.as_ref().clone()),
else_branch: Box::new(chain),
};
}
Ok(chain)
}
fn pattern_condition(scrutinee: &TLExpr, pat: &MatchPattern) -> Result<TLExpr> {
let rhs = match pat {
MatchPattern::ConstNumber(n) => TLExpr::Constant(*n),
MatchPattern::ConstSymbol(s) => TLExpr::SymbolLiteral(s.clone()),
MatchPattern::Wildcard => {
return Err(anyhow!("Wildcard pattern in non-tail position is invalid"));
}
};
Ok(TLExpr::Eq(Box::new(scrutinee.clone()), Box::new(rhs)))
}
pub(crate) fn compile_symbol_literal(
symbol: &str,
_ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
let tensor_name = format!("sym_{symbol}");
let tensor_idx = graph.add_tensor(&tensor_name);
Ok(CompileState {
tensor_idx,
axes: String::new(),
})
}