use block_device_driver::BlockDevice;
use elain::{Align, Alignment};
use embedded_io_async::{ErrorKind, Read, Seek, SeekFrom, Write};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[non_exhaustive]
pub enum BufStreamError<T> {
Io(T),
}
impl<T> From<T> for BufStreamError<T> {
fn from(t: T) -> Self {
BufStreamError::Io(t)
}
}
impl<T: core::fmt::Debug> embedded_io_async::Error for BufStreamError<T> {
fn kind(&self) -> ErrorKind {
ErrorKind::Other
}
}
pub struct BufStream<T: BlockDevice<SIZE>, const SIZE: usize, const ALIGN: usize>
where
Align<ALIGN>: Alignment,
{
inner: T,
buffer: AlignedBuffer<SIZE, ALIGN>,
current_block: u32,
current_offset: u64,
dirty: bool,
}
impl<T: BlockDevice<SIZE>, const SIZE: usize, const ALIGN: usize> BufStream<T, SIZE, ALIGN>
where
Align<ALIGN>: Alignment,
{
pub fn new(inner: T) -> Self {
Self {
inner,
current_block: u32::MAX,
current_offset: 0,
buffer: AlignedBuffer::new(),
dirty: false,
}
}
pub fn into_inner(self) -> T {
self.inner
}
#[inline]
fn pointer_block_start_addr(&self) -> u64 {
self.pointer_block_start() as u64 * SIZE as u64
}
#[inline]
fn pointer_block_start(&self) -> u32 {
(self.current_offset / SIZE as u64)
.try_into()
.expect("Block larger than 2TB")
}
async fn flush(&mut self) -> Result<(), T::Error> {
if self.dirty {
self.dirty = false;
self.inner
.write(self.current_block, slice_to_blocks(&self.buffer[..]))
.await?;
}
Ok(())
}
async fn check_cache(&mut self) -> Result<(), T::Error> {
let block_start = self.pointer_block_start();
if block_start != self.current_block {
self.flush().await?;
let buf = &mut self.buffer[..];
self.inner
.read(block_start, slice_to_blocks_mut(buf))
.await?;
self.current_block = block_start;
}
Ok(())
}
}
impl<T: BlockDevice<SIZE>, const SIZE: usize, const ALIGN: usize> embedded_io_async::ErrorType
for BufStream<T, SIZE, ALIGN>
where
Align<ALIGN>: Alignment,
{
type Error = BufStreamError<T::Error>;
}
impl<T: BlockDevice<SIZE>, const SIZE: usize, const ALIGN: usize> Read for BufStream<T, SIZE, ALIGN>
where
Align<ALIGN>: Alignment,
{
async fn read(&mut self, mut buf: &mut [u8]) -> Result<usize, Self::Error> {
let mut total = 0;
let target = buf.len();
loop {
let bytes_read = if buf.len() % SIZE == 0
&& &buf[0] as *const _ as usize % ALIGN == 0
&& self.current_offset % SIZE as u64 == 0
{
let block = self.pointer_block_start();
self.inner.read(block, slice_to_blocks_mut(buf)).await?;
buf.len()
} else {
let block_start = self.pointer_block_start_addr();
let block_end = block_start + SIZE as u64;
trace!(
"offset {}, block_start {}, block_end {}",
self.current_offset,
block_start,
block_end
);
self.check_cache().await?;
let buffer_offset = (self.current_offset - block_start) as usize;
let bytes_to_read = buf.len();
let end = core::cmp::min(buffer_offset + bytes_to_read, SIZE);
trace!("buffer_offset {}, end {}", buffer_offset, end);
let bytes_read = end - buffer_offset;
buf[..bytes_read].copy_from_slice(&self.buffer[buffer_offset..end]);
buf = &mut buf[bytes_read..]; bytes_read
};
self.current_offset += bytes_read as u64;
total += bytes_read;
if total == target {
return Ok(total);
}
}
}
}
impl<T: BlockDevice<SIZE>, const SIZE: usize, const ALIGN: usize> Write
for BufStream<T, SIZE, ALIGN>
where
Align<ALIGN>: Alignment,
{
async fn write(&mut self, mut buf: &[u8]) -> Result<usize, Self::Error> {
let mut total = 0;
let target = buf.len();
loop {
let bytes_written = if buf.len() % SIZE == 0
&& &buf[0] as *const _ as usize % ALIGN == 0
&& self.current_offset % SIZE as u64 == 0
{
let block = self.pointer_block_start();
self.inner.write(block, slice_to_blocks(buf)).await?;
buf.len()
} else {
let block_start = self.pointer_block_start_addr();
let block_end = block_start + SIZE as u64;
trace!(
"offset {}, block_start {}, block_end {}",
self.current_offset,
block_start,
block_end
);
self.check_cache().await?;
let buffer_offset = (self.current_offset - block_start) as usize;
let bytes_to_write = buf.len();
let end = core::cmp::min(buffer_offset + bytes_to_write, SIZE);
trace!("buffer_offset {}, end {}", buffer_offset, end);
let bytes_written = end - buffer_offset;
self.buffer[buffer_offset..buffer_offset + bytes_written]
.copy_from_slice(&buf[..bytes_written]);
buf = &buf[bytes_written..]; self.dirty = true;
if block_start + end as u64 == block_end {
trace!("Flushing sector cache");
self.flush().await?;
}
bytes_written
};
self.current_offset += bytes_written as u64;
total += bytes_written;
if total == target {
return Ok(total);
}
}
}
async fn flush(&mut self) -> Result<(), Self::Error> {
self.flush().await?;
Ok(())
}
}
fn slice_to_blocks<const SIZE: usize>(slice: &[u8]) -> &[[u8; SIZE]] {
assert!(slice.len() % SIZE == 0);
unsafe { core::slice::from_raw_parts(slice.as_ptr() as *const [u8; SIZE], slice.len() / SIZE) }
}
fn slice_to_blocks_mut<const SIZE: usize>(slice: &mut [u8]) -> &mut [[u8; SIZE]] {
assert!(slice.len() % SIZE == 0);
unsafe {
core::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut [u8; SIZE], slice.len() / SIZE)
}
}
impl<T: BlockDevice<SIZE>, const SIZE: usize, const ALIGN: usize> Seek for BufStream<T, SIZE, ALIGN>
where
Align<ALIGN>: Alignment,
{
async fn seek(&mut self, pos: SeekFrom) -> Result<u64, Self::Error> {
self.current_offset = match pos {
SeekFrom::Start(x) => x,
SeekFrom::End(x) => (self.inner.size().await? as i64 - x) as u64,
SeekFrom::Current(x) => (self.current_offset as i64 + x) as u64,
};
Ok(self.current_offset)
}
}
#[derive(Clone)]
struct AlignedBuffer<const SIZE: usize, const ALIGN: usize>
where
Align<ALIGN>: Alignment,
{
_align: Align<ALIGN>,
buffer: [u8; SIZE],
}
impl<const SIZE: usize, const ALIGN: usize> AlignedBuffer<SIZE, ALIGN>
where
Align<ALIGN>: Alignment,
{
pub const fn new() -> Self {
Self {
_align: Align::NEW,
buffer: [0; SIZE],
}
}
}
impl<const SIZE: usize, const ALIGN: usize> core::ops::Deref for AlignedBuffer<SIZE, ALIGN>
where
Align<ALIGN>: Alignment,
{
type Target = [u8; SIZE];
fn deref(&self) -> &Self::Target {
&self.buffer
}
}
impl<const SIZE: usize, const ALIGN: usize> core::ops::DerefMut for AlignedBuffer<SIZE, ALIGN>
where
Align<ALIGN>: Alignment,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.buffer
}
}
#[cfg(test)]
mod tests {
use embedded_io_async::ErrorType;
use super::{BufStream, *};
struct TestBlockDevice<T: Read + Write + Seek>(T);
impl<T: Read + Write + Seek> ErrorType for TestBlockDevice<T> {
type Error = T::Error;
}
impl<T: Read + Write + Seek> Read for TestBlockDevice<T> {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
Ok(self.0.read(buf).await?)
}
}
impl<T: Read + Write + Seek> Write for TestBlockDevice<T> {
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
Ok(self.0.write(buf).await?)
}
}
impl<T: Read + Write + Seek> Seek for TestBlockDevice<T> {
async fn seek(&mut self, pos: SeekFrom) -> Result<u64, Self::Error> {
Ok(self.0.seek(pos).await?)
}
}
impl<T: Read + Write + Seek> BlockDevice<512> for TestBlockDevice<T> {
type Error = T::Error;
async fn read(
&mut self,
block_address: u32,
data: &mut [[u8; 512]],
) -> Result<(), Self::Error> {
self.0
.seek(SeekFrom::Start((block_address * 512).into()))
.await?;
for b in data {
self.0.read(b).await?;
}
Ok(())
}
async fn write(
&mut self,
block_address: u32,
data: &[[u8; 512]],
) -> Result<(), Self::Error> {
self.0
.seek(SeekFrom::Start((block_address * 512).into()))
.await?;
for b in data {
self.0.write(b).await?;
}
Ok(())
}
async fn size(&mut self) -> Result<u64, Self::Error> {
Ok(u64::MAX)
}
}
#[tokio::test]
async fn block_512_read_test() {
let _ = env_logger::builder().is_test(true).try_init();
let buf = ("A".repeat(512) + "B".repeat(512).as_str()).into_bytes();
let cur = std::io::Cursor::new(buf);
let mut block: BufStream<_, 512, 4> = BufStream::new(TestBlockDevice(
embedded_io_adapters::tokio_1::FromTokio::new(cur),
));
let mut buf = vec![0; 128];
block.seek(SeekFrom::Start(0)).await.unwrap();
block.read_exact(&mut buf[..]).await.unwrap();
assert_eq!(buf, "A".repeat(128).into_bytes());
let mut buf = vec![0; 128];
block.seek(SeekFrom::Start(512)).await.unwrap();
block.read_exact(&mut buf[..]).await.unwrap();
assert_eq!(buf, "B".repeat(128).into_bytes());
let mut buf = vec![0; 128];
block.seek(SeekFrom::Start(512 - 64)).await.unwrap();
block.read_exact(&mut buf[..]).await.unwrap();
assert_eq!(buf, ("A".repeat(64) + "B".repeat(64).as_str()).into_bytes());
}
#[tokio::test]
async fn block_512_read_successive() {
let _ = env_logger::builder().is_test(true).try_init();
let buf = ("A".repeat(64) + "B".repeat(64).as_str())
.repeat(16)
.into_bytes();
let cur = std::io::Cursor::new(buf);
let mut block: BufStream<_, 512, 4> = BufStream::new(TestBlockDevice(
embedded_io_adapters::tokio_1::FromTokio::new(cur),
));
let mut buf = vec![0; 64];
block.seek(SeekFrom::Start(0)).await.unwrap();
block.read_exact(&mut buf[..]).await.unwrap();
assert_eq!(buf, "A".repeat(64).into_bytes());
let mut buf = vec![0; 64];
block.seek(SeekFrom::Start(64)).await.unwrap();
block.read_exact(&mut buf[..]).await.unwrap();
assert_eq!(buf, "B".repeat(64).into_bytes());
let mut buf = vec![0; 64];
block.seek(SeekFrom::Start(32)).await.unwrap();
block.read_exact(&mut buf[..]).await.unwrap();
assert_eq!(buf, ("A".repeat(32) + "B".repeat(32).as_str()).into_bytes());
}
#[tokio::test]
async fn block_512_write_single_sector() {
let _ = env_logger::builder().is_test(true).try_init();
let buf = vec![0; 2048];
let cur = std::io::Cursor::new(buf);
let mut block: BufStream<_, 512, 4> = BufStream::new(TestBlockDevice(
embedded_io_adapters::tokio_1::FromTokio::new(cur),
));
let data_a = "A".repeat(512).into_bytes();
block.seek(SeekFrom::Start(0)).await.unwrap();
block.write_all(&data_a).await.unwrap();
assert_eq!(
&block.into_inner().0.into_inner().into_inner()[..512],
data_a
)
}
#[tokio::test]
async fn block_512_write_across_sectors() {
let _ = env_logger::builder().is_test(true).try_init();
let buf = vec![0; 2048];
let cur = std::io::Cursor::new(buf);
let mut block: BufStream<_, 512, 4> = BufStream::new(TestBlockDevice(
embedded_io_adapters::tokio_1::FromTokio::new(cur),
));
let data_a = "A".repeat(512).into_bytes();
block.seek(SeekFrom::Start(256)).await.unwrap();
block.write_all(&data_a).await.unwrap();
block.flush().await.unwrap();
let buf = block.into_inner().0.into_inner().into_inner();
assert_eq!(&buf[..256], [0; 256]);
assert_eq!(&buf[256..768], data_a);
assert_eq!(&buf[768..1024], [0; 256]);
}
#[tokio::test]
async fn aligned_write_block_optimization() {
let _ = env_logger::builder().is_test(true).try_init();
let buf = vec![0; 2048];
let cur = std::io::Cursor::new(buf);
let mut block: BufStream<_, 512, 4> = BufStream::new(TestBlockDevice(
embedded_io_adapters::tokio_1::FromTokio::new(cur),
));
let mut aligned_buffer: AlignedBuffer<512, 4> = AlignedBuffer::new();
let data_a = "A".repeat(512).into_bytes();
aligned_buffer[..].copy_from_slice(&data_a[..]);
block.seek(SeekFrom::Start(0)).await.unwrap();
block.write_all(&aligned_buffer[..]).await.unwrap();
assert_eq!(&block.buffer[..], [0u8; 512]);
assert_eq!(block.current_offset, 512);
assert_eq!(
&block.into_inner().0.into_inner().into_inner()[..512],
&data_a
)
}
#[tokio::test]
async fn aligned_write_block_optimization_misaligned_block() {
let _ = env_logger::builder().is_test(true).try_init();
let buf = vec![0; 2048];
let cur = std::io::Cursor::new(buf);
let mut block: BufStream<_, 512, 4> = BufStream::new(TestBlockDevice(
embedded_io_adapters::tokio_1::FromTokio::new(cur),
));
let mut aligned_buffer: AlignedBuffer<2048, 4> = AlignedBuffer::new();
let data_a = "A".repeat(512).into_bytes();
aligned_buffer[..512].copy_from_slice(&data_a[..]);
block.seek(SeekFrom::Start(3)).await.unwrap();
block.write_all(&aligned_buffer[..512]).await.unwrap();
block.flush().await.unwrap();
assert_ne!(&block.buffer[..], [0u8; 512]);
assert_eq!(
&block.into_inner().0.into_inner().into_inner()[3..515],
&data_a
)
}
#[tokio::test]
async fn aligned_read_block_optimization() {
let _ = env_logger::builder().is_test(true).try_init();
let buf = "A".repeat(2048).into_bytes();
let cur = std::io::Cursor::new(buf);
let mut block: BufStream<_, 512, 4> = BufStream::new(TestBlockDevice(
embedded_io_adapters::tokio_1::FromTokio::new(cur),
));
let mut aligned_buffer: AlignedBuffer<512, 4> = AlignedBuffer::new();
block.seek(SeekFrom::Start(0)).await.unwrap();
block.read_exact(&mut aligned_buffer[..]).await.unwrap();
assert_eq!(&block.buffer[..], [0u8; 512]);
assert_eq!(block.current_offset, 512);
assert_eq!(
&block.into_inner().0.into_inner().into_inner()[..512],
&aligned_buffer[..]
)
}
#[tokio::test]
async fn aligned_read_block_optimization_misaligned() {
let _ = env_logger::builder().is_test(true).try_init();
let buf = "A".repeat(2048).into_bytes();
let cur = std::io::Cursor::new(buf);
let mut block: BufStream<_, 512, 4> = BufStream::new(TestBlockDevice(
embedded_io_adapters::tokio_1::FromTokio::new(cur),
));
let mut aligned_buffer: AlignedBuffer<512, 4> = AlignedBuffer::new();
block.seek(SeekFrom::Start(3)).await.unwrap();
block.read_exact(&mut aligned_buffer[..]).await.unwrap();
assert_ne!(&block.buffer[..], [0u8; 512]);
assert_eq!(
&block.into_inner().0.into_inner().into_inner()[3..512],
&aligned_buffer[3..]
)
}
#[tokio::test]
async fn write_seek_read_write() {
let _ = env_logger::builder().is_test(true).try_init();
let buf = "A".repeat(2048).into_bytes();
let cur = std::io::Cursor::new(buf);
let mut block: BufStream<_, 512, 4> = BufStream::new(TestBlockDevice(
embedded_io_adapters::tokio_1::FromTokio::new(cur),
));
block.seek(SeekFrom::Start(524)).await.unwrap();
block
.write_all(&"B".repeat(512).into_bytes())
.await
.unwrap();
block.flush().await.unwrap();
block.seek(SeekFrom::Start(0)).await.unwrap();
let mut tmp = [0u8; 256];
block.read(&mut tmp[..]).await.unwrap();
assert_eq!(&tmp[..], "A".repeat(256).into_bytes().as_slice());
block.seek(SeekFrom::Start(524 + 512)).await.unwrap();
block
.write_all(&"C".repeat(512).into_bytes())
.await
.unwrap();
block.flush().await.unwrap();
let buf = block.into_inner().0.into_inner().into_inner();
assert_eq!(
buf,
("A".repeat(524) + &"B".repeat(512) + &"C".repeat(512) + &"A".repeat(500)).into_bytes()
)
}
}