use core::cell::RefCell;
use core::cmp::Ordering;
use core::future::poll_fn;
use core::ops::{Deref, DerefMut, RangeBounds};
use core::task::Poll;
use embassy_sync::blocking_mutex::raw::RawMutex;
use embassy_sync::blocking_mutex::Mutex as BlockingMutex;
use embassy_sync::mutex::Mutex;
use embassy_sync::waitqueue::WakerRegistration;
use heapless::Vec;
use crate::config::*;
use crate::errors::{no_eof, CorruptedError, Error, MountError, ReadError, WriteError};
use crate::file::{FileID, FileManager, FileReader, FileSearcher, FileWriter, SeekDirection, PAGE_MAX_PAYLOAD_SIZE};
use crate::flash::Flash;
use crate::page::{PageReader, ReadError as PageReadError};
use crate::{CommitError, Cursor, FormatError};
const FILE_FLAG_COMPACT_DEST: u8 = 0x01;
const FILE_FLAG_COMPACT_SRC: u8 = 0x02;
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[non_exhaustive]
pub struct Config {
pub random_seed: u32,
}
impl Default for Config {
fn default() -> Self {
Self::default()
}
}
impl Config {
const fn default() -> Self {
Self { random_seed: 0 }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
enum WriteTxState {
Idle,
Created,
Committing,
}
struct State {
read_tx_count: usize,
write_tx: WriteTxState,
waker: WakerRegistration,
}
pub struct Database<F: Flash, M: RawMutex> {
state: BlockingMutex<M, RefCell<State>>,
pub(crate) inner: Mutex<M, Inner<F>>,
}
impl<F: Flash, M: RawMutex> Database<F, M> {
pub fn new(flash: F, config: Config) -> Self {
Self {
inner: Mutex::new(Inner::new(flash, config.random_seed)),
state: BlockingMutex::new(RefCell::new(State {
read_tx_count: 0,
write_tx: WriteTxState::Idle,
waker: WakerRegistration::new(),
})),
}
}
pub async fn lock_flash(&self) -> impl DerefMut<Target = F> + '_ {
FlashLockGuard(self.inner.lock().await)
}
pub async fn format(&self) -> Result<(), FormatError<F::Error>> {
self.inner.lock().await.format().await
}
pub async fn mount(&self) -> Result<(), MountError<F::Error>> {
self.inner.lock().await.mount().await
}
#[cfg(feature = "std")]
pub async fn dump(&self) {
self.inner.lock().await.dump().await
}
pub async fn read_transaction(&self) -> ReadTransaction<'_, F, M> {
poll_fn(|cx| {
self.state.lock(|s| {
let s = &mut s.borrow_mut();
if s.write_tx == WriteTxState::Committing {
s.waker.register(cx.waker());
return Poll::Pending;
}
s.read_tx_count = s.read_tx_count.checked_add(1).unwrap();
Poll::Ready(())
})
})
.await;
ReadTransaction { db: self }
}
pub async fn write_transaction(&self) -> WriteTransaction<'_, F, M> {
poll_fn(|cx| {
self.state.lock(|s| {
let s = &mut s.borrow_mut();
if s.write_tx != WriteTxState::Idle {
s.waker.register(cx.waker());
return Poll::Pending;
}
s.write_tx = WriteTxState::Created;
Poll::Ready(())
})
})
.await;
WriteTransaction {
db: self,
state: WriteTransactionState::Created,
}
}
}
struct FlashLockGuard<G, F>(G)
where
G: Deref<Target = Inner<F>> + DerefMut,
F: Flash;
impl<G, F> DerefMut for FlashLockGuard<G, F>
where
G: Deref<Target = Inner<F>> + DerefMut,
F: Flash,
{
fn deref_mut(&mut self) -> &mut Self::Target {
self.0.files.flash_mut()
}
}
impl<G, F> Deref for FlashLockGuard<G, F>
where
G: Deref<Target = Inner<F>> + DerefMut,
F: Flash,
{
type Target = F;
fn deref(&self) -> &Self::Target {
self.0.files.flash()
}
}
pub struct ReadTransaction<'a, F: Flash + 'a, M: RawMutex + 'a> {
db: &'a Database<F, M>,
}
impl<'a, F: Flash + 'a, M: RawMutex + 'a> Drop for ReadTransaction<'a, F, M> {
fn drop(&mut self) {
self.db.state.lock(|s| {
let s = &mut s.borrow_mut();
s.read_tx_count = s.read_tx_count.checked_sub(1).unwrap();
if s.read_tx_count == 0 {
s.waker.wake();
}
})
}
}
impl<'a, F: Flash + 'a, M: RawMutex + 'a> ReadTransaction<'a, F, M> {
pub async fn read(&self, key: &[u8], value: &mut [u8]) -> Result<usize, ReadError<F::Error>> {
if key.len() > MAX_KEY_SIZE {
return Err(ReadError::KeyTooBig);
}
self.db.inner.lock().await.read(key, value).await
}
pub async fn read_all<'b>(&'b self) -> Result<Cursor<'b, F, M>, Error<F::Error>> {
self.read_range(..).await
}
pub async fn read_range<'b>(
&'b self,
range: impl RangeBounds<&'b [u8]>,
) -> Result<Cursor<'b, F, M>, Error<F::Error>> {
Cursor::new(self.db, range.start_bound().map(|x| *x), range.end_bound().map(|x| *x)).await
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
enum WriteTransactionState {
Created,
InProgress,
Canceled,
}
pub struct WriteTransaction<'a, F: Flash + 'a, M: RawMutex + 'a> {
db: &'a Database<F, M>,
state: WriteTransactionState,
}
impl<'a, F: Flash + 'a, M: RawMutex + 'a> Drop for WriteTransaction<'a, F, M> {
fn drop(&mut self) {
self.db.state.lock(|s| {
let s = &mut s.borrow_mut();
assert!(s.write_tx != WriteTxState::Idle);
s.write_tx = WriteTxState::Idle;
s.waker.wake();
})
}
}
impl<'a, F: Flash + 'a, M: RawMutex + 'a> WriteTransaction<'a, F, M> {
pub async fn write(&mut self, key: &[u8], value: &[u8]) -> Result<(), WriteError<F::Error>> {
self.write_inner(key, value, false).await
}
pub async fn delete(&mut self, key: &[u8]) -> Result<(), WriteError<F::Error>> {
self.write_inner(key, &[], true).await
}
async fn write_inner(&mut self, key: &[u8], value: &[u8], is_delete: bool) -> Result<(), WriteError<F::Error>> {
let is_first_write = match self.state {
WriteTransactionState::Canceled => return Err(WriteError::TransactionCanceled),
WriteTransactionState::Created => true,
WriteTransactionState::InProgress => false,
};
if key.len() > MAX_KEY_SIZE {
return Err(WriteError::KeyTooBig);
}
if value.len() > MAX_VALUE_SIZE {
return Err(WriteError::ValueTooBig);
}
self.state = WriteTransactionState::Canceled;
let db = &mut *self.db.inner.lock().await;
db.files.remount_if_dirty(&mut db.readers[0]).await?;
if is_first_write {
db.rollback_if_any().await?;
}
db.write(key, value, is_delete).await?;
self.state = WriteTransactionState::InProgress;
Ok(())
}
pub async fn commit(self) -> Result<(), CommitError<F::Error>> {
match self.state {
WriteTransactionState::Canceled => return Err(CommitError::TransactionCanceled),
WriteTransactionState::Created => return Ok(()),
WriteTransactionState::InProgress => {}
}
self.db.state.lock(|s| {
let s = &mut s.borrow_mut();
assert!(s.write_tx == WriteTxState::Created);
s.write_tx = WriteTxState::Committing;
});
poll_fn(|cx| {
self.db.state.lock(|s| {
let s = &mut s.borrow_mut();
if s.read_tx_count != 0 {
s.waker.register(cx.waker());
return Poll::Pending;
}
Poll::Ready(())
})
})
.await;
self.db.inner.lock().await.commit().await?;
Ok(())
}
}
pub(crate) struct Inner<F: Flash> {
pub(crate) files: FileManager<F>,
pub(crate) readers: [PageReader; BRANCHING_FACTOR],
write_tx: Option<WriteTransactionInner>,
}
impl<F: Flash> Inner<F> {
fn new(flash: F, random_seed: u32) -> Self {
const NEW_PR: PageReader = PageReader::new();
Self {
files: FileManager::new(flash, random_seed),
readers: [NEW_PR; BRANCHING_FACTOR],
write_tx: None,
}
}
async fn format(&mut self) -> Result<(), FormatError<F::Error>> {
assert!(self.write_tx.is_none());
self.files.format().await
}
async fn mount(&mut self) -> Result<(), MountError<F::Error>> {
assert!(self.write_tx.is_none());
self.files.remount_if_dirty(&mut self.readers[0]).await?;
Ok(())
}
async fn read(&mut self, key: &[u8], value: &mut [u8]) -> Result<usize, ReadError<F::Error>> {
self.files.remount_if_dirty(&mut self.readers[0]).await?;
for file_id in (0..FILE_COUNT).rev() {
trace!("read: checking file {}", file_id);
if let Some(res) = self.read_in_file(file_id as _, key, value).await? {
return Ok(res);
}
}
Err(ReadError::KeyNotFound)
}
async fn read_in_file(
&mut self,
file_id: FileID,
key: &[u8],
value: &mut [u8],
) -> Result<Option<usize>, ReadError<F::Error>> {
let r = self.files.read(&mut self.readers[0], file_id);
let m = &mut self.files;
let mut s = FileSearcher::new(r);
let mut key_buf = [0u8; MAX_KEY_SIZE];
let mut header = [0; RECORD_HEADER_SIZE];
let mut ok = s.start(m).await?;
while ok {
match s.reader().read(m, &mut header).await {
Ok(()) => {}
Err(PageReadError::Eof) => return Ok(None), Err(e) => return Err(no_eof(e).into()),
};
let header = RecordHeader::decode(header)?;
let got_key = &mut key_buf[..header.key_len];
s.reader().read(m, got_key).await.map_err(no_eof)?;
let dir = match got_key[..].cmp(key) {
Ordering::Equal => {
if header.is_delete {
return Err(ReadError::KeyNotFound);
}
if header.value_len > value.len() {
return Err(ReadError::BufferTooSmall);
}
s.reader()
.read(m, &mut value[..header.value_len])
.await
.map_err(no_eof)?;
return Ok(Some(header.value_len));
}
Ordering::Less => SeekDirection::Right,
Ordering::Greater => SeekDirection::Left,
};
ok = s.seek(m, dir).await?;
}
let r = s.reader();
loop {
match r.read(m, &mut header).await {
Ok(()) => {}
Err(PageReadError::Eof) => return Ok(None), Err(e) => return Err(no_eof(e).into()),
};
let header = RecordHeader::decode(header)?;
let got_key = &mut key_buf[..header.key_len];
r.read(m, got_key).await.map_err(no_eof)?;
match got_key[..].cmp(key) {
Ordering::Equal => {
if header.is_delete {
return Err(ReadError::KeyNotFound);
}
if header.value_len > value.len() {
return Err(ReadError::BufferTooSmall);
}
r.read(m, &mut value[..header.value_len]).await.map_err(no_eof)?;
return Ok(Some(header.value_len));
}
Ordering::Less => {} Ordering::Greater => return Ok(None), }
r.skip(m, header.value_len).await.map_err(no_eof)?;
}
}
async fn ensure_write_transaction_started(&mut self) -> Result<(), Error<F::Error>> {
if self.write_tx.is_some() {
return Ok(());
}
debug!("write_transaction: start");
let file_id = loop {
match self.new_file_in_level(LEVEL_COUNT - 1) {
Some(f) => break f,
None => {
debug!("write_transaction: no free file, compacting.");
let did_something = self.compact().await?;
assert!(did_something);
}
}
};
debug!("write_transaction: writing file {}", file_id);
let w = self.files.write(&mut self.readers[0], file_id).await?;
self.write_tx = Some(WriteTransactionInner { w, last_key: None });
Ok(())
}
async fn write(&mut self, key: &[u8], value: &[u8], is_delete: bool) -> Result<(), WriteError<F::Error>> {
self.ensure_write_transaction_started().await?;
let tx = self.write_tx.as_mut().unwrap();
if let Some(last_key) = &tx.last_key {
if key <= last_key {
return Err(WriteError::NotSorted);
}
}
tx.last_key = Some(Vec::from_slice(key).unwrap());
let header = RecordHeader {
is_delete,
key_len: key.len(),
value_len: value.len(),
};
loop {
let tx = self.write_tx.as_mut().unwrap();
let need_size = header.record_size() + MIN_FREE_PAGE_COUNT * PAGE_MAX_PAYLOAD_SIZE;
let available_size = tx.w.space_left_on_current_page() + self.files.free_pages() * PAGE_MAX_PAYLOAD_SIZE;
if need_size <= available_size {
break;
}
debug!("free pages less than buffer, compacting.");
let did_something = self.compact().await?;
if !did_something {
debug!("storage full");
return Err(WriteError::Full);
}
}
let tx = self.write_tx.as_mut().unwrap();
tx.w.write(&mut self.files, &header.encode()).await?;
tx.w.write(&mut self.files, key).await?;
tx.w.write(&mut self.files, value).await?;
tx.w.record_end();
Ok(())
}
async fn commit(&mut self) -> Result<(), Error<F::Error>> {
debug!("write_transaction: commit");
let tx = self.write_tx.as_mut().unwrap();
self.files.commit(&mut tx.w).await?;
self.write_tx = None;
Ok(())
}
async fn rollback(&mut self) -> Result<(), Error<F::Error>> {
debug!("write_transaction: rollback");
let tx = self.write_tx.as_mut().unwrap();
tx.w.discard(&mut self.files).await.unwrap();
self.write_tx = None;
Ok(())
}
async fn rollback_if_any(&mut self) -> Result<(), Error<F::Error>> {
if self.write_tx.is_some() {
self.rollback().await?
}
Ok(())
}
fn file_id(level: usize, index: usize) -> FileID {
(1 + level * BRANCHING_FACTOR + index) as _
}
fn new_file_in_level(&mut self, level: usize) -> Option<FileID> {
let mut res = None;
for i in 0..BRANCHING_FACTOR {
let file_id = Self::file_id(level, i);
if self.files.is_empty(file_id) {
if res.is_none() {
res = Some(file_id)
}
} else {
res = None
}
}
res
}
fn is_level_full(&self, level: usize) -> bool {
(0..BRANCHING_FACTOR).all(|i| !self.files.is_empty(Self::file_id(level, i)))
}
fn level_file_count(&self, level: usize) -> usize {
(0..BRANCHING_FACTOR)
.filter(|&i| !self.files.is_empty(Self::file_id(level, i)))
.count()
}
fn compact_find_work(&mut self) -> Result<Option<(Vec<FileID, BRANCHING_FACTOR>, FileID)>, CorruptedError> {
match self.files.files_with_flag(FILE_FLAG_COMPACT_DEST).single() {
Ok(dst) => {
debug!("compact_find_work: continuing in-progress compact.");
let mut src = Vec::new();
for src_file in self.files.files_with_flag(FILE_FLAG_COMPACT_SRC) {
if src_file <= dst {
corrupted!()
}
if src.push(src_file).is_err() {
corrupted!()
}
}
if src.is_empty() {
corrupted!()
}
return Ok(Some((src, dst)));
}
Err(SingleError::MultipleElements) => corrupted!(), Err(SingleError::NoElements) => {} }
if !self.files.is_empty(0) {
corrupted!()
}
let lv = (0..LEVEL_COUNT)
.filter(|&lv| lv == 0 || !self.is_level_full(lv - 1))
.max_by_key(|&lv| self.level_file_count(lv))
.unwrap();
let dst = if lv == 0 {
0
} else {
match self.new_file_in_level(lv - 1) {
Some(dst) => dst,
None => corrupted!(),
}
};
let mut src = Vec::new();
for i in 0..BRANCHING_FACTOR {
let src_file = Self::file_id(lv, i);
if !self.files.is_empty(src_file) {
src.push(src_file).unwrap();
}
}
if src.is_empty() || (src.len() == 1 && lv == 0) {
debug!("compact_find_work: no work.");
return Ok(None);
}
debug!("compact_find_work: starting new compact.");
Ok(Some((src, dst)))
}
async fn do_compact(&mut self, src: Vec<FileID, BRANCHING_FACTOR>, dst: FileID) -> Result<(), Error<F::Error>> {
debug!("do_compact {:?} -> {}", &src[..], dst);
let topmost = dst == 0;
assert!(!src.is_empty());
if self.files.is_empty(dst) && src.len() == 1 {
debug!("do_compact: short-circuit rename");
let mut tx = self.files.transaction();
tx.rename(src[0], dst).await?;
tx.commit().await?;
return Ok(());
}
let m = &mut self.files;
let mut w = m.write(&mut self.readers[0], dst).await?;
let mut r: Vec<FileReader, BRANCHING_FACTOR> = Vec::from_iter(
core::iter::zip(&src, &mut self.readers[..]).map(|(&file_id, reader)| m.read(reader, file_id)),
);
struct KeySlot {
valid: bool,
header: RecordHeader,
key_buf: [u8; MAX_KEY_SIZE],
}
impl KeySlot {
fn key(&self) -> &[u8] {
&self.key_buf[..self.header.key_len]
}
}
async fn read_key_slot<F: Flash>(
m: &mut FileManager<F>,
r: &mut FileReader<'_>,
buf: &mut KeySlot,
) -> Result<(), Error<F::Error>> {
let mut header = [0; RECORD_HEADER_SIZE];
match r.read(m, &mut header).await {
Ok(()) => {}
Err(PageReadError::Flash(e)) => return Err(Error::Flash(e)),
Err(PageReadError::Eof) => {
buf.valid = false;
return Ok(());
}
Err(PageReadError::Corrupted) => corrupted!(),
}
buf.valid = true;
buf.header = RecordHeader::decode(header)?;
match r.read(m, &mut buf.key_buf[..buf.header.key_len]).await {
Ok(()) => Ok(()),
Err(PageReadError::Flash(e)) => Err(Error::Flash(e)),
Err(PageReadError::Eof) => corrupted!(),
Err(PageReadError::Corrupted) => corrupted!(),
}
}
const NEW_SLOT: KeySlot = KeySlot {
valid: false,
header: RecordHeader {
key_len: 0,
value_len: 0,
is_delete: false,
},
key_buf: [0; MAX_KEY_SIZE],
};
let mut k = [NEW_SLOT; BRANCHING_FACTOR];
let mut trunc = [0; BRANCHING_FACTOR];
for i in 0..src.len() {
read_key_slot(m, &mut r[i], &mut k[i]).await?;
}
let mut progress = false;
let done = loop {
fn highest_bit(x: u32) -> Option<usize> {
match x {
0 => None,
_ => Some(31 - x.leading_zeros() as usize),
}
}
let mut bits: u32 = 0;
for i in 0..src.len() {
if !k[i].valid {
continue;
}
match highest_bit(bits) {
None => bits = 1 << i,
Some(j) => match k[j].key().cmp(k[i].key()) {
Ordering::Greater => bits = 1 << i,
Ordering::Equal => bits |= 1 << i,
Ordering::Less => {}
},
}
}
trace!("do_compact: bits {:02x}", bits);
match highest_bit(bits) {
None => break true,
Some(i) => {
let need_size = k[i].header.record_size() + MIN_FREE_PAGE_COUNT_COMPACT * PAGE_MAX_PAYLOAD_SIZE;
let available_size = w.space_left_on_current_page() + m.free_pages() * PAGE_MAX_PAYLOAD_SIZE;
trace!(
"do_compact: key_len={} val_len={} space_left={} free_pages={} size={} available_size={}",
k[i].header.key_len,
k[i].header.value_len,
w.space_left_on_current_page(),
m.free_pages(),
need_size,
available_size
);
if need_size > available_size {
break false;
}
#[cfg(feature = "defmt")]
trace!("do_compact: copying key from file {:?}: {:02x}", src[i], &k[i].key());
#[cfg(not(feature = "defmt"))]
trace!("do_compact: copying key from file {:?}: {:02x?}", src[i], &k[i].key());
progress = true;
if topmost && k[i].header.is_delete {
trace!("do_compact: skipping tombstone.");
} else {
w.write(m, &k[i].header.encode()).await?;
w.write(m, k[i].key()).await?;
copy(m, &mut r[i], &mut w, k[i].header.value_len).await?;
w.record_end();
}
for j in 0..BRANCHING_FACTOR {
if (bits & 1 << j) != 0 {
if j != i {
r[j].skip(m, k[j].header.value_len).await.map_err(no_eof)?;
}
trunc[j] = r[j].offset(m);
read_key_slot(m, &mut r[j], &mut k[j]).await?;
}
}
}
}
};
debug!("do_compact: stopped. done={:?} progress={:?}", done, progress);
if !progress {
return Err(Error::Corrupted);
}
let (src_flag, dst_flag) = match done {
true => (0, 0),
false => (FILE_FLAG_COMPACT_SRC, FILE_FLAG_COMPACT_DEST),
};
let mut tx = self.files.transaction();
for (i, &file_id) in src.iter().enumerate() {
tx.set_flags(file_id, src_flag).await?;
tx.truncate(file_id, trunc[i]).await?;
}
w.commit(&mut tx).await?;
tx.set_flags(dst, dst_flag).await?;
if topmost && done {
tx.rename(0, Self::file_id(0, 0)).await?;
}
tx.commit().await?;
Ok(())
}
async fn compact(&mut self) -> Result<bool, Error<F::Error>> {
let Some((src, dst)) = self.compact_find_work()? else {
return Ok(false);
};
self.do_compact(src, dst).await?;
Ok(true)
}
#[cfg(feature = "std")]
pub async fn dump(&mut self) {
info!("============= BEGIN DUMP");
self.files.dump_pages(&mut self.readers[0]).await;
if let Err(e) = self.files.remount_if_dirty(&mut self.readers[0]).await {
info!("db is dirty, and remount failed: {:?}", e);
return;
}
info!("File dump:");
for file_id in 0..FILE_COUNT {
if let Err(e) = self.dump_file(file_id as _).await {
info!("failed to dump file: {:?}", e);
}
}
}
#[cfg(feature = "std")]
#[allow(unused)]
async fn dump_file_headers(&mut self) {
info!("============= BEGIN DUMP");
for file_id in 0..FILE_COUNT {
self.files.dump_file_header(file_id as _);
}
}
#[cfg(feature = "std")]
async fn dump_file(&mut self, file_id: FileID) -> Result<(), Error<F::Error>> {
self.files.dump_file(&mut self.readers[0], file_id).await?;
let mut r = self.files.read(&mut self.readers[0], file_id);
let mut key = [0u8; MAX_KEY_SIZE];
let mut value = [0u8; MAX_VALUE_SIZE];
loop {
let seq = r.curr_seq(&mut self.files);
let mut header = [0; RECORD_HEADER_SIZE];
match r.read(&mut self.files, &mut header).await {
Ok(()) => {}
Err(PageReadError::Flash(e)) => return Err(Error::Flash(e)),
Err(PageReadError::Eof) => break,
Err(PageReadError::Corrupted) => corrupted!(),
};
let header = RecordHeader::decode(header)?;
let key = &mut key[..header.key_len];
r.read(&mut self.files, key).await.map_err(no_eof)?;
let value = &mut value[..header.value_len];
r.read(&mut self.files, value).await.map_err(no_eof)?;
debug!(
"record at seq={:?}: key_len={} key={:02x?} value_len={} value={:02x?}",
seq,
key.len(),
key,
value.len(),
value
);
}
Ok(())
}
}
pub struct WriteTransactionInner {
w: FileWriter,
last_key: Option<Vec<u8, MAX_KEY_SIZE>>,
}
async fn copy<F: Flash>(
m: &mut FileManager<F>,
r: &mut FileReader<'_>,
w: &mut FileWriter,
mut len: usize,
) -> Result<(), Error<F::Error>> {
let mut buf = [0; 128];
while len != 0 {
let n = len.min(buf.len());
len -= n;
r.read(m, &mut buf[..n]).await.map_err(no_eof)?;
w.write(m, &buf[..n]).await?;
}
Ok(())
}
#[derive(Debug, Copy, Clone)]
pub(crate) struct RecordHeader {
pub key_len: usize,
pub value_len: usize,
pub is_delete: bool,
}
impl RecordHeader {
pub fn decode(raw: [u8; RECORD_HEADER_SIZE]) -> Result<Self, CorruptedError> {
let mut raw2 = [0u8; 4];
raw2[..RECORD_HEADER_SIZE].copy_from_slice(&raw);
let raw = u32::from_le_bytes(raw2);
let key_len = raw & ((1 << KEY_SIZE_BITS) - 1);
let value_len = (raw >> KEY_SIZE_BITS) & ((1 << VALUE_SIZE_BITS) - 1);
let is_delete = (raw >> (KEY_SIZE_BITS + VALUE_SIZE_BITS)) & 1 != 0;
let this = Self {
is_delete,
key_len: key_len as usize,
value_len: value_len as usize,
};
if !this.valid() {
corrupted!();
}
Ok(this)
}
pub fn encode(self) -> [u8; RECORD_HEADER_SIZE] {
assert!(self.valid());
let res = (self.key_len as u32)
| ((self.value_len as u32) << KEY_SIZE_BITS)
| ((self.is_delete as u32) << (KEY_SIZE_BITS + VALUE_SIZE_BITS));
res.to_le_bytes()[..RECORD_HEADER_SIZE].try_into().unwrap()
}
pub const fn record_size(self) -> usize {
4 + self.key_len + self.value_len
}
fn valid(self) -> bool {
self.key_len <= MAX_KEY_SIZE && self.value_len <= MAX_VALUE_SIZE && !(self.is_delete && self.value_len != 0)
}
}
pub trait Single: Iterator {
fn single(self) -> Result<Self::Item, SingleError>;
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum SingleError {
NoElements,
MultipleElements,
}
impl<I: Iterator> Single for I {
fn single(mut self) -> Result<Self::Item, SingleError> {
match self.next() {
None => Err(SingleError::NoElements),
Some(element) => match self.next() {
None => Ok(element),
Some(_) => Err(SingleError::MultipleElements),
},
}
}
}
#[cfg(test)]
mod tests {
use core::cell::Cell;
use core::future::Future;
use core::pin::Pin;
use core::ptr;
use core::task::{Context, RawWaker, RawWakerVTable, Waker};
use embassy_sync::blocking_mutex::raw::NoopRawMutex;
use tokio::task::yield_now;
use super::*;
use crate::flash::MemFlash;
async fn check_read(db: &Database<impl Flash, NoopRawMutex>, key: &[u8], value: &[u8]) {
let rtx = db.read_transaction().await;
let mut buf = [0; 1024];
let n = rtx.read(key, &mut buf).await.unwrap();
assert_eq!(&buf[..n], value);
}
async fn check_not_found<F: Flash>(db: &Database<F, NoopRawMutex>, key: &[u8])
where
F::Error: PartialEq,
{
let rtx = db.read_transaction().await;
assert_eq!(rtx.read(key, &mut []).await, Err(ReadError::KeyNotFound));
}
async fn compact(db: &Database<impl Flash, NoopRawMutex>) -> bool {
let mut work = false;
while db.inner.lock().await.compact().await.unwrap() {
work = true
}
work
}
#[test_log::test(tokio::test)]
async fn test() {
let mut f = MemFlash::new();
let db = Database::new(&mut f, Config::default());
db.format().await.unwrap();
let mut wtx = db.write_transaction().await;
wtx.write(b"bar", b"4321").await.unwrap();
wtx.write(b"foo", b"1234").await.unwrap();
wtx.commit().await.unwrap();
check_read(&db, b"foo", b"1234").await;
check_read(&db, b"bar", b"4321").await;
check_not_found(&db, b"baz").await;
let mut wtx = db.write_transaction().await;
wtx.write(b"bar", b"8765").await.unwrap();
wtx.write(b"baz", b"4242").await.unwrap();
wtx.write(b"foo", b"5678").await.unwrap();
wtx.commit().await.unwrap();
check_read(&db, b"foo", b"5678").await;
check_read(&db, b"bar", b"8765").await;
check_read(&db, b"baz", b"4242").await;
let mut wtx = db.write_transaction().await;
wtx.write(b"lol", b"9999").await.unwrap();
wtx.commit().await.unwrap();
check_read(&db, b"foo", b"5678").await;
check_read(&db, b"bar", b"8765").await;
check_read(&db, b"baz", b"4242").await;
check_read(&db, b"lol", b"9999").await;
}
#[test_log::test(tokio::test)]
async fn test_empty_key() {
let mut f = MemFlash::new();
let db = Database::new(&mut f, Config::default());
db.format().await.unwrap();
let mut wtx = db.write_transaction().await;
wtx.write(b"", b"aaaa").await.unwrap();
wtx.write(b"foo", b"4321").await.unwrap();
wtx.commit().await.unwrap();
let mut wtx = db.write_transaction().await;
wtx.write(b"", b"bbbb").await.unwrap();
wtx.write(b"foo", b"1234").await.unwrap();
wtx.commit().await.unwrap();
compact(&db).await;
check_read(&db, b"", b"bbbb").await;
check_read(&db, b"foo", b"1234").await;
check_not_found(&db, b"baz").await;
}
#[test_log::test(tokio::test)]
async fn test_empty_value() {
let mut f = MemFlash::new();
let db = Database::new(&mut f, Config::default());
db.format().await.unwrap();
let mut wtx = db.write_transaction().await;
wtx.write(b"", b"aaaa").await.unwrap();
wtx.write(b"bar", b"barbar").await.unwrap();
wtx.write(b"foo", b"").await.unwrap();
wtx.commit().await.unwrap();
let mut wtx = db.write_transaction().await;
wtx.write(b"", b"").await.unwrap();
wtx.write(b"baz", b"").await.unwrap();
wtx.commit().await.unwrap();
compact(&db).await;
check_read(&db, b"", b"").await;
check_read(&db, b"foo", b"").await;
check_read(&db, b"bar", b"barbar").await;
check_read(&db, b"baz", b"").await;
check_not_found(&db, b"lol").await;
}
#[test_log::test(tokio::test)]
async fn test_delete() {
let mut f = MemFlash::new();
let db = Database::new(&mut f, Config::default());
db.format().await.unwrap();
let mut wtx = db.write_transaction().await;
wtx.write(b"", b"").await.unwrap();
wtx.commit().await.unwrap();
check_read(&db, b"", b"").await;
let mut wtx = db.write_transaction().await;
wtx.delete(b"").await.unwrap();
wtx.commit().await.unwrap();
check_not_found(&db, b"").await;
compact(&db).await;
check_not_found(&db, b"").await;
}
#[test_log::test(tokio::test)]
async fn test_transaction() {
let mut f = MemFlash::new();
let db = Database::new(&mut f, Config::default());
db.format().await.unwrap();
check_not_found(&db, b"foo").await;
check_not_found(&db, b"bar").await;
let mut wtx = db.write_transaction().await;
wtx.write(b"bar", b"1234").await.unwrap();
check_not_found(&db, b"foo").await;
check_not_found(&db, b"bar").await;
wtx.write(b"foo", b"4321").await.unwrap();
check_not_found(&db, b"foo").await;
check_not_found(&db, b"bar").await;
wtx.commit().await.unwrap();
check_read(&db, b"foo", b"4321").await;
check_read(&db, b"bar", b"1234").await;
}
#[test_log::test(tokio::test)]
async fn test_transaction_drop() {
let mut f = MemFlash::new();
let db = Database::new(&mut f, Config::default());
db.format().await.unwrap();
let mut wtx = db.write_transaction().await;
wtx.write(b"foo", b"4321").await.unwrap();
drop(wtx);
check_not_found(&db, b"foo").await;
let mut wtx = db.write_transaction().await;
wtx.write(b"bar", b"4321").await.unwrap();
wtx.commit().await.unwrap();
check_not_found(&db, b"foo").await;
check_read(&db, b"bar", b"4321").await;
}
#[test_log::test(tokio::test)]
async fn test_transaction_locking() {
let mut f = MemFlash::new();
let db = Database::<_, NoopRawMutex>::new(&mut f, Config::default());
db.format().await.unwrap();
static VTABLE: RawWakerVTable =
RawWakerVTable::new(|_| RawWaker::new(ptr::null(), &VTABLE), |_| {}, |_| {}, |_| {});
let raw_waker = RawWaker::new(ptr::null(), &VTABLE);
let waker = unsafe { Waker::from_raw(raw_waker) };
let cx = &mut Context::from_waker(&waker);
let read_state = Cell::new(0);
let mut read_fut = async {
read_state.set(1);
let rtx = db.read_transaction().await;
read_state.set(2);
yield_now().await;
let mut buf = [0; 128];
let _ = rtx.read(b"foo", &mut buf).await;
read_state.set(3);
drop(rtx);
read_state.set(4);
};
let write_state = Cell::new(0);
let mut write_fut = async {
write_state.set(1);
let mut wtx = db.write_transaction().await;
write_state.set(2);
wtx.write(b"foo", b"lol").await.unwrap();
write_state.set(3);
wtx.commit().await.unwrap();
write_state.set(4);
};
let mut read_fut = unsafe { Pin::new_unchecked(&mut read_fut) };
let mut write_fut = unsafe { Pin::new_unchecked(&mut write_fut) };
assert_eq!(read_fut.as_mut().poll(cx), Poll::Pending);
assert_eq!(read_state.get(), 2);
assert_eq!(write_fut.as_mut().poll(cx), Poll::Pending);
assert_eq!(write_state.get(), 3);
assert_eq!(write_fut.as_mut().poll(cx), Poll::Pending);
assert_eq!(write_state.get(), 3);
assert_eq!(read_fut.as_mut().poll(cx), Poll::Ready(()));
assert_eq!(read_state.get(), 4);
assert_eq!(write_fut.as_mut().poll(cx), Poll::Ready(()));
assert_eq!(write_state.get(), 4);
}
#[test_log::test(tokio::test)]
async fn test_transaction_locking_queue() {
let mut f = MemFlash::new();
let db = Database::<_, NoopRawMutex>::new(&mut f, Config::default());
db.format().await.unwrap();
static VTABLE: RawWakerVTable =
RawWakerVTable::new(|_| RawWaker::new(ptr::null(), &VTABLE), |_| {}, |_| {}, |_| {});
let raw_waker = RawWaker::new(ptr::null(), &VTABLE);
let waker = unsafe { Waker::from_raw(raw_waker) };
let cx = &mut Context::from_waker(&waker);
let read_state = Cell::new(0);
let mut read_fut = async {
read_state.set(1);
let rtx = db.read_transaction().await;
read_state.set(2);
yield_now().await;
let mut buf = [0; 128];
let _ = rtx.read(b"foo", &mut buf).await;
read_state.set(3);
drop(rtx);
read_state.set(4);
};
let read2_state = Cell::new(0);
let mut read2_fut = async {
read2_state.set(1);
let rtx = db.read_transaction().await;
read2_state.set(2);
let mut buf = [0; 128];
let _ = rtx.read(b"foo", &mut buf).await;
read2_state.set(3);
drop(rtx);
read2_state.set(4);
};
let write_state = Cell::new(0);
let mut write_fut = async {
write_state.set(1);
let mut wtx = db.write_transaction().await;
write_state.set(2);
wtx.write(b"foo", b"lol").await.unwrap();
write_state.set(3);
wtx.commit().await.unwrap();
write_state.set(4);
};
let mut read_fut = unsafe { Pin::new_unchecked(&mut read_fut) };
let mut read2_fut = unsafe { Pin::new_unchecked(&mut read2_fut) };
let mut write_fut = unsafe { Pin::new_unchecked(&mut write_fut) };
assert_eq!(read_fut.as_mut().poll(cx), Poll::Pending);
assert_eq!(read_state.get(), 2);
assert_eq!(write_fut.as_mut().poll(cx), Poll::Pending);
assert_eq!(write_state.get(), 3);
assert_eq!(read2_fut.as_mut().poll(cx), Poll::Pending);
assert_eq!(read2_state.get(), 1);
assert_eq!(write_fut.as_mut().poll(cx), Poll::Pending);
assert_eq!(write_state.get(), 3);
assert_eq!(read2_fut.as_mut().poll(cx), Poll::Pending);
assert_eq!(read2_state.get(), 1);
assert_eq!(read_fut.as_mut().poll(cx), Poll::Ready(()));
assert_eq!(read_state.get(), 4);
assert_eq!(read2_fut.as_mut().poll(cx), Poll::Pending);
assert_eq!(read2_state.get(), 1);
assert_eq!(write_fut.as_mut().poll(cx), Poll::Ready(()));
assert_eq!(write_state.get(), 4);
assert_eq!(read2_fut.as_mut().poll(cx), Poll::Ready(()));
assert_eq!(read2_state.get(), 4);
}
#[test_log::test(tokio::test)]
async fn test_free_pages_on_transaction_drop() {
let mut f = MemFlash::new();
let db = Database::<_, NoopRawMutex>::new(&mut f, Config::default());
db.format().await.unwrap();
let prev_free = db.inner.lock().await.files.free_pages();
let mut wtx = db.write_transaction().await;
wtx.write(b"foo", b"4321").await.unwrap();
drop(wtx);
db.inner.lock().await.rollback().await.unwrap();
let now_free = db.inner.lock().await.files.free_pages();
assert_eq!(prev_free, now_free);
}
#[test_log::test(tokio::test)]
async fn test_buf_too_small() {
let mut f = MemFlash::new();
let db = Database::<_, NoopRawMutex>::new(&mut f, Config::default());
db.format().await.unwrap();
let mut wtx = db.write_transaction().await;
wtx.write(b"foo", b"1234").await.unwrap();
wtx.commit().await.unwrap();
let rtx = db.read_transaction().await;
let mut buf = [0u8; 1];
let r = rtx.read(b"foo", &mut buf).await;
assert!(matches!(r, Err(ReadError::BufferTooSmall)));
}
#[test_log::test(tokio::test)]
async fn test_unformatted_read() {
let mut f = MemFlash::new();
let db = Database::<_, NoopRawMutex>::new(&mut f, Config::default());
let rtx = db.read_transaction().await;
let mut buf = [0u8; 1];
let r = rtx.read(b"foo", &mut buf).await;
assert!(matches!(r, Err(ReadError::Corrupted)));
}
#[test_log::test(tokio::test)]
async fn test_unformatted_write() {
let mut f = MemFlash::new();
let db = Database::<_, NoopRawMutex>::new(&mut f, Config::default());
let mut wtx = db.write_transaction().await;
assert_eq!(wtx.write(b"bar", b"4321").await, Err(WriteError::Corrupted));
}
#[test_log::test(tokio::test)]
async fn test_remount() {
let mut f = MemFlash::new();
{
let db = Database::<_, NoopRawMutex>::new(&mut f, Config::default());
db.format().await.unwrap();
let mut wtx = db.write_transaction().await;
wtx.write(b"bar", b"4321").await.unwrap();
wtx.write(b"foo", b"1234").await.unwrap();
wtx.commit().await.unwrap();
}
{
let db = Database::new(&mut f, Config::default());
check_read(&db, b"foo", b"1234").await;
check_read(&db, b"bar", b"4321").await;
check_not_found(&db, b"baz").await;
let mut wtx = db.write_transaction().await;
wtx.write(b"bar", b"8765").await.unwrap();
wtx.write(b"baz", b"4242").await.unwrap();
wtx.write(b"foo", b"5678").await.unwrap();
wtx.commit().await.unwrap();
}
{
let db = Database::new(&mut f, Config::default());
check_read(&db, b"foo", b"5678").await;
check_read(&db, b"bar", b"8765").await;
check_read(&db, b"baz", b"4242").await;
let mut wtx = db.write_transaction().await;
wtx.write(b"lol", b"9999").await.unwrap();
wtx.commit().await.unwrap();
}
{
let db = Database::new(&mut f, Config::default());
check_read(&db, b"foo", b"5678").await;
check_read(&db, b"bar", b"8765").await;
check_read(&db, b"baz", b"4242").await;
check_read(&db, b"lol", b"9999").await;
}
}
#[test_log::test(tokio::test)]
async fn test_compact() {
let mut f = MemFlash::new();
let db = Database::<_, NoopRawMutex>::new(&mut f, Config::default());
db.format().await.unwrap();
let mut wtx = db.write_transaction().await;
wtx.write(b"foo", b"4321").await.unwrap();
wtx.commit().await.unwrap();
compact(&db).await;
let mut wtx = db.write_transaction().await;
wtx.write(b"bar", b"6666").await.unwrap();
wtx.write(b"foo", b"5555").await.unwrap();
wtx.commit().await.unwrap();
compact(&db).await;
check_read(&db, b"foo", b"5555").await;
check_read(&db, b"bar", b"6666").await;
}
#[test_log::test(tokio::test)]
async fn test_compact_removes_tombstones() {
let mut f = MemFlash::new();
let db = Database::<_, NoopRawMutex>::new(&mut f, Config::default());
db.format().await.unwrap();
let mut wtx = db.write_transaction().await;
wtx.write(b"foo", b"4321").await.unwrap();
wtx.commit().await.unwrap();
compact(&db).await;
let mut wtx = db.write_transaction().await;
wtx.delete(b"foo").await.unwrap();
wtx.commit().await.unwrap();
compact(&db).await;
let dbi = db.inner.lock().await;
assert!((0..FILE_COUNT).all(|i| dbi.files.is_empty(i as _)));
}
#[test_log::test(tokio::test)]
async fn test_write_not_sorted() {
let mut f = MemFlash::new();
let db = Database::<_, NoopRawMutex>::new(&mut f, Config::default());
db.format().await.unwrap();
let mut wtx = db.write_transaction().await;
wtx.write(b"foo", b"4321").await.unwrap();
assert_eq!(wtx.write(b"bar", b"4321").await, Err(WriteError::NotSorted));
}
}