block_device_adapters/
buf_stream.rs

1use aligned::Aligned;
2use block_device_driver::{slice_to_blocks, slice_to_blocks_mut, BlockDevice};
3use embedded_io_async::{ErrorKind, Read, Seek, SeekFrom, Write};
4
5#[derive(Copy, Clone, Debug, PartialEq, Eq)]
6#[cfg_attr(feature = "defmt", derive(defmt::Format))]
7#[non_exhaustive]
8pub enum BufStreamError<T> {
9    Io(T),
10}
11
12impl<T> From<T> for BufStreamError<T> {
13    fn from(t: T) -> Self {
14        BufStreamError::Io(t)
15    }
16}
17
18impl<T: core::fmt::Debug> embedded_io_async::Error for BufStreamError<T> {
19    fn kind(&self) -> ErrorKind {
20        ErrorKind::Other
21    }
22}
23
24/// A Stream wrapper for accessing a stream in block sized chunks.
25///
26/// [`BufStream<T, const SIZE: usize, const ALIGN: usize`](BufStream) can be initialized with the following parameters.
27///
28/// - `T`: The inner stream.
29/// - `SIZE`: The size of the block, this dictates the size of the internal buffer.
30/// - `ALIGN`: The alignment of the internal buffer.
31///
32/// If the `buf` provided to either [`Read::read`] or [`Write::write`] meets the following conditions the `buf`
33/// will be used directly instead of the intermediate buffer to avoid unnecessary copies:
34///
35/// - `buf.len()` is a multiple of block size
36/// - `buf` has the same alignment as the internal buffer
37/// - The byte address of the inner device is aligned to a block size.
38///
39/// [`BufStream<T, const SIZE: usize, const ALIGN: usize`](BufStream) implements the [`embedded_io_async`] traits, and implicitly
40/// handles the RMW (Read, Modify, Write) cycle for you.
41pub struct BufStream<T: BlockDevice<SIZE>, const SIZE: usize> {
42    inner: T,
43    buffer: Aligned<T::Align, [u8; SIZE]>,
44    current_block: u32,
45    current_offset: u64,
46    dirty: bool,
47}
48
49impl<T: BlockDevice<SIZE>, const SIZE: usize> BufStream<T, SIZE> {
50    const ALIGN: usize = core::mem::align_of::<Aligned<T::Align, [u8; SIZE]>>();
51    /// Create a new [`BufStream`] around a hardware block device.
52    pub fn new(inner: T) -> Self {
53        Self {
54            inner,
55            current_block: u32::MAX,
56            current_offset: 0,
57            buffer: Aligned([0; SIZE]),
58            dirty: false,
59        }
60    }
61
62    /// Returns inner object.
63    pub fn into_inner(self) -> T {
64        self.inner
65    }
66
67    #[inline]
68    fn pointer_block_start_addr(&self) -> u64 {
69        self.pointer_block_start() as u64 * SIZE as u64
70    }
71
72    #[inline]
73    fn pointer_block_start(&self) -> u32 {
74        (self.current_offset / SIZE as u64)
75            .try_into()
76            .expect("Block larger than 2TB")
77    }
78
79    async fn flush(&mut self) -> Result<(), T::Error> {
80        // flush the internal buffer if we have modified the buffer
81        if self.dirty {
82            self.dirty = false;
83            // Note, alignment of internal buffer is guarenteed at compile time so we don't have to check it here
84            self.inner
85                .write(self.current_block, slice_to_blocks(&self.buffer[..]))
86                .await?;
87        }
88        Ok(())
89    }
90
91    async fn check_cache(&mut self) -> Result<(), T::Error> {
92        let block_start = self.pointer_block_start();
93        if block_start != self.current_block {
94            // we may have modified data in old block, flush it to disk
95            self.flush().await?;
96            // We have seeked to a new block, read it
97            let buf = &mut self.buffer[..];
98            self.inner
99                .read(block_start, slice_to_blocks_mut(buf))
100                .await?;
101            self.current_block = block_start;
102        }
103        Ok(())
104    }
105}
106
107impl<T: BlockDevice<SIZE>, const SIZE: usize> embedded_io_async::ErrorType for BufStream<T, SIZE> {
108    type Error = BufStreamError<T::Error>;
109}
110
111impl<T: BlockDevice<SIZE>, const SIZE: usize> Read for BufStream<T, SIZE> {
112    async fn read(&mut self, mut buf: &mut [u8]) -> Result<usize, Self::Error> {
113        let mut total = 0;
114        let target = buf.len();
115        loop {
116            let bytes_read = if buf.len() % SIZE == 0
117                && buf.as_ptr().cast::<u8>() as usize % Self::ALIGN == 0
118                && self.current_offset % SIZE as u64 == 0
119            {
120                // If the provided buffer has a suitable length and alignment _and_ the read head is on a block boundary, use it directly
121                let block = self.pointer_block_start();
122                self.inner.read(block, slice_to_blocks_mut(buf)).await?;
123
124                buf.len()
125            } else {
126                let block_start = self.pointer_block_start_addr();
127                let block_end = block_start + SIZE as u64;
128                trace!(
129                    "offset {}, block_start {}, block_end {}",
130                    self.current_offset,
131                    block_start,
132                    block_end
133                );
134
135                self.check_cache().await?;
136
137                // copy as much as possible, up to the block boundary
138                let buffer_offset = (self.current_offset - block_start) as usize;
139                let bytes_to_read = buf.len();
140
141                let end = core::cmp::min(buffer_offset + bytes_to_read, SIZE);
142                trace!("buffer_offset {}, end {}", buffer_offset, end);
143                let bytes_read = end - buffer_offset;
144                buf[..bytes_read].copy_from_slice(&self.buffer[buffer_offset..end]);
145                buf = &mut buf[bytes_read..]; // move the buffer along
146
147                bytes_read
148            };
149
150            self.current_offset += bytes_read as u64;
151            total += bytes_read;
152
153            if total == target {
154                return Ok(total);
155            }
156        }
157    }
158}
159
160impl<T: BlockDevice<SIZE>, const SIZE: usize> Write for BufStream<T, SIZE> {
161    async fn write(&mut self, mut buf: &[u8]) -> Result<usize, Self::Error> {
162        let mut total = 0;
163        let target = buf.len();
164        loop {
165            let bytes_written = if buf.len() % SIZE == 0
166                && buf.as_ptr().cast::<u8>() as usize % Self::ALIGN == 0
167                && self.current_offset % SIZE as u64 == 0
168            {
169                // If the provided buffer has a suitable length and alignment _and_ the write head is on a block boundary, use it directly
170                let block = self.pointer_block_start();
171                self.inner.write(block, slice_to_blocks(buf)).await?;
172
173                buf.len()
174            } else {
175                let block_start = self.pointer_block_start_addr();
176                let block_end = block_start + SIZE as u64;
177                trace!(
178                    "offset {}, block_start {}, block_end {}",
179                    self.current_offset,
180                    block_start,
181                    block_end
182                );
183
184                // reload the cache if we need to
185                self.check_cache().await?;
186
187                // copy as much as possible, up to the block boundary
188                let buffer_offset = (self.current_offset - block_start) as usize;
189                let bytes_to_write = buf.len();
190
191                let end = core::cmp::min(buffer_offset + bytes_to_write, SIZE);
192                trace!("buffer_offset {}, end {}", buffer_offset, end);
193                let bytes_written = end - buffer_offset;
194                self.buffer[buffer_offset..buffer_offset + bytes_written]
195                    .copy_from_slice(&buf[..bytes_written]);
196                buf = &buf[bytes_written..]; // move the buffer along
197
198                // If we haven't written directly, we will use the cache, which will may need to flush later
199                // so we mark it as dirty
200                self.dirty = true;
201
202                // write out the whole block with the modified data
203                if block_start + end as u64 == block_end {
204                    trace!("Flushing sector cache");
205                    self.flush().await?;
206                }
207
208                bytes_written
209            };
210
211            self.current_offset += bytes_written as u64;
212            total += bytes_written;
213
214            if total == target {
215                return Ok(total);
216            }
217        }
218    }
219
220    async fn flush(&mut self) -> Result<(), Self::Error> {
221        self.flush().await?;
222        Ok(())
223    }
224}
225
226impl<T: BlockDevice<SIZE>, const SIZE: usize> Seek for BufStream<T, SIZE> {
227    async fn seek(&mut self, pos: SeekFrom) -> Result<u64, Self::Error> {
228        self.current_offset = match pos {
229            SeekFrom::Start(x) => x,
230            SeekFrom::End(x) => (self.inner.size().await? as i64 - x) as u64,
231            SeekFrom::Current(x) => (self.current_offset as i64 + x) as u64,
232        };
233        Ok(self.current_offset)
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use aligned::A4;
240    use embedded_io_async::ErrorType;
241
242    use super::*;
243
244    struct TestBlockDevice<T: Read + Write + Seek>(T);
245
246    impl<T: Read + Write + Seek> ErrorType for TestBlockDevice<T> {
247        type Error = T::Error;
248    }
249
250    impl<T: Read + Write + Seek> Read for TestBlockDevice<T> {
251        async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
252            Ok(self.0.read(buf).await?)
253        }
254    }
255
256    impl<T: Read + Write + Seek> Write for TestBlockDevice<T> {
257        async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
258            Ok(self.0.write(buf).await?)
259        }
260    }
261
262    impl<T: Read + Write + Seek> Seek for TestBlockDevice<T> {
263        async fn seek(&mut self, pos: SeekFrom) -> Result<u64, Self::Error> {
264            Ok(self.0.seek(pos).await?)
265        }
266    }
267
268    impl<T: Read + Write + Seek> BlockDevice<512> for TestBlockDevice<T> {
269        type Error = T::Error;
270        type Align = aligned::A4;
271
272        /// Read one or more blocks at the given block address.
273        async fn read(
274            &mut self,
275            block_address: u32,
276            data: &mut [Aligned<Self::Align, [u8; 512]>],
277        ) -> Result<(), Self::Error> {
278            self.0
279                .seek(SeekFrom::Start((block_address * 512).into()))
280                .await?;
281            for b in data {
282                self.0.read(&mut b[..]).await?;
283            }
284            Ok(())
285        }
286
287        /// Write one or more blocks at the given block address.
288        async fn write(
289            &mut self,
290            block_address: u32,
291            data: &[Aligned<Self::Align, [u8; 512]>],
292        ) -> Result<(), Self::Error> {
293            self.0
294                .seek(SeekFrom::Start((block_address * 512).into()))
295                .await?;
296            for b in data {
297                self.0.write(&b[..]).await?;
298            }
299            Ok(())
300        }
301
302        async fn size(&mut self) -> Result<u64, Self::Error> {
303            Ok(u64::MAX)
304        }
305    }
306
307    #[tokio::test]
308    async fn block_512_read_test() {
309        let _ = env_logger::builder().is_test(true).try_init();
310        let buf = ("A".repeat(512) + "B".repeat(512).as_str()).into_bytes();
311        let cur = std::io::Cursor::new(buf);
312        let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
313            embedded_io_adapters::tokio_1::FromTokio::new(cur),
314        ));
315
316        // Test sector aligned access
317        let mut buf = vec![0; 128];
318        block.seek(SeekFrom::Start(0)).await.unwrap();
319        block.read_exact(&mut buf[..]).await.unwrap();
320        assert_eq!(buf, "A".repeat(128).into_bytes());
321
322        let mut buf = vec![0; 128];
323        block.seek(SeekFrom::Start(512)).await.unwrap();
324        block.read_exact(&mut buf[..]).await.unwrap();
325        assert_eq!(buf, "B".repeat(128).into_bytes());
326
327        // Read across sectors
328        let mut buf = vec![0; 128];
329        block.seek(SeekFrom::Start(512 - 64)).await.unwrap();
330        block.read_exact(&mut buf[..]).await.unwrap();
331        assert_eq!(buf, ("A".repeat(64) + "B".repeat(64).as_str()).into_bytes());
332    }
333
334    #[tokio::test]
335    async fn block_512_read_successive() {
336        let _ = env_logger::builder().is_test(true).try_init();
337        let buf = ("A".repeat(64) + "B".repeat(64).as_str())
338            .repeat(16)
339            .into_bytes();
340        let cur = std::io::Cursor::new(buf);
341        let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
342            embedded_io_adapters::tokio_1::FromTokio::new(cur),
343        ));
344
345        // Test sector aligned access
346        let mut buf = vec![0; 64];
347        block.seek(SeekFrom::Start(0)).await.unwrap();
348        block.read_exact(&mut buf[..]).await.unwrap();
349        assert_eq!(buf, "A".repeat(64).into_bytes());
350
351        let mut buf = vec![0; 64];
352        block.seek(SeekFrom::Start(64)).await.unwrap();
353        block.read_exact(&mut buf[..]).await.unwrap();
354        assert_eq!(buf, "B".repeat(64).into_bytes());
355
356        let mut buf = vec![0; 64];
357        block.seek(SeekFrom::Start(32)).await.unwrap();
358        block.read_exact(&mut buf[..]).await.unwrap();
359        assert_eq!(buf, ("A".repeat(32) + "B".repeat(32).as_str()).into_bytes());
360    }
361
362    #[tokio::test]
363    async fn block_512_write_single_sector() {
364        let _ = env_logger::builder().is_test(true).try_init();
365        let buf = vec![0; 2048];
366        let cur = std::io::Cursor::new(buf);
367        let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
368            embedded_io_adapters::tokio_1::FromTokio::new(cur),
369        ));
370
371        // Test sector aligned access
372        let data_a = "A".repeat(512).into_bytes();
373        block.seek(SeekFrom::Start(0)).await.unwrap();
374        block.write_all(&data_a).await.unwrap();
375        assert_eq!(
376            &block.into_inner().0.into_inner().into_inner()[..512],
377            data_a
378        )
379    }
380
381    #[tokio::test]
382    async fn block_512_write_across_sectors() {
383        let _ = env_logger::builder().is_test(true).try_init();
384        let buf = vec![0; 2048];
385        let cur = std::io::Cursor::new(buf);
386        let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
387            embedded_io_adapters::tokio_1::FromTokio::new(cur),
388        ));
389
390        // Test sector aligned access
391        let data_a = "A".repeat(512).into_bytes();
392        block.seek(SeekFrom::Start(256)).await.unwrap();
393        block.write_all(&data_a).await.unwrap();
394        block.flush().await.unwrap();
395        let buf = block.into_inner().0.into_inner().into_inner();
396        assert_eq!(&buf[..256], [0; 256]);
397        assert_eq!(&buf[256..768], data_a);
398        assert_eq!(&buf[768..1024], [0; 256]);
399    }
400
401    #[tokio::test]
402    async fn aligned_write_block_optimization() {
403        let _ = env_logger::builder().is_test(true).try_init();
404        let buf = vec![0; 2048];
405        let cur = std::io::Cursor::new(buf);
406        let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
407            embedded_io_adapters::tokio_1::FromTokio::new(cur),
408        ));
409
410        let mut aligned_buffer: Aligned<A4, [u8; 512]> = Aligned([0; 512]);
411        let data_a = "A".repeat(512).into_bytes();
412        aligned_buffer[..].copy_from_slice(&data_a[..]);
413        block.seek(SeekFrom::Start(0)).await.unwrap();
414        block.write_all(&aligned_buffer[..]).await.unwrap();
415
416        // if we wrote directly, the block buffer will be empty
417        assert_eq!(&block.buffer[..], [0u8; 512]);
418        // ensure that the current offset is still updated
419        assert_eq!(block.current_offset, 512);
420        // the write suceeded
421        assert_eq!(
422            &block.into_inner().0.into_inner().into_inner()[..512],
423            &data_a
424        )
425    }
426
427    #[tokio::test]
428    async fn aligned_write_block_optimization_misaligned_block() {
429        let _ = env_logger::builder().is_test(true).try_init();
430        let buf = vec![0; 2048];
431        let cur = std::io::Cursor::new(buf);
432        let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
433            embedded_io_adapters::tokio_1::FromTokio::new(cur),
434        ));
435
436        let mut aligned_buffer: Aligned<A4, [u8; 2048]> = Aligned([0; 2048]);
437        let data_a = "A".repeat(512).into_bytes();
438        aligned_buffer[..512].copy_from_slice(&data_a[..]);
439        // seek away from aligned block address
440        block.seek(SeekFrom::Start(3)).await.unwrap();
441        // attempt write all
442        block.write_all(&aligned_buffer[..512]).await.unwrap();
443        block.flush().await.unwrap();
444
445        // because the addr was not block aligned, we will have used the cache
446        assert_ne!(&block.buffer[..], [0u8; 512]);
447        // the write suceeded
448        assert_eq!(
449            &block.into_inner().0.into_inner().into_inner()[3..515],
450            &data_a
451        )
452    }
453
454    #[tokio::test]
455    async fn aligned_read_block_optimization() {
456        let _ = env_logger::builder().is_test(true).try_init();
457        let buf = "A".repeat(2048).into_bytes();
458        let cur = std::io::Cursor::new(buf);
459        let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
460            embedded_io_adapters::tokio_1::FromTokio::new(cur),
461        ));
462
463        let mut aligned_buffer: Aligned<A4, [u8; 512]> = Aligned([0; 512]);
464        block.seek(SeekFrom::Start(0)).await.unwrap();
465        block.read_exact(&mut aligned_buffer[..]).await.unwrap();
466
467        // if we read directly, the block buffer will be empty
468        assert_eq!(&block.buffer[..], [0u8; 512]);
469        // ensure that the current offset is still updated
470        assert_eq!(block.current_offset, 512);
471        // the write suceeded
472        assert_eq!(
473            &block.into_inner().0.into_inner().into_inner()[..512],
474            &aligned_buffer[..]
475        )
476    }
477
478    #[tokio::test]
479    async fn aligned_read_block_optimization_misaligned() {
480        let _ = env_logger::builder().is_test(true).try_init();
481        let buf = "A".repeat(2048).into_bytes();
482        let cur = std::io::Cursor::new(buf);
483        let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
484            embedded_io_adapters::tokio_1::FromTokio::new(cur),
485        ));
486
487        let mut aligned_buffer: Aligned<A4, [u8; 512]> = Aligned([0; 512]);
488        // seek away from aligned block
489        block.seek(SeekFrom::Start(3)).await.unwrap();
490        // pass an aligned buffer with correct sizing
491        block.read_exact(&mut aligned_buffer[..]).await.unwrap();
492
493        // now, we must seek back and read the entire block
494        // meaning our block cache will be written to:
495        assert_ne!(&block.buffer[..], [0u8; 512]);
496
497        // the read suceeded
498        assert_eq!(
499            &block.into_inner().0.into_inner().into_inner()[3..512],
500            &aligned_buffer[3..]
501        )
502    }
503
504    #[tokio::test]
505    async fn write_seek_read_write() {
506        let _ = env_logger::builder().is_test(true).try_init();
507        let buf = "A".repeat(2048).into_bytes();
508        let cur = std::io::Cursor::new(buf);
509        let mut block: BufStream<_, 512> = BufStream::new(TestBlockDevice(
510            embedded_io_adapters::tokio_1::FromTokio::new(cur),
511        ));
512
513        block.seek(SeekFrom::Start(524)).await.unwrap();
514        block
515            .write_all(&"B".repeat(512).into_bytes())
516            .await
517            .unwrap();
518        block.flush().await.unwrap();
519
520        block.seek(SeekFrom::Start(0)).await.unwrap();
521        let mut tmp = [0u8; 256];
522        block.read(&mut tmp[..]).await.unwrap();
523
524        assert_eq!(&tmp[..], "A".repeat(256).into_bytes().as_slice());
525
526        block.seek(SeekFrom::Start(524 + 512)).await.unwrap();
527        block
528            .write_all(&"C".repeat(512).into_bytes())
529            .await
530            .unwrap();
531        block.flush().await.unwrap();
532
533        let buf = block.into_inner().0.into_inner().into_inner();
534
535        assert_eq!(
536            buf,
537            ("A".repeat(524) + &"B".repeat(512) + &"C".repeat(512) + &"A".repeat(500)).into_bytes()
538        )
539    }
540}