1use crate::{
4 ast::{Expr, ExprS, IdentifierKind, add_type_to_expr},
5 builtins::BuiltinFn,
6 errors::{
7 CompileError::{self, WrongNumberOfArgs},
8 ExprError, ExprErrorS, ExprResult,
9 },
10 prelude::lookup::TYPE,
11 types::Type,
12 value::Value,
13};
14
15pub mod opcode {
16 iota::iota! {
17 pub const
18 CALL: u8 = iota;,
19 GET,
20 CONSTANT,
21 TRUE,
22 FALSE
23 }
24}
25
26pub mod lookup {
32 iota::iota! {
33 pub const
34 BUILTIN: u8 = iota;,
35 VAR,
36 PROMPT,
37 SECRET,
38 USER_BUILTIN,
39 CLIENT_CTX,
40 TYPE
41 }
42}
43
44fn get(list: &[String], identifier: &str) -> Option<u8> {
46 list.iter().position(|x| x == identifier).map(|i| i as u8)
47}
48
49#[derive(Debug)]
50pub struct CompileTimeEnv {
51 builtins: Vec<BuiltinFn<'static>>,
52 user_builtins: Vec<BuiltinFn<'static>>,
53 vars: Vec<String>,
54 prompts: Vec<String>,
55 secrets: Vec<String>,
56 client_context: Vec<String>,
57}
58
59impl Default for CompileTimeEnv {
60 fn default() -> Self {
61 Self {
62 builtins: BuiltinFn::DEFAULT_BUILTINS.to_vec(),
63 user_builtins: vec![],
64 vars: vec![],
65 prompts: vec![],
66 secrets: vec![],
67 client_context: vec![],
68 }
69 }
70}
71
72impl CompileTimeEnv {
73 pub fn new(
74 vars: Vec<String>,
75 prompts: Vec<String>,
76 secrets: Vec<String>,
77 client_context: Vec<String>,
78 ) -> Self {
79 Self {
80 vars,
81 prompts,
82 secrets,
83 client_context,
84 ..Default::default()
85 }
86 }
87
88 pub fn get_builtin_index(&self, name: &str) -> Option<(&BuiltinFn<'_>, u8)> {
89 let index = self.builtins.iter().position(|x| x.name == name);
90
91 index.map(|i| (self.builtins.get(i).unwrap(), i as u8))
92 }
93
94 pub fn get_user_builtin_index(&self, name: &str) -> Option<(&BuiltinFn<'_>, u8)> {
95 let index = self.user_builtins.iter().position(|x| x.name == name);
96
97 index.map(|i| (self.user_builtins.get(i).unwrap(), i as u8))
98 }
99
100 pub fn add_user_builtins(&mut self, builtins: Vec<BuiltinFn<'static>>) {
101 for builtin in builtins {
102 self.add_user_builtin(builtin);
103 }
104 }
105
106 pub fn add_user_builtin(&mut self, builtin: BuiltinFn<'static>) {
107 self.user_builtins.push(builtin);
108 }
109
110 pub fn get_builtin(&self, index: usize) -> Option<&BuiltinFn<'static>> {
111 self.builtins.get(index)
112 }
113
114 pub fn get_user_builtin(&self, index: usize) -> Option<&BuiltinFn<'static>> {
115 self.user_builtins.get(index)
116 }
117
118 pub fn get_var(&self, index: usize) -> Option<&String> {
119 self.vars.get(index)
120 }
121
122 pub fn get_var_index(&self, name: &str) -> Option<usize> {
123 self.vars
124 .iter()
125 .position(|context_name| context_name == name)
126 }
127
128 pub fn get_prompt(&self, index: usize) -> Option<&String> {
129 self.prompts.get(index)
130 }
131
132 pub fn get_prompt_index(&self, name: &str) -> Option<usize> {
133 self.prompts
134 .iter()
135 .position(|context_name| context_name == name)
136 }
137
138 pub fn get_secret(&self, index: usize) -> Option<&String> {
139 self.secrets.get(index)
140 }
141
142 pub fn get_secret_index(&self, name: &str) -> Option<usize> {
143 self.secrets
144 .iter()
145 .position(|context_name| context_name == name)
146 }
147
148 pub fn get_client_context(&self, index: usize) -> Option<&String> {
149 self.client_context.get(index)
150 }
151
152 pub fn add_to_client_context(&mut self, key: &str) -> usize {
153 match self.client_context.iter().position(|x| x == key) {
154 Some(i) => i,
155 None => {
156 self.client_context.push(key.to_string());
157
158 self.client_context.len() - 1
159 }
160 }
161 }
162
163 pub fn get_client_context_index(&self, name: &str) -> Option<(&String, u8)> {
164 let index = self
165 .client_context
166 .iter()
167 .position(|context_name| context_name == name);
168
169 index.map(|i| (self.client_context.get(i).unwrap(), i as u8))
170 }
171}
172
173#[derive(Debug, Clone, PartialEq)]
175pub struct ExprByteCode {
176 version: [u8; 4],
177 codes: Vec<u8>,
178 constants: Vec<Value>,
179 types: Vec<Type>,
180}
181
182impl ExprByteCode {
183 pub fn new(codes: Vec<u8>, constants: Vec<Value>, types: Vec<Type>) -> Self {
184 let version_bytes = get_version_bytes();
185 let version_bytes_from_codes = &codes[0..4];
186
187 assert_eq!(
188 version_bytes, version_bytes_from_codes,
189 "Version bytes do not match"
190 );
191
192 let codes = codes[4..].to_vec();
193
194 Self {
195 version: version_bytes,
196 codes,
197 constants,
198 types,
199 }
200 }
201
202 pub fn version(&self) -> &[u8; 4] {
203 &self.version
204 }
205
206 pub fn codes(&self) -> &[u8] {
207 &self.codes
208 }
209
210 pub fn constants(&self) -> &[Value] {
211 &self.constants
212 }
213
214 pub fn types(&self) -> &[Type] {
215 &self.types
216 }
217}
218
219pub fn get_version_bytes() -> [u8; 4] {
220 [
221 env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap(),
222 env!("CARGO_PKG_VERSION_MINOR").parse().unwrap(),
223 env!("CARGO_PKG_VERSION_PATCH").parse().unwrap(),
224 0,
225 ]
226}
227
228pub fn compile(expr: &mut ExprS, env: &CompileTimeEnv) -> ExprResult<ExprByteCode> {
230 let mut constants: Vec<Value> = vec![];
231 let mut types: Vec<Type> = vec![];
232 let mut codes = vec![];
233
234 codes.extend(get_version_bytes());
235
236 codes.extend(compile_expr(expr, env, &mut constants, &mut types)?);
237
238 Ok(ExprByteCode::new(codes, constants, types))
239}
240
241fn compile_expr(
242 (expr, span): &mut ExprS,
243 env: &CompileTimeEnv,
244 constants: &mut Vec<Value>,
245 types: &mut Vec<Type>,
246) -> ExprResult<Vec<u8>> {
247 use opcode::*;
248
249 let mut codes = vec![];
250 let mut errs: Vec<ExprErrorS> = vec![];
251
252 add_type_to_expr(expr, env);
253
254 match expr {
255 Expr::String(string) => {
256 if let Some(index) = constants.iter().position(|x| {
257 if let Value::String(string_constant) = x {
258 string_constant == &string.0
259 } else {
260 false
261 }
262 }) {
263 codes.push(CONSTANT);
264 codes.push(index as u8);
265 } else {
266 constants.push(Value::String(string.0.clone()));
267 let index = constants.len() - 1;
268 codes.push(CONSTANT);
269 codes.push(index as u8);
270 }
271 }
272 Expr::Number(number) => {
273 if let Some(index) = constants.iter().position(|x| {
274 if let Value::Number(value) = x {
275 value == &number.0
276 } else {
277 false
278 }
279 }) {
280 codes.push(CONSTANT);
281 codes.push(index as u8);
282 } else {
283 constants.push(Value::Number(number.0));
284 let index = constants.len() - 1;
285 codes.push(CONSTANT);
286 codes.push(index as u8);
287 }
288 }
289 Expr::Identifier(identifier) => {
290 let identifier_lookup_name = identifier.lookup_name();
291 let identifier_name = identifier.full_name().to_string();
292
293 let identifier_undefined_err = (
294 CompileError::Undefined(identifier_name.clone()).into(),
295 span.clone(),
296 );
297
298 let result = match identifier.identifier_kind() {
299 IdentifierKind::Var => get(&env.vars, identifier_lookup_name).map(|index| {
300 codes.push(GET);
301 codes.push(lookup::VAR);
302 codes.push(index);
303 }),
304 IdentifierKind::Prompt => get(&env.prompts, identifier_lookup_name).map(|index| {
305 codes.push(GET);
306 codes.push(lookup::PROMPT);
307 codes.push(index);
308 }),
309 IdentifierKind::Secret => get(&env.secrets, identifier_lookup_name).map(|index| {
310 codes.push(GET);
311 codes.push(lookup::SECRET);
312 codes.push(index);
313 }),
314 IdentifierKind::Client => {
315 get(&env.client_context, identifier_lookup_name).map(|index| {
316 codes.push(GET);
317 codes.push(lookup::CLIENT_CTX);
318 codes.push(index);
319 })
320 }
321 IdentifierKind::Builtin => {
322 if let Some((_, index)) = env.get_builtin_index(identifier_lookup_name) {
323 codes.push(GET);
324 codes.push(lookup::BUILTIN);
325 codes.push(index);
326
327 Some(())
328 } else if let Some((_, index)) =
329 env.get_user_builtin_index(identifier_lookup_name)
330 {
331 codes.push(GET);
332 codes.push(lookup::USER_BUILTIN);
333 codes.push(index);
334
335 Some(())
336 } else {
337 None
338 }
339 }
340 IdentifierKind::Type => {
341 let ty = Type::from(&identifier_name);
342 if let Some(index) = types.iter().position(|x| x == &ty) {
343 codes.push(GET);
344 codes.push(TYPE);
345 codes.push(index as u8);
346 } else {
347 types.push(ty);
348 let index = types.len() - 1;
349 codes.push(GET);
350 codes.push(TYPE);
351 codes.push(index as u8);
352 }
353
354 Some(())
355 }
356 };
357
358 if result.is_none() {
359 errs.push(identifier_undefined_err);
360 }
361 }
362 Expr::Call(expr_call) => {
363 let callee_bytecode = compile_expr(&mut expr_call.callee, env, constants, types)?;
364
365 if let Some(_op) = callee_bytecode.first()
366 && let Some(lookup) = callee_bytecode.get(1)
367 && let Some(index) = callee_bytecode.get(2)
368 {
369 match *lookup {
370 lookup::BUILTIN => {
371 let builtin = env.get_builtin((*index).into()).unwrap();
372
373 let call_arity: usize = expr_call.args.len();
374
375 if !builtin.arity_matches(call_arity.try_into().unwrap()) {
376 errs.push((
377 ExprError::CompileError(WrongNumberOfArgs {
378 expected: builtin.arity() as usize,
379 actual: call_arity,
380 }),
381 span.clone(),
382 ));
383 }
384
385 let args: Vec<_> = expr_call.args.iter().take(call_arity).collect();
386
387 for (i, fnarg) in builtin.args.iter().enumerate() {
388 if let Some((a, a_span)) = args.get(i) {
389 let a_type = a.get_type();
390
391 let types_match = fnarg.ty == a_type
392 || fnarg.ty == Type::Value
393 || a_type == Type::Unknown;
394
395 if !types_match {
396 errs.push((
397 CompileError::TypeMismatch {
398 expected: fnarg.ty.clone(),
399 actual: a_type.clone(),
400 }
401 .into(),
402 a_span.clone(),
403 ));
404 }
405 }
406 }
407 }
408 lookup::USER_BUILTIN => {
409 let builtin = env.get_user_builtin((*index).into()).unwrap();
410
411 let call_arity: usize = expr_call.args.len();
412
413 if !builtin.arity_matches(call_arity.try_into().unwrap()) {
414 errs.push((
415 ExprError::CompileError(WrongNumberOfArgs {
416 expected: builtin.arity() as usize,
417 actual: call_arity,
418 }),
419 span.clone(),
420 ));
421 }
422 }
423 lookup::CLIENT_CTX => {
424 }
428 _ => {
429 errs.push((
430 CompileError::InvalidLookupType(*lookup).into(),
431 span.clone(),
432 ));
433 }
434 }
435 }
436
437 codes.extend(callee_bytecode);
438
439 for arg in expr_call.args.iter_mut() {
440 match compile_expr(arg, env, constants, types) {
441 Ok(arg_bytecode) => {
442 codes.extend(arg_bytecode);
443 }
444 Err(err) => {
445 errs.extend(err);
446 }
447 }
448 }
449
450 codes.push(opcode::CALL);
451 codes.push(expr_call.args.len() as u8);
452 }
453 Expr::Bool(value) => match value.0 {
454 true => {
455 codes.push(opcode::TRUE);
456 }
457 false => {
458 codes.push(opcode::FALSE);
459 }
460 },
461 Expr::Error => panic!("tried to compile despite parser errors"),
462 }
463
464 if !errs.is_empty() {
465 return Err(errs);
466 }
467
468 Ok(codes)
469}
470
471#[cfg(test)]
472mod compiler_tests {
473 use super::*;
474
475 #[test]
476 pub fn current_version_bytes() {
477 let version_bytes = get_version_bytes();
478
479 assert_eq!(version_bytes, [0, 8, 0, 0]);
480 }
481
482 #[test]
483 pub fn valid_bytecode_version_bytes() {
484 let mut codes = get_version_bytes().to_vec();
485 codes.push(opcode::TRUE);
486
487 ExprByteCode::new(codes.to_vec(), vec![], vec![]);
488 }
489
490 #[test]
491 #[should_panic(expected = "Version bytes do not match")]
492 pub fn invalid_bytecode_version_bytes() {
493 let mut codes: Vec<u8> = [0, 0, 0, 0].to_vec();
494 codes.push(opcode::TRUE);
495
496 ExprByteCode::new(codes.to_vec(), vec![], vec![]);
497 }
498
499 #[test]
500 pub fn get_version_bytes_from_bytecode() {
501 let mut codes = get_version_bytes().to_vec();
502 codes.push(opcode::TRUE);
503
504 let bytecode = ExprByteCode::new(codes.to_vec(), vec![], vec![]);
505
506 assert_eq!(bytecode.version(), &get_version_bytes());
507 }
508}