use super::common::SerializationOptions;
use crate::{Tensor, TensorElement};
use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Seek, Write};
use std::sync::Arc;
use std::time::{Duration, Instant};
use torsh_core::error::{Result, TorshError};
pub type ProgressCallback = Arc<dyn Fn(u64, u64, Duration) + Send + Sync>;
#[derive(Debug, Clone)]
pub struct StreamingConfig {
pub chunk_size: usize,
pub progress_interval: u64,
pub buffer_size: usize,
pub compress_chunks: bool,
pub memory_limit: usize,
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
chunk_size: 64 * 1024 * 1024, progress_interval: 1024 * 1024, buffer_size: 8 * 1024, compress_chunks: false,
memory_limit: 1024 * 1024 * 1024, }
}
}
impl StreamingConfig {
pub fn fast() -> Self {
Self {
chunk_size: 128 * 1024 * 1024, progress_interval: 16 * 1024 * 1024, buffer_size: 64 * 1024, compress_chunks: false,
memory_limit: 2 * 1024 * 1024 * 1024, }
}
pub fn low_memory() -> Self {
Self {
chunk_size: 1024 * 1024, progress_interval: 512 * 1024, buffer_size: 4 * 1024, compress_chunks: true, memory_limit: 128 * 1024 * 1024, }
}
}
pub struct StreamingTensorWriter<W: Write + Seek> {
writer: BufWriter<W>,
config: StreamingConfig,
bytes_written: u64,
total_bytes: u64,
last_progress_report: u64,
start_time: Instant,
progress_callback: Option<ProgressCallback>,
}
impl<W: Write + Seek> StreamingTensorWriter<W> {
pub fn new(writer: W, config: StreamingConfig) -> Self {
let buf_writer = BufWriter::with_capacity(config.buffer_size, writer);
Self {
writer: buf_writer,
config,
bytes_written: 0,
total_bytes: 0,
last_progress_report: 0,
start_time: Instant::now(),
progress_callback: None,
}
}
pub fn with_progress_callback(mut self, callback: ProgressCallback) -> Self {
self.progress_callback = Some(callback);
self
}
pub fn begin_tensor(&mut self, total_bytes: u64) -> Result<()> {
self.total_bytes = total_bytes;
self.bytes_written = 0;
self.last_progress_report = 0;
self.start_time = Instant::now();
self.report_progress();
Ok(())
}
pub fn write_chunk(&mut self, data: &[u8]) -> Result<()> {
let chunk_data = if self.config.compress_chunks {
#[cfg(feature = "serialize")]
{
let compressed = oxiarc_zstd::compress_with_level(data, 3).map_err(|e| {
TorshError::SerializationError(format!(
"Failed to compress streaming chunk: {}",
e
))
})?;
let len = compressed.len() as u32;
let mut framed = Vec::with_capacity(4 + compressed.len());
framed.extend_from_slice(&len.to_le_bytes());
framed.extend_from_slice(&compressed);
framed
}
#[cfg(not(feature = "serialize"))]
{
data.to_vec()
}
} else {
data.to_vec()
};
self.writer
.write_all(&chunk_data)
.map_err(|e| TorshError::SerializationError(format!("Failed to write chunk: {}", e)))?;
self.bytes_written += data.len() as u64;
if self.bytes_written - self.last_progress_report >= self.config.progress_interval {
self.report_progress();
self.last_progress_report = self.bytes_written;
}
Ok(())
}
pub fn finish(mut self) -> Result<()> {
self.writer.flush().map_err(|e| {
TorshError::SerializationError(format!("Failed to flush writer: {}", e))
})?;
self.report_progress();
Ok(())
}
fn report_progress(&self) {
if let Some(ref callback) = self.progress_callback {
let elapsed = self.start_time.elapsed();
callback(self.bytes_written, self.total_bytes, elapsed);
}
}
}
pub struct StreamingTensorReader<R: Read> {
reader: BufReader<R>,
config: StreamingConfig,
bytes_read: u64,
total_bytes: u64,
last_progress_report: u64,
start_time: Instant,
progress_callback: Option<ProgressCallback>,
#[cfg(feature = "serialize")]
decomp_buf: Vec<u8>,
#[cfg(feature = "serialize")]
decomp_cursor: usize,
}
impl<R: Read> StreamingTensorReader<R> {
pub fn new(reader: R, config: StreamingConfig) -> Self {
let buf_reader = BufReader::with_capacity(config.buffer_size, reader);
Self {
reader: buf_reader,
config,
bytes_read: 0,
total_bytes: 0,
last_progress_report: 0,
start_time: Instant::now(),
progress_callback: None,
#[cfg(feature = "serialize")]
decomp_buf: Vec::new(),
#[cfg(feature = "serialize")]
decomp_cursor: 0,
}
}
pub fn with_progress_callback(mut self, callback: ProgressCallback) -> Self {
self.progress_callback = Some(callback);
self
}
pub fn begin_tensor(&mut self, total_bytes: u64) -> Result<()> {
self.total_bytes = total_bytes;
self.bytes_read = 0;
self.last_progress_report = 0;
self.start_time = Instant::now();
self.report_progress();
Ok(())
}
pub fn read_chunk(&mut self, buffer: &mut [u8]) -> Result<usize> {
if self.config.compress_chunks {
#[cfg(feature = "serialize")]
{
if self.decomp_cursor < self.decomp_buf.len() {
let available = self.decomp_buf.len() - self.decomp_cursor;
let to_copy = available.min(buffer.len());
buffer[..to_copy].copy_from_slice(
&self.decomp_buf[self.decomp_cursor..self.decomp_cursor + to_copy],
);
self.decomp_cursor += to_copy;
self.bytes_read += to_copy as u64;
if self.bytes_read - self.last_progress_report >= self.config.progress_interval
{
self.report_progress();
self.last_progress_report = self.bytes_read;
}
return Ok(to_copy);
}
let mut len_buf = [0u8; 4];
match self.reader.read_exact(&mut len_buf) {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
return Ok(0);
}
Err(e) => {
return Err(TorshError::SerializationError(format!(
"Failed to read compressed frame length: {}",
e
)));
}
}
let compressed_len = u32::from_le_bytes(len_buf) as usize;
let mut compressed = vec![0u8; compressed_len];
self.reader.read_exact(&mut compressed).map_err(|e| {
TorshError::SerializationError(format!(
"Failed to read compressed frame payload: {}",
e
))
})?;
self.decomp_buf = oxiarc_zstd::decompress(&compressed).map_err(|e| {
TorshError::SerializationError(format!(
"Failed to decompress streaming chunk: {}",
e
))
})?;
self.decomp_cursor = 0;
let to_copy = self.decomp_buf.len().min(buffer.len());
buffer[..to_copy].copy_from_slice(&self.decomp_buf[..to_copy]);
self.decomp_cursor = to_copy;
self.bytes_read += to_copy as u64;
if self.bytes_read - self.last_progress_report >= self.config.progress_interval {
self.report_progress();
self.last_progress_report = self.bytes_read;
}
return Ok(to_copy);
}
#[cfg(not(feature = "serialize"))]
{
}
}
let bytes_read = self
.reader
.read(buffer)
.map_err(|e| TorshError::SerializationError(format!("Failed to read chunk: {}", e)))?;
self.bytes_read += bytes_read as u64;
if self.bytes_read - self.last_progress_report >= self.config.progress_interval {
self.report_progress();
self.last_progress_report = self.bytes_read;
}
Ok(bytes_read)
}
pub fn read_exact(&mut self, buffer: &mut [u8]) -> Result<()> {
self.reader.read_exact(buffer).map_err(|e| {
TorshError::SerializationError(format!("Failed to read exact bytes: {}", e))
})?;
self.bytes_read += buffer.len() as u64;
if self.bytes_read - self.last_progress_report >= self.config.progress_interval {
self.report_progress();
self.last_progress_report = self.bytes_read;
}
Ok(())
}
pub fn finish(self) -> Result<()> {
self.report_progress();
Ok(())
}
fn report_progress(&self) {
if let Some(ref callback) = self.progress_callback {
let elapsed = self.start_time.elapsed();
callback(self.bytes_read, self.total_bytes, elapsed);
}
}
}
pub mod utils {
use super::*;
pub fn stream_serialize_to_file<T: TensorElement>(
tensor: &Tensor<T>,
path: &std::path::Path,
_options: &SerializationOptions,
config: StreamingConfig,
progress_callback: Option<ProgressCallback>,
) -> Result<()> {
let file = File::create(path)
.map_err(|e| TorshError::SerializationError(format!("Failed to create file: {}", e)))?;
let mut writer = StreamingTensorWriter::new(file, config);
if let Some(callback) = progress_callback {
writer = writer.with_progress_callback(callback);
}
let data = tensor.data()?;
let total_bytes = data.len() * std::mem::size_of::<T>();
writer.begin_tensor(total_bytes as u64)?;
let data_bytes =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, total_bytes) };
let chunk_size = writer.config.chunk_size;
for chunk in data_bytes.chunks(chunk_size) {
writer.write_chunk(chunk)?;
}
writer.finish()?;
Ok(())
}
pub fn console_progress_callback() -> ProgressCallback {
Arc::new(|bytes_processed, total_bytes, elapsed| {
let percentage = if total_bytes > 0 {
(bytes_processed as f64 / total_bytes as f64) * 100.0
} else {
0.0
};
let rate = if elapsed.as_secs_f64() > 0.0 {
bytes_processed as f64 / elapsed.as_secs_f64() / 1024.0 / 1024.0
} else {
0.0
};
println!(
"Progress: {:.1}% ({}/{} bytes) - {:.1} MB/s - {:?}",
percentage, bytes_processed, total_bytes, rate, elapsed
);
})
}
}