use std::collections::HashMap;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct Id(u32);
impl Id {
fn index(self) -> usize {
self.0 as usize
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct NodeId(pub u32);
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Expr {
Num(u16),
Var(NodeId),
Add(Id, Id),
Sub(Id, Id),
Max(Id, Id),
Min(Id, Id),
Div(Id, Id),
Mul(Id, Id),
Clamp { min: Id, max: Id, val: Id },
HFlex(Vec<Id>),
VFlex(Vec<Id>),
Fill(Id),
}
#[derive(Debug)]
pub struct EGraph {
parents: Vec<u32>,
ranks: Vec<u8>,
nodes: Vec<Vec<Expr>>,
memo: HashMap<Expr, Id>,
last_apply_count: usize,
}
impl EGraph {
pub fn new() -> Self {
Self {
parents: Vec::new(),
ranks: Vec::new(),
nodes: Vec::new(),
memo: HashMap::new(),
last_apply_count: 0,
}
}
pub fn add(&mut self, expr: Expr) -> Id {
let canonical = self.canonicalize(&expr);
if let Some(&id) = self.memo.get(&canonical) {
return self.find(id);
}
let id = Id(self.parents.len() as u32);
self.parents.push(id.0);
self.ranks.push(0);
self.nodes.push(vec![canonical.clone()]);
self.memo.insert(canonical, id);
id
}
pub fn equiv(&self, a: Id, b: Id) -> bool {
self.find(a) == self.find(b)
}
pub fn merge(&mut self, a: Id, b: Id) -> Id {
let a = self.find(a);
let b = self.find(b);
if a == b {
return a;
}
let (winner, loser) = if self.ranks[a.index()] >= self.ranks[b.index()] {
(a, b)
} else {
(b, a)
};
self.parents[loser.index()] = winner.0;
if self.ranks[winner.index()] == self.ranks[loser.index()] {
self.ranks[winner.index()] += 1;
}
let loser_nodes = std::mem::take(&mut self.nodes[loser.index()]);
self.nodes[winner.index()].extend(loser_nodes);
winner
}
pub fn find(&self, mut id: Id) -> Id {
while self.parents[id.index()] != id.0 {
id = Id(self.parents[id.index()]);
}
id
}
pub fn class_count(&self) -> usize {
(0..self.parents.len())
.filter(|&i| self.parents[i] == i as u32)
.count()
}
pub fn node_count(&self) -> usize {
self.nodes.iter().map(|n| n.len()).sum()
}
pub fn last_apply_count(&self) -> usize {
self.last_apply_count
}
pub fn apply_rules(&mut self) -> usize {
self.apply_rules_with_budget(100)
}
pub fn apply_rules_with_budget(&mut self, max_iterations: usize) -> usize {
let mut total = 0;
for _ in 0..max_iterations {
let applied = self.apply_rules_once();
total += applied;
self.last_apply_count = total;
if applied == 0 {
break;
}
}
total
}
fn apply_rules_once(&mut self) -> usize {
let mut merges: Vec<(Id, Id)> = Vec::new();
let mut new_nodes: Vec<(Id, Expr)> = Vec::new();
let class_count = self.parents.len();
for class_idx in 0..class_count {
let class_id = Id(class_idx as u32);
let canonical = self.find(class_id);
if canonical != class_id {
continue;
}
let nodes = self.nodes[class_idx].clone();
for expr in &nodes {
if let Some(target) = self.rewrite_simplify(expr)
&& self.find(class_id) != self.find(target)
{
merges.push((class_id, target));
}
self.rewrite_construct(expr, class_id, &mut new_nodes);
}
}
for (class_id, expr) in new_nodes {
let new_id = self.add(expr);
if self.find(class_id) != self.find(new_id) {
merges.push((class_id, new_id));
}
}
merges.sort_by_key(|&(a, b)| (self.find(a).0, self.find(b).0));
merges
.dedup_by(|a, b| self.find(a.0) == self.find(b.0) && self.find(a.1) == self.find(b.1));
let count = merges
.iter()
.filter(|&&(a, b)| self.find(a) != self.find(b))
.count();
for (a, b) in merges {
self.merge(a, b);
}
count
}
fn rewrite_simplify(&self, expr: &Expr) -> Option<Id> {
match expr {
Expr::Add(a, b) => {
if self.is_zero(*b) {
return Some(*a);
}
if self.is_zero(*a) {
return Some(*b);
}
None
}
Expr::Sub(a, b) => {
if self.is_zero(*b) {
return Some(*a);
}
if self.find(*a) == self.find(*b) {
return self.find_num(0);
}
None
}
Expr::Mul(a, b) => {
if self.is_one(*b) {
return Some(*a);
}
if self.is_one(*a) {
return Some(*b);
}
if self.is_zero(*a) {
return Some(*a);
}
if self.is_zero(*b) {
return Some(*b);
}
None
}
Expr::Max(a, b) => {
if self.find(*a) == self.find(*b) {
return Some(*a);
}
if self.is_zero(*b) {
return Some(*a);
}
if self.is_zero(*a) {
return Some(*b);
}
None
}
Expr::Min(a, b) => {
if self.find(*a) == self.find(*b) {
return Some(*a);
}
None
}
Expr::Clamp { min, max, val } => {
if self.is_zero(*min) && self.is_max(*max) {
return Some(*val);
}
if self.find(*min) == self.find(*max) {
return Some(*min);
}
if self.find(*val) == self.find(*min) {
return Some(*min);
}
if self.find(*val) == self.find(*max) {
return Some(*max);
}
None
}
Expr::Div(a, b) => {
if self.is_one(*b) {
return Some(*a);
}
if self.find(*a) == self.find(*b) && !self.is_zero(*a) {
return self.find_num(1);
}
None
}
_ => None,
}
}
fn rewrite_construct(&self, expr: &Expr, _class_id: Id, out: &mut Vec<(Id, Expr)>) {
match expr {
Expr::Add(a, b) if self.find(*a) != self.find(*b) => {
out.push((_class_id, Expr::Add(*b, *a)));
}
Expr::Max(a, b) if self.find(*a) != self.find(*b) => {
out.push((_class_id, Expr::Max(*b, *a)));
}
Expr::Min(a, b) if self.find(*a) != self.find(*b) => {
out.push((_class_id, Expr::Min(*b, *a)));
}
Expr::Mul(a, b) if self.find(*a) != self.find(*b) => {
out.push((_class_id, Expr::Mul(*b, *a)));
}
_ => {}
}
}
fn find_num(&self, n: u16) -> Option<Id> {
self.memo.get(&Expr::Num(n)).map(|&id| self.find(id))
}
fn is_zero(&self, id: Id) -> bool {
let id = self.find(id);
self.nodes[id.index()]
.iter()
.any(|e| matches!(e, Expr::Num(0)))
}
fn is_one(&self, id: Id) -> bool {
let id = self.find(id);
self.nodes[id.index()]
.iter()
.any(|e| matches!(e, Expr::Num(1)))
}
fn is_max(&self, id: Id) -> bool {
let id = self.find(id);
self.nodes[id.index()]
.iter()
.any(|e| matches!(e, Expr::Num(u16::MAX)))
}
fn canonicalize(&self, expr: &Expr) -> Expr {
match expr {
Expr::Num(n) => Expr::Num(*n),
Expr::Var(v) => Expr::Var(*v),
Expr::Add(a, b) => Expr::Add(self.find(*a), self.find(*b)),
Expr::Sub(a, b) => Expr::Sub(self.find(*a), self.find(*b)),
Expr::Max(a, b) => Expr::Max(self.find(*a), self.find(*b)),
Expr::Min(a, b) => Expr::Min(self.find(*a), self.find(*b)),
Expr::Div(a, b) => Expr::Div(self.find(*a), self.find(*b)),
Expr::Mul(a, b) => Expr::Mul(self.find(*a), self.find(*b)),
Expr::Clamp { min, max, val } => Expr::Clamp {
min: self.find(*min),
max: self.find(*max),
val: self.find(*val),
},
Expr::HFlex(ids) => Expr::HFlex(ids.iter().map(|id| self.find(*id)).collect()),
Expr::VFlex(ids) => Expr::VFlex(ids.iter().map(|id| self.find(*id)).collect()),
Expr::Fill(id) => Expr::Fill(self.find(*id)),
}
}
pub fn extract(&self, id: Id) -> Expr {
let id = self.find(id);
self.nodes[id.index()]
.iter()
.min_by_key(|e| Self::cost(e))
.cloned()
.expect("e-class has no expressions")
}
fn cost(expr: &Expr) -> u32 {
match expr {
Expr::Num(_) => 1,
Expr::Var(_) => 2,
Expr::Add(_, _) | Expr::Sub(_, _) | Expr::Max(_, _) | Expr::Min(_, _) => 3,
Expr::Mul(_, _) | Expr::Div(_, _) => 4,
Expr::Fill(_) => 3,
Expr::Clamp { .. } => 5,
Expr::HFlex(ids) | Expr::VFlex(ids) => 10 + ids.len() as u32,
}
}
}
impl Default for EGraph {
fn default() -> Self {
Self::new()
}
}
pub fn encode_constraint(graph: &mut EGraph, constraint: &crate::Constraint, total: u16) -> Id {
match constraint {
crate::Constraint::Fixed(n) => graph.add(Expr::Num(*n)),
crate::Constraint::Min(n) => {
let min = graph.add(Expr::Num(*n));
let total_id = graph.add(Expr::Num(total));
graph.add(Expr::Clamp {
min,
max: total_id,
val: min,
})
}
crate::Constraint::Max(n) => {
let max = graph.add(Expr::Num(*n));
let zero = graph.add(Expr::Num(0));
graph.add(Expr::Clamp {
min: zero,
max,
val: max,
})
}
crate::Constraint::Percentage(pct) => {
let scaled = ((*pct / 100.0) * total as f32) as u16;
graph.add(Expr::Num(scaled))
}
crate::Constraint::Ratio(num, den) => {
let result = (total as u32)
.checked_mul(*num)
.and_then(|v| v.checked_div(*den))
.unwrap_or(0) as u16;
graph.add(Expr::Num(result))
}
crate::Constraint::Fill | crate::Constraint::FitContent | crate::Constraint::FitMin => {
let total_id = graph.add(Expr::Num(total));
graph.add(Expr::Fill(total_id))
}
crate::Constraint::FitContentBounded { min, max } => {
let min_id = graph.add(Expr::Num(*min));
let max_id = graph.add(Expr::Num(*max));
let preferred = graph.add(Expr::Num(total));
graph.add(Expr::Clamp {
min: min_id,
max: max_id,
val: preferred,
})
}
}
}
pub fn encode_flex(graph: &mut EGraph, children: &[Id], horizontal: bool) -> Id {
if horizontal {
graph.add(Expr::HFlex(children.to_vec()))
} else {
graph.add(Expr::VFlex(children.to_vec()))
}
}
#[derive(Clone, Debug)]
pub struct SaturationConfig {
pub node_budget: usize,
pub iteration_limit: usize,
pub time_limit_us: u64,
pub memory_limit: usize,
}
impl Default for SaturationConfig {
fn default() -> Self {
Self {
node_budget: 10_000,
iteration_limit: 100,
time_limit_us: 5_000, memory_limit: 10 * 1024 * 1024, }
}
}
impl SaturationConfig {
pub fn from_env() -> Self {
let mut config = Self::default();
if let Ok(val) = std::env::var("FRANKENTUI_EGRAPH_NODE_BUDGET")
&& let Ok(n) = val.parse::<usize>()
{
config.node_budget = n;
}
if let Ok(val) = std::env::var("FRANKENTUI_EGRAPH_TIMEOUT_MS")
&& let Ok(ms) = val.parse::<u64>()
{
config.time_limit_us = ms * 1000;
}
if let Ok(val) = std::env::var("FRANKENTUI_EGRAPH_MAX_ITERS")
&& let Ok(n) = val.parse::<usize>()
{
config.iteration_limit = n;
}
if let Ok(val) = std::env::var("FRANKENTUI_EGRAPH_MEMORY_MB")
&& let Ok(mb) = val.parse::<usize>()
{
config.memory_limit = mb * 1024 * 1024;
}
config
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum GuardTriggered {
None,
NodeBudget,
Timeout,
Memory,
IterationLimit,
}
#[derive(Clone, Debug)]
pub struct SaturationResult {
pub rewrites: usize,
pub iterations: usize,
pub saturated: bool,
pub stopped_early: bool,
pub node_count: usize,
pub guard: GuardTriggered,
pub time_us: u64,
pub memory_bytes: usize,
}
impl EGraph {
pub fn saturate(&mut self, config: &SaturationConfig) -> SaturationResult {
let start = std::time::Instant::now();
let mut total_rewrites = 0;
let mut iterations = 0;
let make_result = |this: &Self, rewrites, iterations, saturated, guard: GuardTriggered| {
let elapsed = start.elapsed().as_micros() as u64;
SaturationResult {
rewrites,
iterations,
saturated,
stopped_early: guard != GuardTriggered::None,
node_count: this.node_count(),
guard,
time_us: elapsed,
memory_bytes: this.memory_usage(),
}
};
for _ in 0..config.iteration_limit {
if self.node_count() >= config.node_budget {
return make_result(
self,
total_rewrites,
iterations,
false,
GuardTriggered::NodeBudget,
);
}
if config.time_limit_us > 0
&& start.elapsed().as_micros() as u64 >= config.time_limit_us
{
return make_result(
self,
total_rewrites,
iterations,
false,
GuardTriggered::Timeout,
);
}
if config.memory_limit > 0 && self.memory_usage() >= config.memory_limit {
return make_result(
self,
total_rewrites,
iterations,
false,
GuardTriggered::Memory,
);
}
let applied = self.apply_rules_once();
total_rewrites += applied;
iterations += 1;
if applied == 0 {
self.last_apply_count = total_rewrites;
return make_result(self, total_rewrites, iterations, true, GuardTriggered::None);
}
}
self.last_apply_count = total_rewrites;
make_result(
self,
total_rewrites,
iterations,
false,
GuardTriggered::IterationLimit,
)
}
pub fn memory_usage(&self) -> usize {
let parents = self.parents.capacity() * std::mem::size_of::<u32>();
let ranks = self.ranks.capacity();
let nodes: usize = self
.nodes
.iter()
.map(|v| v.capacity() * std::mem::size_of::<Expr>())
.sum();
let nodes_vec = self.nodes.capacity() * std::mem::size_of::<Vec<Expr>>();
let memo = self.memo.capacity() * (std::mem::size_of::<Expr>() + std::mem::size_of::<Id>());
parents + ranks + nodes + nodes_vec + memo
}
}
pub fn solve_layout(
constraints: &[crate::Constraint],
total: u16,
config: &SaturationConfig,
) -> (Vec<u16>, SaturationResult) {
let mut graph = EGraph::new();
let ids: Vec<Id> = constraints
.iter()
.map(|c| encode_constraint(&mut graph, c, total))
.collect();
let result = graph.saturate(config);
let sizes: Vec<u16> = ids
.iter()
.map(|&id| {
let expr = graph.extract(id);
match expr {
Expr::Num(n) => n,
_ => total,
}
})
.collect();
(sizes, result)
}
pub fn solve_layout_default(constraints: &[crate::Constraint], total: u16) -> Vec<u16> {
solve_layout(constraints, total, &SaturationConfig::default()).0
}
#[derive(Clone, Debug)]
pub struct EvidenceRecord {
pub test_name: String,
pub constraint_count: usize,
pub total_space: u16,
pub nodes_at_completion: usize,
pub iterations: usize,
pub time_us: u64,
pub memory_bytes: usize,
pub guard_triggered: GuardTriggered,
pub saturated: bool,
pub rewrites: usize,
}
impl EvidenceRecord {
pub fn from_result(
test_name: &str,
constraint_count: usize,
total: u16,
result: &SaturationResult,
) -> Self {
Self {
test_name: test_name.to_string(),
constraint_count,
total_space: total,
nodes_at_completion: result.node_count,
iterations: result.iterations,
time_us: result.time_us,
memory_bytes: result.memory_bytes,
guard_triggered: result.guard,
saturated: result.saturated,
rewrites: result.rewrites,
}
}
pub fn to_jsonl(&self) -> String {
format!(
concat!(
"{{\"test\":\"{}\",\"constraints\":{},\"total\":{},",
"\"nodes\":{},\"iterations\":{},\"time_us\":{},",
"\"memory_bytes\":{},\"guard\":\"{:?}\",",
"\"saturated\":{},\"rewrites\":{}}}"
),
self.test_name,
self.constraint_count,
self.total_space,
self.nodes_at_completion,
self.iterations,
self.time_us,
self.memory_bytes,
self.guard_triggered,
self.saturated,
self.rewrites,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn add_num() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(42));
assert_eq!(g.node_count(), 1);
assert_eq!(g.extract(a), Expr::Num(42));
}
#[test]
fn add_dedup() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(42));
let b = g.add(Expr::Num(42));
assert_eq!(a, b);
assert_eq!(g.node_count(), 1);
}
#[test]
fn merge_classes() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(1));
let b = g.add(Expr::Num(2));
assert!(!g.equiv(a, b));
g.merge(a, b);
assert!(g.equiv(a, b));
}
#[test]
fn merge_idempotent() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(1));
let first = g.merge(a, a);
assert_eq!(first, a);
}
#[test]
fn add_zero_identity() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(100));
let zero = g.add(Expr::Num(0));
let sum = g.add(Expr::Add(a, zero));
g.apply_rules();
assert!(g.equiv(sum, a), "Add(x, 0) should equal x");
}
#[test]
fn add_zero_identity_left() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(50));
let zero = g.add(Expr::Num(0));
let sum = g.add(Expr::Add(zero, a));
g.apply_rules();
assert!(g.equiv(sum, a), "Add(0, x) should equal x");
}
#[test]
fn sub_zero_identity() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(100));
let zero = g.add(Expr::Num(0));
let diff = g.add(Expr::Sub(a, zero));
g.apply_rules();
assert!(g.equiv(diff, a), "Sub(x, 0) should equal x");
}
#[test]
fn mul_one_identity() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(42));
let one = g.add(Expr::Num(1));
let prod = g.add(Expr::Mul(a, one));
g.apply_rules();
assert!(g.equiv(prod, a), "Mul(x, 1) should equal x");
}
#[test]
fn mul_zero_annihilation() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(42));
let zero = g.add(Expr::Num(0));
let prod = g.add(Expr::Mul(a, zero));
g.apply_rules();
assert!(g.equiv(prod, zero), "Mul(x, 0) should equal 0");
}
#[test]
fn max_idempotent() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(42));
let m = g.add(Expr::Max(a, a));
g.apply_rules();
assert!(g.equiv(m, a), "Max(x, x) should equal x");
}
#[test]
fn min_idempotent() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(42));
let m = g.add(Expr::Min(a, a));
g.apply_rules();
assert!(g.equiv(m, a), "Min(x, x) should equal x");
}
#[test]
fn clamp_unclamped() {
let mut g = EGraph::new();
let zero = g.add(Expr::Num(0));
let max = g.add(Expr::Num(u16::MAX));
let val = g.add(Expr::Num(50));
let clamped = g.add(Expr::Clamp {
min: zero,
max,
val,
});
g.apply_rules();
assert!(g.equiv(clamped, val), "Clamp(0, MAX, x) should equal x");
}
#[test]
fn extract_prefers_simpler() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(100));
let zero = g.add(Expr::Num(0));
let sum = g.add(Expr::Add(a, zero));
g.apply_rules();
let extracted = g.extract(sum);
assert_eq!(extracted, Expr::Num(100));
}
#[test]
fn extract_var_over_complex() {
let mut g = EGraph::new();
let v = g.add(Expr::Var(NodeId(0)));
let zero = g.add(Expr::Num(0));
let sum = g.add(Expr::Add(v, zero));
g.apply_rules();
let extracted = g.extract(sum);
assert_eq!(extracted, Expr::Var(NodeId(0)));
}
#[test]
fn counts_after_merge() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(1));
let b = g.add(Expr::Num(2));
assert_eq!(g.class_count(), 2);
g.merge(a, b);
assert_eq!(g.class_count(), 1);
assert_eq!(g.node_count(), 2); }
#[test]
fn encode_fixed_constraint() {
let mut g = EGraph::new();
let id = encode_constraint(&mut g, &crate::Constraint::Fixed(50), 200);
assert_eq!(g.extract(id), Expr::Num(50));
}
#[test]
fn encode_percentage_constraint() {
let mut g = EGraph::new();
let id = encode_constraint(&mut g, &crate::Constraint::Percentage(50.0), 200);
assert_eq!(g.extract(id), Expr::Num(100)); }
#[test]
fn encode_ratio_constraint() {
let mut g = EGraph::new();
let id = encode_constraint(&mut g, &crate::Constraint::Ratio(1, 4), 200);
assert_eq!(g.extract(id), Expr::Num(50)); }
#[test]
fn encode_ratio_zero_denom() {
let mut g = EGraph::new();
let id = encode_constraint(&mut g, &crate::Constraint::Ratio(1, 0), 200);
assert_eq!(g.extract(id), Expr::Num(0));
}
#[test]
fn encode_flex_horizontal() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(100));
let b = g.add(Expr::Num(200));
let flex = encode_flex(&mut g, &[a, b], true);
assert!(matches!(g.extract(flex), Expr::HFlex(_)));
}
#[test]
fn encode_flex_vertical() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(100));
let b = g.add(Expr::Num(200));
let flex = encode_flex(&mut g, &[a, b], false);
assert!(matches!(g.extract(flex), Expr::VFlex(_)));
}
#[test]
fn saturation_terminates() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(10));
let zero = g.add(Expr::Num(0));
let one = g.add(Expr::Num(1));
let mul = g.add(Expr::Mul(a, one));
let sum = g.add(Expr::Add(mul, zero));
let total = g.apply_rules();
assert!(total > 0, "rules should fire");
assert!(g.equiv(sum, a), "expression should simplify to a");
}
#[test]
fn apply_rules_returns_zero_when_saturated() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(10));
let b = g.add(Expr::Num(20));
let _sum = g.add(Expr::Add(a, b));
let total = g.apply_rules();
assert!(total >= 1, "commutativity should fire");
let second = g.apply_rules();
assert_eq!(second, 0, "already saturated");
}
#[test]
fn chained_identity_simplification() {
let mut g = EGraph::new();
let x = g.add(Expr::Num(42));
let zero = g.add(Expr::Num(0));
let one = g.add(Expr::Num(1));
let mul = g.add(Expr::Mul(x, one));
let sub = g.add(Expr::Sub(mul, zero));
let add = g.add(Expr::Add(sub, zero));
g.apply_rules();
assert!(g.equiv(add, x));
}
#[test]
fn typical_layout_size() {
let mut g = EGraph::new();
let widgets: Vec<_> = (0..5).map(|i| g.add(Expr::Var(NodeId(i)))).collect();
let _flex = encode_flex(&mut g, &widgets, true);
assert!(g.node_count() <= 10, "small layout should be compact");
}
#[test]
fn medium_layout_size() {
let mut g = EGraph::new();
let widgets: Vec<_> = (0..100).map(|i| g.add(Expr::Var(NodeId(i)))).collect();
let _flex = encode_flex(&mut g, &widgets, true);
assert!(g.node_count() <= 200);
}
#[test]
fn sub_self_is_zero() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(42));
let zero = g.add(Expr::Num(0));
let sub = g.add(Expr::Sub(a, a));
g.apply_rules();
assert!(g.equiv(sub, zero), "Sub(x, x) should equal 0");
}
#[test]
fn div_one_identity() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(42));
let one = g.add(Expr::Num(1));
let div = g.add(Expr::Div(a, one));
g.apply_rules();
assert!(g.equiv(div, a), "Div(x, 1) should equal x");
}
#[test]
fn div_self_is_one() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(42));
let one = g.add(Expr::Num(1));
let div = g.add(Expr::Div(a, a));
g.apply_rules();
assert!(g.equiv(div, one), "Div(x, x) should equal 1");
}
#[test]
fn max_zero_identity() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(42));
let zero = g.add(Expr::Num(0));
let m = g.add(Expr::Max(a, zero));
g.apply_rules();
assert!(g.equiv(m, a), "Max(x, 0) should equal x");
}
#[test]
fn max_zero_identity_left() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(42));
let zero = g.add(Expr::Num(0));
let m = g.add(Expr::Max(zero, a));
g.apply_rules();
assert!(g.equiv(m, a), "Max(0, x) should equal x");
}
#[test]
fn clamp_equal_bounds() {
let mut g = EGraph::new();
let bound = g.add(Expr::Num(50));
let val = g.add(Expr::Num(100));
let clamped = g.add(Expr::Clamp {
min: bound,
max: bound,
val,
});
g.apply_rules();
assert!(g.equiv(clamped, bound), "Clamp(x, x, _) should equal x");
}
#[test]
fn clamp_val_equals_min() {
let mut g = EGraph::new();
let min = g.add(Expr::Num(10));
let max = g.add(Expr::Num(100));
let clamped = g.add(Expr::Clamp { min, max, val: min });
g.apply_rules();
assert!(
g.equiv(clamped, min),
"Clamp(min, max, min) should equal min"
);
}
#[test]
fn add_commutativity() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(10));
let b = g.add(Expr::Num(20));
let ab = g.add(Expr::Add(a, b));
let ba = g.add(Expr::Add(b, a));
g.apply_rules();
assert!(g.equiv(ab, ba), "Add(a, b) should equal Add(b, a)");
}
#[test]
fn max_commutativity() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(10));
let b = g.add(Expr::Num(20));
let ab = g.add(Expr::Max(a, b));
let ba = g.add(Expr::Max(b, a));
g.apply_rules();
assert!(g.equiv(ab, ba), "Max(a, b) should equal Max(b, a)");
}
#[test]
fn min_commutativity() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(10));
let b = g.add(Expr::Num(20));
let ab = g.add(Expr::Min(a, b));
let ba = g.add(Expr::Min(b, a));
g.apply_rules();
assert!(g.equiv(ab, ba), "Min(a, b) should equal Min(b, a)");
}
#[test]
fn mul_commutativity() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(3));
let b = g.add(Expr::Num(7));
let ab = g.add(Expr::Mul(a, b));
let ba = g.add(Expr::Mul(b, a));
g.apply_rules();
assert!(g.equiv(ab, ba), "Mul(a, b) should equal Mul(b, a)");
}
#[test]
fn saturate_basic() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(100));
let zero = g.add(Expr::Num(0));
let sum = g.add(Expr::Add(a, zero));
let result = g.saturate(&SaturationConfig::default());
assert!(result.saturated);
assert!(!result.stopped_early);
assert!(result.rewrites > 0);
assert!(g.equiv(sum, a));
}
#[test]
fn saturate_with_node_budget() {
let mut g = EGraph::new();
for i in 0..50u16 {
let v = g.add(Expr::Var(NodeId(i as u32)));
let n = g.add(Expr::Num(i + 1));
g.add(Expr::Add(v, n));
}
let initial = g.node_count();
let config = SaturationConfig {
node_budget: initial + 10,
iteration_limit: 1000,
time_limit_us: 0,
memory_limit: 0,
};
let result = g.saturate(&config);
assert!(result.stopped_early, "should stop due to node budget");
assert_eq!(result.guard, GuardTriggered::NodeBudget);
}
#[test]
fn saturate_with_iteration_limit() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(10));
let b = g.add(Expr::Num(20));
let _sum = g.add(Expr::Add(a, b));
let config = SaturationConfig {
node_budget: 10_000,
iteration_limit: 1,
time_limit_us: 0,
memory_limit: 0,
};
let result = g.saturate(&config);
assert!(result.iterations <= 1);
assert_eq!(result.guard, GuardTriggered::IterationLimit);
}
#[test]
fn saturate_with_memory_limit() {
let mut g = EGraph::new();
for i in 0..200u16 {
let v = g.add(Expr::Var(NodeId(i as u32)));
let n = g.add(Expr::Num(i + 1));
g.add(Expr::Add(v, n));
}
let config = SaturationConfig {
node_budget: 100_000,
iteration_limit: 1000,
time_limit_us: 0,
memory_limit: 1, };
let result = g.saturate(&config);
assert!(result.stopped_early);
assert_eq!(result.guard, GuardTriggered::Memory);
}
#[test]
fn saturate_guard_none_on_completion() {
let mut g = EGraph::new();
let a = g.add(Expr::Num(100));
let zero = g.add(Expr::Num(0));
let _sum = g.add(Expr::Add(a, zero));
let result = g.saturate(&SaturationConfig::default());
assert!(result.saturated);
assert_eq!(result.guard, GuardTriggered::None);
assert!(result.time_us > 0 || result.iterations > 0);
assert!(result.memory_bytes > 0);
}
#[test]
fn saturate_result_has_timing() {
let mut g = EGraph::new();
for i in 0..100u16 {
let v = g.add(Expr::Var(NodeId(i as u32)));
let zero = g.add(Expr::Num(0));
g.add(Expr::Add(v, zero));
}
let result = g.saturate(&SaturationConfig::default());
assert!(result.memory_bytes > 0);
assert!(result.node_count > 0);
}
#[test]
fn saturate_memory_bounded() {
let mut g = EGraph::new();
for i in 0..500u16 {
let v = g.add(Expr::Var(NodeId(i as u32)));
let zero = g.add(Expr::Num(0));
g.add(Expr::Add(v, zero));
}
g.saturate(&SaturationConfig::default());
let mem = g.memory_usage();
assert!(mem < 10 * 1024 * 1024, "memory {} exceeds 10MB budget", mem);
}
#[test]
fn solve_fixed_constraints() {
let constraints = vec![
crate::Constraint::Fixed(50),
crate::Constraint::Fixed(100),
crate::Constraint::Fixed(50),
];
let sizes = solve_layout_default(&constraints, 200);
assert_eq!(sizes, vec![50, 100, 50]);
}
#[test]
fn solve_percentage_constraints() {
let constraints = vec![
crate::Constraint::Percentage(25.0),
crate::Constraint::Percentage(75.0),
];
let sizes = solve_layout_default(&constraints, 200);
assert_eq!(sizes, vec![50, 150]);
}
#[test]
fn solve_ratio_constraints() {
let constraints = vec![
crate::Constraint::Ratio(1, 3),
crate::Constraint::Ratio(2, 3),
];
let sizes = solve_layout_default(&constraints, 300);
assert_eq!(sizes, vec![100, 200]);
}
#[test]
fn solve_fill_constraint() {
let constraints = vec![crate::Constraint::Fill];
let sizes = solve_layout_default(&constraints, 120);
assert_eq!(sizes, vec![120]);
}
#[test]
fn solve_mixed_constraints() {
let constraints = vec![
crate::Constraint::Fixed(30),
crate::Constraint::Percentage(50.0),
crate::Constraint::Ratio(1, 4),
];
let sizes = solve_layout_default(&constraints, 200);
assert_eq!(sizes, vec![30, 100, 50]);
}
#[test]
fn solve_empty_constraints() {
let constraints: Vec<crate::Constraint> = vec![];
let sizes = solve_layout_default(&constraints, 200);
assert!(sizes.is_empty());
}
#[test]
fn solve_returns_saturation_result() {
let constraints = vec![crate::Constraint::Fixed(50)];
let (sizes, result) = solve_layout(&constraints, 200, &SaturationConfig::default());
assert_eq!(sizes, vec![50]);
assert!(result.saturated);
}
#[test]
fn solve_500_widgets() {
let constraints: Vec<_> = (0..500)
.map(|i| crate::Constraint::Fixed(i as u16 % 100))
.collect();
let config = SaturationConfig::default();
let (sizes, result) = solve_layout(&constraints, 1000, &config);
assert_eq!(sizes.len(), 500);
for (i, &s) in sizes.iter().enumerate() {
assert_eq!(s, i as u16 % 100);
}
assert!(!result.stopped_early || result.node_count <= config.node_budget + 500);
}
#[test]
fn solve_fit_content_bounded() {
let constraints = vec![crate::Constraint::FitContentBounded { min: 10, max: 50 }];
let sizes = solve_layout_default(&constraints, 200);
assert_eq!(sizes.len(), 1);
}
#[test]
fn config_default_values() {
let config = SaturationConfig::default();
assert_eq!(config.node_budget, 10_000);
assert_eq!(config.time_limit_us, 5_000);
assert_eq!(config.iteration_limit, 100);
assert_eq!(config.memory_limit, 10 * 1024 * 1024);
}
#[test]
fn from_env_returns_valid_config() {
let config = SaturationConfig::from_env();
assert!(config.node_budget > 0);
assert!(config.iteration_limit > 0);
}
#[test]
fn random_constraints_never_oom_or_hang() {
let constraint_sets: Vec<Vec<crate::Constraint>> = vec![
(0..1000)
.map(|i| crate::Constraint::Fixed(i as u16))
.collect(),
(0..500).map(|_| crate::Constraint::Fill).collect(),
(0..200)
.map(|i| crate::Constraint::Percentage(i as f32 * 0.5))
.collect(),
(0..100)
.map(|i| crate::Constraint::Ratio(i + 1, 100))
.collect(),
(0..300).map(|i| crate::Constraint::Min(i as u16)).collect(),
(0..300).map(|i| crate::Constraint::Max(i as u16)).collect(),
(0..500)
.map(|i| crate::Constraint::FitContentBounded {
min: i as u16,
max: i as u16 + 100,
})
.collect(),
];
let config = SaturationConfig {
node_budget: 10_000,
iteration_limit: 100,
time_limit_us: 50_000, memory_limit: 10 * 1024 * 1024,
};
for constraints in &constraint_sets {
let (sizes, result) = solve_layout(constraints, 1000, &config);
assert_eq!(sizes.len(), constraints.len());
assert!(
result.memory_bytes <= config.memory_limit + 1024 * 1024,
"memory {} exceeded limit + margin for {:?}",
result.memory_bytes,
result.guard,
);
}
}
#[test]
fn fuzz_1000_random_constraint_sets() {
let mut seed: u32 = 42;
let mut rng = || -> u32 {
seed ^= seed << 13;
seed ^= seed >> 17;
seed ^= seed << 5;
seed
};
let config = SaturationConfig {
node_budget: 10_000,
iteration_limit: 100,
time_limit_us: 50_000,
memory_limit: 10 * 1024 * 1024,
};
for _ in 0..1000 {
let count = (rng() % 50 + 1) as usize;
let total = (rng() % 500 + 1) as u16;
let constraints: Vec<_> = (0..count)
.map(|_| {
let kind = rng() % 7;
match kind {
0 => crate::Constraint::Fixed((rng() % (total as u32 + 1)) as u16),
1 => crate::Constraint::Percentage((rng() % 101) as f32),
2 => crate::Constraint::Min((rng() % (total as u32 + 1)) as u16),
3 => crate::Constraint::Max((rng() % (total as u32 + 1)) as u16),
4 => {
let den = rng() % 10 + 1;
let num = rng() % (den + 1);
crate::Constraint::Ratio(num, den)
}
5 => crate::Constraint::Fill,
_ => {
let min = (rng() % (total as u32 + 1)) as u16;
let max = min.saturating_add((rng() % 100) as u16);
crate::Constraint::FitContentBounded { min, max }
}
}
})
.collect();
let (sizes, result) = solve_layout(&constraints, total, &config);
assert_eq!(sizes.len(), constraints.len());
assert!(result.memory_bytes < config.memory_limit + 2 * 1024 * 1024);
}
}
#[test]
fn deterministic_across_runs() {
let constraints = vec![
crate::Constraint::Fixed(30),
crate::Constraint::Percentage(50.0),
crate::Constraint::Ratio(1, 4),
crate::Constraint::Fill,
crate::Constraint::Min(10),
crate::Constraint::Max(80),
];
let config = SaturationConfig {
node_budget: 10_000,
iteration_limit: 100,
time_limit_us: 0, memory_limit: 0,
};
let (sizes1, r1) = solve_layout(&constraints, 200, &config);
let (sizes2, r2) = solve_layout(&constraints, 200, &config);
assert_eq!(sizes1, sizes2, "sizes must be deterministic");
assert_eq!(r1.rewrites, r2.rewrites, "rewrites must be deterministic");
assert_eq!(
r1.iterations, r2.iterations,
"iterations must be deterministic"
);
assert_eq!(
r1.node_count, r2.node_count,
"node_count must be deterministic"
);
assert_eq!(r1.guard, r2.guard, "guard must be deterministic");
}
#[test]
fn evidence_record_from_result() {
let constraints = vec![
crate::Constraint::Fixed(50),
crate::Constraint::Percentage(50.0),
];
let (_, result) = solve_layout(&constraints, 200, &SaturationConfig::default());
let record = EvidenceRecord::from_result("test_case", 2, 200, &result);
assert_eq!(record.test_name, "test_case");
assert_eq!(record.constraint_count, 2);
assert_eq!(record.total_space, 200);
assert!(record.nodes_at_completion > 0);
}
#[test]
fn evidence_record_to_jsonl() {
let result = SaturationResult {
rewrites: 5,
iterations: 3,
saturated: true,
stopped_early: false,
node_count: 42,
guard: GuardTriggered::None,
time_us: 1234,
memory_bytes: 8192,
};
let record = EvidenceRecord::from_result("demo", 10, 200, &result);
let jsonl = record.to_jsonl();
assert!(jsonl.starts_with('{'));
assert!(jsonl.ends_with('}'));
assert!(jsonl.contains("\"test\":\"demo\""));
assert!(jsonl.contains("\"constraints\":10"));
assert!(jsonl.contains("\"nodes\":42"));
assert!(jsonl.contains("\"saturated\":true"));
assert!(jsonl.contains("\"guard\":\"None\""));
}
#[test]
fn evidence_records_for_all_guards() {
for guard in [
GuardTriggered::None,
GuardTriggered::NodeBudget,
GuardTriggered::Timeout,
GuardTriggered::Memory,
GuardTriggered::IterationLimit,
] {
let result = SaturationResult {
rewrites: 0,
iterations: 1,
saturated: guard == GuardTriggered::None,
stopped_early: guard != GuardTriggered::None,
node_count: 10,
guard,
time_us: 100,
memory_bytes: 1024,
};
let record = EvidenceRecord::from_result("guard_test", 1, 100, &result);
let jsonl = record.to_jsonl();
assert!(jsonl.contains(&format!("{guard:?}")));
}
}
#[test]
fn empty_input_produces_empty_layout() {
let (sizes, result) = solve_layout(&[], 200, &SaturationConfig::default());
assert!(sizes.is_empty());
assert!(result.saturated);
}
#[test]
fn single_constraint_no_rewriting_needed() {
let (sizes, result) = solve_layout(
&[crate::Constraint::Fixed(42)],
200,
&SaturationConfig::default(),
);
assert_eq!(sizes, vec![42]);
assert!(result.saturated);
assert_eq!(result.guard, GuardTriggered::None);
}
#[test]
fn zero_total_space() {
let constraints = vec![
crate::Constraint::Fixed(50),
crate::Constraint::Percentage(50.0),
crate::Constraint::Fill,
];
let (sizes, _) = solve_layout(&constraints, 0, &SaturationConfig::default());
assert_eq!(sizes.len(), 3);
assert_eq!(sizes[0], 50); assert_eq!(sizes[1], 0); }
}