use std::sync::Arc;
use std::sync::atomic::{AtomicU8, Ordering};
use std::time::{Duration, Instant};
use tracing::warn;
pub type StreamingProgressCallback = Box<dyn Fn(f32) + Send + Sync>;
#[derive(Clone)]
pub struct StreamingProgressTracker {
callback: Option<Arc<StreamingProgressCallback>>,
warning_threshold: f32,
total_timeout: Duration,
start_time: Arc<Instant>,
last_reported_progress: Arc<AtomicU8>,
}
impl StreamingProgressTracker {
pub fn new(total_timeout: Duration) -> Self {
Self {
callback: None,
warning_threshold: 0.8,
total_timeout,
start_time: Arc::new(Instant::now()),
last_reported_progress: Arc::new(AtomicU8::new(0)),
}
}
pub fn with_callback(mut self, callback: StreamingProgressCallback) -> Self {
self.callback = Some(Arc::new(callback));
self
}
pub fn with_warning_threshold(mut self, threshold: f32) -> Self {
self.warning_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn report_first_chunk(&self) {
self.report_progress(0.1);
}
pub fn report_chunk_received(&self) {
let elapsed = self.start_time.elapsed();
self.report_progress_with_elapsed(elapsed);
}
pub fn report_progress_with_elapsed(&self, elapsed: Duration) {
if self.total_timeout.as_secs() == 0 {
return;
}
let progress = elapsed.as_secs_f32() / self.total_timeout.as_secs_f32();
self.report_progress(progress.min(0.99)); }
pub fn report_error(&self) {
self.report_progress(1.0);
}
pub fn progress_percent(&self) -> u8 {
self.last_reported_progress.load(Ordering::Relaxed)
}
pub fn elapsed(&self) -> Duration {
self.start_time.elapsed()
}
pub fn is_approaching_timeout(&self) -> bool {
let elapsed = self.start_time.elapsed();
if self.total_timeout.as_secs() == 0 {
return false;
}
let elapsed_progress = elapsed.as_secs_f32() / self.total_timeout.as_secs_f32();
let reported_progress =
f32::from(self.last_reported_progress.load(Ordering::Relaxed)) / 100.0;
elapsed_progress.max(reported_progress) >= self.warning_threshold
}
fn report_progress(&self, progress: f32) {
let progress_clamped = progress.clamp(0.0, 1.0);
let percent = (progress_clamped * 100.0) as u8;
let last_percent = self.last_reported_progress.load(Ordering::Relaxed);
if percent <= last_percent {
return;
}
self.last_reported_progress
.store(percent, Ordering::Relaxed);
if let Some(ref callback) = self.callback {
callback(progress_clamped);
}
if progress_clamped >= self.warning_threshold && progress_clamped < 1.0 {
warn!(
"Streaming operation at {:.0}% of timeout limit ({:?}/{:?} elapsed). Approaching timeout.",
progress_clamped * 100.0,
self.elapsed(),
self.total_timeout
);
}
}
}
pub struct StreamingProgressBuilder {
total_timeout: Duration,
callback: Option<StreamingProgressCallback>,
warning_threshold: f32,
}
impl StreamingProgressBuilder {
pub fn new(timeout_secs: u64) -> Self {
Self {
total_timeout: Duration::from_secs(timeout_secs),
callback: None,
warning_threshold: 0.8,
}
}
pub fn with_duration(duration: Duration) -> Self {
Self {
total_timeout: duration,
callback: None,
warning_threshold: 0.8,
}
}
pub fn callback(mut self, callback: StreamingProgressCallback) -> Self {
self.callback = Some(callback);
self
}
pub fn warning_threshold(mut self, threshold: f32) -> Self {
self.warning_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn build(self) -> StreamingProgressTracker {
let mut tracker = StreamingProgressTracker::new(self.total_timeout);
if let Some(callback) = self.callback {
tracker.callback = Some(Arc::new(callback));
}
tracker.warning_threshold = self.warning_threshold;
tracker
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
#[test]
fn test_progress_tracker_creation() {
let tracker = StreamingProgressTracker::new(Duration::from_secs(600));
assert_eq!(tracker.progress_percent(), 0);
assert!(!tracker.is_approaching_timeout());
}
#[test]
fn test_progress_reporting() {
let tracker = StreamingProgressTracker::new(Duration::from_secs(100));
tracker.report_progress_with_elapsed(Duration::from_secs(30));
assert_eq!(tracker.progress_percent(), 30);
tracker.report_progress_with_elapsed(Duration::from_secs(80));
assert_eq!(tracker.progress_percent(), 80);
}
#[test]
fn test_warning_threshold() {
let tracker =
StreamingProgressTracker::new(Duration::from_secs(100)).with_warning_threshold(0.8);
tracker.report_progress_with_elapsed(Duration::from_secs(50));
assert!(!tracker.is_approaching_timeout());
tracker.report_progress_with_elapsed(Duration::from_secs(85));
assert!(tracker.is_approaching_timeout());
}
#[test]
fn test_callback_execution() {
let progress_log = Arc::new(Mutex::new(Vec::new()));
let progress_clone = progress_log.clone();
let tracker = StreamingProgressTracker::new(Duration::from_secs(100)).with_callback(
Box::new(move |progress: f32| {
progress_clone.lock().unwrap().push(progress);
}),
);
tracker.report_progress_with_elapsed(Duration::from_secs(30));
tracker.report_progress_with_elapsed(Duration::from_secs(60));
tracker.report_progress_with_elapsed(Duration::from_secs(90));
let log = progress_log.lock().unwrap();
assert!(!log.is_empty());
assert!(log.iter().all(|&p| (0.0..=1.0).contains(&p)));
}
#[test]
fn test_builder_pattern() {
let tracker = StreamingProgressBuilder::new(300)
.warning_threshold(0.75)
.build();
assert_eq!(tracker.total_timeout.as_secs(), 300);
assert_eq!(tracker.warning_threshold, 0.75);
}
#[test]
fn test_zero_timeout_safety() {
let tracker = StreamingProgressTracker::new(Duration::from_secs(0));
tracker.report_chunk_received(); assert!(!tracker.is_approaching_timeout());
}
#[test]
fn test_progress_clamping() {
let tracker = StreamingProgressTracker::new(Duration::from_secs(100));
tracker.report_progress_with_elapsed(Duration::from_secs(150)); assert_eq!(tracker.progress_percent(), 99);
tracker.report_error();
assert_eq!(tracker.progress_percent(), 100);
}
}