use std::{
fmt::Display,
sync::Arc,
time::{Duration, Instant},
};
use crate::{error::error_utils, config::ProcessorConfig, debug::debug};
use tokio::sync::{mpsc, oneshot};
use async_trait::async_trait;
use tokio::select;
use crate::{metrics, ForgeResult};
type QueueReceiver<T, O> =
Arc<tokio::sync::Mutex<Option<mpsc::Receiver<QueuedTask<T, O>>>>>;
#[derive(Debug, Clone, PartialEq)]
pub enum TaskStatus {
Pending,
Processing,
Completed,
Failed(String),
Timeout,
Cancelled,
}
impl From<&TaskStatus> for &'static str {
fn from(status: &TaskStatus) -> Self {
match status {
TaskStatus::Pending => "pending",
TaskStatus::Processing => "processing",
TaskStatus::Completed => "completed",
TaskStatus::Failed(_) => "failed",
TaskStatus::Timeout => "timeout",
TaskStatus::Cancelled => "cancelled",
}
}
}
#[derive(Debug)]
pub enum ProcessorError {
QueueFull,
TaskFailed(String),
InternalError(String),
TaskTimeout,
TaskCancelled,
RetryExhausted(String),
}
impl Display for ProcessorError {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
match self {
ProcessorError::QueueFull => write!(f, "任务队列已满"),
ProcessorError::TaskFailed(msg) => {
write!(f, "任务执行失败: {msg}")
},
ProcessorError::InternalError(msg) => {
write!(f, "内部错误: {msg}")
},
ProcessorError::TaskTimeout => {
write!(f, "任务执行超时")
},
ProcessorError::TaskCancelled => write!(f, "任务被取消"),
ProcessorError::RetryExhausted(msg) => {
write!(f, "重试次数耗尽: {msg}")
},
}
}
}
impl std::error::Error for ProcessorError {}
#[derive(Debug, Default, Clone)]
pub struct ProcessorStats {
pub total_tasks: u64,
pub completed_tasks: u64,
pub failed_tasks: u64,
pub timeout_tasks: u64,
pub cancelled_tasks: u64,
pub current_queue_size: usize,
pub current_processing_tasks: usize,
}
#[derive(Debug)]
pub struct TaskResult<T, O>
where
T: Send + Sync,
O: Send + Sync,
{
pub task_id: u64,
pub status: TaskStatus,
pub task: Option<T>,
pub output: Option<O>,
pub error: Option<String>,
pub processing_time: Option<Duration>,
}
struct QueuedTask<T, O>
where
T: Send + Sync,
O: Send + Sync,
{
task: T,
task_id: u64,
result_tx: mpsc::Sender<TaskResult<T, O>>,
priority: u32,
retry_count: u32,
}
pub struct TaskQueue<T, O>
where
T: Send + Sync,
O: Send + Sync,
{
queue: mpsc::Sender<QueuedTask<T, O>>,
queue_rx: QueueReceiver<T, O>,
next_task_id: Arc<tokio::sync::Mutex<u64>>,
stats: Arc<tokio::sync::Mutex<ProcessorStats>>,
}
impl<T: Clone + Send + Sync + 'static, O: Clone + Send + Sync + 'static>
TaskQueue<T, O>
{
pub fn new(config: &ProcessorConfig) -> Self {
let (tx, rx) = mpsc::channel(config.max_queue_size);
Self {
queue: tx,
queue_rx: Arc::new(tokio::sync::Mutex::new(Some(rx))),
next_task_id: Arc::new(tokio::sync::Mutex::new(0)),
stats: Arc::new(tokio::sync::Mutex::new(ProcessorStats::default())),
}
}
pub async fn enqueue_task(
&self,
task: T,
priority: u32,
) -> ForgeResult<(u64, mpsc::Receiver<TaskResult<T, O>>)> {
let mut task_id = self.next_task_id.lock().await;
*task_id += 1;
let current_id = *task_id;
let (result_tx, result_rx) = mpsc::channel(1);
let queued_task = QueuedTask {
task,
task_id: current_id,
result_tx,
priority,
retry_count: 0,
};
self.queue
.send(queued_task)
.await
.map_err(|_| error_utils::resource_exhausted_error("任务队列"))?;
let mut stats = self.stats.lock().await;
stats.total_tasks += 1;
stats.current_queue_size += 1;
metrics::task_submitted();
metrics::set_queue_size(stats.current_queue_size);
Ok((current_id, result_rx))
}
pub async fn get_next_ready(
&self
) -> Option<(T, u64, mpsc::Sender<TaskResult<T, O>>, u32, u32)> {
let mut rx_guard = self.queue_rx.lock().await;
if let Some(rx) = rx_guard.as_mut() {
if let Some(queued) = rx.recv().await {
let mut stats: tokio::sync::MutexGuard<'_, ProcessorStats> =
self.stats.lock().await;
stats.current_queue_size -= 1;
stats.current_processing_tasks += 1;
metrics::set_queue_size(stats.current_queue_size);
metrics::increment_processing_tasks();
return Some((
queued.task,
queued.task_id,
queued.result_tx,
queued.priority,
queued.retry_count,
));
}
}
None
}
pub async fn get_stats(&self) -> ProcessorStats {
self.stats.lock().await.clone()
}
pub async fn update_stats(
&self,
result: &TaskResult<T, O>,
) {
let mut stats = self.stats.lock().await;
stats.current_processing_tasks -= 1;
metrics::decrement_processing_tasks();
let status_str: &'static str = (&result.status).into();
metrics::task_processed(status_str);
if let Some(duration) = result.processing_time {
metrics::task_processing_duration(duration);
}
match result.status {
TaskStatus::Completed => {
stats.completed_tasks += 1;
},
TaskStatus::Failed(_) => stats.failed_tasks += 1,
TaskStatus::Timeout => stats.timeout_tasks += 1,
TaskStatus::Cancelled => stats.cancelled_tasks += 1,
_ => {},
}
}
}
#[async_trait]
pub trait TaskProcessor<T, O>: Send + Sync + 'static
where
T: Clone + Send + Sync + 'static,
O: Clone + Send + Sync + 'static,
{
async fn process(
&self,
task: T,
) -> Result<O, ProcessorError>;
}
#[derive(Debug, Clone, PartialEq)]
pub enum ProcessorState {
NotStarted,
Running,
Shutting,
Shutdown,
}
pub struct AsyncProcessor<T, O, P>
where
T: Clone + Send + Sync + 'static,
O: Clone + Send + Sync + 'static,
P: TaskProcessor<T, O>,
{
task_queue: Arc<TaskQueue<T, O>>,
config: ProcessorConfig,
processor: Arc<P>,
shutdown_tx: Option<oneshot::Sender<()>>,
handle: Option<tokio::task::JoinHandle<()>>,
state: Arc<tokio::sync::Mutex<ProcessorState>>,
}
impl<T, O, P> AsyncProcessor<T, O, P>
where
T: Clone + Send + Sync + 'static,
O: Clone + Send + Sync + 'static,
P: TaskProcessor<T, O>,
{
pub fn new(
config: ProcessorConfig,
processor: P,
) -> Self {
let task_queue = Arc::new(TaskQueue::new(&config));
Self {
task_queue,
config,
processor: Arc::new(processor),
shutdown_tx: None,
handle: None,
state: Arc::new(tokio::sync::Mutex::new(
ProcessorState::NotStarted,
)),
}
}
pub async fn submit_task(
&self,
task: T,
priority: u32,
) -> ForgeResult<(u64, mpsc::Receiver<TaskResult<T, O>>)> {
self.task_queue.enqueue_task(task, priority).await
}
pub async fn start(&mut self) -> Result<(), ProcessorError> {
let mut state = self.state.lock().await;
if *state != ProcessorState::NotStarted {
return Err(ProcessorError::InternalError(
"处理器已经启动或正在关闭".to_string(),
));
}
*state = ProcessorState::Running;
drop(state);
let queue = self.task_queue.clone();
let processor = self.processor.clone();
let config = self.config.clone();
let state_ref = self.state.clone();
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
self.shutdown_tx = Some(shutdown_tx);
let handle = tokio::spawn(async move {
let mut join_set = tokio::task::JoinSet::new();
async fn cleanup_tasks(
join_set: &mut tokio::task::JoinSet<()>,
timeout: Duration,
) {
debug!("开始清理正在运行的任务...");
let cleanup_start = Instant::now();
while !join_set.is_empty() {
if cleanup_start.elapsed() > timeout {
debug!("清理超时,强制中止剩余任务");
join_set.abort_all();
break;
}
if let Some(Err(e)) = join_set.join_next().await {
if !e.is_cancelled() {
debug!("任务执行失败: {}", e);
}
}
}
debug!("任务清理完成");
}
loop {
select! {
_ = &mut shutdown_rx => {
debug!("收到关闭信号,开始优雅关闭");
{
let mut state = state_ref.lock().await;
*state = ProcessorState::Shutting;
}
cleanup_tasks(&mut join_set, Duration::from_secs(30)).await;
break;
}
Some(result) = join_set.join_next() => {
if let Err(e) = result {
if !e.is_cancelled() {
debug!("任务执行失败: {}", e);
}
}
}
Some((task, task_id, result_tx, _priority, retry_count)) = queue.get_next_ready() => {
{
let state = state_ref.lock().await;
if *state != ProcessorState::Running {
let task_result = TaskResult {
task_id,
status: TaskStatus::Cancelled,
task: Some(task),
output: None,
error: Some("处理器正在关闭".to_string()),
processing_time: Some(Duration::from_millis(0)),
};
queue.update_stats(&task_result).await;
let _ = result_tx.send(task_result).await;
continue;
}
}
if join_set.len() < config.max_concurrent_tasks {
let processor = processor.clone();
let config = config.clone();
let queue = queue.clone();
join_set.spawn(async move {
let start_time = Instant::now();
let mut current_retry = retry_count;
loop {
let result = tokio::time::timeout(
config.task_timeout,
processor.process(task.clone())
).await;
match result {
Ok(Ok(output)) => {
let processing_time = start_time.elapsed();
let task_result = TaskResult {
task_id,
status: TaskStatus::Completed,
task: Some(task),
output: Some(output),
error: None,
processing_time: Some(processing_time),
};
queue.update_stats(&task_result).await;
let _ = result_tx.send(task_result).await;
break;
}
Ok(Err(e)) => {
if current_retry < config.max_retries {
current_retry += 1;
tokio::time::sleep(config.retry_delay).await;
continue;
}
let task_result = TaskResult {
task_id,
status: TaskStatus::Failed(e.to_string()),
task: Some(task),
output: None,
error: Some(e.to_string()),
processing_time: Some(start_time.elapsed()),
};
queue.update_stats(&task_result).await;
let _ = result_tx.send(task_result).await;
break;
}
Err(_) => {
let task_result = TaskResult {
task_id,
status: TaskStatus::Timeout,
task: Some(task),
output: None,
error: Some("任务执行超时".to_string()),
processing_time: Some(start_time.elapsed()),
};
queue.update_stats(&task_result).await;
let _ = result_tx.send(task_result).await;
break;
}
}
}
});
}
}
}
}
{
let mut state = state_ref.lock().await;
*state = ProcessorState::Shutdown;
}
debug!("异步处理器已完全关闭");
});
self.handle = Some(handle);
Ok(())
}
pub async fn shutdown(&mut self) -> Result<(), ProcessorError> {
{
let mut state = self.state.lock().await;
match *state {
ProcessorState::NotStarted => {
return Err(ProcessorError::InternalError(
"处理器尚未启动".to_string(),
));
},
ProcessorState::Shutdown => {
return Ok(()); },
ProcessorState::Shutting => {
},
ProcessorState::Running => {
*state = ProcessorState::Shutting;
},
}
}
if let Some(shutdown_tx) = self.shutdown_tx.take() {
shutdown_tx.send(()).map_err(|_| {
ProcessorError::InternalError(
"Failed to send shutdown signal".to_string(),
)
})?;
}
if let Some(handle) = self.handle.take() {
if let Err(e) = handle.await {
return Err(ProcessorError::InternalError(format!(
"等待后台任务完成时出错: {e}"
)));
}
}
{
let state = self.state.lock().await;
if *state != ProcessorState::Shutdown {
return Err(ProcessorError::InternalError(
"关闭过程未正确完成".to_string(),
));
}
}
debug!("异步处理器已成功关闭");
Ok(())
}
pub async fn get_state(&self) -> ProcessorState {
let state = self.state.lock().await;
state.clone()
}
pub async fn is_running(&self) -> bool {
let state = self.state.lock().await;
*state == ProcessorState::Running
}
pub async fn get_stats(&self) -> ProcessorStats {
self.task_queue.get_stats().await
}
}
impl<T, O, P> Drop for AsyncProcessor<T, O, P>
where
T: Clone + Send + Sync + 'static,
O: Clone + Send + Sync + 'static,
P: TaskProcessor<T, O>,
{
fn drop(&mut self) {
if let Some(shutdown_tx) = self.shutdown_tx.take() {
let _ = shutdown_tx.send(());
debug!("AsyncProcessor Drop: 已发送关闭信号");
}
if let Some(handle) = self.handle.take() {
handle.abort();
debug!("AsyncProcessor Drop: 已中止后台任务");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestProcessor;
#[async_trait::async_trait]
impl TaskProcessor<i32, String> for TestProcessor {
async fn process(
&self,
task: i32,
) -> Result<String, ProcessorError> {
tokio::time::sleep(Duration::from_millis(100)).await;
Ok(format!("Processed: {task}"))
}
}
#[tokio::test]
async fn test_async_processor() {
let config = ProcessorConfig {
max_queue_size: 100,
max_concurrent_tasks: 5,
task_timeout: Duration::from_secs(1),
max_retries: 3,
retry_delay: Duration::from_secs(1),
cleanup_timeout: Duration::from_secs(10),
};
let mut processor = AsyncProcessor::new(config, TestProcessor);
processor.start().await.unwrap();
let mut receivers = Vec::new();
for i in 0..10 {
let (_, rx) = processor.submit_task(i, 0).await.unwrap();
receivers.push(rx);
}
for mut rx in receivers {
let result = rx.recv().await.unwrap();
assert_eq!(result.status, TaskStatus::Completed);
assert!(result.error.is_none());
assert!(result.output.is_some());
}
processor.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_processor_shutdown() {
let config = ProcessorConfig {
max_queue_size: 100,
max_concurrent_tasks: 5,
task_timeout: Duration::from_secs(1),
max_retries: 3,
retry_delay: Duration::from_secs(1),
cleanup_timeout: Duration::from_secs(10),
};
let mut processor = AsyncProcessor::new(config, TestProcessor);
processor.start().await.unwrap();
let mut receivers = Vec::new();
for i in 0..5 {
let (_, rx) = processor.submit_task(i, 0).await.unwrap();
receivers.push(rx);
}
processor.shutdown().await.unwrap();
for mut rx in receivers {
let result = rx.recv().await.unwrap();
assert_eq!(result.status, TaskStatus::Completed);
}
}
#[tokio::test]
async fn test_processor_auto_shutdown() {
let config = ProcessorConfig {
max_queue_size: 100,
max_concurrent_tasks: 5,
task_timeout: Duration::from_secs(1),
max_retries: 3,
retry_delay: Duration::from_secs(1),
cleanup_timeout: Duration::from_secs(10),
};
let mut processor = AsyncProcessor::new(config, TestProcessor);
processor.start().await.unwrap();
let mut receivers = Vec::new();
for i in 0..5 {
let (_, rx) = processor.submit_task(i, 0).await.unwrap();
receivers.push(rx);
}
drop(processor);
for mut rx in receivers {
let result = rx.recv().await.unwrap();
assert!(matches!(
result.status,
TaskStatus::Completed | TaskStatus::Cancelled
));
}
}
}