1use std::collections::HashMap;
2
3use hecate_common::{
4 get_pattern, get_pattern_by_mnemonic, Bytecode, BytecodeFile, BytecodeFileHeader,
5 ExpectedOperandType, InstructionPattern, OperandType,
6};
7use indexmap::IndexMap;
8use num_traits::cast::FromPrimitive;
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub enum AssemblerError {
13 #[error("Unknown instruction: {0}")]
14 UnknownInstruction(String),
15 #[error("Wrong number of operands for {mnemonic}: expected {expected}, got {got}")]
16 WrongOperandCount {
17 mnemonic: String,
18 expected: usize,
19 got: usize,
20 },
21 #[error("Invalid register name: {0}")]
22 InvalidRegister(String),
23 #[error("Invalid immediate value: {0}")]
24 InvalidImmediate(String),
25 #[error("Invalid entrypoint: {0}")]
26 InvalidEntrypoint(String),
27 #[error("Invalid label: {0}")]
28 InvalidLabel(String),
29 #[error("Undefined label: {0}")]
30 UndefinedLabel(String),
31}
32
33#[derive(Error, Debug)]
34pub enum DisassemblerError {
35 #[error("Invalid opcode: {0:#x}")]
36 InvalidOpcode(u32),
37 #[error("Unexpected end of bytecode")]
38 UnexpectedEnd,
39}
40
41#[derive(Debug, Clone)]
42pub enum ParsedOperand {
43 Register(u32),
44 ImmediateI32(i32),
45 ImmediateF32(f32),
46 Address(u32),
47 Label(String),
48}
49
50pub struct Assembler {
51 labels: IndexMap<String, u32>,
52 current_address: u32,
53}
54
55impl Default for Assembler {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl Assembler {
62 pub fn new() -> Self {
63 Self {
64 labels: IndexMap::new(),
65 current_address: 0,
66 }
67 }
68
69 pub fn parse_register(reg: &str) -> Result<u32, AssemblerError> {
70 if !reg.to_uppercase().starts_with('R') {
71 return Err(AssemblerError::InvalidRegister(reg.to_string()));
72 }
73 reg[1..]
74 .parse::<u32>()
75 .map_err(|_| AssemblerError::InvalidRegister(reg.to_string()))
76 }
77
78 fn parse_operand(
79 &self,
80 operand: &str,
81 expected_type: ExpectedOperandType,
82 ) -> Result<ParsedOperand, AssemblerError> {
83 match expected_type {
84 ExpectedOperandType::Register => {
85 Ok(ParsedOperand::Register(Self::parse_register(operand)?))
86 }
87 ExpectedOperandType::ImmediateI32 => {
88 let value = if let Some(operand) = operand.strip_prefix("0x") {
89 i32::from_str_radix(operand, 16)
90 } else if let Some(operand) = operand.strip_prefix("b") {
91 i32::from_str_radix(operand, 2)
92 } else {
93 operand.parse::<i32>()
94 }
95 .map_err(|_| AssemblerError::InvalidImmediate(operand.to_string()))?;
96 Ok(ParsedOperand::ImmediateI32(value))
97 }
98 ExpectedOperandType::ImmediateF32 => {
99 let value = operand
100 .parse::<f32>()
101 .map_err(|_| AssemblerError::InvalidImmediate(operand.to_string()))?;
102 Ok(ParsedOperand::ImmediateF32(value))
103 }
104 ExpectedOperandType::MemoryAddress => {
105 let addr = if let Some(operand) = operand.strip_prefix('@') {
106 operand
107 .parse::<u32>()
108 .map_err(|_| AssemblerError::InvalidImmediate(operand.to_string()))?
109 } else {
110 return Err(AssemblerError::InvalidImmediate(operand.to_string()));
111 };
112 Ok(ParsedOperand::Address(addr))
113 }
114 ExpectedOperandType::LabelOrAddress => {
115 let label_or_addr = if let Some(operand) = operand.strip_prefix('@') {
116 operand
117 } else {
118 return Err(AssemblerError::InvalidImmediate(operand.to_string()));
119 };
120 if let Some(&addr) = self.labels.get(label_or_addr) {
121 Ok(ParsedOperand::Address(addr))
122 } else if let Ok(addr) = label_or_addr.parse::<u32>() {
123 Ok(ParsedOperand::Address(addr))
124 } else {
125 Ok(ParsedOperand::Label(label_or_addr.to_string()))
126 }
127 }
128 }
129 }
130
131 pub fn assemble_line(&mut self, line: &str) -> Result<Vec<u32>, AssemblerError> {
132 let line = line.trim();
133
134 if line.ends_with(':') {
135 return Ok(vec![]);
136 }
137
138 let mut parts = line.split_whitespace();
139 let mnemonic = match parts.next() {
140 Some(m) => m,
141 None => return Ok(vec![]),
142 };
143
144 let operand_str = parts.collect::<Vec<_>>().join("");
145 let operand_strs: Vec<&str> = if operand_str.is_empty() {
146 vec![]
147 } else {
148 operand_str.split(',').map(str::trim).collect()
149 };
150
151 let pattern = self
152 .parse_line(line)
153 .ok_or_else(|| AssemblerError::UnknownInstruction(mnemonic.to_string()))?;
154
155 if operand_strs.len() != pattern.operands.len() {
156 return Err(AssemblerError::WrongOperandCount {
157 mnemonic: mnemonic.to_string(),
158 expected: pattern.operands.len(),
159 got: operand_strs.len(),
160 });
161 }
162
163 let mut result = vec![pattern.bytecode as u32];
164
165 for (operand_str, &operand_type) in operand_strs.iter().zip(pattern.operands.iter()) {
166 let parsed = self.parse_operand(operand_str, operand_type)?;
167 match parsed {
168 ParsedOperand::Register(reg) => result.push(reg),
169 ParsedOperand::ImmediateI32(imm) => result.push(imm as u32),
170 ParsedOperand::ImmediateF32(imm) => result.push(imm.to_bits()),
171 ParsedOperand::Address(addr) => result.push(addr),
172 ParsedOperand::Label(label) => {
173 if let Some(&addr) = self.labels.get(&label) {
174 result.push(addr);
175 } else {
176 return Err(AssemblerError::UndefinedLabel(label));
177 }
178 }
179 }
180 }
181
182 self.current_address += result.len() as u32;
183 Ok(result)
184 }
185
186 fn parse_line(&mut self, line: &str) -> Option<&'static InstructionPattern> {
187 let line = line.split(";").next().unwrap().trim();
188 if line.contains(" ") {
189 let (mnemonic, args) = line.split_once(" ").unwrap();
190 let args = args
191 .split(",")
192 .map(|a| a.trim())
193 .map(|a| {
194 if a.to_uppercase().starts_with("R") {
195 Ok(OperandType::Register)
196 } else if a
197 .strip_prefix("@")
198 .map(|a| a.parse::<u32>().is_ok())
199 .unwrap_or_default()
200 {
201 Ok(OperandType::MemoryAddress)
202 } else if a
203 .strip_prefix("@")
204 .map(|a| a.is_ascii())
205 .unwrap_or_default()
206 {
207 Ok(OperandType::Label)
208 } else if (if let Some(a) = a.strip_prefix("0x") {
209 i32::from_str_radix(a, 16)
210 } else if let Some(a) = a.strip_prefix("b") {
211 i32::from_str_radix(a, 2)
212 } else {
213 a.parse::<i32>()
214 })
215 .is_ok()
216 {
217 Ok(OperandType::ImmediateI32)
218 } else if a.parse::<f32>().is_ok() {
219 Ok(OperandType::ImmediateF32)
220 } else {
221 Err(format!("Invalid operand! {a}"))
222 }
223 })
224 .collect::<Result<Vec<_>, _>>()
225 .unwrap();
226 get_pattern_by_mnemonic(mnemonic, &args)
227 } else {
228 get_pattern_by_mnemonic(line, &[])
229 }
230 }
231
232 pub fn assemble_program(&mut self, program: &str) -> Result<BytecodeFile, AssemblerError> {
233 let mut settings = HashMap::new();
234
235 for line in program.lines() {
237 let line = line.trim();
238 if line.is_empty() && line.starts_with(";") {
239 continue;
240 }
241 let line = if line.contains(";") {
242 line.split_once(";").unwrap().0.trim()
243 } else {
244 line
245 };
246 if line.starts_with(".") {
247 let (name, value) = line.split_once(" ").unwrap();
248 settings.insert(&name[1..], value);
249 } else if line.ends_with(':') {
250 let label = &line[..line.trim().len() - 1];
251 self.labels.insert(label.to_string(), self.current_address);
252 } else if let Some(p) = self.parse_line(line) {
253 self.current_address += p.operands.len() as u32 + 1;
254 }
255 }
256
257 self.current_address = 0;
259 let mut bytecode = Vec::new();
260
261 for line in program.lines() {
263 let line = line.trim();
264 if line.starts_with(";") {
265 continue;
266 }
267 if line.starts_with(".") {
268 continue;
269 }
270 let line = if line.contains(";") {
271 line.split_once(";").unwrap().0.trim()
272 } else {
273 line
274 };
275 let mut line_code = self.assemble_line(line)?;
276 bytecode.append(&mut line_code);
277 }
278
279 let entry = if let Some(entry) = settings.get("entry") {
280 if let Some(entry) = entry.strip_prefix("@") {
281 let value = if let Some(entry) = entry.strip_prefix("0x") {
282 u32::from_str_radix(entry, 16)
283 } else if let Some(entry) = entry.strip_prefix("b") {
284 u32::from_str_radix(entry, 2)
285 } else {
286 entry.parse::<u32>()
287 }
288 .map_err(|_| AssemblerError::InvalidEntrypoint(entry.to_string()))?;
289 Ok(value)
290 } else {
291 Err(*entry)
292 }
293 } else {
294 Err("main")
295 };
296
297 Ok(BytecodeFile {
298 header: BytecodeFileHeader {
299 labels: self.labels.clone(),
300 entrypoint: entry
301 .unwrap_or_else(|label| self.labels.get(label).copied().unwrap_or_default()),
302 },
303 data: bytecode,
304 })
305 }
306}
307
308pub struct Disassembler {
309 labels: IndexMap<u32, String>,
310}
311
312impl Default for Disassembler {
313 fn default() -> Self {
314 Self::new()
315 }
316}
317
318impl Disassembler {
319 pub fn new() -> Self {
320 Self {
321 labels: IndexMap::new(),
322 }
323 }
324
325 pub fn from_bytecode_file(file: &BytecodeFile) -> Self {
326 let reverse_labels: IndexMap<u32, String> = file
327 .header
328 .labels
329 .iter()
330 .map(|(name, &addr)| (addr, name.clone()))
331 .collect();
332 Self {
333 labels: reverse_labels,
334 }
335 }
336
337 fn format_operand(&self, value: u32, typ: ExpectedOperandType) -> String {
338 match typ {
339 ExpectedOperandType::Register => format!("R{}", value),
340 ExpectedOperandType::ImmediateI32 => format!("{}", value as i32),
341 ExpectedOperandType::ImmediateF32 => format!("{}", f32::from_bits(value)),
342 ExpectedOperandType::MemoryAddress => format!("@{}", value),
343 ExpectedOperandType::LabelOrAddress => {
344 if let Some(label) = self.labels.get(&value) {
345 format!("@{}", label)
346 } else {
347 format!("@{}", value)
348 }
349 }
350 }
351 }
352
353 pub fn disassemble_instruction(
354 &self,
355 bytecode: &[u32],
356 ) -> Result<(String, usize), DisassemblerError> {
357 if bytecode.is_empty() {
358 return Err(DisassemblerError::UnexpectedEnd);
359 }
360
361 let opcode = bytecode[0];
362 let bytecode_enum =
363 Bytecode::from_u32(opcode).ok_or(DisassemblerError::InvalidOpcode(opcode))?;
364
365 let pattern = get_pattern(bytecode_enum).ok_or(DisassemblerError::InvalidOpcode(opcode))?;
366
367 let mut result = pattern.mnemonic.to_string();
368
369 if !pattern.operands.is_empty() {
370 result.push(' ');
371 let operands: Vec<String> = pattern
372 .operands
373 .iter()
374 .enumerate()
375 .map(|(i, &operand_type)| {
376 if i + 1 >= bytecode.len() {
377 return Err(DisassemblerError::UnexpectedEnd);
378 }
379 Ok(self.format_operand(bytecode[i + 1], operand_type))
380 })
381 .collect::<Result<_, _>>()?;
382 result.push_str(&operands.join(", "));
383 }
384
385 Ok((result, 1 + pattern.operands.len()))
386 }
387
388 pub fn disassemble_program(&self, bytecode: &[u32]) -> Result<String, DisassemblerError> {
389 let mut result = String::new();
390 let mut offset = 0;
391
392 while offset < bytecode.len() {
393 if let Some(label) = self.labels.get(&(offset as u32)) {
395 result.push_str(&format!("{}:\n", label));
396 }
397
398 let (instruction, size) = self.disassemble_instruction(&bytecode[offset..])?;
399 result.push_str(&format!(" {}\n", instruction));
400 offset += size;
401 }
402
403 Ok(result)
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use hecate_common::Bytecode;
410
411 use super::*;
412
413 #[test]
414 fn test_simple_assembly() {
415 let mut assembler = Assembler::new();
416 let program = "\
417 start:\n\
418 load r1, 42\n\
419 add r1, 10\n\
420 jmp @start\
421 ";
422 let result = assembler.assemble_program(program).unwrap();
423 assert!(result.header.labels.contains_key("start"));
424 assert_eq!(result.header.labels["start"], 0);
425 }
426
427 #[test]
428 fn test_simple_disassembly() {
429 let bytecode = vec![
430 Bytecode::LoadValue as u32,
431 1,
432 42,
433 Bytecode::AddValue as u32,
434 1,
435 10,
436 ];
437 let disassembler = Disassembler::new();
438 let result = disassembler.disassemble_program(&bytecode).unwrap();
439 assert!(result.to_lowercase().contains("load r1, 42"));
440 assert!(result.to_lowercase().contains("add r1, 10"));
441 }
442
443 #[test]
444 fn test_memory_addressing() {
445 let mut assembler = Assembler::new();
446 let program = "load r1, @1234\nstore @1234, r1";
447 let bytecode = assembler.assemble_program(program).unwrap();
448 let disassembler = Disassembler::from_bytecode_file(&bytecode);
449 let result = disassembler.disassemble_program(&bytecode.data).unwrap();
450 assert!(result.to_lowercase().contains("load r1, @1234"));
451 assert!(result.to_lowercase().contains("store @1234, r1"));
452 }
453
454 #[test]
455 fn test_roundtrip() {
456 let program = "start:\nload r1, 42\nadd r1, 10\n jmp @start\n";
457 let mut assembler = Assembler::new();
458 let bytecode = assembler.assemble_program(program).unwrap();
459 let disassembler = Disassembler::from_bytecode_file(&bytecode);
460 let result = disassembler.disassemble_program(&bytecode.data).unwrap();
461 let expected = "start:\n load r1, 42\n add r1, 10\n jmp @start\n";
462 assert_eq!(result.to_uppercase(), expected.to_uppercase());
463 }
464}