reqlang_expr/
compiler.rs

1//! The compiler and associated types
2
3use std::rc::Rc;
4
5use crate::{
6    ast::{Expr, ExprS},
7    builtins::{BuiltinFn, BuiltinFns},
8    errors::{
9        CompileError::{self, WrongNumberOfArgs},
10        ExprError, ExprErrorS, ExprResult,
11    },
12    prelude::FnArg,
13    types::Type,
14};
15
16pub mod opcode {
17    iota::iota! {
18        pub const
19        CALL: u8 = iota;,
20        GET,
21        CONSTANT,
22        TRUE,
23        FALSE,
24        NOT,
25        EQ
26    }
27}
28
29/// Types of lookups for the GET op code
30///
31/// Used at compile time to encode lookup indexes
32///
33/// Used at runtime to use lookup indexes to reference runtime values
34pub mod lookup {
35    iota::iota! {
36        pub const
37        BUILTIN: u8 = iota;,
38        VAR,
39        PROMPT,
40        SECRET,
41        USER_BUILTIN,
42        CLIENT_CTX
43    }
44}
45
46/// Try to get a string from a list
47fn get(list: &[String], identifier: &str) -> Option<u8> {
48    list.iter().position(|x| x == identifier).map(|i| i as u8)
49}
50
51#[derive(Debug)]
52pub struct CompileTimeEnv {
53    builtins: Vec<Rc<BuiltinFn>>,
54    user_builtins: Vec<Rc<BuiltinFn>>,
55    vars: Vec<String>,
56    prompts: Vec<String>,
57    secrets: Vec<String>,
58    client_context: Vec<String>,
59}
60
61impl Default for CompileTimeEnv {
62    fn default() -> Self {
63        Self {
64            builtins: vec![
65                Rc::new(BuiltinFn {
66                    name: String::from("id"),
67                    args: vec![FnArg::new("value", Type::Value)],
68                    return_type: Type::Value,
69                    func: Rc::new(BuiltinFns::id),
70                }),
71                Rc::new(BuiltinFn {
72                    name: String::from("noop"),
73                    args: vec![],
74                    return_type: Type::String,
75                    func: Rc::new(BuiltinFns::noop),
76                }),
77                Rc::new(BuiltinFn {
78                    name: String::from("is_empty"),
79                    args: vec![FnArg::new("value", Type::String)],
80                    return_type: Type::String,
81                    func: Rc::new(BuiltinFns::is_empty),
82                }),
83                Rc::new(BuiltinFn {
84                    name: String::from("and"),
85                    args: vec![FnArg::new("a", Type::Bool), FnArg::new("b", Type::Bool)],
86                    return_type: Type::Bool,
87                    func: Rc::new(BuiltinFns::and),
88                }),
89                Rc::new(BuiltinFn {
90                    name: String::from("or"),
91                    args: vec![FnArg::new("a", Type::Bool), FnArg::new("b", Type::Bool)],
92                    return_type: Type::Bool,
93                    func: Rc::new(BuiltinFns::or),
94                }),
95                Rc::new(BuiltinFn {
96                    name: String::from("cond"),
97                    args: vec![
98                        FnArg::new("cond", Type::Bool),
99                        FnArg::new("then", Type::Value),
100                        FnArg::new("else", Type::Value),
101                    ],
102                    return_type: Type::Bool,
103                    func: Rc::new(BuiltinFns::cond),
104                }),
105                Rc::new(BuiltinFn {
106                    name: String::from("to_str"),
107                    args: vec![FnArg::new("value", Type::Value)],
108                    return_type: Type::String,
109                    func: Rc::new(BuiltinFns::to_str),
110                }),
111                Rc::new(BuiltinFn {
112                    name: String::from("concat"),
113                    args: vec![
114                        FnArg::new("a", Type::String),
115                        FnArg::new("b", Type::String),
116                        FnArg::new_varadic("rest", Type::String),
117                    ],
118                    return_type: Type::String,
119                    func: Rc::new(BuiltinFns::concat),
120                }),
121                Rc::new(BuiltinFn {
122                    name: String::from("contains"),
123                    args: vec![
124                        FnArg::new("needle", Type::String),
125                        FnArg::new("haystack", Type::String),
126                    ],
127                    return_type: Type::Bool,
128                    func: Rc::new(BuiltinFns::contains),
129                }),
130                Rc::new(BuiltinFn {
131                    name: String::from("trim"),
132                    args: vec![FnArg::new("value", Type::String)],
133                    return_type: Type::String,
134                    func: Rc::new(BuiltinFns::trim),
135                }),
136                Rc::new(BuiltinFn {
137                    name: String::from("trim_start"),
138                    args: vec![FnArg::new("value", Type::String)],
139                    return_type: Type::String,
140                    func: Rc::new(BuiltinFns::trim_start),
141                }),
142                Rc::new(BuiltinFn {
143                    name: String::from("trim_end"),
144                    args: vec![FnArg::new("value", Type::String)],
145                    return_type: Type::String,
146                    func: Rc::new(BuiltinFns::trim_end),
147                }),
148                Rc::new(BuiltinFn {
149                    name: String::from("lowercase"),
150                    args: vec![FnArg::new("value", Type::String)],
151                    return_type: Type::String,
152                    func: Rc::new(BuiltinFns::lowercase),
153                }),
154                Rc::new(BuiltinFn {
155                    name: String::from("uppercase"),
156                    args: vec![FnArg::new("value", Type::String)],
157                    return_type: Type::String,
158                    func: Rc::new(BuiltinFns::uppercase),
159                }),
160                Rc::new(BuiltinFn {
161                    name: String::from("type"),
162                    args: vec![FnArg::new("value", Type::Value)],
163                    return_type: Type::String,
164                    func: Rc::new(BuiltinFns::get_type),
165                }),
166            ],
167            user_builtins: vec![],
168            vars: Vec::new(),
169            prompts: Vec::new(),
170            secrets: Vec::new(),
171            client_context: Vec::new(),
172        }
173    }
174}
175
176impl CompileTimeEnv {
177    pub fn new(
178        vars: Vec<String>,
179        prompts: Vec<String>,
180        secrets: Vec<String>,
181        client_context: Vec<String>,
182    ) -> Self {
183        Self {
184            vars,
185            prompts,
186            secrets,
187            client_context,
188            ..Default::default()
189        }
190    }
191
192    pub fn get_builtin_index(&self, name: &str) -> Option<(&Rc<BuiltinFn>, u8)> {
193        let index = self.builtins.iter().position(|x| x.name == name);
194
195        let result = index.map(|i| (self.builtins.get(i).unwrap(), i as u8));
196        result
197    }
198
199    pub fn get_user_builtin_index(&self, name: &str) -> Option<(&Rc<BuiltinFn>, u8)> {
200        let index = self.user_builtins.iter().position(|x| x.name == name);
201
202        let result = index.map(|i| (self.user_builtins.get(i).unwrap(), i as u8));
203        result
204    }
205
206    pub fn add_user_builtins(&mut self, builtins: Vec<Rc<BuiltinFn>>) {
207        for builtin in builtins {
208            self.add_user_builtin(builtin);
209        }
210    }
211
212    pub fn add_user_builtin(&mut self, builtin: Rc<BuiltinFn>) {
213        self.user_builtins.push(builtin);
214    }
215
216    pub fn get_builtin(&self, index: usize) -> Option<&Rc<BuiltinFn>> {
217        self.builtins.get(index)
218    }
219
220    pub fn get_user_builtin(&self, index: usize) -> Option<&Rc<BuiltinFn>> {
221        self.user_builtins.get(index)
222    }
223
224    pub fn get_var(&self, index: usize) -> Option<&String> {
225        self.vars.get(index)
226    }
227
228    pub fn get_prompt(&self, index: usize) -> Option<&String> {
229        self.prompts.get(index)
230    }
231
232    pub fn get_secret(&self, index: usize) -> Option<&String> {
233        self.secrets.get(index)
234    }
235
236    pub fn get_client_context(&self, index: usize) -> Option<&String> {
237        self.client_context.get(index)
238    }
239
240    pub fn add_to_client_context(&mut self, key: &str) -> usize {
241        match self.client_context.iter().position(|x| x == key) {
242            Some(i) => i,
243            None => {
244                self.client_context.push(key.to_string());
245
246                self.client_context.len() - 1
247            }
248        }
249    }
250
251    pub fn add_keys_to_client_context(&mut self, keys: Vec<String>) {
252        self.client_context.extend(keys);
253    }
254
255    pub fn get_client_context_index(&self, name: &str) -> Option<(&String, u8)> {
256        let index = self
257            .client_context
258            .iter()
259            .position(|context_name| context_name == name);
260
261        let result = index.map(|i| (self.client_context.get(i).unwrap(), i as u8));
262        result
263    }
264}
265
266/// The compiled bytecode for an expression
267#[derive(Debug, Clone, PartialEq)]
268pub struct ExprByteCode {
269    pub codes: Vec<u8>,
270    pub strings: Vec<String>,
271}
272
273impl ExprByteCode {
274    pub fn new(codes: Vec<u8>, strings: Vec<String>) -> Self {
275        Self { codes, strings }
276    }
277
278    pub fn codes(&self) -> &[u8] {
279        &self.codes
280    }
281
282    pub fn strings(&self) -> &[String] {
283        &self.strings
284    }
285}
286
287/// Compile an [`ast::Expr`] into [`ExprByteCode`]
288pub fn compile(expr: &ExprS, env: &CompileTimeEnv) -> ExprResult<ExprByteCode> {
289    let mut strings: Vec<String> = vec![];
290    let codes = compile_expr(expr, env, &mut strings)?;
291    Ok(ExprByteCode::new(codes, strings))
292}
293
294fn compile_expr(
295    (expr, span): &ExprS,
296    env: &CompileTimeEnv,
297    strings: &mut Vec<String>,
298) -> ExprResult<Vec<u8>> {
299    use opcode::*;
300
301    let mut codes = vec![];
302    let mut errs: Vec<ExprErrorS> = vec![];
303
304    match expr {
305        Expr::String(string) => {
306            if let Some(index) = strings.iter().position(|x| x == &string.0) {
307                codes.push(CONSTANT);
308                codes.push(index as u8);
309            } else {
310                strings.push(string.0.clone());
311                let index = strings.len() - 1;
312                codes.push(CONSTANT);
313                codes.push(index as u8);
314            }
315        }
316        Expr::Identifier(identifier) => {
317            let identifier_name = identifier.0.as_str();
318
319            if let Some((_, index)) = env.get_builtin_index(identifier_name) {
320                codes.push(GET);
321                codes.push(lookup::BUILTIN);
322                codes.push(index);
323            } else if let Some((_, index)) = env.get_user_builtin_index(identifier_name) {
324                codes.push(GET);
325                codes.push(lookup::USER_BUILTIN);
326                codes.push(index);
327            } else {
328                let identifier_prefix = &identifier_name[..1];
329                let identifier_suffix = &identifier_name[1..];
330
331                match identifier_prefix {
332                    "?" => {
333                        if let Some(index) = get(&env.prompts, identifier_suffix) {
334                            codes.push(GET);
335                            codes.push(lookup::PROMPT);
336                            codes.push(index);
337                        } else {
338                            errs.push((
339                                CompileError::Undefined(identifier_name.to_string()).into(),
340                                span.clone(),
341                            ));
342                        }
343                    }
344                    "!" => {
345                        if let Some(index) = get(&env.secrets, identifier_suffix) {
346                            codes.push(GET);
347                            codes.push(lookup::SECRET);
348                            codes.push(index);
349                        } else {
350                            errs.push((
351                                CompileError::Undefined(identifier_name.to_string()).into(),
352                                span.clone(),
353                            ));
354                        }
355                    }
356                    ":" => {
357                        if let Some(index) = get(&env.vars, identifier_suffix) {
358                            codes.push(GET);
359                            codes.push(lookup::VAR);
360                            codes.push(index);
361                        } else {
362                            errs.push((
363                                CompileError::Undefined(identifier_name.to_string()).into(),
364                                span.clone(),
365                            ));
366                        }
367                    }
368                    "@" => {
369                        if let Some(index) = get(&env.client_context, identifier_suffix) {
370                            codes.push(GET);
371                            codes.push(lookup::CLIENT_CTX);
372                            codes.push(index);
373                        } else {
374                            errs.push((
375                                CompileError::Undefined(identifier_name.to_string()).into(),
376                                span.clone(),
377                            ));
378                        }
379                    }
380                    _ => {
381                        errs.push((
382                            ExprError::CompileError(CompileError::Undefined(
383                                identifier_name.to_string(),
384                            )),
385                            span.clone(),
386                        ));
387                    }
388                };
389            }
390        }
391        Expr::Call(expr_call) => {
392            let identifier_name = expr_call.callee.0.identifier_name().unwrap_or_default();
393
394            match identifier_name {
395                "eq" => {
396                    if expr_call.args.is_empty() {
397                        errs.push((
398                            ExprError::CompileError(WrongNumberOfArgs {
399                                expected: 2,
400                                actual: 0,
401                            }),
402                            span.clone(),
403                        ));
404                    } else if expr_call.args.len() == 1 {
405                        errs.push((
406                            ExprError::CompileError(WrongNumberOfArgs {
407                                expected: 2,
408                                actual: 1,
409                            }),
410                            span.clone(),
411                        ));
412                    } else if expr_call.args.len() > 2 {
413                        errs.push((
414                            ExprError::CompileError(WrongNumberOfArgs {
415                                expected: 2,
416                                actual: expr_call.args.len(),
417                            }),
418                            span.clone(),
419                        ));
420                    } else {
421                        let arg = expr_call.args.first().expect("should have first argument");
422
423                        match compile_expr(arg, env, strings) {
424                            Ok(arg_bytecode) => {
425                                codes.extend(arg_bytecode);
426                            }
427                            Err(err) => {
428                                errs.extend(err);
429                            }
430                        }
431
432                        let arg2 = expr_call.args.get(1).expect("should have second argument");
433
434                        match compile_expr(arg2, env, strings) {
435                            Ok(arg_bytecode) => {
436                                codes.extend(arg_bytecode);
437                            }
438                            Err(err) => {
439                                errs.extend(err);
440                            }
441                        }
442
443                        codes.push(opcode::EQ);
444                    }
445                }
446                "not" => {
447                    if expr_call.args.is_empty() {
448                        errs.push((
449                            ExprError::CompileError(WrongNumberOfArgs {
450                                expected: 1,
451                                actual: 0,
452                            }),
453                            span.clone(),
454                        ));
455                    } else if expr_call.args.len() > 1 {
456                        errs.push((
457                            ExprError::CompileError(WrongNumberOfArgs {
458                                expected: 1,
459                                actual: expr_call.args.len(),
460                            }),
461                            span.clone(),
462                        ));
463
464                        let arg = expr_call.args.first().expect("should have first argument");
465
466                        if !arg.0.is_bool() {
467                            errs.push((
468                                CompileError::TypeMismatch {
469                                    expected: Type::Bool,
470                                    actual: arg.0.get_type(),
471                                }
472                                .into(),
473                                arg.1.clone(),
474                            ));
475                        }
476                    } else {
477                        let arg = expr_call.args.first().expect("should have first argument");
478                        if !arg.0.is_bool() {
479                            errs.push((
480                                CompileError::TypeMismatch {
481                                    expected: Type::Bool,
482                                    actual: arg.0.get_type(),
483                                }
484                                .into(),
485                                arg.1.clone(),
486                            ));
487                        }
488
489                        match compile_expr(arg, env, strings) {
490                            Ok(arg_bytecode) => {
491                                codes.extend(arg_bytecode);
492                            }
493                            Err(err) => {
494                                errs.extend(err);
495                            }
496                        }
497
498                        codes.push(opcode::NOT);
499                    }
500                }
501                _ => {
502                    let callee_bytecode = compile_expr(&expr_call.callee, env, strings)?;
503
504                    if let Some(_op) = callee_bytecode.first() {
505                        if let Some(lookup) = callee_bytecode.get(1) {
506                            if let Some(index) = callee_bytecode.get(2) {
507                                match *lookup {
508                                    lookup::BUILTIN => {
509                                        let builtin = env.get_builtin((*index).into()).unwrap();
510
511                                        let call_arity: usize = expr_call.args.len();
512
513                                        if !builtin.arity_matches(call_arity.try_into().unwrap()) {
514                                            errs.push((
515                                                ExprError::CompileError(WrongNumberOfArgs {
516                                                    expected: builtin.arity() as usize,
517                                                    actual: call_arity,
518                                                }),
519                                                span.clone(),
520                                            ));
521                                        }
522                                    }
523                                    lookup::USER_BUILTIN => {
524                                        let builtin =
525                                            env.get_user_builtin((*index).into()).unwrap();
526
527                                        let call_arity: usize = expr_call.args.len();
528
529                                        if !builtin.arity_matches(call_arity.try_into().unwrap()) {
530                                            errs.push((
531                                                ExprError::CompileError(WrongNumberOfArgs {
532                                                    expected: builtin.arity() as usize,
533                                                    actual: call_arity,
534                                                }),
535                                                span.clone(),
536                                            ));
537                                        }
538                                    }
539                                    _ => {}
540                                }
541                            }
542                        }
543                    }
544
545                    codes.extend(callee_bytecode);
546
547                    for arg in expr_call.args.iter() {
548                        match compile_expr(arg, env, strings) {
549                            Ok(arg_bytecode) => {
550                                codes.extend(arg_bytecode);
551                            }
552                            Err(err) => {
553                                errs.extend(err);
554                            }
555                        }
556                    }
557
558                    codes.push(opcode::CALL);
559                    codes.push(expr_call.args.len() as u8);
560                }
561            }
562        }
563        Expr::Bool(value) => match value.0 {
564            true => {
565                codes.push(opcode::TRUE);
566            }
567            false => {
568                codes.push(opcode::FALSE);
569            }
570        },
571        Expr::Error => panic!("tried to compile despite parser errors"),
572    }
573
574    if !errs.is_empty() {
575        return Err(errs);
576    }
577
578    Ok(codes)
579}