use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;
use tracing::{info, warn};
use crate::{
common::task::{
ValorMasterTask, ValorTaskError, ValorTaskId, ValorTaskOutput, ValorTaskStatus,
},
types::ValorID,
};
pub struct ValorTaskManager {
tasks: RwLock<HashMap<ValorTaskId, ValorMasterTask>>,
tasks_by_status: RwLock<HashMap<ValorTaskStatus, Vec<ValorTaskId>>>,
tasks_by_worker: RwLock<HashMap<ValorID, Vec<ValorTaskId>>>,
result_aggregator: RwLock<Option<Arc<crate::master::ValorResultAggregator>>>,
timeout_watchdogs: RwLock<HashMap<ValorTaskId, tokio::task::JoinHandle<()>>>,
}
impl ValorTaskManager {
pub fn new() -> Self {
Self {
tasks: RwLock::new(HashMap::new()),
tasks_by_status: RwLock::new(HashMap::new()),
tasks_by_worker: RwLock::new(HashMap::new()),
result_aggregator: RwLock::new(None),
timeout_watchdogs: RwLock::new(HashMap::new()),
}
}
pub async fn set_result_aggregator(
&self,
aggregator: Arc<crate::master::ValorResultAggregator>,
) {
let mut slot = self.result_aggregator.write().await;
*slot = Some(aggregator);
}
pub async fn create_task(&self, mut task: ValorMasterTask) -> Result<ValorTaskId, String> {
let task_id = task.task_id.clone();
task.status = ValorTaskStatus::Pending;
task.created_at = current_timestamp_ms();
task.attempt = 0;
let mut tasks = self.tasks.write().await;
if tasks.contains_key(&task_id) {
return Err(format!("Task {task_id} already exists"));
}
tasks.insert(task_id.clone(), task.clone());
drop(tasks);
let svc_id = match &task.task_type {
crate::common::task::ValorTaskType::ExecuteService { service_id, .. } => service_id,
};
tracing::info!(
task_id = %task_id,
service_id = %svc_id,
priority = ?task.priority,
"Task created (pending)"
);
let mut by_status = self.tasks_by_status.write().await;
by_status
.entry(ValorTaskStatus::Pending)
.or_insert_with(Vec::new)
.push(task_id.clone());
drop(by_status);
Ok(task_id)
}
pub async fn maybe_spawn_watchdog(self: &Arc<Self>, task_id: &ValorTaskId) {
let timeout_ms = if let Some(task) = self.get_task(task_id).await {
task.timeout_ms
} else {
None
};
if let Some(ms) = timeout_ms {
if self.timeout_watchdogs.read().await.contains_key(task_id) {
return;
}
let id = task_id.clone();
let mgr = Arc::clone(self);
let handle = tokio::spawn(async move {
let dur = std::time::Duration::from_millis(ms);
tokio::time::sleep(dur).await;
if let Some(task) = mgr.get_task(&id).await {
use crate::common::task::ValorTaskStatus::*;
if !matches!(task.status, Completed | Failed | Cancelled) {
let _ = mgr.fail_task_timeout(&id, ms).await;
}
}
});
self.timeout_watchdogs
.write()
.await
.insert(task_id.clone(), handle);
}
}
pub async fn fail_task_timeout(
&self,
task_id: &ValorTaskId,
timeout_ms: u64,
) -> Result<(), String> {
let mut tasks = self.tasks.write().await;
let task = tasks
.get_mut(task_id)
.ok_or_else(|| format!("Task {task_id} not found"))?;
if matches!(
task.status,
ValorTaskStatus::Completed | ValorTaskStatus::Failed | ValorTaskStatus::Cancelled
) {
return Ok(());
}
let old_status = task.status;
task.status = ValorTaskStatus::Failed;
task.completed_at = Some(current_timestamp_ms());
task.error = Some(ValorTaskError::Timeout { timeout_ms });
let snapshot = task.clone();
drop(tasks);
self.update_status_index(task_id, old_status, ValorTaskStatus::Failed)
.await;
if let Some(handle) = self.timeout_watchdogs.write().await.remove(task_id) {
handle.abort();
}
if let Some(agg) = self.result_aggregator.read().await.as_ref().cloned() {
agg.notify_if_terminal(&snapshot).await;
}
info!(task_id = %task_id, timeout_ms, "Task timed out and marked as Failed");
Ok(())
}
pub async fn get_task(&self, task_id: &ValorTaskId) -> Option<ValorMasterTask> {
self.tasks.read().await.get(task_id).cloned()
}
pub async fn list_tasks(&self) -> Vec<ValorMasterTask> {
self.tasks.read().await.values().cloned().collect()
}
pub async fn list_tasks_by_status(&self, status: ValorTaskStatus) -> Vec<ValorMasterTask> {
let by_status = self.tasks_by_status.read().await;
if let Some(task_ids) = by_status.get(&status) {
let tasks = self.tasks.read().await;
task_ids
.iter()
.filter_map(|id| tasks.get(id).cloned())
.collect()
} else {
Vec::new()
}
}
pub async fn assign_task(
&self,
task_id: &ValorTaskId,
worker_id: &ValorID,
) -> Result<(), String> {
let (svc_id, attempt_count) = {
let mut tasks = self.tasks.write().await;
let task = tasks
.get_mut(task_id)
.ok_or_else(|| format!("Task {task_id} not found"))?;
if task.status != ValorTaskStatus::Pending {
return Err(format!("Task {task_id} is not pending"));
}
let service_id = match &task.task_type {
crate::common::task::ValorTaskType::ExecuteService { service_id, .. } => {
service_id.clone()
}
};
task.status = ValorTaskStatus::Assigned;
task.assigned_worker = Some(worker_id.to_string());
task.assigned_at = Some(current_timestamp_ms());
task.attempt = task.attempt.saturating_add(1);
(service_id, task.attempt)
};
self.update_status_index(task_id, ValorTaskStatus::Pending, ValorTaskStatus::Assigned)
.await;
{
let mut by_worker = self.tasks_by_worker.write().await;
by_worker
.entry(worker_id.clone())
.or_insert_with(Vec::new)
.push(task_id.clone());
}
info!(worker_id = %worker_id, service_id = %svc_id, attempt = attempt_count, "Task assigned");
Ok(())
}
pub async fn update_task_status(
&self,
task_id: &ValorTaskId,
worker_id: &ValorID,
status: ValorTaskStatus,
output: Option<ValorTaskOutput>,
error: Option<ValorTaskError>,
) -> Result<(), String> {
let mut tasks = self.tasks.write().await;
let task = tasks
.get_mut(task_id)
.ok_or_else(|| format!("Task {task_id} not found"))?;
if task.assigned_worker.as_ref() != Some(&worker_id.to_string()) {
return Err(format!(
"Task {task_id} is not assigned to worker {worker_id}"
));
}
let old_status = task.status;
task.status = status;
match status {
ValorTaskStatus::Completed => {
task.completed_at = Some(current_timestamp_ms());
task.output = output;
}
ValorTaskStatus::Failed => {
task.completed_at = Some(current_timestamp_ms());
task.error = error;
}
_ => {}
}
let task_snapshot = task.clone();
drop(tasks);
self.update_status_index(task_id, old_status, status).await;
if matches!(
status,
ValorTaskStatus::Completed | ValorTaskStatus::Failed | ValorTaskStatus::Cancelled
) {
if let Some(handle) = self.timeout_watchdogs.write().await.remove(task_id) {
handle.abort();
}
if let Some(agg) = self.result_aggregator.read().await.as_ref().cloned() {
agg.notify_if_terminal(&task_snapshot).await;
}
}
let _svc_id = match &task_snapshot.task_type {
crate::common::task::ValorTaskType::ExecuteService { service_id, .. } => service_id,
};
let worker_label = task_snapshot
.assigned_worker
.as_deref()
.unwrap_or("<unassigned>");
info!(worker = worker_label, from = ?old_status, to = ?status, "Task status updated");
Ok(())
}
pub async fn cancel_task(&self, task_id: &ValorTaskId) -> Result<(), String> {
let mut tasks = self.tasks.write().await;
let task = tasks
.get_mut(task_id)
.ok_or_else(|| format!("Task {task_id} not found"))?;
if matches!(
task.status,
ValorTaskStatus::Completed | ValorTaskStatus::Failed | ValorTaskStatus::Cancelled
) {
return Err(format!("Task {task_id} is already finished"));
}
let old_status = task.status;
task.status = ValorTaskStatus::Cancelled;
task.completed_at = Some(current_timestamp_ms());
let task_snapshot = task.clone();
drop(tasks);
self.update_status_index(task_id, old_status, ValorTaskStatus::Cancelled)
.await;
if let Some(handle) = self.timeout_watchdogs.write().await.remove(task_id) {
handle.abort();
}
if let Some(agg) = self.result_aggregator.read().await.as_ref().cloned() {
agg.notify_if_terminal(&task_snapshot).await;
}
let _svc_id = match &task_snapshot.task_type {
crate::common::task::ValorTaskType::ExecuteService { service_id, .. } => service_id,
};
let worker_label = task_snapshot
.assigned_worker
.as_deref()
.unwrap_or("<unassigned>");
warn!(worker = worker_label, "Task cancelled");
Ok(())
}
pub async fn requeue_tasks_for_unreachable_worker(&self, worker_id: &ValorID) -> usize {
let mut changed = 0usize;
let assigned_ids = {
let by_worker = self.tasks_by_worker.read().await;
by_worker.get(worker_id).cloned().unwrap_or_default()
};
if assigned_ids.is_empty() {
return 0;
}
for task_id in &assigned_ids {
let mut tasks = self.tasks.write().await;
if let Some(task) = tasks.get_mut(task_id) {
let still_mine = task
.assigned_worker
.as_ref()
.map(|w| w == &worker_id.to_string())
.unwrap_or(false);
if still_mine
&& matches!(
task.status,
ValorTaskStatus::Assigned | ValorTaskStatus::Running
)
{
let old_status = task.status;
task.status = ValorTaskStatus::Pending;
task.assigned_worker = None;
task.assigned_at = None;
changed += 1;
let tid = task.task_id.clone();
drop(tasks);
self.update_status_index(&tid, old_status, ValorTaskStatus::Pending)
.await;
} else {
drop(tasks);
}
}
}
let mut by_worker = self.tasks_by_worker.write().await;
by_worker.remove(worker_id);
changed
}
pub async fn wait_for_terminal(
&self,
task_id: &ValorTaskId,
timeout_ms: Option<u64>,
) -> Result<ValorMasterTask, String> {
if let Some(current) = self.get_task(task_id).await {
if matches!(
current.status,
ValorTaskStatus::Completed | ValorTaskStatus::Failed | ValorTaskStatus::Cancelled
) {
return Ok(current);
}
}
let agg = self
.result_aggregator
.read()
.await
.as_ref()
.cloned()
.ok_or_else(|| "Result aggregator not configured".to_string())?;
let rx = agg.register_waiter(task_id).await;
if let Some(ms) = timeout_ms {
let dur = std::time::Duration::from_millis(ms);
match tokio::time::timeout(dur, rx).await {
Ok(Ok(task)) => Ok(task),
Ok(Err(_canceled)) => Err("Waiter cancelled".to_string()),
Err(_elapsed) => Err("Wait timeout".to_string()),
}
} else {
rx.await.map_err(|_| "Waiter cancelled".to_string())
}
}
async fn update_status_index(
&self,
task_id: &ValorTaskId,
old_status: ValorTaskStatus,
new_status: ValorTaskStatus,
) {
let mut by_status = self.tasks_by_status.write().await;
if let Some(ids) = by_status.get_mut(&old_status) {
ids.retain(|id| id != task_id);
}
by_status
.entry(new_status)
.or_insert_with(Vec::new)
.push(task_id.clone());
}
pub async fn get_stats(&self) -> TaskStats {
let by_status = self.tasks_by_status.read().await;
TaskStats {
pending: by_status
.get(&ValorTaskStatus::Pending)
.map(|v| v.len())
.unwrap_or(0),
assigned: by_status
.get(&ValorTaskStatus::Assigned)
.map(|v| v.len())
.unwrap_or(0),
running: by_status
.get(&ValorTaskStatus::Running)
.map(|v| v.len())
.unwrap_or(0),
completed: by_status
.get(&ValorTaskStatus::Completed)
.map(|v| v.len())
.unwrap_or(0),
failed: by_status
.get(&ValorTaskStatus::Failed)
.map(|v| v.len())
.unwrap_or(0),
cancelled: by_status
.get(&ValorTaskStatus::Cancelled)
.map(|v| v.len())
.unwrap_or(0),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct TaskStats {
pub pending: usize,
pub assigned: usize,
pub running: usize,
pub completed: usize,
pub failed: usize,
pub cancelled: usize,
}
impl TaskStats {
pub fn total(&self) -> usize {
self.pending + self.assigned + self.running + self.completed + self.failed + self.cancelled
}
}
fn current_timestamp_ms() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64
}