1mod calls;
2mod classify;
3mod expr;
4mod patterns;
5
6use std::collections::{HashMap, HashSet};
7
8use crate::ast::{FnBody, FnDef, Stmt, TopLevel, TypeDef};
9use crate::nan_value::{Arena, NanValue};
10use crate::source::find_module_file;
11use crate::types::{option, result};
12
13use super::builtin::VmBuiltin;
14use super::opcode::*;
15use super::symbol::{VmSymbolTable, VmVariantCtor};
16use super::types::{CodeStore, FnChunk};
17
18pub fn compile_program(
21 items: &[TopLevel],
22 arena: &mut Arena,
23) -> Result<(CodeStore, Vec<NanValue>), CompileError> {
24 compile_program_with_modules(items, arena, None, "")
25}
26
27pub fn compile_program_with_modules(
29 items: &[TopLevel],
30 arena: &mut Arena,
31 module_root: Option<&str>,
32 source_file: &str,
33) -> Result<(CodeStore, Vec<NanValue>), CompileError> {
34 let mut compiler = ProgramCompiler::new();
35 compiler.source_file = source_file.to_string();
36 compiler.sync_record_field_symbols(arena);
37
38 if let Some(module_root) = module_root {
39 compiler.load_modules(items, module_root, arena)?;
40 }
41
42 for item in items {
43 if let TopLevel::Stmt(Stmt::Binding(name, _, _)) = item {
44 compiler.ensure_global(name);
45 }
46 }
47
48 for item in items {
49 match item {
50 TopLevel::FnDef(fndef) => {
51 compiler.ensure_global(&fndef.name);
52 let effect_ids: Vec<u32> = fndef
53 .effects
54 .iter()
55 .map(|effect| compiler.symbols.intern_name(&effect.node))
56 .collect();
57 let fn_id = compiler.code.add_function(FnChunk {
58 name: fndef.name.clone(),
59 arity: fndef.params.len() as u8,
60 local_count: 0,
61 code: Vec::new(),
62 constants: Vec::new(),
63 effects: effect_ids,
64 thin: false,
65 parent_thin: false,
66 leaf: false,
67 source_file: String::new(),
68 line_table: Vec::new(),
69 });
70 let symbol_id = compiler.symbols.intern_function(
71 &fndef.name,
72 fn_id,
73 &fndef
74 .effects
75 .iter()
76 .map(|e| e.node.clone())
77 .collect::<Vec<_>>(),
78 );
79 let global_idx = compiler.global_names[&fndef.name];
80 compiler.globals[global_idx as usize] = VmSymbolTable::symbol_ref(symbol_id);
81 }
82 TopLevel::TypeDef(td) => {
83 compiler.register_type_def(td, arena);
84 }
85 _ => {}
86 }
87 }
88
89 compiler.register_current_module_namespace(items);
90
91 for item in items {
92 if let TopLevel::FnDef(fndef) = item {
93 let fn_id = compiler.code.find(&fndef.name).unwrap();
94 let chunk = compiler.compile_fn(fndef, arena)?;
95 compiler.code.functions[fn_id as usize] = chunk;
96 }
97 }
98
99 compiler.compile_top_level(items, arena)?;
100 compiler.code.symbols = compiler.symbols.clone();
101 classify::classify_thin_functions(&mut compiler.code, arena)?;
102
103 Ok((compiler.code, compiler.globals))
104}
105
106#[derive(Debug)]
107pub struct CompileError {
108 pub msg: String,
109}
110
111impl std::fmt::Display for CompileError {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 write!(f, "Compile error: {}", self.msg)
114 }
115}
116
117struct ProgramCompiler {
118 code: CodeStore,
119 symbols: VmSymbolTable,
120 globals: Vec<NanValue>,
121 global_names: HashMap<String, u16>,
122 source_file: String,
124}
125
126impl ProgramCompiler {
127 fn new() -> Self {
128 let mut compiler = ProgramCompiler {
129 code: CodeStore::new(),
130 symbols: VmSymbolTable::default(),
131 globals: Vec::new(),
132 global_names: HashMap::new(),
133 source_file: String::new(),
134 };
135 compiler.bootstrap_core_symbols();
136 compiler
137 }
138
139 fn sync_record_field_symbols(&mut self, arena: &Arena) {
140 for type_id in 0..arena.type_count() {
141 let type_name = arena.get_type_name(type_id);
142 self.symbols.intern_namespace_path(type_name);
143 let field_names = arena.get_field_names(type_id);
144 if field_names.is_empty() {
145 continue;
146 }
147 let field_symbol_ids: Vec<u32> = field_names
148 .iter()
149 .map(|field_name| self.symbols.intern_name(field_name))
150 .collect();
151 self.code.register_record_fields(type_id, &field_symbol_ids);
152 }
153 }
154
155 fn load_modules(
160 &mut self,
161 items: &[TopLevel],
162 module_root: &str,
163 arena: &mut Arena,
164 ) -> Result<(), CompileError> {
165 let module = items.iter().find_map(|i| {
166 if let TopLevel::Module(m) = i {
167 Some(m)
168 } else {
169 None
170 }
171 });
172 let module = match module {
173 Some(m) => m,
174 None => return Ok(()),
175 };
176
177 let mut loaded = HashSet::new();
178 for dep_name in &module.depends {
179 self.load_module_recursive(dep_name, module_root, arena, &mut loaded)?;
180 }
181 Ok(())
182 }
183
184 fn load_module_recursive(
185 &mut self,
186 dep_name: &str,
187 module_root: &str,
188 arena: &mut Arena,
189 loaded: &mut HashSet<String>,
190 ) -> Result<(), CompileError> {
191 if loaded.contains(dep_name) {
192 return Ok(());
193 }
194 loaded.insert(dep_name.to_string());
195
196 let file_path = find_module_file(dep_name, module_root).ok_or_else(|| CompileError {
197 msg: format!("module '{}' not found (root: {})", dep_name, module_root),
198 })?;
199
200 let source = std::fs::read_to_string(&file_path).map_err(|e| CompileError {
201 msg: format!("cannot read module '{}': {}", dep_name, e),
202 })?;
203
204 let mut mod_items = crate::source::parse_source(&source).map_err(|e| CompileError {
205 msg: format!("parse error in module '{}': {}", dep_name, e),
206 })?;
207
208 crate::tco::transform_program(&mut mod_items);
209 crate::resolver::resolve_program(&mut mod_items);
210
211 if let Some(sub_module) = mod_items.iter().find_map(|i| {
212 if let TopLevel::Module(m) = i {
213 Some(m)
214 } else {
215 None
216 }
217 }) {
218 for sub_dep in &sub_module.depends {
219 self.load_module_recursive(sub_dep, module_root, arena, loaded)?;
220 }
221 }
222
223 for item in &mod_items {
224 if let TopLevel::TypeDef(td) = item {
225 self.register_type_def(td, arena);
226 }
227 }
228
229 let exposes: Option<Vec<String>> = mod_items.iter().find_map(|i| {
230 if let TopLevel::Module(m) = i {
231 if m.exposes.is_empty() {
232 None
233 } else {
234 Some(m.exposes.clone())
235 }
236 } else {
237 None
238 }
239 });
240
241 let mut module_fn_ids: Vec<(String, u32)> = Vec::new();
242 for item in &mod_items {
243 if let TopLevel::FnDef(fndef) = item {
244 let qualified_name = format!("{}.{}", dep_name, fndef.name);
245 let effect_ids: Vec<u32> = fndef
246 .effects
247 .iter()
248 .map(|effect| self.symbols.intern_name(&effect.node))
249 .collect();
250 let fn_id = self.code.add_function(FnChunk {
251 name: qualified_name.clone(),
252 arity: fndef.params.len() as u8,
253 local_count: 0,
254 code: Vec::new(),
255 constants: Vec::new(),
256 effects: effect_ids,
257 thin: false,
258 parent_thin: false,
259 leaf: false,
260 source_file: String::new(),
261 line_table: Vec::new(),
262 });
263 self.symbols.intern_function(
264 &qualified_name,
265 fn_id,
266 &fndef
267 .effects
268 .iter()
269 .map(|e| e.node.clone())
270 .collect::<Vec<_>>(),
271 );
272 module_fn_ids.push((fndef.name.clone(), fn_id));
273 }
274 }
275
276 let module_scope: HashMap<String, u32> = module_fn_ids.iter().cloned().collect();
277
278 let mut fn_idx = 0;
279 for item in &mod_items {
280 if let TopLevel::FnDef(fndef) = item {
281 let (_, fn_id) = module_fn_ids[fn_idx];
282 let chunk = self.compile_fn_with_scope(fndef, arena, &module_scope)?;
283 self.code.functions[fn_id as usize] = chunk;
284 fn_idx += 1;
285 }
286 }
287
288 for (fn_name, _fn_id) in &module_fn_ids {
289 let exposed = match &exposes {
290 Some(list) => list.iter().any(|e| e == fn_name),
291 None => !fn_name.starts_with('_'),
292 };
293 if exposed {
294 let qualified = format!("{}.{}", dep_name, fn_name);
295 let global_idx = self.ensure_global(&qualified);
296 let symbol_id = self.symbols.find(&qualified).ok_or_else(|| CompileError {
297 msg: format!("missing VM symbol for exposed function {}", qualified),
298 })?;
299 self.globals[global_idx as usize] = VmSymbolTable::symbol_ref(symbol_id);
300 }
301 }
302
303 let module_symbol_id = self.symbols.intern_namespace_path(dep_name);
304 for item in &mod_items {
305 if let TopLevel::TypeDef(td) = item {
306 let type_name = match td {
307 TypeDef::Sum { name, .. } | TypeDef::Product { name, .. } => name,
308 };
309 let exposed = match &exposes {
310 Some(list) => list.iter().any(|e| e == type_name),
311 None => !type_name.starts_with('_'),
312 };
313 if exposed && let Some(type_symbol_id) = self.symbols.find(type_name) {
314 let member_symbol_id = self.symbols.intern_name(type_name);
315 self.symbols.add_namespace_member_by_id(
316 module_symbol_id,
317 member_symbol_id,
318 VmSymbolTable::symbol_ref(type_symbol_id),
319 );
320 }
321 }
322 }
323
324 for (fn_name, _fn_id) in &module_fn_ids {
325 let exposed = match &exposes {
326 Some(list) => list.iter().any(|e| e == fn_name),
327 None => !fn_name.starts_with('_'),
328 };
329 if exposed {
330 let qualified = format!("{}.{}", dep_name, fn_name);
331 if let Some(fn_symbol_id) = self.symbols.find(&qualified) {
332 let member_symbol_id = self.symbols.intern_name(fn_name);
333 self.symbols.add_namespace_member_by_id(
334 module_symbol_id,
335 member_symbol_id,
336 VmSymbolTable::symbol_ref(fn_symbol_id),
337 );
338 }
339 }
340 }
341
342 Ok(())
343 }
344
345 fn ensure_global(&mut self, name: &str) -> u16 {
346 if let Some(&idx) = self.global_names.get(name) {
347 return idx;
348 }
349 let idx = self.globals.len() as u16;
350 self.global_names.insert(name.to_string(), idx);
351 self.globals.push(NanValue::UNIT);
352 idx
353 }
354
355 fn register_type_def(&mut self, td: &TypeDef, arena: &mut Arena) {
356 match td {
357 TypeDef::Product { name, fields, .. } => {
358 self.symbols.intern_namespace_path(name);
359 let field_names: Vec<String> = fields.iter().map(|(n, _)| n.clone()).collect();
360 let field_symbol_ids: Vec<u32> = field_names
361 .iter()
362 .map(|field_name| self.symbols.intern_name(field_name))
363 .collect();
364 let type_id = arena.register_record_type(name, field_names);
365 self.code.register_record_fields(type_id, &field_symbol_ids);
366 }
367 TypeDef::Sum { name, variants, .. } => {
368 let type_symbol_id = self.symbols.intern_namespace_path(name);
369 let variant_names: Vec<String> = variants.iter().map(|v| v.name.clone()).collect();
370 let type_id = arena.register_sum_type(name, variant_names);
371 for (variant_id, variant) in variants.iter().enumerate() {
372 let ctor_id = arena
373 .find_ctor_id(type_id, variant_id as u16)
374 .expect("ctor id");
375 let qualified_name = format!("{}.{}", name, variant.name);
376 let ctor_symbol_id = self.symbols.intern_variant_ctor(
377 &qualified_name,
378 VmVariantCtor {
379 type_id,
380 variant_id: variant_id as u16,
381 ctor_id,
382 field_count: variant.fields.len() as u8,
383 },
384 );
385 let member_symbol_id = self.symbols.intern_name(&variant.name);
386 self.symbols.add_namespace_member_by_id(
387 type_symbol_id,
388 member_symbol_id,
389 VmSymbolTable::symbol_ref(ctor_symbol_id),
390 );
391 }
392 }
393 }
394 }
395
396 fn bootstrap_core_symbols(&mut self) {
397 for builtin in VmBuiltin::ALL.iter().copied() {
398 let builtin_symbol_id = self.symbols.intern_builtin(builtin);
399 if let Some((namespace, member)) = builtin.name().split_once('.') {
400 let namespace_symbol_id = self.symbols.intern_namespace_path(namespace);
401 let member_symbol_id = self.symbols.intern_name(member);
402 self.symbols.add_namespace_member_by_id(
403 namespace_symbol_id,
404 member_symbol_id,
405 VmSymbolTable::symbol_ref(builtin_symbol_id),
406 );
407 }
408 }
409
410 let result_symbol_id = self.symbols.intern_namespace_path("Result");
411 let ok_symbol_id = self.symbols.intern_wrapper("Result.Ok", 0);
412 let err_symbol_id = self.symbols.intern_wrapper("Result.Err", 1);
413 let ok_member_symbol_id = self.symbols.intern_name("Ok");
414 self.symbols.add_namespace_member_by_id(
415 result_symbol_id,
416 ok_member_symbol_id,
417 VmSymbolTable::symbol_ref(ok_symbol_id),
418 );
419 let err_member_symbol_id = self.symbols.intern_name("Err");
420 self.symbols.add_namespace_member_by_id(
421 result_symbol_id,
422 err_member_symbol_id,
423 VmSymbolTable::symbol_ref(err_symbol_id),
424 );
425 for (member, builtin_name) in result::extra_members() {
426 if let Some(symbol_id) = self.symbols.find(&builtin_name) {
427 let member_symbol_id = self.symbols.intern_name(member);
428 self.symbols.add_namespace_member_by_id(
429 result_symbol_id,
430 member_symbol_id,
431 VmSymbolTable::symbol_ref(symbol_id),
432 );
433 }
434 }
435
436 let option_symbol_id = self.symbols.intern_namespace_path("Option");
437 let some_symbol_id = self.symbols.intern_wrapper("Option.Some", 2);
438 self.symbols.intern_constant("Option.None", NanValue::NONE);
439 let some_member_symbol_id = self.symbols.intern_name("Some");
440 self.symbols.add_namespace_member_by_id(
441 option_symbol_id,
442 some_member_symbol_id,
443 VmSymbolTable::symbol_ref(some_symbol_id),
444 );
445 let none_member_symbol_id = self.symbols.intern_name("None");
446 self.symbols.add_namespace_member_by_id(
447 option_symbol_id,
448 none_member_symbol_id,
449 NanValue::NONE,
450 );
451 for (member, builtin_name) in option::extra_members() {
452 if let Some(symbol_id) = self.symbols.find(&builtin_name) {
453 let member_symbol_id = self.symbols.intern_name(member);
454 self.symbols.add_namespace_member_by_id(
455 option_symbol_id,
456 member_symbol_id,
457 VmSymbolTable::symbol_ref(symbol_id),
458 );
459 }
460 }
461 }
462
463 fn compile_fn(&mut self, fndef: &FnDef, arena: &mut Arena) -> Result<FnChunk, CompileError> {
464 let empty_scope = HashMap::new();
465 self.compile_fn_with_scope(fndef, arena, &empty_scope)
466 }
467
468 fn compile_fn_with_scope(
469 &mut self,
470 fndef: &FnDef,
471 arena: &mut Arena,
472 module_scope: &HashMap<String, u32>,
473 ) -> Result<FnChunk, CompileError> {
474 let resolution = fndef.resolution.as_ref();
475 let local_count = resolution.map_or(fndef.params.len() as u16, |r| r.local_count);
476 let local_slots: HashMap<String, u16> = resolution
477 .map(|r| r.local_slots.as_ref().clone())
478 .unwrap_or_else(|| {
479 fndef
480 .params
481 .iter()
482 .enumerate()
483 .map(|(i, (name, _))| (name.clone(), i as u16))
484 .collect()
485 });
486
487 let mut fc = FnCompiler::new(
488 &fndef.name,
489 fndef.params.len() as u8,
490 local_count,
491 fndef
492 .effects
493 .iter()
494 .map(|effect| self.symbols.intern_name(&effect.node))
495 .collect(),
496 local_slots,
497 &self.global_names,
498 module_scope,
499 &self.code,
500 &mut self.symbols,
501 arena,
502 );
503 fc.source_file = self.source_file.clone();
504 fc.note_line(fndef.line);
505
506 match fndef.body.as_ref() {
507 FnBody::Block(stmts) => fc.compile_body(stmts)?,
508 }
509
510 Ok(fc.finish())
511 }
512
513 fn compile_top_level(
514 &mut self,
515 items: &[TopLevel],
516 arena: &mut Arena,
517 ) -> Result<(), CompileError> {
518 let has_stmts = items.iter().any(|i| matches!(i, TopLevel::Stmt(_)));
519 if !has_stmts {
520 return Ok(());
521 }
522
523 for item in items {
524 if let TopLevel::Stmt(Stmt::Binding(name, _, _)) = item {
525 self.ensure_global(name);
526 }
527 }
528
529 let empty_mod_scope = HashMap::new();
530 let mut fc = FnCompiler::new(
531 "__top_level__",
532 0,
533 0,
534 Vec::new(),
535 HashMap::new(),
536 &self.global_names,
537 &empty_mod_scope,
538 &self.code,
539 &mut self.symbols,
540 arena,
541 );
542
543 for item in items {
544 if let TopLevel::Stmt(stmt) = item {
545 match stmt {
546 Stmt::Binding(name, _type_ann, expr) => {
547 fc.compile_expr(expr)?;
548 let idx = self.global_names[name.as_str()];
549 fc.emit_op(STORE_GLOBAL);
550 fc.emit_u16(idx);
551 }
552 Stmt::Expr(expr) => {
553 fc.compile_expr(expr)?;
554 fc.emit_op(POP);
555 }
556 }
557 }
558 }
559
560 fc.emit_op(LOAD_UNIT);
561 fc.emit_op(RETURN);
562
563 let chunk = fc.finish();
564 self.code.add_function(chunk);
565 Ok(())
566 }
567
568 fn register_current_module_namespace(&mut self, items: &[TopLevel]) {
569 let Some(module) = items.iter().find_map(|item| {
570 if let TopLevel::Module(module) = item {
571 Some(module)
572 } else {
573 None
574 }
575 }) else {
576 return;
577 };
578
579 let module_symbol_id = self.symbols.intern_namespace_path(&module.name);
580
581 for item in items {
582 match item {
583 TopLevel::FnDef(fndef) => {
584 let exposed = if module.exposes.is_empty() {
585 !fndef.name.starts_with('_')
586 } else {
587 module.exposes.iter().any(|name| name == &fndef.name)
588 };
589 if exposed && let Some(symbol_id) = self.symbols.find(&fndef.name) {
590 let member_symbol_id = self.symbols.intern_name(&fndef.name);
591 self.symbols.add_namespace_member_by_id(
592 module_symbol_id,
593 member_symbol_id,
594 VmSymbolTable::symbol_ref(symbol_id),
595 );
596 }
597 }
598 TopLevel::TypeDef(TypeDef::Product { name, .. } | TypeDef::Sum { name, .. }) => {
599 let exposed = if module.exposes.is_empty() {
600 !name.starts_with('_')
601 } else {
602 module.exposes.iter().any(|member| member == name)
603 };
604 if exposed && let Some(symbol_id) = self.symbols.find(name) {
605 let member_symbol_id = self.symbols.intern_name(name);
606 self.symbols.add_namespace_member_by_id(
607 module_symbol_id,
608 member_symbol_id,
609 VmSymbolTable::symbol_ref(symbol_id),
610 );
611 }
612 }
613 _ => {}
614 }
615 }
616 }
617}
618
619enum CallTarget {
621 KnownFn(u32),
623 Wrapper(u8),
625 None_,
627 Variant(u32, u16),
629 Builtin(VmBuiltin),
631 UnknownQualified(String),
633}
634
635struct FnCompiler<'a> {
636 name: String,
637 arity: u8,
638 local_count: u16,
639 effects: Vec<u32>,
640 local_slots: HashMap<String, u16>,
641 global_names: &'a HashMap<String, u16>,
642 module_scope: &'a HashMap<String, u32>,
645 code_store: &'a CodeStore,
646 symbols: &'a mut VmSymbolTable,
647 arena: &'a mut Arena,
648 code: Vec<u8>,
649 constants: Vec<NanValue>,
650 last_op_pos: usize,
652 source_file: String,
654 line_table: Vec<(u16, u16)>,
656 last_noted_line: u16,
658}
659
660impl<'a> FnCompiler<'a> {
661 #[allow(clippy::too_many_arguments)]
662 fn new(
663 name: &str,
664 arity: u8,
665 local_count: u16,
666 effects: Vec<u32>,
667 local_slots: HashMap<String, u16>,
668 global_names: &'a HashMap<String, u16>,
669 module_scope: &'a HashMap<String, u32>,
670 code_store: &'a CodeStore,
671 symbols: &'a mut VmSymbolTable,
672 arena: &'a mut Arena,
673 ) -> Self {
674 FnCompiler {
675 name: name.to_string(),
676 arity,
677 local_count,
678 effects,
679 local_slots,
680 global_names,
681 module_scope,
682 code_store,
683 symbols,
684 arena,
685 code: Vec::new(),
686 constants: Vec::new(),
687 last_op_pos: usize::MAX,
688 source_file: String::new(),
689 line_table: Vec::new(),
690 last_noted_line: 0,
691 }
692 }
693
694 fn finish(self) -> FnChunk {
695 FnChunk {
696 name: self.name,
697 arity: self.arity,
698 local_count: self.local_count,
699 code: self.code,
700 constants: self.constants,
701 effects: self.effects,
702 thin: false,
703 parent_thin: false,
704 leaf: false,
705 source_file: self.source_file,
706 line_table: self.line_table,
707 }
708 }
709
710 fn note_line(&mut self, line: usize) {
714 if line == 0 {
715 return;
716 }
717 let line16 = line as u16;
718 if line16 == self.last_noted_line {
719 return; }
721 self.last_noted_line = line16;
722 self.line_table.push((self.code.len() as u16, line16));
723 }
724
725 fn emit_op(&mut self, op: u8) {
726 let prev_pos = self.last_op_pos;
727 let prev_op = if prev_pos < self.code.len() {
728 self.code[prev_pos]
729 } else {
730 0xFF
731 };
732
733 if op == LOAD_LOCAL && prev_op == LOAD_LOCAL && prev_pos + 2 == self.code.len() {
735 self.code[prev_pos] = LOAD_LOCAL_2;
736 return;
738 }
739 if op == LOAD_CONST && prev_op == LOAD_LOCAL && prev_pos + 2 == self.code.len() {
741 self.code[prev_pos] = LOAD_LOCAL_CONST;
742 return;
744 }
745 if op == UNWRAP_OR && self.code.len() >= 4 {
749 let len = self.code.len();
750 if self.code[len - 4] == VECTOR_GET && self.code[len - 3] == LOAD_CONST {
751 let hi = self.code[len - 2];
752 let lo = self.code[len - 1];
753 self.code[len - 4] = VECTOR_GET_OR;
754 self.code[len - 3] = hi;
755 self.code[len - 2] = lo;
756 self.code.pop(); self.last_op_pos = len - 4;
758 return;
759 }
760 }
761 self.last_op_pos = self.code.len();
762 self.code.push(op);
763 }
764
765 fn emit_u8(&mut self, val: u8) {
766 self.code.push(val);
767 }
768
769 fn emit_u16(&mut self, val: u16) {
770 self.code.push((val >> 8) as u8);
771 self.code.push((val & 0xFF) as u8);
772 }
773
774 fn emit_i16(&mut self, val: i16) {
775 self.emit_u16(val as u16);
776 }
777
778 fn emit_u32(&mut self, val: u32) {
779 self.code.push((val >> 24) as u8);
780 self.code.push(((val >> 16) & 0xFF) as u8);
781 self.code.push(((val >> 8) & 0xFF) as u8);
782 self.code.push((val & 0xFF) as u8);
783 }
784
785 fn emit_u64(&mut self, val: u64) {
786 self.code.extend_from_slice(&val.to_be_bytes());
787 }
788
789 fn add_constant(&mut self, val: NanValue) -> u16 {
790 for (i, c) in self.constants.iter().enumerate() {
791 if c.bits() == val.bits() {
792 return i as u16;
793 }
794 }
795 let idx = self.constants.len() as u16;
796 self.constants.push(val);
797 idx
798 }
799
800 fn offset(&self) -> usize {
801 self.code.len()
802 }
803
804 fn emit_jump(&mut self, op: u8) -> usize {
805 self.emit_op(op);
806 let patch_pos = self.code.len();
807 self.emit_i16(0);
808 patch_pos
809 }
810
811 fn patch_jump(&mut self, patch_pos: usize) {
812 let target = self.code.len();
813 let offset = (target as isize - patch_pos as isize - 2) as i16;
814 let bytes = (offset as u16).to_be_bytes();
815 self.code[patch_pos] = bytes[0];
816 self.code[patch_pos + 1] = bytes[1];
817 }
818
819 fn patch_jump_to(&mut self, patch_pos: usize, target: usize) {
820 let offset = (target as isize - patch_pos as isize - 2) as i16;
821 let bytes = (offset as u16).to_be_bytes();
822 self.code[patch_pos] = bytes[0];
823 self.code[patch_pos + 1] = bytes[1];
824 }
825
826 fn bind_top_to_local(&mut self, name: &str) {
827 if let Some(&slot) = self.local_slots.get(name) {
828 self.emit_op(STORE_LOCAL);
829 self.emit_u8(slot as u8);
830 } else {
831 self.emit_op(POP);
832 }
833 }
834
835 fn dup_and_bind_top_to_local(&mut self, name: &str) {
836 self.emit_op(DUP);
837 self.bind_top_to_local(name);
838 }
839}
840
841#[cfg(test)]
842mod tests {
843 use super::compile_program;
844 use crate::nan_value::Arena;
845 use crate::source::parse_source;
846 use crate::vm::opcode::{LT, NOT, VECTOR_GET_OR, VECTOR_SET_OR_KEEP};
847
848 #[test]
849 fn vector_get_with_literal_default_lowers_to_vector_get_or() {
850 let source = r#"
851module Demo
852
853fn cellAt(grid: Vector<Int>, idx: Int) -> Int
854 Option.withDefault(Vector.get(grid, idx), 0)
855"#;
856
857 let mut items = parse_source(source).expect("source should parse");
858 crate::tco::transform_program(&mut items);
859 crate::resolver::resolve_program(&mut items);
860
861 let mut arena = Arena::new();
862 let (code, _globals) = compile_program(&items, &mut arena).expect("vm compile should pass");
863 let fn_id = code.find("cellAt").expect("cellAt should exist");
864 let chunk = code.get(fn_id);
865
866 assert!(
867 chunk.code.contains(&VECTOR_GET_OR),
868 "expected VECTOR_GET_OR in bytecode, got {:?}",
869 chunk.code
870 );
871 }
872
873 #[test]
874 fn vector_set_with_same_default_lowers_to_vector_set_or_keep() {
875 let source = r#"
876module Demo
877
878fn updateOrKeep(vec: Vector<Int>, idx: Int, value: Int) -> Vector<Int>
879 Option.withDefault(Vector.set(vec, idx, value), vec)
880"#;
881
882 let mut items = parse_source(source).expect("source should parse");
883 crate::tco::transform_program(&mut items);
884 crate::resolver::resolve_program(&mut items);
885
886 let mut arena = Arena::new();
887 let (code, _globals) = compile_program(&items, &mut arena).expect("vm compile should pass");
888 let fn_id = code
889 .find("updateOrKeep")
890 .expect("updateOrKeep should exist");
891 let chunk = code.get(fn_id);
892
893 assert!(
894 chunk.code.contains(&VECTOR_SET_OR_KEEP),
895 "expected VECTOR_SET_OR_KEEP in bytecode, got {:?}",
896 chunk.code
897 );
898 }
899
900 #[test]
901 fn bool_match_on_gte_uses_base_compare_without_not() {
902 let source = r#"
903module Demo
904
905fn bucket(n: Int) -> Int
906 match n >= 10
907 true -> 7
908 false -> 3
909"#;
910
911 let mut items = parse_source(source).expect("source should parse");
912 crate::tco::transform_program(&mut items);
913 crate::resolver::resolve_program(&mut items);
914
915 let mut arena = Arena::new();
916 let (code, _globals) = compile_program(&items, &mut arena).expect("vm compile should pass");
917 let fn_id = code.find("bucket").expect("bucket should exist");
918 let chunk = code.get(fn_id);
919
920 assert!(
921 chunk.code.contains(<),
922 "expected LT in bytecode, got {:?}",
923 chunk.code
924 );
925 assert!(
926 !chunk.code.contains(&NOT),
927 "did not expect NOT in normalized bool-match bytecode, got {:?}",
928 chunk.code
929 );
930 }
931}