use crate::Result;
use serde::{Deserialize, Serialize};
use std::sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc,
};
use std::time::{Duration, Instant};
use tokio::sync::{mpsc, Semaphore};
#[derive(Debug, Clone)]
pub struct BatchOptions {
pub concurrency: usize,
pub fail_fast: bool,
pub item_timeout: Duration,
pub batch_timeout: Option<Duration>,
pub progress_interval: Duration,
}
impl Default for BatchOptions {
fn default() -> Self {
Self {
concurrency: 10,
fail_fast: false,
item_timeout: Duration::from_secs(30),
batch_timeout: None,
progress_interval: Duration::from_secs(1),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchProgress {
pub total: u64,
pub completed: u64,
pub failed: u64,
pub in_progress: u64,
pub elapsed_ms: u64,
pub estimated_remaining_ms: Option<u64>,
pub items_per_second: f64,
}
impl BatchProgress {
pub fn percentage(&self) -> f64 {
if self.total == 0 {
100.0
} else {
(self.completed as f64 / self.total as f64) * 100.0
}
}
}
#[derive(Debug, Clone)]
pub enum BatchItemResult<T> {
Success(T),
Failure(String),
Skipped(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchResult<T> {
pub successes: Vec<T>,
pub failures: Vec<(usize, String)>,
pub skipped: Vec<(usize, String)>,
pub total_time_ms: u64,
pub cancelled: bool,
}
impl<T> BatchResult<T> {
pub fn is_all_success(&self) -> bool {
self.failures.is_empty() && !self.cancelled
}
pub fn success_count(&self) -> usize {
self.successes.len()
}
pub fn failure_count(&self) -> usize {
self.failures.len()
}
}
pub struct BatchProcessor {
options: BatchOptions,
cancelled: Arc<AtomicBool>,
}
impl BatchProcessor {
pub fn new(options: BatchOptions) -> Self {
Self {
options,
cancelled: Arc::new(AtomicBool::new(false)),
}
}
pub fn with_defaults() -> Self {
Self::new(BatchOptions::default())
}
pub fn cancel(&self) {
self.cancelled.store(true, Ordering::SeqCst);
}
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::SeqCst)
}
pub async fn process<T, R, F, Fut>(
&self,
items: Vec<T>,
f: F,
progress_callback: Option<mpsc::Sender<BatchProgress>>,
) -> BatchResult<R>
where
T: Send + 'static,
R: Send + 'static,
F: Fn(T) -> Fut + Send + Sync + Clone + 'static,
Fut: std::future::Future<Output = Result<R>> + Send,
{
let start = Instant::now();
let total = items.len() as u64;
let completed = Arc::new(AtomicU64::new(0));
let failed = Arc::new(AtomicU64::new(0));
let in_progress = Arc::new(AtomicU64::new(0));
let semaphore = Arc::new(Semaphore::new(self.options.concurrency));
let cancelled = Arc::clone(&self.cancelled);
let fail_fast = self.options.fail_fast;
let progress_completed = Arc::clone(&completed);
let progress_failed = Arc::clone(&failed);
let progress_in_progress = Arc::clone(&in_progress);
let progress_interval = self.options.progress_interval;
if let Some(tx) = progress_callback.as_ref() {
let tx = tx.clone();
let cancelled = Arc::clone(&cancelled);
tokio::spawn(async move {
let mut interval = tokio::time::interval(progress_interval);
loop {
interval.tick().await;
if cancelled.load(Ordering::SeqCst) {
break;
}
let comp = progress_completed.load(Ordering::SeqCst);
let fail = progress_failed.load(Ordering::SeqCst);
let prog = progress_in_progress.load(Ordering::SeqCst);
let elapsed = start.elapsed().as_millis() as u64;
let items_per_second = if elapsed > 0 {
(comp as f64 / elapsed as f64) * 1000.0
} else {
0.0
};
let remaining = total.saturating_sub(comp + fail);
let estimated_remaining_ms = if items_per_second > 0.0 {
Some((remaining as f64 / items_per_second * 1000.0) as u64)
} else {
None
};
let progress = BatchProgress {
total,
completed: comp,
failed: fail,
in_progress: prog,
elapsed_ms: elapsed,
estimated_remaining_ms,
items_per_second,
};
if tx.send(progress).await.is_err() {
break;
}
if comp + fail >= total {
break;
}
}
});
}
let mut handles = Vec::with_capacity(items.len());
for (idx, item) in items.into_iter().enumerate() {
let permit = semaphore.clone().acquire_owned().await.unwrap();
let f = f.clone();
let completed = Arc::clone(&completed);
let failed = Arc::clone(&failed);
let in_progress = Arc::clone(&in_progress);
let cancelled = Arc::clone(&cancelled);
let timeout = self.options.item_timeout;
in_progress.fetch_add(1, Ordering::SeqCst);
let handle = tokio::spawn(async move {
let _permit = permit;
if cancelled.load(Ordering::SeqCst) {
in_progress.fetch_sub(1, Ordering::SeqCst);
return (idx, BatchItemResult::Skipped("Cancelled".to_string()));
}
let result = match tokio::time::timeout(timeout, f(item)).await {
Ok(Ok(value)) => {
completed.fetch_add(1, Ordering::SeqCst);
BatchItemResult::Success(value)
}
Ok(Err(e)) => {
failed.fetch_add(1, Ordering::SeqCst);
if fail_fast {
cancelled.store(true, Ordering::SeqCst);
}
BatchItemResult::Failure(e.to_string())
}
Err(_) => {
failed.fetch_add(1, Ordering::SeqCst);
BatchItemResult::Failure("Timeout".to_string())
}
};
in_progress.fetch_sub(1, Ordering::SeqCst);
(idx, result)
});
handles.push(handle);
}
let mut successes = Vec::new();
let mut failures = Vec::new();
let mut skipped = Vec::new();
for handle in handles {
match handle.await {
Ok((_idx, BatchItemResult::Success(value))) => {
successes.push(value);
}
Ok((idx, BatchItemResult::Failure(err))) => {
failures.push((idx, err));
}
Ok((idx, BatchItemResult::Skipped(reason))) => {
skipped.push((idx, reason));
}
Err(e) => {
failures.push((0, format!("Task panic: {}", e)));
}
}
}
BatchResult {
successes,
failures,
skipped,
total_time_ms: start.elapsed().as_millis() as u64,
cancelled: self.cancelled.load(Ordering::SeqCst),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchMessageRequest {
pub messages: Vec<BatchMessage>,
#[serde(default)]
pub options: BatchMessageOptions,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchMessage {
pub id: Option<String>,
pub content: String,
pub room_id: String,
pub entity_id: String,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BatchMessageOptions {
#[serde(default = "default_concurrency")]
pub concurrency: usize,
#[serde(default)]
pub fail_fast: bool,
}
fn default_concurrency() -> usize {
5
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchMessageResponse {
pub results: Vec<BatchMessageResult>,
pub errors: Vec<BatchMessageError>,
pub stats: BatchStats,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchMessageResult {
pub message_id: String,
pub response: String,
pub processing_time_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchMessageError {
pub message_id: String,
pub error: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchStats {
pub total: usize,
pub successful: usize,
pub failed: usize,
pub total_time_ms: u64,
pub avg_time_ms: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ZoeyError;
#[tokio::test]
async fn test_batch_processor_success() {
let processor = BatchProcessor::with_defaults();
let items = vec![1, 2, 3, 4, 5];
let result = processor
.process(items, |x| async move { Ok(x * 2) }, None)
.await;
assert!(result.is_all_success());
assert_eq!(result.success_count(), 5);
assert!(result.successes.contains(&2));
assert!(result.successes.contains(&10));
}
#[tokio::test]
async fn test_batch_processor_with_failures() {
let processor = BatchProcessor::with_defaults();
let items = vec![1, 2, 3, 4, 5];
let result = processor
.process(
items,
|x| async move {
if x == 3 {
Err(ZoeyError::other("Test error"))
} else {
Ok(x * 2)
}
},
None,
)
.await;
assert!(!result.is_all_success());
assert_eq!(result.success_count(), 4);
assert_eq!(result.failure_count(), 1);
}
#[tokio::test]
async fn test_batch_processor_fail_fast() {
let options = BatchOptions {
fail_fast: true,
concurrency: 1, ..Default::default()
};
let processor = BatchProcessor::new(options);
let items = vec![1, 2, 3, 4, 5];
let result = processor
.process(
items,
|x| async move {
if x == 2 {
Err(ZoeyError::other("Test error"))
} else {
Ok(x * 2)
}
},
None,
)
.await;
assert!(result.cancelled);
}
#[tokio::test]
async fn test_batch_progress_calculation() {
let progress = BatchProgress {
total: 100,
completed: 50,
failed: 10,
in_progress: 5,
elapsed_ms: 1000,
estimated_remaining_ms: Some(800),
items_per_second: 50.0,
};
assert_eq!(progress.percentage(), 50.0);
}
}