1use std::io::{self, BufRead, Write};
13use thiserror::Error;
14
15#[derive(Error, Debug)]
17pub enum Error {
18    #[error("unmatched bracket `{0}`")]
20    SyntaxError(char),
21    #[error("pointer out of bounds on `{0}`")]
23    PointerOutOfBoundsError(char),
24    #[error("unexpected I/O error")]
26    IoError(#[from] io::Error),
27}
28
29const MEMORY_SIZE: usize = 30000;
30
31pub fn run<I: BufRead, O: Write>(source: &str, input: &mut I, output: &mut O) -> Result<(), Error> {
33    let basic_blocks = into_basic_blocks(source)?;
34    let mut bb_no = 0usize;
35    let mut memory = vec![0u8; MEMORY_SIZE];
36    let mut ptr = 0usize;
37    loop {
38        let BasicBlock { instrs, jz, jnz } = &basic_blocks[bb_no];
39        for &instr in instrs {
40            match instr {
41                Cmd::Inc => memory[ptr] = memory[ptr].wrapping_add(1),
42                Cmd::Dec => memory[ptr] = memory[ptr].wrapping_sub(1),
43                Cmd::Left => {
44                    ptr = ptr
45                        .checked_sub(1)
46                        .ok_or(Error::PointerOutOfBoundsError('<'))?
47                }
48                Cmd::Right => {
49                    ptr = Some(ptr + 1)
50                        .filter(|&x| x < MEMORY_SIZE)
51                        .ok_or(Error::PointerOutOfBoundsError('>'))?
52                }
53                Cmd::Getc => {
54                    if let Some(byte) = getc(input)? {
55                        memory[ptr] = byte;
56                    }
57                }
58                Cmd::Putc => putc(output, memory[ptr])?,
59            }
60        }
61        if let &Some(next_bb) = if memory[ptr] == 0 { jz } else { jnz } {
62            bb_no = next_bb;
63        } else {
64            break;
65        }
66    }
67    Ok(())
68}
69
70#[derive(Debug, Clone, Copy)]
71enum Cmd {
72    Inc,
73    Dec,
74    Left,
75    Right,
76    Getc,
77    Putc,
78}
79
80#[derive(Debug)]
81struct BasicBlock {
82    instrs: Vec<Cmd>,
83    jz: Option<usize>,
84    jnz: Option<usize>,
85}
86
87type ByteCodeProgram = Vec<BasicBlock>;
88
89fn into_basic_blocks(source: &str) -> Result<ByteCodeProgram, Error> {
90    let mut bbno_stack = vec![]; let mut basic_blocks = vec![];
92    let mut cur_basic_block = vec![];
93    let mut cur_bb_id = 0usize;
94    for c in source.chars() {
95        match c {
96            '+' => cur_basic_block.push(Cmd::Inc),
97            '-' => cur_basic_block.push(Cmd::Dec),
98            '<' => cur_basic_block.push(Cmd::Left),
99            '>' => cur_basic_block.push(Cmd::Right),
100            ',' => cur_basic_block.push(Cmd::Getc),
101            '.' => cur_basic_block.push(Cmd::Putc),
102            '[' => {
103                let bb = BasicBlock {
106                    instrs: cur_basic_block,
107                    jz: None,
108                    jnz: Some(cur_bb_id + 1),
109                };
110                basic_blocks.push(bb);
111                bbno_stack.push(cur_bb_id);
112                cur_bb_id += 1;
113                cur_basic_block = vec![];
114            }
115            ']' => {
116                let popped = bbno_stack.pop().ok_or(Error::SyntaxError(']'))?;
119                let bb = BasicBlock {
120                    instrs: cur_basic_block,
121                    jz: Some(cur_bb_id + 1),
122                    jnz: Some(popped + 1),
123                };
124                basic_blocks.push(bb);
125                basic_blocks[popped].jz = Some(cur_bb_id + 1);
126                cur_bb_id += 1;
127                cur_basic_block = vec![];
128            }
129            _ => (),
130        }
131    }
132    if !bbno_stack.is_empty() {
133        return Err(Error::SyntaxError('['));
134    }
135    let bb = BasicBlock {
136        instrs: cur_basic_block,
137        jz: None,
138        jnz: None,
139    };
140    basic_blocks.push(bb);
141    Ok(basic_blocks)
142}
143
144fn getc<I: BufRead>(input: &mut I) -> Result<Option<u8>, Error> {
145    let buf = input.fill_buf()?;
146    let value = buf.get(0).copied();
147    input.consume(1);
148    Ok(value)
149}
150
151fn putc<O: Write>(output: &mut O, byte: u8) -> Result<(), Error> {
152    output.write_all(&[byte][..])?;
153    Ok(())
154}
155
156#[cfg(test)]
157mod tests {
158    use std::io::BufReader;
159
160    use super::*;
161
162    #[test]
163    fn test_printer_1() {
164        let code = "
166        >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
167        >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
168        >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
169        >>>>>>>>
170        +[[-<]-[->]<-]<.<<<<.>>>>-.<<-.<.>>.<<<+++.>>>---.<++.";
171        let mut stdin = BufReader::new(&b""[..]);
172        let mut stdout: Vec<u8> = vec![];
173        let res = run(code, &mut stdin, &mut stdout);
174        assert!(res.is_ok());
175        assert_eq!(stdout, b"brainfuck");
176    }
177
178    #[test]
179    fn test_printer_2() {
180        let code = "++[>+++++<-]++[>>>+++++<<<-]++++++[>>+++++++<<-]>[>..........>.<<-]";
182        let mut stdin = BufReader::new(&b""[..]);
183        let mut stdout: Vec<u8> = vec![];
184        let res = run(code, &mut stdin, &mut stdout);
185        assert!(res.is_ok());
186        assert_eq!(stdout, b"**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n**********\n");
187    }
188
189    #[test]
190    fn test_io_1() {
191        let code = ">+[,>++++[<-------->-]<]>++++[<++++++++>-]<[>,]++++[<-------->-]<[[-]++++[<-------->-]<]<[<]>>[.>]";
193        let testcases = [
194            &b"Samantha Vee Hills"[..],
195            b"Bob Dillinger",
196            b"John Jacob Jingleheimer Schmidt",
197            b"Jose Mario Carasco-Williams",
198            b"James Alfred Van Allen",
199        ];
200        let outputs = [
201            &b"Vee"[..],
202            b"",
203            b"Jacob Jingleheimer",
204            b"Mario",
205            b"Alfred Van",
206        ];
207        for (testcase, expected) in testcases.into_iter().zip(outputs) {
208            let mut stdin = BufReader::new(testcase);
209            let mut stdout: Vec<u8> = vec![];
210            let res = run(code, &mut stdin, &mut stdout);
211            assert!(res.is_ok());
212            assert_eq!(stdout, expected);
213        }
214    }
215
216    #[test]
217    fn test_io_2() {
218        let code = ">>>>>>>-[-[-<]>>+<]>-<<+[[[-]<,[->+>+<<]>[-<+>]>>[-<->>+<]<]<<[>>+<<[-]]<[<]>[[.>]<[<]>[-]>]>>>>.<<]";
220        let testcases = [&b"laser bat "[..], b"on the topic of existence "];
221        let outputs = [
222            &b"laserasersererr batatt "[..],
223            b"onn thehee topicopicpicicc off existencexistenceistencestencetenceencencecee ",
224        ];
225        for (testcase, expected) in testcases.into_iter().zip(outputs) {
226            let mut stdin = BufReader::new(testcase);
227            let mut stdout: Vec<u8> = vec![];
228            let res = run(code, &mut stdin, &mut stdout);
229            dbg!(&res);
230            assert!(matches!(res, Err(Error::PointerOutOfBoundsError('>'))));
231            assert_eq!(stdout, expected);
232        }
233    }
234}