use std::io::{self, SeekFrom};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use memmap2::Mmap;
use tokio::fs::File;
use tokio::io::{AsyncRead, AsyncSeek, BufReader, ReadBuf};
pub(crate) enum BlockSource {
Buffered(BufReader<File>),
Mapped(MmapCursor),
}
impl BlockSource {
pub(crate) fn buffered(file: File) -> Self {
BlockSource::Buffered(BufReader::new(file))
}
pub(crate) fn mapped(mmap: Arc<Mmap>) -> Self {
BlockSource::Mapped(MmapCursor::new(mmap))
}
#[cfg(test)]
pub(crate) fn is_mmap(&self) -> bool {
matches!(self, BlockSource::Mapped(_))
}
}
impl AsyncRead for BlockSource {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
BlockSource::Buffered(r) => Pin::new(r).poll_read(cx, buf),
BlockSource::Mapped(c) => Pin::new(c).poll_read(cx, buf),
}
}
}
impl AsyncSeek for BlockSource {
fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
match self.get_mut() {
BlockSource::Buffered(r) => Pin::new(r).start_seek(position),
BlockSource::Mapped(c) => Pin::new(c).start_seek(position),
}
}
fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
match self.get_mut() {
BlockSource::Buffered(r) => Pin::new(r).poll_complete(cx),
BlockSource::Mapped(c) => Pin::new(c).poll_complete(cx),
}
}
}
pub(crate) struct MmapCursor {
mmap: Arc<Mmap>,
pos: u64,
}
impl MmapCursor {
fn new(mmap: Arc<Mmap>) -> Self {
Self { mmap, pos: 0 }
}
}
impl AsyncRead for MmapCursor {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
let data: &[u8] = &this.mmap;
let len = data.len() as u64;
if this.pos >= len {
return Poll::Ready(Ok(()));
}
let pos = this.pos as usize;
let remaining = &data[pos..];
let n = remaining.len().min(buf.remaining());
buf.put_slice(&remaining[..n]);
this.pos += n as u64;
Poll::Ready(Ok(()))
}
}
impl AsyncSeek for MmapCursor {
fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
let this = self.get_mut();
let len = this.mmap.len() as u64;
let new_pos = match position {
SeekFrom::Start(offset) => offset,
SeekFrom::End(offset) => offset_from(len, offset)?,
SeekFrom::Current(offset) => offset_from(this.pos, offset)?,
};
this.pos = new_pos;
Ok(())
}
fn poll_complete(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
Poll::Ready(Ok(self.get_mut().pos))
}
}
fn offset_from(base: u64, offset: i64) -> io::Result<u64> {
let result = if offset >= 0 {
base.checked_add(offset as u64).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"invalid seek to an overflowing position",
)
})?
} else {
base.checked_sub(offset.unsigned_abs()).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"invalid seek to a negative position",
)
})?
};
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncSeekExt};
fn cursor(bytes: &[u8]) -> MmapCursor {
let mut mmap = memmap2::MmapMut::map_anon(bytes.len().max(1)).unwrap();
mmap[..bytes.len()].copy_from_slice(bytes);
let mmap = mmap.make_read_only().unwrap();
MmapCursor::new(Arc::new(mmap))
}
#[tokio::test]
async fn reads_sequentially() {
let mut c = cursor(b"hello world");
let mut buf = [0u8; 5];
c.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
assert_eq!(c.stream_position().await.unwrap(), 5);
}
#[tokio::test]
async fn seek_start_current_end() {
let mut c = cursor(b"0123456789");
c.seek(SeekFrom::Start(3)).await.unwrap();
let mut b = [0u8; 2];
c.read_exact(&mut b).await.unwrap();
assert_eq!(&b, b"34");
c.seek(SeekFrom::Current(2)).await.unwrap();
c.read_exact(&mut b).await.unwrap();
assert_eq!(&b, b"78");
let end = c.seek(SeekFrom::End(0)).await.unwrap();
assert_eq!(end, 10);
}
#[tokio::test]
async fn read_past_eof_is_unexpected_eof() {
let mut c = cursor(b"abc");
c.seek(SeekFrom::Start(2)).await.unwrap();
let mut b = [0u8; 8];
let err = c.read_exact(&mut b).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
}
#[tokio::test]
async fn negative_seek_before_start_errors() {
let mut c = cursor(b"abc");
let err = c.seek(SeekFrom::Current(-5)).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
}
#[tokio::test]
async fn block_source_reports_backend() {
let mmap = memmap2::MmapMut::map_anon(8).unwrap();
let mmap = mmap.make_read_only().unwrap();
assert!(BlockSource::mapped(Arc::new(mmap)).is_mmap());
let dir = std::env::temp_dir();
let path = dir.join("cqlite_blocksource_backend_test.bin");
tokio::fs::write(&path, b"buffered").await.unwrap();
let file = tokio::fs::File::open(&path).await.unwrap();
assert!(!BlockSource::buffered(file).is_mmap());
tokio::fs::remove_file(&path).await.ok();
}
#[tokio::test]
async fn positive_seek_overflow_errors() {
let mut c = cursor(b"abc");
c.seek(SeekFrom::Start(u64::MAX)).await.unwrap();
let err = c.seek(SeekFrom::Current(1)).await.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
let mut c2 = cursor(b"abc"); let landed = c2.seek(SeekFrom::End(i64::MAX)).await.unwrap();
assert_eq!(landed, 3u64 + i64::MAX as u64);
}
#[tokio::test]
async fn seek_past_eof_preserves_position_like_file() {
let mut c = cursor(b"abc"); let landed = c.seek(SeekFrom::Start(10)).await.unwrap();
assert_eq!(landed, 10);
let mut b = [0u8; 4];
let n = c.read(&mut b).await.unwrap();
assert_eq!(n, 0, "read past EOF yields no bytes");
assert_eq!(c.stream_position().await.unwrap(), 10);
c.seek(SeekFrom::Start(1)).await.unwrap();
let mut one = [0u8; 1];
c.read_exact(&mut one).await.unwrap();
assert_eq!(&one, b"b");
}
#[tokio::test]
async fn seek_past_eof_position_matches_real_file() {
let bytes = b"abc";
let dir = std::env::temp_dir();
let path = dir.join("cqlite_mmapcursor_eof_parity.bin");
tokio::fs::write(&path, bytes).await.unwrap();
let mut file = tokio::fs::File::open(&path).await.unwrap();
file.seek(SeekFrom::Start(10)).await.unwrap();
let mut fb = [0u8; 4];
let file_n = file.read(&mut fb).await.unwrap();
let file_pos = file.stream_position().await.unwrap();
tokio::fs::remove_file(&path).await.ok();
let mut c = cursor(bytes);
c.seek(SeekFrom::Start(10)).await.unwrap();
let mut cb = [0u8; 4];
let cur_n = c.read(&mut cb).await.unwrap();
let cur_pos = c.stream_position().await.unwrap();
assert_eq!(cur_n, file_n, "byte count parity");
assert_eq!(cur_pos, file_pos, "post-read position parity");
}
#[tokio::test]
async fn multipage_read_across_page_boundary() {
let len = 10_000usize;
let mut data = vec![0u8; len];
for (i, b) in data.iter_mut().enumerate() {
*b = (i % 251) as u8; }
let mut c = cursor(&data);
c.seek(SeekFrom::Start(4090)).await.unwrap();
let mut window = [0u8; 16]; c.read_exact(&mut window).await.unwrap();
for (k, b) in window.iter().enumerate() {
assert_eq!(*b, ((4090 + k) % 251) as u8);
}
assert_eq!(c.stream_position().await.unwrap(), 4106);
c.seek(SeekFrom::Start((len - 4) as u64)).await.unwrap();
let mut tail = [0u8; 4];
c.read_exact(&mut tail).await.unwrap();
for (k, b) in tail.iter().enumerate() {
assert_eq!(*b, ((len - 4 + k) % 251) as u8);
}
assert_eq!(c.stream_position().await.unwrap(), len as u64);
}
#[tokio::test]
async fn partial_read_returns_available_bytes() {
let mut c = cursor(b"abcd");
c.seek(SeekFrom::Start(2)).await.unwrap();
let mut b = [0u8; 8];
let n = c.read(&mut b).await.unwrap();
assert_eq!(n, 2);
assert_eq!(&b[..2], b"cd");
}
}