use async_trait::async_trait;
use chrono::{DateTime, Utc};
use tracing::instrument;
use rustvello_proto::call::{CallDTO, SerializedArguments};
use rustvello_proto::config::TaskConfig;
use rustvello_proto::identifiers::{CallId, InvocationId, RunnerId, TaskId};
use rustvello_proto::status::{InvocationStatus, InvocationStatusRecord};
use crate::error::RustvelloResult;
#[derive(Debug, Clone)]
pub struct AtomicServiceExecution {
pub runner_id: String,
pub start: DateTime<Utc>,
pub end: DateTime<Utc>,
}
impl AtomicServiceExecution {
pub fn duration_secs(&self) -> f64 {
(self.end - self.start).num_milliseconds() as f64 / 1000.0
}
}
#[derive(Debug, Clone)]
pub struct ActiveRunnerInfo {
pub runner_id: RunnerId,
pub creation_time: DateTime<Utc>,
pub last_heartbeat: DateTime<Utc>,
pub can_run_atomic_service: bool,
pub last_service_start: Option<DateTime<Utc>>,
pub last_service_end: Option<DateTime<Utc>>,
}
pub trait Orchestrator:
OrchestratorStatus
+ OrchestratorConcurrency
+ OrchestratorBlocking
+ OrchestratorQuery
+ OrchestratorRecovery
{
}
impl<
T: OrchestratorStatus
+ OrchestratorConcurrency
+ OrchestratorBlocking
+ OrchestratorQuery
+ OrchestratorRecovery,
> Orchestrator for T
{
}
#[async_trait]
pub trait OrchestratorStatus: Send + Sync {
async fn register_invocation(&self, call: &CallDTO) -> RustvelloResult<InvocationId>;
async fn register_invocation_with_id(
&self,
invocation_id: &InvocationId,
call: &CallDTO,
runner_id: Option<&RunnerId>,
) -> RustvelloResult<InvocationStatusRecord>;
async fn increment_invocation_retries(
&self,
invocation_id: &InvocationId,
) -> RustvelloResult<u32>;
async fn get_invocation_retries(&self, invocation_id: &InvocationId) -> RustvelloResult<u32>;
async fn remove_invocation(&self, invocation_id: &InvocationId) -> RustvelloResult<()>;
async fn get_invocation_status(
&self,
invocation_id: &InvocationId,
) -> RustvelloResult<InvocationStatusRecord>;
async fn set_invocation_status(
&self,
invocation_id: &InvocationId,
status: InvocationStatus,
runner_id: Option<&RunnerId>,
) -> RustvelloResult<InvocationStatusRecord>;
fn backend_name(&self) -> &'static str {
"Unknown"
}
async fn usage_stats(&self) -> Vec<(&'static str, String)> {
Vec::new()
}
async fn purge(&self) -> RustvelloResult<()>;
async fn schedule_auto_purge(&self, invocation_id: &InvocationId) -> RustvelloResult<()>;
async fn run_auto_purge(&self, max_age_secs: u64) -> RustvelloResult<Vec<InvocationId>>;
}
#[async_trait]
pub trait OrchestratorConcurrency: Send + Sync {
async fn check_running_concurrency(
&self,
task_id: &TaskId,
task_config: &TaskConfig,
cc_args: Option<&SerializedArguments>,
) -> RustvelloResult<bool>;
async fn index_for_concurrency_control(
&self,
invocation_id: &InvocationId,
task_id: &TaskId,
cc_args: Option<&SerializedArguments>,
) -> RustvelloResult<()>;
async fn remove_from_concurrency_index(
&self,
invocation_id: &InvocationId,
) -> RustvelloResult<()>;
#[instrument(skip(self, task_config, cc_args), fields(%invocation_id, %task_id))]
async fn try_acquire_concurrency_slot(
&self,
invocation_id: &InvocationId,
task_id: &TaskId,
task_config: &TaskConfig,
cc_args: Option<&SerializedArguments>,
) -> RustvelloResult<bool> {
if self
.check_running_concurrency(task_id, task_config, cc_args)
.await?
{
self.index_for_concurrency_control(invocation_id, task_id, cc_args)
.await?;
Ok(true)
} else {
Ok(false)
}
}
}
#[async_trait]
pub trait OrchestratorBlocking: Send + Sync {
async fn set_waiting_for(
&self,
waiter: &InvocationId,
waited_on: &InvocationId,
) -> RustvelloResult<()>;
async fn get_waiters(&self, waited_on: &InvocationId) -> RustvelloResult<Vec<InvocationId>>;
async fn release_waiters(&self, completed: &InvocationId)
-> RustvelloResult<Vec<InvocationId>>;
}
#[async_trait]
pub trait OrchestratorQuery: OrchestratorStatus {
async fn get_invocations_by_task(&self, task_id: &TaskId)
-> RustvelloResult<Vec<InvocationId>>;
async fn get_invocations_by_call(&self, call_id: &CallId)
-> RustvelloResult<Vec<InvocationId>>;
async fn get_invocations_by_status(
&self,
status: InvocationStatus,
task_id: Option<&TaskId>,
) -> RustvelloResult<Vec<InvocationId>>;
async fn count_invocations(
&self,
task_id: Option<&TaskId>,
statuses: Option<&[InvocationStatus]>,
) -> RustvelloResult<usize>;
async fn get_invocation_ids_paginated(
&self,
task_id: Option<&TaskId>,
statuses: Option<&[InvocationStatus]>,
limit: usize,
offset: usize,
) -> RustvelloResult<Vec<InvocationId>>;
#[instrument(skip(self, invocation_ids, statuses), fields(count = invocation_ids.len()))]
async fn filter_by_status(
&self,
invocation_ids: &[InvocationId],
statuses: &[InvocationStatus],
) -> RustvelloResult<Vec<InvocationId>> {
let mut result = Vec::new();
for inv_id in invocation_ids {
if let Ok(record) = self.get_invocation_status(inv_id).await {
if statuses.contains(&record.status) {
result.push(inv_id.clone());
}
}
}
Ok(result)
}
async fn get_blocking_invocations(&self, max_num: usize) -> RustvelloResult<Vec<InvocationId>>;
async fn get_existing_invocations(
&self,
task_id: &TaskId,
cc_args: Option<&SerializedArguments>,
statuses: &[InvocationStatus],
) -> RustvelloResult<Vec<InvocationId>>;
}
#[async_trait]
pub trait OrchestratorRecovery: Send + Sync {
async fn register_heartbeat(
&self,
runner_id: &RunnerId,
can_run_atomic_service: bool,
) -> RustvelloResult<()>;
async fn get_stale_pending_invocations(
&self,
max_pending_seconds: u64,
) -> RustvelloResult<Vec<InvocationId>>;
async fn get_stale_running_invocations(
&self,
runner_dead_after_seconds: u64,
) -> RustvelloResult<Vec<InvocationId>>;
async fn get_active_runner_ids(&self, timeout_seconds: u64) -> RustvelloResult<Vec<RunnerId>>;
async fn get_active_runners(
&self,
timeout_seconds: u64,
can_run_atomic_service: Option<bool>,
) -> RustvelloResult<Vec<ActiveRunnerInfo>>;
async fn record_atomic_service_execution(
&self,
runner_id: &RunnerId,
start: DateTime<Utc>,
end: DateTime<Utc>,
) -> RustvelloResult<()>;
async fn get_atomic_service_timeline(&self) -> RustvelloResult<Vec<AtomicServiceExecution>>;
}