1use crate::{
3 Context, Error,
4 compiler::SsaOp,
5 context::{BinaryOpcode, Node, Op, UnaryOpcode},
6 var::VarMap,
7};
8use serde::{Deserialize, Serialize};
9
10use std::collections::{HashMap, HashSet};
11
12#[derive(Clone, Debug, Default, Serialize, Deserialize)]
22pub struct SsaTape {
23 pub tape: Vec<SsaOp>,
26
27 pub choice_count: usize,
29
30 pub output_count: usize,
32}
33
34impl SsaTape {
35 pub fn new(ctx: &Context, roots: &[Node]) -> Result<(Self, VarMap), Error> {
40 let mut mapping = HashMap::new();
41 let mut parent_count: HashMap<Node, usize> = HashMap::new();
42 let mut slot_count = 0;
43
44 #[derive(Copy, Clone, Debug)]
46 enum Slot {
47 Reg(u32),
48 Immediate(f32),
49 }
50
51 let mut seen = HashSet::new();
53 let mut vars = VarMap::new();
54 let mut todo = roots.to_vec();
55 while let Some(node) = todo.pop() {
56 if !seen.insert(node) {
57 continue;
58 }
59 let op = ctx.get_op(node).ok_or(Error::BadNode)?;
60 let prev = match op {
61 Op::Const(c) => {
62 mapping.insert(node, Slot::Immediate(c.0 as f32))
63 }
64 _ => {
65 if let Op::Input(v) = op {
66 vars.insert(*v);
67 }
68 let i = slot_count;
69 slot_count += 1;
70 mapping.insert(node, Slot::Reg(i))
71 }
72 };
73 assert!(prev.is_none());
74 for child in op.iter_children() {
75 *parent_count.entry(child).or_default() += 1;
76 todo.push(child);
77 }
78 }
79
80 let mut seen = HashSet::new();
82 let mut todo = roots.to_vec();
83 let mut choice_count = 0;
84
85 let mut tape = vec![];
86 for (i, r) in roots.iter().enumerate() {
87 let i = i as u32;
88 match mapping[r] {
89 Slot::Reg(out_reg) => tape.push(SsaOp::Output(out_reg, i)),
90 Slot::Immediate(imm) => {
91 let o = slot_count;
92 slot_count += 1;
93 tape.push(SsaOp::Output(o, i));
94 tape.push(SsaOp::CopyImm(o, imm));
95 }
96 }
97 }
98
99 while let Some(node) = todo.pop() {
100 if *parent_count.get(&node).unwrap_or(&0) > 0 || !seen.insert(node)
101 {
102 continue;
103 }
104
105 let op = ctx.get_op(node).unwrap();
106 for child in op.iter_children() {
107 todo.push(child);
108 *parent_count.get_mut(&child).unwrap() -= 1;
109 }
110
111 let Slot::Reg(i) = mapping[&node] else {
112 continue;
114 };
115 let op = match op {
116 Op::Input(v) => {
117 let arg = vars[v];
118 SsaOp::Input(i, arg.try_into().unwrap())
119 }
120 Op::Const(..) => {
121 unreachable!("skipped above")
122 }
123 Op::Binary(op, lhs, rhs) => {
124 let lhs = mapping[lhs];
125 let rhs = mapping[rhs];
126
127 type RegFn = fn(u32, u32, u32) -> SsaOp;
128 type ImmFn = fn(u32, u32, f32) -> SsaOp;
129 let f: (RegFn, ImmFn, ImmFn) = match op {
130 BinaryOpcode::Add => (
131 SsaOp::AddRegReg,
132 SsaOp::AddRegImm,
133 SsaOp::AddRegImm,
134 ),
135 BinaryOpcode::Sub => (
136 SsaOp::SubRegReg,
137 SsaOp::SubRegImm,
138 SsaOp::SubImmReg,
139 ),
140 BinaryOpcode::Mul => (
141 SsaOp::MulRegReg,
142 SsaOp::MulRegImm,
143 SsaOp::MulRegImm,
144 ),
145 BinaryOpcode::Div => (
146 SsaOp::DivRegReg,
147 SsaOp::DivRegImm,
148 SsaOp::DivImmReg,
149 ),
150 BinaryOpcode::Atan => (
151 SsaOp::AtanRegReg,
152 SsaOp::AtanRegImm,
153 SsaOp::AtanImmReg,
154 ),
155 BinaryOpcode::Min => (
156 SsaOp::MinRegReg,
157 SsaOp::MinRegImm,
158 SsaOp::MinRegImm,
159 ),
160 BinaryOpcode::Max => (
161 SsaOp::MaxRegReg,
162 SsaOp::MaxRegImm,
163 SsaOp::MaxRegImm,
164 ),
165 BinaryOpcode::And => (
166 SsaOp::AndRegReg,
167 SsaOp::AndRegImm,
168 |_out, _lhs, _rhs| {
169 panic!("AndImmReg must be collapsed")
170 },
171 ),
172 BinaryOpcode::Or => (
173 SsaOp::OrRegReg,
174 SsaOp::OrRegImm,
175 |_out, _lhs, _rhs| {
176 panic!("OrImmReg must be collapsed")
177 },
178 ),
179 BinaryOpcode::Compare => (
180 SsaOp::CompareRegReg,
181 SsaOp::CompareRegImm,
182 SsaOp::CompareImmReg,
183 ),
184 BinaryOpcode::Mod => (
185 SsaOp::ModRegReg,
186 SsaOp::ModRegImm,
187 SsaOp::ModImmReg,
188 ),
189 };
190
191 if matches!(
192 op,
193 BinaryOpcode::Min
194 | BinaryOpcode::Max
195 | BinaryOpcode::And
196 | BinaryOpcode::Or
197 ) {
198 choice_count += 1;
199 }
200
201 match (lhs, rhs) {
202 (Slot::Reg(lhs), Slot::Reg(rhs)) => f.0(i, lhs, rhs),
203 (Slot::Reg(arg), Slot::Immediate(imm)) => {
204 f.1(i, arg, imm)
205 }
206 (Slot::Immediate(imm), Slot::Reg(arg)) => {
207 f.2(i, arg, imm)
208 }
209 (Slot::Immediate(..), Slot::Immediate(..)) => {
210 panic!("Cannot handle f(imm, imm)")
211 }
212 }
213 }
214 Op::Unary(op, lhs) => {
215 let lhs = match mapping[lhs] {
216 Slot::Reg(r) => r,
217 Slot::Immediate(..) => {
218 panic!("Cannot handle f(imm)")
219 }
220 };
221 let op = match op {
222 UnaryOpcode::Neg => SsaOp::NegReg,
223 UnaryOpcode::Abs => SsaOp::AbsReg,
224 UnaryOpcode::Recip => SsaOp::RecipReg,
225 UnaryOpcode::Sqrt => SsaOp::SqrtReg,
226 UnaryOpcode::Square => SsaOp::SquareReg,
227 UnaryOpcode::Floor => SsaOp::FloorReg,
228 UnaryOpcode::Ceil => SsaOp::CeilReg,
229 UnaryOpcode::Round => SsaOp::RoundReg,
230 UnaryOpcode::Sin => SsaOp::SinReg,
231 UnaryOpcode::Cos => SsaOp::CosReg,
232 UnaryOpcode::Tan => SsaOp::TanReg,
233 UnaryOpcode::Asin => SsaOp::AsinReg,
234 UnaryOpcode::Acos => SsaOp::AcosReg,
235 UnaryOpcode::Atan => SsaOp::AtanReg,
236 UnaryOpcode::Exp => SsaOp::ExpReg,
237 UnaryOpcode::Ln => SsaOp::LnReg,
238 UnaryOpcode::Not => SsaOp::NotReg,
239 };
240 op(i, lhs)
241 }
242 };
243 tape.push(op);
244 }
245
246 Ok((
247 SsaTape {
248 tape,
249 choice_count,
250 output_count: roots.len(),
251 },
252 vars,
253 ))
254 }
255
256 pub fn is_empty(&self) -> bool {
258 self.tape.is_empty()
259 }
260
261 pub fn len(&self) -> usize {
263 self.tape.len()
264 }
265
266 pub fn iter(&self) -> impl DoubleEndedIterator<Item = &SsaOp> {
270 self.tape.iter()
271 }
272
273 pub fn reset(&mut self) {
275 self.tape.clear();
276 self.choice_count = 0;
277 }
278 pub fn pretty_print(&self) {
280 for &op in self.tape.iter().rev() {
281 match op {
282 SsaOp::Output(arg, i) => {
283 println!("OUTPUT[{i}] = ${arg}");
284 }
285 SsaOp::Input(out, i) => {
286 println!("${out} = INPUT[{i}]");
287 }
288 SsaOp::NegReg(out, arg)
289 | SsaOp::AbsReg(out, arg)
290 | SsaOp::RecipReg(out, arg)
291 | SsaOp::SqrtReg(out, arg)
292 | SsaOp::CopyReg(out, arg)
293 | SsaOp::SquareReg(out, arg)
294 | SsaOp::FloorReg(out, arg)
295 | SsaOp::CeilReg(out, arg)
296 | SsaOp::RoundReg(out, arg)
297 | SsaOp::SinReg(out, arg)
298 | SsaOp::CosReg(out, arg)
299 | SsaOp::TanReg(out, arg)
300 | SsaOp::AsinReg(out, arg)
301 | SsaOp::AcosReg(out, arg)
302 | SsaOp::AtanReg(out, arg)
303 | SsaOp::ExpReg(out, arg)
304 | SsaOp::LnReg(out, arg)
305 | SsaOp::NotReg(out, arg) => {
306 let op = match op {
307 SsaOp::NegReg(..) => "NEG",
308 SsaOp::AbsReg(..) => "ABS",
309 SsaOp::RecipReg(..) => "RECIP",
310 SsaOp::SqrtReg(..) => "SQRT",
311 SsaOp::SquareReg(..) => "SQUARE",
312 SsaOp::FloorReg(..) => "FLOOR",
313 SsaOp::CeilReg(..) => "CEIL",
314 SsaOp::RoundReg(..) => "ROUND",
315 SsaOp::SinReg(..) => "SIN",
316 SsaOp::CosReg(..) => "COS",
317 SsaOp::TanReg(..) => "TAN",
318 SsaOp::AsinReg(..) => "ASIN",
319 SsaOp::AcosReg(..) => "ACOS",
320 SsaOp::AtanReg(..) => "ATAN",
321 SsaOp::ExpReg(..) => "EXP",
322 SsaOp::LnReg(..) => "LN",
323 SsaOp::NotReg(..) => "NOT",
324 SsaOp::CopyReg(..) => "COPY",
325 _ => unreachable!(),
326 };
327 println!("${out} = {op} ${arg}");
328 }
329
330 SsaOp::AddRegReg(out, lhs, rhs)
331 | SsaOp::MulRegReg(out, lhs, rhs)
332 | SsaOp::DivRegReg(out, lhs, rhs)
333 | SsaOp::SubRegReg(out, lhs, rhs)
334 | SsaOp::MinRegReg(out, lhs, rhs)
335 | SsaOp::MaxRegReg(out, lhs, rhs)
336 | SsaOp::ModRegReg(out, lhs, rhs)
337 | SsaOp::AndRegReg(out, lhs, rhs)
338 | SsaOp::AtanRegReg(out, lhs, rhs)
339 | SsaOp::OrRegReg(out, lhs, rhs) => {
340 let op = match op {
341 SsaOp::AddRegReg(..) => "ADD",
342 SsaOp::MulRegReg(..) => "MUL",
343 SsaOp::DivRegReg(..) => "DIV",
344 SsaOp::AtanRegReg(..) => "ATAN",
345 SsaOp::SubRegReg(..) => "SUB",
346 SsaOp::MinRegReg(..) => "MIN",
347 SsaOp::MaxRegReg(..) => "MAX",
348 SsaOp::ModRegReg(..) => "MAX",
349 SsaOp::AndRegReg(..) => "AND",
350 SsaOp::OrRegReg(..) => "OR",
351 _ => unreachable!(),
352 };
353 println!("${out} = {op} ${lhs} ${rhs}");
354 }
355
356 SsaOp::AddRegImm(out, arg, imm)
357 | SsaOp::MulRegImm(out, arg, imm)
358 | SsaOp::DivRegImm(out, arg, imm)
359 | SsaOp::DivImmReg(out, arg, imm)
360 | SsaOp::SubImmReg(out, arg, imm)
361 | SsaOp::SubRegImm(out, arg, imm)
362 | SsaOp::AtanRegImm(out, arg, imm)
363 | SsaOp::AtanImmReg(out, arg, imm)
364 | SsaOp::MinRegImm(out, arg, imm)
365 | SsaOp::MaxRegImm(out, arg, imm)
366 | SsaOp::ModRegImm(out, arg, imm)
367 | SsaOp::ModImmReg(out, arg, imm)
368 | SsaOp::AndRegImm(out, arg, imm)
369 | SsaOp::OrRegImm(out, arg, imm) => {
370 let (op, swap) = match op {
371 SsaOp::AddRegImm(..) => ("ADD", false),
372 SsaOp::MulRegImm(..) => ("MUL", false),
373 SsaOp::DivImmReg(..) => ("DIV", true),
374 SsaOp::DivRegImm(..) => ("DIV", false),
375 SsaOp::SubImmReg(..) => ("SUB", true),
376 SsaOp::SubRegImm(..) => ("SUB", false),
377 SsaOp::AtanImmReg(..) => ("ATAN", true),
378 SsaOp::AtanRegImm(..) => ("ATAN", false),
379 SsaOp::MinRegImm(..) => ("MIN", false),
380 SsaOp::MaxRegImm(..) => ("MAX", false),
381 SsaOp::ModRegImm(..) => ("MOD", false),
382 SsaOp::ModImmReg(..) => ("MOD", true),
383 SsaOp::AndRegImm(..) => ("AND", false),
384 SsaOp::OrRegImm(..) => ("OR", false),
385 _ => unreachable!(),
386 };
387 if swap {
388 println!("${out} = {op} {imm} ${arg}");
389 } else {
390 println!("${out} = {op} ${arg} {imm}");
391 }
392 }
393 SsaOp::CompareRegReg(out, lhs, rhs) => {
394 println!("${out} = COMPARE {lhs} {rhs}")
395 }
396 SsaOp::CompareRegImm(out, arg, imm) => {
397 println!("${out} = COMPARE {arg} {imm}")
398 }
399 SsaOp::CompareImmReg(out, arg, imm) => {
400 println!("${out} = COMPARE {imm} {arg}")
401 }
402 SsaOp::CopyImm(out, imm) => {
403 println!("${out} = COPY {imm}");
404 }
405 }
406 }
407 }
408}
409
410#[cfg(test)]
411mod test {
412 use super::*;
413
414 #[test]
415 fn test_ring() {
416 let mut ctx = Context::new();
417 let c0 = ctx.constant(0.5);
418 let x = ctx.x();
419 let y = ctx.y();
420 let x2 = ctx.square(x).unwrap();
421 let y2 = ctx.square(y).unwrap();
422 let r = ctx.add(x2, y2).unwrap();
423 let c6 = ctx.sub(r, c0).unwrap();
424 let c7 = ctx.constant(0.25);
425 let c8 = ctx.sub(c7, r).unwrap();
426 let c9 = ctx.max(c8, c6).unwrap();
427
428 let (tape, vs) = SsaTape::new(&ctx, &[c9]).unwrap();
429 assert_eq!(tape.len(), 9);
430 assert_eq!(vs.len(), 2);
431 }
432
433 #[test]
434 fn test_dupe() {
435 let mut ctx = Context::new();
436 let x = ctx.x();
437 let x_squared = ctx.mul(x, x).unwrap();
438
439 let (tape, vs) = SsaTape::new(&ctx, &[x_squared]).unwrap();
440 assert_eq!(tape.len(), 3); assert_eq!(vs.len(), 1);
442 }
443
444 #[test]
445 fn test_constant() {
446 let mut ctx = Context::new();
447 let p = ctx.constant(1.5);
448 let (tape, vs) = SsaTape::new(&ctx, &[p]).unwrap();
449 assert_eq!(tape.len(), 2); assert_eq!(vs.len(), 0);
451 }
452}