Skip to main content

crypto_brainfuck/utils/
ir.rs

1use crate::utils::{Op, Program};
2use std::fmt::{Debug, Error, Formatter};
3
4/// Link (aka. pointer) to next operation in program graph.
5type Link = Option<usize>;
6
7/// Operations in intermediate representation.
8#[derive(Debug, Copy, Clone)]
9pub enum IrOp {
10    Noop(Link),
11    Right(Link, u8),
12    Left(Link, u8),
13    Add(Link, u8),
14    Sub(Link, u8),
15    SetIndirect(Link, u8),
16    // offset, factor
17    MulCopy(Link, i8, i8),
18    Write(Link),
19    Read(Link),
20    // next, addr if 0
21    JumpIfZero(Link, Link),
22    // next, addr if not 0
23    JumpIfNotZero(Link, Link),
24}
25
26impl IrOp {
27    fn next(&self) -> Link {
28        return match self {
29            IrOp::Noop(l) => l,
30            IrOp::Right(l, _) => l,
31            IrOp::Left(l, _) => l,
32            IrOp::Add(l, _) => l,
33            IrOp::Sub(l, _) => l,
34            IrOp::SetIndirect(l, _) => l,
35            IrOp::MulCopy(l, _, _) => l,
36            IrOp::Write(l) => l,
37            IrOp::Read(l) => l,
38            IrOp::JumpIfZero(l, _) => l,
39            IrOp::JumpIfNotZero(l, _) => l,
40        }
41        .clone();
42    }
43}
44
45/// Graph representation of program using intermediate representation with IrOps.
46pub struct IrCode {
47    pub ops: Vec<IrOp>,
48}
49
50impl IrCode {
51    pub fn new(program: &Program) -> Self {
52        let mut ops: Vec<IrOp> = Vec::new();
53        for (idx, op) in program.instructions.iter().enumerate() {
54            let is_last = program.instructions.len() - 1 == idx;
55            let next = if is_last { None } else { Some(idx + 1) };
56
57            ops.push(match op {
58                Op::IncrementPtr => IrOp::Right(next, 1),
59                Op::DecrementPtr => IrOp::Left(next, 1),
60                Op::IncrementMemory => IrOp::Add(next, 1),
61                Op::DecrementMemory => IrOp::Sub(next, 1),
62                Op::ReadByte => IrOp::Read(next),
63                Op::WriteByte => IrOp::Write(next),
64                Op::JumpForward => IrOp::JumpIfZero(next, Some(program.find_matching_jump_end(idx) + 1)),
65                Op::JumpBackward => IrOp::JumpIfNotZero(next, Some(program.find_matching_jump_start(idx))),
66            })
67        }
68
69        IrCode { ops }
70    }
71
72    fn find_replacement(&self, current_idx: usize) -> Vec<IrOp> {
73        let current = self.ops.get(current_idx).expect("current not found");
74        let next_idx = match current.next() {
75            Some(t) => t,
76            None => return vec![*current],
77        };
78        let next = self.ops.get(next_idx).expect("next not found");
79        let subsequent_idx = next.next();
80
81        // three consecutive ops
82        if let Some(t) = subsequent_idx {
83            if let Some(t) = IrCode::find_three_consecutive(current, next, self.ops.get(t).expect("subsequent not found")) {
84                return vec![t];
85            }
86        }
87
88        // two consecutive ops
89        if let Some(t) = IrCode::find_two_consecutive(current, next) {
90            return vec![t];
91        }
92
93        // multiplication loop
94        if let IrOp::JumpIfZero(_, _) = current {
95            if let Some(t) = self.find_multiplication_loop(current) {
96                return t;
97            }
98        }
99
100        // nothing to optimize
101        vec![*current]
102    }
103
104    fn find_three_consecutive(current: &IrOp, next: &IrOp, subsequent: &IrOp) -> Option<IrOp> {
105        return match (current, next, subsequent) {
106            (IrOp::JumpIfZero(_, _), IrOp::Sub(_, 1), IrOp::JumpIfNotZero(far, _)) => Some(IrOp::SetIndirect(*far, 0)),
107            (IrOp::JumpIfZero(_, _), IrOp::Add(_, 1), IrOp::JumpIfNotZero(far, _)) => Some(IrOp::SetIndirect(*far, 0)),
108            _ => None,
109        };
110    }
111
112    fn find_two_consecutive(current: &IrOp, next: &IrOp) -> Option<IrOp> {
113        return match (current, next) {
114            (IrOp::Add(_, x), IrOp::Add(far, y)) => Some(IrOp::Add(*far, *x + *y)),
115            (IrOp::Sub(_, x), IrOp::Sub(far, y)) => Some(IrOp::Sub(*far, *x + *y)),
116            (IrOp::Sub(_, x), IrOp::Add(far, y)) => {
117                let result = *y as i8 - *x as i8;
118                Some(if result > 0 { IrOp::Add(*far, result as u8) } else { IrOp::Sub(*far, -result as u8) })
119            }
120            (IrOp::Add(_, x), IrOp::Sub(far, y)) => {
121                let result = *x as i8 - *y as i8;
122                Some(if result > 0 { IrOp::Add(*far, result as u8) } else { IrOp::Sub(*far, -result as u8) })
123            }
124
125            (IrOp::Right(_, x), IrOp::Right(far, y)) => Some(IrOp::Right(*far, *x + *y)),
126            (IrOp::Left(_, x), IrOp::Left(far, y)) => Some(IrOp::Left(*far, *x + *y)),
127            (IrOp::Right(_, x), IrOp::Left(far, y)) => {
128                let result = *x as i8 - *y as i8;
129                Some(if result > 0 { IrOp::Right(*far, result as u8) } else { IrOp::Left(*far, -result as u8) })
130            }
131            (IrOp::Left(_, x), IrOp::Right(far, y)) => {
132                let result = *y as i8 - *x as i8;
133                Some(if result > 0 { IrOp::Right(*far, result as u8) } else { IrOp::Left(*far, -result as u8) })
134            }
135
136            (IrOp::SetIndirect(_, c), IrOp::Add(far, x)) => Some(IrOp::SetIndirect(*far, c + x)),
137            (IrOp::SetIndirect(_, c), IrOp::Sub(far, x)) => Some(IrOp::SetIndirect(*far, c.wrapping_sub(*x))),
138
139            (IrOp::Add(_, _), IrOp::SetIndirect(far, c)) => Some(IrOp::SetIndirect(*far, *c)),
140            (IrOp::Sub(_, _), IrOp::SetIndirect(far, c)) => Some(IrOp::SetIndirect(*far, *c)),
141
142            (IrOp::SetIndirect(_, _), IrOp::SetIndirect(far, c)) => Some(IrOp::SetIndirect(*far, *c)),
143
144            (IrOp::SetIndirect(_, 0), IrOp::JumpIfZero(x, y)) => Some(IrOp::JumpIfZero(*x, *y)),
145
146            (IrOp::Add(_, _), IrOp::Read(far)) => Some(IrOp::Read(*far)),
147            (IrOp::Sub(_, _), IrOp::Read(far)) => Some(IrOp::Read(*far)),
148            (IrOp::SetIndirect(_, _), IrOp::Read(far)) => Some(IrOp::Read(*far)),
149
150            (_, _) => None,
151        };
152    }
153
154    fn find_multiplication_loop(&self, current: &IrOp) -> Option<Vec<IrOp>> {
155        if let IrOp::JumpIfZero(_, _) = current {
156        }
157        else {
158            return None;
159        }
160
161        let mut iter = Iter { ir_code: self, idx: current.next()? }; /* None: next does not exists */
162
163        // we are matching patterns like: [sub(1), right(1), add(3), right(1), add(7), left(2)]
164        // we will record adds for different offsets by interpreting the code at compile time
165        // if we subtract more than 1, this is not clear-loop and so cannot be multiplication
166        // loop. if offset is at the end different from zero, this is not multiplication loop.
167        // if we see any other instructions, we return none too.
168
169        let mut offset: i8 = 0;
170        let mut factors: [i16; 256] = [0; 256];
171        let far_op: Option<usize>;
172
173        loop {
174            let current = match iter.next() {
175                Some(t) => t,
176                None => return None,
177            };
178
179            match current {
180                IrOp::Right(_, data) => offset += *data as i8,
181                IrOp::Left(_, data) => offset -= *data as i8,
182                IrOp::Add(_, data) => {
183                    let idx = offset as usize + 128;
184                    factors[idx] += i16::from(*data)
185                }
186                IrOp::Sub(_, data) => {
187                    let idx = offset as usize + 128;
188                    factors[idx] -= i16::from(*data)
189                }
190                IrOp::JumpIfNotZero(far, _) => {
191                    far_op = *far;
192                    break;
193                }
194                _ => return None, // None: does not match pattern
195            }
196        }
197
198        if factors[128] != -1 {
199            return None;
200        } /* None: we must subtract exactly one from original cell to be clear loop */
201        if offset != 0 {
202            return None;
203        } /* None: lefts/rights unbalanced - would not be clear loop */
204
205        // all seems good, lets emit instructions
206        let mut op_idx = self.ops.len();
207        let mut generated: Vec<IrOp> = factors
208            .iter()
209            .enumerate()
210            .filter(|(offset, factor)| *offset != 128 && **factor != 0)
211            .map(|(idx, factor)| {
212                let offset = idx as i16 - 128;
213                let r = IrOp::MulCopy(Some(op_idx), offset as i8, *factor as i8);
214                op_idx += 1;
215                r
216            })
217            .collect();
218
219        generated.push(IrOp::SetIndirect(far_op, 0));
220
221        Some(generated)
222    }
223
224    fn optimize_program_once(&mut self) -> usize {
225        let mut idx = 0;
226        let mut len = 0;
227
228        loop {
229            if idx == std::usize::MAX {
230                return len;
231            }
232
233            let replacement = self.find_replacement(idx);
234            let first = replacement.first().expect("find_replacement returned empty vector");
235            let last = replacement.last().unwrap();
236
237            // push new instructions to ops array (links should be set-up by find_replacement)
238            replacement.iter().skip(1).for_each(|x| self.ops.push(*x));
239
240            let next_idx = match last.next() {
241                Some(t) => t,
242                None => std::usize::MAX,
243            };
244            self.ops[idx] = *first;
245            idx = next_idx;
246            len += 1;
247        }
248    }
249
250    pub fn optimize(&mut self) {
251        let mut old = self.optimize_program_once();
252
253        loop {
254            let new = self.optimize_program_once();
255            if new >= old {
256                break;
257            }
258            old = new;
259        }
260    }
261
262    pub fn iter(&self) -> Iter {
263        Iter { ir_code: &self, idx: 0 }
264    }
265
266    // O(n)
267    pub fn len(&self) -> usize {
268        self.iter().count()
269    }
270}
271
272pub struct Iter<'a> {
273    ir_code: &'a IrCode,
274    idx: usize,
275}
276
277impl<'a> Iterator for Iter<'a> {
278    type Item = &'a IrOp;
279
280    fn next(&mut self) -> Option<Self::Item> {
281        self.ir_code.ops.get(self.idx).map(|t| {
282            self.idx = t.next().unwrap_or(std::usize::MAX); // proceed or point to invalid idx
283            t
284        })
285    }
286}
287
288impl Debug for IrCode {
289    fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
290        let mut current = self.ops.get(0);
291
292        f.write_str("IrCode {\n")?;
293
294        loop {
295            if current.is_none() {
296                break;
297            }
298
299            let next = current.unwrap().next();
300
301            f.write_fmt(format_args!("\t{:?},\n", current))?;
302            current = next.and_then(|x| self.ops.get(x));
303        }
304
305        f.write_str("}\n")?;
306        Ok(())
307    }
308}
309
310#[cfg(test)]
311mod test {
312    use super::*;
313
314    #[test]
315    fn iter() {
316        let ir_code = IrCode::new(&Program::from_string("+-<>.,"));
317        let mut iter = ir_code.iter();
318
319        assert_matches!(iter.next(), Some(IrOp::Add(_, 1)));
320        assert_matches!(iter.next(), Some(IrOp::Sub(_, 1)));
321        assert_matches!(iter.next(), Some(IrOp::Left(_, 1)));
322        assert_matches!(iter.next(), Some(IrOp::Right(_, 1)));
323        assert_matches!(iter.next(), Some(IrOp::Write(_)));
324        assert_matches!(iter.next(), Some(IrOp::Read(_)));
325        assert_matches!(iter.next(), None);
326    }
327
328    #[test]
329    fn len() {
330        let mut ir_code = IrCode::new(&Program::from_string("+++>+"));
331
332        assert_eq!(ir_code.len(), 5);
333        ir_code.optimize();
334        assert_eq!(ir_code.len(), 3);
335    }
336
337    #[test]
338    fn can_print_debug() {
339        let mut ir_code = IrCode::new(&Program::from_string("++[+++[->++>+++<<]>>>[+]--<<<-]"));
340        ir_code.optimize();
341
342        println!("{:?}", ir_code);
343    }
344
345    #[test]
346    fn optimizes_tail_instructions() {
347        let mut ir_code = IrCode::new(&Program::from_string("+++"));
348        ir_code.optimize();
349        let mut iter = ir_code.iter();
350
351        assert_matches!(iter.next(), Some(IrOp::Add(_, 3)));
352        assert_matches!(iter.next(), None);
353    }
354
355    #[test]
356    fn optimizes_multiplication_loops() {
357        let mut ir_code = IrCode::new(&Program::from_string("++[+++[->++>+++<<]>>>[+]--<<<-]"));
358        ir_code.optimize();
359
360        let mut iter = ir_code.iter();
361
362        assert_matches!(iter.next(), Some(IrOp::Add(_, 2)));
363        assert_matches!(iter.next(), Some(IrOp::JumpIfZero(_, _)));
364        assert_matches!(iter.next(), Some(IrOp::Add(_, 3)));
365        assert_matches!(iter.next(), Some(IrOp::MulCopy(_, 1, 2)));
366        assert_matches!(iter.next(), Some(IrOp::MulCopy(_, 2, 3)));
367        assert_matches!(iter.next(), Some(IrOp::SetIndirect(_, 0)));
368        assert_matches!(iter.next(), Some(IrOp::Right(_, 3)));
369        assert_matches!(iter.next(), Some(IrOp::SetIndirect(_, 254)));
370        assert_matches!(iter.next(), Some(IrOp::Left(_, 3)));
371        assert_matches!(iter.next(), Some(IrOp::Sub(_, 1)));
372        assert_matches!(iter.next(), Some(IrOp::JumpIfNotZero(_, _)));
373        assert_matches!(iter.next(), None);
374    }
375
376    #[test]
377    fn multiplication_loop_negative_bug() {
378        let mut ir_code = IrCode::new(&Program::from_string("[>----<-]"));
379        ir_code.optimize();
380        let mut iter = ir_code.iter();
381
382        assert_matches!(iter.next(), Some(IrOp::MulCopy(_, 1, -4)));
383        assert_matches!(iter.next(), Some(IrOp::SetIndirect(_, 0)));
384        assert_matches!(iter.next(), None);
385    }
386
387    #[test]
388    fn optimizes_consecutive_adds() {
389        let mut ir_code = IrCode::new(&Program::from_string("+++>++"));
390        ir_code.optimize();
391        let mut iter = ir_code.iter();
392
393        assert_matches!(iter.next(), Some(IrOp::Add(_, 3)));
394        assert_matches!(iter.next(), Some(IrOp::Right(_, 1)));
395        assert_matches!(iter.next(), Some(IrOp::Add(_, 2)));
396        assert_matches!(iter.next(), None);
397    }
398
399    #[test]
400    fn optimizes_consecutive_mixed_adds() {
401        let mut ir_code = IrCode::new(&Program::from_string("+++-->---++>--+++>++---"));
402        ir_code.optimize();
403        let mut iter = ir_code.iter();
404
405        assert_matches!(iter.next(), Some(IrOp::Add(_, 1)));
406        assert_matches!(iter.next(), Some(IrOp::Right(_, 1)));
407        assert_matches!(iter.next(), Some(IrOp::Sub(_, 1)));
408        assert_matches!(iter.next(), Some(IrOp::Right(_, 1)));
409        assert_matches!(iter.next(), Some(IrOp::Add(_, 1)));
410        assert_matches!(iter.next(), Some(IrOp::Right(_, 1)));
411        assert_matches!(iter.next(), Some(IrOp::Sub(_, 1)));
412        assert_matches!(iter.next(), None);
413    }
414
415    #[test]
416    fn optimizes_consecutive_mixed_lefts_rights() {
417        let mut ir_code = IrCode::new(&Program::from_string(">>><<+<<<>>+<<>>>+>><<<"));
418        ir_code.optimize();
419        let mut iter = ir_code.iter();
420
421        assert_matches!(iter.next(), Some(IrOp::Right(_, 1)));
422        assert_matches!(iter.next(), Some(IrOp::Add(_, 1)));
423        assert_matches!(iter.next(), Some(IrOp::Left(_, 1)));
424        assert_matches!(iter.next(), Some(IrOp::Add(_, 1)));
425        assert_matches!(iter.next(), Some(IrOp::Right(_, 1)));
426        assert_matches!(iter.next(), Some(IrOp::Add(_, 1)));
427        assert_matches!(iter.next(), Some(IrOp::Left(_, 1)));
428        assert_matches!(iter.next(), None);
429    }
430
431    #[test]
432    fn optimizes_consecutive_subtractions() {
433        let mut ir_code = IrCode::new(&Program::from_string("--->-"));
434        ir_code.optimize();
435        let mut iter = ir_code.iter();
436
437        assert_matches!(iter.next(), Some(IrOp::Sub(_, 3)));
438        assert_matches!(iter.next(), Some(IrOp::Right(_, 1)));
439        assert_matches!(iter.next(), Some(IrOp::Sub(_, 1)));
440        assert_matches!(iter.next(), None);
441    }
442
443    #[test]
444    fn optimizes_consecutive_lefts_rights() {
445        let mut ir_code = IrCode::new(&Program::from_string(">>+>>>-<<<<+"));
446
447        ir_code.optimize();
448        let mut iter = ir_code.iter();
449
450        assert_matches!(iter.next(), Some(IrOp::Right(_, 2)));
451        assert_matches!(iter.next(), Some(IrOp::Add(_, 1)));
452        assert_matches!(iter.next(), Some(IrOp::Right(_, 3)));
453        assert_matches!(iter.next(), Some(IrOp::Sub(_, 1)));
454        assert_matches!(iter.next(), Some(IrOp::Left(_, 4)));
455        assert_matches!(iter.next(), Some(IrOp::Add(_, 1)));
456        assert_matches!(iter.next(), None);
457    }
458
459    #[test]
460    fn optimizes_clear_loops() {
461        let mut ir_code = IrCode::new(&Program::from_string("[-]>[+]>"));
462
463        ir_code.optimize();
464        let mut iter = ir_code.iter();
465
466        assert_matches!(iter.next(), Some(IrOp::SetIndirect(_, 0)));
467        assert_matches!(iter.next(), Some(IrOp::Right(_, 1)));
468        assert_matches!(iter.next(), Some(IrOp::SetIndirect(_, 0)));
469        assert_matches!(iter.next(), Some(IrOp::Right(_, 1)));
470        assert_matches!(iter.next(), None);
471    }
472
473    #[test]
474    fn optimizes_adds_following_preceding_clear_loops() {
475        let mut ir_code = IrCode::new(&Program::from_string("+[-]+++++>-[+]----"));
476
477        ir_code.optimize();
478        let mut iter = ir_code.iter();
479
480        assert_matches!(iter.next(), Some(IrOp::SetIndirect(_, 5)));
481        assert_matches!(iter.next(), Some(IrOp::Right(_, 1)));
482        assert_matches!(iter.next(), Some(IrOp::SetIndirect(_, 252)));
483        assert_matches!(iter.next(), None);
484    }
485
486    #[test]
487    fn optimizes_consecutive_sets() {
488        let mut ir_code = IrCode::new(&Program::from_string("+[-]+++++-[+]----"));
489
490        ir_code.optimize();
491        let mut iter = ir_code.iter();
492
493        assert_eq!(ir_code.len(), 1);
494
495        assert_matches!(iter.next(), Some(IrOp::SetIndirect(_, 252)));
496        assert_matches!(iter.next(), None);
497    }
498}