use super::buffer::{ChunkDescriptor, ChunkedBuffer};
use crate::error::{Result, StreamingError};
use async_trait::async_trait;
use bytes::Bytes;
use std::path::{Path, PathBuf};
use tokio::fs::File;
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, SeekFrom};
use tracing::{debug, info};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChunkStrategy {
FixedSize(usize),
Adaptive {
min_size: usize,
max_size: usize,
target_memory: usize,
},
LineBased {
max_lines: usize,
max_bytes: usize,
},
}
impl ChunkStrategy {
pub fn chunk_size_for_index(&self, index: usize, available_memory: usize) -> usize {
match self {
ChunkStrategy::FixedSize(size) => *size,
ChunkStrategy::Adaptive {
min_size,
max_size,
target_memory,
} => {
let target_size = available_memory.min(*target_memory);
target_size.max(*min_size).min(*max_size)
}
ChunkStrategy::LineBased { max_bytes, .. } => *max_bytes,
}
}
}
impl Default for ChunkStrategy {
fn default() -> Self {
ChunkStrategy::FixedSize(1024 * 1024) }
}
#[async_trait]
pub trait ChunkedIO: Send + Sync {
async fn read_chunk(&mut self, descriptor: &ChunkDescriptor) -> Result<Bytes>;
async fn write_chunk(&mut self, descriptor: &ChunkDescriptor, data: Bytes) -> Result<()>;
async fn total_size(&self) -> Result<u64>;
async fn flush(&mut self) -> Result<()>;
}
pub struct FileChunkedIO {
path: PathBuf,
read_file: Option<File>,
write_file: Option<File>,
strategy: ChunkStrategy,
direct_io: bool,
}
impl FileChunkedIO {
pub async fn new<P: AsRef<Path>>(path: P, strategy: ChunkStrategy) -> Result<Self> {
let path = path.as_ref().to_path_buf();
Ok(Self {
path,
read_file: None,
write_file: None,
strategy,
direct_io: false,
})
}
pub async fn open_read(&mut self) -> Result<()> {
if self.read_file.is_some() {
return Ok(());
}
let file = File::open(&self.path)
.await
.map_err(|e| StreamingError::Io(e))?;
info!("Opened file for reading: {:?}", self.path);
self.read_file = Some(file);
Ok(())
}
pub async fn open_write(&mut self) -> Result<()> {
if self.write_file.is_some() {
return Ok(());
}
let file = File::create(&self.path)
.await
.map_err(|e| StreamingError::Io(e))?;
info!("Opened file for writing: {:?}", self.path);
self.write_file = Some(file);
Ok(())
}
pub fn with_direct_io(mut self, enable: bool) -> Self {
self.direct_io = enable;
self
}
}
#[async_trait]
impl ChunkedIO for FileChunkedIO {
async fn read_chunk(&mut self, descriptor: &ChunkDescriptor) -> Result<Bytes> {
if self.read_file.is_none() {
self.open_read().await?;
}
let file = self.read_file.as_mut()
.ok_or_else(|| StreamingError::InvalidState("File not open".to_string()))?;
file.seek(SeekFrom::Start(descriptor.offset))
.await
.map_err(|e| StreamingError::Io(e))?;
let mut buffer = vec![0u8; descriptor.length];
let bytes_read = file.read_exact(&mut buffer)
.await
.map_err(|e| StreamingError::Io(e))?;
debug!(
"Read chunk {} at offset {} ({} bytes)",
descriptor.index, descriptor.offset, bytes_read
);
Ok(Bytes::from(buffer))
}
async fn write_chunk(&mut self, descriptor: &ChunkDescriptor, data: Bytes) -> Result<()> {
if self.write_file.is_none() {
self.open_write().await?;
}
let file = self.write_file.as_mut()
.ok_or_else(|| StreamingError::InvalidState("File not open".to_string()))?;
file.seek(SeekFrom::Start(descriptor.offset))
.await
.map_err(|e| StreamingError::Io(e))?;
file.write_all(&data)
.await
.map_err(|e| StreamingError::Io(e))?;
debug!(
"Wrote chunk {} at offset {} ({} bytes)",
descriptor.index, descriptor.offset, data.len()
);
Ok(())
}
async fn total_size(&self) -> Result<u64> {
let metadata = tokio::fs::metadata(&self.path)
.await
.map_err(|e| StreamingError::Io(e))?;
Ok(metadata.len())
}
async fn flush(&mut self) -> Result<()> {
if let Some(file) = &mut self.write_file {
file.flush()
.await
.map_err(|e| StreamingError::Io(e))?;
file.sync_all()
.await
.map_err(|e| StreamingError::Io(e))?;
}
Ok(())
}
}
pub struct MemoryChunkedIO {
buffer: Vec<u8>,
strategy: ChunkStrategy,
}
impl MemoryChunkedIO {
pub fn new(size: usize, strategy: ChunkStrategy) -> Self {
Self {
buffer: vec![0u8; size],
strategy,
}
}
pub fn buffer(&self) -> &[u8] {
&self.buffer
}
}
#[async_trait]
impl ChunkedIO for MemoryChunkedIO {
async fn read_chunk(&mut self, descriptor: &ChunkDescriptor) -> Result<Bytes> {
let start = descriptor.offset as usize;
let end = start + descriptor.length;
if end > self.buffer.len() {
return Err(StreamingError::InvalidOperation(
"Chunk exceeds buffer size".to_string()
));
}
let data = self.buffer[start..end].to_vec();
Ok(Bytes::from(data))
}
async fn write_chunk(&mut self, descriptor: &ChunkDescriptor, data: Bytes) -> Result<()> {
let start = descriptor.offset as usize;
let end = start + descriptor.length;
if end > self.buffer.len() {
return Err(StreamingError::InvalidOperation(
"Chunk exceeds buffer size".to_string()
));
}
self.buffer[start..end].copy_from_slice(&data);
Ok(())
}
async fn total_size(&self) -> Result<u64> {
Ok(self.buffer.len() as u64)
}
async fn flush(&mut self) -> Result<()> {
Ok(())
}
}
pub struct CachedChunkedIO<T: ChunkedIO> {
inner: T,
cache: ChunkedBuffer,
prefetch_count: usize,
}
impl<T: ChunkedIO> CachedChunkedIO<T> {
pub fn new(inner: T, cache_size: usize, prefetch_count: usize) -> Self {
Self {
inner,
cache: ChunkedBuffer::new(1024 * 1024, cache_size),
prefetch_count,
}
}
pub async fn prefetch(&mut self, start_index: usize, total_size: u64) -> Result<()> {
let total_chunks = self.cache.calculate_chunks(total_size);
for i in 0..self.prefetch_count {
let index = start_index + i;
if index >= total_chunks {
break;
}
let descriptor = self.cache.descriptor_for_index(index, total_size);
let data = self.inner.read_chunk(&descriptor).await?;
self.cache.push(descriptor, data).await?;
}
Ok(())
}
}
#[async_trait]
impl<T: ChunkedIO> ChunkedIO for CachedChunkedIO<T> {
async fn read_chunk(&mut self, descriptor: &ChunkDescriptor) -> Result<Bytes> {
if let Some((_, data)) = self.cache.pop().await? {
return Ok(data);
}
self.inner.read_chunk(descriptor).await
}
async fn write_chunk(&mut self, descriptor: &ChunkDescriptor, data: Bytes) -> Result<()> {
self.inner.write_chunk(descriptor, data).await
}
async fn total_size(&self) -> Result<u64> {
self.inner.total_size().await
}
async fn flush(&mut self) -> Result<()> {
self.inner.flush().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[tokio::test]
async fn test_memory_chunked_io() {
let mut io = MemoryChunkedIO::new(10240, ChunkStrategy::FixedSize(1024));
let descriptor = ChunkDescriptor::new(0, 1024, 0, 10);
let data = Bytes::from(vec![42u8; 1024]);
io.write_chunk(&descriptor, data.clone()).await.ok();
let read_data = io.read_chunk(&descriptor).await.ok();
assert!(read_data.is_some());
assert_eq!(read_data.expect("chunk read should succeed").len(), 1024);
}
#[test]
fn test_chunk_strategy() {
let strategy = ChunkStrategy::FixedSize(1024);
assert_eq!(strategy.chunk_size_for_index(0, 2048), 1024);
let adaptive = ChunkStrategy::Adaptive {
min_size: 512,
max_size: 2048,
target_memory: 1024,
};
assert_eq!(adaptive.chunk_size_for_index(0, 1500), 1024);
assert_eq!(adaptive.chunk_size_for_index(0, 500), 512);
assert_eq!(adaptive.chunk_size_for_index(0, 3000), 2048);
}
}