1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
//! Compilation of Let bindings for local variable definitions.
use anyhow::Result;
use tensorlogic_ir::{EinsumGraph, TLExpr};
use crate::context::{CompileState, CompilerContext};
use super::compile_expr;
/// Compile let binding: let var = value in body
///
/// This compiles a let expression by:
/// 1. Compiling the value expression
/// 2. Binding the variable name to the value's tensor
/// 3. Compiling the body expression (which can reference the variable)
/// 4. Restoring the previous variable binding (if any)
pub(crate) fn compile_let(
var: &str,
value: &TLExpr,
body: &TLExpr,
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
// Compile the value expression
let value_state = compile_expr(value, ctx, graph)?;
// Save the previous binding for this variable (if any)
let prev_binding = ctx.var_to_domain.get(var).cloned();
let prev_axis = ctx.var_to_axis.get(var).copied();
// Bind the variable to the value's tensor state
// For let bindings, we treat the variable as bound to the same domain/axis as the value
if let Some(domain) = infer_domain_from_axes(&value_state.axes, ctx) {
ctx.var_to_domain.insert(var.to_string(), domain);
}
// If the value has axes, bind the variable to the first axis
// (this is a simplification; more complex bindings might need different logic)
if let Some(first_axis) = value_state.axes.chars().next() {
ctx.var_to_axis.insert(var.to_string(), first_axis);
}
// Store the tensor binding
ctx.let_bindings
.insert(var.to_string(), value_state.tensor_idx);
// Compile the body expression (which can reference the variable)
let body_state = compile_expr(body, ctx, graph)?;
// Restore the previous binding
ctx.let_bindings.remove(var);
match prev_binding {
Some(domain) => {
ctx.var_to_domain.insert(var.to_string(), domain);
}
None => {
ctx.var_to_domain.remove(var);
}
}
match prev_axis {
Some(axis) => {
ctx.var_to_axis.insert(var.to_string(), axis);
}
None => {
ctx.var_to_axis.remove(var);
}
}
Ok(body_state)
}
/// Infer domain from axes by looking up the first axis in the context
fn infer_domain_from_axes(axes: &str, ctx: &CompilerContext) -> Option<String> {
if axes.is_empty() {
return None;
}
// Get the first axis character
let first_axis = axes.chars().next()?;
// Find a variable that maps to this axis
for (var, &var_axis) in &ctx.var_to_axis {
if var_axis == first_axis {
if let Some(domain) = ctx.var_to_domain.get(var) {
return Some(domain.clone());
}
}
}
None
}