expr_solver/
program.rs

1//! Type-state program implementation for compile-link-execute workflow.
2
3use super::ast::{BinOp, Expr, ExprKind, UnOp};
4use super::error::{LinkError, ParseError, ProgramError};
5use super::metadata::{SymbolKind, SymbolMetadata};
6use super::parser::Parser;
7use crate::ir::Instr;
8use crate::span::{Span, SpanError};
9use crate::symbol::{SymTable, Symbol};
10use crate::vm::{Vm, VmError};
11use colored::Colorize;
12use rust_decimal::Decimal;
13use serde::{Deserialize, Serialize};
14use unicode_width::UnicodeWidthStr;
15
16/// Current version of the program format
17const PROGRAM_VERSION: &str = env!("CARGO_PKG_VERSION");
18
19/// Binary format for serialization
20#[derive(Debug, Clone, Serialize, Deserialize)]
21struct BinaryFormat {
22    version: String,
23    bytecode: Vec<Instr>,
24    symbols: Vec<SymbolMetadata>,
25}
26
27/// Origin of a compiled program.
28#[derive(Debug, Clone)]
29pub enum ProgramOrigin {
30    /// Loaded from a file (path stored)
31    File(String),
32    /// Compiled from source string
33    Source,
34    /// Deserialized from bytecode bytes
35    Bytecode,
36}
37
38/// Type-state program using Rust's type system to enforce correct usage.
39///
40/// # Examples
41///
42/// ```
43/// use expr_solver::{Program, SymTable};
44/// use rust_decimal_macros::dec;
45///
46/// // Compile from source
47/// let program = Program::new_from_source("x * 2 + 1").unwrap();
48///
49/// // Link with symbol table
50/// let mut table = SymTable::new();
51/// table.add_const("x", dec!(5)).unwrap();
52/// let linked = program.link(table).unwrap();
53///
54/// // Execute
55/// assert_eq!(linked.execute().unwrap(), dec!(11));
56/// ```
57#[derive(Debug)]
58pub struct Program<'src, State> {
59    source: Option<&'src str>,
60    state: State,
61}
62
63/// Compiled state - bytecode ready for linking.
64#[derive(Debug)]
65pub struct Compiled {
66    origin: ProgramOrigin,
67    version: String,
68    bytecode: Vec<Instr>,
69    symbols: Vec<SymbolMetadata>,
70}
71
72/// Linked state - ready to execute.
73#[derive(Debug)]
74pub struct Linked {
75    #[allow(dead_code)]
76    origin: ProgramOrigin,
77    version: String,
78    bytecode: Vec<Instr>,
79    symtable: SymTable,
80}
81
82// ============================================================================
83// Program - Public constructors (return Compiled state directly)
84// ============================================================================
85
86impl<'src> Program<'src, Compiled> {
87    // ========================================================================
88    // Public API
89    // ========================================================================
90
91    /// Creates a compiled program from source code.
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// use expr_solver::Program;
97    ///
98    /// let program = Program::new_from_source("2 + 3 * 4").unwrap();
99    /// ```
100    pub fn new_from_source(source: &'src str) -> Result<Self, ProgramError> {
101        let trimmed = source.trim();
102
103        // Parse
104        let mut parser = Parser::new(trimmed);
105        let ast = parser
106            .parse()
107            .map_err(|parse_err| {
108                // Format error with source highlighting
109                let highlighted = Self::highlight_error(trimmed, &parse_err);
110                ProgramError::ParseError(format!("{}\n{}", parse_err, highlighted))
111            })?
112            .ok_or_else(|| {
113                let parse_err = ParseError::UnexpectedEof {
114                    span: Span::new(0, 0),
115                };
116                let highlighted = Self::highlight_error(trimmed, &parse_err);
117                ProgramError::ParseError(format!("{}\n{}", parse_err, highlighted))
118            })?;
119
120        // Compile
121        let (bytecode, symbols) = Self::generate_bytecode(&ast);
122
123        Ok(Program {
124            source: Some(trimmed),
125            state: Compiled {
126                origin: ProgramOrigin::Source,
127                version: PROGRAM_VERSION.to_string(),
128                bytecode,
129                symbols,
130            },
131        })
132    }
133
134    /// Creates a compiled program from a binary file.
135    ///
136    /// # Examples
137    ///
138    /// ```no_run
139    /// use expr_solver::Program;
140    ///
141    /// let program = Program::new_from_file("expr.bin").unwrap();
142    /// ```
143    pub fn new_from_file(path: impl Into<String>) -> Result<Self, ProgramError> {
144        let path_str = path.into();
145        let data = std::fs::read(&path_str)?;
146        Self::from_bytecode(&data, ProgramOrigin::File(path_str))
147    }
148
149    /// Creates a compiled program from bytecode bytes.
150    ///
151    /// Deserializes the bytecode and validates the version.
152    pub fn new_from_bytecode(data: &[u8]) -> Result<Self, ProgramError> {
153        Self::from_bytecode(data, ProgramOrigin::Bytecode)
154    }
155
156    /// Links the bytecode with a symbol table.
157    ///
158    /// Validates that all required symbols are present and compatible.
159    ///
160    /// # Examples
161    ///
162    /// ```
163    /// use expr_solver::{Program, SymTable};
164    ///
165    /// let program = Program::new_from_source("sin(pi)").unwrap();
166    /// let linked = program.link(SymTable::stdlib()).unwrap();
167    /// ```
168    pub fn link(mut self, table: SymTable) -> Result<Program<'src, Linked>, ProgramError> {
169        // Validate symbols and fill in their resolved indices
170        for metadata in &mut self.state.symbols {
171            let (resolved_idx, symbol) =
172                table
173                    .get_with_index(&metadata.name)
174                    .ok_or_else(|| LinkError::MissingSymbol {
175                        name: metadata.name.to_string(),
176                    })?;
177
178            // Validate kind matches
179            Self::validate_symbol_kind(metadata, symbol)?;
180
181            // Store resolved index in metadata
182            metadata.index = Some(resolved_idx);
183        }
184
185        // Rewrite all indices in bytecode using resolved indices from metadata
186        for instr in &mut self.state.bytecode {
187            match instr {
188                Instr::Load(idx) => {
189                    *idx = self.state.symbols[*idx]
190                        .index
191                        .expect("Symbol should have been resolved during linking");
192                }
193                Instr::Call(idx, _) => {
194                    *idx = self.state.symbols[*idx]
195                        .index
196                        .expect("Symbol should have been resolved during linking");
197                }
198                _ => {}
199            }
200        }
201
202        Ok(Program {
203            source: self.source,
204            state: Linked {
205                origin: self.state.origin,
206                version: self.state.version,
207                bytecode: self.state.bytecode,
208                symtable: table,
209            },
210        })
211    }
212
213    /// Returns the symbol metadata required by this program.
214    pub fn symbols(&self) -> &[SymbolMetadata] {
215        &self.state.symbols
216    }
217
218    /// Returns the version of this program.
219    pub fn version(&self) -> &str {
220        &self.state.version
221    }
222
223    // ========================================================================
224    // Private helpers
225    // ========================================================================
226
227    /// Internal helper to create program from bytecode with a specific origin.
228    fn from_bytecode(data: &[u8], origin: ProgramOrigin) -> Result<Self, ProgramError> {
229        let config = bincode::config::standard();
230        let (binary, _): (BinaryFormat, _) = bincode::serde::decode_from_slice(data, config)?;
231
232        // Validate version
233        if binary.version != PROGRAM_VERSION {
234            return Err(ProgramError::IncompatibleVersion {
235                expected: PROGRAM_VERSION.to_string(),
236                found: binary.version,
237            });
238        }
239
240        Ok(Program {
241            source: None, // No source for bytecode
242            state: Compiled {
243                origin,
244                version: binary.version,
245                bytecode: binary.bytecode,
246                symbols: binary.symbols,
247            },
248        })
249    }
250
251    /// Highlights an error in the source code.
252    fn highlight_error(input: &str, error: &ParseError) -> String {
253        let span = error.span();
254        let pre = Self::escape(&input[..span.start]);
255        let tok = Self::escape(&input[span.start..span.end]);
256        let post = Self::escape(&input[span.end..]);
257        let line = format!("{}{}{}", pre, tok.red().bold(), post);
258
259        let caret = "^".green().bold();
260        let squiggly_len = UnicodeWidthStr::width(tok.as_str());
261        let caret_offset = UnicodeWidthStr::width(pre.as_str()) + caret.len();
262
263        format!(
264            "1 | {0}\n  | {1: >2$}{3}",
265            line,
266            caret,
267            caret_offset,
268            "~".repeat(squiggly_len.saturating_sub(1)).green()
269        )
270    }
271
272    /// Escapes special characters for display.
273    fn escape(s: &str) -> String {
274        let mut out = String::with_capacity(s.len());
275        for c in s.chars() {
276            match c {
277                '\n' => out.push_str("\\n"),
278                '\r' => out.push_str("\\r"),
279                other => out.push(other),
280            }
281        }
282        out
283    }
284
285    /// Generates bytecode and collects symbol metadata in a single AST traversal.
286    fn generate_bytecode(ast: &Expr) -> (Vec<Instr>, Vec<SymbolMetadata>) {
287        let mut bytecode = Vec::new();
288        let mut symbols = Vec::new();
289        Self::emit_instr(ast, &mut bytecode, &mut symbols);
290        (bytecode, symbols)
291    }
292
293    /// Emits bytecode instructions for an expression node.
294    fn emit_instr(expr: &Expr, bytecode: &mut Vec<Instr>, symbols: &mut Vec<SymbolMetadata>) {
295        match &expr.kind {
296            ExprKind::Literal(v) => {
297                bytecode.push(Instr::Push(*v));
298            }
299            ExprKind::Ident { name } => {
300                // Get or create index for this constant
301                let idx = Self::get_or_create_symbol(name, SymbolKind::Const, symbols);
302                bytecode.push(Instr::Load(idx));
303            }
304            ExprKind::Unary { op, expr } => {
305                Self::emit_instr(expr, bytecode, symbols);
306                match op {
307                    UnOp::Neg => bytecode.push(Instr::Neg),
308                    UnOp::Fact => bytecode.push(Instr::Fact),
309                }
310            }
311            ExprKind::Binary { op, left, right } => {
312                Self::emit_instr(left, bytecode, symbols);
313                Self::emit_instr(right, bytecode, symbols);
314                bytecode.push(match op {
315                    BinOp::Add => Instr::Add,
316                    BinOp::Sub => Instr::Sub,
317                    BinOp::Mul => Instr::Mul,
318                    BinOp::Div => Instr::Div,
319                    BinOp::Pow => Instr::Pow,
320                    BinOp::Equal => Instr::Equal,
321                    BinOp::NotEqual => Instr::NotEqual,
322                    BinOp::Less => Instr::Less,
323                    BinOp::LessEqual => Instr::LessEqual,
324                    BinOp::Greater => Instr::Greater,
325                    BinOp::GreaterEqual => Instr::GreaterEqual,
326                });
327            }
328            ExprKind::Call { name, args } => {
329                // Emit arguments first
330                for arg in args {
331                    Self::emit_instr(arg, bytecode, symbols);
332                }
333
334                // Get or create index for this function
335                let idx = Self::get_or_create_symbol(
336                    name,
337                    SymbolKind::Func {
338                        arity: args.len(),
339                        variadic: false, // Will be validated during linking
340                    },
341                    symbols,
342                );
343                bytecode.push(Instr::Call(idx, args.len()));
344            }
345        }
346    }
347
348    /// Gets existing symbol index or creates a new one.
349    /// For ~50 symbols, linear search is faster than HashMap overhead.
350    fn get_or_create_symbol(
351        name: &str,
352        kind: SymbolKind,
353        symbols: &mut Vec<SymbolMetadata>,
354    ) -> usize {
355        // Check if symbol already exists
356        if let Some(pos) = symbols.iter().position(|s| s.name == name) {
357            return pos;
358        }
359
360        // Create new symbol entry
361        symbols.push(SymbolMetadata {
362            name: name.to_string().into(),
363            kind,
364            index: None,
365        });
366        symbols.len() - 1
367    }
368
369    /// Validates that a symbol matches the expected kind.
370    fn validate_symbol_kind(metadata: &SymbolMetadata, symbol: &Symbol) -> Result<(), LinkError> {
371        match (&metadata.kind, symbol) {
372            (SymbolKind::Const, Symbol::Const { .. }) => Ok(()),
373            (
374                SymbolKind::Func { arity, .. },
375                Symbol::Func {
376                    args: min_args,
377                    variadic,
378                    ..
379                },
380            ) => {
381                // Check if the call is valid:
382                // - For non-variadic: arity must match exactly
383                // - For variadic: arity must be >= min_args
384                let valid = if *variadic {
385                    arity >= min_args
386                } else {
387                    arity == min_args
388                };
389
390                if valid {
391                    Ok(())
392                } else {
393                    let expected_msg = if *variadic {
394                        format!("at least {} arguments", min_args)
395                    } else {
396                        format!("exactly {} arguments", min_args)
397                    };
398                    Err(LinkError::TypeMismatch {
399                        name: metadata.name.to_string(),
400                        expected: expected_msg,
401                        found: format!("{} arguments provided", arity),
402                    })
403                }
404            }
405            (SymbolKind::Const, Symbol::Func { .. }) => Err(LinkError::TypeMismatch {
406                name: metadata.name.to_string(),
407                expected: "constant".to_string(),
408                found: "function".to_string(),
409            }),
410            (SymbolKind::Func { .. }, Symbol::Const { .. }) => Err(LinkError::TypeMismatch {
411                name: metadata.name.to_string(),
412                expected: "function".to_string(),
413                found: "constant".to_string(),
414            }),
415        }
416    }
417}
418
419// ============================================================================
420// Program<Linked> - After linking, ready to execute
421// ============================================================================
422
423impl<'src> Program<'src, Linked> {
424    // ========================================================================
425    // Public API
426    // ========================================================================
427
428    /// Executes the program and returns the result.
429    pub fn execute(&self) -> Result<Decimal, VmError> {
430        Vm.run_bytecode(&self.state.bytecode, &self.state.symtable)
431    }
432
433    /// Returns a reference to the symbol table.
434    pub fn symtable(&self) -> &SymTable {
435        &self.state.symtable
436    }
437
438    /// Returns a mutable reference to the symbol table.
439    pub fn symtable_mut(&mut self) -> &mut SymTable {
440        &mut self.state.symtable
441    }
442
443    /// Returns the version of this program.
444    pub fn version(&self) -> &str {
445        &self.state.version
446    }
447
448    /// Returns a human-readable assembly representation of the program.
449    pub fn get_assembly(&self) -> String {
450        Self::format_assembly(
451            &self.state.version,
452            &self.state.bytecode,
453            &self.state.symtable,
454        )
455    }
456
457    /// Converts the program to bytecode bytes.
458    ///
459    /// This involves reverse-mapping the bytecode indices back to metadata indices.
460    pub fn to_bytecode(&self) -> Result<Vec<u8>, ProgramError> {
461        use std::collections::HashMap;
462
463        let mut reverse_map = HashMap::new();
464        let mut symbols = Vec::new();
465
466        // Helper closure to get or create metadata index
467        // All indices are valid since we successfully linked
468        let mut get_or_create_metadata = |idx: usize| -> usize {
469            if let Some(&existing) = reverse_map.get(&idx) {
470                existing
471            } else {
472                let symbol = self
473                    .state
474                    .symtable
475                    .get_by_index(idx)
476                    .expect("symbol index must be valid after linking");
477
478                let new_idx = symbols.len();
479                symbols.push(symbol.into());
480                reverse_map.insert(idx, new_idx);
481                new_idx
482            }
483        };
484
485        // Single pass: build symbol mapping and rewrite bytecode
486        let bytecode: Vec<Instr> = self
487            .state
488            .bytecode
489            .iter()
490            .map(|instr| match instr {
491                Instr::Load(idx) => Instr::Load(get_or_create_metadata(*idx)),
492                Instr::Call(idx, argc) => Instr::Call(get_or_create_metadata(*idx), *argc),
493                other => other.clone(),
494            })
495            .collect();
496
497        // Serialize
498        let binary = BinaryFormat {
499            version: self.state.version.clone(),
500            bytecode,
501            symbols,
502        };
503
504        let config = bincode::config::standard();
505        Ok(bincode::serde::encode_to_vec(&binary, config)?)
506    }
507
508    /// Saves the program bytecode to a file.
509    pub fn save_bytecode_to_file(
510        &self,
511        path: impl AsRef<std::path::Path>,
512    ) -> Result<(), ProgramError> {
513        let bytecode = self.to_bytecode()?;
514        std::fs::write(path, bytecode)?;
515        Ok(())
516    }
517
518    // ========================================================================
519    // Private helpers
520    // ========================================================================
521
522    /// Formats bytecode as human-readable assembly.
523    fn format_assembly(version: &str, bytecode: &[Instr], table: &SymTable) -> String {
524        use std::fmt::Write as _;
525
526        let mut out = String::new();
527        out += &format!("; VERSION {}\n", version)
528            .bright_black()
529            .to_string();
530
531        for (i, instr) in bytecode.iter().enumerate() {
532            let _ = write!(out, "{} ", format!("{:04X}", i).yellow());
533            let line = match instr {
534                Instr::Push(v) => format!("{} {}", "PUSH".magenta(), v.to_string().green()),
535                Instr::Load(idx) => {
536                    let sym_name = table.get_by_index(*idx).map(|s| s.name()).unwrap_or("???");
537                    format!("{} {}", "LOAD".magenta(), sym_name.blue())
538                }
539                Instr::Neg => format!("{}", "NEG".magenta()),
540                Instr::Add => format!("{}", "ADD".magenta()),
541                Instr::Sub => format!("{}", "SUB".magenta()),
542                Instr::Mul => format!("{}", "MUL".magenta()),
543                Instr::Div => format!("{}", "DIV".magenta()),
544                Instr::Pow => format!("{}", "POW".magenta()),
545                Instr::Fact => format!("{}", "FACT".magenta()),
546                Instr::Call(idx, argc) => {
547                    let sym_name = table.get_by_index(*idx).map(|s| s.name()).unwrap_or("???");
548                    format!(
549                        "{} {} args: {}",
550                        "CALL".magenta(),
551                        sym_name.cyan(),
552                        argc.to_string().bright_blue()
553                    )
554                }
555                Instr::Equal => format!("{}", "EQ".magenta()),
556                Instr::NotEqual => format!("{}", "NEQ".magenta()),
557                Instr::Less => format!("{}", "LT".magenta()),
558                Instr::LessEqual => format!("{}", "LTE".magenta()),
559                Instr::Greater => format!("{}", "GT".magenta()),
560                Instr::GreaterEqual => format!("{}", "GTE".magenta()),
561            };
562            let _ = writeln!(out, "{}", line);
563        }
564        out
565    }
566}