1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
use crate::{util::HashMap, *};
use core_relations::BaseValuePrinter;
use ordered_float::NotNan;
use std::collections::VecDeque;
pub struct SerializeConfig {
// Maximumum number of functions to include in the serialized graph, any after this will be discarded
pub max_functions: Option<usize>,
// Maximum number of calls to include per function, any after this will be discarded
pub max_calls_per_function: Option<usize>,
// Whether to include temporary functions in the serialized graph
pub include_temporary_functions: bool,
// Root eclasses to include in the output
pub root_eclasses: Vec<(ArcSort, Value)>,
}
/// Output of serializing an e-graph, including values that were omitted if any.
pub struct SerializeOutput {
/// The serialized e-graph.
pub egraph: egraph_serialize::EGraph,
/// Functions with more calls than max_calls_per_function, so that not all values are included.
pub truncated_functions: Vec<String>,
/// Functions that were discarded from the output, because more functions were present than max_functions
pub discarded_functions: Vec<String>,
}
impl SerializeOutput {
/// Returns true if the serialization is complete and no functions were truncated or discarded.
pub fn is_complete(&self) -> bool {
self.truncated_functions.is_empty() && self.discarded_functions.is_empty()
}
/// Description of what was omitted from the e-graph
pub fn omitted_description(&self) -> String {
let mut msg = String::new();
if !self.discarded_functions.is_empty() {
msg.push_str(&format!(
"Omitted: {}\n",
self.discarded_functions.join(", ")
));
}
if !self.truncated_functions.is_empty() {
msg.push_str(&format!(
"Truncated: {}\n",
self.truncated_functions.join(", ")
));
}
msg
}
}
#[allow(dead_code)]
struct Serializer {
node_ids: NodeIDs,
result: egraph_serialize::EGraph,
let_bindings: HashMap<egraph_serialize::ClassId, Vec<String>>,
}
/// Default is used for exporting JSON and will output all nodes.
impl Default for SerializeConfig {
fn default() -> Self {
SerializeConfig {
max_functions: None,
max_calls_per_function: None,
include_temporary_functions: false,
root_eclasses: vec![],
}
}
}
/// A node in the serialized egraph.
#[derive(PartialEq, Debug, Clone)]
pub enum SerializedNode {
/// A user defined function call.
Function {
/// The name of the function.
name: String,
/// The offset of the index in the table.
/// This can be resolved to the output and input values with table.get_index(offset, true).
offset: usize,
},
/// A primitive value.
Primitive(Value),
/// A dummy node used to represent omitted nodes.
Dummy(Value),
/// A node that was split into multiple e-classes.
Split(Box<SerializedNode>),
}
impl SerializedNode {
/// Returns true if the node is a primitive value.
pub fn is_primitive(&self) -> bool {
match self {
SerializedNode::Primitive(_) => true,
SerializedNode::Split(node) => node.is_primitive(),
_ => false,
}
}
}
impl EGraph {
/// Serialize the egraph into a format that can be read by the egraph-serialize crate.
///
/// There are multiple different semantically valid ways to do this. This is how this implementation does it:
///
/// For node costs:
/// - Primitives: 1.0
/// - Function without costs: 1.0
/// - Function with costs: the cost
/// - Omitted nodes: infinite
///
/// For node IDs:
/// - Functions: Function name + hash of input values
/// - Args which are eq sorts: Choose one ID from the e-class, distribute roughly evenly.
/// - Args and outputs values which are primitives: Sort name + hash of value
///
/// For e-classes IDs:
/// - tag and value of canonicalized value
///
/// This is to achieve the following properties:
/// - Equivalent primitive values will show up once in the e-graph.
/// - Functions which return primitive values will be added to the e-class of that value.
/// - Nodes will have consistant IDs throughout execution of e-graph (used for animating changes in the visualization)
/// - Edges in the visualization will be well distributed (used for animating changes in the visualization)
/// (Note that this will be changed in `<https://github.com/egraphs-good/egglog/pull/158>` so that edges point to exact nodes instead of looking up the e-class)
pub fn serialize(&self, config: SerializeConfig) -> SerializeOutput {
let mut truncated_functions = Vec::new();
let mut discarded_functions = Vec::new();
let max_calls_per_function = config.max_calls_per_function.unwrap_or(usize::MAX);
let max_functions = config.max_functions.unwrap_or(usize::MAX);
let mut all_calls: Vec<(
&Function,
Vec<Value>, // inputs
Value, // output
bool, // is subsumed
egraph_serialize::ClassId,
egraph_serialize::NodeId,
)> = Vec::new();
let mut functions_kept = 0usize;
let mut let_bindings = HashMap::default();
for (name, function) in self.functions.iter() {
if functions_kept >= max_functions {
discarded_functions.push(name.clone());
continue;
}
let mut rows = 0;
self.backend.for_each_while(function.backend_id, |row| {
if rows >= max_calls_per_function {
truncated_functions.push(name.clone());
return false;
}
let (out, inps) = row.vals.split_last().unwrap();
let class_id = self.value_to_class_id(&function.schema.output, *out);
if function.decl.let_binding {
let_bindings
.entry(class_id.clone())
.or_insert_with(Vec::new)
.push(name.clone());
} else {
all_calls.push((
function,
inps.to_vec(),
*out,
row.subsumed,
class_id,
self.to_node_id(
None,
SerializedNode::Function {
name: name.clone(),
offset: rows,
},
),
));
rows += 1;
}
true
});
if rows != 0 {
functions_kept += 1;
}
}
// Then create a mapping from each canonical e-class ID to the set of node IDs in that e-class
// Note that this is only for e-classes, primitives have e-classes equal to their node ID
// This is for when we need to find what node ID to use for an edge to an e-class, we can rotate them evenly
// amoung all possible options.
let node_ids: NodeIDs = all_calls.iter().fold(
HashMap::default(),
|mut acc, (func, _input, _output, _subsumed, class_id, node_id)| {
if func.schema.output.is_eq_sort() {
acc.entry(class_id.clone())
.or_default()
.push_back(node_id.clone());
}
acc
},
);
let mut serializer = Serializer {
node_ids,
result: egraph_serialize::EGraph::default(),
let_bindings,
};
for (func, input, output, subsumed, class_id, node_id) in all_calls {
self.serialize_value(&mut serializer, &func.schema.output, output, &class_id);
assert_eq!(input.len(), func.schema.input.len());
let children: Vec<_> = input
.iter()
.zip(&func.schema.input)
.map(|(&v, sort)| {
self.serialize_value(&mut serializer, sort, v, &self.value_to_class_id(sort, v))
})
.collect();
serializer.result.nodes.insert(
node_id,
egraph_serialize::Node {
op: func.decl.name.to_string(),
eclass: class_id.clone(),
cost: NotNan::new(func.decl.cost.unwrap_or(1) as f64).unwrap(),
children,
subsumed,
},
);
}
serializer.result.root_eclasses = config
.root_eclasses
.iter()
.map(|(sort, v)| self.value_to_class_id(sort, *v))
.collect();
SerializeOutput {
egraph: serializer.result,
truncated_functions,
discarded_functions,
}
}
/// Gets the serialized class ID for a value.
pub fn value_to_class_id(&self, sort: &ArcSort, value: Value) -> egraph_serialize::ClassId {
// Canonicalize the value first so that we always use the canonical e-class ID
let value = self
.backend
.get_canon_repr(value, sort.column_ty(&self.backend));
assert!(
!sort.name().to_string().contains('-'),
"Tag cannot contain '-' when serializing"
);
use numeric_id::NumericId;
format!("{}-{}", sort.name(), value.rep()).into()
}
/// Gets the value for a serialized class ID.
pub fn class_id_to_value(&self, eclass_id: &egraph_serialize::ClassId) -> Value {
let s = eclass_id.to_string();
let (_tag, bits) = s.split_once('-').unwrap();
Value::new_const(bits.parse().unwrap())
}
/// Gets the serialized node ID for the primitive, omitted, or function value.
pub fn to_node_id(
&self,
sort: Option<&ArcSort>,
node: SerializedNode,
) -> egraph_serialize::NodeId {
match node {
SerializedNode::Function { name, offset } => {
assert!(sort.is_none());
format!("function-{}-{}", offset, name).into()
}
SerializedNode::Primitive(value) => {
format!("primitive-{}", self.value_to_class_id(sort.unwrap(), value)).into()
}
SerializedNode::Dummy(value) => {
format!("dummy-{}", self.value_to_class_id(sort.unwrap(), value)).into()
}
SerializedNode::Split(node) => format!("split-{}", self.to_node_id(sort, *node)).into(),
}
}
/// Gets the serialized node for the node ID.
pub fn from_node_id(&self, node_id: &egraph_serialize::NodeId) -> SerializedNode {
let node_id = node_id.to_string();
let (tag, rest) = node_id.split_once('-').unwrap();
match tag {
"function" => {
let (offset, name) = rest.split_once('-').unwrap();
SerializedNode::Function {
name: name.into(),
offset: offset.parse().unwrap(),
}
}
"primitive" => {
let class_id: egraph_serialize::ClassId = rest.into();
SerializedNode::Primitive(self.class_id_to_value(&class_id))
}
"dummy" => {
let class_id: egraph_serialize::ClassId = rest.into();
SerializedNode::Dummy(self.class_id_to_value(&class_id))
}
"split" => {
let (_offset, rest) = rest.split_once('-').unwrap();
let node_id: egraph_serialize::NodeId = rest.into();
SerializedNode::Split(Box::new(self.from_node_id(&node_id)))
}
_ => std::panic::panic_any(format!("Unknown node ID: {}-{}", tag, rest)),
}
}
/// Serialize the value and return the node ID
/// If this is a primitive value, we will add the node to the data, but if it is an eclass, we will not
/// When this is called on the output of a node, we only use the e-class to know which e-class its a part of
/// When this is called on an input of a node, we only use the node ID to know which node to point to.
fn serialize_value(
&self,
serializer: &mut Serializer,
sort: &ArcSort,
value: Value,
class_id: &egraph_serialize::ClassId,
) -> egraph_serialize::NodeId {
let node_id = if sort.is_eq_sort() {
let node_ids = serializer
.node_ids
.entry(class_id.clone())
.or_insert_with(|| {
// If we don't find node IDs for this class, it means that all nodes for it were omitted due to size constraints
// In this case, add a dummy node in this class to represent the missing nodes
let node_id = self.to_node_id(Some(sort), SerializedNode::Dummy(value));
serializer.result.nodes.insert(
node_id.clone(),
egraph_serialize::Node {
op: "[...]".to_string(),
eclass: class_id.clone(),
cost: NotNan::new(f64::INFINITY).unwrap(),
children: vec![],
subsumed: false,
},
);
VecDeque::from(vec![node_id])
});
node_ids.rotate_left(1);
node_ids.front().unwrap().clone()
} else {
let node_id = self.to_node_id(Some(sort), SerializedNode::Primitive(value));
// Add node for value
{
let container_values = self.backend.container_values();
// Children will be empty unless this is a container sort
let children: Vec<egraph_serialize::NodeId> = sort
.inner_values(container_values, value)
.into_iter()
.map(|(s, v)| {
self.serialize_value(serializer, &s, v, &self.value_to_class_id(&s, v))
})
.collect();
// If this is a container sort, use the name, otherwise use the value
let op = if sort.is_container_sort() {
sort.serialized_name(container_values, value)
} else {
let primitive_id = self
.backend
.base_values()
.get_ty_by_id(sort.value_type().unwrap());
let formatted_val = BaseValuePrinter {
base: self.backend.base_values(),
ty: primitive_id,
val: value,
};
format!("{:?}", formatted_val)
};
serializer.result.nodes.insert(
node_id.clone(),
egraph_serialize::Node {
op,
eclass: class_id.clone(),
cost: NotNan::new(1.0).unwrap(),
children,
subsumed: false,
},
);
};
node_id
};
#[allow(clippy::disallowed_types)]
let mut extra = std::collections::HashMap::default();
if let Some(let_bindings) = serializer.let_bindings.get(class_id) {
if !let_bindings.is_empty() {
extra.insert("let".to_string(), let_bindings.join(", "));
}
}
serializer.result.class_data.insert(
class_id.clone(),
egraph_serialize::ClassData {
typ: Some(sort.name().to_string()),
extra,
},
);
node_id
}
}
type NodeIDs = HashMap<egraph_serialize::ClassId, VecDeque<egraph_serialize::NodeId>>;