fidget_core/vm/
data.rs

1//! General-purpose tapes for use during evaluation or further compilation
2use crate::{
3    Error,
4    compiler::{RegOp, RegTape, RegisterAllocator, SsaOp, SsaTape},
5    context::{Context, Node},
6    var::VarMap,
7    vm::Choice,
8};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11
12/// A flattened math expression, ready for evaluation or further compilation.
13///
14/// Under the hood, [`VmData`] stores two different representations:
15/// - A tape in [single static assignment form](https://en.wikipedia.org/wiki/Static_single-assignment_form)
16///   ([`SsaTape`]), which is suitable for use during tape simplification
17/// - A tape in register-allocated form ([`RegTape`]), which can be efficiently
18///   evaluated or lowered into machine assembly
19///
20/// # Example
21/// Consider the expression `x + y`.  The SSA tape will look something like
22/// this:
23/// ```text
24/// $0 = INPUT 0   // X
25/// $1 = INPUT 1   // Y
26/// $2 = ADD $0 $1 // (X + Y)
27/// ```
28///
29/// This will be lowered into a tape using real (or VM) registers:
30/// ```text
31/// r0 = INPUT 0 // X
32/// r1 = INPUT 1 // Y
33/// r0 = ADD r0 r1 // (X + Y)
34/// ```
35///
36/// Note that in this form, registers are reused (e.g. `r0` stores both `X` and
37/// `X + Y`).
38///
39/// We can peek at the internals and see this register-allocated tape:
40/// ```
41/// use fidget_core::{
42///     compiler::RegOp,
43///     context::{Context, Tree},
44///     vm::VmData,
45///     var::Var,
46/// };
47///
48/// let tree = Tree::x() + Tree::y();
49/// let mut ctx = Context::new();
50/// let sum = ctx.import(&tree);
51/// let data = VmData::<255>::new(&ctx, &[sum])?;
52/// assert_eq!(data.len(), 4); // X, Y, (X + Y), and output
53///
54/// let mut iter = data.iter_asm();
55/// let vars = &data.vars; // map from var to index
56/// assert_eq!(iter.next().unwrap(), RegOp::Input(0, vars[&Var::X] as u32));
57/// assert_eq!(iter.next().unwrap(), RegOp::Input(1, vars[&Var::Y] as u32));
58/// assert_eq!(iter.next().unwrap(), RegOp::AddRegReg(0, 0, 1));
59/// # Ok::<(), fidget_core::Error>(())
60/// ```
61///
62/// Despite this peek at its internals, users are unlikely to touch `VmData`
63/// directly; a [`VmShape`](crate::vm::VmShape) wraps the `VmData` and
64/// implements our common traits.
65#[derive(Default, Serialize, Deserialize)]
66pub struct VmData<const N: usize = { u8::MAX as usize }> {
67    ssa: SsaTape,
68    asm: RegTape,
69
70    /// Mapping from variables to indices during evaluation
71    ///
72    /// This member is stored in a shared pointer because it's passed down to
73    /// children (constructed with [`VmData::simplify`]).
74    pub vars: Arc<VarMap>,
75}
76
77impl<const N: usize> VmData<N> {
78    /// Builds a new tape for the given node
79    pub fn new(context: &Context, nodes: &[Node]) -> Result<Self, Error> {
80        let (ssa, vars) = SsaTape::new(context, nodes)?;
81        let asm = RegTape::new::<N>(&ssa);
82        Ok(Self {
83            ssa,
84            asm,
85            vars: vars.into(),
86        })
87    }
88
89    /// Returns the length of the internal VM tape
90    pub fn len(&self) -> usize {
91        self.asm.len()
92    }
93
94    /// Returns true if the internal VM tape is empty
95    pub fn is_empty(&self) -> bool {
96        self.asm.is_empty()
97    }
98
99    /// Returns the number of choice (min/max) nodes in the tape.
100    ///
101    /// This is required because some evaluators pre-allocate spaces for the
102    /// choice array.
103    pub fn choice_count(&self) -> usize {
104        self.ssa.choice_count
105    }
106
107    /// Returns the number of output nodes in the tape.
108    ///
109    /// This is required because some evaluators pre-allocate spaces for the
110    /// output array.
111    pub fn output_count(&self) -> usize {
112        self.ssa.output_count
113    }
114
115    /// Returns the number of slots used by the inner VM tape
116    pub fn slot_count(&self) -> usize {
117        self.asm.slot_count()
118    }
119
120    /// Simplifies both inner tapes, using the provided choice array
121    ///
122    /// To minimize allocations, this function takes a [`VmWorkspace`] and
123    /// spare [`VmData`]; it will reuse those allocations.
124    pub fn simplify<const M: usize>(
125        &self,
126        choices: &[Choice],
127        workspace: &mut VmWorkspace<M>,
128        mut tape: VmData<M>,
129    ) -> Result<VmData<M>, Error> {
130        if choices.len() != self.choice_count() {
131            return Err(Error::BadChoiceSlice(
132                choices.len(),
133                self.choice_count(),
134            ));
135        }
136        tape.ssa.reset();
137
138        // Steal `tape.asm` and hand it to the workspace for use in allocator
139        workspace.reset(self.ssa.tape.len(), tape.asm);
140
141        let mut choice_count = 0;
142        let mut output_count = 0;
143
144        // Other iterators to consume various arrays in order
145        let mut choice_iter = choices.iter().rev();
146
147        let mut ops_out = tape.ssa.tape;
148
149        for mut op in self.ssa.tape.iter().cloned() {
150            let index = match &mut op {
151                SsaOp::Output(reg, _i) => {
152                    *reg = workspace.get_or_insert_active(*reg);
153                    workspace.alloc.op(op);
154                    ops_out.push(op);
155                    output_count += 1;
156                    continue;
157                }
158                _ => op.output().unwrap(),
159            };
160
161            if workspace.active(index).is_none() {
162                if op.has_choice() {
163                    choice_iter.next().unwrap();
164                }
165                continue;
166            }
167
168            // Because we reassign nodes when they're used as an *input*
169            // (while walking the tape in reverse), this node must have been
170            // assigned already.
171            let new_index = workspace.active(index).unwrap();
172
173            match &mut op {
174                SsaOp::Output(..) => unreachable!(),
175                SsaOp::Input(index, ..) | SsaOp::CopyImm(index, ..) => {
176                    *index = new_index;
177                }
178                SsaOp::NegReg(index, arg)
179                | SsaOp::AbsReg(index, arg)
180                | SsaOp::RecipReg(index, arg)
181                | SsaOp::SqrtReg(index, arg)
182                | SsaOp::SquareReg(index, arg)
183                | SsaOp::FloorReg(index, arg)
184                | SsaOp::CeilReg(index, arg)
185                | SsaOp::RoundReg(index, arg)
186                | SsaOp::SinReg(index, arg)
187                | SsaOp::CosReg(index, arg)
188                | SsaOp::TanReg(index, arg)
189                | SsaOp::AsinReg(index, arg)
190                | SsaOp::AcosReg(index, arg)
191                | SsaOp::AtanReg(index, arg)
192                | SsaOp::ExpReg(index, arg)
193                | SsaOp::LnReg(index, arg)
194                | SsaOp::NotReg(index, arg) => {
195                    *index = new_index;
196                    *arg = workspace.get_or_insert_active(*arg);
197                }
198                SsaOp::CopyReg(index, src) => {
199                    // CopyReg effectively does
200                    //      dst <= src
201                    // If src has not yet been used (as we iterate backwards
202                    // through the tape), then we can replace it with dst
203                    // everywhere!
204                    match workspace.active(*src) {
205                        Some(new_src) => {
206                            *index = new_index;
207                            *src = new_src;
208                        }
209                        None => {
210                            workspace.set_active(*src, new_index);
211                            continue;
212                        }
213                    }
214                }
215                SsaOp::MinRegImm(index, arg, imm)
216                | SsaOp::MaxRegImm(index, arg, imm)
217                | SsaOp::AndRegImm(index, arg, imm)
218                | SsaOp::OrRegImm(index, arg, imm) => {
219                    match choice_iter.next().unwrap() {
220                        Choice::Left => match workspace.active(*arg) {
221                            Some(new_arg) => {
222                                op = SsaOp::CopyReg(new_index, new_arg);
223                            }
224                            None => {
225                                workspace.set_active(*arg, new_index);
226                                continue;
227                            }
228                        },
229                        Choice::Right => {
230                            op = SsaOp::CopyImm(new_index, *imm);
231                        }
232                        Choice::Both => {
233                            choice_count += 1;
234                            *index = new_index;
235                            *arg = workspace.get_or_insert_active(*arg);
236                        }
237                        Choice::Unknown => panic!("oh no"),
238                    }
239                }
240                SsaOp::MinRegReg(index, lhs, rhs)
241                | SsaOp::MaxRegReg(index, lhs, rhs)
242                | SsaOp::AndRegReg(index, lhs, rhs)
243                | SsaOp::OrRegReg(index, lhs, rhs) => {
244                    match choice_iter.next().unwrap() {
245                        Choice::Left => match workspace.active(*lhs) {
246                            Some(new_lhs) => {
247                                op = SsaOp::CopyReg(new_index, new_lhs);
248                            }
249                            None => {
250                                workspace.set_active(*lhs, new_index);
251                                continue;
252                            }
253                        },
254                        Choice::Right => match workspace.active(*rhs) {
255                            Some(new_rhs) => {
256                                op = SsaOp::CopyReg(new_index, new_rhs);
257                            }
258                            None => {
259                                workspace.set_active(*rhs, new_index);
260                                continue;
261                            }
262                        },
263                        Choice::Both => {
264                            choice_count += 1;
265                            *index = new_index;
266                            *lhs = workspace.get_or_insert_active(*lhs);
267                            *rhs = workspace.get_or_insert_active(*rhs);
268                        }
269                        Choice::Unknown => panic!("oh no"),
270                    }
271                }
272                SsaOp::AddRegReg(index, lhs, rhs)
273                | SsaOp::MulRegReg(index, lhs, rhs)
274                | SsaOp::SubRegReg(index, lhs, rhs)
275                | SsaOp::DivRegReg(index, lhs, rhs)
276                | SsaOp::AtanRegReg(index, lhs, rhs)
277                | SsaOp::CompareRegReg(index, lhs, rhs)
278                | SsaOp::ModRegReg(index, lhs, rhs) => {
279                    *index = new_index;
280                    *lhs = workspace.get_or_insert_active(*lhs);
281                    *rhs = workspace.get_or_insert_active(*rhs);
282                }
283                SsaOp::AddRegImm(index, arg, _imm)
284                | SsaOp::MulRegImm(index, arg, _imm)
285                | SsaOp::SubRegImm(index, arg, _imm)
286                | SsaOp::SubImmReg(index, arg, _imm)
287                | SsaOp::DivRegImm(index, arg, _imm)
288                | SsaOp::DivImmReg(index, arg, _imm)
289                | SsaOp::AtanImmReg(index, arg, _imm)
290                | SsaOp::AtanRegImm(index, arg, _imm)
291                | SsaOp::CompareRegImm(index, arg, _imm)
292                | SsaOp::CompareImmReg(index, arg, _imm)
293                | SsaOp::ModRegImm(index, arg, _imm)
294                | SsaOp::ModImmReg(index, arg, _imm) => {
295                    *index = new_index;
296                    *arg = workspace.get_or_insert_active(*arg);
297                }
298            }
299            workspace.alloc.op(op);
300            ops_out.push(op);
301        }
302
303        assert_eq!(workspace.count as usize + 1, ops_out.len());
304        let asm_tape = workspace.alloc.finalize();
305
306        Ok(VmData {
307            ssa: SsaTape {
308                tape: ops_out,
309                choice_count,
310                output_count,
311            },
312            asm: asm_tape,
313            vars: self.vars.clone(),
314        })
315    }
316
317    /// Produces an iterator that visits [`RegOp`] values in evaluation order
318    pub fn iter_asm(&self) -> impl Iterator<Item = RegOp> + '_ {
319        self.asm.iter().cloned().rev()
320    }
321
322    /// Pretty-prints the inner SSA tape
323    pub fn pretty_print(&self) {
324        self.ssa.pretty_print();
325        for a in self.iter_asm() {
326            println!("{a:?}");
327        }
328    }
329}
330
331////////////////////////////////////////////////////////////////////////////////
332
333/// Data structures used during [`VmData::simplify`]
334///
335/// This is exposed to minimize reallocations in hot loops.
336pub struct VmWorkspace<const N: usize> {
337    /// Register allocator
338    pub(crate) alloc: RegisterAllocator<N>,
339
340    /// Current bindings from SSA variables to registers
341    pub(crate) bind: Vec<u32>,
342
343    /// Number of active SSA bindings
344    ///
345    /// This value is monotonically increasing; each SSA variable gets the next
346    /// value if it is unassigned when encountered.
347    count: u32,
348}
349
350impl<const N: usize> Default for VmWorkspace<N> {
351    fn default() -> Self {
352        Self {
353            alloc: RegisterAllocator::empty(),
354            bind: vec![],
355            count: 0,
356        }
357    }
358}
359
360impl<const N: usize> VmWorkspace<N> {
361    fn active(&self, i: u32) -> Option<u32> {
362        if self.bind[i as usize] != u32::MAX {
363            Some(self.bind[i as usize])
364        } else {
365            None
366        }
367    }
368
369    fn get_or_insert_active(&mut self, i: u32) -> u32 {
370        if self.bind[i as usize] == u32::MAX {
371            self.bind[i as usize] = self.count;
372            self.count += 1;
373        }
374        self.bind[i as usize]
375    }
376
377    fn set_active(&mut self, i: u32, bind: u32) {
378        self.bind[i as usize] = bind;
379    }
380
381    /// Resets the workspace, preserving allocations and claiming the given
382    /// [`RegTape`].
383    pub fn reset(&mut self, tape_len: usize, tape: RegTape) {
384        self.alloc.reset(tape_len, tape);
385        self.bind.fill(u32::MAX);
386        self.bind.resize(tape_len, u32::MAX);
387        self.count = 0;
388    }
389}
390
391#[cfg(test)]
392mod test {
393    use super::*;
394
395    #[test]
396    fn simplify_reg_count_change() {
397        let mut ctx = Context::new();
398        let x = ctx.x();
399        let y = ctx.y();
400        let z = ctx.z();
401        let xy = ctx.add(x, y).unwrap();
402        let xyz = ctx.add(xy, z).unwrap();
403
404        let data = VmData::<3>::new(&ctx, &[xyz]).unwrap();
405        assert_eq!(data.len(), 6); // 3x input, 2x add, 1x output
406        let next = data
407            .simplify::<2>(&[], &mut Default::default(), Default::default())
408            .unwrap();
409        assert_eq!(next.len(), 8); // extra load + store
410
411        let data = VmData::<2>::new(&ctx, &[xyz]).unwrap();
412        assert_eq!(data.len(), 8);
413        let next = data
414            .simplify::<3>(&[], &mut Default::default(), Default::default())
415            .unwrap();
416        assert_eq!(next.len(), 6);
417    }
418}