use std::collections::HashMap;
use std::future::Future;
use std::sync::Arc;
use std::time::Instant;
use rustvello_core::broker::Broker;
use rustvello_core::context::{InvocationContext, RunnerContext};
use rustvello_core::error::{RustvelloError, RustvelloResult, TaskError};
use rustvello_core::middleware::TaskMiddleware;
use rustvello_core::observability::{EventEmitter, LastResult, WorkerState};
use rustvello_core::orchestrator::Orchestrator;
use rustvello_core::state_backend::StateBackend;
use rustvello_core::task::{DynTask, TaskRegistry};
use rustvello_core::trigger::TriggerManager;
use rustvello_proto::call::SerializedArguments;
use rustvello_proto::identifiers::{InvocationId, RunnerId};
use rustvello_proto::invocation::{InvocationHistory, WorkflowIdentity};
use rustvello_proto::status::{ConcurrencyControlType, InvocationStatus, InvocationStatusRecord};
use rustvello_proto::trigger::{ExceptionContext, ResultContext, StatusContext};
pub(crate) fn unwrap_panic(panic: Box<dyn std::any::Any + Send>) -> RustvelloError {
let msg = match panic.downcast_ref::<&str>() {
Some(s) => (*s).to_string(),
None => match panic.downcast_ref::<String>() {
Some(s) => s.clone(),
None => "unknown panic".to_string(),
},
};
RustvelloError::Internal {
message: format!("task panicked: {msg}"),
}
}
pub(crate) struct ExecutionDeps {
pub orchestrator: Arc<dyn Orchestrator>,
pub state_backend: Arc<dyn StateBackend>,
pub broker: Arc<dyn Broker>,
pub emitter: Arc<dyn EventEmitter>,
pub middlewares: Vec<Arc<dyn TaskMiddleware>>,
pub task_registry: Arc<TaskRegistry>,
pub trigger_manager: Option<Arc<TriggerManager>>,
pub worker_states: Option<Arc<std::sync::Mutex<HashMap<RunnerId, WorkerState>>>>,
}
const MAX_BLOCKING_CANDIDATES: usize = 8;
const MAX_CC_RETRIES: usize = 8;
pub(crate) async fn retrieve_next_invocation_with_cc(
orchestrator: &dyn Orchestrator,
broker: &dyn Broker,
state_backend: Option<&dyn StateBackend>,
task_registry: Option<&TaskRegistry>,
) -> RustvelloResult<Option<InvocationId>> {
match orchestrator
.get_blocking_invocations(MAX_BLOCKING_CANDIDATES)
.await
{
Ok(blocking) if !blocking.is_empty() => {
for inv_id in &blocking {
if check_cc_for_candidate(
orchestrator,
broker,
state_backend,
task_registry,
inv_id,
)
.await?
{
tracing::debug!("Prioritizing blocking invocation {} (has waiters)", inv_id);
return Ok(Some(inv_id.clone()));
}
}
}
Ok(_) => {} Err(e) => {
tracing::warn!(
"get_blocking_invocations failed, falling back to broker: {}",
e
);
}
}
for _ in 0..MAX_CC_RETRIES {
match broker.retrieve_invocation(None).await? {
Some(inv_id) => {
if check_cc_for_candidate(
orchestrator,
broker,
state_backend,
task_registry,
&inv_id,
)
.await?
{
return Ok(Some(inv_id));
}
}
None => return Ok(None),
}
}
Ok(None)
}
async fn check_cc_for_candidate(
orchestrator: &dyn Orchestrator,
broker: &dyn Broker,
state_backend: Option<&dyn StateBackend>,
task_registry: Option<&TaskRegistry>,
invocation_id: &InvocationId,
) -> RustvelloResult<bool> {
let (Some(sb), Some(tr)) = (state_backend, task_registry) else {
return Ok(true); };
let inv_dto = match sb.get_invocation(invocation_id).await {
Ok(dto) => dto,
Err(_) => return Ok(true), };
let task = match tr.get_dyn(&inv_dto.task_id) {
Some(t) => t,
None => return Ok(true), };
let config = task.config();
if config.concurrency_control == ConcurrencyControlType::Unlimited {
return Ok(true); }
let call_dto = match sb.get_call(&inv_dto.call_id).await {
Ok(c) => c,
Err(_) => return Ok(true),
};
let cc_args = compute_cc_args(config, &call_dto.serialized_arguments);
if orchestrator
.check_running_concurrency(&inv_dto.task_id, config, cc_args.as_ref())
.await?
{
return Ok(true); }
tracing::debug!(
"Concurrency control denied invocation {} for task {}",
invocation_id,
inv_dto.task_id
);
if config.reroute_on_cc {
match orchestrator
.set_invocation_status(invocation_id, InvocationStatus::ConcurrencyControlled, None)
.await
{
Ok(_) => {
orchestrator
.set_invocation_status(invocation_id, InvocationStatus::Rerouted, None)
.await?;
broker
.route_invocation_for_task(invocation_id, &inv_dto.task_id)
.await?;
tracing::info!(
"Rerouted CC-denied invocation {} back to broker",
invocation_id
);
}
Err(RustvelloError::InvalidStatusTransition { .. }) => {
}
Err(e) => return Err(e),
}
} else {
match orchestrator
.set_invocation_status(
invocation_id,
InvocationStatus::ConcurrencyControlledFinal,
None,
)
.await
{
Ok(_) => {
tracing::info!(
"Permanently rejected CC-denied invocation {}",
invocation_id
);
}
Err(RustvelloError::InvalidStatusTransition { .. }) => {}
Err(e) => return Err(e),
}
}
Ok(false)
}
pub(crate) fn compute_cc_args(
config: &rustvello_proto::config::TaskConfig,
args: &SerializedArguments,
) -> Option<SerializedArguments> {
match config.concurrency_control {
ConcurrencyControlType::Unlimited => None,
ConcurrencyControlType::Task => Some(SerializedArguments::new()),
ConcurrencyControlType::Argument => {
if config.key_arguments.is_empty() {
Some(args.clone())
} else {
let mut filtered = SerializedArguments::new();
for key in &config.key_arguments {
if let Some(val) = args.0.get(key) {
filtered.insert(key, val.clone());
}
}
Some(filtered)
}
}
ConcurrencyControlType::None => Some(args.clone()),
_ => Some(args.clone()),
}
}
pub(crate) async fn execute_invocation_common<F, Fut>(
deps: &ExecutionDeps,
invocation_id: &InvocationId,
worker_runner_id: &RunnerId,
runner_label: &str,
worker_ctx: &RunnerContext,
execute_task: F,
) -> RustvelloResult<()>
where
F: FnOnce(Arc<dyn DynTask>, SerializedArguments, InvocationContext, RunnerContext) -> Fut,
Fut: Future<Output = RustvelloResult<String>> + Send,
{
match deps
.orchestrator
.set_invocation_status(
invocation_id,
InvocationStatus::Pending,
Some(worker_runner_id),
)
.await
{
Ok(_) => {
deps.state_backend
.add_history(
&InvocationHistory::new(
invocation_id.clone(),
InvocationStatusRecord::new(
InvocationStatus::Pending,
Some(worker_runner_id.clone()),
),
Some(format!("{runner_label} claimed invocation")),
)
.with_runner(worker_runner_id.clone()),
)
.await?;
}
Err(RustvelloError::InvalidStatusTransition {
from_status,
to_status,
..
}) => {
tracing::warn!(
"Already claimed (race): from_status:{} to_status:{} skipped",
from_status,
to_status
);
return Ok(());
}
Err(RustvelloError::OwnershipViolation { .. }) => {
tracing::warn!("Already owned by another runner");
return Ok(());
}
Err(e) => return Err(e),
}
if let Ok(inv_dto) = deps.state_backend.get_invocation(invocation_id).await {
if let Some(task) = deps.task_registry.get_dyn(&inv_dto.task_id) {
let config = task.config();
if config.concurrency_control != ConcurrencyControlType::Unlimited {
if let Ok(call_dto) = deps.state_backend.get_call(&inv_dto.call_id).await {
let cc_args = compute_cc_args(config, &call_dto.serialized_arguments);
if let Err(e) = deps
.orchestrator
.index_for_concurrency_control(
invocation_id,
&inv_dto.task_id,
cc_args.as_ref(),
)
.await
{
tracing::warn!("Failed to index for CC: {}", e);
}
}
}
}
}
deps.orchestrator
.set_invocation_status(
invocation_id,
InvocationStatus::Running,
Some(worker_runner_id),
)
.await?;
deps.state_backend
.add_history(
&InvocationHistory::new(
invocation_id.clone(),
InvocationStatusRecord::new(
InvocationStatus::Running,
Some(worker_runner_id.clone()),
),
Some(format!("{runner_label} executing")),
)
.with_runner(worker_runner_id.clone()),
)
.await?;
let inv_dto = deps.state_backend.get_invocation(invocation_id).await?;
let call_dto = deps.state_backend.get_call(&inv_dto.call_id).await?;
tracing::Span::current().record("task_id", tracing::field::display(&inv_dto.task_id));
let task = deps
.task_registry
.get_dyn(&inv_dto.task_id)
.ok_or_else(|| RustvelloError::TaskNotRegistered {
task_id: inv_dto.task_id.clone(),
})?;
let retry_history = deps.state_backend.get_history(invocation_id).await?;
let num_retries = retry_history
.iter()
.filter(|h| h.status_record.status == InvocationStatus::Retry)
.count() as u32;
let inv_ctx = InvocationContext {
invocation_id: invocation_id.clone(),
task_id: inv_dto.task_id.clone(),
workflow: inv_dto.workflow.clone().unwrap_or_else(|| {
WorkflowIdentity::root(invocation_id.clone(), inv_dto.task_id.clone())
}),
parent_invocation_id: inv_dto.parent_invocation_id.clone(),
num_retries,
};
let run_ctx = worker_ctx.clone();
deps.emitter
.on_task_started(&inv_dto.task_id, invocation_id);
if let Some(ref ws) = deps.worker_states {
if let Ok(mut ws) = ws.lock() {
if let Some(state) = ws.get_mut(worker_runner_id) {
state.current_invocation = Some(invocation_id.clone());
state.current_task = Some(inv_dto.task_id.clone());
state.started_at = Some(Instant::now());
}
}
}
let exec_start = Instant::now();
for mw in &deps.middlewares {
mw.before(invocation_id, &inv_dto.task_id).await?;
}
let exec_result = execute_task(
Arc::clone(&task),
call_dto.serialized_arguments.clone(),
inv_ctx,
run_ctx,
)
.await;
for mw in deps.middlewares.iter().rev() {
if let Err(e) = mw
.after(invocation_id, &inv_dto.task_id, &exec_result)
.await
{
tracing::warn!("After-middleware failed: {}", e);
}
}
match exec_result {
Ok(result) => {
deps.state_backend
.store_result(invocation_id, &result)
.await?;
deps.orchestrator
.set_invocation_status(
invocation_id,
InvocationStatus::Success,
Some(worker_runner_id),
)
.await?;
deps.state_backend
.add_history(
&InvocationHistory::new(
invocation_id.clone(),
InvocationStatusRecord::new(
InvocationStatus::Success,
Some(worker_runner_id.clone()),
),
None,
)
.with_runner(worker_runner_id.clone()),
)
.await?;
deps.orchestrator.release_waiters(invocation_id).await?;
if let Err(e) = deps
.orchestrator
.remove_from_concurrency_index(invocation_id)
.await
{
tracing::warn!("Failed to remove from CC index: {}", e);
}
if let Some(ref tm) = deps.trigger_manager {
let result_ctx = ResultContext {
invocation_id: invocation_id.clone(),
task_id: inv_dto.task_id.clone(),
result: serde_json::Value::String(result.clone()),
arguments: std::collections::BTreeMap::new(),
};
if let Err(e) = tm.report_result(&result_ctx).await {
tracing::warn!("Trigger report_result failed: {}", e);
}
let status_ctx = StatusContext {
invocation_id: invocation_id.clone(),
task_id: inv_dto.task_id.clone(),
status: InvocationStatus::Success,
arguments: std::collections::BTreeMap::new(),
};
if let Err(e) = tm.report_status_change(&status_ctx).await {
tracing::warn!("Trigger report_status_change failed: {}", e);
}
}
tracing::info!("Invocation completed status:success");
let exec_duration = exec_start.elapsed();
deps.emitter
.on_task_succeeded(&inv_dto.task_id, invocation_id, exec_duration);
if let Some(ref ws) = deps.worker_states {
if let Ok(mut ws) = ws.lock() {
if let Some(state) = ws.get_mut(worker_runner_id) {
state.current_invocation = None;
state.current_task = None;
state.started_at = None;
state.last_result = Some(LastResult::Success {
task_id: inv_dto.task_id.clone(),
duration: exec_duration,
});
state.invocations_completed += 1;
}
}
}
}
Err(err) => {
let task_error = match &err {
RustvelloError::TaskExecution {
error_type,
message,
traceback,
} => TaskError {
error_type: error_type.clone(),
message: message.clone(),
traceback: traceback.clone(),
},
_ => TaskError {
error_type: "TaskExecutionError".to_string(),
message: err.to_string(),
traceback: None,
},
};
let retry_count = num_retries;
let max_retries = task.config().max_retries;
let retry_for_errors = &task.config().retry_for_errors;
let should_retry = retry_count < max_retries
&& (retry_for_errors.is_empty()
|| retry_for_errors
.iter()
.any(|e| task_error.error_type.contains(e.as_str())));
if should_retry {
deps.orchestrator
.set_invocation_status(
invocation_id,
InvocationStatus::Retry,
Some(worker_runner_id),
)
.await?;
deps.broker.route_invocation(invocation_id).await?;
deps.state_backend
.add_history(
&InvocationHistory::new(
invocation_id.clone(),
InvocationStatusRecord::new(
InvocationStatus::Retry,
Some(worker_runner_id.clone()),
),
Some(format!(
"Retry {}/{}: {}",
retry_count + 1,
max_retries,
err
)),
)
.with_runner(worker_runner_id.clone()),
)
.await?;
tracing::warn!("Failed status:retry {}/{}", retry_count + 1, max_retries);
deps.emitter
.on_task_retried(&inv_dto.task_id, invocation_id, retry_count + 1);
} else {
deps.state_backend
.store_error(invocation_id, &task_error)
.await?;
deps.orchestrator
.set_invocation_status(
invocation_id,
InvocationStatus::Failed,
Some(worker_runner_id),
)
.await?;
deps.state_backend
.add_history(
&InvocationHistory::new(
invocation_id.clone(),
InvocationStatusRecord::new(
InvocationStatus::Failed,
Some(worker_runner_id.clone()),
),
Some(format!("Failed: {}", err)),
)
.with_runner(worker_runner_id.clone()),
)
.await?;
if let Some(ref tm) = deps.trigger_manager {
let exc_ctx = ExceptionContext {
invocation_id: invocation_id.clone(),
task_id: inv_dto.task_id.clone(),
error_type: task_error.error_type.clone(),
error_message: task_error.message.clone(),
arguments: std::collections::BTreeMap::new(),
};
if let Err(e) = tm.report_failure(&exc_ctx).await {
tracing::warn!("Trigger report_failure failed: {}", e);
}
let status_ctx = StatusContext {
invocation_id: invocation_id.clone(),
task_id: inv_dto.task_id.clone(),
status: InvocationStatus::Failed,
arguments: std::collections::BTreeMap::new(),
};
if let Err(e) = tm.report_status_change(&status_ctx).await {
tracing::warn!("Trigger report_status_change failed: {}", e);
}
}
tracing::error!("Invocation status:failed permanently: {}", err);
if let Err(e) = deps
.orchestrator
.remove_from_concurrency_index(invocation_id)
.await
{
tracing::warn!("Failed to remove from CC index: {}", e);
}
let exec_duration = exec_start.elapsed();
deps.emitter.on_task_failed(
&inv_dto.task_id,
invocation_id,
&err.to_string(),
exec_duration,
);
if let Some(ref ws) = deps.worker_states {
if let Ok(mut ws) = ws.lock() {
if let Some(state) = ws.get_mut(worker_runner_id) {
state.current_invocation = None;
state.current_task = None;
state.started_at = None;
state.last_result = Some(LastResult::Failed {
task_id: inv_dto.task_id.clone(),
error: err.to_string(),
});
state.invocations_completed += 1;
}
}
}
}
}
}
Ok(())
}