reqlang_expr/
compiler.rs

1//! The compiler and associated types
2
3use core::fmt;
4use std::rc::Rc;
5
6use crate::{ast::Expr, errors::ExprResult, vm::Value};
7
8pub mod opcode {
9    iota::iota! {
10        pub const
11        CALL: u8 = iota;,
12        GET,
13        CONSTANT,
14        TRUE,
15        FALSE
16    }
17}
18
19/// Types of lookups for the GET op code
20///
21/// Used at compile time to encode lookup indexes
22///
23/// Used at runtime to use lookup indexes to reference runtime values
24pub mod lookup {
25    iota::iota! {
26        pub const
27        BUILTIN: u8 = iota;,
28        VAR,
29        PROMPT,
30        SECRET,
31        USER_BUILTIN,
32        CLIENT_CTX
33    }
34}
35
36/// Try to get a string from a list
37fn get(list: &[String], identifier: &str) -> Option<u8> {
38    list.iter().position(|x| x == identifier).map(|i| i as u8)
39}
40
41/// Builtin function used in expressions
42pub struct BuiltinFn {
43    // Needs to follow identifier naming rules
44    pub name: String,
45    // Number of arguments the function expects
46    pub arity: u8,
47    // Function used at runtime
48    pub func: Rc<dyn Fn(Vec<Value>) -> Value>,
49}
50
51impl PartialEq for BuiltinFn {
52    fn eq(&self, other: &Self) -> bool {
53        self.name == other.name && self.arity == other.arity
54    }
55}
56
57impl fmt::Debug for BuiltinFn {
58    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
59        write!(f, "builtin {}({})", self.name, self.arity)
60    }
61}
62
63pub struct BuiltinFns;
64
65impl BuiltinFns {
66    pub fn id(args: Vec<Value>) -> Value {
67        let arg = args.first().unwrap();
68
69        arg.get_string().into()
70    }
71
72    pub fn noop(_: Vec<Value>) -> Value {
73        Value::String(String::from("noop"))
74    }
75
76    pub fn is_empty(args: Vec<Value>) -> Value {
77        let string_arg = args
78            .first()
79            .expect("should have string expression passed")
80            .get_string();
81
82        Value::Bool(string_arg.is_empty())
83    }
84
85    pub fn not(args: Vec<Value>) -> Value {
86        let bool_arg = args
87            .first()
88            .expect("should have boolean expression passed")
89            .get_bool();
90
91        Value::Bool(!bool_arg)
92    }
93
94    pub fn and(args: Vec<Value>) -> Value {
95        let a_arg = args
96            .first()
97            .expect("should have first expression passed")
98            .get_bool();
99        let b_arg = args
100            .get(1)
101            .expect("should have second expression passed")
102            .get_bool();
103
104        Value::Bool(a_arg && b_arg)
105    }
106
107    pub fn or(args: Vec<Value>) -> Value {
108        let a_arg = args
109            .first()
110            .expect("should have first expression passed")
111            .get_bool();
112        let b_arg = args
113            .get(1)
114            .expect("should have second expression passed")
115            .get_bool();
116
117        Value::Bool(a_arg || b_arg)
118    }
119
120    pub fn cond(args: Vec<Value>) -> Value {
121        let cond_arg = args
122            .first()
123            .expect("should have cond expression passed")
124            .get_bool();
125        let then_arg = args
126            .get(1)
127            .cloned()
128            .expect("should have then expression passed");
129        let else_arg = args
130            .get(2)
131            .cloned()
132            .expect("should have else expression passed");
133
134        if cond_arg { then_arg } else { else_arg }
135    }
136
137    pub fn to_str(args: Vec<Value>) -> Value {
138        let value_arg = args.first().expect("should have string expression passed");
139
140        match value_arg {
141            Value::String(_) => value_arg.clone(),
142            _ => Value::String(value_arg.to_string()),
143        }
144    }
145
146    pub fn concat(args: Vec<Value>) -> Value {
147        let mut result = String::new();
148
149        for arg in args {
150            let value = match arg {
151                Value::String(string) => string,
152                _ => arg.to_string(),
153            };
154
155            result.push_str(value.as_str());
156        }
157
158        Value::String(result)
159    }
160
161    pub fn contains(args: Vec<Value>) -> Value {
162        let needle_arg = args
163            .first()
164            .expect("should have first expression passed")
165            .get_string();
166        let haystack_arg = args
167            .get(1)
168            .expect("should have second expression passed")
169            .get_string();
170
171        Value::Bool(haystack_arg.contains(needle_arg))
172    }
173
174    pub fn trim(args: Vec<Value>) -> Value {
175        let string_arg = args
176            .first()
177            .expect("should have string expression passed")
178            .get_string();
179
180        Value::String(string_arg.trim().to_string())
181    }
182
183    pub fn trim_start(args: Vec<Value>) -> Value {
184        let string_arg = args
185            .first()
186            .expect("should have string expression passed")
187            .get_string();
188
189        Value::String(string_arg.trim_start().to_string())
190    }
191
192    pub fn trim_end(args: Vec<Value>) -> Value {
193        let string_arg = args
194            .first()
195            .expect("should have string expression passed")
196            .get_string();
197
198        Value::String(string_arg.trim_end().to_string())
199    }
200
201    pub fn lowercase(args: Vec<Value>) -> Value {
202        let string_arg = args
203            .first()
204            .expect("should have string expression passed")
205            .get_string();
206
207        Value::String(string_arg.to_lowercase().to_string())
208    }
209
210    pub fn uppercase(args: Vec<Value>) -> Value {
211        let string_arg = args
212            .first()
213            .expect("should have string expression passed")
214            .get_string();
215
216        Value::String(string_arg.to_uppercase().to_string())
217    }
218
219    pub fn eq(args: Vec<Value>) -> Value {
220        let a_arg = args.first().expect("should have first expression passed");
221        let b_arg = args.get(1).expect("should have second expression passed");
222
223        Value::Bool(a_arg == b_arg)
224    }
225}
226
227#[derive(Debug)]
228pub struct CompileTimeEnv {
229    builtins: Vec<Rc<BuiltinFn>>,
230    user_builtins: Vec<Rc<BuiltinFn>>,
231    vars: Vec<String>,
232    prompts: Vec<String>,
233    secrets: Vec<String>,
234    client_context: Vec<String>,
235}
236
237impl Default for CompileTimeEnv {
238    fn default() -> Self {
239        Self {
240            builtins: vec![
241                Rc::new(BuiltinFn {
242                    name: String::from("id"),
243                    arity: 1,
244                    func: Rc::new(BuiltinFns::id),
245                }),
246                Rc::new(BuiltinFn {
247                    name: String::from("noop"),
248                    arity: 0,
249                    func: Rc::new(BuiltinFns::noop),
250                }),
251                Rc::new(BuiltinFn {
252                    name: String::from("is_empty"),
253                    arity: 1,
254                    func: Rc::new(BuiltinFns::is_empty),
255                }),
256                Rc::new(BuiltinFn {
257                    name: String::from("not"),
258                    arity: 1,
259                    func: Rc::new(BuiltinFns::not),
260                }),
261                Rc::new(BuiltinFn {
262                    name: String::from("and"),
263                    arity: 2,
264                    func: Rc::new(BuiltinFns::and),
265                }),
266                Rc::new(BuiltinFn {
267                    name: String::from("or"),
268                    arity: 2,
269                    func: Rc::new(BuiltinFns::or),
270                }),
271                Rc::new(BuiltinFn {
272                    name: String::from("cond"),
273                    arity: 3,
274                    func: Rc::new(BuiltinFns::cond),
275                }),
276                Rc::new(BuiltinFn {
277                    name: String::from("to_str"),
278                    arity: 1,
279                    func: Rc::new(BuiltinFns::to_str),
280                }),
281                Rc::new(BuiltinFn {
282                    name: String::from("concat"),
283                    arity: 10,
284                    func: Rc::new(BuiltinFns::concat),
285                }),
286                Rc::new(BuiltinFn {
287                    name: String::from("contains"),
288                    arity: 2,
289                    func: Rc::new(BuiltinFns::contains),
290                }),
291                Rc::new(BuiltinFn {
292                    name: String::from("trim"),
293                    arity: 1,
294                    func: Rc::new(BuiltinFns::trim),
295                }),
296                Rc::new(BuiltinFn {
297                    name: String::from("trim_start"),
298                    arity: 1,
299                    func: Rc::new(BuiltinFns::trim_start),
300                }),
301                Rc::new(BuiltinFn {
302                    name: String::from("trim_end"),
303                    arity: 1,
304                    func: Rc::new(BuiltinFns::trim_end),
305                }),
306                Rc::new(BuiltinFn {
307                    name: String::from("lowercase"),
308                    arity: 1,
309                    func: Rc::new(BuiltinFns::lowercase),
310                }),
311                Rc::new(BuiltinFn {
312                    name: String::from("uppercase"),
313                    arity: 1,
314                    func: Rc::new(BuiltinFns::uppercase),
315                }),
316                Rc::new(BuiltinFn {
317                    name: String::from("eq"),
318                    arity: 2,
319                    func: Rc::new(BuiltinFns::eq),
320                }),
321            ],
322            user_builtins: vec![],
323            vars: Vec::new(),
324            prompts: Vec::new(),
325            secrets: Vec::new(),
326            client_context: Vec::new(),
327        }
328    }
329}
330
331impl CompileTimeEnv {
332    pub fn new(
333        vars: Vec<String>,
334        prompts: Vec<String>,
335        secrets: Vec<String>,
336        client_context: Vec<String>,
337    ) -> Self {
338        Self {
339            vars,
340            prompts,
341            secrets,
342            client_context,
343            ..Default::default()
344        }
345    }
346
347    pub fn get_builtin_index(&self, name: &str) -> Option<(&Rc<BuiltinFn>, u8)> {
348        let index = self.builtins.iter().position(|x| x.name == name);
349
350        let result = index.map(|i| (self.builtins.get(i).unwrap(), i as u8));
351        result
352    }
353
354    pub fn get_user_builtin_index(&self, name: &str) -> Option<(&Rc<BuiltinFn>, u8)> {
355        let index = self.user_builtins.iter().position(|x| x.name == name);
356
357        let result = index.map(|i| (self.user_builtins.get(i).unwrap(), i as u8));
358        result
359    }
360
361    pub fn add_user_builtins(&mut self, builtins: Vec<Rc<BuiltinFn>>) {
362        for builtin in builtins {
363            self.add_user_builtin(builtin);
364        }
365    }
366
367    pub fn add_user_builtin(&mut self, builtin: Rc<BuiltinFn>) {
368        self.user_builtins.push(builtin);
369    }
370
371    pub fn get_builtin(&self, index: usize) -> Option<&Rc<BuiltinFn>> {
372        self.builtins.get(index)
373    }
374
375    pub fn get_user_builtin(&self, index: usize) -> Option<&Rc<BuiltinFn>> {
376        self.user_builtins.get(index)
377    }
378
379    pub fn get_var(&self, index: usize) -> Option<&String> {
380        self.vars.get(index)
381    }
382
383    pub fn get_prompt(&self, index: usize) -> Option<&String> {
384        self.prompts.get(index)
385    }
386
387    pub fn get_secret(&self, index: usize) -> Option<&String> {
388        self.secrets.get(index)
389    }
390
391    pub fn get_client_context(&self, index: usize) -> Option<&String> {
392        self.client_context.get(index)
393    }
394
395    pub fn add_to_client_context(&mut self, key: &str) -> usize {
396        match self.client_context.iter().position(|x| x == key) {
397            Some(i) => i,
398            None => {
399                self.client_context.push(key.to_string());
400
401                self.client_context.len() - 1
402            }
403        }
404    }
405
406    pub fn add_keys_to_client_context(&mut self, keys: Vec<String>) {
407        self.client_context.extend(keys);
408    }
409
410    pub fn get_client_context_index(&self, name: &str) -> Option<(&String, u8)> {
411        let index = self
412            .client_context
413            .iter()
414            .position(|context_name| context_name == name);
415
416        let result = index.map(|i| (self.client_context.get(i).unwrap(), i as u8));
417        result
418    }
419}
420
421/// The compiled bytecode for an expression
422#[derive(Debug, Clone)]
423pub struct ExprByteCode {
424    codes: Vec<u8>,
425    strings: Vec<String>,
426}
427
428impl ExprByteCode {
429    pub fn new(codes: Vec<u8>, strings: Vec<String>) -> Self {
430        Self { codes, strings }
431    }
432
433    pub fn codes(&self) -> &[u8] {
434        &self.codes
435    }
436
437    pub fn strings(&self) -> &[String] {
438        &self.strings
439    }
440}
441
442/// Compile an [`ast::Expr`] into [`ExprByteCode`]
443pub fn compile(expr: &Expr, env: &CompileTimeEnv) -> ExprResult<ExprByteCode> {
444    let mut strings: Vec<String> = vec![];
445    let codes = compile_expr(expr, env, &mut strings)?;
446    Ok(ExprByteCode::new(codes, strings))
447}
448
449fn compile_expr(
450    expr: &Expr,
451    env: &CompileTimeEnv,
452    strings: &mut Vec<String>,
453) -> ExprResult<Vec<u8>> {
454    use opcode::*;
455
456    let mut codes = vec![];
457
458    match expr {
459        Expr::String(string) => {
460            if let Some(index) = strings.iter().position(|x| x == &string.0) {
461                codes.push(CONSTANT);
462                codes.push(index as u8);
463            } else {
464                strings.push(string.0.clone());
465                let index = strings.len() - 1;
466                codes.push(CONSTANT);
467                codes.push(index as u8);
468            }
469        }
470        Expr::Identifier(identifier) => {
471            let identifier_name = identifier.0.as_str();
472
473            if let Some((_, index)) = env.get_builtin_index(identifier_name) {
474                codes.push(GET);
475                codes.push(lookup::BUILTIN);
476                codes.push(index);
477            } else if let Some((_, index)) = env.get_user_builtin_index(identifier_name) {
478                codes.push(GET);
479                codes.push(lookup::USER_BUILTIN);
480                codes.push(index);
481            } else {
482                let identifier_prefix = &identifier_name[..1];
483                let identifier_suffix = &identifier_name[1..];
484
485                match identifier_prefix {
486                    "?" => {
487                        if let Some(index) = get(&env.prompts, identifier_suffix) {
488                            codes.push(GET);
489                            codes.push(lookup::PROMPT);
490                            codes.push(index);
491                        }
492                    }
493                    "!" => {
494                        if let Some(index) = get(&env.secrets, identifier_suffix) {
495                            codes.push(GET);
496                            codes.push(lookup::SECRET);
497                            codes.push(index);
498                        }
499                    }
500                    ":" => {
501                        if let Some(index) = get(&env.vars, identifier_suffix) {
502                            codes.push(GET);
503                            codes.push(lookup::VAR);
504                            codes.push(index);
505                        }
506                    }
507                    "@" => {
508                        if let Some(index) = get(&env.client_context, identifier_suffix) {
509                            codes.push(GET);
510                            codes.push(lookup::CLIENT_CTX);
511                            codes.push(index);
512                        }
513                    }
514                    _ => {}
515                };
516            }
517        }
518        Expr::Call(expr_call) => {
519            let callee_bytecode = compile_expr(&expr_call.callee.0, env, strings)?;
520
521            codes.extend(callee_bytecode);
522
523            for arg in expr_call.args.iter() {
524                let arg_bytecode = compile_expr(&arg.0, env, strings)?;
525
526                codes.extend(arg_bytecode);
527            }
528
529            codes.push(opcode::CALL);
530            codes.push(expr_call.args.len() as u8);
531        }
532        Expr::Bool(value) => match value.0 {
533            true => {
534                codes.push(opcode::TRUE);
535            }
536            false => {
537                codes.push(opcode::FALSE);
538            }
539        },
540    }
541
542    Ok(codes)
543}
544
545#[cfg(test)]
546mod value_tests {
547    use super::*;
548
549    #[test]
550    fn test_builtins_debug_0_arity() {
551        assert_eq!(
552            "builtin test_builtin(0)",
553            format!(
554                "{:#?}",
555                BuiltinFn {
556                    name: "test_builtin".to_string(),
557                    arity: 0,
558                    func: Rc::new(|_| { Value::String("test_builtin".to_string()) })
559                }
560            )
561        )
562    }
563
564    #[test]
565    fn test_builtins_debug_1_arity() {
566        assert_eq!(
567            "builtin test_builtin(1)",
568            format!(
569                "{:#?}",
570                BuiltinFn {
571                    name: "test_builtin".to_string(),
572                    arity: 1,
573                    func: Rc::new(|_| { Value::String("test_builtin".to_string()) })
574                }
575            )
576        )
577    }
578
579    #[test]
580    fn test_builtins_debug_2_arity() {
581        assert_eq!(
582            "builtin test_builtin(2)",
583            format!(
584                "{:#?}",
585                BuiltinFn {
586                    name: "test_builtin".to_string(),
587                    arity: 2,
588                    func: Rc::new(|_| { Value::String("test_builtin".to_string()) })
589                }
590            )
591        )
592    }
593}