reqlang_expr/
compiler.rs

1//! The compiler and associated types
2
3use std::{ops::Range, rc::Rc};
4
5use crate::{
6    ast::Expr,
7    builtins::{BuiltinFn, BuiltinFns},
8    errors::{
9        CompileError::{self, WrongNumberOfArgs},
10        ExprError, ExprResult,
11    },
12    prelude::FnArg,
13    types::Type,
14};
15
16pub mod opcode {
17    iota::iota! {
18        pub const
19        CALL: u8 = iota;,
20        GET,
21        CONSTANT,
22        TRUE,
23        FALSE
24    }
25}
26
27/// Types of lookups for the GET op code
28///
29/// Used at compile time to encode lookup indexes
30///
31/// Used at runtime to use lookup indexes to reference runtime values
32pub mod lookup {
33    iota::iota! {
34        pub const
35        BUILTIN: u8 = iota;,
36        VAR,
37        PROMPT,
38        SECRET,
39        USER_BUILTIN,
40        CLIENT_CTX
41    }
42}
43
44/// Try to get a string from a list
45fn get(list: &[String], identifier: &str) -> Option<u8> {
46    list.iter().position(|x| x == identifier).map(|i| i as u8)
47}
48
49#[derive(Debug)]
50pub struct CompileTimeEnv {
51    builtins: Vec<Rc<BuiltinFn>>,
52    user_builtins: Vec<Rc<BuiltinFn>>,
53    vars: Vec<String>,
54    prompts: Vec<String>,
55    secrets: Vec<String>,
56    client_context: Vec<String>,
57}
58
59impl Default for CompileTimeEnv {
60    fn default() -> Self {
61        Self {
62            builtins: vec![
63                Rc::new(BuiltinFn {
64                    name: String::from("id"),
65                    args: vec![FnArg::new("value", Type::Value)],
66                    return_type: Type::Value,
67                    func: Rc::new(BuiltinFns::id),
68                }),
69                Rc::new(BuiltinFn {
70                    name: String::from("noop"),
71                    args: vec![],
72                    return_type: Type::String,
73                    func: Rc::new(BuiltinFns::noop),
74                }),
75                Rc::new(BuiltinFn {
76                    name: String::from("is_empty"),
77                    args: vec![FnArg::new("value", Type::String)],
78                    return_type: Type::String,
79                    func: Rc::new(BuiltinFns::is_empty),
80                }),
81                Rc::new(BuiltinFn {
82                    name: String::from("not"),
83                    args: vec![FnArg::new("value", Type::Bool)],
84                    return_type: Type::Bool,
85                    func: Rc::new(BuiltinFns::not),
86                }),
87                Rc::new(BuiltinFn {
88                    name: String::from("and"),
89                    args: vec![FnArg::new("a", Type::Bool), FnArg::new("b", Type::Bool)],
90                    return_type: Type::Bool,
91                    func: Rc::new(BuiltinFns::and),
92                }),
93                Rc::new(BuiltinFn {
94                    name: String::from("or"),
95                    args: vec![FnArg::new("a", Type::Bool), FnArg::new("b", Type::Bool)],
96                    return_type: Type::Bool,
97                    func: Rc::new(BuiltinFns::or),
98                }),
99                Rc::new(BuiltinFn {
100                    name: String::from("cond"),
101                    args: vec![
102                        FnArg::new("cond", Type::Bool),
103                        FnArg::new("then", Type::Value),
104                        FnArg::new("else", Type::Value),
105                    ],
106                    return_type: Type::Bool,
107                    func: Rc::new(BuiltinFns::cond),
108                }),
109                Rc::new(BuiltinFn {
110                    name: String::from("to_str"),
111                    args: vec![FnArg::new("value", Type::Value)],
112                    return_type: Type::String,
113                    func: Rc::new(BuiltinFns::to_str),
114                }),
115                Rc::new(BuiltinFn {
116                    name: String::from("concat"),
117                    args: vec![
118                        FnArg::new("a", Type::String),
119                        FnArg::new("b", Type::String),
120                        FnArg::new_varadic("rest", Type::String),
121                    ],
122                    return_type: Type::String,
123                    func: Rc::new(BuiltinFns::concat),
124                }),
125                Rc::new(BuiltinFn {
126                    name: String::from("contains"),
127                    args: vec![
128                        FnArg::new("needle", Type::String),
129                        FnArg::new("haystack", Type::String),
130                    ],
131                    return_type: Type::Bool,
132                    func: Rc::new(BuiltinFns::contains),
133                }),
134                Rc::new(BuiltinFn {
135                    name: String::from("trim"),
136                    args: vec![FnArg::new("value", Type::String)],
137                    return_type: Type::String,
138                    func: Rc::new(BuiltinFns::trim),
139                }),
140                Rc::new(BuiltinFn {
141                    name: String::from("trim_start"),
142                    args: vec![FnArg::new("value", Type::String)],
143                    return_type: Type::String,
144                    func: Rc::new(BuiltinFns::trim_start),
145                }),
146                Rc::new(BuiltinFn {
147                    name: String::from("trim_end"),
148                    args: vec![FnArg::new("value", Type::String)],
149                    return_type: Type::String,
150                    func: Rc::new(BuiltinFns::trim_end),
151                }),
152                Rc::new(BuiltinFn {
153                    name: String::from("lowercase"),
154                    args: vec![FnArg::new("value", Type::String)],
155                    return_type: Type::String,
156                    func: Rc::new(BuiltinFns::lowercase),
157                }),
158                Rc::new(BuiltinFn {
159                    name: String::from("uppercase"),
160                    args: vec![FnArg::new("value", Type::String)],
161                    return_type: Type::String,
162                    func: Rc::new(BuiltinFns::uppercase),
163                }),
164                Rc::new(BuiltinFn {
165                    name: String::from("eq"),
166                    args: vec![FnArg::new("a", Type::Value), FnArg::new("b", Type::Value)],
167                    return_type: Type::Bool,
168                    func: Rc::new(BuiltinFns::eq),
169                }),
170                Rc::new(BuiltinFn {
171                    name: String::from("type"),
172                    args: vec![FnArg::new("value", Type::Value)],
173                    return_type: Type::String,
174                    func: Rc::new(BuiltinFns::get_type),
175                }),
176            ],
177            user_builtins: vec![],
178            vars: Vec::new(),
179            prompts: Vec::new(),
180            secrets: Vec::new(),
181            client_context: Vec::new(),
182        }
183    }
184}
185
186impl CompileTimeEnv {
187    pub fn new(
188        vars: Vec<String>,
189        prompts: Vec<String>,
190        secrets: Vec<String>,
191        client_context: Vec<String>,
192    ) -> Self {
193        Self {
194            vars,
195            prompts,
196            secrets,
197            client_context,
198            ..Default::default()
199        }
200    }
201
202    pub fn get_builtin_index(&self, name: &str) -> Option<(&Rc<BuiltinFn>, u8)> {
203        let index = self.builtins.iter().position(|x| x.name == name);
204
205        let result = index.map(|i| (self.builtins.get(i).unwrap(), i as u8));
206        result
207    }
208
209    pub fn get_user_builtin_index(&self, name: &str) -> Option<(&Rc<BuiltinFn>, u8)> {
210        let index = self.user_builtins.iter().position(|x| x.name == name);
211
212        let result = index.map(|i| (self.user_builtins.get(i).unwrap(), i as u8));
213        result
214    }
215
216    pub fn add_user_builtins(&mut self, builtins: Vec<Rc<BuiltinFn>>) {
217        for builtin in builtins {
218            self.add_user_builtin(builtin);
219        }
220    }
221
222    pub fn add_user_builtin(&mut self, builtin: Rc<BuiltinFn>) {
223        self.user_builtins.push(builtin);
224    }
225
226    pub fn get_builtin(&self, index: usize) -> Option<&Rc<BuiltinFn>> {
227        self.builtins.get(index)
228    }
229
230    pub fn get_user_builtin(&self, index: usize) -> Option<&Rc<BuiltinFn>> {
231        self.user_builtins.get(index)
232    }
233
234    pub fn get_var(&self, index: usize) -> Option<&String> {
235        self.vars.get(index)
236    }
237
238    pub fn get_prompt(&self, index: usize) -> Option<&String> {
239        self.prompts.get(index)
240    }
241
242    pub fn get_secret(&self, index: usize) -> Option<&String> {
243        self.secrets.get(index)
244    }
245
246    pub fn get_client_context(&self, index: usize) -> Option<&String> {
247        self.client_context.get(index)
248    }
249
250    pub fn add_to_client_context(&mut self, key: &str) -> usize {
251        match self.client_context.iter().position(|x| x == key) {
252            Some(i) => i,
253            None => {
254                self.client_context.push(key.to_string());
255
256                self.client_context.len() - 1
257            }
258        }
259    }
260
261    pub fn add_keys_to_client_context(&mut self, keys: Vec<String>) {
262        self.client_context.extend(keys);
263    }
264
265    pub fn get_client_context_index(&self, name: &str) -> Option<(&String, u8)> {
266        let index = self
267            .client_context
268            .iter()
269            .position(|context_name| context_name == name);
270
271        let result = index.map(|i| (self.client_context.get(i).unwrap(), i as u8));
272        result
273    }
274}
275
276/// The compiled bytecode for an expression
277#[derive(Debug, Clone, PartialEq)]
278pub struct ExprByteCode {
279    pub codes: Vec<u8>,
280    pub strings: Vec<String>,
281}
282
283impl ExprByteCode {
284    pub fn new(codes: Vec<u8>, strings: Vec<String>) -> Self {
285        Self { codes, strings }
286    }
287
288    pub fn codes(&self) -> &[u8] {
289        &self.codes
290    }
291
292    pub fn strings(&self) -> &[String] {
293        &self.strings
294    }
295}
296
297/// Compile an [`ast::Expr`] into [`ExprByteCode`]
298pub fn compile(expr: &(Expr, Range<usize>), env: &CompileTimeEnv) -> ExprResult<ExprByteCode> {
299    let mut strings: Vec<String> = vec![];
300    let codes = compile_expr(expr, env, &mut strings)?;
301    Ok(ExprByteCode::new(codes, strings))
302}
303
304fn compile_expr(
305    (expr, span): &(Expr, Range<usize>),
306    env: &CompileTimeEnv,
307    strings: &mut Vec<String>,
308) -> ExprResult<Vec<u8>> {
309    use opcode::*;
310
311    let mut codes = vec![];
312    let mut errs: Vec<(ExprError, Range<usize>)> = vec![];
313
314    match expr {
315        Expr::String(string) => {
316            if let Some(index) = strings.iter().position(|x| x == &string.0) {
317                codes.push(CONSTANT);
318                codes.push(index as u8);
319            } else {
320                strings.push(string.0.clone());
321                let index = strings.len() - 1;
322                codes.push(CONSTANT);
323                codes.push(index as u8);
324            }
325        }
326        Expr::Identifier(identifier) => {
327            let identifier_name = identifier.0.as_str();
328
329            if let Some((_, index)) = env.get_builtin_index(identifier_name) {
330                codes.push(GET);
331                codes.push(lookup::BUILTIN);
332                codes.push(index);
333            } else if let Some((_, index)) = env.get_user_builtin_index(identifier_name) {
334                codes.push(GET);
335                codes.push(lookup::USER_BUILTIN);
336                codes.push(index);
337            } else {
338                let identifier_prefix = &identifier_name[..1];
339                let identifier_suffix = &identifier_name[1..];
340
341                match identifier_prefix {
342                    "?" => {
343                        if let Some(index) = get(&env.prompts, identifier_suffix) {
344                            codes.push(GET);
345                            codes.push(lookup::PROMPT);
346                            codes.push(index);
347                        }
348                    }
349                    "!" => {
350                        if let Some(index) = get(&env.secrets, identifier_suffix) {
351                            codes.push(GET);
352                            codes.push(lookup::SECRET);
353                            codes.push(index);
354                        }
355                    }
356                    ":" => {
357                        if let Some(index) = get(&env.vars, identifier_suffix) {
358                            codes.push(GET);
359                            codes.push(lookup::VAR);
360                            codes.push(index);
361                        }
362                    }
363                    "@" => {
364                        if let Some(index) = get(&env.client_context, identifier_suffix) {
365                            codes.push(GET);
366                            codes.push(lookup::CLIENT_CTX);
367                            codes.push(index);
368                        }
369                    }
370                    _ => {
371                        errs.push((
372                            ExprError::CompileError(CompileError::Undefined(
373                                identifier_name.to_string(),
374                            )),
375                            span.clone(),
376                        ));
377                    }
378                };
379            }
380        }
381        Expr::Call(expr_call) => {
382            let callee_bytecode = compile_expr(&expr_call.callee, env, strings)?;
383
384            if let Some(_op) = callee_bytecode.first() {
385                if let Some(lookup) = callee_bytecode.get(1) {
386                    if let Some(index) = callee_bytecode.get(2) {
387                        match *lookup {
388                            lookup::BUILTIN => {
389                                let builtin = env.get_builtin((*index).into()).unwrap();
390
391                                let call_arity: usize = expr_call.args.len();
392
393                                if !builtin.arity_matches(call_arity.try_into().unwrap()) {
394                                    errs.push((
395                                        ExprError::CompileError(WrongNumberOfArgs {
396                                            expected: builtin.arity() as usize,
397                                            actual: call_arity,
398                                        }),
399                                        span.clone(),
400                                    ));
401                                }
402                            }
403                            lookup::USER_BUILTIN => {
404                                let builtin = env.get_user_builtin((*index).into()).unwrap();
405
406                                let call_arity: usize = expr_call.args.len();
407
408                                if !builtin.arity_matches(call_arity.try_into().unwrap()) {
409                                    errs.push((
410                                        ExprError::CompileError(WrongNumberOfArgs {
411                                            expected: builtin.arity() as usize,
412                                            actual: call_arity,
413                                        }),
414                                        span.clone(),
415                                    ));
416                                }
417                            }
418                            _ => {}
419                        }
420                    }
421                }
422            }
423
424            codes.extend(callee_bytecode);
425
426            for arg in expr_call.args.iter() {
427                match compile_expr(arg, env, strings) {
428                    Ok(arg_bytecode) => {
429                        codes.extend(arg_bytecode);
430                    }
431                    Err(err) => {
432                        errs.extend(err);
433                    }
434                }
435            }
436
437            codes.push(opcode::CALL);
438            codes.push(expr_call.args.len() as u8);
439        }
440        Expr::Bool(value) => match value.0 {
441            true => {
442                codes.push(opcode::TRUE);
443            }
444            false => {
445                codes.push(opcode::FALSE);
446            }
447        },
448        Expr::Error => panic!("tried to compile despite parser errors"),
449    }
450
451    if !errs.is_empty() {
452        return Err(errs);
453    }
454
455    Ok(codes)
456}