1use super::ast::{BinOp, Expr, ExprKind, UnOp};
4use super::error::{LinkError, ParseError, ProgramError};
5use super::metadata::{SymbolKind, SymbolMetadata};
6use super::parser::Parser;
7use crate::ir::Instr;
8use crate::span::{Span, SpanError};
9use crate::symbol::{SymTable, Symbol};
10use crate::vm::{Vm, VmError};
11use colored::Colorize;
12use rust_decimal::Decimal;
13use serde::{Deserialize, Serialize};
14use unicode_width::UnicodeWidthStr;
15
16const PROGRAM_VERSION: &str = env!("CARGO_PKG_VERSION");
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21struct BinaryFormat {
22 version: String,
23 bytecode: Vec<Instr>,
24 symbols: Vec<SymbolMetadata>,
25}
26
27#[derive(Debug, Clone)]
29pub enum ProgramOrigin {
30 File(String),
32 Source,
34 Bytecode,
36}
37
38#[derive(Debug)]
58pub struct Program<'src, State> {
59 source: Option<&'src str>,
60 state: State,
61}
62
63#[derive(Debug)]
65pub struct Compiled {
66 origin: ProgramOrigin,
67 version: String,
68 bytecode: Vec<Instr>,
69 symbols: Vec<SymbolMetadata>,
70}
71
72#[derive(Debug)]
74pub struct Linked {
75 #[allow(dead_code)]
76 origin: ProgramOrigin,
77 version: String,
78 bytecode: Vec<Instr>,
79 symtable: SymTable,
80}
81
82impl<'src> Program<'src, Compiled> {
87 pub fn new_from_source(source: &'src str) -> Result<Self, ProgramError> {
101 let trimmed = source.trim();
102
103 let mut parser = Parser::new(trimmed);
105 let ast = parser
106 .parse()
107 .map_err(|parse_err| {
108 let highlighted = Self::highlight_error(trimmed, &parse_err);
110 ProgramError::ParseError(format!("{}\n{}", parse_err, highlighted))
111 })?
112 .ok_or_else(|| {
113 let parse_err = ParseError::UnexpectedEof {
114 span: Span::new(0, 0),
115 };
116 let highlighted = Self::highlight_error(trimmed, &parse_err);
117 ProgramError::ParseError(format!("{}\n{}", parse_err, highlighted))
118 })?;
119
120 let (bytecode, symbols) = Self::generate_bytecode(&ast);
122
123 Ok(Program {
124 source: Some(trimmed),
125 state: Compiled {
126 origin: ProgramOrigin::Source,
127 version: PROGRAM_VERSION.to_string(),
128 bytecode,
129 symbols,
130 },
131 })
132 }
133
134 pub fn new_from_file(path: impl Into<String>) -> Result<Self, ProgramError> {
144 let path_str = path.into();
145 let data = std::fs::read(&path_str)?;
146 Self::from_bytecode(&data, ProgramOrigin::File(path_str))
147 }
148
149 pub fn new_from_bytecode(data: &[u8]) -> Result<Self, ProgramError> {
153 Self::from_bytecode(data, ProgramOrigin::Bytecode)
154 }
155
156 pub fn link(mut self, table: SymTable) -> Result<Program<'src, Linked>, ProgramError> {
169 for metadata in &mut self.state.symbols {
171 let (resolved_idx, symbol) =
172 table
173 .get_with_index(&metadata.name)
174 .ok_or_else(|| LinkError::MissingSymbol {
175 name: metadata.name.to_string(),
176 })?;
177
178 Self::validate_symbol_kind(metadata, symbol)?;
180
181 metadata.index = Some(resolved_idx);
183 }
184
185 for instr in &mut self.state.bytecode {
187 match instr {
188 Instr::Load(idx) => {
189 *idx = self.state.symbols[*idx]
190 .index
191 .expect("Symbol should have been resolved during linking");
192 }
193 Instr::Call(idx, _) => {
194 *idx = self.state.symbols[*idx]
195 .index
196 .expect("Symbol should have been resolved during linking");
197 }
198 _ => {}
199 }
200 }
201
202 Ok(Program {
203 source: self.source,
204 state: Linked {
205 origin: self.state.origin,
206 version: self.state.version,
207 bytecode: self.state.bytecode,
208 symtable: table,
209 },
210 })
211 }
212
213 pub fn symbols(&self) -> &[SymbolMetadata] {
215 &self.state.symbols
216 }
217
218 pub fn version(&self) -> &str {
220 &self.state.version
221 }
222
223 fn from_bytecode(data: &[u8], origin: ProgramOrigin) -> Result<Self, ProgramError> {
229 let config = bincode::config::standard();
230 let (binary, _): (BinaryFormat, _) = bincode::serde::decode_from_slice(data, config)?;
231
232 if binary.version != PROGRAM_VERSION {
234 return Err(ProgramError::IncompatibleVersion {
235 expected: PROGRAM_VERSION.to_string(),
236 found: binary.version,
237 });
238 }
239
240 Ok(Program {
241 source: None, state: Compiled {
243 origin,
244 version: binary.version,
245 bytecode: binary.bytecode,
246 symbols: binary.symbols,
247 },
248 })
249 }
250
251 fn highlight_error(input: &str, error: &ParseError) -> String {
253 let span = error.span();
254 let pre = Self::escape(&input[..span.start]);
255 let tok = Self::escape(&input[span.start..span.end]);
256 let post = Self::escape(&input[span.end..]);
257 let line = format!("{}{}{}", pre, tok.red().bold(), post);
258
259 let caret = "^".green().bold();
260 let squiggly_len = UnicodeWidthStr::width(tok.as_str());
261 let caret_offset = UnicodeWidthStr::width(pre.as_str()) + caret.len();
262
263 format!(
264 "1 | {0}\n | {1: >2$}{3}",
265 line,
266 caret,
267 caret_offset,
268 "~".repeat(squiggly_len.saturating_sub(1)).green()
269 )
270 }
271
272 fn escape(s: &str) -> String {
274 let mut out = String::with_capacity(s.len());
275 for c in s.chars() {
276 match c {
277 '\n' => out.push_str("\\n"),
278 '\r' => out.push_str("\\r"),
279 other => out.push(other),
280 }
281 }
282 out
283 }
284
285 fn generate_bytecode(ast: &Expr) -> (Vec<Instr>, Vec<SymbolMetadata>) {
287 let mut bytecode = Vec::new();
288 let mut symbols = Vec::new();
289 Self::emit_instr(ast, &mut bytecode, &mut symbols);
290 (bytecode, symbols)
291 }
292
293 fn emit_instr(expr: &Expr, bytecode: &mut Vec<Instr>, symbols: &mut Vec<SymbolMetadata>) {
295 match &expr.kind {
296 ExprKind::Literal(v) => {
297 bytecode.push(Instr::Push(*v));
298 }
299 ExprKind::Ident { name } => {
300 let idx = Self::get_or_create_symbol(name, SymbolKind::Const, symbols);
302 bytecode.push(Instr::Load(idx));
303 }
304 ExprKind::Unary { op, expr } => {
305 Self::emit_instr(expr, bytecode, symbols);
306 match op {
307 UnOp::Neg => bytecode.push(Instr::Neg),
308 UnOp::Fact => bytecode.push(Instr::Fact),
309 }
310 }
311 ExprKind::Binary { op, left, right } => {
312 Self::emit_instr(left, bytecode, symbols);
313 Self::emit_instr(right, bytecode, symbols);
314 bytecode.push(match op {
315 BinOp::Add => Instr::Add,
316 BinOp::Sub => Instr::Sub,
317 BinOp::Mul => Instr::Mul,
318 BinOp::Div => Instr::Div,
319 BinOp::Pow => Instr::Pow,
320 BinOp::Equal => Instr::Equal,
321 BinOp::NotEqual => Instr::NotEqual,
322 BinOp::Less => Instr::Less,
323 BinOp::LessEqual => Instr::LessEqual,
324 BinOp::Greater => Instr::Greater,
325 BinOp::GreaterEqual => Instr::GreaterEqual,
326 });
327 }
328 ExprKind::Call { name, args } => {
329 for arg in args {
331 Self::emit_instr(arg, bytecode, symbols);
332 }
333
334 let idx = Self::get_or_create_symbol(
336 name,
337 SymbolKind::Func {
338 arity: args.len(),
339 variadic: false, },
341 symbols,
342 );
343 bytecode.push(Instr::Call(idx, args.len()));
344 }
345 }
346 }
347
348 fn get_or_create_symbol(
351 name: &str,
352 kind: SymbolKind,
353 symbols: &mut Vec<SymbolMetadata>,
354 ) -> usize {
355 if let Some(pos) = symbols.iter().position(|s| s.name == name) {
357 return pos;
358 }
359
360 symbols.push(SymbolMetadata {
362 name: name.to_string().into(),
363 kind,
364 index: None,
365 });
366 symbols.len() - 1
367 }
368
369 fn validate_symbol_kind(metadata: &SymbolMetadata, symbol: &Symbol) -> Result<(), LinkError> {
371 match (&metadata.kind, symbol) {
372 (SymbolKind::Const, Symbol::Const { .. }) => Ok(()),
373 (
374 SymbolKind::Func { arity, .. },
375 Symbol::Func {
376 args: min_args,
377 variadic,
378 ..
379 },
380 ) => {
381 let valid = if *variadic {
385 arity >= min_args
386 } else {
387 arity == min_args
388 };
389
390 if valid {
391 Ok(())
392 } else {
393 let expected_msg = if *variadic {
394 format!("at least {} arguments", min_args)
395 } else {
396 format!("exactly {} arguments", min_args)
397 };
398 Err(LinkError::TypeMismatch {
399 name: metadata.name.to_string(),
400 expected: expected_msg,
401 found: format!("{} arguments provided", arity),
402 })
403 }
404 }
405 (SymbolKind::Const, Symbol::Func { .. }) => Err(LinkError::TypeMismatch {
406 name: metadata.name.to_string(),
407 expected: "constant".to_string(),
408 found: "function".to_string(),
409 }),
410 (SymbolKind::Func { .. }, Symbol::Const { .. }) => Err(LinkError::TypeMismatch {
411 name: metadata.name.to_string(),
412 expected: "function".to_string(),
413 found: "constant".to_string(),
414 }),
415 }
416 }
417}
418
419impl<'src> Program<'src, Linked> {
424 pub fn execute(&self) -> Result<Decimal, VmError> {
430 Vm.run_bytecode(&self.state.bytecode, &self.state.symtable)
431 }
432
433 pub fn symtable(&self) -> &SymTable {
435 &self.state.symtable
436 }
437
438 pub fn symtable_mut(&mut self) -> &mut SymTable {
440 &mut self.state.symtable
441 }
442
443 pub fn version(&self) -> &str {
445 &self.state.version
446 }
447
448 pub fn get_assembly(&self) -> String {
450 Self::format_assembly(
451 &self.state.version,
452 &self.state.bytecode,
453 &self.state.symtable,
454 )
455 }
456
457 pub fn to_bytecode(&self) -> Result<Vec<u8>, ProgramError> {
461 use std::collections::HashMap;
462
463 let mut reverse_map = HashMap::new();
464 let mut symbols = Vec::new();
465
466 let mut get_or_create_metadata = |idx: usize| -> usize {
469 if let Some(&existing) = reverse_map.get(&idx) {
470 existing
471 } else {
472 let symbol = self
473 .state
474 .symtable
475 .get_by_index(idx)
476 .expect("symbol index must be valid after linking");
477
478 let new_idx = symbols.len();
479 symbols.push(symbol.into());
480 reverse_map.insert(idx, new_idx);
481 new_idx
482 }
483 };
484
485 let bytecode: Vec<Instr> = self
487 .state
488 .bytecode
489 .iter()
490 .map(|instr| match instr {
491 Instr::Load(idx) => Instr::Load(get_or_create_metadata(*idx)),
492 Instr::Call(idx, argc) => Instr::Call(get_or_create_metadata(*idx), *argc),
493 other => other.clone(),
494 })
495 .collect();
496
497 let binary = BinaryFormat {
499 version: self.state.version.clone(),
500 bytecode,
501 symbols,
502 };
503
504 let config = bincode::config::standard();
505 Ok(bincode::serde::encode_to_vec(&binary, config)?)
506 }
507
508 pub fn save_bytecode_to_file(
510 &self,
511 path: impl AsRef<std::path::Path>,
512 ) -> Result<(), ProgramError> {
513 let bytecode = self.to_bytecode()?;
514 std::fs::write(path, bytecode)?;
515 Ok(())
516 }
517
518 fn format_assembly(version: &str, bytecode: &[Instr], table: &SymTable) -> String {
524 use std::fmt::Write as _;
525
526 let mut out = String::new();
527 out += &format!("; VERSION {}\n", version)
528 .bright_black()
529 .to_string();
530
531 for (i, instr) in bytecode.iter().enumerate() {
532 let _ = write!(out, "{} ", format!("{:04X}", i).yellow());
533 let line = match instr {
534 Instr::Push(v) => format!("{} {}", "PUSH".magenta(), v.to_string().green()),
535 Instr::Load(idx) => {
536 let sym_name = table.get_by_index(*idx).map(|s| s.name()).unwrap_or("???");
537 format!("{} {}", "LOAD".magenta(), sym_name.blue())
538 }
539 Instr::Neg => format!("{}", "NEG".magenta()),
540 Instr::Add => format!("{}", "ADD".magenta()),
541 Instr::Sub => format!("{}", "SUB".magenta()),
542 Instr::Mul => format!("{}", "MUL".magenta()),
543 Instr::Div => format!("{}", "DIV".magenta()),
544 Instr::Pow => format!("{}", "POW".magenta()),
545 Instr::Fact => format!("{}", "FACT".magenta()),
546 Instr::Call(idx, argc) => {
547 let sym_name = table.get_by_index(*idx).map(|s| s.name()).unwrap_or("???");
548 format!(
549 "{} {} args: {}",
550 "CALL".magenta(),
551 sym_name.cyan(),
552 argc.to_string().bright_blue()
553 )
554 }
555 Instr::Equal => format!("{}", "EQ".magenta()),
556 Instr::NotEqual => format!("{}", "NEQ".magenta()),
557 Instr::Less => format!("{}", "LT".magenta()),
558 Instr::LessEqual => format!("{}", "LTE".magenta()),
559 Instr::Greater => format!("{}", "GT".magenta()),
560 Instr::GreaterEqual => format!("{}", "GTE".magenta()),
561 };
562 let _ = writeln!(out, "{}", line);
563 }
564 out
565 }
566}