1use crate::ir::Instr;
2use crate::symbol::{SymTable, Symbol};
3use bincode::config;
4use colored::Colorize;
5use rust_decimal::Decimal;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use thiserror::Error;
9
10const PROGRAM_VERSION: &str = env!("CARGO_PKG_VERSION");
12
13#[derive(Error, Debug)]
15pub enum ProgramError {
16 #[error("Compilation error: {0}")]
17 CompileError(String),
18 #[error("Decoding error: {0}")]
19 DecodingError(#[from] bincode::error::DecodeError),
20 #[error("incompatible program version: expected {0}, got {1}")]
21 IncompatibleVersions(String, String),
22 #[error("Unknown symbol: {0}")]
23 UnknownSymbol(String),
24 #[error("Symbol '{0}' is not a {1}")]
25 SymbolKindMismatch(String, String),
26 #[error("Function '{0}' incorrect arity")]
27 InvalidFuncArity(String),
28 #[error("Corrupted instruction: {0}")]
29 CorrupedInstruction(String),
30}
31
32#[derive(Default)]
37pub struct Program<'sym> {
38 pub version: String,
39 pub code: Vec<Instr<'sym>>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
45struct Binary {
46 version: String,
47 symbols: Vec<BinarySymbol>,
48 code: Vec<BinaryInstr>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52enum BinarySymbol {
53 Const(String),
55 Func {
57 name: String,
58 args: usize,
59 variadic: bool,
60 },
61}
62
63impl BinarySymbol {
64 fn name(&self) -> String {
65 match self {
66 BinarySymbol::Const(name) => name.clone(),
67 BinarySymbol::Func { name, .. } => name.clone(),
68 }
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73enum BinaryInstr {
74 Push(Decimal),
75 Load(u32), Neg,
77 Add,
78 Sub,
79 Mul,
80 Div,
81 Pow,
82 Fact,
83 Call(u32, usize), Equal,
85 NotEqual,
86 Less,
87 LessEqual,
88 Greater,
89 GreaterEqual,
90}
91
92impl<'sym> Program<'sym> {
93 pub fn new() -> Self {
95 Self {
96 version: PROGRAM_VERSION.to_string(),
97 code: Vec::new(),
98 }
99 }
100
101 pub fn compile(&self) -> Result<Vec<u8>, ProgramError> {
103 let binary = self.to_binary();
104 let config = config::standard();
105 bincode::serde::encode_to_vec(&binary, config)
106 .map_err(|err| ProgramError::CompileError(format!("failed to encode program: {}", err)))
107 }
108
109 pub fn load(data: &[u8], table: &'sym SymTable) -> Result<Program<'sym>, ProgramError> {
113 let config = config::standard();
114 let (decoded, _): (Binary, usize) = bincode::serde::decode_from_slice(&data, config)
115 .map_err(ProgramError::DecodingError)?;
116
117 Self::validate_version(&decoded.version)?;
118
119 let get_sym = |bin_sym: &BinarySymbol| -> Result<&'sym Symbol, ProgramError> {
120 let name = bin_sym.name();
121 table.get(&name).ok_or(ProgramError::UnknownSymbol(name))
122 };
123
124 let mut program = Program::new();
125 program.version = decoded.version.clone();
126
127 for instr in &decoded.code {
128 match instr {
129 BinaryInstr::Push(v) => {
130 program.code.push(Instr::Push(*v));
131 }
132 BinaryInstr::Load(idx) => {
133 let bin_sym = &decoded.symbols[*idx as usize];
134 let sym = get_sym(&bin_sym)?;
135 match bin_sym {
136 BinarySymbol::Const(_) => {
137 if !matches!(sym, Symbol::Const { .. }) {
138 return Err(ProgramError::SymbolKindMismatch(
139 sym.name().to_string(),
140 "constant".to_string(),
141 ));
142 }
143 }
144 _ => {
145 return Err(ProgramError::CorrupedInstruction("LOAD".to_string()));
146 }
147 }
148 program.code.push(Instr::Load(sym))
149 }
150 BinaryInstr::Neg => program.code.push(Instr::Neg),
151 BinaryInstr::Add => program.code.push(Instr::Add),
152 BinaryInstr::Sub => program.code.push(Instr::Sub),
153 BinaryInstr::Mul => program.code.push(Instr::Mul),
154 BinaryInstr::Div => program.code.push(Instr::Div),
155 BinaryInstr::Pow => program.code.push(Instr::Pow),
156 BinaryInstr::Fact => program.code.push(Instr::Fact),
157 BinaryInstr::Call(idx, argc) => {
158 let bin_sym = &decoded.symbols[*idx as usize];
159 let sym = get_sym(&bin_sym)?;
160 if !matches!(sym, Symbol::Func { .. }) {
161 return Err(ProgramError::SymbolKindMismatch(
162 sym.name().to_string(),
163 "function".to_string(),
164 ));
165 }
166 program.code.push(Instr::Call(sym, *argc));
167 }
168 BinaryInstr::Equal => program.code.push(Instr::Equal),
170 BinaryInstr::NotEqual => program.code.push(Instr::NotEqual),
171 BinaryInstr::Less => program.code.push(Instr::Less),
172 BinaryInstr::LessEqual => program.code.push(Instr::LessEqual),
173 BinaryInstr::Greater => program.code.push(Instr::Greater),
174 BinaryInstr::GreaterEqual => program.code.push(Instr::GreaterEqual),
175 }
176 }
177
178 Ok(program)
179 }
180
181 fn validate_version(version: &String) -> Result<(), ProgramError> {
182 if version != PROGRAM_VERSION {
183 return Err(ProgramError::IncompatibleVersions(
184 PROGRAM_VERSION.to_string(),
185 version.clone(),
186 ));
187 }
188 Ok(())
189 }
190
191 pub fn get_assembly(&self) -> String {
193 use std::fmt::Write as _;
194
195 let mut out = String::new();
196 out += &format!("; VERSION {}\n", self.version)
197 .bright_black()
198 .to_string();
199
200 let emit = |mnemonic: &str| -> String { format!("{}", mnemonic.magenta()) };
201 let emit1 = |mnemonic: &str, op: &str| -> String {
202 format!("{} {}", mnemonic.magenta(), op.green())
203 };
204
205 for (i, instr) in self.code.iter().enumerate() {
206 let _ = write!(out, "{} ", format!("{:04X}", i).yellow());
207 let line = match instr {
208 Instr::Push(v) => emit1("PUSH", &v.to_string().green()),
209 Instr::Load(sym) => emit1("LOAD", &sym.name().blue()),
210 Instr::Neg => emit("NEG"),
211 Instr::Add => emit("ADD"),
212 Instr::Sub => emit("SUB"),
213 Instr::Mul => emit("MUL"),
214 Instr::Div => emit("DIV"),
215 Instr::Pow => emit("POW"),
216 Instr::Fact => emit("FACT"),
217 Instr::Call(sym, argc) => format!(
218 "{} {} args: {}",
219 emit("CALL"),
220 sym.name().cyan(),
221 argc.to_string().bright_blue()
222 ),
223 Instr::Equal => emit("EQ"),
224 Instr::NotEqual => emit("NEQ"),
225 Instr::Less => emit("LT"),
226 Instr::LessEqual => emit("LTE"),
227 Instr::Greater => emit("GT"),
228 Instr::GreaterEqual => emit("GTE"),
229 };
230 let _ = writeln!(out, "{}", line);
231 }
232 out
233 }
234
235 fn to_binary(&self) -> Binary {
236 let mut map: HashMap<String, u32> = HashMap::new();
237 let mut binary = Binary {
238 version: self.version.clone(),
239 symbols: Vec::new(),
240 code: Vec::new(),
241 };
242
243 let mut get_index = |sym: &'sym Symbol| -> u32 {
244 map.get(sym.name()).map(|val| *val).unwrap_or_else(|| {
245 let i = binary.symbols.len() as u32;
246 map.insert(sym.name().to_string(), i);
247 binary.symbols.push(match sym {
248 Symbol::Const { .. } => BinarySymbol::Const(sym.name().to_string()),
249 Symbol::Func { args, variadic, .. } => BinarySymbol::Func {
250 name: sym.name().to_string(),
251 args: *args,
252 variadic: *variadic,
253 },
254 });
255 i
256 })
257 };
258
259 for instr in &self.code {
260 match instr {
261 Instr::Push(v) => {
262 binary.code.push(BinaryInstr::Push(*v));
263 }
264 Instr::Load(sym) => {
265 let idx = get_index(sym);
266 binary.code.push(BinaryInstr::Load(idx));
267 }
268 Instr::Neg => {
269 binary.code.push(BinaryInstr::Neg);
270 }
271 Instr::Add => {
272 binary.code.push(BinaryInstr::Add);
273 }
274 Instr::Sub => {
275 binary.code.push(BinaryInstr::Sub);
276 }
277 Instr::Mul => {
278 binary.code.push(BinaryInstr::Mul);
279 }
280 Instr::Div => {
281 binary.code.push(BinaryInstr::Div);
282 }
283 Instr::Pow => {
284 binary.code.push(BinaryInstr::Pow);
285 }
286 Instr::Fact => {
287 binary.code.push(BinaryInstr::Fact);
288 }
289 Instr::Call(sym, argc) => {
290 let idx = get_index(sym);
291 binary.code.push(BinaryInstr::Call(idx, *argc));
292 }
293 Instr::Equal => {
294 binary.code.push(BinaryInstr::Equal);
295 }
296 Instr::NotEqual => {
297 binary.code.push(BinaryInstr::NotEqual);
298 }
299 Instr::Less => {
300 binary.code.push(BinaryInstr::Less);
301 }
302 Instr::LessEqual => {
303 binary.code.push(BinaryInstr::LessEqual);
304 }
305 Instr::Greater => {
306 binary.code.push(BinaryInstr::Greater);
307 }
308 Instr::GreaterEqual => {
309 binary.code.push(BinaryInstr::GreaterEqual);
310 }
311 }
312 }
313
314 binary
315 }
316}