Skip to main content

reqlang_expr/
compiler.rs

1//! The compiler and associated types
2
3use crate::{
4    ast::{Expr, ExprS, IdentifierKind, add_type_to_expr},
5    builtins::BuiltinFn,
6    errors::{
7        CompileError::{self, WrongNumberOfArgs},
8        ExprError, ExprErrorS, ExprResult,
9    },
10    prelude::lookup::TYPE,
11    types::Type,
12    value::Value,
13};
14
15pub mod opcode {
16    iota::iota! {
17        pub const
18        CALL: u8 = iota;,
19        GET,
20        CONSTANT,
21        TRUE,
22        FALSE
23    }
24}
25
26/// Types of lookups for the GET op code
27///
28/// Used at compile time to encode lookup indexes
29///
30/// Used at runtime to use lookup indexes to reference runtime values
31pub mod lookup {
32    iota::iota! {
33        pub const
34        BUILTIN: u8 = iota;,
35        VAR,
36        PROMPT,
37        SECRET,
38        USER_BUILTIN,
39        CLIENT_CTX,
40        TYPE
41    }
42}
43
44/// Try to get a string from a list
45fn get(list: &[String], identifier: &str) -> Option<u8> {
46    list.iter().position(|x| x == identifier).map(|i| i as u8)
47}
48
49#[derive(Debug)]
50pub struct CompileTimeEnv {
51    builtins: Vec<BuiltinFn<'static>>,
52    user_builtins: Vec<BuiltinFn<'static>>,
53    vars: Vec<String>,
54    prompts: Vec<String>,
55    secrets: Vec<String>,
56    client_context: Vec<String>,
57}
58
59impl Default for CompileTimeEnv {
60    fn default() -> Self {
61        Self {
62            builtins: BuiltinFn::DEFAULT_BUILTINS.to_vec(),
63            user_builtins: vec![],
64            vars: vec![],
65            prompts: vec![],
66            secrets: vec![],
67            client_context: vec![],
68        }
69    }
70}
71
72impl CompileTimeEnv {
73    pub fn new(
74        vars: Vec<String>,
75        prompts: Vec<String>,
76        secrets: Vec<String>,
77        client_context: Vec<String>,
78    ) -> Self {
79        Self {
80            vars,
81            prompts,
82            secrets,
83            client_context,
84            ..Default::default()
85        }
86    }
87
88    pub fn get_builtin_index(&self, name: &str) -> Option<(&BuiltinFn<'_>, u8)> {
89        let index = self.builtins.iter().position(|x| x.name == name);
90
91        index.map(|i| (self.builtins.get(i).unwrap(), i as u8))
92    }
93
94    pub fn get_user_builtin_index(&self, name: &str) -> Option<(&BuiltinFn<'_>, u8)> {
95        let index = self.user_builtins.iter().position(|x| x.name == name);
96
97        index.map(|i| (self.user_builtins.get(i).unwrap(), i as u8))
98    }
99
100    pub fn add_user_builtins(&mut self, builtins: Vec<BuiltinFn<'static>>) {
101        for builtin in builtins {
102            self.add_user_builtin(builtin);
103        }
104    }
105
106    pub fn add_user_builtin(&mut self, builtin: BuiltinFn<'static>) {
107        self.user_builtins.push(builtin);
108    }
109
110    pub fn get_builtin(&self, index: usize) -> Option<&BuiltinFn<'static>> {
111        self.builtins.get(index)
112    }
113
114    pub fn get_user_builtin(&self, index: usize) -> Option<&BuiltinFn<'static>> {
115        self.user_builtins.get(index)
116    }
117
118    pub fn get_var(&self, index: usize) -> Option<&String> {
119        self.vars.get(index)
120    }
121
122    pub fn get_var_index(&self, name: &str) -> Option<usize> {
123        self.vars
124            .iter()
125            .position(|context_name| context_name == name)
126    }
127
128    pub fn get_prompt(&self, index: usize) -> Option<&String> {
129        self.prompts.get(index)
130    }
131
132    pub fn get_prompt_index(&self, name: &str) -> Option<usize> {
133        self.prompts
134            .iter()
135            .position(|context_name| context_name == name)
136    }
137
138    pub fn get_secret(&self, index: usize) -> Option<&String> {
139        self.secrets.get(index)
140    }
141
142    pub fn get_secret_index(&self, name: &str) -> Option<usize> {
143        self.secrets
144            .iter()
145            .position(|context_name| context_name == name)
146    }
147
148    pub fn get_client_context(&self, index: usize) -> Option<&String> {
149        self.client_context.get(index)
150    }
151
152    pub fn add_to_client_context(&mut self, key: &str) -> usize {
153        match self.client_context.iter().position(|x| x == key) {
154            Some(i) => i,
155            None => {
156                self.client_context.push(key.to_string());
157
158                self.client_context.len() - 1
159            }
160        }
161    }
162
163    pub fn get_client_context_index(&self, name: &str) -> Option<(&String, u8)> {
164        let index = self
165            .client_context
166            .iter()
167            .position(|context_name| context_name == name);
168
169        index.map(|i| (self.client_context.get(i).unwrap(), i as u8))
170    }
171}
172
173/// The compiled bytecode for an expression
174#[derive(Debug, Clone, PartialEq)]
175pub struct ExprByteCode {
176    version: [u8; 4],
177    codes: Vec<u8>,
178    constants: Vec<Value>,
179    types: Vec<Type>,
180}
181
182impl ExprByteCode {
183    pub fn new(codes: Vec<u8>, constants: Vec<Value>, types: Vec<Type>) -> Self {
184        let version_bytes = get_version_bytes();
185        let version_bytes_from_codes = &codes[0..4];
186
187        assert_eq!(
188            version_bytes, version_bytes_from_codes,
189            "Version bytes do not match"
190        );
191
192        let codes = codes[4..].to_vec();
193
194        Self {
195            version: version_bytes,
196            codes,
197            constants,
198            types,
199        }
200    }
201
202    pub fn version(&self) -> &[u8; 4] {
203        &self.version
204    }
205
206    pub fn codes(&self) -> &[u8] {
207        &self.codes
208    }
209
210    pub fn constants(&self) -> &[Value] {
211        &self.constants
212    }
213
214    pub fn types(&self) -> &[Type] {
215        &self.types
216    }
217}
218
219pub fn get_version_bytes() -> [u8; 4] {
220    [
221        env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap(),
222        env!("CARGO_PKG_VERSION_MINOR").parse().unwrap(),
223        env!("CARGO_PKG_VERSION_PATCH").parse().unwrap(),
224        0,
225    ]
226}
227
228/// Compile an [`ast::Expr`] into [`ExprByteCode`]
229pub fn compile(expr: &mut ExprS, env: &CompileTimeEnv) -> ExprResult<ExprByteCode> {
230    let mut constants: Vec<Value> = vec![];
231    let mut types: Vec<Type> = vec![];
232    let mut codes = vec![];
233
234    codes.extend(get_version_bytes());
235
236    codes.extend(compile_expr(expr, env, &mut constants, &mut types)?);
237
238    Ok(ExprByteCode::new(codes, constants, types))
239}
240
241fn compile_expr(
242    (expr, span): &mut ExprS,
243    env: &CompileTimeEnv,
244    constants: &mut Vec<Value>,
245    types: &mut Vec<Type>,
246) -> ExprResult<Vec<u8>> {
247    use opcode::*;
248
249    let mut codes = vec![];
250    let mut errs: Vec<ExprErrorS> = vec![];
251
252    add_type_to_expr(expr, env);
253
254    match expr {
255        Expr::String(string) => {
256            if let Some(index) = constants.iter().position(|x| {
257                if let Value::String(string_constant) = x {
258                    string_constant == &string.0
259                } else {
260                    false
261                }
262            }) {
263                codes.push(CONSTANT);
264                codes.push(index as u8);
265            } else {
266                constants.push(Value::String(string.0.clone()));
267                let index = constants.len() - 1;
268                codes.push(CONSTANT);
269                codes.push(index as u8);
270            }
271        }
272        Expr::Number(number) => {
273            if let Some(index) = constants.iter().position(|x| {
274                if let Value::Number(value) = x {
275                    value == &number.0
276                } else {
277                    false
278                }
279            }) {
280                codes.push(CONSTANT);
281                codes.push(index as u8);
282            } else {
283                constants.push(Value::Number(number.0));
284                let index = constants.len() - 1;
285                codes.push(CONSTANT);
286                codes.push(index as u8);
287            }
288        }
289        Expr::Identifier(identifier) => {
290            let identifier_lookup_name = identifier.lookup_name();
291            let identifier_name = identifier.full_name().to_string();
292
293            let identifier_undefined_err = (
294                CompileError::Undefined(identifier_name.clone()).into(),
295                span.clone(),
296            );
297
298            let result = match identifier.identifier_kind() {
299                IdentifierKind::Var => get(&env.vars, identifier_lookup_name).map(|index| {
300                    codes.push(GET);
301                    codes.push(lookup::VAR);
302                    codes.push(index);
303                }),
304                IdentifierKind::Prompt => get(&env.prompts, identifier_lookup_name).map(|index| {
305                    codes.push(GET);
306                    codes.push(lookup::PROMPT);
307                    codes.push(index);
308                }),
309                IdentifierKind::Secret => get(&env.secrets, identifier_lookup_name).map(|index| {
310                    codes.push(GET);
311                    codes.push(lookup::SECRET);
312                    codes.push(index);
313                }),
314                IdentifierKind::Client => {
315                    get(&env.client_context, identifier_lookup_name).map(|index| {
316                        codes.push(GET);
317                        codes.push(lookup::CLIENT_CTX);
318                        codes.push(index);
319                    })
320                }
321                IdentifierKind::Builtin => {
322                    if let Some((_, index)) = env.get_builtin_index(identifier_lookup_name) {
323                        codes.push(GET);
324                        codes.push(lookup::BUILTIN);
325                        codes.push(index);
326
327                        Some(())
328                    } else if let Some((_, index)) =
329                        env.get_user_builtin_index(identifier_lookup_name)
330                    {
331                        codes.push(GET);
332                        codes.push(lookup::USER_BUILTIN);
333                        codes.push(index);
334
335                        Some(())
336                    } else {
337                        None
338                    }
339                }
340                IdentifierKind::Type => {
341                    let ty = Type::from(&identifier_name);
342                    if let Some(index) = types.iter().position(|x| x == &ty) {
343                        codes.push(GET);
344                        codes.push(TYPE);
345                        codes.push(index as u8);
346                    } else {
347                        types.push(ty);
348                        let index = types.len() - 1;
349                        codes.push(GET);
350                        codes.push(TYPE);
351                        codes.push(index as u8);
352                    }
353
354                    Some(())
355                }
356            };
357
358            if result.is_none() {
359                errs.push(identifier_undefined_err);
360            }
361        }
362        Expr::Call(expr_call) => {
363            let callee_bytecode = compile_expr(&mut expr_call.callee, env, constants, types)?;
364
365            if let Some(_op) = callee_bytecode.first()
366                && let Some(lookup) = callee_bytecode.get(1)
367                && let Some(index) = callee_bytecode.get(2)
368            {
369                match *lookup {
370                    lookup::BUILTIN => {
371                        let builtin = env.get_builtin((*index).into()).unwrap();
372
373                        let call_arity: usize = expr_call.args.len();
374
375                        if !builtin.arity_matches(call_arity.try_into().unwrap()) {
376                            errs.push((
377                                ExprError::CompileError(WrongNumberOfArgs {
378                                    expected: builtin.arity() as usize,
379                                    actual: call_arity,
380                                }),
381                                span.clone(),
382                            ));
383                        }
384
385                        let args: Vec<_> = expr_call.args.iter().take(call_arity).collect();
386
387                        for (i, fnarg) in builtin.args.iter().enumerate() {
388                            if let Some((a, a_span)) = args.get(i) {
389                                let a_type = a.get_type();
390
391                                let types_match = fnarg.ty == a_type
392                                    || fnarg.ty == Type::Value
393                                    || a_type == Type::Unknown;
394
395                                if !types_match {
396                                    errs.push((
397                                        CompileError::TypeMismatch {
398                                            expected: fnarg.ty.clone(),
399                                            actual: a_type.clone(),
400                                        }
401                                        .into(),
402                                        a_span.clone(),
403                                    ));
404                                }
405                            }
406                        }
407                    }
408                    lookup::USER_BUILTIN => {
409                        let builtin = env.get_user_builtin((*index).into()).unwrap();
410
411                        let call_arity: usize = expr_call.args.len();
412
413                        if !builtin.arity_matches(call_arity.try_into().unwrap()) {
414                            errs.push((
415                                ExprError::CompileError(WrongNumberOfArgs {
416                                    expected: builtin.arity() as usize,
417                                    actual: call_arity,
418                                }),
419                                span.clone(),
420                            ));
421                        }
422                    }
423                    lookup::CLIENT_CTX => {
424                        // No validation needs to be ran at this point
425                        // This won't happen until runtime when the client
426                        // a value.
427                    }
428                    _ => {
429                        errs.push((
430                            CompileError::InvalidLookupType(*lookup).into(),
431                            span.clone(),
432                        ));
433                    }
434                }
435            }
436
437            codes.extend(callee_bytecode);
438
439            for arg in expr_call.args.iter_mut() {
440                match compile_expr(arg, env, constants, types) {
441                    Ok(arg_bytecode) => {
442                        codes.extend(arg_bytecode);
443                    }
444                    Err(err) => {
445                        errs.extend(err);
446                    }
447                }
448            }
449
450            codes.push(opcode::CALL);
451            codes.push(expr_call.args.len() as u8);
452        }
453        Expr::Bool(value) => match value.0 {
454            true => {
455                codes.push(opcode::TRUE);
456            }
457            false => {
458                codes.push(opcode::FALSE);
459            }
460        },
461        Expr::Error => panic!("tried to compile despite parser errors"),
462    }
463
464    if !errs.is_empty() {
465        return Err(errs);
466    }
467
468    Ok(codes)
469}
470
471#[cfg(test)]
472mod compiler_tests {
473    use super::*;
474
475    #[test]
476    pub fn current_version_bytes() {
477        let version_bytes = get_version_bytes();
478
479        assert_eq!(version_bytes, [0, 8, 0, 0]);
480    }
481
482    #[test]
483    pub fn valid_bytecode_version_bytes() {
484        let mut codes = get_version_bytes().to_vec();
485        codes.push(opcode::TRUE);
486
487        ExprByteCode::new(codes.to_vec(), vec![], vec![]);
488    }
489
490    #[test]
491    #[should_panic(expected = "Version bytes do not match")]
492    pub fn invalid_bytecode_version_bytes() {
493        let mut codes: Vec<u8> = [0, 0, 0, 0].to_vec();
494        codes.push(opcode::TRUE);
495
496        ExprByteCode::new(codes.to_vec(), vec![], vec![]);
497    }
498
499    #[test]
500    pub fn get_version_bytes_from_bytecode() {
501        let mut codes = get_version_bytes().to_vec();
502        codes.push(opcode::TRUE);
503
504        let bytecode = ExprByteCode::new(codes.to_vec(), vec![], vec![]);
505
506        assert_eq!(bytecode.version(), &get_version_bytes());
507    }
508}