use tensorlogic_ir::EinsumGraph;
use crate::batch::BatchResult;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamingMode {
None,
FixedChunk(usize),
DynamicChunk { target_memory_mb: usize },
Adaptive { initial_chunk: usize },
}
#[derive(Debug, Clone)]
pub struct StreamingConfig {
pub mode: StreamingMode,
pub prefetch_chunks: usize,
pub overlap_compute_io: bool,
pub checkpoint_interval: Option<usize>,
}
impl StreamingConfig {
pub fn new(mode: StreamingMode) -> Self {
StreamingConfig {
mode,
prefetch_chunks: 1,
overlap_compute_io: true,
checkpoint_interval: None,
}
}
pub fn with_prefetch(mut self, num_chunks: usize) -> Self {
self.prefetch_chunks = num_chunks;
self
}
pub fn with_checkpointing(mut self, interval: usize) -> Self {
self.checkpoint_interval = Some(interval);
self
}
pub fn disable_overlap(mut self) -> Self {
self.overlap_compute_io = false;
self
}
}
impl Default for StreamingConfig {
fn default() -> Self {
Self::new(StreamingMode::None)
}
}
#[derive(Debug, Clone)]
pub struct ChunkMetadata {
pub chunk_id: usize,
pub start_idx: usize,
pub end_idx: usize,
pub size: usize,
pub is_last: bool,
}
impl ChunkMetadata {
pub fn new(chunk_id: usize, start_idx: usize, end_idx: usize, total_size: usize) -> Self {
let size = end_idx - start_idx;
let is_last = end_idx >= total_size;
ChunkMetadata {
chunk_id,
start_idx,
end_idx,
size,
is_last,
}
}
}
#[derive(Debug, Clone)]
pub struct StreamResult<T> {
pub outputs: Vec<T>,
pub metadata: ChunkMetadata,
pub processing_time_ms: f64,
}
impl<T> StreamResult<T> {
pub fn new(outputs: Vec<T>, metadata: ChunkMetadata, processing_time_ms: f64) -> Self {
StreamResult {
outputs,
metadata,
processing_time_ms,
}
}
pub fn throughput_items_per_sec(&self) -> f64 {
if self.processing_time_ms > 0.0 {
(self.metadata.size as f64) / (self.processing_time_ms / 1000.0)
} else {
0.0
}
}
}
pub trait TlStreamingExecutor {
type Tensor;
type Error;
fn execute_stream(
&mut self,
graph: &EinsumGraph,
input_stream: Vec<Vec<Vec<Self::Tensor>>>,
config: &StreamingConfig,
) -> Result<Vec<StreamResult<Self::Tensor>>, Self::Error>;
fn execute_chunk(
&mut self,
graph: &EinsumGraph,
chunk_inputs: Vec<Self::Tensor>,
metadata: &ChunkMetadata,
) -> Result<StreamResult<Self::Tensor>, Self::Error>;
fn recommend_chunk_size(&self, graph: &EinsumGraph, available_memory_mb: usize) -> usize {
let _ = (graph, available_memory_mb);
32 }
fn estimate_chunk_memory(&self, graph: &EinsumGraph, chunk_size: usize) -> usize {
let _ = (graph, chunk_size);
chunk_size * 1024 * 1024 }
}
pub struct ChunkIterator {
total_size: usize,
chunk_size: usize,
current_chunk: usize,
}
impl ChunkIterator {
pub fn new(total_size: usize, chunk_size: usize) -> Self {
ChunkIterator {
total_size,
chunk_size,
current_chunk: 0,
}
}
pub fn from_config(total_size: usize, config: &StreamingConfig) -> Self {
let chunk_size = match config.mode {
StreamingMode::None => total_size,
StreamingMode::FixedChunk(size) => size,
StreamingMode::DynamicChunk { target_memory_mb } => {
(target_memory_mb).max(1)
}
StreamingMode::Adaptive { initial_chunk } => initial_chunk,
};
ChunkIterator::new(total_size, chunk_size)
}
pub fn num_chunks(&self) -> usize {
self.total_size.div_ceil(self.chunk_size)
}
pub fn current_chunk(&self) -> usize {
self.current_chunk
}
}
impl Iterator for ChunkIterator {
type Item = ChunkMetadata;
fn next(&mut self) -> Option<Self::Item> {
let start_idx = self.current_chunk * self.chunk_size;
if start_idx >= self.total_size {
return None;
}
let end_idx = (start_idx + self.chunk_size).min(self.total_size);
let metadata = ChunkMetadata::new(self.current_chunk, start_idx, end_idx, self.total_size);
self.current_chunk += 1;
Some(metadata)
}
}
pub struct StreamProcessor {
config: StreamingConfig,
}
impl StreamProcessor {
pub fn new(config: StreamingConfig) -> Self {
StreamProcessor { config }
}
pub fn split_batch<T: Clone>(&self, batch: &BatchResult<T>) -> Vec<(ChunkMetadata, Vec<T>)> {
let total_size = batch.len();
let iter = ChunkIterator::from_config(total_size, &self.config);
iter.map(|metadata| {
let chunk_data: Vec<T> = batch.outputs[metadata.start_idx..metadata.end_idx].to_vec();
(metadata, chunk_data)
})
.collect()
}
pub fn merge_results<T>(results: Vec<StreamResult<T>>) -> BatchResult<T> {
let total_size: usize = results.iter().map(|r| r.outputs.len()).sum();
let mut outputs = Vec::with_capacity(total_size);
for result in results {
outputs.extend(result.outputs);
}
BatchResult::new(outputs)
}
pub fn adaptive_chunk_size(&self, results: &[StreamResult<impl Clone>]) -> usize {
if results.is_empty() {
return 32; }
let avg_throughput: f64 = results
.iter()
.map(|r| r.throughput_items_per_sec())
.sum::<f64>()
/ results.len() as f64;
let target_time_ms = 100.0;
let items_per_chunk = (avg_throughput * target_time_ms / 1000.0) as usize;
items_per_chunk.clamp(1, 1000) }
pub fn config(&self) -> &StreamingConfig {
&self.config
}
}
impl Default for StreamProcessor {
fn default() -> Self {
Self::new(StreamingConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_streaming_config() {
let config = StreamingConfig::new(StreamingMode::FixedChunk(64))
.with_prefetch(2)
.with_checkpointing(100);
assert_eq!(config.mode, StreamingMode::FixedChunk(64));
assert_eq!(config.prefetch_chunks, 2);
assert_eq!(config.checkpoint_interval, Some(100));
}
#[test]
fn test_chunk_metadata() {
let metadata = ChunkMetadata::new(0, 0, 32, 100);
assert_eq!(metadata.chunk_id, 0);
assert_eq!(metadata.size, 32);
assert!(!metadata.is_last);
let last_metadata = ChunkMetadata::new(3, 96, 100, 100);
assert!(last_metadata.is_last);
}
#[test]
fn test_stream_result() {
let metadata = ChunkMetadata::new(0, 0, 32, 100);
let result: StreamResult<i32> = StreamResult::new(vec![1, 2, 3], metadata, 100.0);
assert_eq!(result.outputs.len(), 3);
let throughput = result.throughput_items_per_sec();
assert!(throughput > 0.0);
}
#[test]
fn test_chunk_iterator() {
let iter = ChunkIterator::new(100, 32);
assert_eq!(iter.num_chunks(), 4);
let chunks: Vec<_> = iter.collect();
assert_eq!(chunks.len(), 4);
assert_eq!(chunks[0].size, 32);
assert_eq!(chunks[3].size, 4);
assert!(chunks[3].is_last);
}
#[test]
fn test_chunk_iterator_from_config() {
let config = StreamingConfig::new(StreamingMode::FixedChunk(25));
let iter = ChunkIterator::from_config(100, &config);
assert_eq!(iter.chunk_size, 25);
assert_eq!(iter.num_chunks(), 4);
}
#[test]
fn test_stream_processor_split() {
let batch = BatchResult::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
let config = StreamingConfig::new(StreamingMode::FixedChunk(3));
let processor = StreamProcessor::new(config);
let chunks = processor.split_batch(&batch);
assert_eq!(chunks.len(), 4);
assert_eq!(chunks[0].1, vec![1, 2, 3]);
assert_eq!(chunks[1].1, vec![4, 5, 6]);
assert_eq!(chunks[2].1, vec![7, 8, 9]);
assert_eq!(chunks[3].1, vec![10]);
}
#[test]
fn test_stream_processor_merge() {
let metadata1 = ChunkMetadata::new(0, 0, 3, 10);
let metadata2 = ChunkMetadata::new(1, 3, 6, 10);
let metadata3 = ChunkMetadata::new(2, 6, 10, 10);
let results = vec![
StreamResult::new(vec![1, 2, 3], metadata1, 10.0),
StreamResult::new(vec![4, 5, 6], metadata2, 10.0),
StreamResult::new(vec![7, 8, 9, 10], metadata3, 10.0),
];
let batch = StreamProcessor::merge_results(results);
assert_eq!(batch.len(), 10);
assert_eq!(batch.outputs, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
}
#[test]
fn test_adaptive_chunk_size() {
let processor = StreamProcessor::default();
let metadata = ChunkMetadata::new(0, 0, 100, 1000);
let results = vec![
StreamResult::new(vec![(); 100], metadata.clone(), 50.0), StreamResult::new(vec![(); 100], metadata.clone(), 100.0), StreamResult::new(vec![(); 100], metadata, 75.0), ];
let chunk_size = processor.adaptive_chunk_size(&results);
assert!(chunk_size > 0);
assert!(chunk_size <= 1000); }
#[test]
fn test_streaming_modes() {
assert_eq!(StreamingMode::None, StreamingConfig::default().mode);
let fixed = StreamingMode::FixedChunk(64);
assert_eq!(fixed, StreamingMode::FixedChunk(64));
let dynamic = StreamingMode::DynamicChunk {
target_memory_mb: 512,
};
match dynamic {
StreamingMode::DynamicChunk { target_memory_mb } => {
assert_eq!(target_memory_mb, 512);
}
_ => panic!("Wrong mode"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BackpressureStrategy {
Block,
DropOldest,
DropNewest,
ErrorOnFull,
}
#[derive(Debug, Clone)]
pub struct BackpressureConfig {
pub max_buffered_chunks: usize,
pub high_watermark: f64,
pub low_watermark: f64,
pub strategy: BackpressureStrategy,
}
impl BackpressureConfig {
pub fn new(max_buffered: usize) -> Self {
BackpressureConfig {
max_buffered_chunks: max_buffered,
high_watermark: 0.8,
low_watermark: 0.2,
strategy: BackpressureStrategy::Block,
}
}
pub fn with_watermarks(mut self, high: f64, low: f64) -> Self {
self.high_watermark = high;
self.low_watermark = low;
self
}
pub fn with_strategy(mut self, strategy: BackpressureStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn is_above_high_watermark(&self, current_buffered: usize) -> bool {
let threshold = (self.max_buffered_chunks as f64 * self.high_watermark) as usize;
current_buffered > threshold
}
pub fn is_below_low_watermark(&self, current_buffered: usize) -> bool {
let threshold = (self.max_buffered_chunks as f64 * self.low_watermark) as usize;
current_buffered < threshold
}
pub fn should_apply_backpressure(&self, current_buffered: usize) -> bool {
self.is_above_high_watermark(current_buffered)
}
}
#[derive(Debug, Clone)]
pub struct WatermarkConfig {
pub max_out_of_order_ms: u64,
pub idle_timeout_ms: Option<u64>,
pub drop_late_events: bool,
}
impl WatermarkConfig {
pub fn new(max_out_of_order_ms: u64) -> Self {
WatermarkConfig {
max_out_of_order_ms,
idle_timeout_ms: None,
drop_late_events: false,
}
}
pub fn with_idle_timeout(mut self, timeout_ms: u64) -> Self {
self.idle_timeout_ms = Some(timeout_ms);
self
}
pub fn with_drop_late(mut self, drop: bool) -> Self {
self.drop_late_events = drop;
self
}
pub fn current_watermark(&self, max_event_time_ms: u64) -> u64 {
max_event_time_ms.saturating_sub(self.max_out_of_order_ms)
}
pub fn is_late(&self, event_time_ms: u64, watermark_ms: u64) -> bool {
event_time_ms < watermark_ms
}
}
#[derive(Debug, Clone)]
pub struct StreamingConfigV2 {
pub base: StreamingConfig,
pub backpressure: Option<BackpressureConfig>,
pub watermark: Option<WatermarkConfig>,
}
impl StreamingConfigV2 {
pub fn new(base: StreamingConfig) -> Self {
StreamingConfigV2 {
base,
backpressure: None,
watermark: None,
}
}
pub fn with_backpressure(mut self, config: BackpressureConfig) -> Self {
self.backpressure = Some(config);
self
}
pub fn with_watermark(mut self, config: WatermarkConfig) -> Self {
self.watermark = Some(config);
self
}
pub fn should_apply_backpressure(&self, current_buffered: usize) -> bool {
self.backpressure
.as_ref()
.is_some_and(|bp| bp.should_apply_backpressure(current_buffered))
}
pub fn is_late_event(&self, event_time_ms: u64, watermark_ms: u64) -> bool {
self.watermark
.as_ref()
.is_some_and(|wm| wm.is_late(event_time_ms, watermark_ms))
}
}
impl Default for StreamingConfigV2 {
fn default() -> Self {
Self::new(StreamingConfig::default())
}
}
#[derive(Debug, Clone, Default)]
pub struct StreamingStats {
pub chunks_processed: usize,
pub chunks_dropped: usize,
pub backpressure_events: usize,
pub late_events_dropped: usize,
pub total_processing_time_ms: u64,
pub total_elements_processed: usize,
}
impl StreamingStats {
pub fn average_latency_ms(&self) -> f64 {
if self.chunks_processed == 0 {
return 0.0;
}
self.total_processing_time_ms as f64 / self.chunks_processed as f64
}
pub fn drop_rate(&self) -> f64 {
let total = self.chunks_processed + self.chunks_dropped;
if total == 0 {
return 0.0;
}
self.chunks_dropped as f64 / total as f64
}
pub fn throughput_chunks_per_sec(&self) -> f64 {
if self.total_processing_time_ms == 0 {
return 0.0;
}
self.chunks_processed as f64 / (self.total_processing_time_ms as f64 / 1000.0)
}
pub fn merge(&mut self, other: &StreamingStats) {
self.chunks_processed += other.chunks_processed;
self.chunks_dropped += other.chunks_dropped;
self.backpressure_events += other.backpressure_events;
self.late_events_dropped += other.late_events_dropped;
self.total_processing_time_ms += other.total_processing_time_ms;
self.total_elements_processed += other.total_elements_processed;
}
}
#[cfg(test)]
mod v2_tests {
use super::*;
#[test]
fn test_backpressure_config_new() {
let cfg = BackpressureConfig::new(100);
assert_eq!(cfg.max_buffered_chunks, 100);
assert!((cfg.high_watermark - 0.8).abs() < f64::EPSILON);
assert!((cfg.low_watermark - 0.2).abs() < f64::EPSILON);
assert_eq!(cfg.strategy, BackpressureStrategy::Block);
}
#[test]
fn test_backpressure_above_high_watermark() {
let cfg = BackpressureConfig::new(100); assert!(cfg.is_above_high_watermark(81));
assert!(!cfg.is_above_high_watermark(80));
assert!(!cfg.is_above_high_watermark(0));
}
#[test]
fn test_backpressure_below_low_watermark() {
let cfg = BackpressureConfig::new(100); assert!(cfg.is_below_low_watermark(19));
assert!(!cfg.is_below_low_watermark(20));
assert!(!cfg.is_below_low_watermark(100));
}
#[test]
fn test_backpressure_between_watermarks() {
let cfg = BackpressureConfig::new(100);
assert!(!cfg.should_apply_backpressure(50));
assert!(cfg.should_apply_backpressure(81));
}
#[test]
fn test_backpressure_strategy_variants() {
let block = BackpressureStrategy::Block;
let drop_oldest = BackpressureStrategy::DropOldest;
let drop_newest = BackpressureStrategy::DropNewest;
let error = BackpressureStrategy::ErrorOnFull;
assert_ne!(drop_oldest, block);
assert_ne!(drop_newest, block);
assert_ne!(error, block);
assert_ne!(drop_oldest, drop_newest);
let cfg = BackpressureConfig::new(10).with_strategy(BackpressureStrategy::DropOldest);
assert_eq!(cfg.strategy, drop_oldest);
let _ = error; }
#[test]
fn test_watermark_config_new() {
let wm = WatermarkConfig::new(100);
assert_eq!(wm.max_out_of_order_ms, 100);
assert_eq!(wm.idle_timeout_ms, None);
assert!(!wm.drop_late_events);
}
#[test]
fn test_watermark_current_watermark_calculation() {
let wm = WatermarkConfig::new(100);
assert_eq!(wm.current_watermark(500), 400);
let wm2 = WatermarkConfig::new(1000);
assert_eq!(wm2.current_watermark(500), 0);
}
#[test]
fn test_watermark_is_late_event() {
let wm = WatermarkConfig::new(100);
assert!(wm.is_late(300, 400));
assert!(!wm.is_late(400, 400));
assert!(!wm.is_late(500, 400));
}
#[test]
fn test_watermark_with_idle_timeout() {
let wm = WatermarkConfig::new(100).with_idle_timeout(5000);
assert_eq!(wm.idle_timeout_ms, Some(5000));
assert_eq!(wm.max_out_of_order_ms, 100);
}
#[test]
fn test_streaming_stats_default() {
let stats = StreamingStats::default();
assert_eq!(stats.chunks_processed, 0);
assert_eq!(stats.chunks_dropped, 0);
assert!((stats.average_latency_ms() - 0.0).abs() < f64::EPSILON);
assert!((stats.drop_rate() - 0.0).abs() < f64::EPSILON);
assert!((stats.throughput_chunks_per_sec() - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_streaming_stats_drop_rate() {
let stats = StreamingStats {
chunks_processed: 9,
chunks_dropped: 1,
..Default::default()
};
assert!((stats.drop_rate() - 0.1).abs() < 1e-9);
}
#[test]
fn test_streaming_stats_merge() {
let mut a = StreamingStats {
chunks_processed: 10,
chunks_dropped: 2,
backpressure_events: 1,
late_events_dropped: 3,
total_processing_time_ms: 500,
total_elements_processed: 100,
};
let b = StreamingStats {
chunks_processed: 5,
chunks_dropped: 1,
backpressure_events: 2,
late_events_dropped: 0,
total_processing_time_ms: 250,
total_elements_processed: 50,
};
a.merge(&b);
assert_eq!(a.chunks_processed, 15);
assert_eq!(a.chunks_dropped, 3);
assert_eq!(a.backpressure_events, 3);
assert_eq!(a.late_events_dropped, 3);
assert_eq!(a.total_processing_time_ms, 750);
assert_eq!(a.total_elements_processed, 150);
}
#[test]
fn test_streaming_config_v2_new() {
let cfg = StreamingConfigV2::new(StreamingConfig::default());
assert!(cfg.backpressure.is_none());
assert!(cfg.watermark.is_none());
}
#[test]
fn test_streaming_config_v2_with_backpressure() {
let cfg_none = StreamingConfigV2::new(StreamingConfig::default());
assert!(!cfg_none.should_apply_backpressure(0));
assert!(!cfg_none.should_apply_backpressure(usize::MAX));
let bp = BackpressureConfig::new(100);
let cfg = StreamingConfigV2::new(StreamingConfig::default()).with_backpressure(bp);
assert!(!cfg.should_apply_backpressure(50));
assert!(cfg.should_apply_backpressure(81));
}
#[test]
fn test_streaming_config_v2_combined() {
let bp = BackpressureConfig::new(50);
let wm = WatermarkConfig::new(200);
let cfg = StreamingConfigV2::new(StreamingConfig::default())
.with_backpressure(bp)
.with_watermark(wm);
assert!(cfg.backpressure.is_some());
assert!(cfg.watermark.is_some());
assert!(cfg.should_apply_backpressure(41));
assert!(cfg.is_late_event(100, 300));
assert!(!cfg.is_late_event(400, 300));
}
}