use std::{
fmt::Display,
sync::Arc,
time::{Duration, Instant},
};
use moduforge_state::debug;
use tokio::sync::{mpsc, oneshot};
use async_trait::async_trait;
use tokio::select;
#[derive(Debug, Clone, PartialEq)]
pub enum TaskStatus {
Pending,
Processing,
Completed,
Failed(String),
Timeout,
Cancelled,
}
#[derive(Debug)]
pub enum ProcessorError {
QueueFull,
TaskFailed(String),
InternalError(String),
TaskTimeout,
TaskCancelled,
RetryExhausted(String),
}
impl Display for ProcessorError {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
match self {
ProcessorError::QueueFull => write!(f, "Task queue is full"),
ProcessorError::TaskFailed(msg) => {
write!(f, "Task failed: {}", msg)
},
ProcessorError::InternalError(msg) => {
write!(f, "Internal error: {}", msg)
},
ProcessorError::TaskTimeout => {
write!(f, "Task execution timed out")
},
ProcessorError::TaskCancelled => write!(f, "Task was cancelled"),
ProcessorError::RetryExhausted(msg) => {
write!(f, "Retry attempts exhausted: {}", msg)
},
}
}
}
impl std::error::Error for ProcessorError {}
#[derive(Clone, Debug)]
pub struct ProcessorConfig {
pub max_queue_size: usize,
pub max_concurrent_tasks: usize,
pub task_timeout: Duration,
pub max_retries: u32,
pub retry_delay: Duration,
}
impl Default for ProcessorConfig {
fn default() -> Self {
Self {
max_queue_size: 1000,
max_concurrent_tasks: 10,
task_timeout: Duration::from_secs(30),
max_retries: 3,
retry_delay: Duration::from_secs(1),
}
}
}
#[derive(Debug, Default, Clone)]
pub struct ProcessorStats {
pub total_tasks: u64,
pub completed_tasks: u64,
pub failed_tasks: u64,
pub timeout_tasks: u64,
pub cancelled_tasks: u64,
pub average_processing_time: Duration,
pub current_queue_size: usize,
pub current_processing_tasks: usize,
}
#[derive(Debug)]
pub struct TaskResult<T, O>
where
T: Send + Sync,
O: Send + Sync,
{
pub task_id: u64,
pub status: TaskStatus,
pub task: Option<T>,
pub output: Option<O>,
pub error: Option<String>,
pub processing_time: Option<Duration>,
}
struct QueuedTask<T, O>
where
T: Send + Sync,
O: Send + Sync,
{
task: T,
task_id: u64,
result_tx: mpsc::Sender<TaskResult<T, O>>,
priority: u32,
retry_count: u32,
}
pub struct TaskQueue<T, O>
where
T: Send + Sync,
O: Send + Sync,
{
queue: mpsc::Sender<QueuedTask<T, O>>,
queue_rx: Arc<tokio::sync::Mutex<Option<mpsc::Receiver<QueuedTask<T, O>>>>>,
next_task_id: Arc<tokio::sync::Mutex<u64>>,
stats: Arc<tokio::sync::Mutex<ProcessorStats>>,
}
impl<T: Clone + Send + Sync + 'static, O: Clone + Send + Sync + 'static>
TaskQueue<T, O>
{
pub fn new(config: &ProcessorConfig) -> Self {
let (tx, rx) = mpsc::channel(config.max_queue_size);
Self {
queue: tx,
queue_rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
next_task_id: Arc::new(tokio::sync::Mutex::new(0)),
stats: Arc::new(tokio::sync::Mutex::new(ProcessorStats::default())),
}
}
pub async fn enqueue_task(
&self,
task: T,
priority: u32,
) -> Result<(u64, mpsc::Receiver<TaskResult<T, O>>), ProcessorError> {
let mut task_id = self.next_task_id.lock().await;
*task_id += 1;
let current_id = *task_id;
let (result_tx, result_rx) = mpsc::channel(1);
let queued_task = QueuedTask {
task,
task_id: current_id,
result_tx,
priority,
retry_count: 0,
};
self.queue
.send(queued_task)
.await
.map_err(|_| ProcessorError::QueueFull)?;
let mut stats = self.stats.lock().await;
stats.total_tasks += 1;
stats.current_queue_size += 1;
Ok((current_id, result_rx))
}
pub async fn get_next_ready(
&self
) -> Option<(T, u64, mpsc::Sender<TaskResult<T, O>>, u32, u32)> {
let mut rx_guard = self.queue_rx.lock().await;
if let Some(rx) = rx_guard.as_mut() {
if let Some(queued) = rx.recv().await {
let mut stats = self.stats.lock().await;
stats.current_queue_size -= 1;
stats.current_processing_tasks += 1;
return Some((
queued.task,
queued.task_id,
queued.result_tx,
queued.priority,
queued.retry_count,
));
}
}
None
}
pub async fn get_stats(&self) -> ProcessorStats {
self.stats.lock().await.clone()
}
pub async fn update_stats(
&self,
result: &TaskResult<T, O>,
) {
let mut stats = self.stats.lock().await;
match result.status {
TaskStatus::Completed => {
stats.completed_tasks += 1;
if let Some(processing_time) = result.processing_time {
stats.average_processing_time =
(stats.average_processing_time + processing_time) / 2;
}
},
TaskStatus::Failed(_) => stats.failed_tasks += 1,
TaskStatus::Timeout => stats.timeout_tasks += 1,
TaskStatus::Cancelled => stats.cancelled_tasks += 1,
_ => {},
}
stats.current_processing_tasks -= 1;
}
}
#[async_trait]
pub trait TaskProcessor<T, O>: Send + Sync + 'static
where
T: Clone + Send + Sync + 'static,
O: Clone + Send + Sync + 'static,
{
async fn process(
&self,
task: T,
) -> Result<O, ProcessorError>;
}
pub struct AsyncProcessor<T, O, P>
where
T: Clone + Send + Sync + 'static,
O: Clone + Send + Sync + 'static,
P: TaskProcessor<T, O>,
{
task_queue: Arc<TaskQueue<T, O>>,
config: ProcessorConfig,
processor: Arc<P>,
shutdown_tx: Option<oneshot::Sender<()>>,
handle: Option<tokio::task::JoinHandle<()>>,
}
impl<T, O, P> AsyncProcessor<T, O, P>
where
T: Clone + Send + Sync + 'static,
O: Clone + Send + Sync + 'static,
P: TaskProcessor<T, O>,
{
pub fn new(
config: ProcessorConfig,
processor: P,
) -> Self {
let task_queue = Arc::new(TaskQueue::new(&config));
Self {
task_queue,
config,
processor: Arc::new(processor),
shutdown_tx: None,
handle: None,
}
}
pub async fn submit_task(
&self,
task: T,
priority: u32,
) -> Result<(u64, mpsc::Receiver<TaskResult<T, O>>), ProcessorError> {
self.task_queue.enqueue_task(task, priority).await
}
pub fn start(&mut self) {
let queue = self.task_queue.clone();
let processor = self.processor.clone();
let config = self.config.clone();
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
self.shutdown_tx = Some(shutdown_tx);
let handle = tokio::spawn(async move {
let mut join_set = tokio::task::JoinSet::new();
loop {
select! {
_ = &mut shutdown_rx => {
break;
}
Some(result) = join_set.join_next() => {
if let Err(e) = result {
debug!("Task failed: {}", e);
}
}
Some((task, task_id, result_tx, _priority, retry_count)) = queue.get_next_ready() => {
if join_set.len() < config.max_concurrent_tasks {
let processor = processor.clone();
let config = config.clone();
let queue = queue.clone();
join_set.spawn(async move {
let start_time = Instant::now();
let mut current_retry = retry_count;
loop {
let result = tokio::time::timeout(
config.task_timeout,
processor.process(task.clone())
).await;
match result {
Ok(Ok(output)) => {
let processing_time = start_time.elapsed();
let task_result = TaskResult {
task_id,
status: TaskStatus::Completed,
task: Some(task),
output: Some(output),
error: None,
processing_time: Some(processing_time),
};
queue.update_stats(&task_result).await;
let _ = result_tx.send(task_result).await;
break;
}
Ok(Err(e)) => {
if current_retry < config.max_retries {
current_retry += 1;
tokio::time::sleep(config.retry_delay).await;
continue;
}
let task_result = TaskResult {
task_id,
status: TaskStatus::Failed(e.to_string()),
task: Some(task),
output: None,
error: Some(e.to_string()),
processing_time: Some(start_time.elapsed()),
};
queue.update_stats(&task_result).await;
let _ = result_tx.send(task_result).await;
break;
}
Err(_) => {
let task_result = TaskResult {
task_id,
status: TaskStatus::Timeout,
task: Some(task),
output: None,
error: Some("Task execution timed out".to_string()),
processing_time: Some(start_time.elapsed()),
};
queue.update_stats(&task_result).await;
let _ = result_tx.send(task_result).await;
break;
}
}
}
});
}
}
}
}
});
self.handle = Some(handle);
}
pub async fn shutdown(&mut self) -> Result<(), ProcessorError> {
if let Some(shutdown_tx) = self.shutdown_tx.take() {
shutdown_tx.send(()).map_err(|_| {
ProcessorError::InternalError(
"Failed to send shutdown signal".to_string(),
)
})?;
if let Some(handle) = self.handle.take() {
handle.await.map_err(|e| {
ProcessorError::InternalError(format!(
"Failed to join processor task: {}",
e
))
})?;
}
}
Ok(())
}
pub async fn get_stats(&self) -> ProcessorStats {
self.task_queue.get_stats().await
}
}
impl<T, O, P> Drop for AsyncProcessor<T, O, P>
where
T: Clone + Send + Sync + 'static,
O: Clone + Send + Sync + 'static,
P: TaskProcessor<T, O>,
{
fn drop(&mut self) {
if self.shutdown_tx.is_some() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(self.shutdown()).unwrap();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestProcessor;
#[async_trait::async_trait]
impl TaskProcessor<i32, String> for TestProcessor {
async fn process(
&self,
task: i32,
) -> Result<String, ProcessorError> {
tokio::time::sleep(Duration::from_millis(100)).await;
Ok(format!("Processed: {}", task))
}
}
#[tokio::test]
async fn test_async_processor() {
let config = ProcessorConfig {
max_queue_size: 100,
max_concurrent_tasks: 5,
task_timeout: Duration::from_secs(1),
max_retries: 3,
retry_delay: Duration::from_secs(1),
};
let mut processor = AsyncProcessor::new(config, TestProcessor);
processor.start();
let mut receivers = Vec::new();
for i in 0..10 {
let (_, rx) = processor.submit_task(i, 0).await.unwrap();
receivers.push(rx);
}
for mut rx in receivers {
let result = rx.recv().await.unwrap();
assert_eq!(result.status, TaskStatus::Completed);
assert!(result.error.is_none());
assert!(result.output.is_some());
}
}
#[tokio::test]
async fn test_processor_shutdown() {
let config = ProcessorConfig {
max_queue_size: 100,
max_concurrent_tasks: 5,
task_timeout: Duration::from_secs(1),
max_retries: 3,
retry_delay: Duration::from_secs(1),
};
let mut processor = AsyncProcessor::new(config, TestProcessor);
processor.start();
let mut receivers = Vec::new();
for i in 0..5 {
let (_, rx) = processor.submit_task(i, 0).await.unwrap();
receivers.push(rx);
}
processor.shutdown().await.unwrap();
for mut rx in receivers {
let result = rx.recv().await.unwrap();
assert_eq!(result.status, TaskStatus::Completed);
}
}
#[tokio::test]
async fn test_processor_auto_shutdown() {
let config = ProcessorConfig {
max_queue_size: 100,
max_concurrent_tasks: 5,
task_timeout: Duration::from_secs(1),
max_retries: 3,
retry_delay: Duration::from_secs(1),
};
let mut processor = AsyncProcessor::new(config, TestProcessor);
processor.start();
let mut receivers = Vec::new();
for i in 0..5 {
let (_, rx) = processor.submit_task(i, 0).await.unwrap();
receivers.push(rx);
}
drop(processor);
for mut rx in receivers {
let result = rx.recv().await.unwrap();
assert_eq!(result.status, TaskStatus::Completed);
}
}
}