expr_solver/
program.rs

1//! Type-state program implementation for compile-link-execute workflow.
2//!
3//! This module provides the [`Program`] type which orchestrates the compilation pipeline:
4//!
5//! 1. **Parsing** - Source code is parsed into an AST using [`Parser`]
6//! 2. **IR Generation** - AST is compiled to bytecode using [`IrBuilder`]
7//! 3. **Linking** - Bytecode is linked with a symbol table using [`Linker`]
8//! 4. **Execution** - Linked bytecode is executed on the VM
9//!
10//! The type-state pattern ensures these stages occur in the correct order at compile time.
11
12use super::error::{ParseError, ProgramError};
13use super::ir_builder::IrBuilder;
14use super::linker::Linker;
15use super::metadata::SymbolMetadata;
16use super::parser::Parser;
17use crate::ir::Instr;
18use crate::number::Number;
19use crate::span::SpanError;
20use crate::symtable::SymTable;
21use crate::vm::{Vm, VmError};
22use colored::Colorize;
23#[cfg(feature = "serialization")]
24use serde::{Deserialize, Serialize};
25use unicode_width::UnicodeWidthStr;
26
27/// Current version of the program format
28const PROGRAM_VERSION: &str = env!("CARGO_PKG_VERSION");
29
30/// Binary format for serialization
31#[cfg(feature = "serialization")]
32#[derive(Debug, Clone, Serialize, Deserialize)]
33struct BinaryFormat {
34    version: String,
35    bytecode: Vec<Instr>,
36    symbols: Vec<SymbolMetadata>,
37}
38
39/// Origin of a compiled program.
40#[derive(Debug, Clone)]
41pub enum ProgramOrigin {
42    /// Loaded from a file (path stored)
43    #[cfg(feature = "serialization")]
44    File(String),
45    /// Compiled from source string
46    Source,
47    /// Deserialized from bytecode bytes
48    #[cfg(feature = "serialization")]
49    Bytecode,
50}
51
52/// Type-state program using Rust's type system to enforce correct usage.
53///
54/// # Examples
55///
56/// ```
57/// use expr_solver::{num, Program, SymTable};
58///
59/// // Compile from source
60/// let program = Program::new_from_source("x * 2 + 1").unwrap();
61///
62/// // Link with symbol table
63/// let mut table = SymTable::new();
64/// table.add_const("x", num!(5), false).unwrap();
65/// let mut linked = program.link(table).unwrap();
66///
67/// // Execute
68/// assert_eq!(linked.execute().unwrap(), num!(11));
69/// ```
70#[derive(Debug)]
71pub struct Program<'src, State> {
72    source: Option<&'src str>,
73    state: State,
74}
75
76/// Compiled state - bytecode ready for linking.
77#[derive(Debug)]
78pub struct Compiled {
79    origin: ProgramOrigin,
80    version: String,
81    bytecode: Vec<Instr>,
82    symbols: Vec<SymbolMetadata>,
83}
84
85/// Linked state - ready to execute.
86#[derive(Debug)]
87pub struct Linked {
88    #[allow(dead_code)]
89    origin: ProgramOrigin,
90    version: String,
91    bytecode: Vec<Instr>,
92    symtable: SymTable,
93}
94
95// ============================================================================
96// Program - Public constructors (return Compiled state directly)
97// ============================================================================
98
99impl<'src> Program<'src, Compiled> {
100    // ========================================================================
101    // Public API
102    // ========================================================================
103
104    /// Creates a compiled program from source code.
105    ///
106    /// # Examples
107    ///
108    /// ```
109    /// use expr_solver::Program;
110    ///
111    /// let program = Program::new_from_source("2 + 3 * 4").unwrap();
112    /// ```
113    pub fn new_from_source(source: &'src str) -> Result<Self, ProgramError> {
114        let trimmed = source.trim();
115
116        // Parse
117        let mut parser = Parser::new(trimmed);
118        let ast_opt = parser.parse().map_err(|parse_err| {
119            // Format error with source highlighting
120            let highlighted = Self::highlight_error(trimmed, &parse_err);
121            ProgramError::ParseError(format!("{}\n{}", parse_err, highlighted))
122        })?;
123
124        // Compile (handle empty input by creating empty bytecode)
125        let (bytecode, symbols) = if let Some(ast) = ast_opt {
126            IrBuilder::new().build(&ast)?
127        } else {
128            // Empty input -> empty program (VM will return 0)
129            (Vec::new(), Vec::new())
130        };
131
132        Ok(Program {
133            source: Some(trimmed),
134            state: Compiled {
135                origin: ProgramOrigin::Source,
136                version: PROGRAM_VERSION.to_string(),
137                bytecode,
138                symbols,
139            },
140        })
141    }
142
143    /// Creates a compiled program from a binary file.
144    ///
145    /// # Examples
146    ///
147    /// ```no_run
148    /// use expr_solver::Program;
149    ///
150    /// let program = Program::new_from_file("expr.bin").unwrap();
151    /// ```
152    #[cfg(feature = "serialization")]
153    pub fn new_from_file(path: impl Into<String>) -> Result<Self, ProgramError> {
154        let path_str = path.into();
155        let data = std::fs::read(&path_str)?;
156        Self::from_bytecode(&data, ProgramOrigin::File(path_str))
157    }
158
159    /// Creates a compiled program from bytecode bytes.
160    ///
161    /// Deserializes the bytecode and validates the version.
162    #[cfg(feature = "serialization")]
163    pub fn new_from_bytecode(data: &[u8]) -> Result<Self, ProgramError> {
164        Self::from_bytecode(data, ProgramOrigin::Bytecode)
165    }
166
167    /// Links the bytecode with a symbol table.
168    ///
169    /// Validates that all required symbols are present and compatible.
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// use expr_solver::{Program, SymTable};
175    ///
176    /// let program = Program::new_from_source("sin(pi)").unwrap();
177    /// let linked = program.link(SymTable::stdlib()).unwrap();
178    /// ```
179    pub fn link(self, table: SymTable) -> Result<Program<'src, Linked>, ProgramError> {
180        let linker = Linker::new(self.state.bytecode, self.state.symbols, table);
181        let (bytecode, symtable) = linker.link()?;
182
183        Ok(Program {
184            source: self.source,
185            state: Linked {
186                origin: self.state.origin,
187                version: self.state.version,
188                bytecode,
189                symtable,
190            },
191        })
192    }
193
194    // ========================================================================
195    // Private helpers
196    // ========================================================================
197
198    /// Internal helper to create program from bytecode with a specific origin.
199    #[cfg(feature = "serialization")]
200    fn from_bytecode(data: &[u8], origin: ProgramOrigin) -> Result<Self, ProgramError> {
201        let config = bincode::config::standard();
202        let (binary, _): (BinaryFormat, _) = bincode::serde::decode_from_slice(data, config)?;
203
204        // Validate version
205        if binary.version != PROGRAM_VERSION {
206            return Err(ProgramError::IncompatibleVersion {
207                expected: PROGRAM_VERSION.to_string(),
208                found: binary.version,
209            });
210        }
211
212        Ok(Program {
213            source: None, // No source for bytecode
214            state: Compiled {
215                origin,
216                version: binary.version,
217                bytecode: binary.bytecode,
218                symbols: binary.symbols,
219            },
220        })
221    }
222
223    /// Highlights an error in the source code.
224    fn highlight_error(input: &str, error: &ParseError) -> String {
225        let span = error.span();
226        let pre = Self::escape(&input[..span.start]);
227        let tok = Self::escape(&input[span.start..span.end]);
228        let post = Self::escape(&input[span.end..]);
229        let line = format!("{}{}{}", pre, tok.red().bold(), post);
230
231        let caret = "^".green().bold();
232        let squiggly_len = UnicodeWidthStr::width(tok.as_str());
233        let caret_offset = UnicodeWidthStr::width(pre.as_str()) + caret.len();
234
235        format!(
236            "1 | {0}\n  | {1: >2$}{3}",
237            line,
238            caret,
239            caret_offset,
240            "~".repeat(squiggly_len.saturating_sub(1)).green()
241        )
242    }
243
244    /// Escapes special characters for display.
245    fn escape(s: &str) -> String {
246        let mut out = String::with_capacity(s.len());
247        for c in s.chars() {
248            match c {
249                '\n' => out.push_str("\\n"),
250                '\r' => out.push_str("\\r"),
251                other => out.push(other),
252            }
253        }
254        out
255    }
256}
257
258// ============================================================================
259// Program<Linked> - After linking, ready to execute
260// ============================================================================
261
262impl<'src> Program<'src, Linked> {
263    // ========================================================================
264    // Public API
265    // ========================================================================
266
267    /// Executes the program and returns the result.
268    pub fn execute(&mut self) -> Result<Number, VmError> {
269        Vm::run(&self.state.bytecode, &mut self.state.symtable)
270    }
271
272    /// Returns a mutable reference to the symbol table.
273    pub fn symtable_mut(&mut self) -> &mut SymTable {
274        &mut self.state.symtable
275    }
276
277    /// Returns a human-readable assembly representation of the program.
278    pub fn get_assembly(&self) -> String {
279        use std::fmt::Write as _;
280
281        let mut out = String::new();
282        out += &format!("; VERSION {}\n", self.state.version)
283            .bright_black()
284            .to_string();
285
286        for (i, instr) in self.state.bytecode.iter().enumerate() {
287            let _ = write!(out, "{} ", format!("{:04X}", i).yellow());
288            let line = match instr {
289                Instr::Push(v) => format!("{} {}", "PUSH".magenta(), v.to_string().green()),
290                Instr::Load(idx) => {
291                    let sym_name = self
292                        .state
293                        .symtable
294                        .get_by_index(*idx)
295                        .map(|s| s.name())
296                        .expect("Symbol not found in assembly");
297                    format!("{} {}", "LOAD".magenta(), sym_name.blue())
298                }
299                Instr::Store(idx) => {
300                    let sym_name = self
301                        .state
302                        .symtable
303                        .get_by_index(*idx)
304                        .map(|s| s.name())
305                        .expect("Symbol not found in assembly");
306                    format!("{} {}", "STORE".magenta(), sym_name.blue())
307                }
308                Instr::Neg => format!("{}", "NEG".magenta()),
309                Instr::Add => format!("{}", "ADD".magenta()),
310                Instr::Sub => format!("{}", "SUB".magenta()),
311                Instr::Mul => format!("{}", "MUL".magenta()),
312                Instr::Div => format!("{}", "DIV".magenta()),
313                Instr::Pow => format!("{}", "POW".magenta()),
314                Instr::Fact => format!("{}", "FACT".magenta()),
315                Instr::Call(idx, argc) => {
316                    let sym_name = self
317                        .state
318                        .symtable
319                        .get_by_index(*idx)
320                        .map(|s| s.name())
321                        .expect("Symbol not found in assembly");
322                    format!(
323                        "{} {} args: {}",
324                        "CALL".magenta(),
325                        sym_name.cyan(),
326                        argc.to_string().bright_blue()
327                    )
328                }
329                Instr::Equal => format!("{}", "EQ".magenta()),
330                Instr::NotEqual => format!("{}", "NEQ".magenta()),
331                Instr::Less => format!("{}", "LT".magenta()),
332                Instr::LessEqual => format!("{}", "LTE".magenta()),
333                Instr::Greater => format!("{}", "GT".magenta()),
334                Instr::GreaterEqual => format!("{}", "GTE".magenta()),
335                Instr::Jmp(target) => {
336                    format!("{} {}", "JMP".magenta(), format!("{:04X}", target).yellow())
337                }
338                Instr::Jz(target) => {
339                    format!("{} {}", "JZ".magenta(), format!("{:04X}", target).yellow())
340                }
341            };
342            let _ = writeln!(out, "{}", line);
343        }
344        out
345    }
346
347    /// Converts the program to bytecode bytes.
348    ///
349    /// This involves reverse-mapping the bytecode indices back to metadata indices.
350    #[cfg(feature = "serialization")]
351    pub fn to_bytecode(&self) -> Result<Vec<u8>, ProgramError> {
352        use std::collections::HashMap;
353
354        let mut reverse_map = HashMap::new();
355        let mut symbols = Vec::new();
356
357        // Helper closure to get or create metadata index
358        // All indices are valid since we successfully linked
359        let mut get_or_create_metadata = |idx: usize| -> usize {
360            if let Some(&existing) = reverse_map.get(&idx) {
361                existing
362            } else {
363                let symbol = self
364                    .state
365                    .symtable
366                    .get_by_index(idx)
367                    .expect("symbol index must be valid after linking");
368
369                let new_idx = symbols.len();
370                symbols.push(symbol.into());
371                reverse_map.insert(idx, new_idx);
372                new_idx
373            }
374        };
375
376        // Single pass: build symbol mapping and rewrite bytecode
377        let bytecode: Vec<Instr> = self
378            .state
379            .bytecode
380            .iter()
381            .map(|instr| match instr {
382                Instr::Load(idx) => Instr::Load(get_or_create_metadata(*idx)),
383                Instr::Store(idx) => Instr::Store(get_or_create_metadata(*idx)),
384                Instr::Call(idx, argc) => Instr::Call(get_or_create_metadata(*idx), *argc),
385                other => other.clone(),
386            })
387            .collect();
388
389        // Serialize
390        let binary = BinaryFormat {
391            version: self.state.version.clone(),
392            bytecode,
393            symbols,
394        };
395
396        let config = bincode::config::standard();
397        Ok(bincode::serde::encode_to_vec(&binary, config)?)
398    }
399
400    /// Saves the program bytecode to a file.
401    #[cfg(feature = "serialization")]
402    pub fn save_bytecode_to_file(
403        &self,
404        path: impl AsRef<std::path::Path>,
405    ) -> Result<(), ProgramError> {
406        let bytecode = self.to_bytecode()?;
407        std::fs::write(path, bytecode)?;
408        Ok(())
409    }
410}