1use crate::{
4 ast::{Expr, ExprS, IdentifierKind, add_type_to_expr},
5 builtins::BuiltinFn,
6 errors::{
7 CompileError::{self, WrongNumberOfArgs},
8 ExprError, ExprErrorS, ExprResult,
9 },
10 prelude::lookup::TYPE,
11 types::Type,
12};
13
14pub mod opcode {
15 iota::iota! {
16 pub const
17 CALL: u8 = iota;,
18 GET,
19 CONSTANT,
20 TRUE,
21 FALSE
22 }
23}
24
25pub mod lookup {
31 iota::iota! {
32 pub const
33 BUILTIN: u8 = iota;,
34 VAR,
35 PROMPT,
36 SECRET,
37 USER_BUILTIN,
38 CLIENT_CTX,
39 TYPE
40 }
41}
42
43fn get(list: &[String], identifier: &str) -> Option<u8> {
45 list.iter().position(|x| x == identifier).map(|i| i as u8)
46}
47
48#[derive(Debug)]
49pub struct CompileTimeEnv {
50 builtins: Vec<BuiltinFn<'static>>,
51 user_builtins: Vec<BuiltinFn<'static>>,
52 vars: Vec<String>,
53 prompts: Vec<String>,
54 secrets: Vec<String>,
55 client_context: Vec<String>,
56}
57
58impl Default for CompileTimeEnv {
59 fn default() -> Self {
60 Self {
61 builtins: BuiltinFn::DEFAULT_BUILTINS.to_vec(),
62 user_builtins: vec![],
63 vars: vec![],
64 prompts: vec![],
65 secrets: vec![],
66 client_context: vec![],
67 }
68 }
69}
70
71impl CompileTimeEnv {
72 pub fn new(
73 vars: Vec<String>,
74 prompts: Vec<String>,
75 secrets: Vec<String>,
76 client_context: Vec<String>,
77 ) -> Self {
78 Self {
79 vars,
80 prompts,
81 secrets,
82 client_context,
83 ..Default::default()
84 }
85 }
86
87 pub fn get_builtin_index(&self, name: &str) -> Option<(&BuiltinFn, u8)> {
88 let index = self.builtins.iter().position(|x| x.name == name);
89
90 let result = index.map(|i| (self.builtins.get(i).unwrap(), i as u8));
91 result
92 }
93
94 pub fn get_user_builtin_index(&self, name: &str) -> Option<(&BuiltinFn, u8)> {
95 let index = self.user_builtins.iter().position(|x| x.name == name);
96
97 let result = index.map(|i| (self.user_builtins.get(i).unwrap(), i as u8));
98 result
99 }
100
101 pub fn add_user_builtins(&mut self, builtins: Vec<BuiltinFn<'static>>) {
102 for builtin in builtins {
103 self.add_user_builtin(builtin);
104 }
105 }
106
107 pub fn add_user_builtin(&mut self, builtin: BuiltinFn<'static>) {
108 self.user_builtins.push(builtin);
109 }
110
111 pub fn get_builtin(&self, index: usize) -> Option<&BuiltinFn<'static>> {
112 self.builtins.get(index)
113 }
114
115 pub fn get_user_builtin(&self, index: usize) -> Option<&BuiltinFn<'static>> {
116 self.user_builtins.get(index)
117 }
118
119 pub fn get_var(&self, index: usize) -> Option<&String> {
120 self.vars.get(index)
121 }
122
123 pub fn get_var_index(&self, name: &str) -> Option<usize> {
124 let index = self
125 .vars
126 .iter()
127 .position(|context_name| context_name == name);
128
129 index
130 }
131
132 pub fn get_prompt(&self, index: usize) -> Option<&String> {
133 self.prompts.get(index)
134 }
135
136 pub fn get_prompt_index(&self, name: &str) -> Option<usize> {
137 let index = self
138 .prompts
139 .iter()
140 .position(|context_name| context_name == name);
141
142 index
143 }
144
145 pub fn get_secret(&self, index: usize) -> Option<&String> {
146 self.secrets.get(index)
147 }
148
149 pub fn get_secret_index(&self, name: &str) -> Option<usize> {
150 let index = self
151 .secrets
152 .iter()
153 .position(|context_name| context_name == name);
154
155 index
156 }
157
158 pub fn get_client_context(&self, index: usize) -> Option<&String> {
159 self.client_context.get(index)
160 }
161
162 pub fn add_to_client_context(&mut self, key: &str) -> usize {
163 match self.client_context.iter().position(|x| x == key) {
164 Some(i) => i,
165 None => {
166 self.client_context.push(key.to_string());
167
168 self.client_context.len() - 1
169 }
170 }
171 }
172
173 pub fn add_keys_to_client_context(&mut self, keys: Vec<String>) {
174 self.client_context.extend(keys);
175 }
176
177 pub fn get_client_context_index(&self, name: &str) -> Option<(&String, u8)> {
178 let index = self
179 .client_context
180 .iter()
181 .position(|context_name| context_name == name);
182
183 let result = index.map(|i| (self.client_context.get(i).unwrap(), i as u8));
184 result
185 }
186}
187
188#[derive(Debug, Clone, PartialEq)]
190pub struct ExprByteCode {
191 version: [u8; 4],
192 codes: Vec<u8>,
193 strings: Vec<String>,
194 types: Vec<Type>,
195}
196
197impl ExprByteCode {
198 pub fn new(codes: Vec<u8>, strings: Vec<String>, types: Vec<Type>) -> Self {
199 let version_bytes = get_version_bytes();
200 let version_bytes_from_codes = &codes[0..4];
201
202 assert_eq!(
203 version_bytes, version_bytes_from_codes,
204 "Version bytes do not match"
205 );
206
207 let codes = codes[4..].to_vec();
208
209 Self {
210 version: version_bytes,
211 codes,
212 strings,
213 types,
214 }
215 }
216
217 pub fn version(&self) -> &[u8; 4] {
218 &self.version
219 }
220
221 pub fn codes(&self) -> &[u8] {
222 &self.codes
223 }
224
225 pub fn get_code(&self, index: usize) -> Option<&u8> {
226 self.codes.get(index)
227 }
228
229 pub fn strings(&self) -> &[String] {
230 &self.strings
231 }
232
233 pub fn types(&self) -> &[Type] {
234 &self.types
235 }
236}
237
238pub fn get_version_bytes() -> [u8; 4] {
239 [
240 env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap(),
241 env!("CARGO_PKG_VERSION_MINOR").parse().unwrap(),
242 env!("CARGO_PKG_VERSION_PATCH").parse().unwrap(),
243 0,
244 ]
245}
246
247pub fn compile(expr: &mut ExprS, env: &CompileTimeEnv) -> ExprResult<ExprByteCode> {
249 let mut strings: Vec<String> = vec![];
250 let mut types: Vec<Type> = vec![];
251 let mut codes = vec![];
252
253 codes.extend(get_version_bytes());
254
255 codes.extend(compile_expr(expr, env, &mut strings, &mut types)?);
256
257 Ok(ExprByteCode::new(codes, strings, types))
258}
259
260fn compile_expr(
261 (expr, span): &mut ExprS,
262 env: &CompileTimeEnv,
263 strings: &mut Vec<String>,
264 types: &mut Vec<Type>,
265) -> ExprResult<Vec<u8>> {
266 use opcode::*;
267
268 let mut codes = vec![];
269 let mut errs: Vec<ExprErrorS> = vec![];
270
271 add_type_to_expr(expr, env);
272
273 match expr {
274 Expr::String(string) => {
275 if let Some(index) = strings.iter().position(|x| x == &string.0) {
276 codes.push(CONSTANT);
277 codes.push(index as u8);
278 } else {
279 strings.push(string.0.clone());
280 let index = strings.len() - 1;
281 codes.push(CONSTANT);
282 codes.push(index as u8);
283 }
284 }
285 Expr::Identifier(identifier) => {
286 let identifier_lookup_name = identifier.lookup_name();
287 let identifier_name = identifier.full_name().to_string();
288
289 let identifier_undefined_err = (
290 CompileError::Undefined(identifier_name.clone()).into(),
291 span.clone(),
292 );
293
294 let result = match identifier.identifier_kind() {
295 IdentifierKind::Var => get(&env.vars, identifier_lookup_name).map(|index| {
296 codes.push(GET);
297 codes.push(lookup::VAR);
298 codes.push(index);
299 }),
300 IdentifierKind::Prompt => get(&env.prompts, identifier_lookup_name).map(|index| {
301 codes.push(GET);
302 codes.push(lookup::PROMPT);
303 codes.push(index);
304 }),
305 IdentifierKind::Secret => get(&env.secrets, identifier_lookup_name).map(|index| {
306 codes.push(GET);
307 codes.push(lookup::SECRET);
308 codes.push(index);
309 }),
310 IdentifierKind::Client => {
311 get(&env.client_context, identifier_lookup_name).map(|index| {
312 codes.push(GET);
313 codes.push(lookup::CLIENT_CTX);
314 codes.push(index);
315 })
316 }
317 IdentifierKind::Builtin => {
318 if let Some((_, index)) = env.get_builtin_index(identifier_lookup_name) {
319 codes.push(GET);
320 codes.push(lookup::BUILTIN);
321 codes.push(index);
322
323 Some(())
324 } else if let Some((_, index)) =
325 env.get_user_builtin_index(identifier_lookup_name)
326 {
327 codes.push(GET);
328 codes.push(lookup::USER_BUILTIN);
329 codes.push(index);
330
331 Some(())
332 } else {
333 None
334 }
335 }
336 IdentifierKind::Type => {
337 let ty = Type::from(&identifier_name);
338 if let Some(index) = types.iter().position(|x| x == &ty) {
339 codes.push(GET);
340 codes.push(TYPE);
341 codes.push(index as u8);
342 } else {
343 types.push(ty);
344 let index = types.len() - 1;
345 codes.push(GET);
346 codes.push(TYPE);
347 codes.push(index as u8);
348 }
349
350 Some(())
351 }
352 };
353
354 if let None = result {
355 errs.push(identifier_undefined_err);
356 }
357 }
358 Expr::Call(expr_call) => {
359 let callee_bytecode = compile_expr(&mut expr_call.callee, env, strings, types)?;
360
361 if let Some(_op) = callee_bytecode.first()
362 && let Some(lookup) = callee_bytecode.get(1)
363 && let Some(index) = callee_bytecode.get(2)
364 {
365 match *lookup {
366 lookup::BUILTIN => {
367 let builtin = env.get_builtin((*index).into()).unwrap();
368
369 let call_arity: usize = expr_call.args.len();
370
371 if !builtin.arity_matches(call_arity.try_into().unwrap()) {
372 errs.push((
373 ExprError::CompileError(WrongNumberOfArgs {
374 expected: builtin.arity() as usize,
375 actual: call_arity,
376 }),
377 span.clone(),
378 ));
379 }
380
381 let args: Vec<_> = expr_call.args.iter().take(call_arity).collect();
382
383 for (i, fnarg) in builtin.args.iter().enumerate() {
384 if let Some((a, a_span)) = args.get(i) {
385 let a_type = a.get_type();
386
387 let types_match = fnarg.ty == a_type
388 || fnarg.ty == Type::Value
389 || a_type == Type::Unknown;
390
391 if !types_match {
392 errs.push((
393 CompileError::TypeMismatch {
394 expected: fnarg.ty.clone(),
395 actual: a_type.clone(),
396 }
397 .into(),
398 a_span.clone(),
399 ));
400 }
401 }
402 }
403 }
404 lookup::USER_BUILTIN => {
405 let builtin = env.get_user_builtin((*index).into()).unwrap();
406
407 let call_arity: usize = expr_call.args.len();
408
409 if !builtin.arity_matches(call_arity.try_into().unwrap()) {
410 errs.push((
411 ExprError::CompileError(WrongNumberOfArgs {
412 expected: builtin.arity() as usize,
413 actual: call_arity,
414 }),
415 span.clone(),
416 ));
417 }
418 }
419 lookup::CLIENT_CTX => {
420 }
424 _ => {
425 errs.push((
426 CompileError::InvalidLookupType(*lookup).into(),
427 span.clone(),
428 ));
429 }
430 }
431 }
432
433 codes.extend(callee_bytecode);
434
435 for arg in expr_call.args.iter_mut() {
436 match compile_expr(arg, env, strings, types) {
437 Ok(arg_bytecode) => {
438 codes.extend(arg_bytecode);
439 }
440 Err(err) => {
441 errs.extend(err);
442 }
443 }
444 }
445
446 codes.push(opcode::CALL);
447 codes.push(expr_call.args.len() as u8);
448 }
449 Expr::Bool(value) => match value.0 {
450 true => {
451 codes.push(opcode::TRUE);
452 }
453 false => {
454 codes.push(opcode::FALSE);
455 }
456 },
457 Expr::Error => panic!("tried to compile despite parser errors"),
458 }
459
460 if !errs.is_empty() {
461 return Err(errs);
462 }
463
464 Ok(codes)
465}
466
467#[cfg(test)]
468mod compiler_tests {
469 use super::*;
470
471 #[test]
472 pub fn current_version_bytes() {
473 let version_bytes = get_version_bytes();
474
475 assert_eq!(version_bytes, [0, 7, 0, 0]);
476 }
477
478 #[test]
479 pub fn valid_bytecode_version_bytes() {
480 let mut codes = get_version_bytes().to_vec();
481 codes.push(opcode::TRUE);
482
483 ExprByteCode::new(codes.to_vec(), vec![], vec![]);
484 }
485
486 #[test]
487 #[should_panic(expected = "Version bytes do not match")]
488 pub fn invalid_bytecode_version_bytes() {
489 let mut codes: Vec<u8> = [0, 0, 0, 0].to_vec();
490 codes.push(opcode::TRUE);
491
492 ExprByteCode::new(codes.to_vec(), vec![], vec![]);
493 }
494
495 #[test]
496 pub fn get_version_bytes_from_bytecode() {
497 let mut codes = get_version_bytes().to_vec();
498 codes.push(opcode::TRUE);
499
500 let bytecode = ExprByteCode::new(codes.to_vec(), vec![], vec![]);
501
502 assert_eq!(bytecode.version(), &get_version_bytes());
503 }
504}