nom_learn/
lib.rs

1mod mem;
2
3pub use mem::Mem;
4use nom::branch::alt;
5use nom::bytes::complete::tag;
6use nom::character::complete as c;
7use nom::combinator::{opt, recognize};
8use nom::multi::many0;
9use nom::sequence::{delimited, pair, preceded, separated_pair, terminated, tuple};
10use nom::IResult;
11use std::borrow::Borrow;
12use std::collections::HashMap;
13use std::error::Error;
14use text_io::read;
15
16pub type DefinedFunc<'a> = Box<dyn FnMut(i128) -> Result<i128, Box<dyn Error>> + 'a>;
17pub fn builtin_callables<'a>() -> HashMap<&'static str, DefinedFunc<'a>> {
18    let mut callables: HashMap<&str, DefinedFunc<'a>> = HashMap::new();
19    callables.insert(
20        "write_int",
21        Box::new(|e| {
22            print!("{}", e);
23            Ok(e)
24        }),
25    );
26    callables.insert(
27        "write_char",
28        Box::new(|e| {
29            print!("{}", e as u8 as char);
30            Ok(e)
31        }),
32    );
33    callables.insert(
34        "read_int",
35        Box::new(|_| {
36            let res: i128 = read!();
37            Ok(res)
38        }),
39    );
40    callables.insert(
41        "read_char",
42        Box::new(|_| {
43            let res: char = read!();
44            Ok(res as i128)
45        }),
46    );
47    callables
48}
49
50#[derive(Debug)]
51pub enum Expr<'a> {
52    BinOp(&'a str, Box<Expr<'a>>, Box<Expr<'a>>),
53    UnOp(&'a str, Box<Expr<'a>>),
54    Call(&'a str, Option<Box<Expr<'a>>>),
55    Ident(&'a str),
56    Int(i128),
57}
58
59impl<'a> Expr<'a> {
60    pub fn eval<'b>(
61        &self,
62        registers: &'a HashMap<&'a str, i128>,
63        mem: &mut Mem<i128>,
64        callables: &mut HashMap<&str, DefinedFunc<'b>>,
65    ) -> Result<i128, Box<dyn Error>> {
66        match self {
67            Expr::BinOp(op, l, r) => Ok(match op {
68                &"+" => l.eval(registers, mem, callables)? + r.eval(registers, mem, callables)?,
69                &"-" => l.eval(registers, mem, callables)? - r.eval(registers, mem, callables)?,
70                &"*" => l.eval(registers, mem, callables)? * r.eval(registers, mem, callables)?,
71                &"/" => {
72                    let (lv, rv) = (
73                        l.eval(registers, mem, callables)?,
74                        r.eval(registers, mem, callables)?,
75                    );
76                    if rv == 0 {
77                        return Err("attempt to divide by zero".into());
78                    }
79                    lv / rv
80                }
81                &"%" => l.eval(registers, mem, callables)? % r.eval(registers, mem, callables)?,
82                &"^" => l
83                    .eval(registers, mem, callables)?
84                    .pow(r.eval(registers, mem, callables)? as u32),
85                &">" => {
86                    (l.eval(registers, mem, callables)? > r.eval(registers, mem, callables)?)
87                        as i128
88                }
89                &">=" => {
90                    (l.eval(registers, mem, callables)? >= r.eval(registers, mem, callables)?)
91                        as i128
92                }
93                &"<" => {
94                    (l.eval(registers, mem, callables)? < r.eval(registers, mem, callables)?)
95                        as i128
96                }
97                &"<=" => {
98                    (l.eval(registers, mem, callables)? <= r.eval(registers, mem, callables)?)
99                        as i128
100                }
101                &"==" => {
102                    (l.eval(registers, mem, callables)? == r.eval(registers, mem, callables)?)
103                        as i128
104                }
105                &"!=" => {
106                    (l.eval(registers, mem, callables)? != r.eval(registers, mem, callables)?)
107                        as i128
108                }
109                &"&&" => {
110                    (l.eval(registers, mem, callables)? != 0
111                        && r.eval(registers, mem, callables)? != 0) as i128
112                }
113                &"||" => {
114                    (l.eval(registers, mem, callables)? != 0
115                        || r.eval(registers, mem, callables)? != 0) as i128
116                }
117                _ => unreachable!(),
118            }),
119            Expr::UnOp(op, e) => Ok(match op {
120                &"+" => e.eval(registers, mem, callables)?,
121                &"-" => -e.eval(registers, mem, callables)?,
122                &"*" => {
123                    let start = e.eval(registers, mem, callables)?;
124                    match mem.mem.get(start as usize) {
125                        Some(Some(res)) => *res,
126                        _ => return Err("visiting invalid memory".into()),
127                    }
128                }
129                &"!" => (e.eval(registers, mem, callables)? == 0) as i128,
130                _ => unreachable!(),
131            }),
132            Expr::Call(fname, opt_e) => Ok(match (fname, opt_e) {
133                (&"malloc", Some(e)) => {
134                    let size = e.eval(registers, mem, callables)? as usize;
135                    mem.malloc(size, 0) as i128
136                }
137                (&"free", Some(e)) => {
138                    let start = e.eval(registers, mem, callables)? as usize;
139                    mem.free(start) as i128
140                }
141                (&otherwise, e) => {
142                    let arg = e
143                        .as_ref()
144                        .and_then(|e| Some(e.eval(registers, mem, callables)))
145                        .unwrap_or(Ok(0))?;
146                    callables
147                        .get_mut(otherwise)
148                        .ok_or(format!("undefined function: {:?}", otherwise))?(
149                        arg
150                    )?
151                }
152            }),
153            Expr::Int(i) => Ok(*i),
154            Expr::Ident(x) => registers
155                .get(x)
156                .and_then(|e| Some(*e))
157                .ok_or(format!("undefined variable: {x}").into()),
158        }
159    }
160}
161
162pub fn identifier(s: &str) -> IResult<&str, &str> {
163    recognize(tuple((
164        alt((tag("_"), c::alpha1)),
165        many0(alt((tag("_"), c::alphanumeric1))),
166    )))(s)
167}
168
169pub fn parse_uint(input: &str) -> IResult<&str, Box<Expr>> {
170    let (rem, res) = c::u128(input)?;
171    Ok((rem, Box::new(Expr::Int(res as i128))))
172}
173
174pub fn parse_ident(input: &str) -> IResult<&str, Box<Expr>> {
175    let (rem, res) = identifier(input)?;
176    Ok((rem, Box::new(Expr::Ident(res))))
177}
178
179pub fn parse_call(input: &str) -> IResult<&str, Box<Expr>> {
180    let (rem, res) = tuple((
181        identifier,
182        delimited(c::multispace0, tag("("), c::multispace0),
183        opt(parse_expr),
184        preceded(c::multispace0, tag(")")),
185    ))(input)?;
186    Ok((rem, Box::new(Expr::Call(res.0, res.2))))
187}
188
189pub fn parse_single_expr(input: &str) -> IResult<&str, Box<Expr>> {
190    alt((
191        parse_uint,
192        parse_call,
193        parse_ident,
194        delimited(
195            terminated(tag("("), c::multispace0),
196            parse_expr,
197            preceded(c::multispace0, tag(")")),
198        ),
199    ))(input)
200}
201
202pub fn parse_pow(input: &str) -> IResult<&str, Box<Expr>> {
203    match tuple((
204        parse_single_expr,
205        delimited(c::multispace0, tag("^"), c::multispace0),
206        parse_pow,
207    ))(input)
208    {
209        Ok((rem, res)) => Ok((rem, Box::new(Expr::BinOp("^", res.0, res.2)))),
210        _ => parse_single_expr(input),
211    }
212}
213
214pub fn parse_higher_unop(input: &str) -> IResult<&str, Box<Expr>> {
215    fn higher_unop(input: &str) -> IResult<&str, Box<Expr>> {
216        let (rem, res) = tuple((terminated(tag("*"), c::multispace0), parse_higher_unop))(input)?;
217        Ok((rem, Box::new(Expr::UnOp(res.0, res.1))))
218    }
219    alt((higher_unop, parse_pow))(input)
220}
221
222pub fn parse_higher_binop(input: &str) -> IResult<&str, Box<Expr>> {
223    let (rem, (mut res, res1)) = tuple((
224        parse_higher_unop,
225        many0(tuple((
226            delimited(
227                c::multispace0,
228                alt((tag("*"), tag("/"), tag("%"))),
229                c::multispace0,
230            ),
231            parse_higher_unop,
232        ))),
233    ))(input)?;
234    for (op, e) in res1.into_iter() {
235        res = Box::new(Expr::BinOp(op, res, e));
236    }
237    Ok((rem, res))
238}
239
240pub fn parse_lower_unop(input: &str) -> IResult<&str, Box<Expr>> {
241    fn lower_unop(input: &str) -> IResult<&str, Box<Expr>> {
242        let (rem, res) = tuple((
243            terminated(alt((tag("+"), tag("-"))), c::multispace0),
244            parse_lower_unop,
245        ))(input)?;
246        Ok((rem, Box::new(Expr::UnOp(res.0, res.1))))
247    }
248    alt((lower_unop, parse_higher_binop))(input)
249}
250
251pub fn parse_lower_binop(input: &str) -> IResult<&str, Box<Expr>> {
252    let (rem, (mut res, res1)) = tuple((
253        parse_lower_unop,
254        many0(tuple((
255            delimited(c::multispace0, alt((tag("+"), tag("-"))), c::multispace0),
256            parse_lower_unop,
257        ))),
258    ))(input)?;
259    for (op, e) in res1.into_iter() {
260        res = Box::new(Expr::BinOp(op, res, e));
261    }
262    Ok((rem, res))
263}
264
265pub fn parse_cmp_binop(input: &str) -> IResult<&str, Box<Expr>> {
266    let (rem, (mut res, res1)) = tuple((
267        parse_lower_binop,
268        many0(tuple((
269            delimited(
270                c::multispace0,
271                alt((
272                    tag(">="),
273                    tag(">"),
274                    tag("<="),
275                    tag("<"),
276                    tag("=="),
277                    tag("!="),
278                )),
279                c::multispace0,
280            ),
281            parse_lower_binop,
282        ))),
283    ))(input)?;
284    for (op, e) in res1.into_iter() {
285        res = Box::new(Expr::BinOp(op, res, e));
286    }
287    Ok((rem, res))
288}
289
290pub fn parse_not_unop(input: &str) -> IResult<&str, Box<Expr>> {
291    fn not_unop(input: &str) -> IResult<&str, Box<Expr>> {
292        let (rem, res) = tuple((terminated(tag("!"), c::multispace0), parse_not_unop))(input)?;
293        Ok((rem, Box::new(Expr::UnOp(res.0, res.1))))
294    }
295    alt((not_unop, parse_cmp_binop))(input)
296}
297
298pub fn parse_and_binop(input: &str) -> IResult<&str, Box<Expr>> {
299    let (rem, (mut res, res1)) = tuple((
300        parse_not_unop,
301        many0(tuple((
302            delimited(c::multispace0, tag("&&"), c::multispace0),
303            parse_not_unop,
304        ))),
305    ))(input)?;
306    for (op, e) in res1.into_iter() {
307        res = Box::new(Expr::BinOp(op, res, e));
308    }
309    Ok((rem, res))
310}
311
312pub fn parse_or_binop(input: &str) -> IResult<&str, Box<Expr>> {
313    let (rem, (mut res, res1)) = tuple((
314        parse_and_binop,
315        many0(tuple((
316            delimited(c::multispace0, tag("||"), c::multispace0),
317            parse_and_binop,
318        ))),
319    ))(input)?;
320    for (op, e) in res1.into_iter() {
321        res = Box::new(Expr::BinOp(op, res, e));
322    }
323    Ok((rem, res))
324}
325
326pub fn parse_expr(input: &str) -> IResult<&str, Box<Expr>> {
327    parse_or_binop(input)
328}
329
330#[derive(Debug)]
331pub enum Cmd<'a> {
332    Expr(Box<Expr<'a>>),
333    Decl(&'a str),
334    Assign(Box<Expr<'a>>, Box<Expr<'a>>),
335    Seq(Vec<Box<Cmd<'a>>>),
336    If(Box<Expr<'a>>, Box<Cmd<'a>>, Box<Cmd<'a>>),
337    While(Box<Expr<'a>>, Box<Cmd<'a>>),
338}
339
340impl<'a> Cmd<'a> {
341    pub fn exec<'b>(
342        &self,
343        registers: &mut HashMap<&'a str, i128>,
344        mem: &mut Mem<i128>,
345        callables: &mut HashMap<&str, DefinedFunc<'b>>,
346    ) -> Result<(), Box<dyn Error>> {
347        match self {
348            Cmd::Expr(e) => {
349                e.eval(registers, mem, callables)?;
350            }
351            Cmd::Decl(ident) => {
352                registers.insert(ident, 0);
353            }
354            Cmd::Assign(e1, e2) => match e1.borrow() {
355                Expr::UnOp("*", e1) => {
356                    let tmp = e2.eval(registers, mem, callables)?;
357                    let index = e1.eval(registers, mem, callables)? as usize;
358                    match mem.mem.get_mut(index) {
359                        Some(m) => {
360                            *m = Some(tmp);
361                        }
362                        None => return Err("cannot assign to invalid memory".into()),
363                    }
364                }
365                Expr::Ident(ident) => {
366                    let tmp = e2.eval(registers, mem, callables)?;
367                    registers.insert(ident, tmp);
368                }
369                _ => return Err(format!("cannot assign to {:?}", e1).into()),
370            },
371            Cmd::Seq(arr) => {
372                for c in arr.iter() {
373                    c.exec(registers, mem, callables)?;
374                }
375            }
376            Cmd::If(cond, c1, c2) => {
377                if cond.eval(registers, mem, callables)? != 0 {
378                    c1.exec(registers, mem, callables)?;
379                } else {
380                    c2.exec(registers, mem, callables)?;
381                }
382            }
383            Cmd::While(cond, c) => {
384                while cond.eval(registers, mem, callables)? != 0 {
385                    c.exec(registers, mem, callables)?;
386                }
387            }
388        };
389        Ok(())
390    }
391}
392
393pub fn parse_expr_cmd(input: &str) -> IResult<&str, Box<Cmd>> {
394    let (rem, res) = parse_expr(input)?;
395    Ok((rem, Box::new(Cmd::Expr(res))))
396}
397
398pub fn parse_decl(input: &str) -> IResult<&str, Box<Cmd>> {
399    let (rem, res) = preceded(tuple((tag("var"), c::multispace1)), identifier)(input)?;
400    Ok((rem, Box::new(Cmd::Decl(res))))
401}
402
403pub fn parse_assign(input: &str) -> IResult<&str, Box<Cmd>> {
404    let (rem, res) = separated_pair(
405        parse_expr,
406        delimited(c::multispace0, tag("="), c::multispace0),
407        parse_expr,
408    )(input)?;
409    Ok((rem, Box::new(Cmd::Assign(res.0, res.1))))
410}
411
412pub fn parse_single_cmd(input: &str) -> IResult<&str, Box<Cmd>> {
413    alt((parse_decl, parse_assign, parse_expr_cmd))(input)
414}
415
416pub fn parse_if(input: &str) -> IResult<&str, Box<Cmd>> {
417    let (rem, res) = tuple((
418        recognize(tuple((tag("if"), c::multispace1))),
419        parse_expr,
420        recognize(tuple((
421            opt(preceded(c::multispace1, tag("then"))),
422            c::multispace0,
423            tag("{"),
424            c::multispace0,
425        ))),
426        parse_cmd,
427        recognize(pair(c::multispace0, tag("}"))),
428        opt(delimited(
429            tuple((
430                c::multispace0,
431                tag("else"),
432                c::multispace0,
433                tag("{"),
434                c::multispace0,
435            )),
436            parse_cmd,
437            tuple((c::multispace0, tag("}"))),
438        )),
439    ))(input)?;
440    Ok((
441        rem,
442        Box::new(Cmd::If(
443            res.1,
444            res.3,
445            match res.5 {
446                Some(x) => x,
447                None => Box::new(Cmd::Seq(vec![])),
448            },
449        )),
450    ))
451}
452
453fn parse_while(input: &str) -> IResult<&str, Box<Cmd>> {
454    let (rem, res) = tuple((
455        tag("while"),
456        preceded(c::multispace1, parse_expr),
457        opt(preceded(c::multispace1, tag("do"))),
458        delimited(c::multispace0, tag("{"), c::multispace0),
459        parse_cmd,
460        preceded(c::multispace0, tag("}")),
461    ))(input)?;
462    Ok((rem, Box::new(Cmd::While(res.1, res.4))))
463}
464
465pub fn parse_block_cmd(input: &str) -> IResult<&str, Box<Cmd>> {
466    alt((parse_if, parse_while))(input)
467}
468
469pub fn parse_cmd(input: &str) -> IResult<&str, Box<Cmd>> {
470    let (rem, (mut res, opt_cmd)) = tuple((
471        many0(preceded(
472            c::multispace0,
473            alt((
474                terminated(parse_single_cmd, preceded(c::multispace0, tag(";"))),
475                terminated(parse_block_cmd, opt(preceded(c::multispace0, tag(";")))),
476            )),
477        )),
478        opt(preceded(c::multispace0, parse_single_cmd)),
479    ))(input)?;
480    if let Some(cmd) = opt_cmd {
481        res.push(cmd);
482    }
483    Ok((rem, Box::new(Cmd::Seq(res))))
484}
485
486#[test]
487fn test_parse_expr() {
488    let src = "1000000 / (1 + 2 * 3 ^ 4 + 5 - 7 * 5474 / 9110)";
489    let (remaining_input, output) = parse_expr(src).unwrap();
490    println!("{:?} {:?}", remaining_input, output);
491}
492
493#[test]
494fn test_eval_expr() {
495    let src = "write_int ( read_int ( ) * k ) + write_char ( 10 ) - 10 + * ( malloc ( 2 ) + 1 )";
496    let (remaining_input, output) = parse_expr(src).unwrap();
497    let mut registers = HashMap::new();
498    registers.insert("k", 3000);
499    println!("{:?} {:?}", remaining_input, output);
500    println!(
501        "{:?}",
502        output.eval(&registers, &mut Mem::new(), &mut builtin_callables())
503    );
504}
505
506#[test]
507fn test_parse_cmd() {
508    let src = "
509    var n; var i; var p; var q; var s;
510    n = read_int();
511    i = 0; p = 0;
512    while (i < n) do {
513        q = malloc(16);
514        * q = read_int();
515        * (q + 8) = p;
516        p = q;
517        i = i + 1
518    };
519    s = 0;
520    while (p != 0) do {
521        s = s + * p;
522        p = * (p + 8)
523    };
524    write_int(s);
525    write_char(10)
526    ";
527    let (remaining_input, output) =
528        delimited(c::multispace0, parse_cmd, c::multispace0)(&src).unwrap();
529    println!("{:?} {:?}", remaining_input, output);
530}
531
532#[test]
533fn test_exec_cmd() -> Result<(), Box<dyn Error>> {
534    let (mut registers, mut mem) = (HashMap::new(), Mem::new());
535    let mut callables = builtin_callables();
536    parse_cmd(
537        "
538    var n; var i; var p; var q; var s;
539    n = read_int();
540    i = 0; p = 0;
541    while (i < n) do {
542        q = malloc(2);
543        * q = read_int();
544        * (q + 1) = p;
545        p = q;
546        i = i + 1
547    };
548    s = 0;
549    while (p != 0) do {
550        s = s + * p;
551        tmp = * (p + 1);
552        free(p);
553        p = tmp;
554    };
555    write_int(s);
556    write_char(10)
557    ",
558    )?
559    .1
560    .exec(&mut registers, &mut mem, &mut callables)?;
561    println!("{:?}", (registers, mem));
562    Ok(())
563}
564
565#[test]
566fn test_rustfunc() -> Result<(), Box<dyn Error>> {
567    let (mut registers, mut mem) = (HashMap::new(), Mem::new());
568    let mut callables = builtin_callables();
569    callables.insert("add2", Box::new(|x| Ok(x + 2)));
570    parse_cmd(
571        "
572    write_int(add2(100));
573    write_char(10)
574    ",
575    )?
576    .1
577    .exec(&mut registers, &mut mem, &mut callables)?;
578    println!("{:?}", (registers, mem));
579    Ok(())
580}
581
582#[test]
583fn test_typ() -> Result<(), Box<dyn Error>> {
584    use std::fmt::Write;
585    let input: &[u8] = "
586write_int(123 / 0);
587write_char(10);
588a = 100000000000000000000000000000000000 * 1000;
589write_int(a);
590    "
591    .as_bytes();
592
593    let src = std::str::from_utf8(input).map_err(|op| format!("Invalid UTF-8: {}", op))?;
594    let mut buf = String::new();
595    {
596        let (mut registers, mut mem) = (HashMap::new(), Mem::new());
597        let mut callables: HashMap<&str, DefinedFunc> = builtin_callables();
598        callables.insert("add2", Box::new(|x| Ok(x + 2)));
599        callables.insert(
600            "write_int",
601            Box::new(|x| {
602                buf.write_str(&x.to_string())?;
603                Ok(x)
604            }),
605        );
606        parse_cmd(src)?
607            .1
608            .exec(&mut registers, &mut mem, &mut callables)?;
609    }
610
611    let output: Result<Vec<u8>, String> = Ok(buf.as_bytes().to_vec());
612    println!("{:?}", output);
613    Ok(())
614}