1use std::rc::Rc;
4
5use crate::{
6 ast::{Expr, ExprS},
7 builtins::{BuiltinFn, BuiltinFns},
8 errors::{
9 CompileError::{self, WrongNumberOfArgs},
10 ExprError, ExprErrorS, ExprResult,
11 },
12 prelude::FnArg,
13 types::Type,
14};
15
16pub mod opcode {
17 iota::iota! {
18 pub const
19 CALL: u8 = iota;,
20 GET,
21 CONSTANT,
22 TRUE,
23 FALSE,
24 NOT,
25 EQ,
26 TYPE
27 }
28}
29
30pub mod lookup {
36 iota::iota! {
37 pub const
38 BUILTIN: u8 = iota;,
39 VAR,
40 PROMPT,
41 SECRET,
42 USER_BUILTIN,
43 CLIENT_CTX
44 }
45}
46
47fn get(list: &[String], identifier: &str) -> Option<u8> {
49 list.iter().position(|x| x == identifier).map(|i| i as u8)
50}
51
52#[derive(Debug)]
53pub struct CompileTimeEnv {
54 builtins: Vec<Rc<BuiltinFn>>,
55 user_builtins: Vec<Rc<BuiltinFn>>,
56 vars: Vec<String>,
57 prompts: Vec<String>,
58 secrets: Vec<String>,
59 client_context: Vec<String>,
60}
61
62impl Default for CompileTimeEnv {
63 fn default() -> Self {
64 Self {
65 builtins: vec![
66 Rc::new(BuiltinFn {
67 name: String::from("id"),
68 args: vec![FnArg::new("value", Type::Value)],
69 return_type: Type::Value,
70 func: Rc::new(BuiltinFns::id),
71 }),
72 Rc::new(BuiltinFn {
73 name: String::from("noop"),
74 args: vec![],
75 return_type: Type::String,
76 func: Rc::new(BuiltinFns::noop),
77 }),
78 Rc::new(BuiltinFn {
79 name: String::from("is_empty"),
80 args: vec![FnArg::new("value", Type::String)],
81 return_type: Type::String,
82 func: Rc::new(BuiltinFns::is_empty),
83 }),
84 Rc::new(BuiltinFn {
85 name: String::from("and"),
86 args: vec![FnArg::new("a", Type::Bool), FnArg::new("b", Type::Bool)],
87 return_type: Type::Bool,
88 func: Rc::new(BuiltinFns::and),
89 }),
90 Rc::new(BuiltinFn {
91 name: String::from("or"),
92 args: vec![FnArg::new("a", Type::Bool), FnArg::new("b", Type::Bool)],
93 return_type: Type::Bool,
94 func: Rc::new(BuiltinFns::or),
95 }),
96 Rc::new(BuiltinFn {
97 name: String::from("cond"),
98 args: vec![
99 FnArg::new("cond", Type::Bool),
100 FnArg::new("then", Type::Value),
101 FnArg::new("else", Type::Value),
102 ],
103 return_type: Type::Bool,
104 func: Rc::new(BuiltinFns::cond),
105 }),
106 Rc::new(BuiltinFn {
107 name: String::from("to_str"),
108 args: vec![FnArg::new("value", Type::Value)],
109 return_type: Type::String,
110 func: Rc::new(BuiltinFns::to_str),
111 }),
112 Rc::new(BuiltinFn {
113 name: String::from("concat"),
114 args: vec![
115 FnArg::new("a", Type::String),
116 FnArg::new("b", Type::String),
117 FnArg::new_varadic("rest", Type::String),
118 ],
119 return_type: Type::String,
120 func: Rc::new(BuiltinFns::concat),
121 }),
122 Rc::new(BuiltinFn {
123 name: String::from("contains"),
124 args: vec![
125 FnArg::new("needle", Type::String),
126 FnArg::new("haystack", Type::String),
127 ],
128 return_type: Type::Bool,
129 func: Rc::new(BuiltinFns::contains),
130 }),
131 Rc::new(BuiltinFn {
132 name: String::from("trim"),
133 args: vec![FnArg::new("value", Type::String)],
134 return_type: Type::String,
135 func: Rc::new(BuiltinFns::trim),
136 }),
137 Rc::new(BuiltinFn {
138 name: String::from("trim_start"),
139 args: vec![FnArg::new("value", Type::String)],
140 return_type: Type::String,
141 func: Rc::new(BuiltinFns::trim_start),
142 }),
143 Rc::new(BuiltinFn {
144 name: String::from("trim_end"),
145 args: vec![FnArg::new("value", Type::String)],
146 return_type: Type::String,
147 func: Rc::new(BuiltinFns::trim_end),
148 }),
149 Rc::new(BuiltinFn {
150 name: String::from("lowercase"),
151 args: vec![FnArg::new("value", Type::String)],
152 return_type: Type::String,
153 func: Rc::new(BuiltinFns::lowercase),
154 }),
155 Rc::new(BuiltinFn {
156 name: String::from("uppercase"),
157 args: vec![FnArg::new("value", Type::String)],
158 return_type: Type::String,
159 func: Rc::new(BuiltinFns::uppercase),
160 }),
161 ],
162 user_builtins: vec![],
163 vars: Vec::new(),
164 prompts: Vec::new(),
165 secrets: Vec::new(),
166 client_context: Vec::new(),
167 }
168 }
169}
170
171impl CompileTimeEnv {
172 pub fn new(
173 vars: Vec<String>,
174 prompts: Vec<String>,
175 secrets: Vec<String>,
176 client_context: Vec<String>,
177 ) -> Self {
178 Self {
179 vars,
180 prompts,
181 secrets,
182 client_context,
183 ..Default::default()
184 }
185 }
186
187 pub fn get_builtin_index(&self, name: &str) -> Option<(&Rc<BuiltinFn>, u8)> {
188 let index = self.builtins.iter().position(|x| x.name == name);
189
190 let result = index.map(|i| (self.builtins.get(i).unwrap(), i as u8));
191 result
192 }
193
194 pub fn get_user_builtin_index(&self, name: &str) -> Option<(&Rc<BuiltinFn>, u8)> {
195 let index = self.user_builtins.iter().position(|x| x.name == name);
196
197 let result = index.map(|i| (self.user_builtins.get(i).unwrap(), i as u8));
198 result
199 }
200
201 pub fn add_user_builtins(&mut self, builtins: Vec<Rc<BuiltinFn>>) {
202 for builtin in builtins {
203 self.add_user_builtin(builtin);
204 }
205 }
206
207 pub fn add_user_builtin(&mut self, builtin: Rc<BuiltinFn>) {
208 self.user_builtins.push(builtin);
209 }
210
211 pub fn get_builtin(&self, index: usize) -> Option<&Rc<BuiltinFn>> {
212 self.builtins.get(index)
213 }
214
215 pub fn get_user_builtin(&self, index: usize) -> Option<&Rc<BuiltinFn>> {
216 self.user_builtins.get(index)
217 }
218
219 pub fn get_var(&self, index: usize) -> Option<&String> {
220 self.vars.get(index)
221 }
222
223 pub fn get_prompt(&self, index: usize) -> Option<&String> {
224 self.prompts.get(index)
225 }
226
227 pub fn get_secret(&self, index: usize) -> Option<&String> {
228 self.secrets.get(index)
229 }
230
231 pub fn get_client_context(&self, index: usize) -> Option<&String> {
232 self.client_context.get(index)
233 }
234
235 pub fn add_to_client_context(&mut self, key: &str) -> usize {
236 match self.client_context.iter().position(|x| x == key) {
237 Some(i) => i,
238 None => {
239 self.client_context.push(key.to_string());
240
241 self.client_context.len() - 1
242 }
243 }
244 }
245
246 pub fn add_keys_to_client_context(&mut self, keys: Vec<String>) {
247 self.client_context.extend(keys);
248 }
249
250 pub fn get_client_context_index(&self, name: &str) -> Option<(&String, u8)> {
251 let index = self
252 .client_context
253 .iter()
254 .position(|context_name| context_name == name);
255
256 let result = index.map(|i| (self.client_context.get(i).unwrap(), i as u8));
257 result
258 }
259}
260
261#[derive(Debug, Clone, PartialEq)]
263pub struct ExprByteCode {
264 version: [u8; 4],
265 codes: Vec<u8>,
266 strings: Vec<String>,
267}
268
269impl ExprByteCode {
270 pub fn new(codes: Vec<u8>, strings: Vec<String>) -> Self {
271 let version_bytes = get_version_bytes();
272 let version_bytes_from_codes = &codes[0..4];
273
274 assert_eq!(
275 version_bytes, version_bytes_from_codes,
276 "Version bytes do not match"
277 );
278
279 let codes = codes[4..].to_vec();
280
281 Self {
282 version: version_bytes,
283 codes,
284 strings,
285 }
286 }
287
288 pub fn version(&self) -> &[u8; 4] {
289 &self.version
290 }
291
292 pub fn codes(&self) -> &[u8] {
293 &self.codes
294 }
295
296 pub fn strings(&self) -> &[String] {
297 &self.strings
298 }
299}
300
301pub fn get_version_bytes() -> [u8; 4] {
302 [
303 env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap(),
304 env!("CARGO_PKG_VERSION_MINOR").parse().unwrap(),
305 env!("CARGO_PKG_VERSION_PATCH").parse().unwrap(),
306 0,
307 ]
308}
309
310pub fn compile(expr: &ExprS, env: &CompileTimeEnv) -> ExprResult<ExprByteCode> {
312 let mut strings: Vec<String> = vec![];
313 let mut codes = vec![];
314
315 codes.extend(get_version_bytes());
316
317 codes.extend(compile_expr(expr, env, &mut strings)?);
318
319 Ok(ExprByteCode::new(codes, strings))
320}
321
322fn compile_expr(
323 (expr, span): &ExprS,
324 env: &CompileTimeEnv,
325 strings: &mut Vec<String>,
326) -> ExprResult<Vec<u8>> {
327 use opcode::*;
328
329 let mut codes = vec![];
330 let mut errs: Vec<ExprErrorS> = vec![];
331
332 match expr {
333 Expr::String(string) => {
334 if let Some(index) = strings.iter().position(|x| x == &string.0) {
335 codes.push(CONSTANT);
336 codes.push(index as u8);
337 } else {
338 strings.push(string.0.clone());
339 let index = strings.len() - 1;
340 codes.push(CONSTANT);
341 codes.push(index as u8);
342 }
343 }
344 Expr::Identifier(identifier) => {
345 let identifier_name = identifier.0.as_str();
346
347 if let Some((_, index)) = env.get_builtin_index(identifier_name) {
348 codes.push(GET);
349 codes.push(lookup::BUILTIN);
350 codes.push(index);
351 } else if let Some((_, index)) = env.get_user_builtin_index(identifier_name) {
352 codes.push(GET);
353 codes.push(lookup::USER_BUILTIN);
354 codes.push(index);
355 } else {
356 let identifier_prefix = &identifier_name[..1];
357 let identifier_suffix = &identifier_name[1..];
358
359 match identifier_prefix {
360 "?" => {
361 if let Some(index) = get(&env.prompts, identifier_suffix) {
362 codes.push(GET);
363 codes.push(lookup::PROMPT);
364 codes.push(index);
365 } else {
366 errs.push((
367 CompileError::Undefined(identifier_name.to_string()).into(),
368 span.clone(),
369 ));
370 }
371 }
372 "!" => {
373 if let Some(index) = get(&env.secrets, identifier_suffix) {
374 codes.push(GET);
375 codes.push(lookup::SECRET);
376 codes.push(index);
377 } else {
378 errs.push((
379 CompileError::Undefined(identifier_name.to_string()).into(),
380 span.clone(),
381 ));
382 }
383 }
384 ":" => {
385 if let Some(index) = get(&env.vars, identifier_suffix) {
386 codes.push(GET);
387 codes.push(lookup::VAR);
388 codes.push(index);
389 } else {
390 errs.push((
391 CompileError::Undefined(identifier_name.to_string()).into(),
392 span.clone(),
393 ));
394 }
395 }
396 "@" => {
397 if let Some(index) = get(&env.client_context, identifier_suffix) {
398 codes.push(GET);
399 codes.push(lookup::CLIENT_CTX);
400 codes.push(index);
401 } else {
402 errs.push((
403 CompileError::Undefined(identifier_name.to_string()).into(),
404 span.clone(),
405 ));
406 }
407 }
408 _ => {
409 errs.push((
410 ExprError::CompileError(CompileError::Undefined(
411 identifier_name.to_string(),
412 )),
413 span.clone(),
414 ));
415 }
416 };
417 }
418 }
419 Expr::Call(expr_call) => {
420 let identifier_name = expr_call.callee.0.identifier_name().unwrap_or_default();
421
422 match identifier_name {
423 "type" => {
424 if expr_call.args.is_empty() {
425 errs.push((
426 ExprError::CompileError(WrongNumberOfArgs {
427 expected: 1,
428 actual: 0,
429 }),
430 span.clone(),
431 ));
432 } else if expr_call.args.len() > 1 {
433 errs.push((
434 ExprError::CompileError(WrongNumberOfArgs {
435 expected: 1,
436 actual: expr_call.args.len(),
437 }),
438 span.clone(),
439 ));
440 } else {
441 let arg = expr_call.args.first().expect("should have first argument");
442
443 match compile_expr(arg, env, strings) {
444 Ok(arg_bytecode) => {
445 codes.extend(arg_bytecode);
446 codes.push(opcode::TYPE);
447 }
448 Err(err) => {
449 errs.extend(err);
450 }
451 }
452 }
453 }
454 "eq" => {
455 if expr_call.args.is_empty() {
456 errs.push((
457 ExprError::CompileError(WrongNumberOfArgs {
458 expected: 2,
459 actual: 0,
460 }),
461 span.clone(),
462 ));
463 } else if expr_call.args.len() == 1 {
464 errs.push((
465 ExprError::CompileError(WrongNumberOfArgs {
466 expected: 2,
467 actual: 1,
468 }),
469 span.clone(),
470 ));
471 } else if expr_call.args.len() > 2 {
472 errs.push((
473 ExprError::CompileError(WrongNumberOfArgs {
474 expected: 2,
475 actual: expr_call.args.len(),
476 }),
477 span.clone(),
478 ));
479 } else {
480 let arg = expr_call.args.first().expect("should have first argument");
481
482 match compile_expr(arg, env, strings) {
483 Ok(arg_bytecode) => {
484 codes.extend(arg_bytecode);
485 }
486 Err(err) => {
487 errs.extend(err);
488 }
489 }
490
491 let arg2 = expr_call.args.get(1).expect("should have second argument");
492
493 match compile_expr(arg2, env, strings) {
494 Ok(arg_bytecode) => {
495 codes.extend(arg_bytecode);
496 }
497 Err(err) => {
498 errs.extend(err);
499 }
500 }
501
502 codes.push(opcode::EQ);
503 }
504 }
505 "not" => {
506 if expr_call.args.is_empty() {
507 errs.push((
508 ExprError::CompileError(WrongNumberOfArgs {
509 expected: 1,
510 actual: 0,
511 }),
512 span.clone(),
513 ));
514 } else if expr_call.args.len() > 1 {
515 errs.push((
516 ExprError::CompileError(WrongNumberOfArgs {
517 expected: 1,
518 actual: expr_call.args.len(),
519 }),
520 span.clone(),
521 ));
522
523 let arg = expr_call.args.first().expect("should have first argument");
524
525 if !arg.0.is_bool() {
526 errs.push((
527 CompileError::TypeMismatch {
528 expected: Type::Bool,
529 actual: arg.0.get_type(),
530 }
531 .into(),
532 arg.1.clone(),
533 ));
534 }
535 } else {
536 let arg = expr_call.args.first().expect("should have first argument");
537 if !arg.0.is_bool() {
538 errs.push((
539 CompileError::TypeMismatch {
540 expected: Type::Bool,
541 actual: arg.0.get_type(),
542 }
543 .into(),
544 arg.1.clone(),
545 ));
546 }
547
548 match compile_expr(arg, env, strings) {
549 Ok(arg_bytecode) => {
550 codes.extend(arg_bytecode);
551 }
552 Err(err) => {
553 errs.extend(err);
554 }
555 }
556
557 codes.push(opcode::NOT);
558 }
559 }
560 _ => {
561 let callee_bytecode = compile_expr(&expr_call.callee, env, strings)?;
562
563 if let Some(_op) = callee_bytecode.first() {
564 if let Some(lookup) = callee_bytecode.get(1) {
565 if let Some(index) = callee_bytecode.get(2) {
566 match *lookup {
567 lookup::BUILTIN => {
568 let builtin = env.get_builtin((*index).into()).unwrap();
569
570 let call_arity: usize = expr_call.args.len();
571
572 if !builtin.arity_matches(call_arity.try_into().unwrap()) {
573 errs.push((
574 ExprError::CompileError(WrongNumberOfArgs {
575 expected: builtin.arity() as usize,
576 actual: call_arity,
577 }),
578 span.clone(),
579 ));
580 }
581 }
582 lookup::USER_BUILTIN => {
583 let builtin =
584 env.get_user_builtin((*index).into()).unwrap();
585
586 let call_arity: usize = expr_call.args.len();
587
588 if !builtin.arity_matches(call_arity.try_into().unwrap()) {
589 errs.push((
590 ExprError::CompileError(WrongNumberOfArgs {
591 expected: builtin.arity() as usize,
592 actual: call_arity,
593 }),
594 span.clone(),
595 ));
596 }
597 }
598 _ => {}
599 }
600 }
601 }
602 }
603
604 codes.extend(callee_bytecode);
605
606 for arg in expr_call.args.iter() {
607 match compile_expr(arg, env, strings) {
608 Ok(arg_bytecode) => {
609 codes.extend(arg_bytecode);
610 }
611 Err(err) => {
612 errs.extend(err);
613 }
614 }
615 }
616
617 codes.push(opcode::CALL);
618 codes.push(expr_call.args.len() as u8);
619 }
620 }
621 }
622 Expr::Bool(value) => match value.0 {
623 true => {
624 codes.push(opcode::TRUE);
625 }
626 false => {
627 codes.push(opcode::FALSE);
628 }
629 },
630 Expr::Error => panic!("tried to compile despite parser errors"),
631 }
632
633 if !errs.is_empty() {
634 return Err(errs);
635 }
636
637 Ok(codes)
638}
639
640#[cfg(test)]
641mod compiler_tests {
642 use super::*;
643
644 #[test]
645 pub fn current_version_bytes() {
646 let version_bytes = get_version_bytes();
647
648 assert_eq!(version_bytes, [0, 7, 0, 0]);
649 }
650
651 #[test]
652 pub fn valid_bytecode_version_bytes() {
653 let mut codes = get_version_bytes().to_vec();
654 codes.push(opcode::TRUE);
655
656 ExprByteCode::new(codes.to_vec(), vec![]);
657 }
658
659 #[test]
660 #[should_panic(expected = "Version bytes do not match")]
661 pub fn invalid_bytecode_version_bytes() {
662 let mut codes: Vec<u8> = [0, 0, 0, 0].to_vec();
663 codes.push(opcode::TRUE);
664
665 ExprByteCode::new(codes.to_vec(), vec![]);
666 }
667
668 #[test]
669 pub fn get_version_bytes_from_bytecode() {
670 let mut codes = get_version_bytes().to_vec();
671 codes.push(opcode::TRUE);
672
673 let bytecode = ExprByteCode::new(codes.to_vec(), vec![]);
674
675 assert_eq!(bytecode.version(), &get_version_bytes());
676 }
677}