1#[cfg(feature = "graphviz")]
2mod graphviz;
3
4mod algorithms;
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use indexmap::{map::Entry, IndexMap};
10use once_cell::sync::OnceCell;
11use ordered_float::NotNan;
12
13pub type Cost = NotNan<f64>;
14
15#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
16#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
17pub struct NodeId(Arc<str>);
18
19#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
20#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
21pub struct ClassId(Arc<str>);
22
23mod id_impls {
24 use super::*;
25
26 impl AsRef<str> for NodeId {
27 fn as_ref(&self) -> &str {
28 &self.0
29 }
30 }
31
32 impl<S: Into<String>> From<S> for NodeId {
33 fn from(s: S) -> Self {
34 Self(s.into().into())
35 }
36 }
37
38 impl std::fmt::Display for NodeId {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 write!(f, "{}", self.0)
41 }
42 }
43
44 impl AsRef<str> for ClassId {
45 fn as_ref(&self) -> &str {
46 &self.0
47 }
48 }
49
50 impl<S: Into<String>> From<S> for ClassId {
51 fn from(s: S) -> Self {
52 Self(s.into().into())
53 }
54 }
55
56 impl std::fmt::Display for ClassId {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 write!(f, "{}", self.0)
59 }
60 }
61}
62
63#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
64#[derive(Debug, Default, Clone, PartialEq, Eq)]
65pub struct EGraph {
66 pub nodes: IndexMap<NodeId, Node>,
67 #[cfg_attr(feature = "serde", serde(default))]
68 pub root_eclasses: Vec<ClassId>,
69 #[cfg_attr(feature = "serde", serde(default))]
71 pub class_data: IndexMap<ClassId, ClassData>,
72 #[cfg_attr(feature = "serde", serde(skip))]
73 once_cell_classes: OnceCell<IndexMap<ClassId, Class>>,
74}
75
76impl EGraph {
77 pub fn add_node(&mut self, node_id: impl Into<NodeId>, node: Node) {
81 match self.nodes.entry(node_id.into()) {
82 Entry::Occupied(e) => {
83 panic!(
84 "Duplicate node with id {key:?}\nold: {old:?}\nnew: {new:?}",
85 key = e.key(),
86 old = e.get(),
87 new = node
88 )
89 }
90 Entry::Vacant(e) => e.insert(node),
91 };
92 }
93
94 pub fn nid_to_cid(&self, node_id: &NodeId) -> &ClassId {
95 &self[node_id].eclass
96 }
97
98 pub fn nid_to_class(&self, node_id: &NodeId) -> &Class {
99 &self[&self[node_id].eclass]
100 }
101
102 pub fn classes(&self) -> &IndexMap<ClassId, Class> {
108 self.once_cell_classes.get_or_init(|| {
109 let mut classes = IndexMap::new();
110 for (node_id, node) in &self.nodes {
111 classes
112 .entry(node.eclass.clone())
113 .or_insert_with(|| Class {
114 id: node.eclass.clone(),
115 nodes: vec![],
116 })
117 .nodes
118 .push(node_id.clone())
119 }
120 classes
121 })
122 }
123
124 #[cfg(feature = "serde")]
125 pub fn from_json_file(path: impl AsRef<std::path::Path>) -> std::io::Result<Self> {
126 let file = std::fs::File::open(path)?;
127 let egraph: Self = serde_json::from_reader(std::io::BufReader::new(file))?;
128 Ok(egraph)
129 }
130
131 #[cfg(feature = "serde")]
132 pub fn to_json_file(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> {
133 let file = std::fs::File::create(path)?;
134 serde_json::to_writer_pretty(std::io::BufWriter::new(file), self)?;
135 Ok(())
136 }
137
138 #[cfg(feature = "serde")]
139 pub fn test_round_trip(&self) {
140 let json = serde_json::to_string_pretty(&self).unwrap();
141 let egraph2: EGraph = serde_json::from_str(&json).unwrap();
142 assert_eq!(self, &egraph2);
143 }
144}
145
146impl std::ops::Index<&NodeId> for EGraph {
147 type Output = Node;
148
149 fn index(&self, index: &NodeId) -> &Self::Output {
150 self.nodes
151 .get(index)
152 .unwrap_or_else(|| panic!("No node with id {index:?}"))
153 }
154}
155
156impl std::ops::Index<&ClassId> for EGraph {
157 type Output = Class;
158
159 fn index(&self, index: &ClassId) -> &Self::Output {
160 self.classes()
161 .get(index)
162 .unwrap_or_else(|| panic!("No class with id {index:?}"))
163 }
164}
165
166#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
167#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
168pub struct Node {
169 pub op: String,
170 #[cfg_attr(feature = "serde", serde(default))]
171 pub children: Vec<NodeId>,
172 pub eclass: ClassId,
173 #[cfg_attr(feature = "serde", serde(default = "one"))]
174 pub cost: Cost,
175 #[cfg_attr(feature = "serde", serde(default))]
176 pub subsumed: bool,
177}
178
179impl Node {
180 pub fn is_leaf(&self) -> bool {
181 self.children.is_empty()
182 }
183}
184
185fn one() -> Cost {
186 Cost::new(1.0).unwrap()
187}
188
189#[derive(Debug, Clone, PartialEq, Eq)]
190pub struct Class {
191 pub id: ClassId,
192 pub nodes: Vec<NodeId>,
193}
194
195#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
196#[derive(Debug, Clone, PartialEq, Eq)]
197pub struct ClassData {
198 #[cfg_attr(feature = "serde", serde(rename = "type"))]
199 pub typ: Option<String>,
200
201 #[cfg_attr(feature = "serde", serde(flatten))]
202 pub extra: HashMap<String, String>,
203}