nes_core/
mem.rs

1use alloc::vec;
2use alloc::vec::Vec;
3use core::str::FromStr;
4use rand::{Rng, SeedableRng};
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
8#[must_use]
9pub enum Access {
10    Read,
11    Write,
12    Execute,
13    Dummy,
14}
15
16pub trait Mem {
17    #[inline]
18    fn read(&mut self, addr: u16, access: Access) -> u8 {
19        self.peek(addr, access)
20    }
21
22    fn peek(&self, addr: u16, access: Access) -> u8;
23
24    #[inline]
25    fn read_u16(&mut self, addr: u16, access: Access) -> u16 {
26        let lo = self.read(addr, access);
27        let hi = self.read(addr.wrapping_add(1), access);
28        u16::from_le_bytes([lo, hi])
29    }
30
31    #[inline]
32    fn peek_u16(&self, addr: u16, access: Access) -> u16 {
33        let lo = self.peek(addr, access);
34        let hi = self.peek(addr.wrapping_add(1), access);
35        u16::from_le_bytes([lo, hi])
36    }
37
38    fn write(&mut self, addr: u16, val: u8, access: Access);
39
40    #[inline]
41    fn write_u16(&mut self, addr: u16, val: u16, access: Access) {
42        let [lo, hi] = val.to_le_bytes();
43        self.write(addr, lo, access);
44        self.write(addr, hi, access);
45    }
46}
47
48#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
49#[must_use]
50pub enum RamState {
51    #[default]
52    AllZeros,
53    AllOnes,
54    Random,
55}
56
57impl RamState {
58    pub const fn as_slice() -> &'static [Self] {
59        &[Self::AllZeros, Self::AllOnes, Self::Random]
60    }
61
62    #[must_use]
63    pub fn with_capacity(capacity: usize, state: Self) -> Vec<u8> {
64        let mut ram = vec![0x00; capacity];
65        Self::fill(&mut ram, state);
66        ram
67    }
68
69    pub fn fill(ram: &mut [u8], state: RamState) {
70        match state {
71            RamState::AllZeros => ram.fill(0x00),
72            RamState::AllOnes => ram.fill(0xFF),
73            RamState::Random => {
74                let mut rng = rand::rngs::SmallRng::seed_from_u64(256);
75                for val in ram {
76                    *val = rng.gen_range(0x00..=0xFF);
77                }
78            }
79        }
80    }
81}
82
83impl From<usize> for RamState {
84    fn from(value: usize) -> Self {
85        match value {
86            0 => Self::AllZeros,
87            1 => Self::AllOnes,
88            _ => Self::Random,
89        }
90    }
91}
92
93impl AsRef<str> for RamState {
94    fn as_ref(&self) -> &str {
95        match self {
96            Self::AllZeros => "All $00",
97            Self::AllOnes => "All $FF",
98            Self::Random => "Random",
99        }
100    }
101}
102
103impl FromStr for RamState {
104    type Err = &'static str;
105    fn from_str(s: &str) -> Result<Self, Self::Err> {
106        match s {
107            "all_zeros" => Ok(Self::AllZeros),
108            "all_ones" => Ok(Self::AllOnes),
109            "random" => Ok(Self::Random),
110            _ => Err("invalid RamState value. valid options: `all_zeros`, `all_ones`, or `random`"),
111        }
112    }
113}
114
115#[derive(Default, Clone, Serialize, Deserialize)]
116#[must_use]
117pub struct MemBanks {
118    start: usize,
119    end: usize,
120    size: usize,
121    window: usize,
122    shift: usize,
123    mask: usize,
124    banks: Vec<usize>,
125    page_count: usize,
126}
127
128impl MemBanks {
129    pub fn new(start: usize, end: usize, capacity: usize, window: usize) -> Self {
130        let size = end - start;
131        let mut banks = vec![0; (size + 1) / window];
132        for (i, bank) in banks.iter_mut().enumerate() {
133            *bank = i * window;
134        }
135        let page_count = core::cmp::max(1, capacity / window);
136        Self {
137            start,
138            end,
139            size,
140            window,
141            shift: window.trailing_zeros() as usize,
142            mask: page_count - 1,
143            banks,
144            page_count,
145        }
146    }
147
148    #[inline]
149    pub fn set(&mut self, slot: usize, bank: usize) {
150        self.banks[slot] = (bank & self.mask) << self.shift;
151        debug_assert!(self.banks[slot] < self.page_count * self.window);
152    }
153
154    #[inline]
155    pub fn set_range(&mut self, start: usize, end: usize, bank: usize) {
156        let mut new_addr = (bank & self.mask) << self.shift;
157        for slot in start..=end {
158            self.banks[slot] = new_addr;
159            debug_assert!(self.banks[slot] < self.page_count * self.window);
160            new_addr += self.window;
161        }
162    }
163
164    #[inline]
165    #[must_use]
166    pub const fn last(&self) -> usize {
167        self.page_count.saturating_sub(1)
168    }
169
170    #[inline]
171    #[must_use]
172    pub const fn get_bank(&self, addr: u16) -> usize {
173        // 0x6005    - 0b0110000000000101 -> bank 0
174        //  (0x2000)   0b0010000000000000
175        //
176        // 0x8005    - 0b1000000000000101 -> bank 0
177        //   (0x4000)  0b0100000000000000
178        // 0xC005    - 0b1100000000000101 -> bank 1
179        //
180        // 0x8005    - 0b1000000000000101 -> bank 0
181        // 0xA005    - 0b1010000000000101 -> bank 1
182        // 0xC005    - 0b1100000000000101 -> bank 2
183        // 0xE005    - 0b1110000000000101 -> bank 3
184        //   (0x2000)  0b0010000000000000
185        ((addr as usize) & self.size) >> self.shift
186    }
187
188    #[inline]
189    #[must_use]
190    pub fn translate(&self, addr: u16) -> usize {
191        // 0x6005    - 0b0110000000000101 -> bank 0
192        //  (0x2000)   0b0010000000000000
193        //
194        // 0x8005    - 0b1000000000000101 -> bank 0
195        //   (0x4000)  0b0100000000000000
196        // 0xC005    - 0b1100000000000101 -> bank 1
197        //
198        // 0x8005    - 0b1000000000000101 -> bank 0
199        //  0 -> 0x0000
200        //  1
201        //  2
202        // 0xA005    - 0b1010000000000101 -> bank 1
203        // 0xC005    - 0b1100000000000101 -> bank 2
204        // 0xE005    - 0b1110000000000101 -> bank 3
205        //   (0x2000)  0b0010000000000000
206        let page = self.banks[self.get_bank(addr)];
207        page | (addr as usize) & (self.window - 1)
208    }
209}
210
211impl core::fmt::Debug for MemBanks {
212    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
213        f.debug_struct("Bank")
214            .field("start", &format_args!("0x{:04X}", self.start))
215            .field("end", &format_args!("0x{:04X}", self.end))
216            .field("size", &format_args!("0x{:04X}", self.size))
217            .field("window", &format_args!("0x{:04X}", self.window))
218            .field("shift", &self.shift)
219            .field("mask", &self.shift)
220            .field("banks", &self.banks)
221            .field("page_count", &self.page_count)
222            .finish()
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn get_bank() {
232        let size = 128 * 1024;
233        let banks = MemBanks::new(0x8000, 0xFFFF, size, 0x4000);
234        assert_eq!(banks.get_bank(0x8000), 0);
235        assert_eq!(banks.get_bank(0x9FFF), 0);
236        assert_eq!(banks.get_bank(0xA000), 0);
237        assert_eq!(banks.get_bank(0xBFFF), 0);
238        assert_eq!(banks.get_bank(0xC000), 1);
239        assert_eq!(banks.get_bank(0xDFFF), 1);
240        assert_eq!(banks.get_bank(0xE000), 1);
241        assert_eq!(banks.get_bank(0xFFFF), 1);
242    }
243
244    #[test]
245    fn bank_translate() {
246        let size = 128 * 1024;
247        let mut banks = MemBanks::new(0x8000, 0xFFFF, size, 0x2000);
248
249        let last_bank = banks.last();
250        assert_eq!(last_bank, 15, "bank count");
251
252        assert_eq!(banks.translate(0x8000), 0x0000);
253        banks.set(0, 1);
254        assert_eq!(banks.translate(0x8000), 0x2000);
255        banks.set(0, 2);
256        assert_eq!(banks.translate(0x8000), 0x4000);
257        banks.set(0, 0);
258        assert_eq!(banks.translate(0x8000), 0x0000);
259        banks.set(0, banks.last());
260        assert_eq!(banks.translate(0x8000), 0x1E000);
261    }
262}