use chrono::Utc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::Semaphore;
use tracing::{debug, error, info, warn};
use super::slot_token::SlotToken;
use super::task_handle::{with_task_handle, TaskHandle};
use super::types::{ClaimedTask, ExecutionScope, ExecutorConfig};
use crate::dal::DAL;
use crate::database::universal_types::UniversalUuid;
use crate::dispatcher::{
DispatchError, ExecutionResult, ExecutorMetrics, TaskExecutor, TaskReadyEvent,
};
use crate::error::ExecutorError;
use crate::retry::{RetryCondition, RetryPolicy};
use crate::task::get_task;
use crate::{parse_namespace, Context, Database, Task, TaskRegistry};
use async_trait::async_trait;
pub struct ThreadTaskExecutor {
database: Database,
dal: DAL,
task_registry: Arc<TaskRegistry>,
instance_id: UniversalUuid,
config: ExecutorConfig,
semaphore: Arc<Semaphore>,
total_executed: AtomicU64,
total_failed: AtomicU64,
}
impl ThreadTaskExecutor {
pub fn new(
database: Database,
task_registry: Arc<TaskRegistry>,
config: ExecutorConfig,
) -> Self {
let dal = DAL::new(database.clone());
let max_concurrent = config.max_concurrent_tasks;
Self {
database,
dal,
task_registry,
instance_id: UniversalUuid::new_v4(),
config,
semaphore: Arc::new(Semaphore::new(max_concurrent)),
total_executed: AtomicU64::new(0),
total_failed: AtomicU64::new(0),
}
}
pub fn with_global_registry(
database: Database,
config: ExecutorConfig,
) -> Result<Self, crate::error::RegistrationError> {
let mut registry = TaskRegistry::new();
let global_registry = crate::global_task_registry();
let global_tasks = global_registry.read();
for (namespace, constructor) in global_tasks.iter() {
let task = constructor();
registry.register_arc(namespace.clone(), task)?;
}
Ok(Self::new(database, Arc::new(registry), config))
}
pub fn semaphore(&self) -> &Arc<Semaphore> {
&self.semaphore
}
async fn build_task_context(
&self,
claimed_task: &ClaimedTask,
dependencies: &[crate::task::TaskNamespace],
) -> Result<Context<serde_json::Value>, ExecutorError> {
tracing::debug!(
"Building context for task '{}' with {} dependencies: {:?}",
claimed_task.task_name,
dependencies.len(),
dependencies
);
tracing::debug!(
"DEBUG: Building context for task '{}' with {} dependencies: {:?}",
claimed_task.task_name,
dependencies.len(),
dependencies
);
let execution_scope = ExecutionScope {
pipeline_execution_id: claimed_task.pipeline_execution_id,
task_execution_id: Some(claimed_task.task_execution_id),
task_name: Some(claimed_task.task_name.clone()),
};
let mut context = Context::new();
let _execution_scope = execution_scope;
if dependencies.is_empty() {
if let Ok(pipeline_execution) = self
.dal
.pipeline_execution()
.get_by_id(claimed_task.pipeline_execution_id)
.await
{
if let Some(context_id) = pipeline_execution.context_id {
if let Ok(initial_context) = self
.dal
.context()
.read::<serde_json::Value>(context_id)
.await
{
for (key, value) in initial_context.data() {
let _ = context.insert(key, value.clone());
}
debug!(
"Loaded initial pipeline context with {} keys",
initial_context.data().len()
);
}
}
}
}
if !dependencies.is_empty() {
debug!(
"Loading dependency contexts for {} dependencies: {:?}",
dependencies.len(),
dependencies
);
if let Ok(dep_metadata_with_contexts) = self
.dal
.task_execution_metadata()
.get_dependency_metadata_with_contexts(
claimed_task.pipeline_execution_id,
dependencies,
)
.await
{
debug!(
"Found {} dependency metadata records",
dep_metadata_with_contexts.len()
);
for (_task_metadata, context_json) in dep_metadata_with_contexts {
if let Some(json_str) = context_json {
if let Ok(dep_context) = Context::<serde_json::Value>::from_json(json_str) {
debug!(
"Merging dependency context with {} keys: {:?}",
dep_context.data().len(),
dep_context.data().keys().collect::<Vec<_>>()
);
for (key, value) in dep_context.data() {
if let Some(existing_value) = context.get(key) {
let merged_value =
Self::merge_context_values(existing_value, value);
let _ = context.update(key, merged_value);
} else {
let _ = context.insert(key, value.clone());
}
}
} else {
debug!("Failed to parse dependency context JSON");
}
}
}
} else {
debug!(
"Failed to load dependency metadata for dependencies: {:?}",
dependencies
);
}
}
debug!(
"Final context for task {} has {} keys: {:?}",
claimed_task.task_name,
context.data().len(),
context.data().keys().collect::<Vec<_>>()
);
Ok(context)
}
fn merge_context_values(
existing: &serde_json::Value,
new: &serde_json::Value,
) -> serde_json::Value {
use serde_json::Value;
match (existing, new) {
(Value::Array(existing_arr), Value::Array(new_arr)) => {
let mut merged = existing_arr.clone();
for item in new_arr {
if !merged.contains(item) {
merged.push(item.clone());
}
}
Value::Array(merged)
}
(Value::Object(existing_obj), Value::Object(new_obj)) => {
let mut merged = existing_obj.clone();
for (key, value) in new_obj {
if let Some(existing_value) = merged.get(key) {
merged.insert(
key.clone(),
Self::merge_context_values(existing_value, value),
);
} else {
merged.insert(key.clone(), value.clone());
}
}
Value::Object(merged)
}
(_, new_value) => new_value.clone(),
}
}
async fn execute_with_timeout(
&self,
task: &dyn Task,
context: Context<serde_json::Value>,
) -> Result<Context<serde_json::Value>, ExecutorError> {
match tokio::time::timeout(self.config.task_timeout, task.execute(context)).await {
Ok(result) => result.map_err(ExecutorError::TaskExecution),
Err(_) => Err(ExecutorError::TaskTimeout),
}
}
async fn handle_task_result(
&self,
claimed_task: ClaimedTask,
result: Result<Context<serde_json::Value>, ExecutorError>,
) -> Result<(), ExecutorError> {
match result {
Ok(result_context) => {
self.complete_task_transaction(&claimed_task, result_context)
.await?;
info!("Task completed successfully: {}", claimed_task.task_name);
}
Err(error) => {
let namespace = parse_namespace(&claimed_task.task_name).map_err(|e| {
ExecutorError::TaskNotFound(format!("Invalid namespace: {}", e))
})?;
let task = get_task(&namespace)
.ok_or_else(|| ExecutorError::TaskNotFound(claimed_task.task_name.clone()))?;
let retry_policy = task.retry_policy();
if self
.should_retry_task(&claimed_task, &error, &retry_policy)
.await?
{
self.schedule_task_retry(&claimed_task, &retry_policy)
.await?;
warn!(
"Task failed, scheduled for retry: {} (attempt {})",
claimed_task.task_name, claimed_task.attempt
);
} else {
self.mark_task_failed(claimed_task.task_execution_id, &error)
.await?;
error!(
"Task failed permanently: {} - {}",
claimed_task.task_name, error
);
}
}
}
Ok(())
}
async fn save_task_context(
&self,
claimed_task: &ClaimedTask,
context: Context<serde_json::Value>,
) -> Result<(), ExecutorError> {
use crate::models::task_execution_metadata::NewTaskExecutionMetadata;
let context_id = self.dal.context().create(&context).await?;
let task_metadata_record = NewTaskExecutionMetadata {
task_execution_id: claimed_task.task_execution_id,
pipeline_execution_id: claimed_task.pipeline_execution_id,
task_name: claimed_task.task_name.clone(),
context_id,
};
self.dal
.task_execution_metadata()
.upsert_task_execution_metadata(task_metadata_record)
.await?;
let key_count = context.data().len();
let keys: Vec<_> = context.data().keys().collect();
info!(
"Context saved: {} (pipeline: {}, {} keys: {:?}, context_id: {:?})",
claimed_task.task_name, claimed_task.pipeline_execution_id, key_count, keys, context_id
);
Ok(())
}
async fn mark_task_completed(
&self,
task_execution_id: UniversalUuid,
) -> Result<(), ExecutorError> {
let task = self
.dal
.task_execution()
.get_by_id(task_execution_id)
.await?;
self.dal
.task_execution()
.mark_completed(task_execution_id)
.await?;
info!(
"Task state change: {} -> Completed (task: {}, pipeline: {})",
task.status, task.task_name, task.pipeline_execution_id
);
Ok(())
}
async fn complete_task_transaction(
&self,
claimed_task: &ClaimedTask,
context: Context<serde_json::Value>,
) -> Result<(), ExecutorError> {
self.save_task_context(claimed_task, context).await?;
self.mark_task_completed(claimed_task.task_execution_id)
.await?;
Ok(())
}
async fn mark_task_failed(
&self,
task_execution_id: UniversalUuid,
error: &ExecutorError,
) -> Result<(), ExecutorError> {
let task = self
.dal
.task_execution()
.get_by_id(task_execution_id)
.await?;
self.dal
.task_execution()
.mark_failed(task_execution_id, &error.to_string())
.await?;
error!(
"Task state change: {} -> Failed (task: {}, pipeline: {}, error: {})",
task.status, task.task_name, task.pipeline_execution_id, error
);
Ok(())
}
async fn should_retry_task(
&self,
claimed_task: &ClaimedTask,
error: &ExecutorError,
retry_policy: &RetryPolicy,
) -> Result<bool, ExecutorError> {
if claimed_task.attempt >= retry_policy.max_attempts {
debug!(
"Task {} exceeded max retry attempts ({}/{})",
claimed_task.task_name, claimed_task.attempt, retry_policy.max_attempts
);
return Ok(false);
}
let should_retry = retry_policy
.retry_conditions
.iter()
.all(|condition| match condition {
RetryCondition::Never => false,
RetryCondition::AllErrors => true,
RetryCondition::TransientOnly => self.is_transient_error(error),
RetryCondition::ErrorPattern { patterns } => {
let error_msg = error.to_string().to_lowercase();
patterns
.iter()
.any(|pattern| error_msg.contains(&pattern.to_lowercase()))
}
});
debug!(
"Retry decision for task {}: {} (conditions: {:?}, error: {})",
claimed_task.task_name, should_retry, retry_policy.retry_conditions, error
);
Ok(should_retry)
}
fn is_transient_error(&self, error: &ExecutorError) -> bool {
match error {
ExecutorError::TaskTimeout => true,
ExecutorError::Database(_) => true,
ExecutorError::ConnectionPool(_) => true,
ExecutorError::TaskNotFound(_) => false,
ExecutorError::TaskExecution(task_error) => {
let error_msg = task_error.to_string().to_lowercase();
error_msg.contains("timeout")
|| error_msg.contains("connection")
|| error_msg.contains("network")
|| error_msg.contains("temporary")
|| error_msg.contains("unavailable")
}
_ => false,
}
}
async fn schedule_task_retry(
&self,
claimed_task: &ClaimedTask,
retry_policy: &RetryPolicy,
) -> Result<(), ExecutorError> {
let retry_delay = retry_policy.calculate_delay(claimed_task.attempt);
let retry_at = Utc::now() + retry_delay;
self.dal
.task_execution()
.schedule_retry(
claimed_task.task_execution_id,
crate::database::UniversalTimestamp(retry_at),
claimed_task.attempt + 1,
)
.await?;
info!(
"Scheduled retry for task {} in {:?} (attempt {})",
claimed_task.task_name,
retry_delay,
claimed_task.attempt + 1
);
Ok(())
}
}
impl Clone for ThreadTaskExecutor {
fn clone(&self) -> Self {
Self {
database: self.database.clone(),
dal: self.dal.clone(),
task_registry: Arc::clone(&self.task_registry),
instance_id: self.instance_id,
config: self.config.clone(),
semaphore: Arc::clone(&self.semaphore),
total_executed: AtomicU64::new(self.total_executed.load(Ordering::SeqCst)),
total_failed: AtomicU64::new(self.total_failed.load(Ordering::SeqCst)),
}
}
}
#[async_trait]
impl TaskExecutor for ThreadTaskExecutor {
async fn execute(&self, event: TaskReadyEvent) -> Result<ExecutionResult, DispatchError> {
let start = Instant::now();
let permit = self
.semaphore
.clone()
.acquire_owned()
.await
.map_err(|_| DispatchError::ExecutorNotFound("semaphore closed".into()))?;
let claimed_task = ClaimedTask {
task_execution_id: event.task_execution_id,
pipeline_execution_id: event.pipeline_execution_id,
task_name: event.task_name.clone(),
attempt: event.attempt,
};
let namespace = match parse_namespace(&claimed_task.task_name) {
Ok(ns) => ns,
Err(e) => {
self.total_failed.fetch_add(1, Ordering::SeqCst);
return Ok(ExecutionResult::failure(
event.task_execution_id,
format!("Invalid namespace: {}", e),
start.elapsed(),
));
}
};
let task = match get_task(&namespace) {
Some(t) => t,
None => {
self.total_failed.fetch_add(1, Ordering::SeqCst);
return Ok(ExecutionResult::failure(
event.task_execution_id,
format!("Task not found: {}", claimed_task.task_name),
start.elapsed(),
));
}
};
let dependencies = task.dependencies();
let context = match self.build_task_context(&claimed_task, dependencies).await {
Ok(ctx) => ctx,
Err(e) => {
self.total_failed.fetch_add(1, Ordering::SeqCst);
return Ok(ExecutionResult::failure(
event.task_execution_id,
format!("Context build failed: {}", e),
start.elapsed(),
));
}
};
let execution_result = if task.requires_handle() {
let slot_token = SlotToken::new(permit, self.semaphore.clone());
let handle =
TaskHandle::with_dal(slot_token, event.task_execution_id, self.dal.clone());
if let Err(e) = self
.dal
.task_execution()
.set_sub_status(event.task_execution_id, Some("Active"))
.await
{
tracing::warn!(
task_execution_id = %event.task_execution_id,
error = %e,
"Failed to set initial sub_status to Active"
);
}
let (result, _returned_handle) =
with_task_handle(handle, self.execute_with_timeout(task.as_ref(), context)).await;
if let Err(e) = self
.dal
.task_execution()
.set_sub_status(event.task_execution_id, None)
.await
{
tracing::warn!(
task_execution_id = %event.task_execution_id,
error = %e,
"Failed to clear sub_status after execution"
);
}
result
} else {
let _permit = permit;
self.execute_with_timeout(task.as_ref(), context).await
};
let duration = start.elapsed();
match execution_result {
Ok(result_context) => {
match self
.complete_task_transaction(&claimed_task, result_context)
.await
{
Ok(_) => {
self.total_executed.fetch_add(1, Ordering::SeqCst);
info!(
task_id = %event.task_execution_id,
task_name = %event.task_name,
duration_ms = duration.as_millis(),
"Task executed successfully via dispatcher"
);
Ok(ExecutionResult::success(event.task_execution_id, duration))
}
Err(e) => {
self.total_failed.fetch_add(1, Ordering::SeqCst);
Ok(ExecutionResult::failure(
event.task_execution_id,
format!("Failed to save context: {}", e),
duration,
))
}
}
}
Err(error) => {
let retry_policy = task.retry_policy();
let should_retry = self
.should_retry_task(&claimed_task, &error, &retry_policy)
.await
.unwrap_or(false);
if should_retry {
if let Err(e) = self.schedule_task_retry(&claimed_task, &retry_policy).await {
warn!(
task_id = %event.task_execution_id,
error = %e,
"Failed to schedule retry"
);
}
self.total_executed.fetch_add(1, Ordering::SeqCst);
Ok(ExecutionResult::retry(
event.task_execution_id,
error.to_string(),
duration,
))
} else {
self.total_failed.fetch_add(1, Ordering::SeqCst);
Ok(ExecutionResult::failure(
event.task_execution_id,
error.to_string(),
duration,
))
}
}
}
}
fn has_capacity(&self) -> bool {
self.semaphore.available_permits() > 0
}
fn metrics(&self) -> ExecutorMetrics {
let available = self.semaphore.available_permits();
let active = self.config.max_concurrent_tasks.saturating_sub(available);
ExecutorMetrics {
active_tasks: active,
max_concurrent: self.config.max_concurrent_tasks,
total_executed: self.total_executed.load(Ordering::SeqCst),
total_failed: self.total_failed.load(Ordering::SeqCst),
avg_duration_ms: 0, }
}
fn name(&self) -> &str {
"ThreadTaskExecutor"
}
}