use std::any::Any;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
use crossbeam_channel::{Receiver, Sender};
use tokio::runtime::Handle;
use crate::types::TaskId;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TaskStatus {
Pending,
Running,
Completed,
Failed,
Timeout,
}
pub trait AsyncTask: Send + Sync + 'static {
fn task_type(&self) -> &str;
fn execute(&self) -> Pin<Box<dyn Future<Output = AsyncTaskResult> + Send>>;
fn timeout(&self) -> Duration {
Duration::from_secs(30)
}
}
pub struct AsyncTaskResult {
pub task_id: TaskId,
pub task_type: String,
pub payload: Option<Box<dyn Any + Send>>,
pub metadata: TaskMetadata,
}
impl AsyncTaskResult {
pub fn success<T: Any + Send + 'static>(
task_id: TaskId,
task_type: impl Into<String>,
payload: T,
duration: Duration,
) -> Self {
Self {
task_id,
task_type: task_type.into(),
payload: Some(Box::new(payload)),
metadata: TaskMetadata {
duration,
status: TaskStatus::Completed,
error: None,
},
}
}
pub fn failure(
task_id: TaskId,
task_type: impl Into<String>,
error: String,
duration: Duration,
) -> Self {
Self {
task_id,
task_type: task_type.into(),
payload: None,
metadata: TaskMetadata {
duration,
status: TaskStatus::Failed,
error: Some(AsyncTaskError { message: error }),
},
}
}
pub fn timeout(task_id: TaskId, task_type: impl Into<String>, duration: Duration) -> Self {
Self {
task_id,
task_type: task_type.into(),
payload: None,
metadata: TaskMetadata {
duration,
status: TaskStatus::Timeout,
error: Some(AsyncTaskError {
message: "Task timed out".to_string(),
}),
},
}
}
}
pub struct TaskMetadata {
pub duration: Duration,
pub status: TaskStatus,
pub error: Option<AsyncTaskError>,
}
#[derive(Debug, Clone, thiserror::Error)]
#[error("Async task error: {message}")]
pub struct AsyncTaskError {
pub message: String,
}
impl From<crate::error::SwarmError> for AsyncTaskError {
fn from(err: crate::error::SwarmError) -> Self {
Self {
message: err.message().to_string(),
}
}
}
impl From<AsyncTaskError> for crate::error::SwarmError {
fn from(err: AsyncTaskError) -> Self {
crate::error::SwarmError::AsyncTask {
message: err.message,
}
}
}
pub struct AsyncTaskSystem {
result_tx: Sender<AsyncTaskResult>,
result_rx: Receiver<AsyncTaskResult>,
runtime: Handle,
factories: HashMap<String, Box<dyn AsyncTaskFactory>>,
}
impl AsyncTaskSystem {
pub fn new(runtime: Handle) -> Self {
let (tx, rx) = crossbeam_channel::unbounded();
Self {
result_tx: tx,
result_rx: rx,
runtime,
factories: HashMap::new(),
}
}
pub fn spawn<T: AsyncTask>(&self, task: T) -> TaskId {
self.spawn_boxed(Box::new(task))
}
pub fn spawn_boxed(&self, task: Box<dyn AsyncTask>) -> TaskId {
let task_id = TaskId::new();
let tx = self.result_tx.clone();
let timeout_duration = task.timeout();
let task_type = task.task_type().to_string();
self.runtime.spawn(async move {
let start = std::time::Instant::now();
let result = tokio::time::timeout(timeout_duration, task.execute()).await;
let duration = start.elapsed();
let task_result = match result {
Ok(mut r) => {
r.task_id = task_id;
r.task_type = task_type;
r.metadata.duration = duration;
r
}
Err(_) => AsyncTaskResult::timeout(task_id, task_type, duration),
};
let _ = tx.send(task_result);
});
task_id
}
pub fn collect_results(&self) -> Vec<AsyncTaskResult> {
let mut results = Vec::new();
while let Ok(result) = self.result_rx.try_recv() {
results.push(result);
}
results
}
pub fn register_factory<F: AsyncTaskFactory + 'static>(&mut self, name: &str, factory: F) {
self.factories.insert(name.to_string(), Box::new(factory));
}
pub fn create_task(&self, name: &str, params: TaskParams) -> Option<Box<dyn AsyncTask>> {
self.factories.get(name).map(|f| f.create(params))
}
}
pub trait AsyncTaskFactory: Send + Sync {
fn create(&self, params: TaskParams) -> Box<dyn AsyncTask>;
}
#[derive(Debug, Clone, Default)]
pub struct TaskParams {
pub data: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct AsyncConfig {
pub max_concurrent: usize,
pub default_timeout_secs: u64,
}
impl Default for AsyncConfig {
fn default() -> Self {
Self {
max_concurrent: 100,
default_timeout_secs: 30,
}
}
}
pub struct DelayTask {
pub delay: Duration,
pub result: String,
}
impl AsyncTask for DelayTask {
fn task_type(&self) -> &str {
"delay"
}
fn execute(&self) -> Pin<Box<dyn Future<Output = AsyncTaskResult> + Send>> {
let delay = self.delay;
let result = self.result.clone();
Box::pin(async move {
tokio::time::sleep(delay).await;
AsyncTaskResult::success(TaskId::new(), "delay", result, delay)
})
}
fn timeout(&self) -> Duration {
self.delay + Duration::from_secs(1)
}
}