hecate_assembler/
lib.rs

1use std::collections::HashMap;
2
3use hecate_common::{
4    get_pattern, get_pattern_by_mnemonic, Bytecode, BytecodeFile, BytecodeFileHeader,
5    ExpectedOperandType, InstructionPattern, OperandType,
6};
7use indexmap::IndexMap;
8use num_traits::cast::FromPrimitive;
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub enum AssemblerError {
13    #[error("Unknown instruction: {0}")]
14    UnknownInstruction(String),
15    #[error("Wrong number of operands for {mnemonic}: expected {expected}, got {got}")]
16    WrongOperandCount {
17        mnemonic: String,
18        expected: usize,
19        got: usize,
20    },
21    #[error("Invalid register name: {0}")]
22    InvalidRegister(String),
23    #[error("Invalid immediate value: {0}")]
24    InvalidImmediate(String),
25    #[error("Invalid entrypoint: {0}")]
26    InvalidEntrypoint(String),
27    #[error("Invalid label: {0}")]
28    InvalidLabel(String),
29    #[error("Undefined label: {0}")]
30    UndefinedLabel(String),
31}
32
33#[derive(Error, Debug)]
34pub enum DisassemblerError {
35    #[error("Invalid opcode: {0:#x}")]
36    InvalidOpcode(u32),
37    #[error("Unexpected end of bytecode")]
38    UnexpectedEnd,
39}
40
41#[derive(Debug, Clone)]
42pub enum ParsedOperand {
43    Register(u32),
44    ImmediateI32(i32),
45    ImmediateF32(f32),
46    Address(u32),
47    Label(String),
48}
49
50pub struct Assembler {
51    labels: IndexMap<String, u32>,
52    current_address: u32,
53}
54
55impl Default for Assembler {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl Assembler {
62    pub fn new() -> Self {
63        Self {
64            labels: IndexMap::new(),
65            current_address: 0,
66        }
67    }
68
69    pub fn parse_register(reg: &str) -> Result<u32, AssemblerError> {
70        if !reg.to_uppercase().starts_with('R') {
71            return Err(AssemblerError::InvalidRegister(reg.to_string()));
72        }
73        reg[1..]
74            .parse::<u32>()
75            .map_err(|_| AssemblerError::InvalidRegister(reg.to_string()))
76    }
77
78    fn parse_operand(
79        &self,
80        operand: &str,
81        expected_type: ExpectedOperandType,
82    ) -> Result<ParsedOperand, AssemblerError> {
83        match expected_type {
84            ExpectedOperandType::Register => {
85                Ok(ParsedOperand::Register(Self::parse_register(operand)?))
86            }
87            ExpectedOperandType::ImmediateI32 => {
88                let value = if let Some(operand) = operand.strip_prefix("0x") {
89                    i32::from_str_radix(operand, 16)
90                } else if let Some(operand) = operand.strip_prefix("b") {
91                    i32::from_str_radix(operand, 2)
92                } else {
93                    operand.parse::<i32>()
94                }
95                .map_err(|_| AssemblerError::InvalidImmediate(operand.to_string()))?;
96                Ok(ParsedOperand::ImmediateI32(value))
97            }
98            ExpectedOperandType::ImmediateF32 => {
99                let value = operand
100                    .parse::<f32>()
101                    .map_err(|_| AssemblerError::InvalidImmediate(operand.to_string()))?;
102                Ok(ParsedOperand::ImmediateF32(value))
103            }
104            ExpectedOperandType::MemoryAddress => {
105                let addr = if let Some(operand) = operand.strip_prefix('@') {
106                    operand
107                        .parse::<u32>()
108                        .map_err(|_| AssemblerError::InvalidImmediate(operand.to_string()))?
109                } else {
110                    return Err(AssemblerError::InvalidImmediate(operand.to_string()));
111                };
112                Ok(ParsedOperand::Address(addr))
113            }
114            ExpectedOperandType::LabelOrAddress => {
115                let label_or_addr = if let Some(operand) = operand.strip_prefix('@') {
116                    operand
117                } else {
118                    return Err(AssemblerError::InvalidImmediate(operand.to_string()));
119                };
120                if let Some(&addr) = self.labels.get(label_or_addr) {
121                    Ok(ParsedOperand::Address(addr))
122                } else if let Ok(addr) = label_or_addr.parse::<u32>() {
123                    Ok(ParsedOperand::Address(addr))
124                } else {
125                    Ok(ParsedOperand::Label(label_or_addr.to_string()))
126                }
127            }
128        }
129    }
130
131    pub fn assemble_line(&mut self, line: &str) -> Result<Vec<u32>, AssemblerError> {
132        let line = line.trim();
133
134        if line.ends_with(':') {
135            return Ok(vec![]);
136        }
137
138        let mut parts = line.split_whitespace();
139        let mnemonic = match parts.next() {
140            Some(m) => m,
141            None => return Ok(vec![]),
142        };
143
144        let operand_str = parts.collect::<Vec<_>>().join("");
145        let operand_strs: Vec<&str> = if operand_str.is_empty() {
146            vec![]
147        } else {
148            operand_str.split(',').map(str::trim).collect()
149        };
150
151        let pattern = self
152            .parse_line(line)
153            .ok_or_else(|| AssemblerError::UnknownInstruction(mnemonic.to_string()))?;
154
155        if operand_strs.len() != pattern.operands.len() {
156            return Err(AssemblerError::WrongOperandCount {
157                mnemonic: mnemonic.to_string(),
158                expected: pattern.operands.len(),
159                got: operand_strs.len(),
160            });
161        }
162
163        let mut result = vec![pattern.bytecode as u32];
164
165        for (operand_str, &operand_type) in operand_strs.iter().zip(pattern.operands.iter()) {
166            let parsed = self.parse_operand(operand_str, operand_type)?;
167            match parsed {
168                ParsedOperand::Register(reg) => result.push(reg),
169                ParsedOperand::ImmediateI32(imm) => result.push(imm as u32),
170                ParsedOperand::ImmediateF32(imm) => result.push(imm.to_bits()),
171                ParsedOperand::Address(addr) => result.push(addr),
172                ParsedOperand::Label(label) => {
173                    if let Some(&addr) = self.labels.get(&label) {
174                        result.push(addr);
175                    } else {
176                        return Err(AssemblerError::UndefinedLabel(label));
177                    }
178                }
179            }
180        }
181
182        self.current_address += result.len() as u32;
183        Ok(result)
184    }
185
186    fn parse_line(&mut self, line: &str) -> Option<&'static InstructionPattern> {
187        let line = line.split(";").next().unwrap().trim();
188        if line.contains(" ") {
189            let (mnemonic, args) = line.split_once(" ").unwrap();
190            let args = args
191                .split(",")
192                .map(|a| a.trim())
193                .map(|a| {
194                    if a.to_uppercase().starts_with("R") {
195                        Ok(OperandType::Register)
196                    } else if a
197                        .strip_prefix("@")
198                        .map(|a| a.parse::<u32>().is_ok())
199                        .unwrap_or_default()
200                    {
201                        Ok(OperandType::MemoryAddress)
202                    } else if a
203                        .strip_prefix("@")
204                        .map(|a| a.is_ascii())
205                        .unwrap_or_default()
206                    {
207                        Ok(OperandType::Label)
208                    } else if (if let Some(a) = a.strip_prefix("0x") {
209                        i32::from_str_radix(a, 16)
210                    } else if let Some(a) = a.strip_prefix("b") {
211                        i32::from_str_radix(a, 2)
212                    } else {
213                        a.parse::<i32>()
214                    })
215                    .is_ok()
216                    {
217                        Ok(OperandType::ImmediateI32)
218                    } else if a.parse::<f32>().is_ok() {
219                        Ok(OperandType::ImmediateF32)
220                    } else {
221                        Err(format!("Invalid operand! {a}"))
222                    }
223                })
224                .collect::<Result<Vec<_>, _>>()
225                .unwrap();
226            get_pattern_by_mnemonic(mnemonic, &args)
227        } else {
228            get_pattern_by_mnemonic(line, &[])
229        }
230    }
231
232    pub fn assemble_program(&mut self, program: &str) -> Result<BytecodeFile, AssemblerError> {
233        let mut settings = HashMap::new();
234
235        // First pass: collect labels
236        for line in program.lines() {
237            let line = line.trim();
238            if line.is_empty() && line.starts_with(";") {
239                continue;
240            }
241            let line = if line.contains(";") {
242                line.split_once(";").unwrap().0.trim()
243            } else {
244                line
245            };
246            if line.starts_with(".") {
247                let (name, value) = line.split_once(" ").unwrap();
248                settings.insert(&name[1..], value);
249            } else if line.ends_with(':') {
250                let label = &line[..line.trim().len() - 1];
251                self.labels.insert(label.to_string(), self.current_address);
252            } else if let Some(p) = self.parse_line(line) {
253                self.current_address += p.operands.len() as u32 + 1;
254            }
255        }
256
257        // Reset for second pass
258        self.current_address = 0;
259        let mut bytecode = Vec::new();
260
261        // Second pass: generate bytecode
262        for line in program.lines() {
263            let line = line.trim();
264            if line.starts_with(";") {
265                continue;
266            }
267            if line.starts_with(".") {
268                continue;
269            }
270            let line = if line.contains(";") {
271                line.split_once(";").unwrap().0.trim()
272            } else {
273                line
274            };
275            let mut line_code = self.assemble_line(line)?;
276            bytecode.append(&mut line_code);
277        }
278
279        let entry = if let Some(entry) = settings.get("entry") {
280            if let Some(entry) = entry.strip_prefix("@") {
281                let value = if let Some(entry) = entry.strip_prefix("0x") {
282                    u32::from_str_radix(entry, 16)
283                } else if let Some(entry) = entry.strip_prefix("b") {
284                    u32::from_str_radix(entry, 2)
285                } else {
286                    entry.parse::<u32>()
287                }
288                .map_err(|_| AssemblerError::InvalidEntrypoint(entry.to_string()))?;
289                Ok(value)
290            } else {
291                Err(*entry)
292            }
293        } else {
294            Err("main")
295        };
296
297        Ok(BytecodeFile {
298            header: BytecodeFileHeader {
299                labels: self.labels.clone(),
300                entrypoint: entry
301                    .unwrap_or_else(|label| self.labels.get(label).copied().unwrap_or_default()),
302            },
303            data: bytecode,
304        })
305    }
306}
307
308pub struct Disassembler {
309    labels: IndexMap<u32, String>,
310}
311
312impl Default for Disassembler {
313    fn default() -> Self {
314        Self::new()
315    }
316}
317
318impl Disassembler {
319    pub fn new() -> Self {
320        Self {
321            labels: IndexMap::new(),
322        }
323    }
324
325    pub fn from_bytecode_file(file: &BytecodeFile) -> Self {
326        let reverse_labels: IndexMap<u32, String> = file
327            .header
328            .labels
329            .iter()
330            .map(|(name, &addr)| (addr, name.clone()))
331            .collect();
332        Self {
333            labels: reverse_labels,
334        }
335    }
336
337    fn format_operand(&self, value: u32, typ: ExpectedOperandType) -> String {
338        match typ {
339            ExpectedOperandType::Register => format!("R{}", value),
340            ExpectedOperandType::ImmediateI32 => format!("{}", value as i32),
341            ExpectedOperandType::ImmediateF32 => format!("{}", f32::from_bits(value)),
342            ExpectedOperandType::MemoryAddress => format!("@{}", value),
343            ExpectedOperandType::LabelOrAddress => {
344                if let Some(label) = self.labels.get(&value) {
345                    format!("@{}", label)
346                } else {
347                    format!("@{}", value)
348                }
349            }
350        }
351    }
352
353    pub fn disassemble_instruction(
354        &self,
355        bytecode: &[u32],
356    ) -> Result<(String, usize), DisassemblerError> {
357        if bytecode.is_empty() {
358            return Err(DisassemblerError::UnexpectedEnd);
359        }
360
361        let opcode = bytecode[0];
362        let bytecode_enum =
363            Bytecode::from_u32(opcode).ok_or(DisassemblerError::InvalidOpcode(opcode))?;
364
365        let pattern = get_pattern(bytecode_enum).ok_or(DisassemblerError::InvalidOpcode(opcode))?;
366
367        let mut result = pattern.mnemonic.to_string();
368
369        if !pattern.operands.is_empty() {
370            result.push(' ');
371            let operands: Vec<String> = pattern
372                .operands
373                .iter()
374                .enumerate()
375                .map(|(i, &operand_type)| {
376                    if i + 1 >= bytecode.len() {
377                        return Err(DisassemblerError::UnexpectedEnd);
378                    }
379                    Ok(self.format_operand(bytecode[i + 1], operand_type))
380                })
381                .collect::<Result<_, _>>()?;
382            result.push_str(&operands.join(", "));
383        }
384
385        Ok((result, 1 + pattern.operands.len()))
386    }
387
388    pub fn disassemble_program(&self, bytecode: &[u32]) -> Result<String, DisassemblerError> {
389        let mut result = String::new();
390        let mut offset = 0;
391
392        while offset < bytecode.len() {
393            // Add label if this address has one
394            if let Some(label) = self.labels.get(&(offset as u32)) {
395                result.push_str(&format!("{}:\n", label));
396            }
397
398            let (instruction, size) = self.disassemble_instruction(&bytecode[offset..])?;
399            result.push_str(&format!("    {}\n", instruction));
400            offset += size;
401        }
402
403        Ok(result)
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use hecate_common::Bytecode;
410
411    use super::*;
412
413    #[test]
414    fn test_simple_assembly() {
415        let mut assembler = Assembler::new();
416        let program = "\
417            start:\n\
418            load r1, 42\n\
419            add r1, 10\n\
420            jmp @start\
421        ";
422        let result = assembler.assemble_program(program).unwrap();
423        assert!(result.header.labels.contains_key("start"));
424        assert_eq!(result.header.labels["start"], 0);
425    }
426
427    #[test]
428    fn test_simple_disassembly() {
429        let bytecode = vec![
430            Bytecode::LoadValue as u32,
431            1,
432            42,
433            Bytecode::AddValue as u32,
434            1,
435            10,
436        ];
437        let disassembler = Disassembler::new();
438        let result = disassembler.disassemble_program(&bytecode).unwrap();
439        assert!(result.to_lowercase().contains("load r1, 42"));
440        assert!(result.to_lowercase().contains("add r1, 10"));
441    }
442
443    #[test]
444    fn test_memory_addressing() {
445        let mut assembler = Assembler::new();
446        let program = "load r1, @1234\nstore @1234, r1";
447        let bytecode = assembler.assemble_program(program).unwrap();
448        let disassembler = Disassembler::from_bytecode_file(&bytecode);
449        let result = disassembler.disassemble_program(&bytecode.data).unwrap();
450        assert!(result.to_lowercase().contains("load r1, @1234"));
451        assert!(result.to_lowercase().contains("store @1234, r1"));
452    }
453
454    #[test]
455    fn test_roundtrip() {
456        let program = "start:\nload r1, 42\nadd r1, 10\n jmp @start\n";
457        let mut assembler = Assembler::new();
458        let bytecode = assembler.assemble_program(program).unwrap();
459        let disassembler = Disassembler::from_bytecode_file(&bytecode);
460        let result = disassembler.disassemble_program(&bytecode.data).unwrap();
461        let expected = "start:\n    load r1, 42\n    add r1, 10\n    jmp @start\n";
462        assert_eq!(result.to_uppercase(), expected.to_uppercase());
463    }
464}