use core::{
fmt::Display,
ops::{Add, AddAssign, Range},
};
use embedded_storage_async::nor_flash::{
ErrorType, MultiwriteNorFlash, NorFlash, NorFlashError, NorFlashErrorKind, ReadNorFlash,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Writable {
T,
O,
N,
}
use Writable::*;
#[derive(Debug, Clone)]
pub struct MockFlashBase<const PAGES: usize, const BYTES_PER_WORD: usize, const PAGE_WORDS: usize> {
writable: Vec<Writable>,
data: Vec<u8>,
current_stats: FlashStatsSnapshot,
pub write_count_check: WriteCountCheck,
pub bytes_until_shutoff: Option<u32>,
pub alignment_check: bool,
}
impl<const PAGES: usize, const BYTES_PER_WORD: usize, const PAGE_WORDS: usize> Default
for MockFlashBase<PAGES, BYTES_PER_WORD, PAGE_WORDS>
{
fn default() -> Self {
Self::new(WriteCountCheck::OnceOnly, None, true)
}
}
impl<const PAGES: usize, const BYTES_PER_WORD: usize, const PAGE_WORDS: usize>
MockFlashBase<PAGES, BYTES_PER_WORD, PAGE_WORDS>
{
const CAPACITY_WORDS: usize = PAGES * PAGE_WORDS;
const CAPACITY_BYTES: usize = Self::CAPACITY_WORDS * BYTES_PER_WORD;
const PAGE_BYTES: usize = PAGE_WORDS * BYTES_PER_WORD;
pub const FULL_FLASH_RANGE: Range<u32> = 0..(PAGES * PAGE_WORDS * BYTES_PER_WORD) as u32;
pub fn new(
write_count_check: WriteCountCheck,
bytes_until_shutoff: Option<u32>,
alignment_check: bool,
) -> Self {
Self {
writable: vec![T; Self::CAPACITY_WORDS],
data: vec![u8::MAX; Self::CAPACITY_BYTES],
current_stats: FlashStatsSnapshot {
erases: 0,
reads: 0,
writes: 0,
bytes_read: 0,
bytes_written: 0,
},
write_count_check,
bytes_until_shutoff,
alignment_check,
}
}
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
pub fn as_bytes_mut(&mut self) -> &mut [u8] {
&mut self.data
}
fn validate_operation(offset: u32, length: usize) -> Result<Range<usize>, MockFlashError> {
let offset = offset as usize;
if (offset % Self::READ_SIZE) != 0 || length == 0 || length % BYTES_PER_WORD != 0 {
Err(MockFlashError::NotAligned)
} else if offset > Self::CAPACITY_BYTES || offset + length > Self::CAPACITY_BYTES {
Err(MockFlashError::OutOfBounds)
} else {
Ok(offset..(offset + length))
}
}
fn check_shutoff(&mut self, address: u32, operation: Operation) -> Result<(), MockFlashError> {
if let Some(bytes_until_shutoff) = self.bytes_until_shutoff.as_mut() {
if let Some(next) = bytes_until_shutoff.checked_sub(1) {
*bytes_until_shutoff = next;
Ok(())
} else {
#[cfg(fuzzing_repro)]
eprintln!("!!! Shutoff at {address} while doing '{operation:?}' !!!");
self.bytes_until_shutoff = None;
Err(MockFlashError::EarlyShutoff(address, operation))
}
} else {
Ok(())
}
}
pub fn stats_snapshot(&self) -> FlashStatsSnapshot {
self.current_stats
}
#[cfg(any(test, feature = "_test"))]
pub async fn get_item_presence(&mut self, target_item_address: u32) -> Option<bool> {
use crate::NorFlashExt;
if !Self::FULL_FLASH_RANGE.contains(&target_item_address) {
return None;
}
let mut buf = [0; 1024 * 16];
let page_index =
crate::calculate_page_index::<Self>(Self::FULL_FLASH_RANGE, target_item_address);
let page_data_start =
crate::calculate_page_address::<Self>(Self::FULL_FLASH_RANGE, page_index)
+ Self::WORD_SIZE as u32;
let page_data_end =
crate::calculate_page_end_address::<Self>(Self::FULL_FLASH_RANGE, page_index)
- Self::WORD_SIZE as u32;
let mut found_item = None;
let mut it = crate::item::ItemHeaderIter::new(page_data_start, page_data_end);
while let (Some(header), item_address) = it.traverse(self, |_, _| false).await.unwrap() {
let next_item_address = header.next_item_address::<Self>(item_address);
if (item_address..next_item_address).contains(&target_item_address) {
let maybe_item = header
.read_item(self, &mut buf, item_address, page_data_end)
.await
.unwrap();
match maybe_item {
crate::item::MaybeItem::Corrupted(_, _)
| crate::item::MaybeItem::Erased(_, _) => {
found_item.replace(false);
break;
}
crate::item::MaybeItem::Present(_) => {
found_item.replace(true);
break;
}
}
}
}
found_item
}
}
impl<const PAGES: usize, const BYTES_PER_WORD: usize, const PAGE_WORDS: usize> ErrorType
for MockFlashBase<PAGES, BYTES_PER_WORD, PAGE_WORDS>
{
type Error = MockFlashError;
}
impl<const PAGES: usize, const BYTES_PER_WORD: usize, const PAGE_WORDS: usize> ReadNorFlash
for MockFlashBase<PAGES, BYTES_PER_WORD, PAGE_WORDS>
{
const READ_SIZE: usize = BYTES_PER_WORD;
async fn read(&mut self, offset: u32, bytes: &mut [u8]) -> Result<(), Self::Error> {
self.current_stats.reads += 1;
self.current_stats.bytes_read += bytes.len() as u64;
if bytes.len() % Self::READ_SIZE != 0 {
panic!("any read must be a multiple of Self::READ_SIZE bytes");
}
let range = Self::validate_operation(offset, bytes.len())?;
bytes.copy_from_slice(&self.as_bytes()[range]);
Ok(())
}
fn capacity(&self) -> usize {
Self::CAPACITY_BYTES
}
}
impl<const PAGES: usize, const BYTES_PER_WORD: usize, const PAGE_WORDS: usize> MultiwriteNorFlash
for MockFlashBase<PAGES, BYTES_PER_WORD, PAGE_WORDS>
{
}
impl<const PAGES: usize, const BYTES_PER_WORD: usize, const PAGE_WORDS: usize> NorFlash
for MockFlashBase<PAGES, BYTES_PER_WORD, PAGE_WORDS>
{
const WRITE_SIZE: usize = BYTES_PER_WORD;
const ERASE_SIZE: usize = Self::PAGE_BYTES;
async fn erase(&mut self, from: u32, to: u32) -> Result<(), Self::Error> {
self.current_stats.erases += 1;
let from = from as usize;
let to = to as usize;
assert!(from <= to);
if to > Self::CAPACITY_BYTES {
return Err(MockFlashError::OutOfBounds);
}
if from % Self::PAGE_BYTES != 0 || to % Self::PAGE_BYTES != 0 {
return Err(MockFlashError::NotAligned);
}
for index in from..to {
self.check_shutoff(index as u32, Operation::Erase)?;
self.as_bytes_mut()[index] = u8::MAX;
if index % BYTES_PER_WORD == 0 {
self.writable[index / BYTES_PER_WORD] = T;
}
}
Ok(())
}
async fn write(&mut self, offset: u32, bytes: &[u8]) -> Result<(), Self::Error> {
self.current_stats.writes += 1;
let range = Self::validate_operation(offset, bytes.len())?;
if self.alignment_check && bytes.as_ptr() as usize % 4 != 0 {
panic!("write buffer must be aligned to 4 bytes");
}
if bytes.len() % Self::WRITE_SIZE != 0 {
panic!("any write must be a multiple of Self::WRITE_SIZE bytes");
}
for (source_word, address) in bytes
.chunks_exact(BYTES_PER_WORD)
.zip(range.step_by(BYTES_PER_WORD))
{
for (byte_index, byte) in source_word.iter().enumerate() {
self.check_shutoff((address + byte_index) as u32, Operation::Write)?;
if byte_index == 0 {
let word_writable = &mut self.writable[address / BYTES_PER_WORD];
*word_writable = match (*word_writable, self.write_count_check) {
(v, WriteCountCheck::Disabled) => v,
(Writable::T, _) => Writable::O,
(Writable::O, WriteCountCheck::Twice) => Writable::N,
(Writable::O, WriteCountCheck::TwiceDifferent)
if source_word == &self.data[address..][..BYTES_PER_WORD] =>
{
Writable::O
}
(Writable::O, WriteCountCheck::TwiceDifferent) => Writable::N,
(Writable::O, WriteCountCheck::TwiceWithZero)
if source_word.iter().all(|b| *b == 0) =>
{
Writable::N
}
_ => return Err(MockFlashError::NotWritable(address as u32)),
};
}
self.current_stats.bytes_written += 1;
self.as_bytes_mut()[address + byte_index] &= byte;
}
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MockFlashError {
OutOfBounds,
NotAligned,
NotWritable(u32),
EarlyShutoff(u32, Operation),
}
impl Display for MockFlashError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{self:?}")
}
}
impl NorFlashError for MockFlashError {
fn kind(&self) -> NorFlashErrorKind {
match self {
MockFlashError::OutOfBounds => NorFlashErrorKind::OutOfBounds,
MockFlashError::NotAligned => NorFlashErrorKind::NotAligned,
MockFlashError::NotWritable(_) => NorFlashErrorKind::Other,
MockFlashError::EarlyShutoff(_, _) => NorFlashErrorKind::Other,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WriteCountCheck {
OnceOnly,
TwiceDifferent,
Twice,
TwiceWithZero,
Disabled,
}
#[derive(Debug, Clone, Copy)]
pub struct FlashStatsSnapshot {
erases: u64,
reads: u64,
writes: u64,
bytes_read: u64,
bytes_written: u64,
}
impl FlashStatsSnapshot {
pub fn compare_to(&self, other: Self) -> FlashStatsResult {
FlashStatsResult {
erases: other
.erases
.checked_sub(self.erases)
.expect("Order is old compare to new"),
reads: other
.reads
.checked_sub(self.reads)
.expect("Order is old compare to new"),
writes: other
.writes
.checked_sub(self.writes)
.expect("Order is old compare to new"),
bytes_read: other
.bytes_read
.checked_sub(self.bytes_read)
.expect("Order is old compare to new"),
bytes_written: other
.bytes_written
.checked_sub(self.bytes_written)
.expect("Order is old compare to new"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct FlashStatsResult {
pub erases: u64,
pub reads: u64,
pub writes: u64,
pub bytes_read: u64,
pub bytes_written: u64,
}
impl FlashStatsResult {
pub fn take_average(&self, divider: u64) -> FlashAverageStatsResult {
FlashAverageStatsResult {
avg_erases: self.erases as f64 / divider as f64,
avg_reads: self.reads as f64 / divider as f64,
avg_writes: self.writes as f64 / divider as f64,
avg_bytes_read: self.bytes_read as f64 / divider as f64,
avg_bytes_written: self.bytes_written as f64 / divider as f64,
}
}
}
impl AddAssign for FlashStatsResult {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl Add for FlashStatsResult {
type Output = FlashStatsResult;
fn add(self, rhs: Self) -> Self::Output {
Self {
erases: self.erases + rhs.erases,
reads: self.reads + rhs.reads,
writes: self.writes + rhs.writes,
bytes_read: self.bytes_read + rhs.bytes_read,
bytes_written: self.bytes_written + rhs.bytes_written,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub struct FlashAverageStatsResult {
pub avg_erases: f64,
pub avg_reads: f64,
pub avg_writes: f64,
pub avg_bytes_read: f64,
pub avg_bytes_written: f64,
}
impl approx::AbsDiffEq for FlashAverageStatsResult {
type Epsilon = f64;
fn default_epsilon() -> Self::Epsilon {
f64::EPSILON
}
fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
self.avg_erases.abs_diff_eq(&other.avg_erases, epsilon)
&& self.avg_reads.abs_diff_eq(&other.avg_reads, epsilon)
&& self.avg_writes.abs_diff_eq(&other.avg_writes, epsilon)
&& self
.avg_bytes_read
.abs_diff_eq(&other.avg_bytes_read, epsilon)
&& self
.avg_bytes_written
.abs_diff_eq(&other.avg_bytes_written, epsilon)
}
}
impl approx::RelativeEq for FlashAverageStatsResult {
fn default_max_relative() -> Self::Epsilon {
f64::default_max_relative()
}
fn relative_eq(
&self,
other: &Self,
epsilon: Self::Epsilon,
max_relative: Self::Epsilon,
) -> bool {
self.avg_erases
.relative_eq(&other.avg_erases, epsilon, max_relative)
&& self
.avg_reads
.relative_eq(&other.avg_reads, epsilon, max_relative)
&& self
.avg_writes
.relative_eq(&other.avg_writes, epsilon, max_relative)
&& self
.avg_bytes_read
.relative_eq(&other.avg_bytes_read, epsilon, max_relative)
&& self
.avg_bytes_written
.relative_eq(&other.avg_bytes_written, epsilon, max_relative)
}
}
#[allow(missing_docs)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Operation {
Read,
Write,
Erase,
}