1use aligned::Aligned;
2use block_device_driver::{slice_to_blocks, slice_to_blocks_mut, BlockDevice};
3use embedded_io_async::{ErrorKind, Read, Seek, SeekFrom, Write};
4
5#[derive(Copy, Clone, Debug, PartialEq, Eq)]
6#[cfg_attr(feature = "defmt", derive(defmt::Format))]
7#[non_exhaustive]
8pub enum BufStreamError<T> {
9 Io(T),
10}
11
12impl<T> From<T> for BufStreamError<T> {
13 fn from(t: T) -> Self {
14 BufStreamError::Io(t)
15 }
16}
17
18impl<T: core::fmt::Debug> embedded_io_async::Error for BufStreamError<T> {
19 fn kind(&self) -> ErrorKind {
20 ErrorKind::Other
21 }
22}
23
24pub struct BufStream<T: BlockDevice<SIZE>, const SIZE: usize> {
42 inner: T,
43 buffer: Aligned<T::Align, [u8; SIZE]>,
44 current_block: u32,
45 current_offset: u64,
46 dirty: bool,
47}
48
49impl<T: BlockDevice<SIZE>, const SIZE: usize> BufStream<T, SIZE> {
50 const ALIGN: usize = core::mem::align_of::<Aligned<T::Align, [u8; SIZE]>>();
51 pub fn new(inner: T) -> Self {
53 Self {
54 inner,
55 current_block: u32::MAX,
56 current_offset: 0,
57 buffer: Aligned([0; SIZE]),
58 dirty: false,
59 }
60 }
61
62 pub fn into_inner(self) -> T {
64 self.inner
65 }
66
67 #[inline]
68 fn pointer_block_start_addr(&self) -> u64 {
69 self.pointer_block_start() as u64 * SIZE as u64
70 }
71
72 #[inline]
73 fn pointer_block_start(&self) -> u32 {
74 (self.current_offset / SIZE as u64)
75 .try_into()
76 .expect("Block larger than 2TB")
77 }
78
79 async fn flush(&mut self) -> Result<(), T::Error> {
80 if self.dirty {
82 self.dirty = false;
83 self.inner
85 .write(self.current_block, slice_to_blocks(&self.buffer[..]))
86 .await?;
87 }
88 Ok(())
89 }
90
91 async fn check_cache(&mut self) -> Result<(), T::Error> {
92 let block_start = self.pointer_block_start();
93 if block_start != self.current_block {
94 self.flush().await?;
96 let buf = &mut self.buffer[..];
98 self.inner
99 .read(block_start, slice_to_blocks_mut(buf))
100 .await?;
101 self.current_block = block_start;
102 }
103 Ok(())
104 }
105}
106
107impl<T: BlockDevice<SIZE>, const SIZE: usize> embedded_io_async::ErrorType for BufStream<T, SIZE> {
108 type Error = BufStreamError<T::Error>;
109}
110
111impl<T: BlockDevice<SIZE>, const SIZE: usize> Read for BufStream<T, SIZE> {
112 async fn read(&mut self, mut buf: &mut [u8]) -> Result<usize, Self::Error> {
113 let mut total = 0;
114 let target = buf.len();
115 loop {
116 let bytes_read = if buf.len() % SIZE == 0
117 && buf.as_ptr().cast::<u8>() as usize % Self::ALIGN == 0
118 && self.current_offset % SIZE as u64 == 0
119 {
120 let block = self.pointer_block_start();
122 self.inner.read(block, slice_to_blocks_mut(buf)).await?;
123
124 buf.len()
125 } else {
126 let block_start = self.pointer_block_start_addr();
127 let block_end = block_start + SIZE as u64;
128 trace!(
129 "offset {}, block_start {}, block_end {}",
130 self.current_offset,
131 block_start,
132 block_end
133 );
134
135 self.check_cache().await?;
136
137 let buffer_offset = (self.current_offset - block_start) as usize;
139 let bytes_to_read = buf.len();
140
141 let end = core::cmp::min(buffer_offset + bytes_to_read, SIZE);
142 trace!("buffer_offset {}, end {}", buffer_offset, end);
143 let bytes_read = end - buffer_offset;
144 buf[..bytes_read].copy_from_slice(&self.buffer[buffer_offset..end]);
145 buf = &mut buf[bytes_read..]; bytes_read
148 };
149
150 self.current_offset += bytes_read as u64;
151 total += bytes_read;
152
153 if total == target {
154 return Ok(total);
155 }
156 }
157 }
158}
159
160impl<T: BlockDevice<SIZE>, const SIZE: usize> Write for BufStream<T, SIZE> {
161 async fn write(&mut self, mut buf: &[u8]) -> Result<usize, Self::Error> {
162 let mut total = 0;
163 let target = buf.len();
164 loop {
165 let bytes_written = if buf.len() % SIZE == 0
166 && buf.as_ptr().cast::<u8>() as usize % Self::ALIGN == 0
167 && self.current_offset % SIZE as u64 == 0
168 {
169 let block = self.pointer_block_start();
171 self.inner.write(block, slice_to_blocks(buf)).await?;
172
173 buf.len()
174 } else {
175 let block_start = self.pointer_block_start_addr();
176 let block_end = block_start + SIZE as u64;
177 trace!(
178 "offset {}, block_start {}, block_end {}",
179 self.current_offset,
180 block_start,
181 block_end
182 );
183
184 self.check_cache().await?;
186
187 let buffer_offset = (self.current_offset - block_start) as usize;
189 let bytes_to_write = buf.len();
190
191 let end = core::cmp::min(buffer_offset + bytes_to_write, SIZE);
192 trace!("buffer_offset {}, end {}", buffer_offset, end);
193 let bytes_written = end - buffer_offset;
194 self.buffer[buffer_offset..buffer_offset + bytes_written]
195 .copy_from_slice(&buf[..bytes_written]);
196 buf = &buf[bytes_written..]; self.dirty = true;
201
202 if block_start + end as u64 == block_end {
204 trace!("Flushing sector cache");
205 self.flush().await?;
206 }
207
208 bytes_written
209 };
210
211 self.current_offset += bytes_written as u64;
212 total += bytes_written;
213
214 if total == target {
215 return Ok(total);
216 }
217 }
218 }
219
220 async fn flush(&mut self) -> Result<(), Self::Error> {
221 self.flush().await?;
222 Ok(())
223 }
224}
225
226impl<T: BlockDevice<SIZE>, const SIZE: usize> Seek for BufStream<T, SIZE> {
227 async fn seek(&mut self, pos: SeekFrom) -> Result<u64, Self::Error> {
228 self.current_offset = match pos {
229 SeekFrom::Start(x) => x,
230 SeekFrom::End(x) => (self.inner.size().await? as i64 - x) as u64,
231 SeekFrom::Current(x) => (self.current_offset as i64 + x) as u64,
232 };
233 Ok(self.current_offset)
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use aligned::A4;
240 use embedded_io_async::ErrorType;
241
242 use super::*;
243
244 struct TestBlockDevice<T: Read + Write + Seek>(T);
245
246 impl<T: Read + Write + Seek> ErrorType for TestBlockDevice<T> {
247 type Error = T::Error;
248 }
249
250 impl<T: Read + Write + Seek> Read for TestBlockDevice<T> {
251 async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
252 Ok(self.0.read(buf).await?)
253 }
254 }
255
256 impl<T: Read + Write + Seek> Write for TestBlockDevice<T> {
257 async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
258 Ok(self.0.write(buf).await?)
259 }
260 }
261
262 impl<T: Read + Write + Seek> Seek for TestBlockDevice<T> {
263 async fn seek(&mut self, pos: SeekFrom) -> Result<u64, Self::Error> {
264 Ok(self.0.seek(pos).await?)
265 }
266 }
267
268 impl<T: Read + Write + Seek> BlockDevice<512> for TestBlockDevice<T> {
269 type Error = T::Error;
270 type Align = aligned::A4;
271
272 async fn read(
274 &mut self,
275 block_address: u32,
276 data: &mut [Aligned<Self::Align, [u8; 512]>],
277 ) -> Result<(), Self::Error> {
278 self.0
279 .seek(SeekFrom::Start((block_address * 512).into()))
280 .await?;
281 for b in data {
282 self.0.read(&mut b[..]).await?;
283 }
284 Ok(())
285 }
286
287 async fn write(
289 &mut self,
290 block_address: u32,
291 data: &[Aligned<Self::Align, [u8; 512]>],
292 ) -> Result<(), Self::Error> {
293 self.0
294 .seek(SeekFrom::Start((block_address * 512).into()))
295 .await?;
296 for b in data {
297 self.0.write(&b[..]).await?;
298 }
299 Ok(())
300 }
301
302 async fn size(&mut self) -> Result<u64, Self::Error> {
303 Ok(u64::MAX)
304 }
305 }
306
307 #[tokio::test]
308 async fn block_512_read_test() {
309 let _ = env_logger::builder().is_test(true).try_init();
310 let buf = ("A".repeat(512) + "B".repeat(512).as_str()).into_bytes();
311 let cur = std::io::Cursor::new(buf);
312 let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
313 embedded_io_adapters::tokio_1::FromTokio::new(cur),
314 ));
315
316 let mut buf = vec![0; 128];
318 block.seek(SeekFrom::Start(0)).await.unwrap();
319 block.read_exact(&mut buf[..]).await.unwrap();
320 assert_eq!(buf, "A".repeat(128).into_bytes());
321
322 let mut buf = vec![0; 128];
323 block.seek(SeekFrom::Start(512)).await.unwrap();
324 block.read_exact(&mut buf[..]).await.unwrap();
325 assert_eq!(buf, "B".repeat(128).into_bytes());
326
327 let mut buf = vec![0; 128];
329 block.seek(SeekFrom::Start(512 - 64)).await.unwrap();
330 block.read_exact(&mut buf[..]).await.unwrap();
331 assert_eq!(buf, ("A".repeat(64) + "B".repeat(64).as_str()).into_bytes());
332 }
333
334 #[tokio::test]
335 async fn block_512_read_successive() {
336 let _ = env_logger::builder().is_test(true).try_init();
337 let buf = ("A".repeat(64) + "B".repeat(64).as_str())
338 .repeat(16)
339 .into_bytes();
340 let cur = std::io::Cursor::new(buf);
341 let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
342 embedded_io_adapters::tokio_1::FromTokio::new(cur),
343 ));
344
345 let mut buf = vec![0; 64];
347 block.seek(SeekFrom::Start(0)).await.unwrap();
348 block.read_exact(&mut buf[..]).await.unwrap();
349 assert_eq!(buf, "A".repeat(64).into_bytes());
350
351 let mut buf = vec![0; 64];
352 block.seek(SeekFrom::Start(64)).await.unwrap();
353 block.read_exact(&mut buf[..]).await.unwrap();
354 assert_eq!(buf, "B".repeat(64).into_bytes());
355
356 let mut buf = vec![0; 64];
357 block.seek(SeekFrom::Start(32)).await.unwrap();
358 block.read_exact(&mut buf[..]).await.unwrap();
359 assert_eq!(buf, ("A".repeat(32) + "B".repeat(32).as_str()).into_bytes());
360 }
361
362 #[tokio::test]
363 async fn block_512_write_single_sector() {
364 let _ = env_logger::builder().is_test(true).try_init();
365 let buf = vec![0; 2048];
366 let cur = std::io::Cursor::new(buf);
367 let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
368 embedded_io_adapters::tokio_1::FromTokio::new(cur),
369 ));
370
371 let data_a = "A".repeat(512).into_bytes();
373 block.seek(SeekFrom::Start(0)).await.unwrap();
374 block.write_all(&data_a).await.unwrap();
375 assert_eq!(
376 &block.into_inner().0.into_inner().into_inner()[..512],
377 data_a
378 )
379 }
380
381 #[tokio::test]
382 async fn block_512_write_across_sectors() {
383 let _ = env_logger::builder().is_test(true).try_init();
384 let buf = vec![0; 2048];
385 let cur = std::io::Cursor::new(buf);
386 let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
387 embedded_io_adapters::tokio_1::FromTokio::new(cur),
388 ));
389
390 let data_a = "A".repeat(512).into_bytes();
392 block.seek(SeekFrom::Start(256)).await.unwrap();
393 block.write_all(&data_a).await.unwrap();
394 block.flush().await.unwrap();
395 let buf = block.into_inner().0.into_inner().into_inner();
396 assert_eq!(&buf[..256], [0; 256]);
397 assert_eq!(&buf[256..768], data_a);
398 assert_eq!(&buf[768..1024], [0; 256]);
399 }
400
401 #[tokio::test]
402 async fn aligned_write_block_optimization() {
403 let _ = env_logger::builder().is_test(true).try_init();
404 let buf = vec![0; 2048];
405 let cur = std::io::Cursor::new(buf);
406 let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
407 embedded_io_adapters::tokio_1::FromTokio::new(cur),
408 ));
409
410 let mut aligned_buffer: Aligned<A4, [u8; 512]> = Aligned([0; 512]);
411 let data_a = "A".repeat(512).into_bytes();
412 aligned_buffer[..].copy_from_slice(&data_a[..]);
413 block.seek(SeekFrom::Start(0)).await.unwrap();
414 block.write_all(&aligned_buffer[..]).await.unwrap();
415
416 assert_eq!(&block.buffer[..], [0u8; 512]);
418 assert_eq!(block.current_offset, 512);
420 assert_eq!(
422 &block.into_inner().0.into_inner().into_inner()[..512],
423 &data_a
424 )
425 }
426
427 #[tokio::test]
428 async fn aligned_write_block_optimization_misaligned_block() {
429 let _ = env_logger::builder().is_test(true).try_init();
430 let buf = vec![0; 2048];
431 let cur = std::io::Cursor::new(buf);
432 let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
433 embedded_io_adapters::tokio_1::FromTokio::new(cur),
434 ));
435
436 let mut aligned_buffer: Aligned<A4, [u8; 2048]> = Aligned([0; 2048]);
437 let data_a = "A".repeat(512).into_bytes();
438 aligned_buffer[..512].copy_from_slice(&data_a[..]);
439 block.seek(SeekFrom::Start(3)).await.unwrap();
441 block.write_all(&aligned_buffer[..512]).await.unwrap();
443 block.flush().await.unwrap();
444
445 assert_ne!(&block.buffer[..], [0u8; 512]);
447 assert_eq!(
449 &block.into_inner().0.into_inner().into_inner()[3..515],
450 &data_a
451 )
452 }
453
454 #[tokio::test]
455 async fn aligned_read_block_optimization() {
456 let _ = env_logger::builder().is_test(true).try_init();
457 let buf = "A".repeat(2048).into_bytes();
458 let cur = std::io::Cursor::new(buf);
459 let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
460 embedded_io_adapters::tokio_1::FromTokio::new(cur),
461 ));
462
463 let mut aligned_buffer: Aligned<A4, [u8; 512]> = Aligned([0; 512]);
464 block.seek(SeekFrom::Start(0)).await.unwrap();
465 block.read_exact(&mut aligned_buffer[..]).await.unwrap();
466
467 assert_eq!(&block.buffer[..], [0u8; 512]);
469 assert_eq!(block.current_offset, 512);
471 assert_eq!(
473 &block.into_inner().0.into_inner().into_inner()[..512],
474 &aligned_buffer[..]
475 )
476 }
477
478 #[tokio::test]
479 async fn aligned_read_block_optimization_misaligned() {
480 let _ = env_logger::builder().is_test(true).try_init();
481 let buf = "A".repeat(2048).into_bytes();
482 let cur = std::io::Cursor::new(buf);
483 let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
484 embedded_io_adapters::tokio_1::FromTokio::new(cur),
485 ));
486
487 let mut aligned_buffer: Aligned<A4, [u8; 512]> = Aligned([0; 512]);
488 block.seek(SeekFrom::Start(3)).await.unwrap();
490 block.read_exact(&mut aligned_buffer[..]).await.unwrap();
492
493 assert_ne!(&block.buffer[..], [0u8; 512]);
496
497 assert_eq!(
499 &block.into_inner().0.into_inner().into_inner()[3..512],
500 &aligned_buffer[3..]
501 )
502 }
503
504 #[tokio::test]
505 async fn write_seek_read_write() {
506 let _ = env_logger::builder().is_test(true).try_init();
507 let buf = "A".repeat(2048).into_bytes();
508 let cur = std::io::Cursor::new(buf);
509 let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
510 embedded_io_adapters::tokio_1::FromTokio::new(cur),
511 ));
512
513 block.seek(SeekFrom::Start(524)).await.unwrap();
514 block
515 .write_all(&"B".repeat(512).into_bytes())
516 .await
517 .unwrap();
518 block.flush().await.unwrap();
519
520 block.seek(SeekFrom::Start(0)).await.unwrap();
521 let mut tmp = [0u8; 256];
522 block.read(&mut tmp[..]).await.unwrap();
523
524 assert_eq!(&tmp[..], "A".repeat(256).into_bytes().as_slice());
525
526 block.seek(SeekFrom::Start(524 + 512)).await.unwrap();
527 block
528 .write_all(&"C".repeat(512).into_bytes())
529 .await
530 .unwrap();
531 block.flush().await.unwrap();
532
533 let buf = block.into_inner().0.into_inner().into_inner();
534
535 assert_eq!(
536 buf,
537 ("A".repeat(524) + &"B".repeat(512) + &"C".repeat(512) + &"A".repeat(500)).into_bytes()
538 )
539 }
540}