use crate::kernel::{Domain, ExprId, ExprPool};
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SideCondition {
NonZero(ExprId),
Positive(ExprId),
InDomain(ExprId, Domain),
}
impl SideCondition {
pub fn display_with<'a>(&'a self, pool: &'a ExprPool) -> SideConditionDisplay<'a> {
SideConditionDisplay { cond: self, pool }
}
}
impl fmt::Display for SideCondition {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SideCondition::NonZero(id) => write!(f, "nonzero({id:?})"),
SideCondition::Positive(id) => write!(f, "positive({id:?})"),
SideCondition::InDomain(id, d) => write!(f, "in_domain({id:?}, {d:?})"),
}
}
}
pub struct SideConditionDisplay<'a> {
cond: &'a SideCondition,
pool: &'a ExprPool,
}
impl fmt::Display for SideConditionDisplay<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.cond {
SideCondition::NonZero(id) => {
write!(f, "{} ≠ 0", self.pool.display(*id))
}
SideCondition::Positive(id) => {
write!(f, "{} > 0", self.pool.display(*id))
}
SideCondition::InDomain(id, d) => {
write!(f, "{} ∈ {:?}", self.pool.display(*id), d)
}
}
}
}
#[derive(Debug, Clone)]
pub struct RewriteStep {
pub rule_name: &'static str,
pub before: ExprId,
pub after: ExprId,
pub side_conditions: Vec<SideCondition>,
}
impl RewriteStep {
pub fn simple(rule_name: &'static str, before: ExprId, after: ExprId) -> Self {
RewriteStep {
rule_name,
before,
after,
side_conditions: vec![],
}
}
pub fn with_conditions(
rule_name: &'static str,
before: ExprId,
after: ExprId,
side_conditions: Vec<SideCondition>,
) -> Self {
RewriteStep {
rule_name,
before,
after,
side_conditions,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct DerivationLog(pub Vec<RewriteStep>);
impl DerivationLog {
pub fn new() -> Self {
DerivationLog(Vec::new())
}
pub fn push(&mut self, step: RewriteStep) {
self.0.push(step);
}
pub fn merge(mut self, other: DerivationLog) -> DerivationLog {
self.0.extend(other.0);
self
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn steps(&self) -> &[RewriteStep] {
&self.0
}
pub fn display_with<'a>(&'a self, pool: &'a ExprPool) -> LogDisplay<'a> {
LogDisplay { log: self, pool }
}
}
impl fmt::Display for DerivationLog {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.0.is_empty() {
return write!(f, "(no steps)");
}
for (i, step) in self.0.iter().enumerate() {
if i > 0 {
writeln!(f)?;
}
write!(
f,
"step {}: {} ({:?} → {:?})",
i + 1,
step.rule_name,
step.before,
step.after
)?;
}
Ok(())
}
}
pub struct LogDisplay<'a> {
log: &'a DerivationLog,
pool: &'a ExprPool,
}
impl fmt::Display for LogDisplay<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.log.0.is_empty() {
return write!(f, "(no steps)");
}
for (i, step) in self.log.0.iter().enumerate() {
if i > 0 {
writeln!(f)?;
}
write!(
f,
"step {}: {} applied to {} → {}",
i + 1,
step.rule_name,
self.pool.display(step.before),
self.pool.display(step.after)
)?;
for cond in &step.side_conditions {
write!(f, " [{}]", cond.display_with(self.pool))?;
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct DerivedExpr<T> {
pub value: T,
pub log: DerivationLog,
}
impl<T> DerivedExpr<T> {
pub fn new(value: T) -> Self {
DerivedExpr {
value,
log: DerivationLog::new(),
}
}
pub fn with_step(value: T, step: RewriteStep) -> Self {
let mut log = DerivationLog::new();
log.push(step);
DerivedExpr { value, log }
}
pub fn with_log(value: T, log: DerivationLog) -> Self {
DerivedExpr { value, log }
}
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> DerivedExpr<U> {
DerivedExpr {
value: f(self.value),
log: self.log,
}
}
pub fn and_then<U, F: FnOnce(T) -> DerivedExpr<U>>(self, f: F) -> DerivedExpr<U> {
let next = f(self.value);
DerivedExpr {
value: next.value,
log: self.log.merge(next.log),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::{Domain, ExprPool};
fn pool_and_x() -> (ExprPool, ExprId) {
let p = ExprPool::new();
let x = p.symbol("x", Domain::Real);
(p, x)
}
#[test]
fn log_push_and_len() {
let (p, x) = pool_and_x();
let one = p.integer(1_i32);
let mut log = DerivationLog::new();
assert!(log.is_empty());
log.push(RewriteStep::simple("test_rule", x, one));
assert_eq!(log.len(), 1);
}
#[test]
fn log_merge_order() {
let (p, x) = pool_and_x();
let one = p.integer(1_i32);
let two = p.integer(2_i32);
let mut a = DerivationLog::new();
a.push(RewriteStep::simple("rule_a", x, one));
let mut b = DerivationLog::new();
b.push(RewriteStep::simple("rule_b", one, two));
let merged = a.merge(b);
assert_eq!(merged.len(), 2);
assert_eq!(merged.steps()[0].rule_name, "rule_a");
assert_eq!(merged.steps()[1].rule_name, "rule_b");
}
#[test]
fn display_without_pool() {
let (p, x) = pool_and_x();
let one = p.integer(1_i32);
let mut log = DerivationLog::new();
log.push(RewriteStep::simple("add_zero", x, one));
let s = log.to_string();
assert!(s.contains("step 1"), "should have step 1: {s}");
assert!(s.contains("add_zero"), "should mention rule: {s}");
}
#[test]
fn display_with_pool() {
let (p, x) = pool_and_x();
let one = p.integer(1_i32);
let mut log = DerivationLog::new();
log.push(RewriteStep::simple("diff_identity", x, one));
let s = log.display_with(&p).to_string();
assert!(s.contains("diff_identity"), "{s}");
assert!(s.contains('x'), "{s}");
assert!(s.contains('1'), "{s}");
}
#[test]
fn side_condition_display() {
let (p, x) = pool_and_x();
let cond = SideCondition::NonZero(x);
let s = cond.display_with(&p).to_string();
assert!(s.contains('x'), "{s}");
assert!(s.contains('0'), "{s}");
}
#[test]
fn derived_expr_map() {
let d: DerivedExpr<i32> = DerivedExpr::new(5);
let doubled = d.map(|v| v * 2);
assert_eq!(doubled.value, 10);
assert!(doubled.log.is_empty());
}
#[test]
fn derived_expr_and_then_merges_logs() {
let (p, x) = pool_and_x();
let one = p.integer(1_i32);
let two = p.integer(2_i32);
let d = DerivedExpr::with_step(x, RewriteStep::simple("step_a", x, one));
let result =
d.and_then(|_| DerivedExpr::with_step(two, RewriteStep::simple("step_b", one, two)));
assert_eq!(result.value, two);
assert_eq!(result.log.len(), 2);
assert_eq!(result.log.steps()[0].rule_name, "step_a");
assert_eq!(result.log.steps()[1].rule_name, "step_b");
}
}