#![forbid(unsafe_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Trit {
Neg = -1,
Zero = 0,
Pos = 1,
}
impl Trit {
pub fn from_i8(v: i8) -> Option<Self> {
match v {
-1 => Some(Trit::Neg),
0 => Some(Trit::Zero),
1 => Some(Trit::Pos),
_ => None,
}
}
pub fn to_i8(self) -> i8 { self as i8 }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Op {
LoadConst(Trit),
Load(usize),
Store(usize),
Add,
Mul,
Neg,
Nop,
Jump(usize),
JumpIfZero(usize),
JumpIfNeg(usize),
JumpIfPos(usize),
Input,
Output,
Halt,
}
#[derive(Debug, Clone)]
pub struct Program {
pub instructions: Vec<Op>,
}
impl Program {
pub fn new(instructions: Vec<Op>) -> Self {
Self { instructions }
}
pub fn len(&self) -> usize {
self.instructions.len()
}
pub fn is_empty(&self) -> bool {
self.instructions.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct OptimizationResult {
pub program: Program,
pub passes_applied: Vec<String>,
pub trits_eliminated: usize,
}
pub fn dead_trit_elimination(program: &Program) -> Program {
let mut used_regs: std::collections::HashSet<usize> = std::collections::HashSet::new();
let mut used_labels: std::collections::HashSet<usize> = std::collections::HashSet::new();
for op in &program.instructions {
match op {
Op::Load(reg) => { used_regs.insert(*reg); }
Op::Jump(target) | Op::JumpIfZero(target) | Op::JumpIfNeg(target) | Op::JumpIfPos(target) => {
used_labels.insert(*target);
}
_ => {}
}
}
let mut needed = std::collections::HashSet::new();
for op in program.instructions.iter().rev() {
match op {
Op::Load(reg) => { if needed.contains(reg) { used_regs.insert(*reg); } }
Op::Store(reg) => { if used_regs.contains(reg) { needed.insert(*reg); } }
_ => {}
}
}
let mut result = Vec::new();
for (i, op) in program.instructions.iter().enumerate() {
match op {
Op::Nop => {} Op::LoadConst(_) => result.push(*op), Op::Load(reg) => {
if used_regs.contains(reg) {
result.push(*op);
}
}
_ => result.push(*op),
}
let _ = used_labels.contains(&i); }
Program::new(result)
}
pub fn constant_folding(program: &Program) -> Program {
let mut result: Vec<Op> = Vec::new();
let mut i = 0;
let instrs = &program.instructions;
while i < instrs.len() {
match (instrs.get(i), instrs.get(i + 1), instrs.get(i + 2)) {
(Some(Op::LoadConst(a)), Some(Op::LoadConst(b)), Some(Op::Add)) => {
let sum = (a.to_i8() + b.to_i8()).clamp(-1, 1);
result.push(Op::LoadConst(Trit::from_i8(sum).unwrap_or(Trit::Zero)));
i += 3;
}
(Some(Op::LoadConst(a)), Some(Op::LoadConst(b)), Some(Op::Mul)) => {
let product = a.to_i8() * b.to_i8();
result.push(Op::LoadConst(Trit::from_i8(product.clamp(-1, 1)).unwrap_or(Trit::Zero)));
i += 3;
}
(Some(Op::LoadConst(a)), Some(Op::Neg), _) => {
let neg = match a {
Trit::Pos => Trit::Neg,
Trit::Neg => Trit::Pos,
Trit::Zero => Trit::Zero,
};
result.push(Op::LoadConst(neg));
i += 2;
}
(Some(Op::LoadConst(Trit::Zero)), Some(Op::Mul), _) => {
result.push(Op::LoadConst(Trit::Zero));
i += 2;
}
_ => {
result.push(instrs[i]);
i += 1;
}
}
}
Program::new(result)
}
pub fn trit_merging(program: &Program) -> Program {
let mut result: Vec<Op> = Vec::new();
let mut i = 0;
let instrs = &program.instructions;
while i < instrs.len() {
match (instrs.get(i), instrs.get(i + 1)) {
(Some(Op::Neg), Some(Op::Neg)) => { i += 2; }
(Some(Op::LoadConst(Trit::Zero)), Some(Op::Add)) => { i += 2; }
(Some(Op::LoadConst(Trit::Pos)), Some(Op::Mul)) => { i += 2; }
(Some(Op::LoadConst(Trit::Zero)), Some(Op::Mul)) => {
result.push(Op::LoadConst(Trit::Zero));
i += 2;
}
_ => {
result.push(instrs[i]);
i += 1;
}
}
}
Program::new(result)
}
pub struct PeepholeOptimizer {
pub window_size: usize,
}
impl PeepholeOptimizer {
pub fn new(window_size: usize) -> Self {
Self { window_size }
}
pub fn optimize(&self, program: &Program) -> Program {
let mut result: Vec<Op> = Vec::new();
let instrs = &program.instructions;
let mut i = 0;
while i < instrs.len() {
if i + 1 < instrs.len() {
if let (Op::Store(r1), Op::Load(r2)) = (&instrs[i], &instrs[i + 1]) {
if r1 == r2 {
i += 2;
continue;
}
}
}
if i + 2 < instrs.len() {
if let (Op::LoadConst(_), Op::Store(_), Op::Load(_)) = (&instrs[i], &instrs[i + 1], &instrs[i + 2]) {
result.push(instrs[i]); result.push(instrs[i + 1]); i += 3;
continue;
}
}
result.push(instrs[i]);
i += 1;
}
Program::new(result)
}
}
#[derive(Debug, Clone)]
pub struct LoopInfo {
pub start: usize,
pub end: usize,
pub back_edge: usize,
pub estimated_iterations: Option<usize>,
}
pub fn detect_loops(program: &Program) -> Vec<LoopInfo> {
let mut loops = Vec::new();
for (i, op) in program.instructions.iter().enumerate() {
let target = match op {
Op::Jump(t) | Op::JumpIfZero(t) | Op::JumpIfNeg(t) | Op::JumpIfPos(t) => Some(*t),
_ => None,
};
if let Some(target) = target {
if target <= i {
loops.push(LoopInfo {
start: target,
end: i,
back_edge: i,
estimated_iterations: None,
});
}
}
}
loops
}
pub fn detect_loops_with_iterations(program: &Program) -> Vec<LoopInfo> {
let mut loops = detect_loops(program);
for loop_info in &mut loops {
if loop_info.start > 0 {
if let Some(Op::LoadConst(Trit::Pos)) = program.instructions.get(loop_info.start.saturating_sub(1)) {
loop_info.estimated_iterations = Some(1);
}
}
}
loops
}
pub struct OptimizationPipeline {
pub passes: Vec<Box<dyn Fn(&Program) -> Program>>,
pub pass_names: Vec<String>,
pub max_iterations: usize,
}
impl OptimizationPipeline {
pub fn new() -> Self {
Self {
passes: Vec::new(),
pass_names: Vec::new(),
max_iterations: 10,
}
}
pub fn add_pass<F: Fn(&Program) -> Program + 'static>(mut self, name: &str, pass: F) -> Self {
self.passes.push(Box::new(pass));
self.pass_names.push(name.to_string());
self
}
pub fn run_once(&self, program: &Program) -> OptimizationResult {
let mut current = program.clone();
for pass in &self.passes {
current = pass(¤t);
}
let eliminated = program.len().saturating_sub(current.len());
OptimizationResult {
program: current,
passes_applied: self.pass_names.clone(),
trits_eliminated: eliminated,
}
}
pub fn run_to_fixed_point(&self, program: &Program) -> OptimizationResult {
let mut current = program.clone();
let mut total_eliminated = 0;
let mut all_applied = Vec::new();
for _ in 0..self.max_iterations {
let prev_len = current.len();
for pass in &self.passes {
current = pass(¤t);
}
all_applied.extend(self.pass_names.iter().cloned());
let eliminated = prev_len.saturating_sub(current.len());
total_eliminated += eliminated;
if current.len() == prev_len {
break;
}
}
OptimizationResult {
program: current,
passes_applied: all_applied,
trits_eliminated: total_eliminated,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trit_from_i8() {
assert_eq!(Trit::from_i8(-1), Some(Trit::Neg));
assert_eq!(Trit::from_i8(0), Some(Trit::Zero));
assert_eq!(Trit::from_i8(1), Some(Trit::Pos));
assert_eq!(Trit::from_i8(2), None);
}
#[test]
fn test_dead_trit_elimination_nop() {
let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Nop, Op::Halt]);
let optimized = dead_trit_elimination(&prog);
assert_eq!(optimized.len(), 2);
assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Pos));
assert_eq!(optimized.instructions[1], Op::Halt);
}
#[test]
fn test_dead_trit_elimination_unused_load() {
let prog = Program::new(vec![Op::Load(5), Op::Store(5), Op::Halt]);
let optimized = dead_trit_elimination(&prog);
assert!(optimized.len() >= 1);
}
#[test]
fn test_constant_folding_add() {
let prog = Program::new(vec![
Op::LoadConst(Trit::Pos),
Op::LoadConst(Trit::Pos),
Op::Add,
]);
let optimized = constant_folding(&prog);
assert_eq!(optimized.len(), 1);
assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Pos)); }
#[test]
fn test_constant_folding_mul() {
let prog = Program::new(vec![
Op::LoadConst(Trit::Neg),
Op::LoadConst(Trit::Neg),
Op::Mul,
]);
let optimized = constant_folding(&prog);
assert_eq!(optimized.len(), 1);
assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Pos)); }
#[test]
fn test_constant_folding_neg() {
let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Neg]);
let optimized = constant_folding(&prog);
assert_eq!(optimized.len(), 1);
assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Neg));
}
#[test]
fn test_constant_folding_neg_zero() {
let prog = Program::new(vec![Op::LoadConst(Trit::Zero), Op::Neg]);
let optimized = constant_folding(&prog);
assert_eq!(optimized.len(), 1);
assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Zero));
}
#[test]
fn test_trit_merging_double_neg() {
let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Neg, Op::Neg, Op::Halt]);
let optimized = trit_merging(&prog);
assert_eq!(optimized.len(), 2); assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Pos));
}
#[test]
fn test_trit_merging_zero_add() {
let prog = Program::new(vec![Op::LoadConst(Trit::Zero), Op::Add, Op::Halt]);
let optimized = trit_merging(&prog);
assert_eq!(optimized.len(), 1); assert_eq!(optimized.instructions[0], Op::Halt);
}
#[test]
fn test_trit_merging_pos_mul() {
let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Mul, Op::Halt]);
let optimized = trit_merging(&prog);
assert_eq!(optimized.len(), 1); }
#[test]
fn test_trit_merging_zero_mul() {
let prog = Program::new(vec![Op::LoadConst(Trit::Zero), Op::Mul, Op::Halt]);
let optimized = trit_merging(&prog);
assert_eq!(optimized.len(), 2);
assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Zero));
}
#[test]
fn test_peephole_store_load() {
let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Store(0), Op::Load(0), Op::Halt]);
let optimizer = PeepholeOptimizer::new(2);
let optimized = optimizer.optimize(&prog);
assert!(optimized.len() <= 4);
}
#[test]
fn test_peephole_preserves_halt() {
let prog = Program::new(vec![Op::Halt]);
let optimizer = PeepholeOptimizer::new(2);
let optimized = optimizer.optimize(&prog);
assert_eq!(optimized.len(), 1);
assert_eq!(optimized.instructions[0], Op::Halt);
}
#[test]
fn test_loop_detection_simple() {
let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::JumpIfZero(0)]);
let loops = detect_loops(&prog);
assert_eq!(loops.len(), 1);
assert_eq!(loops[0].start, 0);
assert_eq!(loops[0].back_edge, 1);
}
#[test]
fn test_loop_detection_no_loop() {
let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Jump(2), Op::Halt]);
let loops = detect_loops(&prog);
assert!(loops.is_empty());
}
#[test]
fn test_loop_detection_nested() {
let prog = Program::new(vec![
Op::LoadConst(Trit::Pos),
Op::JumpIfZero(0),
Op::JumpIfNeg(0),
]);
let loops = detect_loops(&prog);
assert_eq!(loops.len(), 2);
}
#[test]
fn test_optimization_pipeline_single_pass() {
let pipeline = OptimizationPipeline::new()
.add_pass("dead_trit_elimination", |p| dead_trit_elimination(p))
.add_pass("constant_folding", |p| constant_folding(p));
let prog = Program::new(vec![
Op::LoadConst(Trit::Pos),
Op::LoadConst(Trit::Pos),
Op::Add,
Op::Nop,
]);
let result = pipeline.run_once(&prog);
assert!(result.program.len() < prog.len());
assert_eq!(result.passes_applied.len(), 2);
}
#[test]
fn test_optimization_pipeline_fixed_point() {
let pipeline = OptimizationPipeline::new()
.add_pass("constant_folding", |p| constant_folding(p))
.add_pass("trit_merging", |p| trit_merging(p))
.add_pass("dead_trit_elimination", |p| dead_trit_elimination(p));
let prog = Program::new(vec![
Op::LoadConst(Trit::Pos),
Op::LoadConst(Trit::Neg),
Op::Mul, Op::Neg, Op::Neg, Op::Nop, ]);
let result = pipeline.run_to_fixed_point(&prog);
assert!(result.program.len() < prog.len());
}
#[test]
fn test_constant_folding_add_neg_pos() {
let prog = Program::new(vec![
Op::LoadConst(Trit::Neg),
Op::LoadConst(Trit::Pos),
Op::Add,
]);
let optimized = constant_folding(&prog);
assert_eq!(optimized.len(), 1);
assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Zero));
}
#[test]
fn test_constant_folding_zero_mul_pattern() {
let prog = Program::new(vec![Op::LoadConst(Trit::Zero), Op::Mul, Op::Halt]);
let optimized = constant_folding(&prog);
assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Zero));
}
#[test]
fn test_program_empty() {
let prog = Program::new(vec![]);
assert!(prog.is_empty());
assert_eq!(prog.len(), 0);
}
#[test]
fn test_optimization_pipeline_no_change() {
let pipeline = OptimizationPipeline::new()
.add_pass("constant_folding", |p| constant_folding(p));
let prog = Program::new(vec![Op::Halt]);
let result = pipeline.run_once(&prog);
assert_eq!(result.program.len(), 1);
}
#[test]
fn test_detect_loops_with_iterations() {
let prog = Program::new(vec![
Op::LoadConst(Trit::Pos),
Op::JumpIfZero(1),
]);
let loops = detect_loops_with_iterations(&prog);
assert_eq!(loops.len(), 1);
}
}