embedded_savegame/
storage.rs

1use crate::{
2    Slot,
3    chksum::{self, Chksum},
4};
5
6pub trait Flash {
7    type Error;
8
9    fn read(&mut self, addr: u32, buf: &mut [u8]) -> Result<(), Self::Error>;
10
11    // XXX: The w25q crate requires data to be mutable for write operations
12    fn write(&mut self, addr: u32, data: &mut [u8]) -> Result<(), Self::Error>;
13
14    fn erase(&mut self, addr: u32) -> Result<(), Self::Error>;
15}
16
17#[derive(Debug)]
18pub struct Storage<F: Flash, const SLOT_SIZE: usize, const SLOT_COUNT: usize> {
19    flash: F,
20    prev: Chksum,
21    idx: usize,
22}
23
24impl<F: Flash, const SLOT_SIZE: usize, const SLOT_COUNT: usize> Storage<F, SLOT_SIZE, SLOT_COUNT> {
25    pub const SPACE: u32 = SLOT_SIZE as u32 * SLOT_COUNT as u32;
26
27    pub const fn new(flash: F) -> Self {
28        Self {
29            flash,
30            prev: Chksum::zero(),
31            idx: 0,
32        }
33    }
34
35    const fn addr(&self, idx: usize) -> u32 {
36        ((idx % SLOT_COUNT) * SLOT_SIZE) as u32
37    }
38
39    fn scan_slot(&mut self, idx: usize) -> Result<Option<Slot>, F::Error> {
40        let mut buf = [0u8; Slot::HEADER_SIZE];
41        let (head, tail) = buf.split_at_mut(1);
42
43        // Read first byte for sanity check to allow early skip
44        let addr = self.addr(idx);
45        self.flash.read(addr, head)?;
46
47        if head[0] & chksum::BYTE_MASK != 0 {
48            return Ok(None);
49        }
50
51        // Read the rest of the header
52        let addr = addr.saturating_add(1);
53        self.flash.read(addr, tail)?;
54
55        // Parse and validate slot
56        let slot = Slot::from_bytes(idx, buf);
57        let slot = slot.is_valid().then_some(slot);
58        Ok(slot)
59    }
60
61    pub fn scan(&mut self) -> Result<Option<Slot>, F::Error> {
62        let mut current: Option<Slot> = None;
63
64        for idx in 0..SLOT_COUNT {
65            let Some(slot) = self.scan_slot(idx)? else {
66                continue;
67            };
68
69            if let Some(existing) = &current {
70                if slot.is_update_to(existing) {
71                    current = Some(slot);
72                }
73            } else {
74                current = Some(slot);
75            }
76        }
77
78        if let Some(current) = &current {
79            self.idx = current.next_slot::<SLOT_SIZE, SLOT_COUNT>();
80            self.prev = current.chksum;
81        }
82
83        Ok(current)
84    }
85
86    pub fn erase(&mut self, idx: usize) -> Result<(), F::Error> {
87        self.flash.erase(self.addr(idx))?;
88        Ok(())
89    }
90
91    pub fn erase_all(&mut self) -> Result<(), F::Error> {
92        // TODO: some flash chips have a better way to do bulk erase
93        for idx in 0..SLOT_COUNT {
94            self.erase(idx)?;
95        }
96        Ok(())
97    }
98
99    pub fn read<'a>(
100        &mut self,
101        mut idx: usize,
102        buf: &'a mut [u8],
103    ) -> Result<Option<&'a mut [u8]>, F::Error> {
104        let mut addr = self.addr(idx);
105        let mut slot = [0u8; Slot::HEADER_SIZE];
106        self.flash.read(addr, &mut slot)?;
107        addr = addr.saturating_add(Slot::HEADER_SIZE as u32);
108        let slot = Slot::from_bytes(idx, slot);
109
110        let Some(data) = buf.get_mut(..slot.len as usize) else {
111            return Ok(None);
112        };
113        let mut buf = &mut *data;
114        let mut remaining_space = SLOT_SIZE - Slot::HEADER_SIZE;
115        while !buf.is_empty() {
116            let read_size = remaining_space.min(buf.len());
117            let (to_read, remaining) = buf.split_at_mut(read_size);
118            self.flash.read(addr, to_read)?;
119            buf = remaining;
120
121            idx = idx.saturating_add(1) % SLOT_COUNT;
122            addr = self.addr(idx).saturating_add(1);
123            remaining_space = SLOT_SIZE - 1;
124        }
125
126        // TODO: validate checksum
127
128        Ok(Some(data))
129    }
130
131    pub fn write(
132        &mut self,
133        mut idx: usize,
134        prev: Chksum,
135        mut data: &mut [u8],
136    ) -> Result<(usize, Chksum), F::Error> {
137        let slot = Slot::create(idx, prev, data);
138        let chksum = slot.chksum;
139        let addr = self.addr(idx);
140        let mut bytes = slot.to_bytes();
141        self.flash.erase(addr)?;
142        self.flash.write(addr, &mut bytes)?;
143
144        let mut addr = addr.saturating_add(Slot::HEADER_SIZE as u32);
145        let mut remaining_space = SLOT_SIZE - Slot::HEADER_SIZE;
146
147        loop {
148            let write_size = remaining_space.min(data.len());
149            let (to_write, remaining) = data.split_at_mut(write_size);
150            self.flash.write(addr, to_write)?;
151            data = remaining;
152            idx = idx.saturating_add(1) % SLOT_COUNT;
153
154            // erase first byte of next slot, but only if more data remains
155            if data.is_empty() {
156                break;
157            }
158
159            addr = self.addr(idx);
160            self.flash.erase(addr)?;
161
162            addr = addr.saturating_add(1);
163            remaining_space = SLOT_SIZE - 1;
164        }
165
166        Ok((idx, chksum))
167    }
168
169    pub fn append(&mut self, data: &mut [u8]) -> Result<(), F::Error> {
170        let (idx, chksum) = self.write(self.idx, self.prev, data)?;
171        self.idx = idx;
172        self.prev = chksum;
173        Ok(())
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use crate::mock::{MockFlash, SectorMockFlash};
181    use core::convert::Infallible;
182
183    const SLOT_SIZE: usize = 64;
184    const SLOT_COUNT: usize = 8;
185    const SIZE: usize = SLOT_SIZE * SLOT_COUNT;
186
187    const fn mock_storage() -> Storage<MockFlash<SIZE>, SLOT_SIZE, SLOT_COUNT> {
188        let flash = MockFlash::<SIZE>::new();
189        Storage::<_, SLOT_SIZE, SLOT_COUNT>::new(flash)
190    }
191
192    const fn mock_sector_storage()
193    -> Storage<SectorMockFlash<SLOT_SIZE, SLOT_COUNT>, SLOT_SIZE, SLOT_COUNT> {
194        let flash = SectorMockFlash::<SLOT_SIZE, SLOT_COUNT>::new();
195        Storage::<_, SLOT_SIZE, SLOT_COUNT>::new(flash)
196    }
197
198    fn test_storage_empty_scan<F: Flash<Error = Infallible>>(
199        mut storage: Storage<F, SLOT_SIZE, SLOT_COUNT>,
200    ) {
201        let Ok(slot) = storage.scan();
202        assert_eq!(slot, None);
203    }
204
205    #[test]
206    fn test_at24cxx_storage_empty_scan() {
207        let storage = mock_storage();
208        test_storage_empty_scan(storage);
209    }
210
211    #[test]
212    fn test_w25qxx_storage_empty_scan() {
213        let storage = mock_sector_storage();
214        test_storage_empty_scan(storage);
215    }
216
217    #[test]
218    fn test_storage_write() {
219        let mut storage = mock_storage();
220
221        let mut data = *b"hello world";
222        storage.append(&mut data);
223
224        let mut buf = [0u8; Slot::HEADER_SIZE];
225        storage.flash.read(0, &mut buf);
226        let slot = Slot::from_bytes(0, buf);
227        assert_eq!(
228            slot,
229            Slot {
230                idx: 0,
231                prev: Chksum::zero(),
232                chksum: Chksum::hash(Chksum::zero(), &data),
233                len: data.len() as u32,
234            }
235        );
236    }
237
238    fn test_storage_write_scan<F: Flash<Error = Infallible>>(
239        mut storage: Storage<F, SLOT_SIZE, SLOT_COUNT>,
240    ) {
241        let mut data = *b"hello world";
242        storage.append(&mut data);
243
244        let Ok(scan) = storage.scan();
245        assert_eq!(
246            scan,
247            Some(Slot {
248                idx: 0,
249                prev: Chksum::zero(),
250                chksum: Chksum::hash(Chksum::zero(), &data),
251                len: data.len() as u32,
252            })
253        );
254    }
255
256    #[test]
257    fn test_at24cxx_storage_write_scan() {
258        let storage = mock_storage();
259        test_storage_write_scan(storage);
260    }
261
262    #[test]
263    fn test_w25qxx_storage_write_scan() {
264        let storage = mock_sector_storage();
265        test_storage_write_scan(storage);
266    }
267
268    fn test_storage_write_read<F: Flash<Error = Infallible>>(
269        mut storage: Storage<F, SLOT_SIZE, SLOT_COUNT>,
270    ) {
271        let mut data = *b"hello world";
272        storage.append(&mut data);
273
274        let mut buf = [0u8; 1024];
275        let Ok(slice) = storage.read(0, &mut buf);
276
277        assert_eq!(slice.map(|s| &*s), Some("hello world".as_bytes()));
278    }
279
280    #[test]
281    fn test_at24cxx_storage_write_read() {
282        let storage = mock_storage();
283        test_storage_write_read(storage);
284    }
285
286    #[test]
287    fn test_w25qxx_storage_write_read() {
288        let storage = mock_sector_storage();
289        test_storage_write_read(storage);
290    }
291
292    fn test_storage_write_wrap_around<F: Flash<Error = Infallible>>(
293        mut storage: Storage<F, SLOT_SIZE, SLOT_COUNT>,
294    ) {
295        for num in 0..(SLOT_COUNT as u32 * 3 + 2) {
296            let mut buf = [0u8; 6];
297            num.to_be_bytes().iter().enumerate().for_each(|(i, b)| {
298                buf[i] = *b;
299            });
300            storage.append(&mut buf);
301        }
302
303        let slot = storage.scan().unwrap().unwrap();
304        assert_eq!(slot.idx, 1);
305        assert_eq!(storage.idx, 2);
306
307        let mut buf = [0u8; 32];
308        let Ok(slice) = storage.read(slot.idx, &mut buf);
309        assert_eq!(slice, Some(&mut [0, 0, 0, 25, 0, 0][..]));
310    }
311
312    #[test]
313    fn test_at24cxx_storage_write_wrap_around() {
314        let storage = mock_storage();
315        test_storage_write_wrap_around(storage);
316    }
317
318    #[test]
319    fn test_w25qxx_storage_write_wrap_around() {
320        let storage = mock_sector_storage();
321        test_storage_write_wrap_around(storage);
322    }
323
324    fn test_storage_big_write<F: Flash<Error = Infallible>>(
325        mut storage: Storage<F, SLOT_SIZE, SLOT_COUNT>,
326    ) {
327        let mut buf = [b'A'; SLOT_SIZE * 5];
328        storage.append(&mut buf);
329        let slot = storage.scan().unwrap().unwrap();
330        assert_eq!(
331            slot,
332            Slot {
333                idx: 0,
334                prev: Chksum::zero(),
335                chksum: Chksum::hash(Chksum::zero(), &buf),
336                len: buf.len() as u32,
337            }
338        );
339
340        let mut buf2 = [0u8; 512];
341        let Ok(slice) = storage.read(slot.idx, &mut buf2);
342        assert_eq!(slice.map(|s| &*s), Some(&buf[..]));
343
344        let mut buf = [b'B'; SLOT_SIZE * 5];
345        storage.append(&mut buf);
346        let new_slot = storage.scan().unwrap().unwrap();
347        assert_eq!(
348            new_slot,
349            Slot {
350                idx: 6,
351                prev: slot.chksum,
352                chksum: Chksum::hash(slot.chksum, &buf),
353                len: buf.len() as u32,
354            }
355        );
356        // TODO: this test is also broken because it's parsing the content of a slot as header
357    }
358
359    #[test]
360    fn test_at24cxx_storage_big_write() {
361        let storage = mock_storage();
362        test_storage_big_write(storage);
363    }
364
365    #[test]
366    fn test_w25qxx_storage_big_write() {
367        let storage = mock_sector_storage();
368        test_storage_big_write(storage);
369    }
370
371    fn test_append_after_scan<F: Flash<Error = Infallible>>(
372        mut storage: Storage<F, SLOT_SIZE, SLOT_COUNT>,
373    ) {
374        let mut big = [b'A'; SLOT_SIZE * 2];
375        storage.append(&mut big);
376        assert_eq!(storage.idx, 3);
377        storage.idx = 0;
378
379        storage.scan().unwrap();
380        assert_eq!(storage.idx, 3);
381        assert_eq!(storage.prev, Chksum::hash(Chksum::zero(), &big));
382    }
383
384    #[test]
385    fn test_at24cxx_append_after_scan() {
386        let storage = mock_storage();
387        test_append_after_scan(storage);
388    }
389
390    #[test]
391    fn test_w25qxx_append_after_scan() {
392        let storage = mock_sector_storage();
393        test_append_after_scan(storage);
394    }
395}