reqlang_expr/
compiler.rs

1//! The compiler and associated types
2
3use std::rc::Rc;
4
5use crate::{
6    ast::{Expr, ExprS},
7    builtins::{BuiltinFn, BuiltinFns},
8    errors::{
9        CompileError::{self, WrongNumberOfArgs},
10        ExprError, ExprErrorS, ExprResult,
11    },
12    prelude::FnArg,
13    types::Type,
14};
15
16pub mod opcode {
17    iota::iota! {
18        pub const
19        CALL: u8 = iota;,
20        GET,
21        CONSTANT,
22        TRUE,
23        FALSE,
24        NOT,
25        EQ,
26        TYPE
27    }
28}
29
30/// Types of lookups for the GET op code
31///
32/// Used at compile time to encode lookup indexes
33///
34/// Used at runtime to use lookup indexes to reference runtime values
35pub mod lookup {
36    iota::iota! {
37        pub const
38        BUILTIN: u8 = iota;,
39        VAR,
40        PROMPT,
41        SECRET,
42        USER_BUILTIN,
43        CLIENT_CTX
44    }
45}
46
47/// Try to get a string from a list
48fn get(list: &[String], identifier: &str) -> Option<u8> {
49    list.iter().position(|x| x == identifier).map(|i| i as u8)
50}
51
52#[derive(Debug)]
53pub struct CompileTimeEnv {
54    builtins: Vec<Rc<BuiltinFn>>,
55    user_builtins: Vec<Rc<BuiltinFn>>,
56    vars: Vec<String>,
57    prompts: Vec<String>,
58    secrets: Vec<String>,
59    client_context: Vec<String>,
60}
61
62impl Default for CompileTimeEnv {
63    fn default() -> Self {
64        Self {
65            builtins: vec![
66                Rc::new(BuiltinFn {
67                    name: String::from("id"),
68                    args: vec![FnArg::new("value", Type::Value)],
69                    return_type: Type::Value,
70                    func: Rc::new(BuiltinFns::id),
71                }),
72                Rc::new(BuiltinFn {
73                    name: String::from("noop"),
74                    args: vec![],
75                    return_type: Type::String,
76                    func: Rc::new(BuiltinFns::noop),
77                }),
78                Rc::new(BuiltinFn {
79                    name: String::from("is_empty"),
80                    args: vec![FnArg::new("value", Type::String)],
81                    return_type: Type::String,
82                    func: Rc::new(BuiltinFns::is_empty),
83                }),
84                Rc::new(BuiltinFn {
85                    name: String::from("and"),
86                    args: vec![FnArg::new("a", Type::Bool), FnArg::new("b", Type::Bool)],
87                    return_type: Type::Bool,
88                    func: Rc::new(BuiltinFns::and),
89                }),
90                Rc::new(BuiltinFn {
91                    name: String::from("or"),
92                    args: vec![FnArg::new("a", Type::Bool), FnArg::new("b", Type::Bool)],
93                    return_type: Type::Bool,
94                    func: Rc::new(BuiltinFns::or),
95                }),
96                Rc::new(BuiltinFn {
97                    name: String::from("cond"),
98                    args: vec![
99                        FnArg::new("cond", Type::Bool),
100                        FnArg::new("then", Type::Value),
101                        FnArg::new("else", Type::Value),
102                    ],
103                    return_type: Type::Bool,
104                    func: Rc::new(BuiltinFns::cond),
105                }),
106                Rc::new(BuiltinFn {
107                    name: String::from("to_str"),
108                    args: vec![FnArg::new("value", Type::Value)],
109                    return_type: Type::String,
110                    func: Rc::new(BuiltinFns::to_str),
111                }),
112                Rc::new(BuiltinFn {
113                    name: String::from("concat"),
114                    args: vec![
115                        FnArg::new("a", Type::String),
116                        FnArg::new("b", Type::String),
117                        FnArg::new_varadic("rest", Type::String),
118                    ],
119                    return_type: Type::String,
120                    func: Rc::new(BuiltinFns::concat),
121                }),
122                Rc::new(BuiltinFn {
123                    name: String::from("contains"),
124                    args: vec![
125                        FnArg::new("needle", Type::String),
126                        FnArg::new("haystack", Type::String),
127                    ],
128                    return_type: Type::Bool,
129                    func: Rc::new(BuiltinFns::contains),
130                }),
131                Rc::new(BuiltinFn {
132                    name: String::from("trim"),
133                    args: vec![FnArg::new("value", Type::String)],
134                    return_type: Type::String,
135                    func: Rc::new(BuiltinFns::trim),
136                }),
137                Rc::new(BuiltinFn {
138                    name: String::from("trim_start"),
139                    args: vec![FnArg::new("value", Type::String)],
140                    return_type: Type::String,
141                    func: Rc::new(BuiltinFns::trim_start),
142                }),
143                Rc::new(BuiltinFn {
144                    name: String::from("trim_end"),
145                    args: vec![FnArg::new("value", Type::String)],
146                    return_type: Type::String,
147                    func: Rc::new(BuiltinFns::trim_end),
148                }),
149                Rc::new(BuiltinFn {
150                    name: String::from("lowercase"),
151                    args: vec![FnArg::new("value", Type::String)],
152                    return_type: Type::String,
153                    func: Rc::new(BuiltinFns::lowercase),
154                }),
155                Rc::new(BuiltinFn {
156                    name: String::from("uppercase"),
157                    args: vec![FnArg::new("value", Type::String)],
158                    return_type: Type::String,
159                    func: Rc::new(BuiltinFns::uppercase),
160                }),
161            ],
162            user_builtins: vec![],
163            vars: Vec::new(),
164            prompts: Vec::new(),
165            secrets: Vec::new(),
166            client_context: Vec::new(),
167        }
168    }
169}
170
171impl CompileTimeEnv {
172    pub fn new(
173        vars: Vec<String>,
174        prompts: Vec<String>,
175        secrets: Vec<String>,
176        client_context: Vec<String>,
177    ) -> Self {
178        Self {
179            vars,
180            prompts,
181            secrets,
182            client_context,
183            ..Default::default()
184        }
185    }
186
187    pub fn get_builtin_index(&self, name: &str) -> Option<(&Rc<BuiltinFn>, u8)> {
188        let index = self.builtins.iter().position(|x| x.name == name);
189
190        let result = index.map(|i| (self.builtins.get(i).unwrap(), i as u8));
191        result
192    }
193
194    pub fn get_user_builtin_index(&self, name: &str) -> Option<(&Rc<BuiltinFn>, u8)> {
195        let index = self.user_builtins.iter().position(|x| x.name == name);
196
197        let result = index.map(|i| (self.user_builtins.get(i).unwrap(), i as u8));
198        result
199    }
200
201    pub fn add_user_builtins(&mut self, builtins: Vec<Rc<BuiltinFn>>) {
202        for builtin in builtins {
203            self.add_user_builtin(builtin);
204        }
205    }
206
207    pub fn add_user_builtin(&mut self, builtin: Rc<BuiltinFn>) {
208        self.user_builtins.push(builtin);
209    }
210
211    pub fn get_builtin(&self, index: usize) -> Option<&Rc<BuiltinFn>> {
212        self.builtins.get(index)
213    }
214
215    pub fn get_user_builtin(&self, index: usize) -> Option<&Rc<BuiltinFn>> {
216        self.user_builtins.get(index)
217    }
218
219    pub fn get_var(&self, index: usize) -> Option<&String> {
220        self.vars.get(index)
221    }
222
223    pub fn get_prompt(&self, index: usize) -> Option<&String> {
224        self.prompts.get(index)
225    }
226
227    pub fn get_secret(&self, index: usize) -> Option<&String> {
228        self.secrets.get(index)
229    }
230
231    pub fn get_client_context(&self, index: usize) -> Option<&String> {
232        self.client_context.get(index)
233    }
234
235    pub fn add_to_client_context(&mut self, key: &str) -> usize {
236        match self.client_context.iter().position(|x| x == key) {
237            Some(i) => i,
238            None => {
239                self.client_context.push(key.to_string());
240
241                self.client_context.len() - 1
242            }
243        }
244    }
245
246    pub fn add_keys_to_client_context(&mut self, keys: Vec<String>) {
247        self.client_context.extend(keys);
248    }
249
250    pub fn get_client_context_index(&self, name: &str) -> Option<(&String, u8)> {
251        let index = self
252            .client_context
253            .iter()
254            .position(|context_name| context_name == name);
255
256        let result = index.map(|i| (self.client_context.get(i).unwrap(), i as u8));
257        result
258    }
259}
260
261/// The compiled bytecode for an expression
262#[derive(Debug, Clone, PartialEq)]
263pub struct ExprByteCode {
264    version: [u8; 4],
265    codes: Vec<u8>,
266    strings: Vec<String>,
267}
268
269impl ExprByteCode {
270    pub fn new(codes: Vec<u8>, strings: Vec<String>) -> Self {
271        let version_bytes = get_version_bytes();
272        let version_bytes_from_codes = &codes[0..4];
273
274        assert_eq!(
275            version_bytes, version_bytes_from_codes,
276            "Version bytes do not match"
277        );
278
279        let codes = codes[4..].to_vec();
280
281        Self {
282            version: version_bytes,
283            codes,
284            strings,
285        }
286    }
287
288    pub fn version(&self) -> &[u8; 4] {
289        &self.version
290    }
291
292    pub fn codes(&self) -> &[u8] {
293        &self.codes
294    }
295
296    pub fn strings(&self) -> &[String] {
297        &self.strings
298    }
299}
300
301pub fn get_version_bytes() -> [u8; 4] {
302    [
303        env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap(),
304        env!("CARGO_PKG_VERSION_MINOR").parse().unwrap(),
305        env!("CARGO_PKG_VERSION_PATCH").parse().unwrap(),
306        0,
307    ]
308}
309
310/// Compile an [`ast::Expr`] into [`ExprByteCode`]
311pub fn compile(expr: &ExprS, env: &CompileTimeEnv) -> ExprResult<ExprByteCode> {
312    let mut strings: Vec<String> = vec![];
313    let mut codes = vec![];
314
315    codes.extend(get_version_bytes());
316
317    codes.extend(compile_expr(expr, env, &mut strings)?);
318
319    Ok(ExprByteCode::new(codes, strings))
320}
321
322fn compile_expr(
323    (expr, span): &ExprS,
324    env: &CompileTimeEnv,
325    strings: &mut Vec<String>,
326) -> ExprResult<Vec<u8>> {
327    use opcode::*;
328
329    let mut codes = vec![];
330    let mut errs: Vec<ExprErrorS> = vec![];
331
332    match expr {
333        Expr::String(string) => {
334            if let Some(index) = strings.iter().position(|x| x == &string.0) {
335                codes.push(CONSTANT);
336                codes.push(index as u8);
337            } else {
338                strings.push(string.0.clone());
339                let index = strings.len() - 1;
340                codes.push(CONSTANT);
341                codes.push(index as u8);
342            }
343        }
344        Expr::Identifier(identifier) => {
345            let identifier_name = identifier.0.as_str();
346
347            if let Some((_, index)) = env.get_builtin_index(identifier_name) {
348                codes.push(GET);
349                codes.push(lookup::BUILTIN);
350                codes.push(index);
351            } else if let Some((_, index)) = env.get_user_builtin_index(identifier_name) {
352                codes.push(GET);
353                codes.push(lookup::USER_BUILTIN);
354                codes.push(index);
355            } else {
356                let identifier_prefix = &identifier_name[..1];
357                let identifier_suffix = &identifier_name[1..];
358
359                match identifier_prefix {
360                    "?" => {
361                        if let Some(index) = get(&env.prompts, identifier_suffix) {
362                            codes.push(GET);
363                            codes.push(lookup::PROMPT);
364                            codes.push(index);
365                        } else {
366                            errs.push((
367                                CompileError::Undefined(identifier_name.to_string()).into(),
368                                span.clone(),
369                            ));
370                        }
371                    }
372                    "!" => {
373                        if let Some(index) = get(&env.secrets, identifier_suffix) {
374                            codes.push(GET);
375                            codes.push(lookup::SECRET);
376                            codes.push(index);
377                        } else {
378                            errs.push((
379                                CompileError::Undefined(identifier_name.to_string()).into(),
380                                span.clone(),
381                            ));
382                        }
383                    }
384                    ":" => {
385                        if let Some(index) = get(&env.vars, identifier_suffix) {
386                            codes.push(GET);
387                            codes.push(lookup::VAR);
388                            codes.push(index);
389                        } else {
390                            errs.push((
391                                CompileError::Undefined(identifier_name.to_string()).into(),
392                                span.clone(),
393                            ));
394                        }
395                    }
396                    "@" => {
397                        if let Some(index) = get(&env.client_context, identifier_suffix) {
398                            codes.push(GET);
399                            codes.push(lookup::CLIENT_CTX);
400                            codes.push(index);
401                        } else {
402                            errs.push((
403                                CompileError::Undefined(identifier_name.to_string()).into(),
404                                span.clone(),
405                            ));
406                        }
407                    }
408                    _ => {
409                        errs.push((
410                            ExprError::CompileError(CompileError::Undefined(
411                                identifier_name.to_string(),
412                            )),
413                            span.clone(),
414                        ));
415                    }
416                };
417            }
418        }
419        Expr::Call(expr_call) => {
420            let identifier_name = expr_call.callee.0.identifier_name().unwrap_or_default();
421
422            match identifier_name {
423                "type" => {
424                    if expr_call.args.is_empty() {
425                        errs.push((
426                            ExprError::CompileError(WrongNumberOfArgs {
427                                expected: 1,
428                                actual: 0,
429                            }),
430                            span.clone(),
431                        ));
432                    } else if expr_call.args.len() > 1 {
433                        errs.push((
434                            ExprError::CompileError(WrongNumberOfArgs {
435                                expected: 1,
436                                actual: expr_call.args.len(),
437                            }),
438                            span.clone(),
439                        ));
440                    } else {
441                        let arg = expr_call.args.first().expect("should have first argument");
442
443                        match compile_expr(arg, env, strings) {
444                            Ok(arg_bytecode) => {
445                                codes.extend(arg_bytecode);
446                                codes.push(opcode::TYPE);
447                            }
448                            Err(err) => {
449                                errs.extend(err);
450                            }
451                        }
452                    }
453                }
454                "eq" => {
455                    if expr_call.args.is_empty() {
456                        errs.push((
457                            ExprError::CompileError(WrongNumberOfArgs {
458                                expected: 2,
459                                actual: 0,
460                            }),
461                            span.clone(),
462                        ));
463                    } else if expr_call.args.len() == 1 {
464                        errs.push((
465                            ExprError::CompileError(WrongNumberOfArgs {
466                                expected: 2,
467                                actual: 1,
468                            }),
469                            span.clone(),
470                        ));
471                    } else if expr_call.args.len() > 2 {
472                        errs.push((
473                            ExprError::CompileError(WrongNumberOfArgs {
474                                expected: 2,
475                                actual: expr_call.args.len(),
476                            }),
477                            span.clone(),
478                        ));
479                    } else {
480                        let arg = expr_call.args.first().expect("should have first argument");
481
482                        match compile_expr(arg, env, strings) {
483                            Ok(arg_bytecode) => {
484                                codes.extend(arg_bytecode);
485                            }
486                            Err(err) => {
487                                errs.extend(err);
488                            }
489                        }
490
491                        let arg2 = expr_call.args.get(1).expect("should have second argument");
492
493                        match compile_expr(arg2, env, strings) {
494                            Ok(arg_bytecode) => {
495                                codes.extend(arg_bytecode);
496                            }
497                            Err(err) => {
498                                errs.extend(err);
499                            }
500                        }
501
502                        codes.push(opcode::EQ);
503                    }
504                }
505                "not" => {
506                    if expr_call.args.is_empty() {
507                        errs.push((
508                            ExprError::CompileError(WrongNumberOfArgs {
509                                expected: 1,
510                                actual: 0,
511                            }),
512                            span.clone(),
513                        ));
514                    } else if expr_call.args.len() > 1 {
515                        errs.push((
516                            ExprError::CompileError(WrongNumberOfArgs {
517                                expected: 1,
518                                actual: expr_call.args.len(),
519                            }),
520                            span.clone(),
521                        ));
522
523                        let arg = expr_call.args.first().expect("should have first argument");
524
525                        if !arg.0.is_bool() {
526                            errs.push((
527                                CompileError::TypeMismatch {
528                                    expected: Type::Bool,
529                                    actual: arg.0.get_type(),
530                                }
531                                .into(),
532                                arg.1.clone(),
533                            ));
534                        }
535                    } else {
536                        let arg = expr_call.args.first().expect("should have first argument");
537                        if !arg.0.is_bool() {
538                            errs.push((
539                                CompileError::TypeMismatch {
540                                    expected: Type::Bool,
541                                    actual: arg.0.get_type(),
542                                }
543                                .into(),
544                                arg.1.clone(),
545                            ));
546                        }
547
548                        match compile_expr(arg, env, strings) {
549                            Ok(arg_bytecode) => {
550                                codes.extend(arg_bytecode);
551                            }
552                            Err(err) => {
553                                errs.extend(err);
554                            }
555                        }
556
557                        codes.push(opcode::NOT);
558                    }
559                }
560                _ => {
561                    let callee_bytecode = compile_expr(&expr_call.callee, env, strings)?;
562
563                    if let Some(_op) = callee_bytecode.first() {
564                        if let Some(lookup) = callee_bytecode.get(1) {
565                            if let Some(index) = callee_bytecode.get(2) {
566                                match *lookup {
567                                    lookup::BUILTIN => {
568                                        let builtin = env.get_builtin((*index).into()).unwrap();
569
570                                        let call_arity: usize = expr_call.args.len();
571
572                                        if !builtin.arity_matches(call_arity.try_into().unwrap()) {
573                                            errs.push((
574                                                ExprError::CompileError(WrongNumberOfArgs {
575                                                    expected: builtin.arity() as usize,
576                                                    actual: call_arity,
577                                                }),
578                                                span.clone(),
579                                            ));
580                                        }
581                                    }
582                                    lookup::USER_BUILTIN => {
583                                        let builtin =
584                                            env.get_user_builtin((*index).into()).unwrap();
585
586                                        let call_arity: usize = expr_call.args.len();
587
588                                        if !builtin.arity_matches(call_arity.try_into().unwrap()) {
589                                            errs.push((
590                                                ExprError::CompileError(WrongNumberOfArgs {
591                                                    expected: builtin.arity() as usize,
592                                                    actual: call_arity,
593                                                }),
594                                                span.clone(),
595                                            ));
596                                        }
597                                    }
598                                    _ => {}
599                                }
600                            }
601                        }
602                    }
603
604                    codes.extend(callee_bytecode);
605
606                    for arg in expr_call.args.iter() {
607                        match compile_expr(arg, env, strings) {
608                            Ok(arg_bytecode) => {
609                                codes.extend(arg_bytecode);
610                            }
611                            Err(err) => {
612                                errs.extend(err);
613                            }
614                        }
615                    }
616
617                    codes.push(opcode::CALL);
618                    codes.push(expr_call.args.len() as u8);
619                }
620            }
621        }
622        Expr::Bool(value) => match value.0 {
623            true => {
624                codes.push(opcode::TRUE);
625            }
626            false => {
627                codes.push(opcode::FALSE);
628            }
629        },
630        Expr::Error => panic!("tried to compile despite parser errors"),
631    }
632
633    if !errs.is_empty() {
634        return Err(errs);
635    }
636
637    Ok(codes)
638}
639
640#[cfg(test)]
641mod compiler_tests {
642    use super::*;
643
644    #[test]
645    pub fn current_version_bytes() {
646        let version_bytes = get_version_bytes();
647
648        assert_eq!(version_bytes, [0, 7, 0, 0]);
649    }
650
651    #[test]
652    pub fn valid_bytecode_version_bytes() {
653        let mut codes = get_version_bytes().to_vec();
654        codes.push(opcode::TRUE);
655
656        ExprByteCode::new(codes.to_vec(), vec![]);
657    }
658
659    #[test]
660    #[should_panic(expected = "Version bytes do not match")]
661    pub fn invalid_bytecode_version_bytes() {
662        let mut codes: Vec<u8> = [0, 0, 0, 0].to_vec();
663        codes.push(opcode::TRUE);
664
665        ExprByteCode::new(codes.to_vec(), vec![]);
666    }
667
668    #[test]
669    pub fn get_version_bytes_from_bytecode() {
670        let mut codes = get_version_bytes().to_vec();
671        codes.push(opcode::TRUE);
672
673        let bytecode = ExprByteCode::new(codes.to_vec(), vec![]);
674
675        assert_eq!(bytecode.version(), &get_version_bytes());
676    }
677}