use std::collections::HashMap;
use bigdecimal::ToPrimitive;
use chrono::{DateTime, Utc};
use sqlx::PgPool;
use thiserror::Error;
use tracing::error;
use uuid::Uuid;
use tasker_shared::database::sql_functions::{
SqlFunctionExecutor, StepReadinessStatus, TaskExecutionContext,
};
use tasker_shared::models::core::task::{PaginationInfo, Task, TaskListQuery};
use tasker_shared::models::orchestration::{ExecutionStatus, RecommendedAction};
use tasker_shared::types::api::orchestration::TaskResponse;
#[derive(Error, Debug)]
pub enum TaskQueryError {
#[error("Task not found: {0}")]
NotFound(Uuid),
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Failed to fetch task metadata: {0}")]
MetadataError(String),
}
pub type TaskQueryResult<T> = Result<T, TaskQueryError>;
#[derive(Debug, Clone)]
pub struct TaskWithContext {
pub task: Task,
pub task_name: String,
pub namespace: String,
pub version: String,
pub status: String,
pub execution_context: TaskExecutionContext,
pub steps: Vec<StepReadinessStatus>,
}
#[derive(Debug)]
pub struct PaginatedTasksWithContext {
pub tasks: Vec<TaskWithContext>,
pub pagination: PaginationInfo,
}
#[derive(Debug, Clone)]
pub struct TaskQueryService {
pool: PgPool,
}
impl TaskQueryService {
pub fn new(pool: PgPool) -> Self {
Self { pool }
}
pub async fn get_task_with_context(&self, uuid: Uuid) -> TaskQueryResult<TaskWithContext> {
let sql_executor = SqlFunctionExecutor::new(self.pool.clone());
let task = Task::find_by_id(&self.pool, uuid)
.await?
.ok_or(TaskQueryError::NotFound(uuid))?;
let execution_context = sql_executor
.get_task_execution_context(uuid)
.await?
.ok_or(TaskQueryError::NotFound(uuid))?;
let steps = sql_executor.get_step_readiness_status(uuid, None).await?;
let task_metadata = task.for_orchestration(&self.pool).await.map_err(|e| {
error!(error = %e, "Failed to get task metadata");
TaskQueryError::MetadataError(e.to_string())
})?;
Ok(TaskWithContext {
task,
task_name: task_metadata.task_name,
namespace: task_metadata.namespace_name,
version: task_metadata.task_version,
status: execution_context.status.clone(),
execution_context,
steps,
})
}
pub async fn list_tasks_with_context(
&self,
query: &TaskListQuery,
) -> TaskQueryResult<PaginatedTasksWithContext> {
let sql_executor = SqlFunctionExecutor::new(self.pool.clone());
let paginated_result = Task::list_with_pagination(&self.pool, query).await?;
let task_uuids: Vec<Uuid> = paginated_result
.tasks
.iter()
.map(|t| t.task.task_uuid)
.collect();
let execution_contexts = sql_executor
.get_task_execution_contexts_batch(task_uuids)
.await?;
let context_map: HashMap<Uuid, TaskExecutionContext> = execution_contexts
.into_iter()
.map(|ctx| (ctx.task_uuid, ctx))
.collect();
let tasks: Vec<TaskWithContext> = paginated_result
.tasks
.into_iter()
.map(|twm| {
let execution_context = context_map
.get(&twm.task.task_uuid)
.cloned()
.unwrap_or_else(|| {
Self::default_execution_context(
twm.task.task_uuid,
twm.task.named_task_uuid,
)
});
TaskWithContext {
task: twm.task,
task_name: twm.task_name,
namespace: twm.namespace_name,
version: twm.task_version,
status: twm.status,
execution_context,
steps: Vec::new(), }
})
.collect();
Ok(PaginatedTasksWithContext {
tasks,
pagination: paginated_result.pagination,
})
}
pub fn to_task_response(twc: &TaskWithContext) -> TaskResponse {
let task = &twc.task;
let ec = &twc.execution_context;
let tags = match &task.tags {
Some(serde_json::Value::Array(arr)) => Some(
arr.iter()
.filter_map(|v| v.as_str())
.map(|s| s.to_string())
.collect(),
),
_ => None,
};
let completion_percentage = ec.completion_percentage.to_f64().unwrap_or(0.0);
TaskResponse {
task_uuid: task.task_uuid.to_string(),
name: twc.task_name.clone(),
namespace: twc.namespace.clone(),
version: twc.version.clone(),
status: twc.status.clone(),
created_at: DateTime::from_naive_utc_and_offset(task.created_at, Utc),
updated_at: DateTime::from_naive_utc_and_offset(task.updated_at, Utc),
completed_at: None,
context: task
.context
.clone()
.unwrap_or_else(|| serde_json::json!({})),
initiator: task
.initiator
.clone()
.unwrap_or_else(|| "unknown".to_string()),
source_system: task
.source_system
.clone()
.unwrap_or_else(|| "unknown".to_string()),
reason: task.reason.clone().unwrap_or_else(|| "unknown".to_string()),
priority: Some(task.priority),
tags,
total_steps: ec.total_steps,
pending_steps: ec.pending_steps,
in_progress_steps: ec.in_progress_steps,
completed_steps: ec.completed_steps,
failed_steps: ec.failed_steps,
ready_steps: ec.ready_steps,
execution_status: ec.execution_status.as_str().to_string(),
recommended_action: ec
.recommended_action
.map(|ra| ra.into())
.unwrap_or_else(|| "none".to_string()),
completion_percentage,
health_status: ec.health_status.clone(),
steps: twc.steps.clone(),
correlation_id: task.correlation_id,
parent_correlation_id: task.parent_correlation_id,
}
}
fn default_execution_context(task_uuid: Uuid, named_task_uuid: Uuid) -> TaskExecutionContext {
TaskExecutionContext {
task_uuid,
named_task_uuid,
status: "unknown".to_string(),
total_steps: 0,
pending_steps: 0,
in_progress_steps: 0,
completed_steps: 0,
failed_steps: 0,
ready_steps: 0,
execution_status: ExecutionStatus::WaitingForDependencies,
recommended_action: Some(RecommendedAction::WaitForDependencies),
completion_percentage: sqlx::types::BigDecimal::from(0),
health_status: "unknown".to_string(),
enqueued_steps: 0,
}
}
}