embedded_savegame/
storage.rs

1use crate::{
2    Slot,
3    chksum::{self, Chksum},
4};
5
6pub trait Flash {
7    type Error;
8
9    fn read(&mut self, addr: u32, buf: &mut [u8]) -> Result<(), Self::Error>;
10
11    // XXX: The w25q crate requires data to be mutable for write operations
12    fn write(&mut self, addr: u32, data: &mut [u8]) -> Result<(), Self::Error>;
13
14    fn erase(&mut self, addr: u32) -> Result<(), Self::Error>;
15
16    /// Some flash chips have a better way to do bulk erase
17    /// Default implementation erases all sectors one by one
18    fn erase_all(&mut self, count: usize) -> Result<(), Self::Error> {
19        for idx in 0..count {
20            self.erase(idx as u32)?;
21        }
22        Ok(())
23    }
24}
25
26#[derive(Debug)]
27pub struct Storage<F: Flash, const SLOT_SIZE: usize, const SLOT_COUNT: usize> {
28    flash: F,
29    prev: Chksum,
30    idx: usize,
31}
32
33impl<F: Flash, const SLOT_SIZE: usize, const SLOT_COUNT: usize> Storage<F, SLOT_SIZE, SLOT_COUNT> {
34    pub const SPACE: u32 = SLOT_SIZE as u32 * SLOT_COUNT as u32;
35
36    pub const fn new(flash: F) -> Self {
37        Self {
38            flash,
39            prev: Chksum::zero(),
40            idx: 0,
41        }
42    }
43
44    const fn addr(&self, idx: usize) -> u32 {
45        ((idx % SLOT_COUNT) * SLOT_SIZE) as u32
46    }
47
48    fn scan_slot(&mut self, idx: usize) -> Result<Option<Slot>, F::Error> {
49        let mut buf = [0u8; Slot::HEADER_SIZE];
50        let (head, tail) = buf.split_at_mut(1);
51
52        // Read first byte for sanity check to allow early skip
53        let addr = self.addr(idx);
54        self.flash.read(addr, head)?;
55
56        if head[0] & chksum::BYTE_MASK != 0 {
57            return Ok(None);
58        }
59
60        // Read the rest of the header
61        let addr = addr.saturating_add(1);
62        self.flash.read(addr, tail)?;
63
64        // Parse and validate slot
65        let slot = Slot::from_bytes(idx, buf);
66        let slot = slot.is_valid().then_some(slot);
67        Ok(slot)
68    }
69
70    pub fn scan(&mut self) -> Result<Option<Slot>, F::Error> {
71        let mut current: Option<Slot> = None;
72
73        for idx in 0..SLOT_COUNT {
74            let Some(slot) = self.scan_slot(idx)? else {
75                continue;
76            };
77
78            if let Some(existing) = &current {
79                if slot.is_update_to(existing) {
80                    current = Some(slot);
81                }
82            } else {
83                current = Some(slot);
84            }
85        }
86
87        if let Some(current) = &current {
88            self.idx = current.next_slot::<SLOT_SIZE, SLOT_COUNT>();
89            self.prev = current.chksum;
90        }
91
92        Ok(current)
93    }
94
95    pub fn erase(&mut self, idx: usize) -> Result<(), F::Error> {
96        self.flash.erase(self.addr(idx))?;
97        Ok(())
98    }
99
100    pub fn erase_all(&mut self) -> Result<(), F::Error> {
101        self.flash.erase_all(SLOT_COUNT)
102    }
103
104    pub fn read<'a>(
105        &mut self,
106        mut idx: usize,
107        buf: &'a mut [u8],
108    ) -> Result<Option<&'a mut [u8]>, F::Error> {
109        let mut addr = self.addr(idx);
110        let mut slot = [0u8; Slot::HEADER_SIZE];
111        self.flash.read(addr, &mut slot)?;
112        addr = addr.saturating_add(Slot::HEADER_SIZE as u32);
113        let slot = Slot::from_bytes(idx, slot);
114
115        let Some(data) = buf.get_mut(..slot.len as usize) else {
116            return Ok(None);
117        };
118        let mut buf = &mut *data;
119        let mut remaining_space = SLOT_SIZE - Slot::HEADER_SIZE;
120        while !buf.is_empty() {
121            let read_size = remaining_space.min(buf.len());
122            let (to_read, remaining) = buf.split_at_mut(read_size);
123            self.flash.read(addr, to_read)?;
124            buf = remaining;
125
126            idx = idx.saturating_add(1) % SLOT_COUNT;
127            addr = self.addr(idx).saturating_add(1);
128            remaining_space = SLOT_SIZE - 1;
129        }
130
131        // TODO: validate checksum
132
133        Ok(Some(data))
134    }
135
136    pub fn write(
137        &mut self,
138        mut idx: usize,
139        prev: Chksum,
140        mut data: &mut [u8],
141    ) -> Result<(usize, Chksum), F::Error> {
142        let slot = Slot::create(idx, prev, data);
143        let slot_addr = self.addr(idx);
144        self.flash.erase(slot_addr)?;
145
146        let mut addr = slot_addr.saturating_add(Slot::HEADER_SIZE as u32);
147        let mut remaining_space = SLOT_SIZE - Slot::HEADER_SIZE;
148
149        loop {
150            let write_size = remaining_space.min(data.len());
151            let (to_write, remaining) = data.split_at_mut(write_size);
152            self.flash.write(addr, to_write)?;
153            data = remaining;
154            idx = idx.saturating_add(1) % SLOT_COUNT;
155
156            // erase first byte of next slot, but only if more data remains
157            if data.is_empty() {
158                break;
159            }
160
161            addr = self.addr(idx);
162            self.flash.erase(addr)?;
163
164            addr = addr.saturating_add(1);
165            remaining_space = SLOT_SIZE - 1;
166        }
167
168        // Write header last, to finalize the slot
169        // The last field is `prev`, marking the previous slot as outdated
170        let mut bytes = slot.to_bytes();
171        self.flash.write(slot_addr, &mut bytes)?;
172
173        Ok((idx, slot.chksum))
174    }
175
176    pub fn append(&mut self, data: &mut [u8]) -> Result<(), F::Error> {
177        let (idx, chksum) = self.write(self.idx, self.prev, data)?;
178        self.idx = idx;
179        self.prev = chksum;
180        Ok(())
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use crate::mock::{MockFlash, SectorMockFlash};
188    use core::convert::Infallible;
189
190    const SLOT_SIZE: usize = 64;
191    const SLOT_COUNT: usize = 8;
192    const SIZE: usize = SLOT_SIZE * SLOT_COUNT;
193
194    const fn mock_storage() -> Storage<MockFlash<SIZE>, SLOT_SIZE, SLOT_COUNT> {
195        let flash = MockFlash::<SIZE>::new();
196        Storage::<_, SLOT_SIZE, SLOT_COUNT>::new(flash)
197    }
198
199    const fn mock_sector_storage()
200    -> Storage<SectorMockFlash<SLOT_SIZE, SLOT_COUNT>, SLOT_SIZE, SLOT_COUNT> {
201        let flash = SectorMockFlash::<SLOT_SIZE, SLOT_COUNT>::new();
202        Storage::<_, SLOT_SIZE, SLOT_COUNT>::new(flash)
203    }
204
205    fn test_storage_empty_scan<F: Flash<Error = Infallible>>(
206        mut storage: Storage<F, SLOT_SIZE, SLOT_COUNT>,
207    ) {
208        let Ok(slot) = storage.scan();
209        assert_eq!(slot, None);
210    }
211
212    #[test]
213    fn test_at24cxx_storage_empty_scan() {
214        let storage = mock_storage();
215        test_storage_empty_scan(storage);
216    }
217
218    #[test]
219    fn test_w25qxx_storage_empty_scan() {
220        let storage = mock_sector_storage();
221        test_storage_empty_scan(storage);
222    }
223
224    #[test]
225    fn test_storage_write() {
226        let mut storage = mock_storage();
227
228        let mut data = *b"hello world";
229        storage.append(&mut data);
230
231        let mut buf = [0u8; Slot::HEADER_SIZE];
232        storage.flash.read(0, &mut buf);
233        let slot = Slot::from_bytes(0, buf);
234        assert_eq!(
235            slot,
236            Slot {
237                idx: 0,
238                chksum: Chksum::hash(Chksum::zero(), &data),
239                len: data.len() as u32,
240                prev: Chksum::zero(),
241            }
242        );
243    }
244
245    fn test_storage_write_scan<F: Flash<Error = Infallible>>(
246        mut storage: Storage<F, SLOT_SIZE, SLOT_COUNT>,
247    ) {
248        let mut data = *b"hello world";
249        storage.append(&mut data);
250
251        let Ok(scan) = storage.scan();
252        assert_eq!(
253            scan,
254            Some(Slot {
255                idx: 0,
256                chksum: Chksum::hash(Chksum::zero(), &data),
257                len: data.len() as u32,
258                prev: Chksum::zero(),
259            })
260        );
261    }
262
263    #[test]
264    fn test_at24cxx_storage_write_scan() {
265        let storage = mock_storage();
266        test_storage_write_scan(storage);
267    }
268
269    #[test]
270    fn test_w25qxx_storage_write_scan() {
271        let storage = mock_sector_storage();
272        test_storage_write_scan(storage);
273    }
274
275    fn test_storage_write_read<F: Flash<Error = Infallible>>(
276        mut storage: Storage<F, SLOT_SIZE, SLOT_COUNT>,
277    ) {
278        let mut data = *b"hello world";
279        storage.append(&mut data);
280
281        let mut buf = [0u8; 1024];
282        let Ok(slice) = storage.read(0, &mut buf);
283
284        assert_eq!(slice.map(|s| &*s), Some("hello world".as_bytes()));
285    }
286
287    #[test]
288    fn test_at24cxx_storage_write_read() {
289        let storage = mock_storage();
290        test_storage_write_read(storage);
291    }
292
293    #[test]
294    fn test_w25qxx_storage_write_read() {
295        let storage = mock_sector_storage();
296        test_storage_write_read(storage);
297    }
298
299    fn test_storage_write_wrap_around<F: Flash<Error = Infallible>>(
300        mut storage: Storage<F, SLOT_SIZE, SLOT_COUNT>,
301    ) {
302        for num in 0..(SLOT_COUNT as u32 * 3 + 2) {
303            let mut buf = [0u8; 6];
304            num.to_be_bytes().iter().enumerate().for_each(|(i, b)| {
305                buf[i] = *b;
306            });
307            storage.append(&mut buf);
308        }
309
310        let slot = storage.scan().unwrap().unwrap();
311        assert_eq!(slot.idx, 1);
312        assert_eq!(storage.idx, 2);
313
314        let mut buf = [0u8; 32];
315        let Ok(slice) = storage.read(slot.idx, &mut buf);
316        assert_eq!(slice, Some(&mut [0, 0, 0, 25, 0, 0][..]));
317    }
318
319    #[test]
320    fn test_at24cxx_storage_write_wrap_around() {
321        let storage = mock_storage();
322        test_storage_write_wrap_around(storage);
323    }
324
325    #[test]
326    fn test_w25qxx_storage_write_wrap_around() {
327        let storage = mock_sector_storage();
328        test_storage_write_wrap_around(storage);
329    }
330
331    fn test_storage_big_write<F: Flash<Error = Infallible>>(
332        mut storage: Storage<F, SLOT_SIZE, SLOT_COUNT>,
333    ) {
334        let mut buf = [b'A'; SLOT_SIZE * 5];
335        storage.append(&mut buf);
336        let slot = storage.scan().unwrap().unwrap();
337        assert_eq!(
338            slot,
339            Slot {
340                idx: 0,
341                chksum: Chksum::hash(Chksum::zero(), &buf),
342                len: buf.len() as u32,
343                prev: Chksum::zero(),
344            }
345        );
346
347        let mut buf2 = [0u8; 512];
348        let Ok(slice) = storage.read(slot.idx, &mut buf2);
349        assert_eq!(slice.map(|s| &*s), Some(&buf[..]));
350
351        let mut buf = [b'B'; SLOT_SIZE * 5];
352        storage.append(&mut buf);
353        let new_slot = storage.scan().unwrap().unwrap();
354        assert_eq!(
355            new_slot,
356            Slot {
357                idx: 6,
358                chksum: Chksum::hash(slot.chksum, &buf),
359                len: buf.len() as u32,
360                prev: slot.chksum,
361            }
362        );
363        // TODO: this test is also broken because it's parsing the content of a slot as header
364    }
365
366    #[test]
367    fn test_at24cxx_storage_big_write() {
368        let storage = mock_storage();
369        test_storage_big_write(storage);
370    }
371
372    #[test]
373    fn test_w25qxx_storage_big_write() {
374        let storage = mock_sector_storage();
375        test_storage_big_write(storage);
376    }
377
378    fn test_append_after_scan<F: Flash<Error = Infallible>>(
379        mut storage: Storage<F, SLOT_SIZE, SLOT_COUNT>,
380    ) {
381        let mut big = [b'A'; SLOT_SIZE * 2];
382        storage.append(&mut big);
383        assert_eq!(storage.idx, 3);
384        storage.idx = 0;
385
386        storage.scan().unwrap();
387        assert_eq!(storage.idx, 3);
388        assert_eq!(storage.prev, Chksum::hash(Chksum::zero(), &big));
389    }
390
391    #[test]
392    fn test_at24cxx_append_after_scan() {
393        let storage = mock_storage();
394        test_append_after_scan(storage);
395    }
396
397    #[test]
398    fn test_w25qxx_append_after_scan() {
399        let storage = mock_sector_storage();
400        test_append_after_scan(storage);
401    }
402}