1use super::error::{ParseError, ProgramError};
13use super::ir_builder::IrBuilder;
14use super::linker::Linker;
15use super::metadata::SymbolMetadata;
16use super::parser::Parser;
17use crate::ir::Instr;
18use crate::number::Number;
19use crate::span::SpanError;
20use crate::symtable::SymTable;
21use crate::vm::{Vm, VmError};
22use colored::Colorize;
23#[cfg(feature = "serialization")]
24use serde::{Deserialize, Serialize};
25use unicode_width::UnicodeWidthStr;
26
27const PROGRAM_VERSION: &str = env!("CARGO_PKG_VERSION");
29
30#[cfg(feature = "serialization")]
32#[derive(Debug, Clone, Serialize, Deserialize)]
33struct BinaryFormat {
34 version: String,
35 bytecode: Vec<Instr>,
36 symbols: Vec<SymbolMetadata>,
37}
38
39#[derive(Debug, Clone)]
41pub enum ProgramOrigin {
42 #[cfg(feature = "serialization")]
44 File(String),
45 Source,
47 #[cfg(feature = "serialization")]
49 Bytecode,
50}
51
52#[derive(Debug)]
71pub struct Program<'src, State> {
72 source: Option<&'src str>,
73 state: State,
74}
75
76#[derive(Debug)]
78pub struct Compiled {
79 origin: ProgramOrigin,
80 version: String,
81 bytecode: Vec<Instr>,
82 symbols: Vec<SymbolMetadata>,
83}
84
85#[derive(Debug)]
87pub struct Linked {
88 #[allow(dead_code)]
89 origin: ProgramOrigin,
90 version: String,
91 bytecode: Vec<Instr>,
92 symtable: SymTable,
93}
94
95impl<'src> Program<'src, Compiled> {
100 pub fn new_from_source(source: &'src str) -> Result<Self, ProgramError> {
114 let trimmed = source.trim();
115
116 let mut parser = Parser::new(trimmed);
118 let ast_opt = parser.parse().map_err(|parse_err| {
119 let highlighted = Self::highlight_error(trimmed, &parse_err);
121 ProgramError::ParseError(format!("{}\n{}", parse_err, highlighted))
122 })?;
123
124 let (bytecode, symbols) = if let Some(ast) = ast_opt {
126 IrBuilder::new().build(&ast)?
127 } else {
128 (Vec::new(), Vec::new())
130 };
131
132 Ok(Program {
133 source: Some(trimmed),
134 state: Compiled {
135 origin: ProgramOrigin::Source,
136 version: PROGRAM_VERSION.to_string(),
137 bytecode,
138 symbols,
139 },
140 })
141 }
142
143 #[cfg(feature = "serialization")]
153 pub fn new_from_file(path: impl Into<String>) -> Result<Self, ProgramError> {
154 let path_str = path.into();
155 let data = std::fs::read(&path_str)?;
156 Self::from_bytecode(&data, ProgramOrigin::File(path_str))
157 }
158
159 #[cfg(feature = "serialization")]
163 pub fn new_from_bytecode(data: &[u8]) -> Result<Self, ProgramError> {
164 Self::from_bytecode(data, ProgramOrigin::Bytecode)
165 }
166
167 pub fn link(self, table: SymTable) -> Result<Program<'src, Linked>, ProgramError> {
180 let linker = Linker::new(self.state.bytecode, self.state.symbols, table);
181 let (bytecode, symtable) = linker.link()?;
182
183 Ok(Program {
184 source: self.source,
185 state: Linked {
186 origin: self.state.origin,
187 version: self.state.version,
188 bytecode,
189 symtable,
190 },
191 })
192 }
193
194 #[cfg(feature = "serialization")]
200 fn from_bytecode(data: &[u8], origin: ProgramOrigin) -> Result<Self, ProgramError> {
201 let config = bincode::config::standard();
202 let (binary, _): (BinaryFormat, _) = bincode::serde::decode_from_slice(data, config)?;
203
204 if binary.version != PROGRAM_VERSION {
206 return Err(ProgramError::IncompatibleVersion {
207 expected: PROGRAM_VERSION.to_string(),
208 found: binary.version,
209 });
210 }
211
212 Ok(Program {
213 source: None, state: Compiled {
215 origin,
216 version: binary.version,
217 bytecode: binary.bytecode,
218 symbols: binary.symbols,
219 },
220 })
221 }
222
223 fn highlight_error(input: &str, error: &ParseError) -> String {
225 let span = error.span();
226 let pre = Self::escape(&input[..span.start]);
227 let tok = Self::escape(&input[span.start..span.end]);
228 let post = Self::escape(&input[span.end..]);
229 let line = format!("{}{}{}", pre, tok.red().bold(), post);
230
231 let caret = "^".green().bold();
232 let squiggly_len = UnicodeWidthStr::width(tok.as_str());
233 let caret_offset = UnicodeWidthStr::width(pre.as_str()) + caret.len();
234
235 format!(
236 "1 | {0}\n | {1: >2$}{3}",
237 line,
238 caret,
239 caret_offset,
240 "~".repeat(squiggly_len.saturating_sub(1)).green()
241 )
242 }
243
244 fn escape(s: &str) -> String {
246 let mut out = String::with_capacity(s.len());
247 for c in s.chars() {
248 match c {
249 '\n' => out.push_str("\\n"),
250 '\r' => out.push_str("\\r"),
251 other => out.push(other),
252 }
253 }
254 out
255 }
256}
257
258impl<'src> Program<'src, Linked> {
263 pub fn execute(&mut self) -> Result<Number, VmError> {
269 Vm::run(&self.state.bytecode, &mut self.state.symtable)
270 }
271
272 pub fn symtable_mut(&mut self) -> &mut SymTable {
274 &mut self.state.symtable
275 }
276
277 pub fn get_assembly(&self) -> String {
279 use std::fmt::Write as _;
280
281 let mut out = String::new();
282 out += &format!("; VERSION {}\n", self.state.version)
283 .bright_black()
284 .to_string();
285
286 for (i, instr) in self.state.bytecode.iter().enumerate() {
287 let _ = write!(out, "{} ", format!("{:04X}", i).yellow());
288 let line = match instr {
289 Instr::Push(v) => format!("{} {}", "PUSH".magenta(), v.to_string().green()),
290 Instr::Load(idx) => {
291 let sym_name = self
292 .state
293 .symtable
294 .get_by_index(*idx)
295 .map(|s| s.name())
296 .expect("Symbol not found in assembly");
297 format!("{} {}", "LOAD".magenta(), sym_name.blue())
298 }
299 Instr::Store(idx) => {
300 let sym_name = self
301 .state
302 .symtable
303 .get_by_index(*idx)
304 .map(|s| s.name())
305 .expect("Symbol not found in assembly");
306 format!("{} {}", "STORE".magenta(), sym_name.blue())
307 }
308 Instr::Neg => format!("{}", "NEG".magenta()),
309 Instr::Add => format!("{}", "ADD".magenta()),
310 Instr::Sub => format!("{}", "SUB".magenta()),
311 Instr::Mul => format!("{}", "MUL".magenta()),
312 Instr::Div => format!("{}", "DIV".magenta()),
313 Instr::Pow => format!("{}", "POW".magenta()),
314 Instr::Fact => format!("{}", "FACT".magenta()),
315 Instr::Call(idx, argc) => {
316 let sym_name = self
317 .state
318 .symtable
319 .get_by_index(*idx)
320 .map(|s| s.name())
321 .expect("Symbol not found in assembly");
322 format!(
323 "{} {} args: {}",
324 "CALL".magenta(),
325 sym_name.cyan(),
326 argc.to_string().bright_blue()
327 )
328 }
329 Instr::Equal => format!("{}", "EQ".magenta()),
330 Instr::NotEqual => format!("{}", "NEQ".magenta()),
331 Instr::Less => format!("{}", "LT".magenta()),
332 Instr::LessEqual => format!("{}", "LTE".magenta()),
333 Instr::Greater => format!("{}", "GT".magenta()),
334 Instr::GreaterEqual => format!("{}", "GTE".magenta()),
335 Instr::Jmp(target) => {
336 format!("{} {}", "JMP".magenta(), format!("{:04X}", target).yellow())
337 }
338 Instr::Jz(target) => {
339 format!("{} {}", "JZ".magenta(), format!("{:04X}", target).yellow())
340 }
341 };
342 let _ = writeln!(out, "{}", line);
343 }
344 out
345 }
346
347 #[cfg(feature = "serialization")]
351 pub fn to_bytecode(&self) -> Result<Vec<u8>, ProgramError> {
352 use std::collections::HashMap;
353
354 let mut reverse_map = HashMap::new();
355 let mut symbols = Vec::new();
356
357 let mut get_or_create_metadata = |idx: usize| -> usize {
360 if let Some(&existing) = reverse_map.get(&idx) {
361 existing
362 } else {
363 let symbol = self
364 .state
365 .symtable
366 .get_by_index(idx)
367 .expect("symbol index must be valid after linking");
368
369 let new_idx = symbols.len();
370 symbols.push(symbol.into());
371 reverse_map.insert(idx, new_idx);
372 new_idx
373 }
374 };
375
376 let bytecode: Vec<Instr> = self
378 .state
379 .bytecode
380 .iter()
381 .map(|instr| match instr {
382 Instr::Load(idx) => Instr::Load(get_or_create_metadata(*idx)),
383 Instr::Store(idx) => Instr::Store(get_or_create_metadata(*idx)),
384 Instr::Call(idx, argc) => Instr::Call(get_or_create_metadata(*idx), *argc),
385 other => other.clone(),
386 })
387 .collect();
388
389 let binary = BinaryFormat {
391 version: self.state.version.clone(),
392 bytecode,
393 symbols,
394 };
395
396 let config = bincode::config::standard();
397 Ok(bincode::serde::encode_to_vec(&binary, config)?)
398 }
399
400 #[cfg(feature = "serialization")]
402 pub fn save_bytecode_to_file(
403 &self,
404 path: impl AsRef<std::path::Path>,
405 ) -> Result<(), ProgramError> {
406 let bytecode = self.to_bytecode()?;
407 std::fs::write(path, bytecode)?;
408 Ok(())
409 }
410}