fidget_core/vm/data.rs
1//! General-purpose tapes for use during evaluation or further compilation
2use crate::{
3 Error,
4 compiler::{RegOp, RegTape, RegisterAllocator, SsaOp, SsaTape},
5 context::{Context, Node},
6 var::VarMap,
7 vm::Choice,
8};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11
12/// A flattened math expression, ready for evaluation or further compilation.
13///
14/// Under the hood, [`VmData`] stores two different representations:
15/// - A tape in [single static assignment form](https://en.wikipedia.org/wiki/Static_single-assignment_form)
16/// ([`SsaTape`]), which is suitable for use during tape simplification
17/// - A tape in register-allocated form ([`RegTape`]), which can be efficiently
18/// evaluated or lowered into machine assembly
19///
20/// # Example
21/// Consider the expression `x + y`. The SSA tape will look something like
22/// this:
23/// ```text
24/// $0 = INPUT 0 // X
25/// $1 = INPUT 1 // Y
26/// $2 = ADD $0 $1 // (X + Y)
27/// ```
28///
29/// This will be lowered into a tape using real (or VM) registers:
30/// ```text
31/// r0 = INPUT 0 // X
32/// r1 = INPUT 1 // Y
33/// r0 = ADD r0 r1 // (X + Y)
34/// ```
35///
36/// Note that in this form, registers are reused (e.g. `r0` stores both `X` and
37/// `X + Y`).
38///
39/// We can peek at the internals and see this register-allocated tape:
40/// ```
41/// use fidget_core::{
42/// compiler::RegOp,
43/// context::{Context, Tree},
44/// vm::VmData,
45/// var::Var,
46/// };
47///
48/// let tree = Tree::x() + Tree::y();
49/// let mut ctx = Context::new();
50/// let sum = ctx.import(&tree);
51/// let data = VmData::<255>::new(&ctx, &[sum])?;
52/// assert_eq!(data.len(), 4); // X, Y, (X + Y), and output
53///
54/// let mut iter = data.iter_asm();
55/// let vars = &data.vars; // map from var to index
56/// assert_eq!(iter.next().unwrap(), RegOp::Input(0, vars[&Var::X] as u32));
57/// assert_eq!(iter.next().unwrap(), RegOp::Input(1, vars[&Var::Y] as u32));
58/// assert_eq!(iter.next().unwrap(), RegOp::AddRegReg(0, 0, 1));
59/// # Ok::<(), fidget_core::Error>(())
60/// ```
61///
62/// Despite this peek at its internals, users are unlikely to touch `VmData`
63/// directly; a [`VmShape`](crate::vm::VmShape) wraps the `VmData` and
64/// implements our common traits.
65#[derive(Default, Serialize, Deserialize)]
66pub struct VmData<const N: usize = { u8::MAX as usize }> {
67 ssa: SsaTape,
68 asm: RegTape,
69
70 /// Mapping from variables to indices during evaluation
71 ///
72 /// This member is stored in a shared pointer because it's passed down to
73 /// children (constructed with [`VmData::simplify`]).
74 pub vars: Arc<VarMap>,
75}
76
77impl<const N: usize> VmData<N> {
78 /// Builds a new tape for the given node
79 pub fn new(context: &Context, nodes: &[Node]) -> Result<Self, Error> {
80 let (ssa, vars) = SsaTape::new(context, nodes)?;
81 let asm = RegTape::new::<N>(&ssa);
82 Ok(Self {
83 ssa,
84 asm,
85 vars: vars.into(),
86 })
87 }
88
89 /// Returns the length of the internal VM tape
90 pub fn len(&self) -> usize {
91 self.asm.len()
92 }
93
94 /// Returns true if the internal VM tape is empty
95 pub fn is_empty(&self) -> bool {
96 self.asm.is_empty()
97 }
98
99 /// Returns the number of choice (min/max) nodes in the tape.
100 ///
101 /// This is required because some evaluators pre-allocate spaces for the
102 /// choice array.
103 pub fn choice_count(&self) -> usize {
104 self.ssa.choice_count
105 }
106
107 /// Returns the number of output nodes in the tape.
108 ///
109 /// This is required because some evaluators pre-allocate spaces for the
110 /// output array.
111 pub fn output_count(&self) -> usize {
112 self.ssa.output_count
113 }
114
115 /// Returns the number of slots used by the inner VM tape
116 pub fn slot_count(&self) -> usize {
117 self.asm.slot_count()
118 }
119
120 /// Simplifies both inner tapes, using the provided choice array
121 ///
122 /// To minimize allocations, this function takes a [`VmWorkspace`] and
123 /// spare [`VmData`]; it will reuse those allocations.
124 pub fn simplify<const M: usize>(
125 &self,
126 choices: &[Choice],
127 workspace: &mut VmWorkspace<M>,
128 mut tape: VmData<M>,
129 ) -> Result<VmData<M>, Error> {
130 if choices.len() != self.choice_count() {
131 return Err(Error::BadChoiceSlice(
132 choices.len(),
133 self.choice_count(),
134 ));
135 }
136 tape.ssa.reset();
137
138 // Steal `tape.asm` and hand it to the workspace for use in allocator
139 workspace.reset(self.ssa.tape.len(), tape.asm);
140
141 let mut choice_count = 0;
142 let mut output_count = 0;
143
144 // Other iterators to consume various arrays in order
145 let mut choice_iter = choices.iter().rev();
146
147 let mut ops_out = tape.ssa.tape;
148
149 for mut op in self.ssa.tape.iter().cloned() {
150 let index = match &mut op {
151 SsaOp::Output(reg, _i) => {
152 *reg = workspace.get_or_insert_active(*reg);
153 workspace.alloc.op(op);
154 ops_out.push(op);
155 output_count += 1;
156 continue;
157 }
158 _ => op.output().unwrap(),
159 };
160
161 if workspace.active(index).is_none() {
162 if op.has_choice() {
163 choice_iter.next().unwrap();
164 }
165 continue;
166 }
167
168 // Because we reassign nodes when they're used as an *input*
169 // (while walking the tape in reverse), this node must have been
170 // assigned already.
171 let new_index = workspace.active(index).unwrap();
172
173 match &mut op {
174 SsaOp::Output(..) => unreachable!(),
175 SsaOp::Input(index, ..) | SsaOp::CopyImm(index, ..) => {
176 *index = new_index;
177 }
178 SsaOp::NegReg(index, arg)
179 | SsaOp::AbsReg(index, arg)
180 | SsaOp::RecipReg(index, arg)
181 | SsaOp::SqrtReg(index, arg)
182 | SsaOp::SquareReg(index, arg)
183 | SsaOp::FloorReg(index, arg)
184 | SsaOp::CeilReg(index, arg)
185 | SsaOp::RoundReg(index, arg)
186 | SsaOp::SinReg(index, arg)
187 | SsaOp::CosReg(index, arg)
188 | SsaOp::TanReg(index, arg)
189 | SsaOp::AsinReg(index, arg)
190 | SsaOp::AcosReg(index, arg)
191 | SsaOp::AtanReg(index, arg)
192 | SsaOp::ExpReg(index, arg)
193 | SsaOp::LnReg(index, arg)
194 | SsaOp::NotReg(index, arg) => {
195 *index = new_index;
196 *arg = workspace.get_or_insert_active(*arg);
197 }
198 SsaOp::CopyReg(index, src) => {
199 // CopyReg effectively does
200 // dst <= src
201 // If src has not yet been used (as we iterate backwards
202 // through the tape), then we can replace it with dst
203 // everywhere!
204 match workspace.active(*src) {
205 Some(new_src) => {
206 *index = new_index;
207 *src = new_src;
208 }
209 None => {
210 workspace.set_active(*src, new_index);
211 continue;
212 }
213 }
214 }
215 SsaOp::MinRegImm(index, arg, imm)
216 | SsaOp::MaxRegImm(index, arg, imm)
217 | SsaOp::AndRegImm(index, arg, imm)
218 | SsaOp::OrRegImm(index, arg, imm) => {
219 match choice_iter.next().unwrap() {
220 Choice::Left => match workspace.active(*arg) {
221 Some(new_arg) => {
222 op = SsaOp::CopyReg(new_index, new_arg);
223 }
224 None => {
225 workspace.set_active(*arg, new_index);
226 continue;
227 }
228 },
229 Choice::Right => {
230 op = SsaOp::CopyImm(new_index, *imm);
231 }
232 Choice::Both => {
233 choice_count += 1;
234 *index = new_index;
235 *arg = workspace.get_or_insert_active(*arg);
236 }
237 Choice::Unknown => panic!("oh no"),
238 }
239 }
240 SsaOp::MinRegReg(index, lhs, rhs)
241 | SsaOp::MaxRegReg(index, lhs, rhs)
242 | SsaOp::AndRegReg(index, lhs, rhs)
243 | SsaOp::OrRegReg(index, lhs, rhs) => {
244 match choice_iter.next().unwrap() {
245 Choice::Left => match workspace.active(*lhs) {
246 Some(new_lhs) => {
247 op = SsaOp::CopyReg(new_index, new_lhs);
248 }
249 None => {
250 workspace.set_active(*lhs, new_index);
251 continue;
252 }
253 },
254 Choice::Right => match workspace.active(*rhs) {
255 Some(new_rhs) => {
256 op = SsaOp::CopyReg(new_index, new_rhs);
257 }
258 None => {
259 workspace.set_active(*rhs, new_index);
260 continue;
261 }
262 },
263 Choice::Both => {
264 choice_count += 1;
265 *index = new_index;
266 *lhs = workspace.get_or_insert_active(*lhs);
267 *rhs = workspace.get_or_insert_active(*rhs);
268 }
269 Choice::Unknown => panic!("oh no"),
270 }
271 }
272 SsaOp::AddRegReg(index, lhs, rhs)
273 | SsaOp::MulRegReg(index, lhs, rhs)
274 | SsaOp::SubRegReg(index, lhs, rhs)
275 | SsaOp::DivRegReg(index, lhs, rhs)
276 | SsaOp::AtanRegReg(index, lhs, rhs)
277 | SsaOp::CompareRegReg(index, lhs, rhs)
278 | SsaOp::ModRegReg(index, lhs, rhs) => {
279 *index = new_index;
280 *lhs = workspace.get_or_insert_active(*lhs);
281 *rhs = workspace.get_or_insert_active(*rhs);
282 }
283 SsaOp::AddRegImm(index, arg, _imm)
284 | SsaOp::MulRegImm(index, arg, _imm)
285 | SsaOp::SubRegImm(index, arg, _imm)
286 | SsaOp::SubImmReg(index, arg, _imm)
287 | SsaOp::DivRegImm(index, arg, _imm)
288 | SsaOp::DivImmReg(index, arg, _imm)
289 | SsaOp::AtanImmReg(index, arg, _imm)
290 | SsaOp::AtanRegImm(index, arg, _imm)
291 | SsaOp::CompareRegImm(index, arg, _imm)
292 | SsaOp::CompareImmReg(index, arg, _imm)
293 | SsaOp::ModRegImm(index, arg, _imm)
294 | SsaOp::ModImmReg(index, arg, _imm) => {
295 *index = new_index;
296 *arg = workspace.get_or_insert_active(*arg);
297 }
298 }
299 workspace.alloc.op(op);
300 ops_out.push(op);
301 }
302
303 assert_eq!(workspace.count as usize + 1, ops_out.len());
304 let asm_tape = workspace.alloc.finalize();
305
306 Ok(VmData {
307 ssa: SsaTape {
308 tape: ops_out,
309 choice_count,
310 output_count,
311 },
312 asm: asm_tape,
313 vars: self.vars.clone(),
314 })
315 }
316
317 /// Produces an iterator that visits [`RegOp`] values in evaluation order
318 pub fn iter_asm(&self) -> impl Iterator<Item = RegOp> + '_ {
319 self.asm.iter().cloned().rev()
320 }
321
322 /// Pretty-prints the inner SSA tape
323 pub fn pretty_print(&self) {
324 self.ssa.pretty_print();
325 for a in self.iter_asm() {
326 println!("{a:?}");
327 }
328 }
329}
330
331////////////////////////////////////////////////////////////////////////////////
332
333/// Data structures used during [`VmData::simplify`]
334///
335/// This is exposed to minimize reallocations in hot loops.
336pub struct VmWorkspace<const N: usize> {
337 /// Register allocator
338 pub(crate) alloc: RegisterAllocator<N>,
339
340 /// Current bindings from SSA variables to registers
341 pub(crate) bind: Vec<u32>,
342
343 /// Number of active SSA bindings
344 ///
345 /// This value is monotonically increasing; each SSA variable gets the next
346 /// value if it is unassigned when encountered.
347 count: u32,
348}
349
350impl<const N: usize> Default for VmWorkspace<N> {
351 fn default() -> Self {
352 Self {
353 alloc: RegisterAllocator::empty(),
354 bind: vec![],
355 count: 0,
356 }
357 }
358}
359
360impl<const N: usize> VmWorkspace<N> {
361 fn active(&self, i: u32) -> Option<u32> {
362 if self.bind[i as usize] != u32::MAX {
363 Some(self.bind[i as usize])
364 } else {
365 None
366 }
367 }
368
369 fn get_or_insert_active(&mut self, i: u32) -> u32 {
370 if self.bind[i as usize] == u32::MAX {
371 self.bind[i as usize] = self.count;
372 self.count += 1;
373 }
374 self.bind[i as usize]
375 }
376
377 fn set_active(&mut self, i: u32, bind: u32) {
378 self.bind[i as usize] = bind;
379 }
380
381 /// Resets the workspace, preserving allocations and claiming the given
382 /// [`RegTape`].
383 pub fn reset(&mut self, tape_len: usize, tape: RegTape) {
384 self.alloc.reset(tape_len, tape);
385 self.bind.fill(u32::MAX);
386 self.bind.resize(tape_len, u32::MAX);
387 self.count = 0;
388 }
389}
390
391#[cfg(test)]
392mod test {
393 use super::*;
394
395 #[test]
396 fn simplify_reg_count_change() {
397 let mut ctx = Context::new();
398 let x = ctx.x();
399 let y = ctx.y();
400 let z = ctx.z();
401 let xy = ctx.add(x, y).unwrap();
402 let xyz = ctx.add(xy, z).unwrap();
403
404 let data = VmData::<3>::new(&ctx, &[xyz]).unwrap();
405 assert_eq!(data.len(), 6); // 3x input, 2x add, 1x output
406 let next = data
407 .simplify::<2>(&[], &mut Default::default(), Default::default())
408 .unwrap();
409 assert_eq!(next.len(), 8); // extra load + store
410
411 let data = VmData::<2>::new(&ctx, &[xyz]).unwrap();
412 assert_eq!(data.len(), 8);
413 let next = data
414 .simplify::<3>(&[], &mut Default::default(), Default::default())
415 .unwrap();
416 assert_eq!(next.len(), 6);
417 }
418}