use super::processor::StreamProcessor;
use super::types::StreamSource;
use anyhow::Result;
use async_trait::async_trait;
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::time;
#[derive(Debug, Clone)]
pub enum OverflowStrategy {
DropOldest,
DropNewest,
Block,
Fail,
}
pub struct BufferedStreamProcessor {
inner: Box<dyn StreamProcessor>,
buffer: Arc<Mutex<VecDeque<BufferedLine>>>,
max_buffer_size: usize,
overflow_strategy: OverflowStrategy,
block_timeout: Duration,
}
#[derive(Clone)]
struct BufferedLine {
line: String,
source: StreamSource,
}
impl BufferedStreamProcessor {
pub fn new(
inner: Box<dyn StreamProcessor>,
max_buffer_size: usize,
overflow_strategy: OverflowStrategy,
block_timeout: Duration,
) -> Self {
Self {
inner,
buffer: Arc::new(Mutex::new(VecDeque::new())),
max_buffer_size,
overflow_strategy,
block_timeout,
}
}
pub async fn process_with_backpressure(
&self,
line: String,
source: StreamSource,
) -> Result<()> {
let mut buffer = self.buffer.lock().await;
if buffer.len() >= self.max_buffer_size {
match self.overflow_strategy {
OverflowStrategy::DropOldest => {
buffer.pop_front();
buffer.push_back(BufferedLine { line, source });
tracing::warn!("Buffer overflow: dropped oldest line");
}
OverflowStrategy::DropNewest => {
tracing::warn!("Buffer overflow: dropped newest line");
}
OverflowStrategy::Block => {
drop(buffer);
let start = std::time::Instant::now();
while start.elapsed() < self.block_timeout {
time::sleep(Duration::from_millis(10)).await;
let mut buffer = self.buffer.lock().await;
if buffer.len() < self.max_buffer_size {
buffer.push_back(BufferedLine { line, source });
return Ok(());
}
}
return Err(anyhow::anyhow!(
"Buffer overflow: timeout waiting for space"
));
}
OverflowStrategy::Fail => {
return Err(anyhow::anyhow!(
"Buffer overflow: max size {} reached",
self.max_buffer_size
));
}
}
} else {
buffer.push_back(BufferedLine { line, source });
}
Ok(())
}
pub async fn flush(&self) -> Result<()> {
let mut buffer = self.buffer.lock().await;
let lines: Vec<BufferedLine> = buffer.drain(..).collect();
drop(buffer);
for buffered in lines {
self.inner
.process_line(&buffered.line, buffered.source)
.await?;
}
Ok(())
}
pub async fn buffer_size(&self) -> usize {
self.buffer.lock().await.len()
}
}
#[async_trait]
impl StreamProcessor for BufferedStreamProcessor {
async fn process_line(&self, line: &str, source: StreamSource) -> Result<()> {
self.process_with_backpressure(line.to_string(), source)
.await?;
let mut buffer = self.buffer.lock().await;
if let Some(buffered) = buffer.pop_front() {
drop(buffer); self.inner
.process_line(&buffered.line, buffered.source)
.await?;
}
Ok(())
}
async fn on_complete(&self, exit_code: Option<i32>) -> Result<()> {
self.flush().await?;
self.inner.on_complete(exit_code).await
}
async fn on_error(&self, error: &anyhow::Error) -> Result<()> {
let _ = self.flush().await;
self.inner.on_error(error).await
}
}
pub struct RateLimitedProcessor {
inner: Box<dyn StreamProcessor>,
max_lines_per_second: usize,
last_process_time: Arc<Mutex<std::time::Instant>>,
lines_processed: Arc<Mutex<usize>>,
}
impl RateLimitedProcessor {
pub fn new(inner: Box<dyn StreamProcessor>, max_lines_per_second: usize) -> Self {
Self {
inner,
max_lines_per_second,
last_process_time: Arc::new(Mutex::new(std::time::Instant::now())),
lines_processed: Arc::new(Mutex::new(0)),
}
}
}
#[async_trait]
impl StreamProcessor for RateLimitedProcessor {
async fn process_line(&self, line: &str, source: StreamSource) -> Result<()> {
{
let mut last_time = self.last_process_time.lock().await;
let mut count = self.lines_processed.lock().await;
let now = std::time::Instant::now();
let elapsed = now.duration_since(*last_time);
if elapsed >= Duration::from_secs(1) {
*last_time = now;
*count = 0;
}
if *count >= self.max_lines_per_second {
let wait_time = Duration::from_secs(1) - elapsed;
if wait_time > Duration::ZERO {
drop(last_time);
drop(count);
time::sleep(wait_time).await;
let mut last_time = self.last_process_time.lock().await;
let mut count = self.lines_processed.lock().await;
*last_time = std::time::Instant::now();
*count = 0;
} else {
*count += 1;
}
} else {
*count += 1;
}
}
self.inner.process_line(line, source).await
}
async fn on_complete(&self, exit_code: Option<i32>) -> Result<()> {
self.inner.on_complete(exit_code).await
}
async fn on_error(&self, error: &anyhow::Error) -> Result<()> {
self.inner.on_error(error).await
}
}