use std::fmt;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::error::SynthResult;
#[derive(Debug, Clone)]
pub enum StreamEvent<T> {
Data(T),
Progress(StreamProgress),
BatchComplete {
batch_id: u64,
count: usize,
},
Error(StreamError),
Complete(StreamSummary),
}
impl<T> StreamEvent<T> {
pub fn is_data(&self) -> bool {
matches!(self, StreamEvent::Data(_))
}
pub fn is_complete(&self) -> bool {
matches!(self, StreamEvent::Complete(_))
}
pub fn is_error(&self) -> bool {
matches!(self, StreamEvent::Error(_))
}
pub fn into_data(self) -> Option<T> {
match self {
StreamEvent::Data(data) => Some(data),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamProgress {
pub items_generated: u64,
pub items_per_second: f64,
pub elapsed_ms: u64,
pub phase: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub memory_usage_mb: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub buffer_fill_ratio: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub items_remaining: Option<u64>,
}
impl StreamProgress {
pub fn new(phase: impl Into<String>) -> Self {
Self {
items_generated: 0,
items_per_second: 0.0,
elapsed_ms: 0,
phase: phase.into(),
memory_usage_mb: None,
buffer_fill_ratio: None,
items_remaining: None,
}
}
pub fn update(&mut self, items_generated: u64, elapsed_ms: u64) {
self.items_generated = items_generated;
self.elapsed_ms = elapsed_ms;
if elapsed_ms > 0 {
self.items_per_second = (items_generated as f64) / (elapsed_ms as f64 / 1000.0);
}
}
pub fn eta_ms(&self) -> Option<u64> {
self.items_remaining.map(|remaining| {
if self.items_per_second > 0.0 {
((remaining as f64 / self.items_per_second) * 1000.0) as u64
} else {
0
}
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamError {
pub message: String,
pub category: StreamErrorCategory,
pub recoverable: bool,
pub items_affected: Option<usize>,
}
impl StreamError {
pub fn new(message: impl Into<String>, category: StreamErrorCategory) -> Self {
Self {
message: message.into(),
category,
recoverable: true,
items_affected: None,
}
}
pub fn non_recoverable(mut self) -> Self {
self.recoverable = false;
self
}
pub fn with_affected_items(mut self, count: usize) -> Self {
self.items_affected = Some(count);
self
}
}
impl fmt::Display for StreamError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "[{:?}] {}", self.category, self.message)
}
}
impl std::error::Error for StreamError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StreamErrorCategory {
Configuration,
Generation,
Output,
Resource,
Validation,
Network,
Internal,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamSummary {
pub total_items: u64,
pub total_time_ms: u64,
pub avg_items_per_second: f64,
pub error_count: u64,
pub dropped_count: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub peak_memory_mb: Option<u64>,
pub phases_completed: Vec<String>,
}
impl StreamSummary {
pub fn new(total_items: u64, total_time_ms: u64) -> Self {
let avg_items_per_second = if total_time_ms > 0 {
(total_items as f64) / (total_time_ms as f64 / 1000.0)
} else {
0.0
};
Self {
total_items,
total_time_ms,
avg_items_per_second,
error_count: 0,
dropped_count: 0,
peak_memory_mb: None,
phases_completed: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct StreamConfig {
pub buffer_size: usize,
pub enable_progress: bool,
pub progress_interval: u64,
pub backpressure: BackpressureStrategy,
pub timeout: Option<Duration>,
pub batch_size: usize,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
buffer_size: 1000,
enable_progress: true,
progress_interval: 100,
backpressure: BackpressureStrategy::Block,
timeout: None,
batch_size: 100,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum BackpressureStrategy {
#[default]
Block,
DropOldest,
DropNewest,
Buffer {
max_overflow: usize,
},
}
#[derive(Debug)]
pub struct StreamControl {
cancelled: std::sync::atomic::AtomicBool,
paused: std::sync::atomic::AtomicBool,
}
impl StreamControl {
pub fn new() -> Self {
Self {
cancelled: std::sync::atomic::AtomicBool::new(false),
paused: std::sync::atomic::AtomicBool::new(false),
}
}
pub fn cancel(&self) {
self.cancelled
.store(true, std::sync::atomic::Ordering::SeqCst);
}
pub fn pause(&self) {
self.paused.store(true, std::sync::atomic::Ordering::SeqCst);
}
pub fn resume(&self) {
self.paused
.store(false, std::sync::atomic::Ordering::SeqCst);
}
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(std::sync::atomic::Ordering::SeqCst)
}
pub fn is_paused(&self) -> bool {
self.paused.load(std::sync::atomic::Ordering::SeqCst)
}
}
impl Default for StreamControl {
fn default() -> Self {
Self::new()
}
}
impl Clone for StreamControl {
fn clone(&self) -> Self {
Self {
cancelled: std::sync::atomic::AtomicBool::new(self.is_cancelled()),
paused: std::sync::atomic::AtomicBool::new(self.is_paused()),
}
}
}
#[allow(clippy::type_complexity)]
pub trait StreamingGenerator {
type Item: Clone + Send + 'static;
fn stream(
&mut self,
config: StreamConfig,
) -> SynthResult<(
std::sync::mpsc::Receiver<StreamEvent<Self::Item>>,
std::sync::Arc<StreamControl>,
)>;
fn stream_with_progress<F>(
&mut self,
config: StreamConfig,
on_progress: F,
) -> SynthResult<(
std::sync::mpsc::Receiver<StreamEvent<Self::Item>>,
std::sync::Arc<StreamControl>,
)>
where
F: Fn(&StreamProgress) + Send + Sync + 'static;
}
pub trait StreamingSink<T>: Send {
fn process(&mut self, event: StreamEvent<T>) -> SynthResult<()>;
fn flush(&mut self) -> SynthResult<()>;
fn close(self) -> SynthResult<()>;
fn items_processed(&self) -> u64;
}
pub struct CollectorSink<T> {
items: Vec<T>,
errors: Vec<StreamError>,
summary: Option<StreamSummary>,
}
impl<T> CollectorSink<T> {
pub fn new() -> Self {
Self {
items: Vec::new(),
errors: Vec::new(),
summary: None,
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
items: Vec::with_capacity(capacity),
errors: Vec::new(),
summary: None,
}
}
pub fn into_items(self) -> Vec<T> {
self.items
}
pub fn items(&self) -> &[T] {
&self.items
}
pub fn errors(&self) -> &[StreamError] {
&self.errors
}
pub fn summary(&self) -> Option<&StreamSummary> {
self.summary.as_ref()
}
}
impl<T> Default for CollectorSink<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Send> StreamingSink<T> for CollectorSink<T> {
fn process(&mut self, event: StreamEvent<T>) -> SynthResult<()> {
match event {
StreamEvent::Data(item) => {
self.items.push(item);
}
StreamEvent::Error(error) => {
self.errors.push(error);
}
StreamEvent::Complete(summary) => {
self.summary = Some(summary);
}
_ => {}
}
Ok(())
}
fn flush(&mut self) -> SynthResult<()> {
Ok(())
}
fn close(self) -> SynthResult<()> {
Ok(())
}
fn items_processed(&self) -> u64 {
self.items.len() as u64
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_stream_progress() {
let mut progress = StreamProgress::new("test_phase");
progress.update(1000, 2000);
assert_eq!(progress.items_generated, 1000);
assert_eq!(progress.items_per_second, 500.0);
}
#[test]
fn test_stream_error() {
let error =
StreamError::new("test error", StreamErrorCategory::Generation).with_affected_items(5);
assert_eq!(error.message, "test error");
assert_eq!(error.items_affected, Some(5));
assert!(error.recoverable);
}
#[test]
fn test_stream_summary() {
let summary = StreamSummary::new(10000, 5000);
assert_eq!(summary.total_items, 10000);
assert_eq!(summary.avg_items_per_second, 2000.0);
}
#[test]
fn test_stream_control() {
let control = StreamControl::new();
assert!(!control.is_cancelled());
assert!(!control.is_paused());
control.pause();
assert!(control.is_paused());
control.resume();
assert!(!control.is_paused());
control.cancel();
assert!(control.is_cancelled());
}
#[test]
fn test_collector_sink() {
let mut sink = CollectorSink::new();
sink.process(StreamEvent::Data(1)).unwrap();
sink.process(StreamEvent::Data(2)).unwrap();
sink.process(StreamEvent::Data(3)).unwrap();
assert_eq!(sink.items(), &[1, 2, 3]);
assert_eq!(sink.items_processed(), 3);
}
#[test]
fn test_backpressure_strategy_default() {
let strategy = BackpressureStrategy::default();
assert_eq!(strategy, BackpressureStrategy::Block);
}
#[test]
fn test_stream_config_default() {
let config = StreamConfig::default();
assert_eq!(config.buffer_size, 1000);
assert!(config.enable_progress);
assert_eq!(config.progress_interval, 100);
}
}