use std::path::Path;
use rustc_hash::FxHashMap;
use serde::Deserialize;
use crate::optimizer::eqsat::{EClassId, EGraph, ENodeLang, Rule};
pub trait OpIdNode {
fn op_id(&self) -> Option<&str>;
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
pub struct EquivalenceRule {
pub left: String,
pub right: String,
}
#[derive(Debug, Clone, Deserialize)]
struct RuleFile {
#[serde(default)]
schema: u32,
#[serde(default)]
equivalence: Vec<EquivalenceRule>,
}
#[derive(Debug, Clone)]
pub struct TomlEquivalenceRules {
name: &'static str,
rules: Vec<EquivalenceRule>,
}
impl TomlEquivalenceRules {
#[must_use]
pub fn new(name: &'static str) -> Self {
Self {
name,
rules: Vec::new(),
}
}
pub fn load(name: &'static str, path: &Path) -> std::io::Result<Self> {
let text = std::fs::read_to_string(path)?;
Self::from_toml_str(name, &text)
}
pub fn from_toml_str(name: &'static str, text: &str) -> std::io::Result<Self> {
let parsed: RuleFile = toml::from_str(text)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
if parsed.schema != 1 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Fix: TOML rule file declares schema = {}, expected schema = 1.",
parsed.schema
),
));
}
Ok(Self {
name,
rules: parsed.equivalence,
})
}
#[must_use]
pub fn len(&self) -> usize {
self.rules.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.rules.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &EquivalenceRule> {
self.rules.iter()
}
}
impl<L> Rule<L> for TomlEquivalenceRules
where
L: ENodeLang + OpIdNode,
{
fn name(&self) -> &'static str {
self.name
}
fn matches(&self, egraph: &EGraph<L>) -> Vec<(EClassId, EClassId)> {
if self.rules.is_empty() {
return Vec::new();
}
let mut by_op: FxHashMap<&str, Vec<EClassId>> = FxHashMap::default();
for (cid, node) in egraph.iter_nodes() {
if let Some(op_id) = node.op_id() {
by_op.entry(op_id).or_default().push(cid);
}
}
let mut equivs = Vec::new();
for rule in &self.rules {
let lefts = by_op.get(rule.left.as_str());
let rights = by_op.get(rule.right.as_str());
if let (Some(lefts), Some(rights)) = (lefts, rights) {
for &a in lefts {
for &b in rights {
if a != b {
equivs.push((a, b));
}
}
}
}
}
equivs
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::optimizer::eqsat::{EChildren, EGraph};
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, PartialEq, Eq)]
enum Toy {
Named(&'static str, Vec<EClassId>),
Lit(u32),
}
impl Hash for Toy {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
Toy::Named(name, children) => {
state.write_u8(0);
name.hash(state);
for c in children {
c.hash(state);
}
}
Toy::Lit(v) => {
state.write_u8(1);
v.hash(state);
}
}
}
}
impl ENodeLang for Toy {
fn children(&self) -> EChildren {
match self {
Toy::Named(_, kids) => kids.iter().copied().collect(),
Toy::Lit(_) => EChildren::new(),
}
}
fn with_children(&self, children: &[EClassId]) -> Self {
match self {
Toy::Named(name, _) => Toy::Named(name, children.to_vec()),
Toy::Lit(v) => Toy::Lit(*v),
}
}
}
impl OpIdNode for Toy {
fn op_id(&self) -> Option<&str> {
match self {
Toy::Named(name, _) => Some(name),
Toy::Lit(_) => None,
}
}
}
#[test]
fn from_toml_str_parses_equivalence_pairs() {
let toml = r#"
schema = 1
[[equivalence]]
left = "a"
right = "b"
[[equivalence]]
left = "c"
right = "d"
"#;
let rules = TomlEquivalenceRules::from_toml_str("test", toml).unwrap();
assert_eq!(rules.len(), 2);
assert_eq!(rules.iter().next().unwrap().left, "a");
}
#[test]
fn from_toml_str_rejects_wrong_schema() {
let toml = "schema = 99\nequivalence = []\n";
let err = TomlEquivalenceRules::from_toml_str("test", toml).unwrap_err();
assert!(format!("{err}").contains("schema = 1"));
}
#[test]
fn from_toml_str_accepts_empty_equivalence() {
let toml = "schema = 1\n";
let rules = TomlEquivalenceRules::from_toml_str("test", toml).unwrap();
assert!(rules.is_empty());
}
#[test]
fn matches_returns_empty_when_no_rules() {
let mut egraph: EGraph<Toy> = EGraph::new();
let _ = egraph.add(Toy::Named("a", vec![]));
let _ = egraph.add(Toy::Named("b", vec![]));
let rules = TomlEquivalenceRules::new("empty");
assert!(rules.matches(&egraph).is_empty());
}
#[test]
fn matches_emits_pair_when_both_op_ids_present() {
let mut egraph: EGraph<Toy> = EGraph::new();
let a = egraph.add(Toy::Named("a", vec![]));
let b = egraph.add(Toy::Named("b", vec![]));
let toml = "schema = 1\n[[equivalence]]\nleft = \"a\"\nright = \"b\"\n";
let rules = TomlEquivalenceRules::from_toml_str("test", toml).unwrap();
let pairs = rules.matches(&egraph);
assert_eq!(pairs.len(), 1);
assert!(
(pairs[0].0 == a && pairs[0].1 == b) || (pairs[0].0 == b && pairs[0].1 == a),
"expected (a, b) pair; got {pairs:?}"
);
}
#[test]
fn matches_empty_when_one_side_absent() {
let mut egraph: EGraph<Toy> = EGraph::new();
let _ = egraph.add(Toy::Named("a", vec![]));
let toml = "schema = 1\n[[equivalence]]\nleft = \"a\"\nright = \"b\"\n";
let rules = TomlEquivalenceRules::from_toml_str("test", toml).unwrap();
assert!(rules.matches(&egraph).is_empty());
}
#[test]
fn matches_skips_leaf_nodes_without_op_id() {
let mut egraph: EGraph<Toy> = EGraph::new();
let _ = egraph.add(Toy::Lit(7));
let _ = egraph.add(Toy::Lit(8));
let toml = "schema = 1\n[[equivalence]]\nleft = \"7\"\nright = \"8\"\n";
let rules = TomlEquivalenceRules::from_toml_str("test", toml).unwrap();
assert!(rules.matches(&egraph).is_empty());
}
#[test]
fn rule_name_forwards_constructor_name() {
let rules: TomlEquivalenceRules = TomlEquivalenceRules::new("algebra_v1");
let r: &dyn Rule<Toy> = &rules;
assert_eq!(r.name(), "algebra_v1");
}
}