fidget_core/compiler/
ssa_tape.rs

1//use crate::vm::{RegisterAllocator, Tape as VmTape};
2use crate::{
3    Context, Error,
4    compiler::SsaOp,
5    context::{BinaryOpcode, Node, Op, UnaryOpcode},
6    var::VarMap,
7};
8use serde::{Deserialize, Serialize};
9
10use std::collections::{HashMap, HashSet};
11
12/// Instruction tape, storing [opcodes in SSA form](crate::compiler::SsaOp)
13///
14/// Each operation has the following parameters
15/// - 4-byte opcode (required)
16/// - 4-byte output register (required)
17/// - 4-byte LHS register
18/// - 4-byte RHS register (or immediate `f32`)
19///
20/// All register addressing is absolute.
21#[derive(Clone, Debug, Default, Serialize, Deserialize)]
22pub struct SsaTape {
23    /// The tape is stored in reverse order, such that the root of the tree is
24    /// the first item in the tape.
25    pub tape: Vec<SsaOp>,
26
27    /// Number of choice operations in the tape
28    pub choice_count: usize,
29
30    /// Number of output operations in the tape
31    pub output_count: usize,
32}
33
34impl SsaTape {
35    /// Flattens a subtree of the graph into straight-line code.
36    ///
37    /// This should always succeed unless the `root` is from a different
38    /// `Context`, in which case `Error::BadNode` will be returned.
39    pub fn new(ctx: &Context, roots: &[Node]) -> Result<(Self, VarMap), Error> {
40        let mut mapping = HashMap::new();
41        let mut parent_count: HashMap<Node, usize> = HashMap::new();
42        let mut slot_count = 0;
43
44        // Get either a node or constant index
45        #[derive(Copy, Clone, Debug)]
46        enum Slot {
47            Reg(u32),
48            Immediate(f32),
49        }
50
51        // Accumulate parent counts and declare all nodes
52        let mut seen = HashSet::new();
53        let mut vars = VarMap::new();
54        let mut todo = roots.to_vec();
55        while let Some(node) = todo.pop() {
56            if !seen.insert(node) {
57                continue;
58            }
59            let op = ctx.get_op(node).ok_or(Error::BadNode)?;
60            let prev = match op {
61                Op::Const(c) => {
62                    mapping.insert(node, Slot::Immediate(c.0 as f32))
63                }
64                _ => {
65                    if let Op::Input(v) = op {
66                        vars.insert(*v);
67                    }
68                    let i = slot_count;
69                    slot_count += 1;
70                    mapping.insert(node, Slot::Reg(i))
71                }
72            };
73            assert!(prev.is_none());
74            for child in op.iter_children() {
75                *parent_count.entry(child).or_default() += 1;
76                todo.push(child);
77            }
78        }
79
80        // Now that we've populated our parents, flatten the graph
81        let mut seen = HashSet::new();
82        let mut todo = roots.to_vec();
83        let mut choice_count = 0;
84
85        let mut tape = vec![];
86        for (i, r) in roots.iter().enumerate() {
87            let i = i as u32;
88            match mapping[r] {
89                Slot::Reg(out_reg) => tape.push(SsaOp::Output(out_reg, i)),
90                Slot::Immediate(imm) => {
91                    let o = slot_count;
92                    slot_count += 1;
93                    tape.push(SsaOp::Output(o, i));
94                    tape.push(SsaOp::CopyImm(o, imm));
95                }
96            }
97        }
98
99        while let Some(node) = todo.pop() {
100            if *parent_count.get(&node).unwrap_or(&0) > 0 || !seen.insert(node)
101            {
102                continue;
103            }
104
105            let op = ctx.get_op(node).unwrap();
106            for child in op.iter_children() {
107                todo.push(child);
108                *parent_count.get_mut(&child).unwrap() -= 1;
109            }
110
111            let Slot::Reg(i) = mapping[&node] else {
112                // Constants are skipped, because they become immediates
113                continue;
114            };
115            let op = match op {
116                Op::Input(v) => {
117                    let arg = vars[v];
118                    SsaOp::Input(i, arg.try_into().unwrap())
119                }
120                Op::Const(..) => {
121                    unreachable!("skipped above")
122                }
123                Op::Binary(op, lhs, rhs) => {
124                    let lhs = mapping[lhs];
125                    let rhs = mapping[rhs];
126
127                    type RegFn = fn(u32, u32, u32) -> SsaOp;
128                    type ImmFn = fn(u32, u32, f32) -> SsaOp;
129                    let f: (RegFn, ImmFn, ImmFn) = match op {
130                        BinaryOpcode::Add => (
131                            SsaOp::AddRegReg,
132                            SsaOp::AddRegImm,
133                            SsaOp::AddRegImm,
134                        ),
135                        BinaryOpcode::Sub => (
136                            SsaOp::SubRegReg,
137                            SsaOp::SubRegImm,
138                            SsaOp::SubImmReg,
139                        ),
140                        BinaryOpcode::Mul => (
141                            SsaOp::MulRegReg,
142                            SsaOp::MulRegImm,
143                            SsaOp::MulRegImm,
144                        ),
145                        BinaryOpcode::Div => (
146                            SsaOp::DivRegReg,
147                            SsaOp::DivRegImm,
148                            SsaOp::DivImmReg,
149                        ),
150                        BinaryOpcode::Atan => (
151                            SsaOp::AtanRegReg,
152                            SsaOp::AtanRegImm,
153                            SsaOp::AtanImmReg,
154                        ),
155                        BinaryOpcode::Min => (
156                            SsaOp::MinRegReg,
157                            SsaOp::MinRegImm,
158                            SsaOp::MinRegImm,
159                        ),
160                        BinaryOpcode::Max => (
161                            SsaOp::MaxRegReg,
162                            SsaOp::MaxRegImm,
163                            SsaOp::MaxRegImm,
164                        ),
165                        BinaryOpcode::And => (
166                            SsaOp::AndRegReg,
167                            SsaOp::AndRegImm,
168                            |_out, _lhs, _rhs| {
169                                panic!("AndImmReg must be collapsed")
170                            },
171                        ),
172                        BinaryOpcode::Or => (
173                            SsaOp::OrRegReg,
174                            SsaOp::OrRegImm,
175                            |_out, _lhs, _rhs| {
176                                panic!("OrImmReg must be collapsed")
177                            },
178                        ),
179                        BinaryOpcode::Compare => (
180                            SsaOp::CompareRegReg,
181                            SsaOp::CompareRegImm,
182                            SsaOp::CompareImmReg,
183                        ),
184                        BinaryOpcode::Mod => (
185                            SsaOp::ModRegReg,
186                            SsaOp::ModRegImm,
187                            SsaOp::ModImmReg,
188                        ),
189                    };
190
191                    if matches!(
192                        op,
193                        BinaryOpcode::Min
194                            | BinaryOpcode::Max
195                            | BinaryOpcode::And
196                            | BinaryOpcode::Or
197                    ) {
198                        choice_count += 1;
199                    }
200
201                    match (lhs, rhs) {
202                        (Slot::Reg(lhs), Slot::Reg(rhs)) => f.0(i, lhs, rhs),
203                        (Slot::Reg(arg), Slot::Immediate(imm)) => {
204                            f.1(i, arg, imm)
205                        }
206                        (Slot::Immediate(imm), Slot::Reg(arg)) => {
207                            f.2(i, arg, imm)
208                        }
209                        (Slot::Immediate(..), Slot::Immediate(..)) => {
210                            panic!("Cannot handle f(imm, imm)")
211                        }
212                    }
213                }
214                Op::Unary(op, lhs) => {
215                    let lhs = match mapping[lhs] {
216                        Slot::Reg(r) => r,
217                        Slot::Immediate(..) => {
218                            panic!("Cannot handle f(imm)")
219                        }
220                    };
221                    let op = match op {
222                        UnaryOpcode::Neg => SsaOp::NegReg,
223                        UnaryOpcode::Abs => SsaOp::AbsReg,
224                        UnaryOpcode::Recip => SsaOp::RecipReg,
225                        UnaryOpcode::Sqrt => SsaOp::SqrtReg,
226                        UnaryOpcode::Square => SsaOp::SquareReg,
227                        UnaryOpcode::Floor => SsaOp::FloorReg,
228                        UnaryOpcode::Ceil => SsaOp::CeilReg,
229                        UnaryOpcode::Round => SsaOp::RoundReg,
230                        UnaryOpcode::Sin => SsaOp::SinReg,
231                        UnaryOpcode::Cos => SsaOp::CosReg,
232                        UnaryOpcode::Tan => SsaOp::TanReg,
233                        UnaryOpcode::Asin => SsaOp::AsinReg,
234                        UnaryOpcode::Acos => SsaOp::AcosReg,
235                        UnaryOpcode::Atan => SsaOp::AtanReg,
236                        UnaryOpcode::Exp => SsaOp::ExpReg,
237                        UnaryOpcode::Ln => SsaOp::LnReg,
238                        UnaryOpcode::Not => SsaOp::NotReg,
239                    };
240                    op(i, lhs)
241                }
242            };
243            tape.push(op);
244        }
245
246        Ok((
247            SsaTape {
248                tape,
249                choice_count,
250                output_count: roots.len(),
251            },
252            vars,
253        ))
254    }
255
256    /// Checks whether the tape is empty
257    pub fn is_empty(&self) -> bool {
258        self.tape.is_empty()
259    }
260
261    /// Returns the length of the tape
262    pub fn len(&self) -> usize {
263        self.tape.len()
264    }
265
266    /// Iterates over clauses in the tape in reverse-evaluation order
267    ///
268    /// The root (output) of the tape will be first in the iterator
269    pub fn iter(&self) -> impl DoubleEndedIterator<Item = &SsaOp> {
270        self.tape.iter()
271    }
272
273    /// Resets to an empty tape, preserving allocations
274    pub fn reset(&mut self) {
275        self.tape.clear();
276        self.choice_count = 0;
277    }
278    /// Pretty-prints the given tape to `stdout`
279    pub fn pretty_print(&self) {
280        for &op in self.tape.iter().rev() {
281            match op {
282                SsaOp::Output(arg, i) => {
283                    println!("OUTPUT[{i}] = ${arg}");
284                }
285                SsaOp::Input(out, i) => {
286                    println!("${out} = INPUT[{i}]");
287                }
288                SsaOp::NegReg(out, arg)
289                | SsaOp::AbsReg(out, arg)
290                | SsaOp::RecipReg(out, arg)
291                | SsaOp::SqrtReg(out, arg)
292                | SsaOp::CopyReg(out, arg)
293                | SsaOp::SquareReg(out, arg)
294                | SsaOp::FloorReg(out, arg)
295                | SsaOp::CeilReg(out, arg)
296                | SsaOp::RoundReg(out, arg)
297                | SsaOp::SinReg(out, arg)
298                | SsaOp::CosReg(out, arg)
299                | SsaOp::TanReg(out, arg)
300                | SsaOp::AsinReg(out, arg)
301                | SsaOp::AcosReg(out, arg)
302                | SsaOp::AtanReg(out, arg)
303                | SsaOp::ExpReg(out, arg)
304                | SsaOp::LnReg(out, arg)
305                | SsaOp::NotReg(out, arg) => {
306                    let op = match op {
307                        SsaOp::NegReg(..) => "NEG",
308                        SsaOp::AbsReg(..) => "ABS",
309                        SsaOp::RecipReg(..) => "RECIP",
310                        SsaOp::SqrtReg(..) => "SQRT",
311                        SsaOp::SquareReg(..) => "SQUARE",
312                        SsaOp::FloorReg(..) => "FLOOR",
313                        SsaOp::CeilReg(..) => "CEIL",
314                        SsaOp::RoundReg(..) => "ROUND",
315                        SsaOp::SinReg(..) => "SIN",
316                        SsaOp::CosReg(..) => "COS",
317                        SsaOp::TanReg(..) => "TAN",
318                        SsaOp::AsinReg(..) => "ASIN",
319                        SsaOp::AcosReg(..) => "ACOS",
320                        SsaOp::AtanReg(..) => "ATAN",
321                        SsaOp::ExpReg(..) => "EXP",
322                        SsaOp::LnReg(..) => "LN",
323                        SsaOp::NotReg(..) => "NOT",
324                        SsaOp::CopyReg(..) => "COPY",
325                        _ => unreachable!(),
326                    };
327                    println!("${out} = {op} ${arg}");
328                }
329
330                SsaOp::AddRegReg(out, lhs, rhs)
331                | SsaOp::MulRegReg(out, lhs, rhs)
332                | SsaOp::DivRegReg(out, lhs, rhs)
333                | SsaOp::SubRegReg(out, lhs, rhs)
334                | SsaOp::MinRegReg(out, lhs, rhs)
335                | SsaOp::MaxRegReg(out, lhs, rhs)
336                | SsaOp::ModRegReg(out, lhs, rhs)
337                | SsaOp::AndRegReg(out, lhs, rhs)
338                | SsaOp::AtanRegReg(out, lhs, rhs)
339                | SsaOp::OrRegReg(out, lhs, rhs) => {
340                    let op = match op {
341                        SsaOp::AddRegReg(..) => "ADD",
342                        SsaOp::MulRegReg(..) => "MUL",
343                        SsaOp::DivRegReg(..) => "DIV",
344                        SsaOp::AtanRegReg(..) => "ATAN",
345                        SsaOp::SubRegReg(..) => "SUB",
346                        SsaOp::MinRegReg(..) => "MIN",
347                        SsaOp::MaxRegReg(..) => "MAX",
348                        SsaOp::ModRegReg(..) => "MAX",
349                        SsaOp::AndRegReg(..) => "AND",
350                        SsaOp::OrRegReg(..) => "OR",
351                        _ => unreachable!(),
352                    };
353                    println!("${out} = {op} ${lhs} ${rhs}");
354                }
355
356                SsaOp::AddRegImm(out, arg, imm)
357                | SsaOp::MulRegImm(out, arg, imm)
358                | SsaOp::DivRegImm(out, arg, imm)
359                | SsaOp::DivImmReg(out, arg, imm)
360                | SsaOp::SubImmReg(out, arg, imm)
361                | SsaOp::SubRegImm(out, arg, imm)
362                | SsaOp::AtanRegImm(out, arg, imm)
363                | SsaOp::AtanImmReg(out, arg, imm)
364                | SsaOp::MinRegImm(out, arg, imm)
365                | SsaOp::MaxRegImm(out, arg, imm)
366                | SsaOp::ModRegImm(out, arg, imm)
367                | SsaOp::ModImmReg(out, arg, imm)
368                | SsaOp::AndRegImm(out, arg, imm)
369                | SsaOp::OrRegImm(out, arg, imm) => {
370                    let (op, swap) = match op {
371                        SsaOp::AddRegImm(..) => ("ADD", false),
372                        SsaOp::MulRegImm(..) => ("MUL", false),
373                        SsaOp::DivImmReg(..) => ("DIV", true),
374                        SsaOp::DivRegImm(..) => ("DIV", false),
375                        SsaOp::SubImmReg(..) => ("SUB", true),
376                        SsaOp::SubRegImm(..) => ("SUB", false),
377                        SsaOp::AtanImmReg(..) => ("ATAN", true),
378                        SsaOp::AtanRegImm(..) => ("ATAN", false),
379                        SsaOp::MinRegImm(..) => ("MIN", false),
380                        SsaOp::MaxRegImm(..) => ("MAX", false),
381                        SsaOp::ModRegImm(..) => ("MOD", false),
382                        SsaOp::ModImmReg(..) => ("MOD", true),
383                        SsaOp::AndRegImm(..) => ("AND", false),
384                        SsaOp::OrRegImm(..) => ("OR", false),
385                        _ => unreachable!(),
386                    };
387                    if swap {
388                        println!("${out} = {op} {imm} ${arg}");
389                    } else {
390                        println!("${out} = {op} ${arg} {imm}");
391                    }
392                }
393                SsaOp::CompareRegReg(out, lhs, rhs) => {
394                    println!("${out} = COMPARE {lhs} {rhs}")
395                }
396                SsaOp::CompareRegImm(out, arg, imm) => {
397                    println!("${out} = COMPARE {arg} {imm}")
398                }
399                SsaOp::CompareImmReg(out, arg, imm) => {
400                    println!("${out} = COMPARE {imm} {arg}")
401                }
402                SsaOp::CopyImm(out, imm) => {
403                    println!("${out} = COPY {imm}");
404                }
405            }
406        }
407    }
408}
409
410#[cfg(test)]
411mod test {
412    use super::*;
413
414    #[test]
415    fn test_ring() {
416        let mut ctx = Context::new();
417        let c0 = ctx.constant(0.5);
418        let x = ctx.x();
419        let y = ctx.y();
420        let x2 = ctx.square(x).unwrap();
421        let y2 = ctx.square(y).unwrap();
422        let r = ctx.add(x2, y2).unwrap();
423        let c6 = ctx.sub(r, c0).unwrap();
424        let c7 = ctx.constant(0.25);
425        let c8 = ctx.sub(c7, r).unwrap();
426        let c9 = ctx.max(c8, c6).unwrap();
427
428        let (tape, vs) = SsaTape::new(&ctx, &[c9]).unwrap();
429        assert_eq!(tape.len(), 9);
430        assert_eq!(vs.len(), 2);
431    }
432
433    #[test]
434    fn test_dupe() {
435        let mut ctx = Context::new();
436        let x = ctx.x();
437        let x_squared = ctx.mul(x, x).unwrap();
438
439        let (tape, vs) = SsaTape::new(&ctx, &[x_squared]).unwrap();
440        assert_eq!(tape.len(), 3); // x, square, output
441        assert_eq!(vs.len(), 1);
442    }
443
444    #[test]
445    fn test_constant() {
446        let mut ctx = Context::new();
447        let p = ctx.constant(1.5);
448        let (tape, vs) = SsaTape::new(&ctx, &[p]).unwrap();
449        assert_eq!(tape.len(), 2); // CopyImm, output
450        assert_eq!(vs.len(), 0);
451    }
452}