egraph_serialize/
lib.rs

1#[cfg(feature = "graphviz")]
2mod graphviz;
3
4mod algorithms;
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use indexmap::{map::Entry, IndexMap};
10use once_cell::sync::OnceCell;
11use ordered_float::NotNan;
12
13pub type Cost = NotNan<f64>;
14
15#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
16#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
17pub struct NodeId(Arc<str>);
18
19#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
20#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
21pub struct ClassId(Arc<str>);
22
23mod id_impls {
24    use super::*;
25
26    impl AsRef<str> for NodeId {
27        fn as_ref(&self) -> &str {
28            &self.0
29        }
30    }
31
32    impl<S: Into<String>> From<S> for NodeId {
33        fn from(s: S) -> Self {
34            Self(s.into().into())
35        }
36    }
37
38    impl std::fmt::Display for NodeId {
39        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40            write!(f, "{}", self.0)
41        }
42    }
43
44    impl AsRef<str> for ClassId {
45        fn as_ref(&self) -> &str {
46            &self.0
47        }
48    }
49
50    impl<S: Into<String>> From<S> for ClassId {
51        fn from(s: S) -> Self {
52            Self(s.into().into())
53        }
54    }
55
56    impl std::fmt::Display for ClassId {
57        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58            write!(f, "{}", self.0)
59        }
60    }
61}
62
63#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
64#[derive(Debug, Default, Clone, PartialEq, Eq)]
65pub struct EGraph {
66    pub nodes: IndexMap<NodeId, Node>,
67    #[cfg_attr(feature = "serde", serde(default))]
68    pub root_eclasses: Vec<ClassId>,
69    // Optional mapping of e-class ids to some additional data about the e-class
70    #[cfg_attr(feature = "serde", serde(default))]
71    pub class_data: IndexMap<ClassId, ClassData>,
72    #[cfg_attr(feature = "serde", serde(skip))]
73    once_cell_classes: OnceCell<IndexMap<ClassId, Class>>,
74}
75
76impl EGraph {
77    /// Adds a new node to the egraph
78    ///
79    /// Panics if a node with the same id already exists
80    pub fn add_node(&mut self, node_id: impl Into<NodeId>, node: Node) {
81        match self.nodes.entry(node_id.into()) {
82            Entry::Occupied(e) => {
83                panic!(
84                    "Duplicate node with id {key:?}\nold: {old:?}\nnew: {new:?}",
85                    key = e.key(),
86                    old = e.get(),
87                    new = node
88                )
89            }
90            Entry::Vacant(e) => e.insert(node),
91        };
92    }
93
94    pub fn nid_to_cid(&self, node_id: &NodeId) -> &ClassId {
95        &self[node_id].eclass
96    }
97
98    pub fn nid_to_class(&self, node_id: &NodeId) -> &Class {
99        &self[&self[node_id].eclass]
100    }
101
102    /// Groups the nodes in the e-graph by their e-class
103    ///
104    /// This is *only done once* and then the result is cached.
105    /// Modifications to the e-graph will not be reflected
106    /// in later calls to this function.
107    pub fn classes(&self) -> &IndexMap<ClassId, Class> {
108        self.once_cell_classes.get_or_init(|| {
109            let mut classes = IndexMap::new();
110            for (node_id, node) in &self.nodes {
111                classes
112                    .entry(node.eclass.clone())
113                    .or_insert_with(|| Class {
114                        id: node.eclass.clone(),
115                        nodes: vec![],
116                    })
117                    .nodes
118                    .push(node_id.clone())
119            }
120            classes
121        })
122    }
123
124    #[cfg(feature = "serde")]
125    pub fn from_json_file(path: impl AsRef<std::path::Path>) -> std::io::Result<Self> {
126        let file = std::fs::File::open(path)?;
127        let egraph: Self = serde_json::from_reader(std::io::BufReader::new(file))?;
128        Ok(egraph)
129    }
130
131    #[cfg(feature = "serde")]
132    pub fn to_json_file(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> {
133        let file = std::fs::File::create(path)?;
134        serde_json::to_writer_pretty(std::io::BufWriter::new(file), self)?;
135        Ok(())
136    }
137
138    #[cfg(feature = "serde")]
139    pub fn test_round_trip(&self) {
140        let json = serde_json::to_string_pretty(&self).unwrap();
141        let egraph2: EGraph = serde_json::from_str(&json).unwrap();
142        assert_eq!(self, &egraph2);
143    }
144}
145
146impl std::ops::Index<&NodeId> for EGraph {
147    type Output = Node;
148
149    fn index(&self, index: &NodeId) -> &Self::Output {
150        self.nodes
151            .get(index)
152            .unwrap_or_else(|| panic!("No node with id {index:?}"))
153    }
154}
155
156impl std::ops::Index<&ClassId> for EGraph {
157    type Output = Class;
158
159    fn index(&self, index: &ClassId) -> &Self::Output {
160        self.classes()
161            .get(index)
162            .unwrap_or_else(|| panic!("No class with id {index:?}"))
163    }
164}
165
166#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
167#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
168pub struct Node {
169    pub op: String,
170    #[cfg_attr(feature = "serde", serde(default))]
171    pub children: Vec<NodeId>,
172    pub eclass: ClassId,
173    #[cfg_attr(feature = "serde", serde(default = "one"))]
174    pub cost: Cost,
175    #[cfg_attr(feature = "serde", serde(default))]
176    pub subsumed: bool,
177}
178
179impl Node {
180    pub fn is_leaf(&self) -> bool {
181        self.children.is_empty()
182    }
183}
184
185fn one() -> Cost {
186    Cost::new(1.0).unwrap()
187}
188
189#[derive(Debug, Clone, PartialEq, Eq)]
190pub struct Class {
191    pub id: ClassId,
192    pub nodes: Vec<NodeId>,
193}
194
195#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
196#[derive(Debug, Clone, PartialEq, Eq)]
197pub struct ClassData {
198    #[cfg_attr(feature = "serde", serde(rename = "type"))]
199    pub typ: Option<String>,
200
201    #[cfg_attr(feature = "serde", serde(flatten))]
202    pub extra: HashMap<String, String>,
203}