lc3_zkvm/
instruction.rs

1//! Module `instruction`
2//!
3//! This module contains functions to execute various instructions. These instructions include arithmetic operations, logical operations, branching, jumping, loading, storing, and system calls.
4//!
5//! # Usage
6//!
7//! The module provides an `execute` function that executes the corresponding operation based on the given raw instruction, register file, and memory state.
8//!
9//! ```rust
10//! use lc3_zkvm::memory::Memory;
11//! use lc3_zkvm::register::RegisterFile;
12//! use lc3_zkvm::instruction::execute;
13//!
14//! let raw_instruction: u16 = 0x1234;
15//! let mut registers = RegisterFile::new();
16//! let mut memory = Memory::new();
17//!
18//! match execute(raw_instruction, &mut registers, &mut memory) {
19//!     Ok(_) => println!("Instruction executed successfully"),
20//!     Err(e) => println!("Instruction execution failed: {}", e),
21//! }
22//! ```
23//!
24//! # Instruction List
25//!
26//! - `OP_ADD`: Addition operation
27//! - `OP_AND`: Bitwise AND operation
28//! - `OP_NOT`: Bitwise NOT operation
29//! - `OP_BR`: Conditional branch
30//! - `OP_JMP`: Unconditional jump
31//! - `OP_JSR`: Jump to subroutine
32//! - `OP_LD`: Load data
33//! - `OP_LDI`: Indirect load data
34//! - `OP_LDR`: Load data from register
35//! - `OP_LEA`: Load effective address
36//! - `OP_ST`: Store data
37//! - `OP_STI`: Indirect store data
38//! - `OP_STR`: Store data from register
39//! - `OP_TRAP`: System call
40//!
41//! # Error Handling
42//!
43//! If the instruction is unrecognized or not implemented, the `execute` function will return the corresponding error message.
44//!
45//! # Helper Functions
46//!
47//! - `sign_extend`: Sign-extend a value
48//! - `Register::from`: Convert `u16` to `Register` enum
49//!
50//! # Example
51//!
52//! ```rust
53//! use lc3_zkvm::memory::Memory;
54//! use lc3_zkvm::register::RegisterFile;
55//! use lc3_zkvm::instruction::execute;
56//!
57//! let raw_instruction: u16 = 0x1234;
58//! let mut registers = RegisterFile::new();
59//! let mut memory = Memory::new();
60//!
61//! match execute(raw_instruction, &mut registers, &mut memory) {
62//!     Ok(_) => println!("Instruction executed successfully"),
63//!     Err(e) => println!("Instruction execution failed: {}", e),
64//! }
65//! ```
66
67use crate::memory::Memory;
68use crate::opcode::{extract_opcode, Opcode};
69use crate::register::{condition_flags, Register, RegisterFile};
70use std::io::{self, Read, Write};
71
72pub fn execute(
73    raw: u16,
74    registers: &mut RegisterFile,
75    memory: &mut Memory,
76) -> Result<(), &'static str> {
77    let opcode = extract_opcode(raw).ok_or("Unknown opcode")?;
78    match opcode {
79        Opcode::OP_ADD => execute_add(raw, registers),
80        Opcode::OP_AND => execute_and(raw, registers),
81        Opcode::OP_NOT => execute_not(raw, registers),
82        Opcode::OP_BR => execute_br(raw, registers),
83        Opcode::OP_JMP => execute_jmp(raw, registers),
84        Opcode::OP_JSR => execute_jsr(raw, registers),
85        Opcode::OP_LD => execute_ld(raw, registers, memory),
86        Opcode::OP_LDI => execute_ldi(raw, registers, memory),
87        Opcode::OP_LDR => execute_ldr(raw, registers, memory),
88        Opcode::OP_LEA => execute_lea(raw, registers),
89        Opcode::OP_ST => execute_st(raw, registers, memory),
90        Opcode::OP_STI => execute_sti(raw, registers, memory),
91        Opcode::OP_STR => execute_str(raw, registers, memory),
92        Opcode::OP_TRAP => execute_trap(raw, registers, memory),
93        Opcode::OP_RES => Err("Reserved opcode"),
94        Opcode::OP_RTI => Err("RTI not implemented"),
95    }
96}
97
98/// ADD - Add
99///
100/// Add two values and store the result in a register.
101/// If bit [5] is 0, the second source operand is obtained from SR2.
102/// If bit [5] is 1, the second source operand is obtained by sign-extending the imm5 field to 16 bits.
103fn execute_add(raw: u16, registers: &mut RegisterFile) -> Result<(), &'static str> {
104    let dr = (raw >> 9) & 0x7;
105    let sr1 = (raw >> 6) & 0x7;
106    let mode = (raw >> 5) & 0x1;
107
108    let val1 = registers.read(Register::from(sr1));
109    let val2 = if mode == 0 {
110        let sr2 = raw & 0x7;
111        registers.read(Register::from(sr2))
112    } else {
113        sign_extend(raw & 0x1F, 5)
114    };
115
116    let result = val1.wrapping_add(val2);
117    registers.write(Register::from(dr), result);
118    registers.update_flags(result);
119
120    Ok(())
121}
122
123/// AND - Bitwise AND
124///
125/// Perform bitwise AND on two values and store the result in a register.
126/// If bit [5] is 0, the second source operand is obtained from SR2.
127/// If bit [5] is 1, the second source operand is obtained by sign-extending the imm5 field to 16 bits.
128fn execute_and(raw: u16, registers: &mut RegisterFile) -> Result<(), &'static str> {
129    let dr = (raw >> 9) & 0x7;
130    let sr1 = (raw >> 6) & 0x7;
131    let mode = (raw >> 5) & 0x1;
132
133    let val1 = registers.read(Register::from(sr1));
134    let val2 = if mode == 0 {
135        let sr2 = raw & 0x7;
136        registers.read(Register::from(sr2))
137    } else {
138        sign_extend(raw & 0x1F, 5)
139    };
140
141    let result = val1 & val2;
142    registers.write(Register::from(dr), result);
143    registers.update_flags(result);
144
145    Ok(())
146}
147
148/// NOT - Bitwise NOT
149///
150/// Perform bitwise NOT on a value and store the result in a register.
151fn execute_not(raw: u16, registers: &mut RegisterFile) -> Result<(), &'static str> {
152    let dr = (raw >> 9) & 0x7;
153    let sr = (raw >> 6) & 0x7;
154
155    let val = registers.read(Register::from(sr));
156    let result = !val;
157    registers.write(Register::from(dr), result);
158    registers.update_flags(result);
159
160    Ok(())
161}
162
163/// BR - Branch
164///
165/// Conditional branch based on condition codes (N, Z, P).
166/// If (n AND N) OR (z AND Z) OR (p AND P) is true, the program branches to the address specified by adding the sign-extended PCoffset9 field to the incremented PC.
167fn execute_br(raw: u16, registers: &mut RegisterFile) -> Result<(), &'static str> {
168    let pc = registers.read(Register::PC);
169    let cond = registers.read(Register::COND);
170    let n = (raw >> 11) & 0x1;
171    let z = (raw >> 10) & 0x1;
172    let p = (raw >> 9) & 0x1;
173
174    if (n == 1 && cond & condition_flags::FL_NEG != 0)
175        || (z == 1 && cond & condition_flags::FL_ZRO != 0)
176        || (p == 1 && cond & condition_flags::FL_POS != 0)
177    {
178        let pc_offset = sign_extend(raw & 0x1FF, 9);
179        registers.write(Register::PC, pc.wrapping_add(pc_offset));
180    }
181
182    Ok(())
183}
184
185/// JMP - Jump
186///
187/// Unconditional jump to the address specified by the contents of the base register.
188/// Also used for RET (return from subroutine) when BaseR is R7.
189fn execute_jmp(raw: u16, registers: &mut RegisterFile) -> Result<(), &'static str> {
190    let base_r = (raw >> 6) & 0x7;
191    let base = registers.read(Register::from(base_r));
192    registers.write(Register::PC, base);
193    Ok(())
194}
195
196/// JSR - Jump to Subroutine
197///
198/// Jump to a subroutine, saving the return address in R7.
199/// If bit [11] is 0, the PC is loaded with the contents of the base register (JSRR).
200/// If bit [11] is 1, the PC is loaded with the address specified by adding the sign-extended PCoffset11 field to the incremented PC (JSR).
201fn execute_jsr(raw: u16, registers: &mut RegisterFile) -> Result<(), &'static str> {
202    let pc = registers.read(Register::PC);
203    registers.write(Register::R7, pc);
204
205    if (raw >> 11) & 0x1 == 1 {
206        let pc_offset = sign_extend(raw & 0x7FF, 11);
207        registers.write(Register::PC, pc.wrapping_add(pc_offset));
208    } else {
209        let base_r = (raw >> 6) & 0x7;
210        let base = registers.read(Register::from(base_r));
211        registers.write(Register::PC, base);
212    }
213
214    Ok(())
215}
216
217/// LD - Load
218///
219/// Load a value from memory into a register.
220/// The address is calculated by sign-extending bits [8:0] to 16 bits and adding this value to the incremented PC.
221fn execute_ld(raw: u16, registers: &mut RegisterFile, memory: &Memory) -> Result<(), &'static str> {
222    let dr = (raw >> 9) & 0x7;
223    let pc = registers.read(Register::PC);
224    let pc_offset = sign_extend(raw & 0x1FF, 9);
225    let address = pc.wrapping_add(pc_offset);
226    let val = memory.read(address);
227    registers.write(Register::from(dr), val);
228    registers.update_flags(val);
229    Ok(())
230}
231
232/// LDI - Load Indirect
233///
234/// Load a value from memory into a register using an indirect address.
235/// The address of the address is calculated by sign-extending bits [8:0] to 16 bits and adding this value to the incremented PC.
236fn execute_ldi(
237    raw: u16,
238    registers: &mut RegisterFile,
239    memory: &Memory,
240) -> Result<(), &'static str> {
241    let dr = (raw >> 9) & 0x7;
242    let pc = registers.read(Register::PC);
243    let pc_offset = sign_extend(raw & 0x1FF, 9);
244    let address = pc.wrapping_add(pc_offset);
245    let indirect_address = memory.read(address);
246    let val = memory.read(indirect_address);
247    registers.write(Register::from(dr), val);
248    registers.update_flags(val);
249    Ok(())
250}
251
252/// LDR - Load Register
253///
254/// Load a value from memory into a register.
255/// The address is calculated by sign-extending bits [5:0] to 16 bits and adding this value to the contents of the base register.
256fn execute_ldr(
257    raw: u16,
258    registers: &mut RegisterFile,
259    memory: &Memory,
260) -> Result<(), &'static str> {
261    let dr = (raw >> 9) & 0x7;
262    let base_r = (raw >> 6) & 0x7;
263    let offset = sign_extend(raw & 0x3F, 6);
264    let base = registers.read(Register::from(base_r));
265    let address = base.wrapping_add(offset);
266    let val = memory.read(address);
267    registers.write(Register::from(dr), val);
268    registers.update_flags(val);
269    Ok(())
270}
271
272/// LEA - Load Effective Address
273///
274/// Load a register with an effective address.
275/// The address is calculated by sign-extending bits [8:0] to 16 bits and adding this value to the incremented PC.
276fn execute_lea(raw: u16, registers: &mut RegisterFile) -> Result<(), &'static str> {
277    let dr = (raw >> 9) & 0x7;
278    let pc = registers.read(Register::PC);
279    let pc_offset = sign_extend(raw & 0x1FF, 9);
280    let address = pc.wrapping_add(pc_offset);
281    registers.write(Register::from(dr), address);
282    registers.update_flags(address);
283    Ok(())
284}
285
286/// ST - Store
287///
288/// Store a value from a register into memory.
289/// The address is calculated by sign-extending bits [8:0] to 16 bits and adding this value to the incremented PC.
290fn execute_st(
291    raw: u16,
292    registers: &mut RegisterFile,
293    memory: &mut Memory,
294) -> Result<(), &'static str> {
295    let sr = (raw >> 9) & 0x7;
296    let pc = registers.read(Register::PC);
297    let pc_offset = sign_extend(raw & 0x1FF, 9);
298    let address = pc.wrapping_add(pc_offset);
299    let val = registers.read(Register::from(sr));
300    memory.write(address, val);
301    Ok(())
302}
303
304/// STI - Store Indirect
305///
306/// Store a value from a register into memory using an indirect address.
307/// The address of the address is calculated by sign-extending bits [8:0] to 16 bits and adding this value to the incremented PC.
308fn execute_sti(
309    raw: u16,
310    registers: &mut RegisterFile,
311    memory: &mut Memory,
312) -> Result<(), &'static str> {
313    let sr = (raw >> 9) & 0x7;
314    let pc = registers.read(Register::PC);
315    let pc_offset = sign_extend(raw & 0x1FF, 9);
316    let address = pc.wrapping_add(pc_offset);
317    let indirect_address = memory.read(address);
318    let val = registers.read(Register::from(sr));
319    memory.write(indirect_address, val);
320    Ok(())
321}
322
323/// STR - Store Register
324///
325/// Store a value from a register into memory.
326/// The address is calculated by sign-extending bits [5:0] to 16 bits and adding this value to the contents of the base register.
327fn execute_str(
328    raw: u16,
329    registers: &mut RegisterFile,
330    memory: &mut Memory,
331) -> Result<(), &'static str> {
332    let sr = (raw >> 9) & 0x7;
333    let base_r = (raw >> 6) & 0x7;
334    let offset = sign_extend(raw & 0x3F, 6);
335    let base = registers.read(Register::from(base_r));
336    let address = base.wrapping_add(offset);
337    let val = registers.read(Register::from(sr));
338    memory.write(address, val);
339    Ok(())
340}
341
342/// TRAP - System Call
343///
344/// Perform a system call specified by the trap vector.
345/// The trap vector is specified in bits [7:0] of the instruction.
346fn execute_trap(
347    raw: u16,
348    registers: &mut RegisterFile,
349    memory: &mut Memory,
350) -> Result<(), &'static str> {
351    let trapvect8 = raw & 0xFF;
352    match trapvect8 {
353        0x20 => trap_getc(registers),
354        0x21 => trap_out(registers),
355        0x22 => trap_puts(registers, memory),
356        0x23 => trap_in(registers),
357        0x24 => trap_putsp(registers, memory),
358        0x25 => trap_halt(),
359        _ => Err("Unknown TRAP vector"),
360    }
361}
362
363fn trap_getc(registers: &mut RegisterFile) -> Result<(), &'static str> {
364    let mut buffer = [0; 1];
365    if io::stdin().read_exact(&mut buffer).is_ok() {
366        registers.write(Register::R0, buffer[0] as u16);
367        Ok(())
368    } else {
369        Err("Failed to read character")
370    }
371}
372
373fn trap_out(registers: &mut RegisterFile) -> Result<(), &'static str> {
374    let char = (registers.read(Register::R0) & 0xFF) as u8 as char;
375    print!("{}", char);
376    io::stdout().flush().map_err(|_| "Failed to flush stdout")?;
377    Ok(())
378}
379
380fn trap_puts(registers: &mut RegisterFile, memory: &Memory) -> Result<(), &'static str> {
381    let mut address = registers.read(Register::R0);
382    loop {
383        let char = (memory.read(address) & 0xFF) as u8 as char;
384        if char == '\0' {
385            break;
386        }
387        print!("{}", char);
388        address += 1;
389    }
390    io::stdout().flush().map_err(|_| "Failed to flush stdout")?;
391    Ok(())
392}
393
394fn trap_in(registers: &mut RegisterFile) -> Result<(), &'static str> {
395    print!("Enter a character: ");
396    io::stdout().flush().map_err(|_| "Failed to flush stdout")?;
397    let mut buffer = [0; 1];
398    if io::stdin().read_exact(&mut buffer).is_ok() {
399        let char = buffer[0] as char;
400        println!("{}", char);
401        registers.write(Register::R0, buffer[0] as u16);
402        Ok(())
403    } else {
404        Err("Failed to read character")
405    }
406}
407
408fn trap_putsp(registers: &mut RegisterFile, memory: &Memory) -> Result<(), &'static str> {
409    let mut address = registers.read(Register::R0);
410    loop {
411        let word = memory.read(address);
412        let char1 = (word & 0xFF) as u8 as char;
413        if char1 == '\0' {
414            break;
415        }
416        print!("{}", char1);
417
418        let char2 = ((word >> 8) & 0xFF) as u8 as char;
419        if char2 != '\0' {
420            print!("{}", char2);
421        } else {
422            break;
423        }
424        address += 1;
425    }
426    io::stdout().flush().map_err(|_| "Failed to flush stdout")?;
427    Ok(())
428}
429
430fn trap_halt() -> Result<(), &'static str> {
431    Err("HALT")
432}
433
434// Helper function: Sign extend a value with a given bit count
435fn sign_extend(mut x: u16, bit_count: u16) -> u16 {
436    if ((x >> (bit_count - 1)) & 1) != 0 {
437        x |= 0xFFFF << bit_count;
438    }
439    x
440}
441// Helper function: Convert u16 to Register
442impl From<u16> for Register {
443    fn from(value: u16) -> Self {
444        match value {
445            0 => Register::R0,
446            1 => Register::R1,
447            2 => Register::R2,
448            3 => Register::R3,
449            4 => Register::R4,
450            5 => Register::R5,
451            6 => Register::R6,
452            7 => Register::R7,
453            _ => panic!("Invalid register number"),
454        }
455    }
456}