1use crate::ir::Instr;
2use crate::symbol::{SymTable, Symbol};
3use bincode::config;
4use colored::Colorize;
5use rust_decimal::Decimal;
6use std::collections::HashMap;
7use serde::{Deserialize, Serialize};
8use thiserror::Error;
9
10#[derive(Error, Debug)]
12pub enum ProgramError {
13 #[error("Compilation error: {0}")]
14 CompileError(String),
15 #[error("Decoding error: {0}")]
16 DecodingError(#[from] bincode::error::DecodeError),
17 #[error("incompatible program version: expected {0}, got {1}")]
18 IncompatibleVersions(String, String),
19 #[error("Unknown symbol: {0}")]
20 UnknownSymbol(String),
21 #[error("Symbol '{0}' is not a {1}")]
22 SymbolKindMismatch(String, String),
23 #[error("Function '{0}' incorrect arity")]
24 InvalidFuncArity(String),
25 #[error("Corrupted instruction: {0}")]
26 CorrupedInstruction(String),
27}
28
29#[derive(Default)]
32pub struct Program<'sym> {
33 pub version: String,
34 pub code: Vec<Instr<'sym>>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
40struct Binary {
41 version: String,
42 symbols: Vec<BinarySymbol>,
43 code: Vec<BinaryInstr>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47enum BinarySymbol {
48 Const(String),
50 Func {
52 name: String,
53 args: usize,
54 variadic: bool,
55 },
56}
57
58impl BinarySymbol {
59 fn name(&self) -> String {
60 match self {
61 BinarySymbol::Const(name) => name.clone(),
62 BinarySymbol::Func { name, .. } => name.clone(),
63 }
64 }
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68enum BinaryInstr {
69 Push(Decimal),
70 Load(u32), Neg,
72 Add,
73 Sub,
74 Mul,
75 Div,
76 Pow,
77 Fact,
78 Call(u32, usize), Equal,
80 NotEqual,
81 Less,
82 LessEqual,
83 Greater,
84 GreaterEqual,
85}
86
87impl<'sym> Program<'sym> {
88 pub fn new() -> Self {
89 Self {
90 version: env!("CARGO_PKG_VERSION").to_string(),
91 code: Vec::new(),
92 }
93 }
94
95 pub fn compile(&self) -> Result<Vec<u8>, ProgramError> {
96 let binary = self.to_binary();
97 let config = config::standard();
98 bincode::serde::encode_to_vec(&binary, config)
99 .map_err(|err| ProgramError::CompileError(format!("failed to encode program: {}", err)))
100 }
101
102 pub fn load(data: &[u8], table: &'sym SymTable) -> Result<Program<'sym>, ProgramError> {
103 let config = config::standard();
104 let (decoded, _): (Binary, usize) = bincode::serde::decode_from_slice(&data, config)
105 .map_err(ProgramError::DecodingError)?;
106
107 Self::validate_version(&decoded.version)?;
108
109 let get_sym = |bin_sym: &BinarySymbol| -> Result<&'sym Symbol, ProgramError> {
110 let name = bin_sym.name();
111 table.get(&name).ok_or(ProgramError::UnknownSymbol(name))
112 };
113
114 let mut program = Program::new();
115 program.version = decoded.version.clone();
116
117 for instr in &decoded.code {
118 match instr {
119 BinaryInstr::Push(v) => {
120 program.code.push(Instr::Push(*v));
121 }
122 BinaryInstr::Load(idx) => {
123 let bin_sym = &decoded.symbols[*idx as usize];
124 let sym = get_sym(&bin_sym)?;
125 match bin_sym {
126 BinarySymbol::Const(_) => {
127 if !matches!(sym, Symbol::Const { .. }) {
128 return Err(ProgramError::SymbolKindMismatch(
129 sym.name().to_string(),
130 "constant".to_string(),
131 ));
132 }
133 }
134 _ => {
135 return Err(ProgramError::CorrupedInstruction("LOAD".to_string()));
136 }
137 }
138 program.code.push(Instr::Load(sym))
139 }
140 BinaryInstr::Neg => program.code.push(Instr::Neg),
141 BinaryInstr::Add => program.code.push(Instr::Add),
142 BinaryInstr::Sub => program.code.push(Instr::Sub),
143 BinaryInstr::Mul => program.code.push(Instr::Mul),
144 BinaryInstr::Div => program.code.push(Instr::Div),
145 BinaryInstr::Pow => program.code.push(Instr::Pow),
146 BinaryInstr::Fact => program.code.push(Instr::Fact),
147 BinaryInstr::Call(idx, argc) => {
148 let bin_sym = &decoded.symbols[*idx as usize];
149 let sym = get_sym(&bin_sym)?;
150 if !matches!(sym, Symbol::Func { .. }) {
151 return Err(ProgramError::SymbolKindMismatch(
152 sym.name().to_string(),
153 "function".to_string(),
154 ));
155 }
156 program.code.push(Instr::Call(sym, *argc));
157 }
158 BinaryInstr::Equal => program.code.push(Instr::Equal),
160 BinaryInstr::NotEqual => program.code.push(Instr::NotEqual),
161 BinaryInstr::Less => program.code.push(Instr::Less),
162 BinaryInstr::LessEqual => program.code.push(Instr::LessEqual),
163 BinaryInstr::Greater => program.code.push(Instr::Greater),
164 BinaryInstr::GreaterEqual => program.code.push(Instr::GreaterEqual),
165 }
166 }
167
168 Ok(program)
169 }
170
171 fn validate_version(version: &String) -> Result<(), ProgramError> {
172 let current_version = env!("CARGO_PKG_VERSION");
173 if version != current_version {
174 return Err(ProgramError::IncompatibleVersions(
175 current_version.to_string(),
176 version.clone(),
177 ));
178 }
179 Ok(())
180 }
181
182 pub fn get_assembly(&self) -> String {
183 use std::fmt::Write as _;
184
185 let mut out = String::new();
186 out += &format!("; VERSION {}\n", self.version).bright_black().to_string();
187
188 let emit = |mnemonic: &str| -> String { format!("{}", mnemonic.magenta()) };
189 let emit1 = |mnemonic: &str, op: &str| -> String {
190 format!("{} {}", mnemonic.magenta(), op.green())
191 };
192
193 for (i, instr) in self.code.iter().enumerate() {
194 let _ = write!(out, "{} ", format!("{:04X}", i).yellow());
195 let line = match instr {
196 Instr::Push(v) => emit1("PUSH", &v.to_string().green()),
197 Instr::Load(sym) => emit1("LOAD", &sym.name().blue()),
198 Instr::Neg => emit("NEG"),
199 Instr::Add => emit("ADD"),
200 Instr::Sub => emit("SUB"),
201 Instr::Mul => emit("MUL"),
202 Instr::Div => emit("DIV"),
203 Instr::Pow => emit("POW"),
204 Instr::Fact => emit("FACT"),
205 Instr::Call(sym, argc) => format!(
206 "{} {} args: {}",
207 emit("CALL"),
208 sym.name().cyan(),
209 argc.to_string().bright_blue()
210 ),
211 Instr::Equal => emit("EQ"),
212 Instr::NotEqual => emit("NEQ"),
213 Instr::Less => emit("LT"),
214 Instr::LessEqual => emit("LTE"),
215 Instr::Greater => emit("GT"),
216 Instr::GreaterEqual => emit("GTE"),
217 };
218 let _ = writeln!(out, "{}", line);
219 }
220 out
221 }
222
223 fn to_binary(&self) -> Binary {
224 let mut map: HashMap<String, u32> = HashMap::new();
225 let mut binary = Binary {
226 version: self.version.clone(),
227 symbols: Vec::new(),
228 code: Vec::new(),
229 };
230
231 let mut get_index = |sym: &'sym Symbol| -> u32 {
232 map.get(sym.name()).map(|val| *val).unwrap_or_else(|| {
233 let i = binary.symbols.len() as u32;
234 map.insert(sym.name().to_string(), i);
235 binary.symbols.push(match sym {
236 Symbol::Const { .. } => BinarySymbol::Const(sym.name().to_string()),
237 Symbol::Func { args, variadic, .. } => BinarySymbol::Func {
238 name: sym.name().to_string(),
239 args: *args,
240 variadic: *variadic,
241 },
242 });
243 i
244 })
245 };
246
247 for instr in &self.code {
248 match instr {
249 Instr::Push(v) => {
250 binary.code.push(BinaryInstr::Push(*v));
251 }
252 Instr::Load(sym) => {
253 let idx = get_index(sym);
254 binary.code.push(BinaryInstr::Load(idx));
255 }
256 Instr::Neg => {
257 binary.code.push(BinaryInstr::Neg);
258 }
259 Instr::Add => {
260 binary.code.push(BinaryInstr::Add);
261 }
262 Instr::Sub => {
263 binary.code.push(BinaryInstr::Sub);
264 }
265 Instr::Mul => {
266 binary.code.push(BinaryInstr::Mul);
267 }
268 Instr::Div => {
269 binary.code.push(BinaryInstr::Div);
270 }
271 Instr::Pow => {
272 binary.code.push(BinaryInstr::Pow);
273 }
274 Instr::Fact => {
275 binary.code.push(BinaryInstr::Fact);
276 }
277 Instr::Call(sym, argc) => {
278 let idx = get_index(sym);
279 binary.code.push(BinaryInstr::Call(idx, *argc));
280 }
281 Instr::Equal => {
282 binary.code.push(BinaryInstr::Equal);
283 }
284 Instr::NotEqual => {
285 binary.code.push(BinaryInstr::NotEqual);
286 }
287 Instr::Less => {
288 binary.code.push(BinaryInstr::Less);
289 }
290 Instr::LessEqual => {
291 binary.code.push(BinaryInstr::LessEqual);
292 }
293 Instr::Greater => {
294 binary.code.push(BinaryInstr::Greater);
295 }
296 Instr::GreaterEqual => {
297 binary.code.push(BinaryInstr::GreaterEqual);
298 }
299 }
300 }
301
302 binary
303 }
304}