#![allow(dead_code)]
use crate::utils::error::{GatewayError, Result};
use futures::{Future, StreamExt};
use std::time::Duration;
use tokio::time::{sleep, timeout};
use tracing::{debug, error, warn};
#[derive(Clone)]
pub struct ConcurrentRunner {
max_concurrent: usize,
timeout_duration: Option<Duration>,
}
impl ConcurrentRunner {
pub fn new(max_concurrent: usize) -> Self {
Self {
max_concurrent,
timeout_duration: None,
}
}
pub fn with_timeout(mut self, timeout_duration: Duration) -> Self {
self.timeout_duration = Some(timeout_duration);
self
}
pub async fn run_concurrent<F, T, E>(&self, futures: Vec<F>) -> Vec<std::result::Result<T, E>>
where
F: Future<Output = std::result::Result<T, E>> + Send + 'static,
T: Send + 'static,
E: Send + 'static,
{
let stream = futures::stream::iter(futures)
.map(|fut| async move { fut.await })
.buffer_unordered(self.max_concurrent);
stream.collect().await
}
pub async fn run_concurrent_ok<F, T, E>(&self, futures: Vec<F>) -> Vec<T>
where
F: Future<Output = std::result::Result<T, E>> + Send + 'static,
T: Send + 'static,
E: Send + 'static + std::fmt::Debug,
{
let results = self.run_concurrent(futures).await;
results
.into_iter()
.filter_map(|result| match result {
Ok(value) => Some(value),
Err(e) => {
debug!("Concurrent operation failed: {:?}", e);
None
}
})
.collect()
}
}
#[derive(Clone)]
pub struct RetryPolicy {
max_attempts: usize,
base_delay: Duration,
max_delay: Duration,
backoff_multiplier: f64,
}
impl RetryPolicy {
pub fn new(max_attempts: usize) -> Self {
Self {
max_attempts,
base_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
}
}
pub fn with_base_delay(mut self, delay: Duration) -> Self {
self.base_delay = delay;
self
}
pub fn with_max_delay(mut self, delay: Duration) -> Self {
self.max_delay = delay;
self
}
pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
self.backoff_multiplier = multiplier;
self
}
pub async fn execute<F, Fut, T, E>(&self, mut operation: F) -> std::result::Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = std::result::Result<T, E>>,
E: std::fmt::Debug,
{
let mut attempt = 0;
let mut delay = self.base_delay;
loop {
attempt += 1;
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
if attempt >= self.max_attempts {
error!("Operation failed after {} attempts: {:?}", attempt, e);
return Err(e);
}
warn!(
"Operation failed (attempt {}/{}): {:?}. Retrying in {:?}",
attempt, self.max_attempts, e, delay
);
sleep(delay).await;
delay = std::cmp::min(
Duration::from_millis(
(delay.as_millis() as f64 * self.backoff_multiplier) as u64,
),
self.max_delay,
);
}
}
}
}
}
pub struct BatchProcessor {
batch_size: usize,
flush_interval: Duration,
}
impl BatchProcessor {
pub fn new(batch_size: usize, flush_interval: Duration) -> Self {
Self {
batch_size,
flush_interval,
}
}
pub async fn process<T, F, Fut, R, E>(
&self,
items: Vec<T>,
processor: F,
) -> Vec<std::result::Result<R, E>>
where
T: Clone,
F: Fn(Vec<T>) -> Fut + Clone,
Fut: Future<Output = std::result::Result<Vec<R>, E>>,
E: Clone,
{
let mut results = Vec::new();
for chunk in items.chunks(self.batch_size) {
match processor(chunk.to_vec()).await {
Ok(batch_results) => results.extend(batch_results.into_iter().map(Ok)),
Err(e) => {
for _ in chunk {
results.push(Err(e.clone()));
}
}
}
}
results
}
}
pub struct GracefulShutdown {
shutdown_timeout: Duration,
}
impl GracefulShutdown {
pub fn new(shutdown_timeout: Duration) -> Self {
Self { shutdown_timeout }
}
pub async fn wait_for_shutdown<F, Fut>(&self, cleanup: F) -> Result<()>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<()>>,
{
tokio::signal::ctrl_c().await.map_err(|e| {
GatewayError::Internal(format!("Failed to listen for shutdown signal: {}", e))
})?;
debug!("Shutdown signal received, starting graceful shutdown");
match timeout(self.shutdown_timeout, cleanup()).await {
Ok(Ok(())) => {
debug!("Graceful shutdown completed successfully");
Ok(())
}
Ok(Err(e)) => {
error!("Error during graceful shutdown: {}", e);
Err(e)
}
Err(_) => {
error!(
"Graceful shutdown timed out after {:?}",
self.shutdown_timeout
);
Err(GatewayError::Timeout(
"Graceful shutdown timed out".to_string(),
))
}
}
}
}
pub fn default_concurrent_runner() -> ConcurrentRunner {
ConcurrentRunner::new(10).with_timeout(Duration::from_secs(30))
}
pub fn default_retry_policy() -> RetryPolicy {
RetryPolicy::new(3)
.with_base_delay(Duration::from_millis(100))
.with_max_delay(Duration::from_secs(5))
.with_backoff_multiplier(2.0)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[tokio::test]
async fn test_concurrent_runner() {
let runner = ConcurrentRunner::new(2);
let counter = Arc::new(AtomicUsize::new(0));
let futures: Vec<_> = (0..5)
.map(|_| {
let counter = counter.clone();
async move {
counter.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(10)).await;
Ok::<_, GatewayError>(())
}
})
.collect();
let results = runner.run_concurrent(futures).await;
assert_eq!(results.len(), 5);
assert_eq!(counter.load(Ordering::SeqCst), 5);
}
#[tokio::test]
async fn test_retry_policy() {
let policy = RetryPolicy::new(3);
let counter = Arc::new(AtomicUsize::new(0));
let result = policy
.execute(|| {
let counter = counter.clone();
async move {
let count = counter.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err("temporary failure")
} else {
Ok("success")
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
}