1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::ops::Deref;
use std::rc::Rc;
use petgraph::graph::NodeIndex;
use crate::expressions::Expression;
use crate::graph_traversal::complexity_rec;
use crate::{
graph::{Graph, RelType, Relationship},
optimization_profiles::OptimizationProfile,
};
/// Debug information yielded during a derivation operation.
#[derive(Clone)]
pub struct DerivationDebugInfo {
/// How many times each rule was applied and returned an expression.
pub rule_uses: HashMap<String, u32>,
}
/// Object used to expand equivalence graphs. Derefs to the graph it holds.
pub struct Deriver {
graph: Graph,
optimizer: Box<dyn OptimizationProfile>,
_debug_info: Option<Rc<RefCell<DerivationDebugInfo>>>,
_allowed_rules: Option<Vec<String>>,
node_indices: HashMap<Expression, NodeIndex>,
// Tuples are compared lexographically, so the sort descriminant is first
derivation_queue: BinaryHeap<(u32, NodeIndex)>,
}
impl Deriver {
/// * `graph` - A graph containing the initial expressions to derive from.
/// * `optimizer` - An optimization profile used in the derivation process.
/// * `allowed_rules` - If not specifed, all rules are applied as the optimization profile
/// finds appropriate. If specified and the optimization profile supports it, only the rules
/// specified by name are used.
/// * `debug_info` - If supplied debug information about the derivation process is recorded in
/// this object.
pub fn new(
graph: Graph,
optimizer: Box<dyn OptimizationProfile>,
allowed_rules: Option<Vec<String>>,
debug_info: Option<Rc<RefCell<DerivationDebugInfo>>>,
) -> Self {
let mut optimizer = optimizer;
if let Some(debug) = debug_info.clone() {
let _ = optimizer.set_debug(debug);
}
let mut derivation_queue = BinaryHeap::<(u32, NodeIndex)>::new();
derivation_queue.extend(graph.node_indices().map(|n| (u32::MAX, n)));
Self {
graph,
optimizer,
_debug_info: debug_info,
_allowed_rules: allowed_rules,
node_indices: HashMap::new(),
derivation_queue,
}
}
/// Expands the graph with derivation rules until one of the given constraints is met.
///
/// * `depth` - The maximum number of derivation steps from the original expressions in the
/// graph to take.
/// * `max_derivations` - The maximum number of equivalent expressions to derive. This number
/// is approximate and may be slightly exceeded before the process is stopped.
pub fn expand_to_constraint(&mut self, depth: u32, max_derivations: u32) {
for i in self.graph.node_indices() {
let node = self.graph.node_weight(i).unwrap();
self.node_indices.insert(node.clone(), i);
}
for _ in 0..depth {
if self.graph.node_count() as u32 >= max_derivations {
return;
}
for i in self.graph.node_indices() {
let expression = self.graph.node_weight(i).unwrap().clone();
let equivalents = self.optimizer.find_equivalents(&expression);
for (derived, argument) in equivalents.0 {
let index = if let Entry::Vacant(e) = self.node_indices.entry(derived.clone()) {
let result = self.graph.add_node(derived.clone());
e.insert(result);
result
} else {
self.node_indices[&derived]
};
match self.graph.find_edge(i, index) {
Some(edge_id) => {
self.graph
.edge_weight_mut(edge_id)
.unwrap()
.derived_from
.insert(argument);
}
None => {
let mut new_edge = Relationship {
r_type: RelType::Equal,
derived_from: HashSet::new(),
};
new_edge.derived_from.insert(argument);
self.graph.add_edge(i, index, new_edge);
}
}
}
}
}
}
/// Expands the graph with derivation rules.
///
/// * `max_new_derivations` - The maximum number of new expressions to derive during this call.
/// May be slightly exceeded.
///
/// Returns true if there are no further derivations to make. Any further call to this function would yield no result.
pub fn expand_increment(&mut self, max_new_derivations: u32) -> bool {
let mut derivations = 0;
while let Some((_, curr_index)) = self.derivation_queue.pop() {
let expression = self.graph.node_weight(curr_index).unwrap().clone();
let equivalents = self.optimizer.find_equivalents(&expression);
for (derived, argument) in equivalents.0 {
let index = if let Entry::Vacant(e) = self.node_indices.entry(derived.clone()) {
let result = self.graph.add_node(derived.clone());
e.insert(result);
derivations += 1;
self.derivation_queue.push((
if equivalents.1 {
u32::MAX
} else {
u32::MAX / 2
} - complexity_rec(&derived),
result,
));
result
} else {
self.node_indices[&derived]
};
match self.graph.find_edge(curr_index, index) {
Some(edge_id) => {
self.graph
.edge_weight_mut(edge_id)
.unwrap()
.derived_from
.insert(argument);
}
None => {
let mut new_edge = Relationship {
r_type: RelType::Equal,
derived_from: HashSet::new(),
};
new_edge.derived_from.insert(argument);
self.graph.add_edge(curr_index, index, new_edge);
}
}
}
if derivations >= max_new_derivations {
return false;
}
}
true
}
}
impl Deref for Deriver {
type Target = Graph;
fn deref(&self) -> &Self::Target {
&self.graph
}
}
#[cfg(test)]
mod tests {
use crate::{
convenience_expressions::i,
expressions::{product::product_of, sum::sum_of},
optimization_profiles::BruteForceProfile,
};
use super::*;
#[test]
fn applies_multiple_rules() {
let mut graph = Graph::new();
let start = sum_of(&[i(1), i(3), i(3), product_of(&[i(3), i(3)])]);
graph.add_node(start);
let mut deriver = Deriver::new(graph, BruteForceProfile::new(), None, None);
deriver.expand_to_constraint(5, 10000);
assert!(deriver.node_weights().any(|exp| *exp == i(16)));
}
}