rubbler 0.1.2

Rubbler is a RISC-V assembler written in Rust 🦀. This library was written with the main purpose of embedding a simple RISC-V assembler inside of a RISC-V CPU test bench code written with verilator.
Documentation
use crate::directives;
use crate::encoder;
use crate::error::error;
use crate::stanalyzer::Context;
use crate::stanalyzer::Section;
use crate::statement::Statement;
use crate::statement::StmtType;

pub struct Code {
    ctx: Context,
    curr_section: Section,
    text: Vec<u8>,
    data: Vec<u8>,
    rodata: Vec<u8>,
}

pub enum ResolveType {
    Abs,
    Rel,
}

impl Code {
    pub fn new(ctx: Context) -> Code {
        Code {
            ctx,
            curr_section: Section::Text,
            text: vec![],
            data: vec![],
            rodata: vec![],
        }
    }
    pub fn get_addr(&self) -> usize {
        match self.curr_section {
            Section::Text => self.text.len(),
            Section::Data => self.data.len(),
            Section::ROData => self.rodata.len(),
            Section::BSS => panic!("Cannot request for BSS address."),
        }
    }
    pub fn to_text(&mut self) {
        self.curr_section = Section::Text;
    }
    pub fn to_data(&mut self) {
        self.curr_section = Section::Data;
    }
    pub fn to_rodata(&mut self) {
        self.curr_section = Section::ROData;
    }
    pub fn to_bss(&mut self) {
        self.curr_section = Section::BSS;
    }
    pub fn align_addr(&mut self, alignment: usize) {
        let curr_addr = self.get_addr();
        let offset = (alignment - curr_addr % alignment) % alignment;
        let zeros = vec![0; offset];
        match self.curr_section {
            Section::BSS => (),
            _ => self.append_bytes(zeros).unwrap(),
        };
    }
    pub fn append_bytes(&mut self, mut bytes: Vec<u8>) -> Result<(), String> {
        match self.curr_section {
            Section::Text => self.text.append(&mut bytes),
            Section::Data => self.data.append(&mut bytes),
            Section::ROData => self.rodata.append(&mut bytes),
            Section::BSS => return Err("Cannot append bytes to .bss section".to_string()),
        };
        Ok(())
    }
    pub fn resolve_sym(&self, sym: &str, resolve_type: ResolveType) -> Option<i32> {
        if let Some(value) = self.ctx.resolve_const(sym) {
            Some(value)
        } else if let Some(addr) = self.ctx.resolve_sym(sym) {
            match resolve_type {
                ResolveType::Abs => Some(addr as i32),
                ResolveType::Rel => Some(self.get_rel_addr(addr)),
            }
        } else {
            None
        }
    }
    fn get_rel_addr(&self, addr: usize) -> i32 {
        match self.curr_section {
            Section::Text => addr as i32 - self.text.len() as i32,
            Section::ROData => addr as i32 - (self.rodata.len() + self.ctx.text_size()) as i32,
            Section::Data => {
                addr as i32
                    - (self.rodata.len() + self.ctx.text_size() + self.ctx.rodata_size()) as i32
            }
            Section::BSS => {
                addr as i32
                    - (self.rodata.len()
                        + self.ctx.text_size()
                        + self.ctx.rodata_size()
                        + self.ctx.data_size()) as i32
            }
        }
    }
    fn get_code(&mut self) -> Vec<u8> {
        let mut aggregate_code = vec![];
        aggregate_code.append(&mut self.text);
        aggregate_code.append(&mut self.rodata);
        aggregate_code.append(&mut self.data);
        aggregate_code
    }
}

pub struct Generator {
    stmts: Vec<Statement>,
    code: Code,
}

impl Generator {
    pub fn new(stmts: Vec<Statement>, ctx: Context) -> Generator {
        Generator {
            stmts,
            code: Code::new(ctx),
        }
    }

    pub fn generate_code(&mut self) -> Result<Vec<u8>, String> {
        for stmt in self.stmts.iter() {
            let ln = stmt.get_line_number();
            match stmt.get_type() {
                StmtType::Operation(op, args) => {
                    encoder::generate_inst_code(op, args, &mut self.code)
                        .map_err(|e| error(ln, "Generator error", &e))?;
                }
                StmtType::Directive(dir, args) => {
                    directives::generate_directive(dir, args, &mut self.code)
                        .map_err(|e| error(ln, "Generator error", &e))?;
                }
                StmtType::Label(_) => (),
            }
        }
        Ok(self.code.get_code())
    }
}

#[cfg(test)]
mod test {
    use crate::{
        parser::Parser,
        scanner::Scanner,
        stanalyzer::{Analyzer, Context},
        statement::Statement,
    };

    use super::Generator;

    fn source_to_stmts_ctx(source: &str) -> (Vec<Statement>, Context) {
        let scanner = Scanner::new(source.to_string());
        let parser = Parser::new(scanner.scan_tokens().unwrap());
        let analyzer = Analyzer::new(parser.parse().unwrap());
        analyzer.analyze().unwrap()
    }

    fn single_inst_test(source: &str, expected_code: u32) {
        let (stmts, ctx) = source_to_stmts_ctx(source);
        let mut generator = Generator::new(stmts, ctx);
        let code = generator.generate_code().unwrap();
        let expected_code = u32_to_vecu8(expected_code);
        assert_eq!(code, expected_code);
    }

    fn u32_to_vecu8(value: u32) -> Vec<u8> {
        let mut bytes = vec![];
        for i in 0..4 {
            bytes.push((value >> (8 * i)) as u8)
        }
        bytes
    }

    fn vecu32_to_vecu8(words: &Vec<u32>) -> Vec<u8> {
        let mut bytes = vec![];
        for word in words {
            bytes.append(&mut u32_to_vecu8(*word))
        }
        bytes
    }

    #[test]
    fn u_inst() {
        let source = "lui t2, -3";
        let expected_code: u32 = 0b11111111111111111111_00111_0110111;
        single_inst_test(source, expected_code)
    }
    #[test]
    fn op() {
        let source = "add t2, t1, t0";
        let expected_code: u32 = 0b0000000_00101_00110_000_00111_0110011;
        single_inst_test(source, expected_code)
    }
    #[test]
    fn op_imm() {
        let source = "addi t2, t1, -3";
        let expected_code: u32 = 0b111111111101_00110_000_00111_0010011;
        single_inst_test(source, expected_code)
    }

    #[test]
    fn jal() {
        let source = "jal t2, -3";
        let expected_code: u32 = 0b1_1111111110_1_11111111_00111_1101111;
        single_inst_test(source, expected_code)
    }

    #[test]
    fn jalr() {
        let source = "jalr t2, t1, -3";
        let expected_code: u32 = 0b111111111101_00110_000_00111_1100111;
        single_inst_test(source, expected_code);
    }
    #[test]
    fn b_inst() {
        let source = "beq t2, t1, -3";
        let expected_code: u32 = 0b1_111111_00110_00111_000_1110_1_1100011;
        single_inst_test(source, expected_code);
    }
    #[test]
    fn load() {
        let source = "lw t2, -3(t1)";
        let expected_code: u32 = 0b111111111101_00110_010_00111_0000011;
        single_inst_test(source, expected_code);
    }
    #[test]
    fn store() {
        let source = "sw t2, -3(t1)";
        let expected_code: u32 = 0b1111111_00111_00110_010_11101_0100011;
        single_inst_test(source, expected_code);
    }
    #[test]
    fn jal_w_label() {
        let source = "
        fail:
        jal t2, success
        jal t2, fail
        success:
        ";
        let (stmts, ctx) = source_to_stmts_ctx(source);
        let mut generator = Generator::new(stmts, ctx);
        let code_w_label = generator.generate_code().unwrap();
        let source = "
        jal t2, 8
        jal t2, -4
        ";
        let (stmts, ctx) = source_to_stmts_ctx(source);
        let mut generator = Generator::new(stmts, ctx);
        let code_wo_label = generator.generate_code().unwrap();
        assert_eq!(code_w_label, code_wo_label);
    }
    #[test]
    fn jalr_w_label() {
        let source = "
        jalr t2, t1, success
        .comm fuad, 10, 1
        success:
        ";
        let (stmts, ctx) = source_to_stmts_ctx(source);
        let mut generator = Generator::new(stmts, ctx);
        let code_w_label = generator.generate_code().unwrap();
        let source = "
        jalr t2, t1, 14
        ";
        let (stmts, ctx) = source_to_stmts_ctx(source);
        let mut generator = Generator::new(stmts, ctx);
        let code_wo_label = generator.generate_code().unwrap();
        assert_eq!(code_w_label, code_wo_label);
    }
    #[test]
    fn b_w_label() {
        let source = "
        fail:
        beq t2, t1, success
        bne t2, t1, fail
        success:
        ";
        let (stmts, ctx) = source_to_stmts_ctx(source);
        let mut generator = Generator::new(stmts, ctx);
        let code_w_label = generator.generate_code().unwrap();
        let source = "
        beq t2, t1, 8
        bne t2, t1, -4
        ";
        let (stmts, ctx) = source_to_stmts_ctx(source);
        let mut generator = Generator::new(stmts, ctx);
        let code_wo_label = generator.generate_code().unwrap();
        assert_eq!(code_w_label, code_wo_label);
    }
    #[test]
    fn align() {
        let source = "
        beq t2, t1, 8
        .align 3
        ";
        let (stmts, ctx) = source_to_stmts_ctx(source);
        let mut generator = Generator::new(stmts, ctx);
        let code_w_align = generator.generate_code().unwrap();
        let source = "
        beq t2, t1, 8
        ";
        let (stmts, ctx) = source_to_stmts_ctx(source);
        let mut generator = Generator::new(stmts, ctx);
        let mut code_wo_align = generator.generate_code().unwrap();
        let mut padding = vec![0; 4];
        code_wo_align.append(&mut padding);
        assert_eq!(code_w_align, code_wo_align);
    }
    #[test]
    fn bytes() {
        let source = "
        .byte 1, 2, 3
        ";
        let (stmts, ctx) = source_to_stmts_ctx(source);
        let mut generator = Generator::new(stmts, ctx);
        let code = generator.generate_code().unwrap();
        assert_eq!(code, vec![1, 2, 3]);
    }
    #[test]
    fn all() {
        let source = "
        addi t2, t1, -3
        lui t2, -3
        add t2, t1, t0
        jal t2, -3
        jalr t2, t1, -3
        beq t2, t1, -3
        lw t2, -3(t1)
        sw t2, -3(t1)
        ";
        let (stmts, ctx) = source_to_stmts_ctx(source);
        let mut generator = Generator::new(stmts, ctx);
        let code = generator.generate_code().unwrap();
        let expected_codes: Vec<u32> = vec![
            0b111111111101_00110_000_00111_0010011,
            0b11111111111111111111_00111_0110111,
            0b0000000_00101_00110_000_00111_0110011,
            0b1_1111111110_1_11111111_00111_1101111,
            0b111111111101_00110_000_00111_1100111,
            0b1_111111_00110_00111_000_1110_1_1100011,
            0b111111111101_00110_010_00111_0000011,
            0b1111111_00111_00110_010_11101_0100011,
        ];
        assert_eq!(code, vecu32_to_vecu8(&expected_codes));
    }
}