reqlang_expr/
vm.rs

1//! The virtual machine and associated types
2
3use std::{fmt::Display, rc::Rc};
4
5use crate::{
6    compiler::{
7        BuiltinFn, CompileTimeEnv, ExprByteCode,
8        lookup::{BUILTIN, PROMPT, SECRET, VAR},
9        opcode,
10    },
11    errors::ExprResult,
12    prelude::lookup::{CLIENT_CTX, USER_BUILTIN},
13};
14
15#[derive(Debug, Clone, PartialEq)]
16pub enum Value {
17    String(String),
18    Fn(Rc<BuiltinFn>),
19    Bool(bool),
20}
21
22impl Value {
23    pub fn get_string(&self) -> &str {
24        match self {
25            Value::String(s) => s.as_str(),
26            _ => panic!("Value is not a string"),
27        }
28    }
29
30    pub fn get_func(&self) -> Rc<BuiltinFn> {
31        match self {
32            Value::Fn(f) => f.clone(),
33            _ => panic!("Value is not a function"),
34        }
35    }
36
37    pub fn get_bool(&self) -> bool {
38        match self {
39            Value::Bool(s) => *s,
40            _ => panic!("Value is not a string"),
41        }
42    }
43}
44
45impl From<&str> for Value {
46    fn from(s: &str) -> Self {
47        Value::String(s.to_string())
48    }
49}
50
51impl Display for Value {
52    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
53        match self {
54            Value::String(string) => write!(f, "`{}`", string),
55            Value::Fn(builtin) => write!(f, "{builtin:?}"),
56            Value::Bool(value) => write!(f, "{}", value),
57        }
58    }
59}
60
61#[derive(Debug, Clone, Default)]
62pub struct RuntimeEnv {
63    pub vars: Vec<String>,
64    pub prompts: Vec<String>,
65    pub secrets: Vec<String>,
66    pub client_context: Vec<Value>,
67}
68
69impl RuntimeEnv {
70    pub fn add_to_client_context(&mut self, index: usize, value: Value) {
71        if index < self.client_context.len() {
72            self.client_context[index] = value;
73        } else {
74            self.client_context.push(value);
75        }
76    }
77}
78
79#[derive(Debug)]
80pub struct Vm {
81    bytecode: Option<Box<ExprByteCode>>,
82    ip: usize,
83    stack: Vec<Value>,
84}
85
86impl Default for Vm {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl Vm {
93    pub fn new() -> Self {
94        Self {
95            bytecode: None,
96            ip: 0,
97            stack: vec![],
98        }
99    }
100
101    pub fn interpret(
102        &mut self,
103        bytecode: Box<ExprByteCode>,
104        env: &CompileTimeEnv,
105        runtime_env: &RuntimeEnv,
106    ) -> ExprResult<Value> {
107        self.bytecode = Some(bytecode);
108        self.ip = 0;
109
110        while let Some(op_code) = self
111            .bytecode
112            .as_ref()
113            .and_then(|bc| bc.codes().get(self.ip))
114        {
115            self.interpret_op(env, runtime_env, *op_code);
116        }
117
118        assert_eq!(1, self.stack.len());
119
120        let value = self.stack_pop();
121
122        Ok(value)
123    }
124
125    fn interpret_op(&mut self, env: &CompileTimeEnv, runtime_env: &RuntimeEnv, op_code: u8) {
126        match op_code {
127            opcode::CALL => self.op_call(),
128            opcode::CONSTANT => self.op_constant(),
129            opcode::GET => self.op_get(env, runtime_env),
130            opcode::TRUE => self.op_true(),
131            opcode::FALSE => self.op_false(),
132            _ => panic!("Invalid OP code: {op_code}"),
133        }
134    }
135
136    fn op_call(&mut self) {
137        // Confirm the current op code is CALL
138        assert_eq!(opcode::CALL, self.read_u8(), "Expected CALL opcode");
139
140        let arg_count = self.read_u8() as usize;
141
142        let mut args: Vec<Value> = vec![];
143
144        for _ in 0..arg_count {
145            args.push(self.stack_pop());
146        }
147
148        args.reverse();
149
150        let value = self.stack_pop();
151
152        let builtin = value.get_func().func.clone();
153
154        let result = builtin(args);
155
156        self.stack_push(result);
157    }
158
159    fn op_get(&mut self, env: &CompileTimeEnv, runtime_env: &RuntimeEnv) {
160        assert_eq!(opcode::GET, self.read_u8(), "Expected GET opcode");
161        let get_lookup = self.read_u8();
162        let get_idx = self.read_u8() as usize;
163
164        match get_lookup {
165            BUILTIN => {
166                let value = env
167                    .get_builtin(get_idx)
168                    .unwrap_or_else(|| panic!("undefined builtin: {get_idx}"));
169                self.stack_push(Value::Fn(value.clone()));
170            }
171            USER_BUILTIN => {
172                let value = env
173                    .get_user_builtin(get_idx)
174                    .unwrap_or_else(|| panic!("undefined user builtin: {get_idx}"));
175                self.stack_push(Value::Fn(value.clone()));
176            }
177            VAR => {
178                let value = env
179                    .get_var(get_idx)
180                    .and_then(|_| runtime_env.vars.get(get_idx))
181                    .unwrap_or_else(|| panic!("undefined variable: {get_idx}"));
182
183                self.stack_push(Value::String(value.clone()));
184            }
185            PROMPT => {
186                let value = env
187                    .get_prompt(get_idx)
188                    .and_then(|_| runtime_env.prompts.get(get_idx))
189                    .unwrap_or_else(|| panic!("undefined prompt: {get_idx}"));
190
191                self.stack_push(Value::String(value.clone()));
192            }
193            SECRET => {
194                let value = env
195                    .get_secret(get_idx)
196                    .and_then(|_| runtime_env.secrets.get(get_idx))
197                    .unwrap_or_else(|| panic!("undefined secret: {get_idx}"));
198
199                self.stack_push(Value::String(value.clone()));
200            }
201            CLIENT_CTX => {
202                let value = env
203                    .get_client_context(get_idx)
204                    .and_then(|_| runtime_env.client_context.get(get_idx))
205                    .unwrap_or_else(|| panic!("undefined client context: {get_idx}"));
206
207                self.stack_push(value.clone());
208            }
209            _ => panic!("invalid get lookup code: {}", get_lookup),
210        };
211    }
212
213    fn op_constant(&mut self) {
214        assert_eq!(opcode::CONSTANT, self.read_u8(), "Expected CONSTANT opcode");
215
216        let get_idx = self.read_u8() as usize;
217
218        let s = self
219            .bytecode
220            .as_ref()
221            .expect("should have bytecode")
222            .strings()
223            .get(get_idx)
224            .unwrap_or_else(|| panic!("undefined string: {}", get_idx));
225
226        self.stack_push(Value::String(s.clone()));
227    }
228
229    fn op_true(&mut self) {
230        assert_eq!(opcode::TRUE, self.read_u8(), "Expected TRUE opcode");
231
232        self.stack_push(Value::Bool(true));
233    }
234
235    fn op_false(&mut self) {
236        assert_eq!(opcode::FALSE, self.read_u8(), "Expected FALSE opcode");
237
238        self.stack_push(Value::Bool(false));
239    }
240
241    fn stack_push(&mut self, value: Value) {
242        self.stack.push(value);
243    }
244
245    fn stack_pop(&mut self) -> Value {
246        self.stack
247            .pop()
248            .expect("should have a value to pop from the stack")
249    }
250
251    fn read_u8(&mut self) -> u8 {
252        let current_ip = self.ip as u8;
253
254        self.ip += 1;
255
256        *self
257            .bytecode
258            .as_ref()
259            .expect("should have bytecode")
260            .codes()
261            .get(current_ip as usize)
262            .expect("should have op in bytecode at {}")
263    }
264}