1use crate::errors::CResultMany;
2use crate::hlr::hlr_data::{FuncRep, VariableInfo, ArgIndex, VarID, GotoLabelID};
3use crate::{Type, VarName};
4
5use super::{ExprID, ExprTree, HNodeData, NodeDataGen, SetGen};
6use super::{ExprNode, HNodeData::*};
7
8impl ExprTree {
9 pub fn iter_mut<'a>(
10 &'a mut self,
11 ) -> Box<dyn Iterator<Item = (ExprID, &mut HNodeData)> + 'a> {
12 Box::new(self.nodes.iter_mut().map(|(id, node)| (id, &mut node.data)))
13 }
14
15 pub fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (ExprID, &HNodeData)> + 'a> {
16 Box::new(self.nodes.iter().map(|(id, node)| (id, &node.data)))
17 }
18
19 pub fn ids_in_order(&self) -> Vec<ExprID> {
20 self.ids_of(self.root)
21 }
22
23 pub fn ids_of(&self, id: ExprID) -> Vec<ExprID> {
24 use HNodeData::*;
25 let rest = match self.get(id) {
26 Number { .. }
27 | Float { .. }
28 | Bool { .. }
29 | GlobalLoad { .. }
30 | Ident { .. }
31 | AccessAlias(_)
32 | Goto(_)
33 | GotoLabel(_) => Vec::new(),
34 StructLit { fields, .. } => fields
35 .iter()
36 .flat_map(|(_, id)| self.ids_of(*id))
37 .collect(),
38 ArrayLit { parts: many, .. }
39 | Call { a: many, .. }
40 | Block { stmts: many, .. } => {
41 many.iter().flat_map(|id| self.ids_of(*id)).collect()
42 },
43 IndirectCall {
44 f: one, a: many, ..
45 } => many
46 .iter()
47 .flat_map(|id| self.ids_of(*id))
48 .chain(self.ids_of(one).drain(..))
49 .collect(),
50 BinOp { lhs: l, rhs: r, .. }
51 | Set { lhs: l, rhs: r, .. }
52 | While { w: l, d: r, .. }
53 | Index {
54 object: l,
55 index: r,
56 ..
57 } => self.ids_of(l)
58 .drain(..)
59 .chain(self.ids_of(r).drain(..))
60 .collect(),
61 UnarOp { hs: one, .. }
62 | Transform { hs: one, .. }
63 | Member { object: one, .. } => self.ids_of(one),
64 IfThenElse { i, t, e, .. } => self.ids_of(i)
65 .drain(..)
66 .chain(self.ids_of(t).drain(..))
67 .chain(self.ids_of(e).drain(..))
68 .collect(),
69 Return { to_return, .. } => {
70 if let Some(to_return) = to_return {
71 self.ids_of(to_return)
72 } else {
73 Vec::new()
74 }
75 },
76 };
77
78 [id].into_iter().chain(rest.into_iter()).collect()
79 }
80
81 pub fn ids_unordered(&self) -> slotmap::basic::Keys<ExprID, ExprNode> {
82 self.nodes.keys()
83 }
84
85 pub fn insert(&mut self, parent: ExprID, data: HNodeData) -> ExprID {
86 self.nodes.insert(ExprNode { parent, data })
87 }
88
89 pub fn replace(&mut self, at: ExprID, with: HNodeData) {
90 self.nodes.get_mut(at).unwrap().data = with;
91 }
92
93 pub fn make_one_space(&mut self, parent: ExprID) -> ExprID {
94 self.nodes.insert(ExprNode {
95 parent,
96 data: HNodeData::Number {
97 value: 0,
98 lit_type: Type::i(32),
99 },
100 })
101 }
102
103 pub fn get(&self, at: ExprID) -> HNodeData { self.nodes[at].data.clone() }
104
105 pub fn get_ref(&self, at: ExprID) -> &HNodeData { &self.nodes[at].data }
106
107 pub fn get_mut(&mut self, at: ExprID) -> &mut HNodeData { &mut self.nodes[at].data }
108
109 pub fn parent(&self, of: ExprID) -> ExprID { self.nodes[of].parent }
110
111 pub fn statement_and_block(&self, of: ExprID) -> (ExprID, ExprID) {
112 if of == self.root {
113 return (of, of);
114 }
115
116 let parent = self.parent(of);
117
118 if matches!(self.get_ref(parent), Block { .. }) {
119 (of, parent)
120 } else {
121 self.statement_and_block(parent)
122 }
123 }
124
125 pub fn block_of(&self, of: ExprID) -> ExprID {
126 if matches!(self.get_ref(of), HNodeData::Block { .. }) {
127 return of;
128 }
129
130 self.statement_and_block(of).1
131 }
132
133 pub fn count(&self) -> usize { self.nodes.len() }
134
135 pub fn remove_node(&mut self, remove_id: ExprID) -> Result<(), ()> {
136 let parent = self.get_mut(self.parent(remove_id));
137 match parent {
138 Block { stmts, .. } => {
139 let old_len = stmts.len();
140 stmts.retain(|id| *id != remove_id);
141 assert_eq!(stmts.len(), old_len - 1);
142 }
143 _ => return Err(()),
144 };
145 self.nodes.remove(remove_id);
146 Ok(())
147 }
148
149 pub fn prune(&mut self) {
150 let ids = self.ids_in_order();
151 self.nodes.retain(|key, _| ids.contains(&key));
152 }
153}
154
155impl<'a> FuncRep<'a> {
156 pub fn modify_many(
157 &mut self,
158 mut modifier: impl FnMut(ExprID, &mut HNodeData, &mut FuncRep) -> CResultMany<()>,
159 ) -> CResultMany<()> {
160 self.modify_many_inner(self.tree.ids_in_order().drain(..), |a, b, c| { modifier(a, b, c) })
161 }
162
163 pub fn modify_many_rev(
164 &mut self,
165 mut modifier: impl FnMut(ExprID, &mut HNodeData, &mut FuncRep) -> CResultMany<()>,
166 ) -> CResultMany<()> {
167 self.modify_many_inner(self.tree.ids_in_order().drain(..).rev(), |a, b, c| { modifier(a, b, c) })
168 }
169
170 pub fn modify_many_infallible(
171 &mut self,
172 mut modifier: impl FnMut(ExprID, &mut HNodeData, &mut FuncRep),
173 ) {
174 self.modify_many(|a, b, c| { modifier(a, b, c); Ok(()) }).unwrap();
175 }
176
177 pub fn modify_many_infallible_rev(
178 &mut self,
179 mut modifier: impl FnMut(ExprID, &mut HNodeData, &mut FuncRep),
180 ) {
181 self.modify_many_rev(|a, b, c| { modifier(a, b, c); Ok(()) }).unwrap();
182 }
183
184 fn modify_many_inner(
185 &mut self,
186 id_iterator: impl Iterator<Item = ExprID>,
187 mut modifier: impl FnMut(ExprID, &mut HNodeData, &mut FuncRep) -> CResultMany<()>,
188 ) -> CResultMany<()> {
189 for id in id_iterator {
190 let mut data_copy = self.tree.get(id);
191
192 modifier(id, &mut data_copy, self)?;
193
194 if self.tree.nodes.contains_key(id) {
196 self.tree.replace(id, data_copy);
197 }
198 }
199
200 Ok(())
201 }
202
203 pub fn insert_statement_before<'ptr>(
204 &'ptr mut self,
205 statement_origin: ExprID,
206 new_data: impl NodeDataGen,
207 ) -> InsertionData<'ptr, 'a> {
208 let new_statement = self.insert_statement_inner(statement_origin, new_data, 0);
209 InsertionData(self, new_statement)
210 }
211
212 pub fn insert_statement_after<'ptr>(
213 &'ptr mut self,
214 statement_origin: ExprID,
215 new_data: impl NodeDataGen,
216 ) -> InsertionData<'ptr, 'a> {
217 let new_statement = self.insert_statement_inner(statement_origin, new_data, 1);
218 InsertionData(self, new_statement)
219 }
220
221 fn insert_statement_inner(
222 &mut self,
223 statement_origin: ExprID,
224 new_data: impl NodeDataGen,
225 offset: usize,
226 ) -> ExprID {
227 let (statement, block) = self.tree.statement_and_block(statement_origin);
228
229 let new = self.insert_quick(block, new_data);
230
231 let HNodeData::Block { ref mut stmts, .. } = self.tree.get_mut(block)
232 else { unreachable!() };
233
234 let block_pos = stmts.iter().position(|s_id| *s_id == statement).unwrap() + offset;
235 stmts.insert(block_pos, new);
236
237 new
238 }
239
240 pub fn add_variable(&mut self, typ: &Type) -> VarID {
241 let var = self.variables.insert(VariableInfo {
242 typ: typ.clone(),
243 arg_index: ArgIndex::None,
244 ..Default::default()
245 });
246
247 let HNodeData::Block { ref mut declared, .. } = self.tree.get_mut(self.tree.root)
248 else { unreachable!() };
249 declared.insert(var);
250
251 var
252 }
253
254 pub fn add_goto_label(&mut self, name: VarName, use_block_of: ExprID) -> GotoLabelID {
255 let label = self.goto_labels.insert(use_block_of);
256
257 let block = if matches!(self.tree.get_ref(use_block_of), HNodeData::Block { .. }) {
258 use_block_of
259 } else {
260 self.tree.statement_and_block(use_block_of).1
261 };
262
263 let HNodeData::Block { ref mut goto_labels, .. } = self.tree.get_mut(block)
264 else { unreachable!() };
265 goto_labels.insert(name, label);
266
267 label
268 }
269
270 pub fn separate_expression(&mut self, expression: ExprID) -> ExprID {
271 let expr_data = self.tree.get(expression);
272
273 use HNodeData::*;
274 if matches!(expr_data, Ident { .. } | Number { .. } | Float { .. } | Bool { .. }) {
275 return expression;
276 }
277
278 let new_var = self.add_variable(&expr_data.ret_type());
279
280 self.insert_statement_before(expression, SetGen {
281 lhs: new_var,
282 rhs: expr_data,
283 });
284 self.replace_quick(expression, new_var);
285 self.insert_quick(expression, new_var)
286 }
287}
288
289pub struct InsertionData<'ptr, 'a>(&'ptr mut FuncRep<'a>, ExprID);
290
291impl<'ptr, 'a> InsertionData<'ptr, 'a> {
292 pub fn after_that(self, new_data: impl NodeDataGen) -> Self {
293 self.0.insert_statement_after(self.1, new_data)
294 }
295
296 pub fn inserted_id(self) -> ExprID { self.1 }
297}