Skip to main content

bonsai_eval_utils/
replay_disk.rs

1use alloc::vec::Vec;
2use core::fmt;
3use core::ops::Range;
4
5use bonsai_disk::Disk;
6use embedded_io::ErrorType;
7use generic_array::ArrayLength;
8use generic_array::GenericArray;
9use generic_array::typenum;
10use typenum::marker_traits::Unsigned;
11
12
13
14#[derive(Debug, Clone)]
15pub struct RecordDisk<D: Disk> {
16	pub before_storage: D,
17
18	pub records: Vec<Record<D::WRITE_GRANULARITY>>,
19}
20impl<D: Disk> RecordDisk<D> {
21	pub fn new(before_storage: D) -> Self {
22		Self {
23			before_storage,
24			records: Vec::new(),
25		}
26	}
27}
28impl<D: Disk> ErrorType for RecordDisk<D> {
29	type Error = <D as ErrorType>::Error;
30}
31impl<D: Disk> Disk for RecordDisk<D> {
32	type WRITE_GRANULARITY = D::WRITE_GRANULARITY;
33
34	const ERASE_BLOCK_SIZE: usize = D::ERASE_BLOCK_SIZE;
35
36	fn block_count(&self) -> usize {
37		self.before_storage.block_count()
38	}
39
40	fn read(&mut self, offset: usize, buf: &mut [u8]) -> Result<usize, Self::Error> {
41		if buf.is_empty() {
42			return Ok(0);
43		}
44
45		assert!(offset % Self::WRITE_GRANULARITY::USIZE == 0);
46		assert!(buf.len() >= Self::WRITE_GRANULARITY::USIZE);
47
48		// Iter in reverse chronological order, i.e. return the most recent state
49		for rec in self.records.iter().rev() {
50			match rec {
51				Record::Write(WriteRecord {
52					block,
53					ibo,
54					data,
55				}) => {
56					if block * Self::ERASE_BLOCK_SIZE + ibo == offset {
57						buf[..Self::WRITE_GRANULARITY::USIZE].copy_from_slice(data.as_slice());
58						return Ok(Self::WRITE_GRANULARITY::USIZE);
59					}
60				},
61				Record::Erase(EraseRecord {
62					block,
63				}) => {
64					if *block == offset / Self::ERASE_BLOCK_SIZE {
65						// Emulate NOR Flash
66						buf[..Self::WRITE_GRANULARITY::USIZE].fill(0xff);
67						return Ok(Self::WRITE_GRANULARITY::USIZE);
68					}
69				},
70			}
71		}
72
73		self.before_storage
74			.read(offset, &mut buf[..Self::WRITE_GRANULARITY::USIZE])
75	}
76
77	fn write(&mut self, offset: usize, buf: &[u8]) -> Result<usize, Self::Error> {
78		if buf.is_empty() {
79			return Ok(0);
80		}
81
82		assert!(offset % Self::WRITE_GRANULARITY::USIZE == 0);
83		assert!(buf.len() >= Self::WRITE_GRANULARITY::USIZE);
84
85		// Emulate NOR Flash, i.e. read the current state and AND the new data
86		// into it
87		let mut buffer = GenericArray::default();
88		let len = self.read(offset, &mut buffer)?;
89		assert_eq!(len, Self::WRITE_GRANULARITY::USIZE);
90		// AND together the latest state and the new data
91		for (b, input) in buffer.iter_mut().zip(buf) {
92			*b &= input;
93		}
94
95		let rec = WriteRecord {
96			block: offset / Self::ERASE_BLOCK_SIZE,
97			ibo: offset % Self::ERASE_BLOCK_SIZE,
98			data: buffer,
99		};
100
101		self.records.push(Record::Write(rec));
102
103		Ok(Self::WRITE_GRANULARITY::USIZE)
104	}
105
106	fn erase(&mut self, block: usize) -> Result<(), Self::Error> {
107		self.records.push(Record::Erase(EraseRecord {
108			block,
109		}));
110
111		Ok(())
112	}
113}
114
115
116#[derive(Debug)]
117pub struct ReplayDisk<'a, D: Disk> {
118	pub before_storage: &'a mut D,
119
120	records: &'a [Record<D::WRITE_GRANULARITY>],
121
122	new_actions: Vec<Record<D::WRITE_GRANULARITY>>,
123
124	/// The highest index which will cause a different result upon the first
125	/// action taken by the user.
126	///
127	/// Zero means the state before the very first records.
128	pub first_undiscovered: Option<usize>,
129
130	merged_records: usize,
131}
132impl<'a, D: Disk> ReplayDisk<'a, D> {
133	pub fn new(before_storage: &'a mut D, records: &'a [Record<D::WRITE_GRANULARITY>]) -> Self {
134		Self {
135			before_storage,
136			records,
137			new_actions: Vec::new(),
138			first_undiscovered: None,
139			merged_records: 0,
140		}
141	}
142
143	pub fn from_sub_range(
144		before_storage: &'a mut D,
145		records: &'a [Record<D::WRITE_GRANULARITY>],
146		range: Range<usize>,
147	) -> Self {
148		let mut me = Self {
149			before_storage,
150			records: &records,
151			new_actions: Vec::new(),
152			first_undiscovered: None,
153			merged_records: 0,
154		};
155
156		me.cut_off_top(range.end);
157		me.merge_bottom(range.start);
158
159		me
160	}
161
162	pub fn with_limit(
163		before_storage: &'a mut D,
164		records: &'a [Record<D::WRITE_GRANULARITY>],
165		range: Range<usize>,
166		limit: Option<usize>,
167	) -> Self {
168		let mut me = Self::from_sub_range(before_storage, records, range);
169		me.first_undiscovered = limit;
170
171		me
172	}
173
174	pub fn current_range(&self) -> Range<usize> {
175		let start = self.merged_records;
176		let end = self.merged_records + self.records.len();
177		start..end
178	}
179
180	pub fn before_range(&self) -> Range<usize> {
181		let start = self.merged_records;
182		if let Some(first_undiscovered) = self.first_undiscovered {
183			start..(start + first_undiscovered)
184		} else {
185			start..start
186		}
187	}
188
189	pub fn after_range(&self) -> Range<usize> {
190		let start = self.merged_records;
191		if let Some(first_undiscovered) = self.first_undiscovered {
192			(start + first_undiscovered + 1)..(start + self.records.len())
193		} else {
194			(start + self.records.len())..(start + self.records.len())
195		}
196	}
197
198	pub fn cut_off_top(&mut self, top: usize) {
199		self.records = &self.records[..top];
200	}
201
202	pub fn merge_bottom(&mut self, bottom: usize) {
203		let mergable = &self.records[..bottom];
204		self.records = &self.records[bottom..];
205
206		self.merged_records += mergable.len();
207
208		// Modify the base image to incorporate the merged records
209		for rec in mergable.iter() {
210			match rec {
211				Record::Write(WriteRecord {
212					block,
213					ibo,
214					data,
215				}) => {
216					self.before_storage
217						.write(block * D::ERASE_BLOCK_SIZE + ibo, data.as_slice())
218						.unwrap();
219				},
220				Record::Erase(EraseRecord {
221					block,
222				}) => {
223					self.before_storage.erase(*block).unwrap();
224				},
225			}
226		}
227	}
228}
229impl<D: Disk> ErrorType for ReplayDisk<'_, D> {
230	type Error = <D as ErrorType>::Error;
231}
232impl<D: Disk> Disk for ReplayDisk<'_, D> {
233	type WRITE_GRANULARITY = D::WRITE_GRANULARITY;
234
235	const ERASE_BLOCK_SIZE: usize = D::ERASE_BLOCK_SIZE;
236
237	fn block_count(&self) -> usize {
238		self.before_storage.block_count()
239	}
240
241	fn read(&mut self, offset: usize, buf: &mut [u8]) -> Result<usize, Self::Error> {
242		if buf.is_empty() {
243			return Ok(0);
244		}
245
246		assert!(offset % Self::WRITE_GRANULARITY::USIZE == 0);
247		assert!(buf.len() >= Self::WRITE_GRANULARITY::USIZE);
248
249		// Iter in reverse chronological order, i.e. return the most recent state
250		for rec in self.new_actions.iter().rev() {
251			match rec {
252				Record::Write(WriteRecord {
253					block,
254					ibo,
255					data,
256				}) => {
257					if block * Self::ERASE_BLOCK_SIZE + ibo == offset {
258						buf[..Self::WRITE_GRANULARITY::USIZE].copy_from_slice(data.as_slice());
259						return Ok(Self::WRITE_GRANULARITY::USIZE);
260					}
261				},
262				Record::Erase(EraseRecord {
263					block,
264				}) => {
265					if *block == offset / Self::ERASE_BLOCK_SIZE {
266						// Emulate NOR Flash
267						buf[..Self::WRITE_GRANULARITY::USIZE].fill(0xff);
268						return Ok(Self::WRITE_GRANULARITY::USIZE);
269					}
270				},
271			}
272		}
273
274		let limit = if let Some(limit) = self.first_undiscovered {
275			limit + 1
276		} else {
277			self.records.len()
278		};
279
280		// Iter in reverse chronological order, i.e. return the most recent state
281		for (i, rec) in self.records[..limit].iter().enumerate().rev() {
282			match rec {
283				Record::Write(WriteRecord {
284					block,
285					ibo,
286					data,
287				}) => {
288					if block * Self::ERASE_BLOCK_SIZE + ibo == offset {
289						if self.first_undiscovered.is_none() {
290							self.first_undiscovered = Some(i);
291						}
292
293						buf[..Self::WRITE_GRANULARITY::USIZE].copy_from_slice(data.as_slice());
294						return Ok(Self::WRITE_GRANULARITY::USIZE);
295					}
296				},
297				Record::Erase(EraseRecord {
298					block,
299				}) => {
300					if *block == offset / Self::ERASE_BLOCK_SIZE {
301						if self.first_undiscovered.is_none() {
302							self.first_undiscovered = Some(i + 1);
303						}
304
305						// Emulate NOR Flash
306						buf[..Self::WRITE_GRANULARITY::USIZE].fill(0xff);
307						return Ok(Self::WRITE_GRANULARITY::USIZE);
308					}
309				},
310			}
311		}
312
313		self.before_storage
314			.read(offset, &mut buf[..Self::WRITE_GRANULARITY::USIZE])
315	}
316
317	fn write(&mut self, offset: usize, buf: &[u8]) -> Result<usize, Self::Error> {
318		if buf.is_empty() {
319			return Ok(0);
320		}
321
322		assert!(offset % Self::WRITE_GRANULARITY::USIZE == 0);
323		assert!(buf.len() >= Self::WRITE_GRANULARITY::USIZE);
324
325		// Emulate NOR Flash, i.e. read the current state and AND the new data
326		// into it
327		let mut buffer = GenericArray::default();
328		let len = self.read(offset, &mut buffer[..Self::WRITE_GRANULARITY::USIZE])?;
329		assert_eq!(len, Self::WRITE_GRANULARITY::USIZE);
330		// AND together the latest state and the new data
331		for (b, input) in buffer.iter_mut().zip(buf) {
332			*b &= input;
333		}
334
335		let rec = WriteRecord {
336			block: offset / Self::ERASE_BLOCK_SIZE,
337			ibo: offset % Self::ERASE_BLOCK_SIZE,
338			data: buffer,
339		};
340
341		self.new_actions.push(Record::Write(rec));
342
343		Ok(Self::WRITE_GRANULARITY::USIZE)
344	}
345
346	fn erase(&mut self, block: usize) -> Result<(), Self::Error> {
347		self.new_actions.push(Record::Erase(EraseRecord {
348			block,
349		}));
350
351		Ok(())
352	}
353}
354
355use derivative::Derivative;
356
357#[derive(Derivative)]
358#[derivative(Debug(bound = ""), Clone(bound = ""))]
359pub enum Record<N: ArrayLength> {
360	Write(WriteRecord<N>),
361	Erase(EraseRecord),
362}
363impl<N: ArrayLength> fmt::Display for Record<N> {
364	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365		match self {
366			Record::Write(rec) => {
367				let block = rec.block;
368				let range = rec.ibo..(rec.ibo + rec.data.len());
369				let data = &rec.data;
370
371				write!(f, "Write at {block} @ {range:?}, data: {data:02x?}")
372			},
373			Record::Erase(rec) => write!(f, "Erase at {}", rec.block),
374		}
375	}
376}
377
378
379#[derive(Derivative)]
380#[derivative(Debug(bound = ""), Clone(bound = ""))]
381pub struct WriteRecord<N: ArrayLength> {
382	block: usize,
383	ibo: usize,
384	data: GenericArray<u8, N>,
385}
386
387#[derive(Debug, Clone)]
388pub struct EraseRecord {
389	block: usize,
390}