reqlang_expr/
compiler.rs

1//! The compiler and associated types
2
3use crate::{
4    ast::{Expr, ExprS, IdentifierKind, add_type_to_expr},
5    builtins::BuiltinFn,
6    errors::{
7        CompileError::{self, WrongNumberOfArgs},
8        ExprError, ExprErrorS, ExprResult,
9    },
10    prelude::lookup::TYPE,
11    types::Type,
12};
13
14pub mod opcode {
15    iota::iota! {
16        pub const
17        CALL: u8 = iota;,
18        GET,
19        CONSTANT,
20        TRUE,
21        FALSE
22    }
23}
24
25/// Types of lookups for the GET op code
26///
27/// Used at compile time to encode lookup indexes
28///
29/// Used at runtime to use lookup indexes to reference runtime values
30pub mod lookup {
31    iota::iota! {
32        pub const
33        BUILTIN: u8 = iota;,
34        VAR,
35        PROMPT,
36        SECRET,
37        USER_BUILTIN,
38        CLIENT_CTX,
39        TYPE
40    }
41}
42
43/// Try to get a string from a list
44fn get(list: &[String], identifier: &str) -> Option<u8> {
45    list.iter().position(|x| x == identifier).map(|i| i as u8)
46}
47
48#[derive(Debug)]
49pub struct CompileTimeEnv {
50    builtins: Vec<BuiltinFn<'static>>,
51    user_builtins: Vec<BuiltinFn<'static>>,
52    vars: Vec<String>,
53    prompts: Vec<String>,
54    secrets: Vec<String>,
55    client_context: Vec<String>,
56}
57
58impl Default for CompileTimeEnv {
59    fn default() -> Self {
60        Self {
61            builtins: BuiltinFn::DEFAULT_BUILTINS.to_vec(),
62            user_builtins: vec![],
63            vars: vec![],
64            prompts: vec![],
65            secrets: vec![],
66            client_context: vec![],
67        }
68    }
69}
70
71impl CompileTimeEnv {
72    pub fn new(
73        vars: Vec<String>,
74        prompts: Vec<String>,
75        secrets: Vec<String>,
76        client_context: Vec<String>,
77    ) -> Self {
78        Self {
79            vars,
80            prompts,
81            secrets,
82            client_context,
83            ..Default::default()
84        }
85    }
86
87    pub fn get_builtin_index(&self, name: &str) -> Option<(&BuiltinFn, u8)> {
88        let index = self.builtins.iter().position(|x| x.name == name);
89
90        let result = index.map(|i| (self.builtins.get(i).unwrap(), i as u8));
91        result
92    }
93
94    pub fn get_user_builtin_index(&self, name: &str) -> Option<(&BuiltinFn, u8)> {
95        let index = self.user_builtins.iter().position(|x| x.name == name);
96
97        let result = index.map(|i| (self.user_builtins.get(i).unwrap(), i as u8));
98        result
99    }
100
101    pub fn add_user_builtins(&mut self, builtins: Vec<BuiltinFn<'static>>) {
102        for builtin in builtins {
103            self.add_user_builtin(builtin);
104        }
105    }
106
107    pub fn add_user_builtin(&mut self, builtin: BuiltinFn<'static>) {
108        self.user_builtins.push(builtin);
109    }
110
111    pub fn get_builtin(&self, index: usize) -> Option<&BuiltinFn<'static>> {
112        self.builtins.get(index)
113    }
114
115    pub fn get_user_builtin(&self, index: usize) -> Option<&BuiltinFn<'static>> {
116        self.user_builtins.get(index)
117    }
118
119    pub fn get_var(&self, index: usize) -> Option<&String> {
120        self.vars.get(index)
121    }
122
123    pub fn get_var_index(&self, name: &str) -> Option<usize> {
124        let index = self
125            .vars
126            .iter()
127            .position(|context_name| context_name == name);
128
129        index
130    }
131
132    pub fn get_prompt(&self, index: usize) -> Option<&String> {
133        self.prompts.get(index)
134    }
135
136    pub fn get_prompt_index(&self, name: &str) -> Option<usize> {
137        let index = self
138            .prompts
139            .iter()
140            .position(|context_name| context_name == name);
141
142        index
143    }
144
145    pub fn get_secret(&self, index: usize) -> Option<&String> {
146        self.secrets.get(index)
147    }
148
149    pub fn get_secret_index(&self, name: &str) -> Option<usize> {
150        let index = self
151            .secrets
152            .iter()
153            .position(|context_name| context_name == name);
154
155        index
156    }
157
158    pub fn get_client_context(&self, index: usize) -> Option<&String> {
159        self.client_context.get(index)
160    }
161
162    pub fn add_to_client_context(&mut self, key: &str) -> usize {
163        match self.client_context.iter().position(|x| x == key) {
164            Some(i) => i,
165            None => {
166                self.client_context.push(key.to_string());
167
168                self.client_context.len() - 1
169            }
170        }
171    }
172
173    pub fn add_keys_to_client_context(&mut self, keys: Vec<String>) {
174        self.client_context.extend(keys);
175    }
176
177    pub fn get_client_context_index(&self, name: &str) -> Option<(&String, u8)> {
178        let index = self
179            .client_context
180            .iter()
181            .position(|context_name| context_name == name);
182
183        let result = index.map(|i| (self.client_context.get(i).unwrap(), i as u8));
184        result
185    }
186}
187
188/// The compiled bytecode for an expression
189#[derive(Debug, Clone, PartialEq)]
190pub struct ExprByteCode {
191    version: [u8; 4],
192    codes: Vec<u8>,
193    strings: Vec<String>,
194    types: Vec<Type>,
195}
196
197impl ExprByteCode {
198    pub fn new(codes: Vec<u8>, strings: Vec<String>, types: Vec<Type>) -> Self {
199        let version_bytes = get_version_bytes();
200        let version_bytes_from_codes = &codes[0..4];
201
202        assert_eq!(
203            version_bytes, version_bytes_from_codes,
204            "Version bytes do not match"
205        );
206
207        let codes = codes[4..].to_vec();
208
209        Self {
210            version: version_bytes,
211            codes,
212            strings,
213            types,
214        }
215    }
216
217    pub fn version(&self) -> &[u8; 4] {
218        &self.version
219    }
220
221    pub fn codes(&self) -> &[u8] {
222        &self.codes
223    }
224
225    pub fn get_code(&self, index: usize) -> Option<&u8> {
226        self.codes.get(index)
227    }
228
229    pub fn strings(&self) -> &[String] {
230        &self.strings
231    }
232
233    pub fn types(&self) -> &[Type] {
234        &self.types
235    }
236}
237
238pub fn get_version_bytes() -> [u8; 4] {
239    [
240        env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap(),
241        env!("CARGO_PKG_VERSION_MINOR").parse().unwrap(),
242        env!("CARGO_PKG_VERSION_PATCH").parse().unwrap(),
243        0,
244    ]
245}
246
247/// Compile an [`ast::Expr`] into [`ExprByteCode`]
248pub fn compile(expr: &mut ExprS, env: &CompileTimeEnv) -> ExprResult<ExprByteCode> {
249    let mut strings: Vec<String> = vec![];
250    let mut types: Vec<Type> = vec![];
251    let mut codes = vec![];
252
253    codes.extend(get_version_bytes());
254
255    codes.extend(compile_expr(expr, env, &mut strings, &mut types)?);
256
257    Ok(ExprByteCode::new(codes, strings, types))
258}
259
260fn compile_expr(
261    (expr, span): &mut ExprS,
262    env: &CompileTimeEnv,
263    strings: &mut Vec<String>,
264    types: &mut Vec<Type>,
265) -> ExprResult<Vec<u8>> {
266    use opcode::*;
267
268    let mut codes = vec![];
269    let mut errs: Vec<ExprErrorS> = vec![];
270
271    add_type_to_expr(expr, env);
272
273    match expr {
274        Expr::String(string) => {
275            if let Some(index) = strings.iter().position(|x| x == &string.0) {
276                codes.push(CONSTANT);
277                codes.push(index as u8);
278            } else {
279                strings.push(string.0.clone());
280                let index = strings.len() - 1;
281                codes.push(CONSTANT);
282                codes.push(index as u8);
283            }
284        }
285        Expr::Identifier(identifier) => {
286            let identifier_lookup_name = identifier.lookup_name();
287            let identifier_name = identifier.full_name().to_string();
288
289            let identifier_undefined_err = (
290                CompileError::Undefined(identifier_name.clone()).into(),
291                span.clone(),
292            );
293
294            let result = match identifier.identifier_kind() {
295                IdentifierKind::Var => get(&env.vars, identifier_lookup_name).map(|index| {
296                    codes.push(GET);
297                    codes.push(lookup::VAR);
298                    codes.push(index);
299                }),
300                IdentifierKind::Prompt => get(&env.prompts, identifier_lookup_name).map(|index| {
301                    codes.push(GET);
302                    codes.push(lookup::PROMPT);
303                    codes.push(index);
304                }),
305                IdentifierKind::Secret => get(&env.secrets, identifier_lookup_name).map(|index| {
306                    codes.push(GET);
307                    codes.push(lookup::SECRET);
308                    codes.push(index);
309                }),
310                IdentifierKind::Client => {
311                    get(&env.client_context, identifier_lookup_name).map(|index| {
312                        codes.push(GET);
313                        codes.push(lookup::CLIENT_CTX);
314                        codes.push(index);
315                    })
316                }
317                IdentifierKind::Builtin => {
318                    if let Some((_, index)) = env.get_builtin_index(identifier_lookup_name) {
319                        codes.push(GET);
320                        codes.push(lookup::BUILTIN);
321                        codes.push(index);
322
323                        Some(())
324                    } else if let Some((_, index)) =
325                        env.get_user_builtin_index(identifier_lookup_name)
326                    {
327                        codes.push(GET);
328                        codes.push(lookup::USER_BUILTIN);
329                        codes.push(index);
330
331                        Some(())
332                    } else {
333                        None
334                    }
335                }
336                IdentifierKind::Type => {
337                    let ty = Type::from(&identifier_name);
338                    if let Some(index) = types.iter().position(|x| x == &ty) {
339                        codes.push(GET);
340                        codes.push(TYPE);
341                        codes.push(index as u8);
342                    } else {
343                        types.push(ty);
344                        let index = types.len() - 1;
345                        codes.push(GET);
346                        codes.push(TYPE);
347                        codes.push(index as u8);
348                    }
349
350                    Some(())
351                }
352            };
353
354            if let None = result {
355                errs.push(identifier_undefined_err);
356            }
357        }
358        Expr::Call(expr_call) => {
359            let callee_bytecode = compile_expr(&mut expr_call.callee, env, strings, types)?;
360
361            if let Some(_op) = callee_bytecode.first()
362                && let Some(lookup) = callee_bytecode.get(1)
363                && let Some(index) = callee_bytecode.get(2)
364            {
365                match *lookup {
366                    lookup::BUILTIN => {
367                        let builtin = env.get_builtin((*index).into()).unwrap();
368
369                        let call_arity: usize = expr_call.args.len();
370
371                        if !builtin.arity_matches(call_arity.try_into().unwrap()) {
372                            errs.push((
373                                ExprError::CompileError(WrongNumberOfArgs {
374                                    expected: builtin.arity() as usize,
375                                    actual: call_arity,
376                                }),
377                                span.clone(),
378                            ));
379                        }
380
381                        let args: Vec<_> = expr_call.args.iter().take(call_arity).collect();
382
383                        for (i, fnarg) in builtin.args.iter().enumerate() {
384                            if let Some((a, a_span)) = args.get(i) {
385                                let a_type = a.get_type();
386
387                                let types_match = fnarg.ty == a_type
388                                    || fnarg.ty == Type::Value
389                                    || a_type == Type::Unknown;
390
391                                if !types_match {
392                                    errs.push((
393                                        CompileError::TypeMismatch {
394                                            expected: fnarg.ty.clone(),
395                                            actual: a_type.clone(),
396                                        }
397                                        .into(),
398                                        a_span.clone(),
399                                    ));
400                                }
401                            }
402                        }
403                    }
404                    lookup::USER_BUILTIN => {
405                        let builtin = env.get_user_builtin((*index).into()).unwrap();
406
407                        let call_arity: usize = expr_call.args.len();
408
409                        if !builtin.arity_matches(call_arity.try_into().unwrap()) {
410                            errs.push((
411                                ExprError::CompileError(WrongNumberOfArgs {
412                                    expected: builtin.arity() as usize,
413                                    actual: call_arity,
414                                }),
415                                span.clone(),
416                            ));
417                        }
418                    }
419                    lookup::CLIENT_CTX => {
420                        // No validation needs to be ran at this point
421                        // This won't happen until runtime when the client
422                        // a value.
423                    }
424                    _ => {
425                        errs.push((
426                            CompileError::InvalidLookupType(*lookup).into(),
427                            span.clone(),
428                        ));
429                    }
430                }
431            }
432
433            codes.extend(callee_bytecode);
434
435            for arg in expr_call.args.iter_mut() {
436                match compile_expr(arg, env, strings, types) {
437                    Ok(arg_bytecode) => {
438                        codes.extend(arg_bytecode);
439                    }
440                    Err(err) => {
441                        errs.extend(err);
442                    }
443                }
444            }
445
446            codes.push(opcode::CALL);
447            codes.push(expr_call.args.len() as u8);
448        }
449        Expr::Bool(value) => match value.0 {
450            true => {
451                codes.push(opcode::TRUE);
452            }
453            false => {
454                codes.push(opcode::FALSE);
455            }
456        },
457        Expr::Error => panic!("tried to compile despite parser errors"),
458    }
459
460    if !errs.is_empty() {
461        return Err(errs);
462    }
463
464    Ok(codes)
465}
466
467#[cfg(test)]
468mod compiler_tests {
469    use super::*;
470
471    #[test]
472    pub fn current_version_bytes() {
473        let version_bytes = get_version_bytes();
474
475        assert_eq!(version_bytes, [0, 7, 0, 0]);
476    }
477
478    #[test]
479    pub fn valid_bytecode_version_bytes() {
480        let mut codes = get_version_bytes().to_vec();
481        codes.push(opcode::TRUE);
482
483        ExprByteCode::new(codes.to_vec(), vec![], vec![]);
484    }
485
486    #[test]
487    #[should_panic(expected = "Version bytes do not match")]
488    pub fn invalid_bytecode_version_bytes() {
489        let mut codes: Vec<u8> = [0, 0, 0, 0].to_vec();
490        codes.push(opcode::TRUE);
491
492        ExprByteCode::new(codes.to_vec(), vec![], vec![]);
493    }
494
495    #[test]
496    pub fn get_version_bytes_from_bytecode() {
497        let mut codes = get_version_bytes().to_vec();
498        codes.push(opcode::TRUE);
499
500        let bytecode = ExprByteCode::new(codes.to_vec(), vec![], vec![]);
501
502        assert_eq!(bytecode.version(), &get_version_bytes());
503    }
504}