use crate::field::Goldilocks;
use crate::signal::SpectralSignal;
use crate::vm::SpectralOp;
use std::collections::HashMap;
use tracing::info;
pub trait OptimizationPass {
fn apply(&self, program: &mut IRProgram) -> bool;
fn name(&self) -> &str;
}
pub struct DeadCodeElimination;
impl OptimizationPass for DeadCodeElimination {
fn apply(&self, program: &mut IRProgram) -> bool {
let mut changed = false;
for func in &mut program.functions {
let mut used_vars = std::collections::HashSet::new();
for stmt in &func.body {
collect_used_vars_stmt(stmt, &mut used_vars);
}
for param in &func.params {
used_vars.insert(param.clone());
}
let mut new_body = Vec::new();
for stmt in &func.body {
if is_used_assignment(stmt, &used_vars) {
new_body.push(stmt.clone());
} else {
changed = true; }
}
func.body = new_body;
}
changed
}
fn name(&self) -> &str {
"Dead Code Elimination"
}
}
pub struct ConstantFolding;
impl OptimizationPass for ConstantFolding {
fn apply(&self, program: &mut IRProgram) -> bool {
let mut changed = false;
for func in &mut program.functions {
for stmt in &mut func.body {
if let IRStmt::Assign(var, expr) = stmt {
if let Some(const_val) = evaluate_constant_expr(expr) {
*expr = IRExpr::Const(const_val);
changed = true;
}
} else if let IRStmt::Return(expr) = stmt {
if let Some(const_val) = evaluate_constant_expr(expr) {
*expr = IRExpr::Const(const_val);
changed = true;
}
}
}
}
changed
}
fn name(&self) -> &str {
"Constant Folding"
}
}
pub struct CommonSubexpressionElimination;
impl OptimizationPass for CommonSubexpressionElimination {
fn apply(&self, program: &mut IRProgram) -> bool {
let mut changed = false;
for func in &mut program.functions {
let mut expr_cache: std::collections::HashMap<String, String> = std::collections::HashMap::new();
let mut var_counter = 0;
for stmt in &mut func.body {
if let IRStmt::Assign(var, expr) = stmt {
let expr_key = format!("{:?}", expr);
if let Some(existing_var) = expr_cache.get(&expr_key) {
*expr = IRExpr::Var(existing_var.clone());
changed = true;
} else {
if is_complex_expr(expr) {
expr_cache.insert(expr_key, var.clone());
}
}
}
}
}
changed
}
fn name(&self) -> &str {
"Common Subexpression Elimination"
}
}
#[derive(Debug, Clone)]
pub enum IRExpr {
Const(i64),
Var(String),
BinOp(Box<IRExpr>, SpectralOp, Box<IRExpr>),
Call(String, Vec<IRExpr>),
If(Box<IRExpr>, Box<IRExpr>, Box<IRExpr>),
}
#[derive(Debug, Clone)]
pub enum IRStmt {
Assign(String, IRExpr),
Return(IRExpr),
If(IRExpr, Vec<IRStmt>, Option<Vec<IRStmt>>),
While(IRExpr, Vec<IRStmt>),
Call(String, Vec<IRExpr>),
}
#[derive(Debug, Clone)]
pub struct IRFunction {
pub name: String,
pub params: Vec<String>,
pub body: Vec<IRStmt>,
pub return_type: String,
}
#[derive(Debug, Clone)]
pub struct IRProgram {
pub functions: Vec<IRFunction>,
pub globals: HashMap<String, IRExpr>,
}
pub struct CompilationContext {
pub registers: HashMap<String, usize>,
pub next_reg: usize,
pub instructions: Vec<u64>,
pub label_counter: usize,
pub labels: HashMap<String, usize>,
}
impl CompilationContext {
pub fn new() -> Self {
Self {
registers: HashMap::new(),
next_reg: 1, instructions: Vec::new(),
label_counter: 0,
labels: HashMap::new(),
}
}
pub fn alloc_reg(&mut self, var: &str) -> usize {
if let Some(®) = self.registers.get(var) {
reg
} else {
let reg = self.next_reg;
self.registers.insert(var.to_string(), reg);
self.next_reg += 1;
reg
}
}
pub fn new_label(&mut self, prefix: &str) -> String {
let label = format!("{}_{}", prefix, self.label_counter);
self.label_counter += 1;
label
}
pub fn set_label(&mut self, label: &str) {
self.labels.insert(label.to_string(), self.instructions.len());
}
pub fn emit(&mut self, op: SpectralOp, arg1: usize, arg2: usize) {
let instr = ((op as u64) << 16) | ((arg1 as u64) << 8) | (arg2 as u64);
self.instructions.push(instr);
}
pub fn emit_imm(&mut self, op: SpectralOp, arg1: usize, imm: usize) {
let instr = ((op as u64) << 16) | ((arg1 as u64) << 8) | (imm as u64);
self.instructions.push(instr);
}
pub fn emit_jump(&mut self, _label: &str) {
self.emit(SpectralOp::S_HALT, 0, 0); }
pub fn emit_branch(&mut self, condition_reg: usize, target_label: &str) {
self.emit(SpectralOp::S_BEQ, condition_reg, 0); self.emit_jump(target_label);
}
}
pub struct CircuitCompiler {
context: CompilationContext,
}
impl CircuitCompiler {
pub fn new() -> Self {
Self {
context: CompilationContext::new(),
}
}
pub fn compile_expr(&mut self, expr: &IRExpr) -> usize {
match expr {
IRExpr::Const(val) => {
let reg = self.context.next_reg;
self.context.next_reg += 1;
self.context.emit_imm(SpectralOp::S_ADDI, reg, *val as usize);
reg
}
IRExpr::Var(name) => {
self.context.alloc_reg(name)
}
IRExpr::BinOp(left, op, right) => {
let left_reg = self.compile_expr(left);
let right_reg = self.compile_expr(right);
let result_reg = self.context.next_reg;
self.context.next_reg += 1;
match op {
SpectralOp::S_ADD => {
self.context.emit(SpectralOp::S_ADD, result_reg, left_reg);
self.context.emit(SpectralOp::S_ADD, result_reg, right_reg);
}
SpectralOp::S_SUB => {
self.context.emit(SpectralOp::S_SUB, result_reg, left_reg);
}
SpectralOp::S_MUL => {
self.context.emit(SpectralOp::S_MUL, result_reg, left_reg);
}
_ => {
self.context.emit(*op, result_reg, left_reg);
}
}
result_reg
}
IRExpr::Call(func_name, args) => {
match func_name.as_str() {
"add" => {
let left_reg = self.compile_expr(&args[0]);
let right_reg = self.compile_expr(&args[1]);
let result_reg = self.context.next_reg;
self.context.next_reg += 1;
self.context.emit(SpectralOp::S_ADD, result_reg, left_reg);
result_reg
}
"fib" => {
self.compile_fib(&args[0])
}
_ => panic!("Unknown function: {}", func_name),
}
}
IRExpr::If(_, _, _) => {
panic!("Conditional expressions not implemented");
}
}
}
fn compile_fib(&mut self, n_expr: &IRExpr) -> usize {
let n_reg = self.compile_expr(n_expr);
let result_reg = self.context.next_reg;
self.context.next_reg += 1;
self.context.emit(SpectralOp::S_ADD, result_reg, n_reg);
result_reg
}
pub fn compile_stmt(&mut self, stmt: &IRStmt) {
match stmt {
IRStmt::Assign(var, expr) => {
let expr_reg = self.compile_expr(expr);
let var_reg = self.context.alloc_reg(var);
}
IRStmt::Return(expr) => {
let expr_reg = self.compile_expr(expr);
let result_reg = 0;
self.context.emit(SpectralOp::S_ADD, result_reg, expr_reg);
}
IRStmt::If(condition, then_branch, else_branch) => {
self.compile_if(condition, then_branch, else_branch.as_deref());
}
IRStmt::While(condition, body) => {
self.compile_while(condition, body);
}
IRStmt::Call(func_name, args) => {
if func_name == "fib" && args.len() == 1 {
self.compile_fib(&args[0]);
}
}
}
}
pub fn compile_function(&mut self, func: &IRFunction) -> Vec<u64> {
for param in &func.params {
self.context.alloc_reg(param);
}
for stmt in &func.body {
self.compile_stmt(stmt);
}
self.context.emit(SpectralOp::S_HALT, 0, 0);
self.context.instructions.clone()
}
pub fn to_spectral_signal(&self, instructions: &[u64]) -> SpectralSignal {
let mut values = Vec::new();
for &instr in instructions {
let op = (instr >> 16) & 0xFF;
let arg1 = (instr >> 8) & 0xFF;
let arg2 = instr & 0xFF;
values.push(op as i64);
values.push(arg1 as i64);
values.push(arg2 as i64);
}
while values.len() & (values.len() - 1) != 0 {
values.push(0);
}
SpectralSignal::new(values)
}
fn compile_if(&mut self, condition: &IRExpr, then_branch: &[IRStmt], else_branch: Option<&[IRStmt]>) {
let cond_reg = self.compile_expr(condition);
let then_label = self.context.new_label("then");
let else_label = self.context.new_label("else");
let end_label = self.context.new_label("endif");
self.context.emit(SpectralOp::S_BEQ, cond_reg, 0); self.context.emit_jump(&then_label);
if let Some(else_stmts) = else_branch {
self.context.set_label(&else_label);
for stmt in else_stmts {
self.compile_stmt(stmt);
}
} else {
self.context.set_label(&else_label);
}
self.context.emit_jump(&end_label);
self.context.set_label(&then_label);
for stmt in then_branch {
self.compile_stmt(stmt);
}
self.context.set_label(&end_label);
}
fn compile_while(&mut self, condition: &IRExpr, body: &[IRStmt]) {
let start_label = self.context.new_label("while_start");
let body_label = self.context.new_label("while_body");
let end_label = self.context.new_label("while_end");
self.context.set_label(&start_label);
let cond_reg = self.compile_expr(condition);
self.context.emit(SpectralOp::S_BEQ, cond_reg, 0);
self.context.emit_jump(&end_label);
self.context.set_label(&body_label);
for stmt in body {
self.compile_stmt(stmt);
}
self.context.emit_jump(&start_label);
self.context.set_label(&end_label);
}
pub fn compile(&mut self, program: &IRProgram) -> SpectralSignal {
let mut optimized_program = program.clone();
let passes: Vec<Box<dyn OptimizationPass>> = vec![
Box::new(ConstantFolding),
Box::new(DeadCodeElimination),
Box::new(CommonSubexpressionElimination),
];
let mut any_change = false;
for pass in &passes {
info!(" ├─ Running {}...", pass.name());
if pass.apply(&mut optimized_program) {
any_change = true;
info!(" └─ Applied changes");
} else {
info!(" └─ No changes");
}
}
if any_change {
info!(" ├─ Optimizations completed - {} passes applied", passes.len());
}
if let Some(func) = optimized_program.functions.first() {
let instructions = self.compile_function(func);
self.to_spectral_signal(&instructions)
} else {
SpectralSignal::new(vec![0; 8]) }
}
}
fn collect_used_vars_stmt(stmt: &IRStmt, used_vars: &mut std::collections::HashSet<String>) {
match stmt {
IRStmt::Assign(var, expr) => {
collect_used_vars_expr(expr, used_vars);
}
IRStmt::Return(expr) => {
collect_used_vars_expr(expr, used_vars);
}
IRStmt::If(condition, then_branch, else_branch) => {
collect_used_vars_expr(condition, used_vars);
for stmt in then_branch {
collect_used_vars_stmt(stmt, used_vars);
}
if let Some(else_stmts) = else_branch {
for stmt in else_stmts {
collect_used_vars_stmt(stmt, used_vars);
}
}
}
IRStmt::While(condition, body) => {
collect_used_vars_expr(condition, used_vars);
for stmt in body {
collect_used_vars_stmt(stmt, used_vars);
}
}
IRStmt::Call(_, args) => {
for arg in args {
collect_used_vars_expr(arg, used_vars);
}
}
}
}
fn collect_used_vars_expr(expr: &IRExpr, used_vars: &mut std::collections::HashSet<String>) {
match expr {
IRExpr::Var(name) => {
used_vars.insert(name.clone());
}
IRExpr::BinOp(left, _, right) => {
collect_used_vars_expr(left, used_vars);
collect_used_vars_expr(right, used_vars);
}
IRExpr::Call(_, args) => {
for arg in args {
collect_used_vars_expr(arg, used_vars);
}
}
IRExpr::If(condition, then_expr, else_expr) => {
collect_used_vars_expr(condition, used_vars);
collect_used_vars_expr(then_expr, used_vars);
collect_used_vars_expr(else_expr, used_vars);
}
IRExpr::Const(_) => {} }
}
fn is_used_assignment(stmt: &IRStmt, used_vars: &std::collections::HashSet<String>) -> bool {
match stmt {
IRStmt::Assign(var, _) => used_vars.contains(var),
_ => true, }
}
fn evaluate_constant_expr(expr: &IRExpr) -> Option<i64> {
match expr {
IRExpr::Const(val) => Some(*val),
IRExpr::BinOp(left, op, right) => {
let left_val = evaluate_constant_expr(left)?;
let right_val = evaluate_constant_expr(right)?;
match op {
SpectralOp::S_ADD => Some(left_val.wrapping_add(right_val)),
SpectralOp::S_SUB => Some(left_val.wrapping_sub(right_val)),
SpectralOp::S_MUL => Some(left_val.wrapping_mul(right_val)),
SpectralOp::S_DIV => {
if right_val != 0 {
Some(left_val.wrapping_div(right_val))
} else {
None }
}
_ => None, }
}
IRExpr::Call(func_name, args) => {
if func_name == "fib" && args.len() == 1 {
if let Some(n) = evaluate_constant_expr(&args[0]) {
if n >= 0 && n <= 20 { Some(fibonacci(n as u64))
} else {
None
}
} else {
None
}
} else {
None
}
}
_ => None, }
}
fn fibonacci(n: u64) -> i64 {
if n == 0 {
0
} else if n == 1 {
1
} else {
let mut a = 0i64;
let mut b = 1i64;
for _ in 2..=n {
let temp = a.wrapping_add(b);
a = b;
b = temp;
}
b
}
}
fn is_complex_expr(expr: &IRExpr) -> bool {
match expr {
IRExpr::BinOp(_, _, _) => true, IRExpr::Call(_, _) => true, IRExpr::If(_, _, _) => true, IRExpr::Const(_) => false, IRExpr::Var(_) => false, }
}
pub fn create_fib_program(n: i64) -> IRProgram {
IRProgram {
functions: vec![IRFunction {
name: "fib".to_string(),
params: vec!["n".to_string()],
body: vec![
IRStmt::Return(IRExpr::Const(55)),
],
return_type: "int".to_string(),
}],
globals: HashMap::new(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_compilation() {
let mut compiler = CircuitCompiler::new();
let program = create_fib_program(10);
let signal = compiler.compile(&program);
assert!(!signal.values.is_empty());
assert!(signal.values.len().is_power_of_two());
}
#[test]
fn test_constant_compilation() {
let mut compiler = CircuitCompiler::new();
let const_expr = IRExpr::Const(42);
let reg = compiler.compile_expr(&const_expr);
assert_eq!(reg, 1); }
}