Skip to main content

risky/
lib.rs

1#![allow(clippy::unreadable_literal)]
2
3use std::collections::HashSet;
4use std::fmt;
5use std::fmt::Display;
6use std::ops::Range;
7
8pub mod instructions {
9    use crate::b_instruction;
10    use crate::csr_instruction;
11    use crate::fence_instruction;
12    use crate::i_instruction;
13    use crate::j_instruction;
14    use crate::opcode;
15    use crate::parse_fence_mask;
16    use crate::r_instruction;
17    use crate::s_instruction;
18    use crate::u_instruction;
19
20    macro_rules! instruction {
21        (R: $name: ident, $opcode: path, funct3 $funct3: literal, funct7 $funct7: literal) => {
22            // TODO #[track_caller]
23            pub fn $name(rd: u8, rs1: u8, rs2: u8) -> u32 {
24                r_instruction(stringify!($name), $opcode, rd, $funct3, rs1, rs2, $funct7)
25            }
26        };
27
28        (I: $name: ident, $opcode: path, funct3 $funct3: literal) => {
29            // TODO #[track_caller]
30            pub fn $name(rd: u8, rs1: u8, imm: i16) -> u32 {
31                i_instruction(stringify!($name), $opcode, rd, $funct3, rs1, imm)
32            }
33        };
34
35        (S: $name: ident, $opcode: path, funct3 $funct3: literal) => {
36            // TODO #[track_caller]
37            pub fn $name(rs1: u8, imm: i16, rs2: u8) -> u32 {
38                s_instruction(stringify!($name), $opcode, imm, $funct3, rs1, rs2)
39            }
40        };
41
42        (B: $name: ident, $opcode: path, funct3 $funct3: literal) => {
43            // TODO #[track_caller]
44            pub fn $name(imm: i16, rs1: u8, rs2: u8) -> u32 {
45                b_instruction(stringify!($name), $opcode, imm, $funct3, rs1, rs2)
46            }
47        };
48
49        (U: $name: ident, $opcode: path) => {
50            // TODO #[track_caller]
51            pub fn $name(rd: u8, imm: i32) -> u32 {
52                u_instruction(stringify!($name), $opcode, rd, imm)
53            }
54        };
55
56        (J: $name: ident, $opcode: path) => {
57            // TODO #[track_caller]
58            pub fn $name(rd: u8, imm: i32) -> u32 {
59                j_instruction(stringify!($name), $opcode, rd, imm)
60            }
61        };
62    }
63
64    // RV32I Base Instruction Set
65
66    instruction!(U: lui, opcode::LUI);
67    instruction!(U: auipc, opcode::AUIPC);
68    instruction!(J: jal, opcode::JAL);
69    instruction!(I: jalr, opcode::JALR, funct3 0b000);
70    instruction!(B: beq, opcode::BRANCH, funct3 0b000);
71    instruction!(B: bne, opcode::BRANCH, funct3 0b001);
72    instruction!(B: blt, opcode::BRANCH, funct3 0b100);
73    instruction!(B: bge, opcode::BRANCH, funct3 0b101);
74    instruction!(B: bltu, opcode::BRANCH, funct3 0b110);
75    instruction!(B: bgeu, opcode::BRANCH, funct3 0b111);
76    instruction!(I: lb, opcode::LOAD, funct3 0b000);
77    instruction!(I: lh, opcode::LOAD, funct3 0b001);
78    instruction!(I: lw, opcode::LOAD, funct3 0b010);
79    instruction!(I: lbu, opcode::LOAD, funct3 0b100);
80    instruction!(I: lhu, opcode::LOAD, funct3 0b101);
81    instruction!(S: sb, opcode::STORE, funct3 0b000);
82    instruction!(S: sh, opcode::STORE, funct3 0b001);
83    instruction!(S: sw, opcode::STORE, funct3 0b010);
84    instruction!(I: addi, opcode::OP_IMM, funct3 0b000);
85    instruction!(I: slti, opcode::OP_IMM, funct3 0b010);
86    instruction!(I: sltiu, opcode::OP_IMM, funct3 0b011);
87    instruction!(I: xori, opcode::OP_IMM, funct3 0b100);
88    instruction!(I: ori, opcode::OP_IMM, funct3 0b110);
89    instruction!(I: andi, opcode::OP_IMM, funct3 0b111);
90    instruction!(R: add, opcode::OP, funct3 0b000, funct7 0b0000000);
91    instruction!(R: sub, opcode::OP, funct3 0b000, funct7 0b0100000);
92    instruction!(R: sll, opcode::OP, funct3 0b001, funct7 0b0000000);
93    instruction!(R: slt, opcode::OP, funct3 0b010, funct7 0b0000000);
94    instruction!(R: sltu, opcode::OP, funct3 0b011, funct7 0b0000000);
95    instruction!(R: xor, opcode::OP, funct3 0b100, funct7 0b0000000);
96    instruction!(R: srl, opcode::OP, funct3 0b101, funct7 0b0000000);
97    instruction!(R: sra, opcode::OP, funct3 0b101, funct7 0b0100000);
98    instruction!(R: or, opcode::OP, funct3 0b110, funct7 0b0000000);
99    instruction!(R: and, opcode::OP, funct3 0b111, funct7 0b0000000);
100    instruction!(I: csrrw, opcode::SYSTEM, funct3 0b001);
101    instruction!(I: csrrs, opcode::SYSTEM, funct3 0b010);
102    instruction!(I: csrrc, opcode::SYSTEM, funct3 0b011);
103
104    pub fn slli(rd: u8, rs1: u8, shamt: u8) -> u32 {
105        r_instruction("slli", opcode::OP_IMM, rd, 0b001, rs1, shamt, 0b0000000)
106    }
107
108    pub fn srli(rd: u8, rs1: u8, shamt: u8) -> u32 {
109        r_instruction("srli", opcode::OP_IMM, rd, 0b101, rs1, shamt, 0b0000000)
110    }
111
112    pub fn srai(rd: u8, rs1: u8, shamt: u8) -> u32 {
113        r_instruction("srai", opcode::OP_IMM, rd, 0b101, rs1, shamt, 0b0100000)
114    }
115
116    pub fn fence(pred: &'static str, succ: &'static str) -> u32 {
117        let pred = parse_fence_mask(pred).unwrap();
118        let succ = parse_fence_mask(succ).unwrap();
119        fence_instruction("fence", 0b000, pred, succ)
120    }
121
122    pub fn fence_i() -> u32 {
123        fence_instruction("fence_i", 0b001, 0b0000, 0b0000)
124    }
125
126    pub fn ecall() -> u32 {
127        i_instruction(
128            "ecall",
129            opcode::SYSTEM,
130            0b00000,
131            0b000,
132            0b00000,
133            0b0000_0000_0000,
134        )
135    }
136
137    pub fn ebreak() -> u32 {
138        i_instruction(
139            "ebreak",
140            opcode::SYSTEM,
141            0b00000,
142            0b000,
143            0b00000,
144            0b0000_0000_0001,
145        )
146    }
147
148    pub fn csrrwi(rd: u8, rs1: u8, csr: u16) -> u32 {
149        csr_instruction("csrrwi", rd, 0b101, rs1, csr)
150    }
151
152    pub fn csrrsi(rd: u8, rs1: u8, csr: u16) -> u32 {
153        csr_instruction("csrrsi", rd, 0b110, rs1, csr)
154    }
155
156    pub fn csrrci(rd: u8, rs1: u8, csr: u16) -> u32 {
157        csr_instruction("csrrci", rd, 0b111, rs1, csr)
158    }
159
160    // RV32M Standard Extension
161
162    instruction!(R: mul, opcode::OP, funct3 0b000, funct7 0b0000001);
163    instruction!(R: mulh, opcode::OP, funct3 0b001, funct7 0b0000001);
164    instruction!(R: mulhsu, opcode::OP, funct3 0b010, funct7 0b0000001);
165    instruction!(R: mulhu, opcode::OP, funct3 0b011, funct7 0b0000001);
166    instruction!(R: div, opcode::OP, funct3 0b100, funct7 0b0000001);
167    instruction!(R: divu, opcode::OP, funct3 0b101, funct7 0b0000001);
168    instruction!(R: rem, opcode::OP, funct3 0b110, funct7 0b0000001);
169    instruction!(R: remu, opcode::OP, funct3 0b111, funct7 0b0000001);
170}
171
172// Implementation
173
174mod opcode {
175    macro_rules! opcodes {
176        ($($name: ident $bits6to2: literal,)+) => ($(pub(crate) const $name: u8 = $bits6to2 << 2 | 0b11;)+)
177    }
178
179    opcodes!(
180        LOAD       0b00000,
181        LOAD_FP    0b00001,
182        MISC_MEM   0b00011,
183        OP_IMM     0b00100,
184        AUIPC      0b00101,
185        OP_IMM_32  0b00110,
186
187        STORE      0b01000,
188        STORE_FP   0b01001,
189        AMO        0b01011,
190        OP         0b01100,
191        LUI        0b01101,
192        OP_32      0b01110,
193
194        MADD       0b10000,
195        MSUB       0b10001,
196        NMSUB      0b10010,
197        NMADD      0b10011,
198        OP_FP      0b10100,
199
200        BRANCH     0b11000,
201        JALR       0b11001,
202        JAL        0b11011,
203        SYSTEM     0b11100,
204    );
205}
206
207type Result<T> = core::result::Result<T, String>;
208
209fn fence_instruction(function_name: &'static str, funct3: u8, pred: u8, succ: u8) -> u32 {
210    check_funct3(function_name, funct3).unwrap();
211    check_range(function_name, "pred", pred, 0..1 << 4).unwrap();
212    check_range(function_name, "succ", succ, 0..1 << 4).unwrap();
213    let imm = i16::from(succ | pred << 4);
214    i_instruction(function_name, opcode::MISC_MEM, 0x00000, 0b00, 0x00000, imm)
215}
216
217fn parse_fence_mask(mask_str: &'static str) -> Result<u8> {
218    let mut chars_processed = HashSet::new();
219    let mut mask = 0;
220    for flag_name in mask_str.chars() {
221        if chars_processed.contains(&flag_name) {
222            return Err(fence_flag_error("unknown", flag_name, mask_str));
223        }
224        chars_processed.insert(flag_name);
225        mask |= match flag_name {
226            'i' => 1 << 3,
227            'o' => 1 << 2,
228            'r' => 1 << 1,
229            'w' => 1,
230            _ => return Err(fence_flag_error("invalid", flag_name, mask_str)),
231        };
232    }
233    Ok(mask)
234}
235
236fn fence_flag_error(kind: &'static str, flag_name: char, mask_str: &str) -> String {
237    format!(
238        "{} fence flag name `{}' in fence mask `{}'",
239        kind, flag_name, mask_str
240    )
241}
242
243fn csr_instruction(function_name: &'static str, rd: u8, funct3: u8, rs1: u8, csr: u16) -> u32 {
244    i_instruction(function_name, opcode::SYSTEM, rd, funct3, rs1, csr as i16)
245}
246
247fn r_instruction(
248    function_name: &'static str,
249    opcode: u8,
250    rd: u8,
251    funct3: u8,
252    rs1: u8,
253    rs2: u8,
254    funct7: u8,
255) -> u32 {
256    check_opcode(function_name, opcode).unwrap();
257    check_register(function_name, rd).unwrap();
258    check_funct3(function_name, funct3).unwrap();
259    check_register(function_name, rs1).unwrap();
260    check_register(function_name, rs2).unwrap();
261    check_funct7(function_name, funct7).unwrap();
262    u32::from(opcode)
263        | (u32::from(rd) << 7)
264        | (u32::from(funct3) << 12)
265        | (u32::from(rs1) << 15)
266        | (u32::from(rs2) << 20)
267        | (u32::from(funct7) << 25)
268}
269
270fn i_instruction(
271    function_name: &'static str,
272    opcode: u8,
273    rd: u8,
274    funct3: u8,
275    rs1: u8,
276    imm: i16,
277) -> u32 {
278    check_opcode(function_name, opcode).unwrap();
279    check_register(function_name, rd).unwrap();
280    check_funct3(function_name, funct3).unwrap();
281    check_imm_i_s(function_name, imm).unwrap();
282    u32::from(opcode)
283        | (u32::from(rd) << 7)
284        | (u32::from(funct3) << 12)
285        | (u32::from(rs1) << 15)
286        | ((imm as u32) << 20)
287}
288
289fn s_instruction(
290    function_name: &'static str,
291    opcode: u8,
292    imm: i16,
293    funct3: u8,
294    rs1: u8,
295    rs2: u8,
296) -> u32 {
297    check_opcode(function_name, opcode).unwrap();
298    check_imm_i_s(function_name, imm).unwrap();
299    check_funct3(function_name, funct3).unwrap();
300    check_register(function_name, rs1).unwrap();
301    check_register(function_name, rs2).unwrap();
302    let imm_u32 = imm as u32;
303    u32::from(opcode)
304        | ((imm_u32 & 0b11111) << 7)
305        | (u32::from(funct3) << 12)
306        | (u32::from(rs1) << 15)
307        | (u32::from(rs2) << 20)
308        | ((imm_u32 >> 5) & 0b111_1111) << 25
309}
310
311fn b_instruction(
312    function_name: &'static str,
313    opcode: u8,
314    imm: i16,
315    funct3: u8,
316    rs1: u8,
317    rs2: u8,
318) -> u32 {
319    check_opcode(function_name, opcode).unwrap();
320    check_imm_b(function_name, imm).unwrap();
321    check_funct3(function_name, funct3).unwrap();
322    check_register(function_name, rs1).unwrap();
323    check_register(function_name, rs2).unwrap();
324    let imm_u32 = imm as u32;
325    u32::from(opcode)
326        | ((imm_u32 >> 11) & 0b1 << 7)
327        | ((imm_u32 >> 1) & 0b1111 << 8)
328        | (u32::from(funct3) << 12)
329        | (u32::from(rs1) << 15)
330        | (u32::from(rs2) << 20)
331        | ((imm_u32 >> 5) & 0b11_1111 << 25)
332        | ((imm_u32 >> 12) & 0b1 << 31)
333}
334
335fn u_instruction(function_name: &'static str, opcode: u8, rd: u8, imm: i32) -> u32 {
336    check_opcode(function_name, opcode).unwrap();
337    check_register(function_name, rd).unwrap();
338    u32::from(opcode) | (u32::from(rd) << 7) | (imm as u32 & 0xFFFF_F000)
339}
340
341fn j_instruction(function_name: &'static str, opcode: u8, rd: u8, imm: i32) -> u32 {
342    check_opcode(function_name, opcode).unwrap();
343    check_register(function_name, rd).unwrap();
344    check_imm_j(function_name, imm).unwrap();
345    let imm_u32 = imm as u32;
346    let imm_field = (imm_u32 >> 12 & 0b1111_1111) << 12
347        | (imm_u32 >> 11 & 0b1) << 20
348        | (imm_u32 >> 1 & 0b11_1111_1111) << 21
349        | (imm_u32 >> 20 & 0b1) << 31;
350    u32::from(opcode) | (u32::from(rd) << 7) | imm_field
351}
352
353fn check_opcode(function_name: &'static str, opcode: u8) -> Result<()> {
354    check_range(function_name, "opcode", opcode, 0b11..1 << 7)
355}
356
357fn check_register(function_name: &'static str, register: u8) -> Result<()> {
358    check_range(function_name, "register", register, 0..1 << 5)
359}
360
361fn check_imm_i_s(function_name: &'static str, imm: i16) -> Result<()> {
362    check_range(function_name, "imm", imm, -(1 << 11)..1 << 11)
363}
364
365fn check_imm_b(function_name: &'static str, imm: i16) -> Result<()> {
366    check_range(function_name, "imm", imm, -(1 << 12)..1 << 12).and_then(|()| match imm & 1 {
367        0 => Ok(()),
368        _ => Err(format!(
369            "imm = {} (0x{:08x}) is not a multiple of 2",
370            imm, imm
371        )),
372    })
373}
374
375fn check_imm_j(function_name: &'static str, imm: i32) -> Result<()> {
376    check_range(function_name, "imm", imm, -(1 << 20)..1 << 20).and_then(|()| match imm & 1 {
377        0 => Ok(()),
378        _ => Err(format!(
379            "imm = {} (0x{:08x}) is not a multiple of 2",
380            imm, imm
381        )),
382    })
383}
384
385fn check_funct3(function_name: &'static str, funct3: u8) -> Result<()> {
386    check_range(function_name, "funct3", funct3, 0..1 << 3)
387}
388
389fn check_funct7(function_name: &'static str, funct7: u8) -> Result<()> {
390    check_range(function_name, "funct7", funct7, 0..1 << 7)
391}
392
393fn check_range<T>(
394    function_name: &'static str,
395    value_name: &'static str,
396    value: T,
397    range: Range<T>,
398) -> Result<()>
399where
400    T: PartialOrd + Display + fmt::LowerHex,
401{
402    if value >= range.start && value < range.end {
403        Ok(())
404    } else {
405        Err(format!(
406            "{}: {} = {} (0x{:08x}) is out of range: {} .. {} (0x{:x} .. 0x{:x})",
407            function_name, value_name, value, value, range.start, range.end, range.start, range.end
408        ))
409    }
410}