use std::{
collections::{HashMap, HashSet},
time::Duration,
};
use claudius::{
Anthropic, Error as AnthropicError, Message, MessageBatchCreateParams,
MessageBatchCreateRequest, MessageBatchProcessingStatus, MessageBatchResult,
MessageBatchResultVariant, MessageCreateParams,
};
use futures::StreamExt;
use serde_json::{Value, json};
use sqlx::{PgPool, Row};
use uuid::Uuid;
use crate::{
CausalRef, ObservabilityConfig, ObservabilityContext, PendingWorkflowEvent, Trampoline,
Workflow, WorkflowError, WorkflowNext, WorkflowResult, WorkflowStepOutcome,
};
pub mod visualizer;
const MIGRATION_SQL: &str = include_str!("../migrations/0001_batch.sql");
const WORKFLOW_RUNNABLE: &str = "runnable";
const WORKFLOW_WAITING_ANTHROPIC: &str = "waiting_anthropic";
const WORKFLOW_BLOCKED_HUMAN: &str = "blocked_human";
const WORKFLOW_BLOCKED_OPENAI: &str = "blocked_openai";
const WORKFLOW_WAITING_FORK_JOIN: &str = "waiting_fork_join";
const WORKFLOW_HALTED: &str = "halted";
const WORKFLOW_FAILED: &str = "failed";
const CONTINUATION_ANTHROPIC: &str = "anthropic";
const CONTINUATION_HUMAN: &str = "human";
const CONTINUATION_OPENAI: &str = "openai";
const CONTINUATION_PENDING: &str = "pending";
const CONTINUATION_SUBMITTED: &str = "submitted";
const CONTINUATION_BLOCKED: &str = "blocked";
const CONTINUATION_SUCCEEDED: &str = "succeeded";
const CONTINUATION_FAILED: &str = "failed";
const CONTINUATION_RESUMED: &str = "resumed";
const PROVIDER_BATCH_IN_PROGRESS: &str = "in_progress";
const PROVIDER_BATCH_CANCELING: &str = "canceling";
const PROVIDER_BATCH_ENDED: &str = "ended";
const FORK_JOIN_WAITING: &str = "waiting";
const FORK_JOIN_JOINED: &str = "joined";
const FORK_JOIN_FAILED: &str = "failed";
pub type Result<T> = std::result::Result<T, handled::SError>;
pub async fn migrate(pool: &PgPool) -> Result<()> {
for statement in MIGRATION_SQL.split(';') {
let statement = statement.trim();
if statement.is_empty() {
continue;
}
sqlx::query(statement)
.execute(pool)
.await
.map_err(|err| sqlx_error("batch-migrate", err))?;
}
Ok(())
}
#[derive(Clone, Debug)]
pub struct Config {
pub poll_interval: Duration,
pub min_batch_size: usize,
pub max_batch_age: Duration,
pub max_batch_requests: usize,
pub max_workflows_per_poll: usize,
}
impl Default for Config {
fn default() -> Self {
Self {
poll_interval: Duration::from_secs(60),
min_batch_size: 100,
max_batch_age: Duration::from_secs(5 * 60),
max_batch_requests: 10_000,
max_workflows_per_poll: 1_000,
}
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct PollSummary {
pub workflows_advanced: u64,
pub workflows_halted: u64,
pub workflows_failed: u64,
pub continuations_blocked: u64,
pub fork_joins_resumed: u64,
pub provider_batches_submitted: u64,
pub provider_requests_submitted: u64,
pub provider_batches_completed: u64,
pub provider_results_completed: u64,
pub workflows_resumed: u64,
pub events_committed: u64,
pub more_work: bool,
}
impl PollSummary {
fn absorb(&mut self, other: &Self) {
self.workflows_advanced += other.workflows_advanced;
self.workflows_halted += other.workflows_halted;
self.workflows_failed += other.workflows_failed;
self.continuations_blocked += other.continuations_blocked;
self.fork_joins_resumed += other.fork_joins_resumed;
self.provider_batches_submitted += other.provider_batches_submitted;
self.provider_requests_submitted += other.provider_requests_submitted;
self.provider_batches_completed += other.provider_batches_completed;
self.provider_results_completed += other.provider_results_completed;
self.workflows_resumed += other.workflows_resumed;
self.events_committed += other.events_committed;
self.more_work |= other.more_work;
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum WorkflowStatus {
Runnable,
WaitingAnthropic,
BlockedHuman,
BlockedOpenAI,
WaitingForkJoin,
Halted,
Failed,
}
impl WorkflowStatus {
pub fn as_str(self) -> &'static str {
match self {
Self::Runnable => WORKFLOW_RUNNABLE,
Self::WaitingAnthropic => WORKFLOW_WAITING_ANTHROPIC,
Self::BlockedHuman => WORKFLOW_BLOCKED_HUMAN,
Self::BlockedOpenAI => WORKFLOW_BLOCKED_OPENAI,
Self::WaitingForkJoin => WORKFLOW_WAITING_FORK_JOIN,
Self::Halted => WORKFLOW_HALTED,
Self::Failed => WORKFLOW_FAILED,
}
}
fn from_db(value: &str) -> Result<Self> {
match value {
WORKFLOW_RUNNABLE => Ok(Self::Runnable),
WORKFLOW_WAITING_ANTHROPIC => Ok(Self::WaitingAnthropic),
WORKFLOW_BLOCKED_HUMAN => Ok(Self::BlockedHuman),
WORKFLOW_BLOCKED_OPENAI => Ok(Self::BlockedOpenAI),
WORKFLOW_WAITING_FORK_JOIN => Ok(Self::WaitingForkJoin),
WORKFLOW_HALTED => Ok(Self::Halted),
WORKFLOW_FAILED => Ok(Self::Failed),
other => Err(
batch_error("unknown-workflow-status", "unknown workflow status")
.with_string_field("status", other),
),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum ContinuationKind {
Anthropic,
Human,
OpenAI,
}
impl ContinuationKind {
fn from_db(value: &str) -> Result<Self> {
match value {
CONTINUATION_ANTHROPIC => Ok(Self::Anthropic),
CONTINUATION_HUMAN => Ok(Self::Human),
CONTINUATION_OPENAI => Ok(Self::OpenAI),
other => Err(
batch_error("unknown-continuation-kind", "unknown continuation kind")
.with_string_field("kind", other),
),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum ContinuationStatus {
Pending,
Submitted,
Blocked,
Succeeded,
Failed,
Resumed,
}
impl ContinuationStatus {
fn from_db(value: &str) -> Result<Self> {
match value {
CONTINUATION_PENDING => Ok(Self::Pending),
CONTINUATION_SUBMITTED => Ok(Self::Submitted),
CONTINUATION_BLOCKED => Ok(Self::Blocked),
CONTINUATION_SUCCEEDED => Ok(Self::Succeeded),
CONTINUATION_FAILED => Ok(Self::Failed),
CONTINUATION_RESUMED => Ok(Self::Resumed),
other => Err(
batch_error("unknown-continuation-status", "unknown continuation status")
.with_string_field("status", other),
),
}
}
}
#[derive(Clone, Debug)]
pub struct WorkflowRecord {
pub run_id: String,
pub status: WorkflowStatus,
pub workflow: Workflow,
pub parent_run_id: Option<String>,
pub fork_name: Option<String>,
pub error_sexpr: Option<String>,
pub quiescent: bool,
}
#[derive(Clone, Debug)]
pub struct ContinuationRecord {
pub continuation_id: String,
pub workflow_run_id: String,
pub kind: ContinuationKind,
pub status: ContinuationStatus,
pub provider: Option<String>,
pub output_key: Option<String>,
pub error_sexpr: Option<String>,
pub quiescent: bool,
}
#[derive(Clone, Debug)]
pub struct WorkflowEventRecord {
pub event_id: Uuid,
pub root_run_id: String,
pub run_id: String,
pub parent_run_id: Option<String>,
pub fork_name: Option<String>,
pub event_ordinal: i64,
pub caused_by: CausalRef,
pub event_type: String,
pub event_version: i16,
pub continuation_id: Option<String>,
pub event: Value,
pub created_at: String,
}
pub struct Executor {
trampoline: Trampoline,
pool: PgPool,
config: Config,
anthropic: HashMap<String, Anthropic>,
}
impl Executor {
pub fn new(trampoline: Trampoline, pool: PgPool, config: Config) -> Self {
crate::wire_executor_indicio_stderr();
Self {
trampoline,
pool,
config,
anthropic: HashMap::new(),
}
}
pub fn with_default_config(trampoline: Trampoline, pool: PgPool) -> Self {
Self::new(trampoline, pool, Config::default())
}
pub fn register_anthropic(&mut self, provider: impl Into<String>, client: Anthropic) {
self.anthropic.insert(provider.into(), client);
}
pub fn with_anthropic(mut self, provider: impl Into<String>, client: Anthropic) -> Self {
self.register_anthropic(provider, client);
self
}
pub async fn enqueue_workflow(&self, mut workflow: Workflow) -> Result<()> {
let run_id = workflow.run_id().to_string();
let pending_events = workflow.drain_pending_events();
let workflow_json = workflow_to_value(&workflow)?;
if let Some(row) = sqlx::query("SELECT workflow FROM batch_workflows WHERE run_id = $1")
.bind(&run_id)
.fetch_optional(&self.pool)
.await
.map_err(|err| sqlx_error("batch-enqueue", err))?
{
let existing: Value = row
.try_get("workflow")
.map_err(|err| sqlx_error("batch-enqueue", err))?;
if existing == workflow_json {
return Ok(());
}
return Err(batch_error(
"duplicate-workflow-run",
"workflow run id already exists with different state",
)
.with_string_field("run_id", &run_id));
}
let causal_cursor = CausalRef::RunId {
run_id: run_id.clone(),
};
let causal_cursor_json = causal_ref_to_value(&causal_cursor)?;
let enqueued = first_party_event(
self.trampoline.observability_config(),
"workflow.enqueued",
None,
json!({
"kind": "top_level",
"run_id": run_id,
"next_action": self.trampoline.next_action(&workflow),
}),
causal_cursor.clone(),
)?;
let events = chain_pending_after(enqueued, pending_events);
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-enqueue", err))?;
sqlx::query(
"INSERT INTO batch_workflows \
(run_id, root_run_id, workflow, status, causal_cursor, quiescent) \
VALUES ($1, $2, $3, $4, $5, false)",
)
.bind(&run_id)
.bind(&run_id)
.bind(workflow_json)
.bind(WORKFLOW_RUNNABLE)
.bind(causal_cursor_json)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-enqueue", err))?;
append_workflow_events(&mut tx, &run_id, &causal_cursor, &events).await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-enqueue", err))?;
Ok(())
}
pub async fn enqueue_workflows(
&self,
workflows: impl IntoIterator<Item = Workflow>,
) -> Result<()> {
for workflow in workflows {
self.enqueue_workflow(workflow).await?;
}
Ok(())
}
pub async fn workflow_status(&self, run_id: &str) -> Result<Option<WorkflowStatus>> {
let Some(row) = sqlx::query("SELECT status FROM batch_workflows WHERE run_id = $1")
.bind(run_id)
.fetch_optional(&self.pool)
.await
.map_err(|err| sqlx_error("batch-load-workflow-status", err))?
else {
return Ok(None);
};
let status: String = row
.try_get("status")
.map_err(|err| sqlx_error("batch-load-workflow-status", err))?;
WorkflowStatus::from_db(&status).map(Some)
}
pub async fn load_workflow(&self, run_id: &str) -> Result<Option<WorkflowRecord>> {
let row = sqlx::query(
"SELECT run_id, workflow, status, parent_run_id, fork_name, \
error_sexpr, quiescent \
FROM batch_workflows WHERE run_id = $1",
)
.bind(run_id)
.fetch_optional(&self.pool)
.await
.map_err(|err| sqlx_error("batch-load-workflow", err))?;
row.map(workflow_record_from_row).transpose()
}
pub async fn load_workflow_event(&self, event_id: Uuid) -> Result<Option<WorkflowEventRecord>> {
let row = sqlx::query(
"SELECT event_id, root_run_id, run_id, parent_run_id, fork_name, \
event_ordinal, caused_by, event_type, event_version, \
continuation_id, event, created_at::TEXT AS created_at \
FROM batch_workflow_events WHERE event_id = $1",
)
.bind(event_id)
.fetch_optional(&self.pool)
.await
.map_err(|err| sqlx_error("batch-load-workflow-event", err))?;
row.map(workflow_event_record_from_row).transpose()
}
pub async fn load_workflow_events(&self, run_id: &str) -> Result<Vec<WorkflowEventRecord>> {
let rows = sqlx::query(
"SELECT event_id, root_run_id, run_id, parent_run_id, fork_name, \
event_ordinal, caused_by, event_type, event_version, \
continuation_id, event, created_at::TEXT AS created_at \
FROM batch_workflow_events \
WHERE run_id = $1 ORDER BY event_ordinal",
)
.bind(run_id)
.fetch_all(&self.pool)
.await
.map_err(|err| sqlx_error("batch-load-workflow-events", err))?;
rows.into_iter()
.map(workflow_event_record_from_row)
.collect()
}
pub async fn load_root_workflow_events(
&self,
root_run_id: &str,
) -> Result<Vec<WorkflowEventRecord>> {
let rows = sqlx::query(
"SELECT event_id, root_run_id, run_id, parent_run_id, fork_name, \
event_ordinal, caused_by, event_type, event_version, \
continuation_id, event, created_at::TEXT AS created_at \
FROM batch_workflow_events \
WHERE root_run_id = $1 ORDER BY created_at, event_id",
)
.bind(root_run_id)
.fetch_all(&self.pool)
.await
.map_err(|err| sqlx_error("batch-load-root-workflow-events", err))?;
rows.into_iter()
.map(workflow_event_record_from_row)
.collect()
}
pub async fn list_blocked_continuations(&self) -> Result<Vec<ContinuationRecord>> {
self.list_continuations_by_status(CONTINUATION_BLOCKED)
.await
}
pub async fn list_failed_continuations(&self) -> Result<Vec<ContinuationRecord>> {
self.list_continuations_by_status(CONTINUATION_FAILED).await
}
pub async fn resume_human<T: serde::Serialize>(
&self,
continuation_id: &str,
value: T,
) -> Result<()> {
let value_json = serde_json::to_value(&value).map_err(|err| {
batch_error(
"invalid-human-response",
"failed to serialize human response",
)
.with_string_field("continuation_id", continuation_id)
.with_string_field("source", &err.to_string())
})?;
let blocked = self
.load_blocked_continuation(continuation_id, CONTINUATION_HUMAN)
.await?;
let output_key = blocked.output_key.ok_or_else(|| {
batch_error(
"missing-human-output-key",
"blocked human continuation is missing its output key",
)
.with_string_field("continuation_id", continuation_id)
})?;
let workflow = self
.trampoline
.resume_human(blocked.workflow, output_key, value)
.map_err(|err| err.with_string_field("continuation_id", continuation_id))?;
self.finish_manual_resume(
continuation_id,
&blocked.run_id,
workflow,
value_json,
"human.resumed",
blocked.causal_cursor,
)
.await
}
pub async fn resume_open_ai(
&self,
continuation_id: &str,
output_key: impl Into<String>,
value: Value,
) -> Result<()> {
let blocked = self
.load_blocked_continuation(continuation_id, CONTINUATION_OPENAI)
.await?;
let workflow = self
.trampoline
.resume_open_ai(blocked.workflow, output_key, value.clone())
.map_err(|err| err.with_string_field("continuation_id", continuation_id))?;
self.finish_manual_resume(
continuation_id,
&blocked.run_id,
workflow,
value,
"openai.resumed",
blocked.causal_cursor,
)
.await
}
pub async fn retry_continuation(&self, continuation_id: &str) -> Result<String> {
let row = sqlx::query(
"SELECT workflow_run_id, provider, output_key, request, attempt_of, attempt \
FROM batch_continuations \
WHERE continuation_id = $1 AND kind = $2 AND status = $3",
)
.bind(continuation_id)
.bind(CONTINUATION_ANTHROPIC)
.bind(CONTINUATION_FAILED)
.fetch_optional(&self.pool)
.await
.map_err(|err| sqlx_error("batch-retry-continuation", err))?
.ok_or_else(|| {
batch_error(
"retry-continuation-not-found",
"failed Anthropic continuation was not found",
)
.with_string_field("continuation_id", continuation_id)
})?;
let workflow_run_id: String = row
.try_get("workflow_run_id")
.map_err(|err| sqlx_error("batch-retry-continuation", err))?;
let provider: String = row
.try_get("provider")
.map_err(|err| sqlx_error("batch-retry-continuation", err))?;
let output_key: String = row
.try_get("output_key")
.map_err(|err| sqlx_error("batch-retry-continuation", err))?;
let request: Value = row
.try_get("request")
.map_err(|err| sqlx_error("batch-retry-continuation", err))?;
let attempt_of: Option<String> = row
.try_get("attempt_of")
.map_err(|err| sqlx_error("batch-retry-continuation", err))?;
let attempt: i32 = row
.try_get("attempt")
.map_err(|err| sqlx_error("batch-retry-continuation", err))?;
let root_attempt = attempt_of.unwrap_or_else(|| continuation_id.to_string());
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-retry-continuation", err))?;
let workflow_cursor = load_workflow_causal_cursor(&mut tx, &workflow_run_id).await?;
let mut retried = first_party_event(
self.trampoline.observability_config(),
"anthropic.retried",
None,
json!({
"previous_continuation_id": continuation_id,
"attempt": attempt + 1,
"provider": provider.clone(),
}),
workflow_cursor.clone(),
)?;
let retry_cursor = CausalRef::EventId {
event_id: retried.event_id,
};
let new_id = insert_continuation(
&mut tx,
NewContinuation {
kind: CONTINUATION_ANTHROPIC,
status: CONTINUATION_PENDING,
workflow_run_id: &workflow_run_id,
provider: Some(&provider),
output_key: Some(&output_key),
request: Some(request),
attempt_of: Some(&root_attempt),
attempt: attempt + 1,
causal_cursor: retry_cursor,
},
)
.await?;
retried.continuation_id = Some(new_id.clone());
if let Value::Object(payload) = &mut retried.event {
payload.insert(
"new_continuation_id".to_string(),
Value::String(new_id.clone()),
);
}
sqlx::query(
"UPDATE batch_workflows \
SET status = $2, error_sexpr = NULL, quiescent = false, \
quiesced_at = NULL, updated_at = now() \
WHERE run_id = $1",
)
.bind(&workflow_run_id)
.bind(WORKFLOW_WAITING_ANTHROPIC)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-retry-continuation", err))?;
append_workflow_events(&mut tx, &workflow_run_id, &workflow_cursor, &[retried]).await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-retry-continuation", err))?;
Ok(new_id)
}
pub async fn poll(&self) -> Result<PollSummary> {
log_batch_executor_transition("poll.started", indicio::value!({}));
let mut summary = PollSummary::default();
self.advance_runnable_workflows(&mut summary).await?;
self.resume_ready_fork_joins(&mut summary).await?;
self.process_provider_batches(&mut summary).await?;
self.submit_ready_anthropic_batches(&mut summary).await?;
summary.more_work = self.has_immediate_work().await?;
log_batch_executor_transition("poll.completed", poll_summary_to_value(&summary));
Ok(summary)
}
pub async fn run(&self) -> Result<PollSummary> {
log_batch_executor_transition("run.started", indicio::value!({}));
let mut total = PollSummary::default();
loop {
let summary = self.poll().await?;
let more_work = summary.more_work;
total.absorb(&summary);
if more_work {
log_batch_executor_transition(
"run.continued",
indicio::value!({ summary: poll_summary_to_value(&summary) }),
);
continue;
}
if self.has_runtime_work().await? {
let poll_interval_ms =
u64::try_from(self.config.poll_interval.as_millis()).unwrap_or(u64::MAX);
log_batch_executor_transition(
"run.sleeping",
indicio::value!({
poll_interval_ms: poll_interval_ms,
summary: poll_summary_to_value(&summary),
}),
);
tokio::time::sleep(self.config.poll_interval).await;
continue;
}
total.more_work = false;
log_batch_executor_transition(
"run.completed",
indicio::value!({ summary: poll_summary_to_value(&total) }),
);
return Ok(total);
}
}
async fn list_continuations_by_status(&self, status: &str) -> Result<Vec<ContinuationRecord>> {
let rows = sqlx::query(
"SELECT continuation_id, workflow_run_id, kind, status, provider, \
output_key, error_sexpr, quiescent \
FROM batch_continuations \
WHERE status = $1 ORDER BY created_at, id",
)
.bind(status)
.fetch_all(&self.pool)
.await
.map_err(|err| sqlx_error("batch-list-continuations", err))?;
rows.into_iter().map(continuation_record_from_row).collect()
}
async fn load_blocked_continuation(
&self,
continuation_id: &str,
kind: &str,
) -> Result<BlockedContinuation> {
let row = sqlx::query(
"SELECT c.workflow_run_id, c.output_key, c.causal_cursor, w.workflow \
FROM batch_continuations c \
JOIN batch_workflows w ON w.run_id = c.workflow_run_id \
WHERE c.continuation_id = $1 AND c.kind = $2 AND c.status = $3",
)
.bind(continuation_id)
.bind(kind)
.bind(CONTINUATION_BLOCKED)
.fetch_optional(&self.pool)
.await
.map_err(|err| sqlx_error("batch-load-blocked-continuation", err))?
.ok_or_else(|| {
batch_error(
"blocked-continuation-not-found",
"blocked continuation was not found",
)
.with_string_field("continuation_id", continuation_id)
.with_string_field("kind", kind)
})?;
let run_id: String = row
.try_get("workflow_run_id")
.map_err(|err| sqlx_error("batch-load-blocked-continuation", err))?;
let output_key = row
.try_get("output_key")
.map_err(|err| sqlx_error("batch-load-blocked-continuation", err))?;
let workflow_json: Value = row
.try_get("workflow")
.map_err(|err| sqlx_error("batch-load-blocked-continuation", err))?;
let workflow = workflow_from_value(workflow_json)?;
Ok(BlockedContinuation {
run_id,
output_key,
causal_cursor: causal_ref_from_value(
row.try_get("causal_cursor")
.map_err(|err| sqlx_error("batch-load-blocked-continuation", err))?,
)?,
workflow,
})
}
async fn finish_manual_resume(
&self,
continuation_id: &str,
run_id: &str,
workflow: Workflow,
response: Value,
event_type: &str,
causal_cursor: CausalRef,
) -> Result<()> {
let workflow_json = workflow_to_value(&workflow)?;
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-finish-manual-resume", err))?;
let resumed = first_party_event(
self.trampoline.observability_config(),
event_type,
Some(continuation_id.to_string()),
json!({
"continuation_id": continuation_id,
"response_ref": {
"kind": "batch_continuation_response",
"continuation_id": continuation_id
}
}),
causal_cursor.clone(),
)?;
let final_cursor = CausalRef::EventId {
event_id: resumed.event_id,
};
sqlx::query(
"UPDATE batch_workflows \
SET workflow = $2, status = $3, error_sexpr = NULL, \
quiescent = false, quiesced_at = NULL, updated_at = now() \
WHERE run_id = $1",
)
.bind(run_id)
.bind(workflow_json)
.bind(WORKFLOW_RUNNABLE)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-finish-manual-resume", err))?;
sqlx::query(
"UPDATE batch_continuations \
SET status = $2, response = $3, completed_at = now(), \
quiescent = true, quiesced_at = now(), causal_cursor = $4, \
updated_at = now() \
WHERE continuation_id = $1 AND status = $5",
)
.bind(continuation_id)
.bind(CONTINUATION_RESUMED)
.bind(response)
.bind(causal_ref_to_value(&final_cursor)?)
.bind(CONTINUATION_BLOCKED)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-finish-manual-resume", err))?;
append_workflow_events(&mut tx, run_id, &causal_cursor, &[resumed]).await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-finish-manual-resume", err))?;
Ok(())
}
async fn advance_runnable_workflows(&self, summary: &mut PollSummary) -> Result<()> {
let rows = sqlx::query(
"SELECT run_id, workflow, causal_cursor FROM batch_workflows \
WHERE status = $1 AND quiescent = false \
ORDER BY updated_at, id LIMIT $2",
)
.bind(WORKFLOW_RUNNABLE)
.bind(limit_i64(self.config.max_workflows_per_poll))
.fetch_all(&self.pool)
.await
.map_err(|err| sqlx_error("batch-load-runnable-workflows", err))?;
for row in rows {
let run_id: String = row
.try_get("run_id")
.map_err(|err| sqlx_error("batch-load-runnable-workflows", err))?;
let workflow_json: Value = row
.try_get("workflow")
.map_err(|err| sqlx_error("batch-load-runnable-workflows", err))?;
let mut workflow = workflow_from_value(workflow_json)?;
let mut causal_cursor = causal_ref_from_value(
row.try_get("causal_cursor")
.map_err(|err| sqlx_error("batch-load-runnable-workflows", err))?,
)?;
summary.workflows_advanced += 1;
loop {
match self.trampoline.next_action(&workflow) {
WorkflowNext::LocalCall { function } => {
let started_event_id = self
.commit_local_call_started(&run_id, &function, &causal_cursor, summary)
.await?;
let started_cursor = CausalRef::EventId {
event_id: started_event_id,
};
workflow.set_observability_context(ObservabilityContext {
causal_cursor: started_cursor.clone(),
});
match self.trampoline.run_one_local_call(workflow).await {
Ok(outcome) => {
causal_cursor = self
.commit_local_call_completed(
&run_id,
&started_cursor,
&outcome,
summary,
)
.await?;
workflow = outcome.workflow;
}
Err(err) => {
self.commit_local_call_failed(
&run_id,
&started_cursor,
err,
summary,
)
.await?;
summary.workflows_failed += 1;
break;
}
}
}
_ => {
match self.trampoline.run(workflow).await {
Ok(outcome) => {
self.persist_workflow_result(&run_id, outcome.result, summary)
.await?
}
Err(err) => {
self.mark_workflow_failed(&run_id, err.source.to_string(), summary)
.await?;
summary.workflows_failed += 1;
}
}
break;
}
}
}
}
Ok(())
}
async fn commit_local_call_started(
&self,
run_id: &str,
function: &str,
expected_cursor: &CausalRef,
summary: &mut PollSummary,
) -> Result<Uuid> {
let event = first_party_event(
self.trampoline.observability_config(),
"local_call.started",
None,
json!({ "function": function }),
expected_cursor.clone(),
)?;
let event_id = event.event_id;
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-local-call-started", err))?;
summary.events_committed +=
append_workflow_events(&mut tx, run_id, expected_cursor, &[event]).await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-local-call-started", err))?;
Ok(event_id)
}
async fn commit_local_call_completed(
&self,
run_id: &str,
started_cursor: &CausalRef,
outcome: &WorkflowStepOutcome,
summary: &mut PollSummary,
) -> Result<CausalRef> {
let mut events = outcome.events.clone();
let caused_by = events
.last()
.map(|event| CausalRef::EventId {
event_id: event.event_id,
})
.unwrap_or_else(|| started_cursor.clone());
let completed = first_party_event(
self.trampoline.observability_config(),
"local_call.completed",
None,
json!({
"function": outcome.function.clone(),
"duration_ms": u64::try_from(outcome.duration_ms).unwrap_or(u64::MAX),
"env_changes": outcome.env_changes.clone(),
"flow": outcome.flow.clone(),
}),
caused_by,
)?;
let final_cursor = CausalRef::EventId {
event_id: completed.event_id,
};
events.push(completed);
let workflow_json = workflow_to_value(&outcome.workflow)?;
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-local-call-completed", err))?;
sqlx::query(
"UPDATE batch_workflows \
SET workflow = $2, status = $3, error_sexpr = NULL, \
quiescent = false, quiesced_at = NULL, updated_at = now() \
WHERE run_id = $1",
)
.bind(run_id)
.bind(workflow_json)
.bind(WORKFLOW_RUNNABLE)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-local-call-completed", err))?;
summary.events_committed +=
append_workflow_events(&mut tx, run_id, started_cursor, &events).await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-local-call-completed", err))?;
Ok(final_cursor)
}
async fn commit_local_call_failed(
&self,
run_id: &str,
started_cursor: &CausalRef,
error: WorkflowError,
summary: &mut PollSummary,
) -> Result<()> {
let function = error.function.clone();
let error_sexpr = error.source.to_string();
let mut events = error.events;
let caused_by = events
.last()
.map(|event| CausalRef::EventId {
event_id: event.event_id,
})
.unwrap_or_else(|| started_cursor.clone());
let failed = first_party_event(
self.trampoline.observability_config(),
"local_call.failed",
None,
json!({
"function": function,
"duration_ms": error.duration_ms.map(|value| u64::try_from(value).unwrap_or(u64::MAX)),
"env_changes": error.env_changes,
"flow": error.flow,
"error": {
"source": error_sexpr,
}
}),
caused_by,
)?;
let workflow_failed = first_party_event(
self.trampoline.observability_config(),
"workflow.failed",
None,
json!({
"reason": "local_call_failed",
"error_ref": {
"kind": "batch_workflow_error"
}
}),
CausalRef::EventId {
event_id: failed.event_id,
},
)?;
events.push(failed);
events.push(workflow_failed);
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-local-call-failed", err))?;
sqlx::query(
"UPDATE batch_workflows \
SET status = $2, error_sexpr = $3, quiescent = true, \
quiesced_at = now(), updated_at = now() \
WHERE run_id = $1",
)
.bind(run_id)
.bind(WORKFLOW_FAILED)
.bind(error_sexpr)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-local-call-failed", err))?;
summary.events_committed +=
append_workflow_events(&mut tx, run_id, started_cursor, &events).await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-local-call-failed", err))?;
Ok(())
}
async fn persist_workflow_result(
&self,
run_id: &str,
result: WorkflowResult,
summary: &mut PollSummary,
) -> Result<()> {
let mut result = result;
while let WorkflowResult::ToolCall {
workflow,
tool_uses,
output_key,
} = result
{
match self
.run_tool_call_inline(workflow, tool_uses, output_key, summary)
.await
{
Ok(workflow) => match self.trampoline.run(workflow).await {
Ok(next) => result = next.result,
Err(err) => {
self.mark_workflow_failed(run_id, err.source.to_string(), summary)
.await?;
summary.workflows_failed += 1;
return Ok(());
}
},
Err(err) => {
self.mark_workflow_failed(run_id, err.to_string(), summary)
.await?;
summary.workflows_failed += 1;
return Ok(());
}
}
}
match result {
WorkflowResult::Halt { workflow } => {
self.update_workflow_terminal(run_id, workflow, WORKFLOW_HALTED, None, summary)
.await?;
summary.workflows_halted += 1;
}
WorkflowResult::Anthropic {
workflow,
provider,
message,
output_key,
} => {
let request = (*message).with_stream(false);
self.store_anthropic_continuation(
run_id, workflow, provider, output_key, request, summary,
)
.await?;
}
WorkflowResult::Human {
workflow,
request,
output_key,
} => {
self.store_blocked_continuation(
run_id,
workflow,
CONTINUATION_HUMAN,
WORKFLOW_BLOCKED_HUMAN,
Some(output_key),
Some(serde_json::to_value(request).map_err(|err| {
json_error(
"invalid-human-request",
"failed to serialize human request",
err,
)
})?),
summary,
)
.await?;
summary.continuations_blocked += 1;
}
WorkflowResult::OpenAI { workflow } => {
self.store_blocked_continuation(
run_id,
workflow,
CONTINUATION_OPENAI,
WORKFLOW_BLOCKED_OPENAI,
None,
None,
summary,
)
.await?;
summary.continuations_blocked += 1;
}
WorkflowResult::ToolCall { .. } => {
unreachable!("tool-call results are dispatched inline before persistence");
}
WorkflowResult::ForkJoin { workflow, lhs, rhs } => {
self.store_fork_join(run_id, workflow, *lhs, *rhs, summary)
.await?;
}
}
Ok(())
}
async fn run_tool_call_inline(
&self,
workflow: Workflow,
tool_uses: Vec<claudius::ToolUseBlock>,
output_key: String,
summary: &mut PollSummary,
) -> Result<Workflow> {
let run_id = workflow.run_id().to_string();
let tool_names: Vec<_> = tool_uses
.iter()
.map(|tool_use| tool_use.name.clone())
.collect();
let tool_call_ids: Vec<_> = tool_uses
.iter()
.map(|tool_use| format!("{}:{}", run_id, tool_use.id))
.collect();
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-tool-call-started", err))?;
let workflow_cursor = load_workflow_causal_cursor(&mut tx, &run_id).await?;
let started = first_party_event(
self.trampoline.observability_config(),
"tool_call.started",
None,
json!({
"tool_names": tool_names,
"tool_call_ids": tool_call_ids,
}),
workflow_cursor.clone(),
)?;
let started_cursor = CausalRef::EventId {
event_id: started.event_id,
};
summary.events_committed +=
append_workflow_events(&mut tx, &run_id, &workflow_cursor, &[started]).await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-tool-call-started", err))?;
let results = match self.trampoline.run_tool_calls(&run_id, &tool_uses).await {
Ok(results) => results,
Err(err) => {
self.commit_tool_call_failed(
&run_id,
&started_cursor,
"dispatch_failed",
&err.to_string(),
summary,
)
.await?;
return Err(batch_error(
"tool-call-dispatch-failed",
"failed to dispatch tool calls",
)
.with_string_field("run_id", &run_id)
.with_string_field("source", &err.to_string()));
}
};
let result_error_count = results
.iter()
.filter(|result| result.is_error.unwrap_or(false))
.count();
let workflow = match self
.trampoline
.resume_tool_call(workflow, output_key, results)
{
Ok(workflow) => workflow,
Err(err) => {
self.commit_tool_call_failed(
&run_id,
&started_cursor,
"resume_failed",
&err.to_string(),
summary,
)
.await?;
return Err(batch_error(
"tool-call-resume-failed",
"failed to resume after tool calls",
)
.with_string_field("run_id", &run_id)
.with_string_field("source", &err.to_string()));
}
};
let workflow_json = workflow_to_value(&workflow)?;
let completed = first_party_event(
self.trampoline.observability_config(),
"tool_call.completed",
None,
json!({
"tool_count": tool_uses.len(),
"result_error_count": result_error_count,
}),
started_cursor.clone(),
)?;
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-tool-call-completed", err))?;
sqlx::query(
"UPDATE batch_workflows \
SET workflow = $2, status = $3, error_sexpr = NULL, \
quiescent = false, quiesced_at = NULL, updated_at = now() \
WHERE run_id = $1",
)
.bind(&run_id)
.bind(workflow_json)
.bind(WORKFLOW_RUNNABLE)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-tool-call-completed", err))?;
summary.events_committed +=
append_workflow_events(&mut tx, &run_id, &started_cursor, &[completed]).await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-tool-call-completed", err))?;
Ok(workflow)
}
async fn commit_tool_call_failed(
&self,
run_id: &str,
started_cursor: &CausalRef,
reason: &str,
source: &str,
summary: &mut PollSummary,
) -> Result<()> {
let failed = first_party_event(
self.trampoline.observability_config(),
"tool_call.failed",
None,
json!({
"reason": reason,
"source": source,
}),
started_cursor.clone(),
)?;
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-tool-call-failed", err))?;
summary.events_committed +=
append_workflow_events(&mut tx, run_id, started_cursor, &[failed]).await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-tool-call-failed", err))?;
Ok(())
}
async fn update_workflow_terminal(
&self,
run_id: &str,
workflow: Workflow,
status: &str,
error_sexpr: Option<String>,
summary: &mut PollSummary,
) -> Result<()> {
let workflow_json = workflow_to_value(&workflow)?;
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-update-workflow-terminal", err))?;
let workflow_cursor = load_workflow_causal_cursor(&mut tx, run_id).await?;
let event = match status {
WORKFLOW_HALTED => Some(first_party_event(
self.trampoline.observability_config(),
"workflow.halted",
None,
json!({
"env_key_count": workflow.env.len(),
}),
workflow_cursor.clone(),
)?),
_ => None,
};
sqlx::query(
"UPDATE batch_workflows \
SET workflow = $2, status = $3, error_sexpr = $4, \
quiescent = true, quiesced_at = now(), updated_at = now() \
WHERE run_id = $1",
)
.bind(run_id)
.bind(workflow_json)
.bind(status)
.bind(error_sexpr)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-update-workflow-terminal", err))?;
if let Some(event) = event {
summary.events_committed +=
append_workflow_events(&mut tx, run_id, &workflow_cursor, &[event]).await?;
}
tx.commit()
.await
.map_err(|err| sqlx_error("batch-update-workflow-terminal", err))?;
Ok(())
}
async fn mark_workflow_failed(
&self,
run_id: &str,
error_sexpr: String,
summary: &mut PollSummary,
) -> Result<()> {
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-mark-workflow-failed", err))?;
let workflow_cursor = load_workflow_causal_cursor(&mut tx, run_id).await?;
let event = first_party_event(
self.trampoline.observability_config(),
"workflow.failed",
None,
json!({
"reason": "workflow_execution_failed",
"error_ref": {
"kind": "batch_workflow_error"
}
}),
workflow_cursor.clone(),
)?;
sqlx::query(
"UPDATE batch_workflows \
SET status = $2, error_sexpr = $3, quiescent = true, \
quiesced_at = now(), updated_at = now() \
WHERE run_id = $1",
)
.bind(run_id)
.bind(WORKFLOW_FAILED)
.bind(error_sexpr)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-mark-workflow-failed", err))?;
summary.events_committed +=
append_workflow_events(&mut tx, run_id, &workflow_cursor, &[event]).await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-mark-workflow-failed", err))?;
Ok(())
}
async fn store_anthropic_continuation(
&self,
run_id: &str,
workflow: Workflow,
provider: String,
output_key: String,
request: MessageCreateParams,
summary: &mut PollSummary,
) -> Result<String> {
let workflow_json = workflow_to_value(&workflow)?;
let request_json = serde_json::to_value(request).map_err(|err| {
json_error(
"invalid-anthropic-request",
"failed to serialize Anthropic request",
err,
)
})?;
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-store-anthropic-continuation", err))?;
let workflow_cursor = load_workflow_causal_cursor(&mut tx, run_id).await?;
let mut suspended = first_party_event(
self.trampoline.observability_config(),
"anthropic.suspended",
None,
json!({
"provider": provider.clone(),
"output_key": output_key.clone(),
"request_ref": {
"kind": "batch_continuation_request"
}
}),
workflow_cursor.clone(),
)?;
let continuation_cursor = CausalRef::EventId {
event_id: suspended.event_id,
};
sqlx::query(
"UPDATE batch_workflows \
SET workflow = $2, status = $3, error_sexpr = NULL, \
quiescent = false, quiesced_at = NULL, updated_at = now() \
WHERE run_id = $1",
)
.bind(run_id)
.bind(workflow_json)
.bind(WORKFLOW_WAITING_ANTHROPIC)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-store-anthropic-continuation", err))?;
let id = insert_continuation(
&mut tx,
NewContinuation {
kind: CONTINUATION_ANTHROPIC,
status: CONTINUATION_PENDING,
workflow_run_id: run_id,
provider: Some(&provider),
output_key: Some(&output_key),
request: Some(request_json),
attempt_of: None,
attempt: 1,
causal_cursor: continuation_cursor,
},
)
.await?;
suspended.continuation_id = Some(id.clone());
if let Value::Object(payload) = &mut suspended.event {
payload.insert("continuation_id".to_string(), Value::String(id.clone()));
if let Some(Value::Object(request_ref)) = payload.get_mut("request_ref") {
request_ref.insert("continuation_id".to_string(), Value::String(id.clone()));
}
}
summary.events_committed +=
append_workflow_events(&mut tx, run_id, &workflow_cursor, &[suspended]).await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-store-anthropic-continuation", err))?;
log_batch_executor_transition(
"anthropic.continuation_stored",
indicio::value!({
run_id: run_id,
continuation_id: &id,
provider: &provider,
output_key: &output_key,
}),
);
Ok(id)
}
async fn store_blocked_continuation(
&self,
run_id: &str,
workflow: Workflow,
continuation_kind: &str,
workflow_status: &str,
output_key: Option<String>,
request: Option<Value>,
summary: &mut PollSummary,
) -> Result<String> {
let workflow_json = workflow_to_value(&workflow)?;
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-store-blocked-continuation", err))?;
let workflow_cursor = load_workflow_causal_cursor(&mut tx, run_id).await?;
let event_type = match continuation_kind {
CONTINUATION_HUMAN => "human.blocked",
CONTINUATION_OPENAI => "openai.blocked",
_ => "continuation.blocked",
};
let mut blocked = first_party_event(
self.trampoline.observability_config(),
event_type,
None,
json!({
"kind": continuation_kind,
"output_key": output_key.clone(),
}),
workflow_cursor.clone(),
)?;
let continuation_cursor = CausalRef::EventId {
event_id: blocked.event_id,
};
sqlx::query(
"UPDATE batch_workflows \
SET workflow = $2, status = $3, error_sexpr = NULL, \
quiescent = false, quiesced_at = NULL, updated_at = now() \
WHERE run_id = $1",
)
.bind(run_id)
.bind(workflow_json)
.bind(workflow_status)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-store-blocked-continuation", err))?;
let id = insert_continuation(
&mut tx,
NewContinuation {
kind: continuation_kind,
status: CONTINUATION_BLOCKED,
workflow_run_id: run_id,
provider: None,
output_key: output_key.as_deref(),
request,
attempt_of: None,
attempt: 1,
causal_cursor: continuation_cursor,
},
)
.await?;
blocked.continuation_id = Some(id.clone());
if let Value::Object(payload) = &mut blocked.event {
payload.insert("continuation_id".to_string(), Value::String(id.clone()));
}
summary.events_committed +=
append_workflow_events(&mut tx, run_id, &workflow_cursor, &[blocked]).await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-store-blocked-continuation", err))?;
Ok(id)
}
async fn store_fork_join(
&self,
run_id: &str,
workflow: Workflow,
lhs: Workflow,
rhs: Workflow,
summary: &mut PollSummary,
) -> Result<()> {
let workflow_json = workflow_to_value(&workflow)?;
let lhs_run_id = lhs.run_id().to_string();
let rhs_run_id = rhs.run_id().to_string();
let lhs_json = workflow_to_value(&lhs)?;
let rhs_json = workflow_to_value(&rhs)?;
let mut branch_run_id = serde_json::Map::new();
branch_run_id.insert("lhs".to_string(), Value::String(lhs_run_id.clone()));
branch_run_id.insert("rhs".to_string(), Value::String(rhs_run_id.clone()));
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-store-fork-join", err))?;
let parent_row =
sqlx::query("SELECT root_run_id, causal_cursor FROM batch_workflows WHERE run_id = $1")
.bind(run_id)
.fetch_one(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-store-fork-join", err))?;
let root_run_id: String = parent_row
.try_get("root_run_id")
.map_err(|err| sqlx_error("batch-store-fork-join", err))?;
let parent_cursor = causal_ref_from_value(
parent_row
.try_get("causal_cursor")
.map_err(|err| sqlx_error("batch-store-fork-join", err))?,
)?;
let fork_started = first_party_event(
self.trampoline.observability_config(),
"fork_join.started",
None,
json!({
"branch_run_id": Value::Object(branch_run_id),
}),
parent_cursor.clone(),
)?;
let fork_event_id = fork_started.event_id;
sqlx::query(
"UPDATE batch_workflows \
SET workflow = $2, status = $3, error_sexpr = NULL, \
quiescent = false, quiesced_at = NULL, updated_at = now() \
WHERE run_id = $1",
)
.bind(run_id)
.bind(workflow_json)
.bind(WORKFLOW_WAITING_FORK_JOIN)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-store-fork-join", err))?;
summary.events_committed +=
append_workflow_events(&mut tx, run_id, &parent_cursor, &[fork_started]).await?;
let branch_cursor = CausalRef::EventId {
event_id: fork_event_id,
};
insert_workflow_row(
&mut tx,
&lhs_run_id,
&root_run_id,
lhs_json,
Some(run_id),
Some("lhs"),
&branch_cursor,
)
.await?;
insert_workflow_row(
&mut tx,
&rhs_run_id,
&root_run_id,
rhs_json,
Some(run_id),
Some("rhs"),
&branch_cursor,
)
.await?;
let lhs_enqueued = first_party_event(
self.trampoline.observability_config(),
"workflow.enqueued",
None,
json!({
"kind": "fork_branch",
"parent_run_id": run_id,
"fork_name": "lhs",
}),
branch_cursor.clone(),
)?;
summary.events_committed +=
append_workflow_events(&mut tx, &lhs_run_id, &branch_cursor, &[lhs_enqueued]).await?;
let rhs_enqueued = first_party_event(
self.trampoline.observability_config(),
"workflow.enqueued",
None,
json!({
"kind": "fork_branch",
"parent_run_id": run_id,
"fork_name": "rhs",
}),
branch_cursor.clone(),
)?;
summary.events_committed +=
append_workflow_events(&mut tx, &rhs_run_id, &branch_cursor, &[rhs_enqueued]).await?;
sqlx::query(
"INSERT INTO batch_fork_joins \
(parent_run_id, lhs_run_id, rhs_run_id, status, quiescent) \
VALUES ($1, $2, $3, $4, false)",
)
.bind(run_id)
.bind(&lhs_run_id)
.bind(&rhs_run_id)
.bind(FORK_JOIN_WAITING)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-store-fork-join", err))?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-store-fork-join", err))?;
Ok(())
}
async fn resume_ready_fork_joins(&self, summary: &mut PollSummary) -> Result<()> {
let rows = sqlx::query(
"SELECT fj.parent_run_id, p.workflow AS parent_workflow, \
p.causal_cursor, lhs.workflow AS lhs_workflow, \
rhs.workflow AS rhs_workflow, fj.lhs_run_id, fj.rhs_run_id \
FROM batch_fork_joins fj \
JOIN batch_workflows p ON p.run_id = fj.parent_run_id \
JOIN batch_workflows lhs ON lhs.run_id = fj.lhs_run_id \
JOIN batch_workflows rhs ON rhs.run_id = fj.rhs_run_id \
WHERE fj.status = $1 AND fj.quiescent = false \
AND lhs.status = $2 AND rhs.status = $2 \
ORDER BY fj.updated_at, fj.id \
LIMIT $3",
)
.bind(FORK_JOIN_WAITING)
.bind(WORKFLOW_HALTED)
.bind(limit_i64(self.config.max_workflows_per_poll))
.fetch_all(&self.pool)
.await
.map_err(|err| sqlx_error("batch-load-ready-fork-joins", err))?;
for row in rows {
let parent_run_id: String = row
.try_get("parent_run_id")
.map_err(|err| sqlx_error("batch-load-ready-fork-joins", err))?;
let parent = workflow_from_value(
row.try_get("parent_workflow")
.map_err(|err| sqlx_error("batch-load-ready-fork-joins", err))?,
)?;
let parent_cursor = causal_ref_from_value(
row.try_get("causal_cursor")
.map_err(|err| sqlx_error("batch-load-ready-fork-joins", err))?,
)?;
let lhs_run_id: String = row
.try_get("lhs_run_id")
.map_err(|err| sqlx_error("batch-load-ready-fork-joins", err))?;
let rhs_run_id: String = row
.try_get("rhs_run_id")
.map_err(|err| sqlx_error("batch-load-ready-fork-joins", err))?;
let lhs = workflow_from_value(
row.try_get("lhs_workflow")
.map_err(|err| sqlx_error("batch-load-ready-fork-joins", err))?,
)?;
let rhs = workflow_from_value(
row.try_get("rhs_workflow")
.map_err(|err| sqlx_error("batch-load-ready-fork-joins", err))?,
)?;
match self.trampoline.resume_fork_join(parent, lhs, rhs) {
Ok(workflow) => {
let workflow_json = workflow_to_value(&workflow)?;
let completed = first_party_event(
self.trampoline.observability_config(),
"fork_join.completed",
None,
json!({
"branch_run_id": {
"lhs": lhs_run_id,
"rhs": rhs_run_id
}
}),
parent_cursor.clone(),
)?;
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-resume-fork-join", err))?;
sqlx::query(
"UPDATE batch_workflows \
SET workflow = $2, status = $3, error_sexpr = NULL, \
quiescent = false, quiesced_at = NULL, updated_at = now() \
WHERE run_id = $1",
)
.bind(&parent_run_id)
.bind(workflow_json)
.bind(WORKFLOW_RUNNABLE)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-resume-fork-join", err))?;
sqlx::query(
"UPDATE batch_fork_joins \
SET status = $2, quiescent = true, quiesced_at = now(), \
updated_at = now() \
WHERE parent_run_id = $1",
)
.bind(&parent_run_id)
.bind(FORK_JOIN_JOINED)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-resume-fork-join", err))?;
summary.events_committed += append_workflow_events(
&mut tx,
&parent_run_id,
&parent_cursor,
&[completed],
)
.await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-resume-fork-join", err))?;
summary.fork_joins_resumed += 1;
}
Err(err) => {
let error_sexpr = err.to_string();
self.mark_fork_join_failed(&parent_run_id, error_sexpr, summary)
.await?;
summary.workflows_failed += 1;
}
}
}
Ok(())
}
async fn mark_fork_join_failed(
&self,
parent_run_id: &str,
error_sexpr: String,
summary: &mut PollSummary,
) -> Result<()> {
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-mark-fork-join-failed", err))?;
let parent_cursor = load_workflow_causal_cursor(&mut tx, parent_run_id).await?;
let failed = first_party_event(
self.trampoline.observability_config(),
"fork_join.failed",
None,
json!({
"error_ref": {
"kind": "batch_workflow_error"
}
}),
parent_cursor.clone(),
)?;
let workflow_failed = first_party_event(
self.trampoline.observability_config(),
"workflow.failed",
None,
json!({
"reason": "fork_join_failed",
"error_ref": {
"kind": "batch_workflow_error"
}
}),
CausalRef::EventId {
event_id: failed.event_id,
},
)?;
sqlx::query(
"UPDATE batch_workflows \
SET status = $2, error_sexpr = $3, quiescent = true, \
quiesced_at = now(), updated_at = now() \
WHERE run_id = $1",
)
.bind(parent_run_id)
.bind(WORKFLOW_FAILED)
.bind(&error_sexpr)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-mark-fork-join-failed", err))?;
sqlx::query(
"UPDATE batch_fork_joins \
SET status = $2, error_sexpr = $3, quiescent = true, \
quiesced_at = now(), updated_at = now() \
WHERE parent_run_id = $1",
)
.bind(parent_run_id)
.bind(FORK_JOIN_FAILED)
.bind(error_sexpr)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-mark-fork-join-failed", err))?;
summary.events_committed += append_workflow_events(
&mut tx,
parent_run_id,
&parent_cursor,
&[failed, workflow_failed],
)
.await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-mark-fork-join-failed", err))?;
Ok(())
}
async fn process_provider_batches(&self, summary: &mut PollSummary) -> Result<()> {
let rows = sqlx::query(
"SELECT provider, provider_batch_id FROM batch_provider_batches \
WHERE status IN ($1, $2) AND quiescent = false \
ORDER BY submitted_at, id",
)
.bind(PROVIDER_BATCH_IN_PROGRESS)
.bind(PROVIDER_BATCH_CANCELING)
.fetch_all(&self.pool)
.await
.map_err(|err| sqlx_error("batch-load-provider-batches", err))?;
for row in rows {
let provider: String = row
.try_get("provider")
.map_err(|err| sqlx_error("batch-load-provider-batches", err))?;
let provider_batch_id: String = row
.try_get("provider_batch_id")
.map_err(|err| sqlx_error("batch-load-provider-batches", err))?;
let client = self.anthropic.get(&provider).ok_or_else(|| {
missing_anthropic_provider_error(&provider)
.with_string_field("provider_batch_id", &provider_batch_id)
})?;
log_batch_executor_transition(
"anthropic.batch.retrieve_started",
indicio::value!({
provider: &provider,
provider_batch_id: &provider_batch_id,
}),
);
let batch = match client.retrieve_message_batch(&provider_batch_id).await {
Ok(batch) => batch,
Err(err) => {
let error = anthropic_error(err);
log_batch_executor_transition(
"anthropic.batch.retrieve_failed",
indicio::value!({
provider: &provider,
provider_batch_id: &provider_batch_id,
source: error.to_string(),
}),
);
return Err(error);
}
};
log_batch_executor_transition(
"anthropic.batch.retrieve_completed",
indicio::value!({
provider: &provider,
provider_batch_id: &provider_batch_id,
status: provider_batch_status(batch.processing_status),
}),
);
if batch.processing_status != MessageBatchProcessingStatus::Ended {
self.update_provider_batch_status(&batch).await?;
log_batch_executor_transition(
"anthropic.batch.status_updated",
indicio::value!({
provider: &provider,
provider_batch_id: &provider_batch_id,
status: provider_batch_status(batch.processing_status),
}),
);
continue;
}
log_batch_executor_transition(
"anthropic.batch.results_stream_started",
indicio::value!({
provider: &provider,
provider_batch_id: &provider_batch_id,
}),
);
let stream = match client
.stream_message_batch_results(&provider_batch_id)
.await
{
Ok(stream) => stream,
Err(err) => {
let error = anthropic_error(err);
log_batch_executor_transition(
"anthropic.batch.results_stream_failed",
indicio::value!({
provider: &provider,
provider_batch_id: &provider_batch_id,
source: error.to_string(),
}),
);
return Err(error);
}
};
futures::pin_mut!(stream);
let mut seen_results = 0_u64;
let mut processed_results = 0_u64;
while let Some(result) = stream.next().await {
let result = match result {
Ok(result) => result,
Err(err) => {
let error = anthropic_error(err);
log_batch_executor_transition(
"anthropic.batch.result_stream_item_failed",
indicio::value!({
provider: &provider,
provider_batch_id: &provider_batch_id,
source: error.to_string(),
}),
);
return Err(error);
}
};
seen_results += 1;
if self.process_provider_result(result, summary).await? {
processed_results += 1;
summary.provider_results_completed += 1;
}
}
self.finish_provider_batch(&batch).await?;
log_batch_executor_transition(
"anthropic.batch.results_stream_completed",
indicio::value!({
provider: &provider,
provider_batch_id: &provider_batch_id,
seen_results: seen_results,
processed_results: processed_results,
}),
);
summary.provider_batches_completed += 1;
}
Ok(())
}
async fn update_provider_batch_status(&self, batch: &claudius::MessageBatch) -> Result<()> {
let response = serde_json::to_value(batch).map_err(|err| {
json_error(
"invalid-provider-batch-response",
"failed to serialize provider batch response",
err,
)
})?;
sqlx::query(
"UPDATE batch_provider_batches \
SET status = $2, response = $3, updated_at = now() \
WHERE provider_batch_id = $1",
)
.bind(&batch.id)
.bind(provider_batch_status(batch.processing_status))
.bind(response)
.execute(&self.pool)
.await
.map_err(|err| sqlx_error("batch-update-provider-batch", err))?;
Ok(())
}
async fn finish_provider_batch(&self, batch: &claudius::MessageBatch) -> Result<()> {
let response = serde_json::to_value(batch).map_err(|err| {
json_error(
"invalid-provider-batch-response",
"failed to serialize provider batch response",
err,
)
})?;
sqlx::query(
"UPDATE batch_provider_batches \
SET status = $2, response = $3, completed_at = now(), \
quiescent = true, quiesced_at = now(), updated_at = now() \
WHERE provider_batch_id = $1",
)
.bind(&batch.id)
.bind(PROVIDER_BATCH_ENDED)
.bind(response)
.execute(&self.pool)
.await
.map_err(|err| sqlx_error("batch-finish-provider-batch", err))?;
Ok(())
}
async fn process_provider_result(
&self,
result: MessageBatchResult,
summary: &mut PollSummary,
) -> Result<bool> {
let result_type = provider_result_type(&result.result);
let row = sqlx::query(
"SELECT c.status, c.workflow_run_id, c.output_key, c.causal_cursor, \
c.provider_batch_id, c.attempt, w.workflow \
FROM batch_continuations c \
JOIN batch_workflows w ON w.run_id = c.workflow_run_id \
WHERE c.continuation_id = $1 AND c.kind = $2",
)
.bind(&result.custom_id)
.bind(CONTINUATION_ANTHROPIC)
.fetch_optional(&self.pool)
.await
.map_err(|err| sqlx_error("batch-load-provider-result-continuation", err))?;
let Some(row) = row else {
log_batch_executor_transition(
"anthropic.batch.result_ignored",
indicio::value!({
continuation_id: &result.custom_id,
result_type: result_type,
reason: "unknown_continuation",
}),
);
return Ok(false);
};
let status: String = row
.try_get("status")
.map_err(|err| sqlx_error("batch-load-provider-result-continuation", err))?;
if status != CONTINUATION_SUBMITTED {
log_batch_executor_transition(
"anthropic.batch.result_ignored",
indicio::value!({
continuation_id: &result.custom_id,
result_type: result_type,
reason: "unexpected_status",
status: &status,
}),
);
return Ok(false);
}
let workflow_run_id: String = row
.try_get("workflow_run_id")
.map_err(|err| sqlx_error("batch-load-provider-result-continuation", err))?;
let output_key: String = row
.try_get("output_key")
.map_err(|err| sqlx_error("batch-load-provider-result-continuation", err))?;
let causal_cursor = causal_ref_from_value(
row.try_get("causal_cursor")
.map_err(|err| sqlx_error("batch-load-provider-result-continuation", err))?,
)?;
let provider_batch_id: Option<String> = row
.try_get("provider_batch_id")
.map_err(|err| sqlx_error("batch-load-provider-result-continuation", err))?;
let attempt: i32 = row
.try_get("attempt")
.map_err(|err| sqlx_error("batch-load-provider-result-continuation", err))?;
let workflow = workflow_from_value(
row.try_get("workflow")
.map_err(|err| sqlx_error("batch-load-provider-result-continuation", err))?,
)?;
let result_json = serde_json::to_value(&result).map_err(|err| {
json_error(
"invalid-provider-result",
"failed to serialize provider batch result",
err,
)
})?;
log_batch_executor_transition(
"anthropic.batch.result_received",
indicio::value!({
continuation_id: &result.custom_id,
run_id: &workflow_run_id,
provider_batch_id: crate::optional_indicio_string(provider_batch_id.as_deref()),
attempt: attempt,
result_type: result_type,
}),
);
let provider_batch_id_for_log = provider_batch_id.clone();
match result.result {
MessageBatchResultVariant::Succeeded { message } => {
self.finish_anthropic_success(
&result.custom_id,
&workflow_run_id,
workflow,
output_key,
message,
result_json,
causal_cursor,
provider_batch_id,
attempt,
summary,
)
.await?;
summary.workflows_resumed += 1;
}
MessageBatchResultVariant::Errored { error } => {
let error_sexpr =
batch_error("anthropic-batch-item-error", "Anthropic batch item failed")
.with_string_field("type", &error.error.r#type)
.with_string_field("message", &error.error.message)
.to_string();
self.fail_anthropic_continuation(
&result.custom_id,
&workflow_run_id,
"errored",
result_json,
error_sexpr,
causal_cursor,
provider_batch_id,
attempt,
summary,
)
.await?;
summary.workflows_failed += 1;
}
MessageBatchResultVariant::Canceled => {
let error_sexpr = batch_error(
"anthropic-batch-item-canceled",
"Anthropic batch item was canceled",
)
.to_string();
self.fail_anthropic_continuation(
&result.custom_id,
&workflow_run_id,
"canceled",
result_json,
error_sexpr,
causal_cursor,
provider_batch_id,
attempt,
summary,
)
.await?;
summary.workflows_failed += 1;
}
MessageBatchResultVariant::Expired => {
let error_sexpr = batch_error(
"anthropic-batch-item-expired",
"Anthropic batch item expired",
)
.to_string();
self.fail_anthropic_continuation(
&result.custom_id,
&workflow_run_id,
"expired",
result_json,
error_sexpr,
causal_cursor,
provider_batch_id,
attempt,
summary,
)
.await?;
summary.workflows_failed += 1;
}
}
log_batch_executor_transition(
"anthropic.batch.result_processed",
indicio::value!({
continuation_id: &result.custom_id,
run_id: &workflow_run_id,
provider_batch_id: crate::optional_indicio_string(provider_batch_id_for_log.as_deref()),
attempt: attempt,
result_type: result_type,
}),
);
Ok(true)
}
async fn finish_anthropic_success(
&self,
continuation_id: &str,
workflow_run_id: &str,
workflow: Workflow,
output_key: String,
message: Message,
result_json: Value,
causal_cursor: CausalRef,
provider_batch_id: Option<String>,
attempt: i32,
summary: &mut PollSummary,
) -> Result<()> {
let provider_message_id = message.id.clone();
let usage = message.usage;
log_batch_executor_transition(
"anthropic.resume_started",
indicio::value!({
continuation_id: continuation_id,
run_id: workflow_run_id,
provider_batch_id: crate::optional_indicio_string(provider_batch_id.as_deref()),
provider_message_id: &provider_message_id,
attempt: attempt,
output_key: &output_key,
}),
);
match self
.trampoline
.resume_anthropic(workflow, output_key.clone(), message.clone())
{
Ok(workflow) => {
let workflow_json = workflow_to_value(&workflow)?;
let usage_json = serde_json::to_value(usage).map_err(|err| {
json_error(
"invalid-anthropic-usage",
"failed to serialize Anthropic usage",
err,
)
})?;
let server_tool_use = usage
.server_tool_use
.map(serde_json::to_value)
.transpose()
.map_err(|err| {
json_error(
"invalid-anthropic-server-tool-use",
"failed to serialize Anthropic server tool usage",
err,
)
})?;
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-finish-anthropic-success", err))?;
let completed = first_party_event(
self.trampoline.observability_config(),
"anthropic.completed",
Some(continuation_id.to_string()),
json!({
"continuation_id": continuation_id,
"provider_batch_id": provider_batch_id.clone(),
"provider_message_id": provider_message_id.clone(),
"attempt": attempt,
"usage": usage_json.clone(),
"response_ref": {
"kind": "batch_continuation_response",
"continuation_id": continuation_id
}
}),
causal_cursor.clone(),
)?;
let resumed = first_party_event(
self.trampoline.observability_config(),
"anthropic.resumed",
Some(continuation_id.to_string()),
json!({
"continuation_id": continuation_id,
"output_key": output_key.clone(),
}),
CausalRef::EventId {
event_id: completed.event_id,
},
)?;
let final_cursor = CausalRef::EventId {
event_id: resumed.event_id,
};
sqlx::query(
"UPDATE batch_workflows \
SET workflow = $2, status = $3, error_sexpr = NULL, \
quiescent = false, quiesced_at = NULL, updated_at = now() \
WHERE run_id = $1",
)
.bind(workflow_run_id)
.bind(workflow_json)
.bind(WORKFLOW_RUNNABLE)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-finish-anthropic-success", err))?;
sqlx::query(
"UPDATE batch_continuations \
SET status = $2, provider_message_id = $3, response = $4, \
usage = $5, input_tokens = $6, output_tokens = $7, \
cache_creation_input_tokens = $8, \
cache_read_input_tokens = $9, server_tool_use = $10, \
completed_at = now(), quiescent = true, \
quiesced_at = now(), causal_cursor = $11, updated_at = now() \
WHERE continuation_id = $1 AND status = $12",
)
.bind(continuation_id)
.bind(CONTINUATION_SUCCEEDED)
.bind(provider_message_id.clone())
.bind(result_json)
.bind(usage_json.clone())
.bind(i64::from(usage.input_tokens))
.bind(i64::from(usage.output_tokens))
.bind(usage.cache_creation_input_tokens.map(i64::from))
.bind(usage.cache_read_input_tokens.map(i64::from))
.bind(server_tool_use)
.bind(causal_ref_to_value(&final_cursor)?)
.bind(CONTINUATION_SUBMITTED)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-finish-anthropic-success", err))?;
summary.events_committed += append_workflow_events(
&mut tx,
workflow_run_id,
&causal_cursor,
&[completed, resumed],
)
.await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-finish-anthropic-success", err))?;
log_batch_executor_transition(
"anthropic.resume_completed",
indicio::value!({
continuation_id: continuation_id,
run_id: workflow_run_id,
provider_batch_id: crate::optional_indicio_string(provider_batch_id.as_deref()),
provider_message_id: &provider_message_id,
attempt: attempt,
output_key: &output_key,
}),
);
Ok(())
}
Err(err) => {
log_batch_executor_transition(
"anthropic.resume_failed",
indicio::value!({
continuation_id: continuation_id,
run_id: workflow_run_id,
provider_batch_id: crate::optional_indicio_string(provider_batch_id.as_deref()),
provider_message_id: &provider_message_id,
attempt: attempt,
output_key: &output_key,
source: err.to_string(),
}),
);
self.fail_anthropic_continuation(
continuation_id,
workflow_run_id,
"resume_error",
result_json,
err.to_string(),
causal_cursor,
provider_batch_id,
attempt,
summary,
)
.await
}
}
}
async fn fail_anthropic_continuation(
&self,
continuation_id: &str,
workflow_run_id: &str,
result_type: &str,
response: Value,
error_sexpr: String,
causal_cursor: CausalRef,
provider_batch_id: Option<String>,
attempt: i32,
summary: &mut PollSummary,
) -> Result<()> {
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-fail-anthropic-continuation", err))?;
let provider_batch_id_for_log = provider_batch_id.clone();
let failed = first_party_event(
self.trampoline.observability_config(),
"anthropic.failed",
Some(continuation_id.to_string()),
json!({
"continuation_id": continuation_id,
"provider_batch_id": provider_batch_id,
"attempt": attempt,
"result_type": result_type,
"error_ref": {
"kind": "batch_continuation_error",
"continuation_id": continuation_id
}
}),
causal_cursor.clone(),
)?;
let workflow_failed = first_party_event(
self.trampoline.observability_config(),
"workflow.failed",
None,
json!({
"reason": "anthropic_failed",
"error_ref": {
"kind": "batch_workflow_error"
}
}),
CausalRef::EventId {
event_id: failed.event_id,
},
)?;
let final_cursor = CausalRef::EventId {
event_id: workflow_failed.event_id,
};
sqlx::query(
"UPDATE batch_continuations \
SET status = $2, result_type = $3, response = $4, error_sexpr = $5, \
completed_at = now(), quiescent = true, quiesced_at = now(), \
causal_cursor = $6, updated_at = now() \
WHERE continuation_id = $1 AND status = $7",
)
.bind(continuation_id)
.bind(CONTINUATION_FAILED)
.bind(result_type)
.bind(response)
.bind(&error_sexpr)
.bind(causal_ref_to_value(&final_cursor)?)
.bind(CONTINUATION_SUBMITTED)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-fail-anthropic-continuation", err))?;
sqlx::query(
"UPDATE batch_workflows \
SET status = $2, error_sexpr = $3, quiescent = true, \
quiesced_at = now(), updated_at = now() \
WHERE run_id = $1",
)
.bind(workflow_run_id)
.bind(WORKFLOW_FAILED)
.bind(error_sexpr)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-fail-anthropic-continuation", err))?;
summary.events_committed += append_workflow_events(
&mut tx,
workflow_run_id,
&causal_cursor,
&[failed, workflow_failed],
)
.await?;
tx.commit()
.await
.map_err(|err| sqlx_error("batch-fail-anthropic-continuation", err))?;
log_batch_executor_transition(
"anthropic.continuation_failed",
indicio::value!({
continuation_id: continuation_id,
run_id: workflow_run_id,
provider_batch_id: crate::optional_indicio_string(provider_batch_id_for_log.as_deref()),
attempt: attempt,
result_type: result_type,
}),
);
Ok(())
}
async fn submit_ready_anthropic_batches(&self, summary: &mut PollSummary) -> Result<()> {
let providers = self.pending_anthropic_providers().await?;
for provider in providers {
let client = self
.anthropic
.get(&provider)
.ok_or_else(|| missing_anthropic_provider_error(&provider))?;
if !self.should_flush_provider(&provider).await? {
continue;
}
let pending = self.pending_anthropic_requests(&provider).await?;
if pending.is_empty() {
continue;
}
let requests: Vec<_> = pending
.iter()
.map(|pending| {
let params: MessageCreateParams =
serde_json::from_value(pending.request.clone()).map_err(|err| {
json_error(
"invalid-anthropic-request",
"failed to decode saved Anthropic request",
err,
)
})?;
Ok(MessageBatchCreateRequest::new(
pending.continuation_id.clone(),
params.with_stream(false),
))
})
.collect::<Result<Vec<_>>>()?;
log_batch_executor_transition(
"anthropic.batch.create_started",
indicio::value!({
provider: &provider,
request_count: pending.len(),
}),
);
let batch = match client
.create_message_batch(MessageBatchCreateParams::new(requests))
.await
{
Ok(batch) => batch,
Err(err) => {
let error = anthropic_error(err);
log_batch_executor_transition(
"anthropic.batch.create_failed",
indicio::value!({
provider: &provider,
request_count: pending.len(),
source: error.to_string(),
}),
);
return Err(error);
}
};
let batch_json = serde_json::to_value(&batch).map_err(|err| {
json_error(
"invalid-provider-batch-response",
"failed to serialize provider batch response",
err,
)
})?;
let mut tx = self
.pool
.begin()
.await
.map_err(|err| sqlx_error("batch-submit-anthropic-batch", err))?;
sqlx::query(
"INSERT INTO batch_provider_batches \
(provider, provider_batch_id, status, request_count, response, quiescent) \
VALUES ($1, $2, $3, $4, $5, false) \
ON CONFLICT (provider_batch_id) DO UPDATE \
SET status = EXCLUDED.status, response = EXCLUDED.response, \
updated_at = now()",
)
.bind(&provider)
.bind(&batch.id)
.bind(provider_batch_status(batch.processing_status))
.bind(i32::try_from(pending.len()).unwrap_or(i32::MAX))
.bind(batch_json)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-submit-anthropic-batch", err))?;
for pending in &pending {
let submitted = first_party_event(
self.trampoline.observability_config(),
"anthropic.submitted",
Some(pending.continuation_id.clone()),
json!({
"continuation_id": pending.continuation_id.clone(),
"provider_batch_id": batch.id.clone(),
}),
pending.causal_cursor.clone(),
)?;
let submitted_cursor = causal_ref_to_value(&CausalRef::EventId {
event_id: submitted.event_id,
})?;
sqlx::query(
"UPDATE batch_continuations \
SET status = $2, provider_batch_id = $3, submitted_at = now(), \
causal_cursor = $4, updated_at = now() \
WHERE continuation_id = $1 AND status = $5",
)
.bind(&pending.continuation_id)
.bind(CONTINUATION_SUBMITTED)
.bind(&batch.id)
.bind(submitted_cursor)
.bind(CONTINUATION_PENDING)
.execute(&mut *tx)
.await
.map_err(|err| sqlx_error("batch-submit-anthropic-batch", err))?;
summary.events_committed += append_workflow_events(
&mut tx,
&pending.workflow_run_id,
&pending.causal_cursor,
&[submitted],
)
.await?;
}
tx.commit()
.await
.map_err(|err| sqlx_error("batch-submit-anthropic-batch", err))?;
log_batch_executor_transition(
"anthropic.batch.create_completed",
indicio::value!({
provider: &provider,
provider_batch_id: &batch.id,
request_count: pending.len(),
status: provider_batch_status(batch.processing_status),
}),
);
summary.provider_batches_submitted += 1;
summary.provider_requests_submitted += pending.len() as u64;
}
Ok(())
}
async fn pending_anthropic_providers(&self) -> Result<Vec<String>> {
let rows = sqlx::query(
"SELECT DISTINCT provider FROM batch_continuations \
WHERE kind = $1 AND status = $2 AND provider IS NOT NULL \
ORDER BY provider",
)
.bind(CONTINUATION_ANTHROPIC)
.bind(CONTINUATION_PENDING)
.fetch_all(&self.pool)
.await
.map_err(|err| sqlx_error("batch-load-pending-providers", err))?;
rows.into_iter()
.map(|row| {
row.try_get("provider")
.map_err(|err| sqlx_error("batch-load-pending-providers", err))
})
.collect()
}
async fn should_flush_provider(&self, provider: &str) -> Result<bool> {
let row = sqlx::query(
"SELECT COUNT(*)::BIGINT AS count, \
EXTRACT(EPOCH FROM (now() - MIN(created_at)))::DOUBLE PRECISION AS age \
FROM batch_continuations \
WHERE kind = $1 AND status = $2 AND provider = $3",
)
.bind(CONTINUATION_ANTHROPIC)
.bind(CONTINUATION_PENDING)
.bind(provider)
.fetch_one(&self.pool)
.await
.map_err(|err| sqlx_error("batch-count-pending-provider", err))?;
let count: i64 = row
.try_get("count")
.map_err(|err| sqlx_error("batch-count-pending-provider", err))?;
let age: Option<f64> = row
.try_get("age")
.map_err(|err| sqlx_error("batch-count-pending-provider", err))?;
let active = self.provider_has_active_batch(provider).await?;
Ok(should_flush(
usize::try_from(count).unwrap_or(usize::MAX),
age.map(Duration::from_secs_f64),
active,
&self.config,
))
}
async fn provider_has_active_batch(&self, provider: &str) -> Result<bool> {
let row = sqlx::query(
"SELECT EXISTS( \
SELECT 1 FROM batch_provider_batches \
WHERE provider = $1 AND status IN ($2, $3) AND quiescent = false \
) AS active",
)
.bind(provider)
.bind(PROVIDER_BATCH_IN_PROGRESS)
.bind(PROVIDER_BATCH_CANCELING)
.fetch_one(&self.pool)
.await
.map_err(|err| sqlx_error("batch-provider-has-active", err))?;
row.try_get("active")
.map_err(|err| sqlx_error("batch-provider-has-active", err))
}
async fn pending_anthropic_requests(&self, provider: &str) -> Result<Vec<PendingAnthropic>> {
let rows = sqlx::query(
"SELECT continuation_id, workflow_run_id, request, causal_cursor \
FROM batch_continuations \
WHERE kind = $1 AND status = $2 AND provider = $3 \
ORDER BY created_at, id LIMIT $4",
)
.bind(CONTINUATION_ANTHROPIC)
.bind(CONTINUATION_PENDING)
.bind(provider)
.bind(limit_i64(self.config.max_batch_requests))
.fetch_all(&self.pool)
.await
.map_err(|err| sqlx_error("batch-load-pending-anthropic", err))?;
rows.into_iter()
.map(|row| {
Ok(PendingAnthropic {
continuation_id: row
.try_get("continuation_id")
.map_err(|err| sqlx_error("batch-load-pending-anthropic", err))?,
workflow_run_id: row
.try_get("workflow_run_id")
.map_err(|err| sqlx_error("batch-load-pending-anthropic", err))?,
request: row
.try_get("request")
.map_err(|err| sqlx_error("batch-load-pending-anthropic", err))?,
causal_cursor: causal_ref_from_value(
row.try_get("causal_cursor")
.map_err(|err| sqlx_error("batch-load-pending-anthropic", err))?,
)?,
})
})
.collect()
}
async fn has_immediate_work(&self) -> Result<bool> {
if self.exists_workflow_status(WORKFLOW_RUNNABLE).await? {
return Ok(true);
}
if self.exists_ready_fork_join().await? {
return Ok(true);
}
for provider in self.pending_anthropic_providers().await? {
if self.should_flush_provider(&provider).await? {
return Ok(true);
}
}
Ok(false)
}
async fn has_runtime_work(&self) -> Result<bool> {
if self.has_immediate_work().await? {
return Ok(true);
}
let row = sqlx::query(
"SELECT EXISTS( \
SELECT 1 FROM batch_provider_batches \
WHERE status IN ($1, $2) AND quiescent = false \
) AS active",
)
.bind(PROVIDER_BATCH_IN_PROGRESS)
.bind(PROVIDER_BATCH_CANCELING)
.fetch_one(&self.pool)
.await
.map_err(|err| sqlx_error("batch-has-runtime-work", err))?;
row.try_get("active")
.map_err(|err| sqlx_error("batch-has-runtime-work", err))
}
async fn exists_workflow_status(&self, status: &str) -> Result<bool> {
let row = sqlx::query(
"SELECT EXISTS( \
SELECT 1 FROM batch_workflows \
WHERE status = $1 AND quiescent = false \
) AS exists",
)
.bind(status)
.fetch_one(&self.pool)
.await
.map_err(|err| sqlx_error("batch-exists-workflow-status", err))?;
row.try_get("exists")
.map_err(|err| sqlx_error("batch-exists-workflow-status", err))
}
async fn exists_ready_fork_join(&self) -> Result<bool> {
let row = sqlx::query(
"SELECT EXISTS( \
SELECT 1 FROM batch_fork_joins fj \
JOIN batch_workflows lhs ON lhs.run_id = fj.lhs_run_id \
JOIN batch_workflows rhs ON rhs.run_id = fj.rhs_run_id \
WHERE fj.status = $1 AND fj.quiescent = false \
AND lhs.status = $2 AND rhs.status = $2 \
) AS exists",
)
.bind(FORK_JOIN_WAITING)
.bind(WORKFLOW_HALTED)
.fetch_one(&self.pool)
.await
.map_err(|err| sqlx_error("batch-exists-ready-fork-join", err))?;
row.try_get("exists")
.map_err(|err| sqlx_error("batch-exists-ready-fork-join", err))
}
}
struct BlockedContinuation {
run_id: String,
output_key: Option<String>,
causal_cursor: CausalRef,
workflow: Workflow,
}
struct PendingAnthropic {
continuation_id: String,
workflow_run_id: String,
request: Value,
causal_cursor: CausalRef,
}
async fn append_workflow_events(
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
run_id: &str,
expected_cursor: &CausalRef,
events: &[PendingWorkflowEvent],
) -> Result<u64> {
if events.is_empty() {
return Ok(0);
}
validate_event_causes(tx, events).await?;
let new_cursor = CausalRef::EventId {
event_id: events.last().expect("events is not empty").event_id,
};
let expected_cursor_json = causal_ref_to_value(expected_cursor)?;
let new_cursor_json = causal_ref_to_value(&new_cursor)?;
let event_count = i64::try_from(events.len()).map_err(|_| {
batch_error(
"too-many-workflow-events",
"too many workflow events were provided in one commit",
)
})?;
let row = sqlx::query(
"UPDATE batch_workflows \
SET next_event_ordinal = next_event_ordinal + $2, \
causal_cursor = $3, updated_at = now() \
WHERE run_id = $1 AND causal_cursor = $4 \
RETURNING root_run_id, parent_run_id, fork_name, \
next_event_ordinal - $2 AS first_ordinal",
)
.bind(run_id)
.bind(event_count)
.bind(new_cursor_json)
.bind(expected_cursor_json)
.fetch_optional(&mut **tx)
.await
.map_err(|err| sqlx_error("batch-append-workflow-events", err))?
.ok_or_else(|| {
batch_error(
"workflow-event-cursor-mismatch",
"workflow causal cursor did not match while appending events",
)
.with_string_field("run_id", run_id)
})?;
let root_run_id: String = row
.try_get("root_run_id")
.map_err(|err| sqlx_error("batch-append-workflow-events", err))?;
let parent_run_id: Option<String> = row
.try_get("parent_run_id")
.map_err(|err| sqlx_error("batch-append-workflow-events", err))?;
let fork_name: Option<String> = row
.try_get("fork_name")
.map_err(|err| sqlx_error("batch-append-workflow-events", err))?;
let first_ordinal: i64 = row
.try_get("first_ordinal")
.map_err(|err| sqlx_error("batch-append-workflow-events", err))?;
for (offset, event) in events.iter().enumerate() {
let event_ordinal = first_ordinal + i64::try_from(offset).unwrap_or(i64::MAX);
let caused_by = causal_ref_to_value(&event.caused_by)?;
sqlx::query(
"INSERT INTO batch_workflow_events \
(event_id, root_run_id, run_id, parent_run_id, fork_name, \
event_ordinal, caused_by, event_type, event_version, \
continuation_id, event) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)",
)
.bind(event.event_id)
.bind(&root_run_id)
.bind(run_id)
.bind(&parent_run_id)
.bind(&fork_name)
.bind(event_ordinal)
.bind(caused_by)
.bind(&event.event_type)
.bind(event.event_version)
.bind(&event.continuation_id)
.bind(&event.event)
.execute(&mut **tx)
.await
.map_err(|err| sqlx_error("batch-insert-workflow-event", err))?;
log_workflow_event(
&root_run_id,
run_id,
parent_run_id.as_deref(),
fork_name.as_deref(),
event_ordinal,
event,
);
}
Ok(events.len() as u64)
}
fn log_workflow_event(
root_run_id: &str,
run_id: &str,
parent_run_id: Option<&str>,
fork_name: Option<&str>,
event_ordinal: i64,
event: &PendingWorkflowEvent,
) {
crate::log_json_clue(
indicio::INFO,
json!({
"log_type": "langcontinuation.batch.workflow_event",
"event_id": &event.event_id,
"root_run_id": root_run_id,
"run_id": run_id,
"parent_run_id": parent_run_id,
"fork_name": fork_name,
"event_ordinal": event_ordinal,
"caused_by": &event.caused_by,
"event_type": &event.event_type,
"event_version": event.event_version,
"continuation_id": &event.continuation_id,
"event": &event.event,
}),
);
}
fn log_batch_executor_transition(transition: &str, fields: indicio::Value) {
crate::log_executor_transition("batch", transition, fields);
}
fn poll_summary_to_value(summary: &PollSummary) -> indicio::Value {
indicio::value!({
workflows_advanced: summary.workflows_advanced,
workflows_halted: summary.workflows_halted,
workflows_failed: summary.workflows_failed,
continuations_blocked: summary.continuations_blocked,
fork_joins_resumed: summary.fork_joins_resumed,
provider_batches_submitted: summary.provider_batches_submitted,
provider_requests_submitted: summary.provider_requests_submitted,
provider_batches_completed: summary.provider_batches_completed,
provider_results_completed: summary.provider_results_completed,
workflows_resumed: summary.workflows_resumed,
events_committed: summary.events_committed,
more_work: summary.more_work,
})
}
fn provider_result_type(result: &MessageBatchResultVariant) -> &'static str {
match result {
MessageBatchResultVariant::Succeeded { .. } => "succeeded",
MessageBatchResultVariant::Errored { .. } => "errored",
MessageBatchResultVariant::Canceled => "canceled",
MessageBatchResultVariant::Expired => "expired",
}
}
async fn load_workflow_causal_cursor(
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
run_id: &str,
) -> Result<CausalRef> {
let row = sqlx::query("SELECT causal_cursor FROM batch_workflows WHERE run_id = $1")
.bind(run_id)
.fetch_one(&mut **tx)
.await
.map_err(|err| sqlx_error("batch-load-workflow-causal-cursor", err))?;
causal_ref_from_value(
row.try_get("causal_cursor")
.map_err(|err| sqlx_error("batch-load-workflow-causal-cursor", err))?,
)
}
async fn validate_event_causes(
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
events: &[PendingWorkflowEvent],
) -> Result<()> {
let mut seen = HashSet::new();
for event in events {
match &event.caused_by {
CausalRef::RunId { run_id } => {
let row = sqlx::query(
"SELECT EXISTS(SELECT 1 FROM batch_workflows WHERE run_id = $1) AS exists",
)
.bind(run_id)
.fetch_one(&mut **tx)
.await
.map_err(|err| sqlx_error("batch-validate-event-cause", err))?;
let exists: bool = row
.try_get("exists")
.map_err(|err| sqlx_error("batch-validate-event-cause", err))?;
if !exists {
return Err(batch_error(
"unknown-causal-run",
"workflow event references an unknown run anchor",
)
.with_string_field("run_id", run_id));
}
}
CausalRef::EventId { event_id } if seen.contains(event_id) => {}
CausalRef::EventId { event_id } => {
let row = sqlx::query(
"SELECT EXISTS(SELECT 1 FROM batch_workflow_events WHERE event_id = $1) AS exists",
)
.bind(event_id)
.fetch_one(&mut **tx)
.await
.map_err(|err| sqlx_error("batch-validate-event-cause", err))?;
let exists: bool = row
.try_get("exists")
.map_err(|err| sqlx_error("batch-validate-event-cause", err))?;
if !exists {
return Err(batch_error(
"unknown-causal-event",
"workflow event references an unknown event id",
)
.with_string_field("event_id", &event_id.to_string()));
}
}
}
seen.insert(event.event_id);
}
Ok(())
}
fn causal_ref_to_value(causal_ref: &CausalRef) -> Result<Value> {
serde_json::to_value(causal_ref).map_err(|err| {
json_error(
"invalid-causal-ref",
"failed to serialize workflow event causal reference",
err,
)
})
}
fn causal_ref_from_value(value: Value) -> Result<CausalRef> {
serde_json::from_value(value).map_err(|err| {
json_error(
"invalid-causal-ref",
"failed to deserialize workflow event causal reference",
err,
)
})
}
fn first_party_event(
config: &ObservabilityConfig,
event_type: &str,
continuation_id: Option<String>,
payload: Value,
caused_by: CausalRef,
) -> Result<PendingWorkflowEvent> {
PendingWorkflowEvent::first_party(event_type, continuation_id, payload, caused_by, config)
.map_err(|err| {
batch_error(
"invalid-first-party-event",
"failed to build workflow event",
)
.with_string_field("event_type", event_type)
.with_string_field("source", &err.to_string())
})
}
fn chain_pending_after(
first: PendingWorkflowEvent,
pending: Vec<PendingWorkflowEvent>,
) -> Vec<PendingWorkflowEvent> {
let mut events = Vec::with_capacity(1 + pending.len());
let mut cursor = CausalRef::EventId {
event_id: first.event_id,
};
events.push(first);
for mut event in pending {
if event.caused_automatically() {
event.caused_by = cursor;
}
cursor = CausalRef::EventId {
event_id: event.event_id,
};
events.push(event);
}
events
}
struct NewContinuation<'a> {
kind: &'a str,
status: &'a str,
workflow_run_id: &'a str,
provider: Option<&'a str>,
output_key: Option<&'a str>,
request: Option<Value>,
attempt_of: Option<&'a str>,
attempt: i32,
causal_cursor: CausalRef,
}
async fn insert_continuation(
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
continuation: NewContinuation<'_>,
) -> Result<String> {
let row = sqlx::query(
"WITH next_id AS (SELECT nextval('batch_continuations_id_seq') AS id) \
INSERT INTO batch_continuations \
(id, continuation_id, workflow_run_id, kind, status, provider, \
output_key, request, attempt_of, attempt, causal_cursor, quiescent) \
SELECT id, 'lc' || id::TEXT, $1, $2, $3, $4, $5, $6, $7, $8, $9, false \
FROM next_id \
RETURNING continuation_id",
)
.bind(continuation.workflow_run_id)
.bind(continuation.kind)
.bind(continuation.status)
.bind(continuation.provider)
.bind(continuation.output_key)
.bind(continuation.request)
.bind(continuation.attempt_of)
.bind(continuation.attempt)
.bind(causal_ref_to_value(&continuation.causal_cursor)?)
.fetch_one(&mut **tx)
.await
.map_err(|err| sqlx_error("batch-insert-continuation", err))?;
row.try_get("continuation_id")
.map_err(|err| sqlx_error("batch-insert-continuation", err))
}
async fn insert_workflow_row(
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
run_id: &str,
root_run_id: &str,
workflow: Value,
parent_run_id: Option<&str>,
fork_name: Option<&str>,
causal_cursor: &CausalRef,
) -> Result<()> {
let causal_cursor = causal_ref_to_value(causal_cursor)?;
sqlx::query(
"INSERT INTO batch_workflows \
(run_id, root_run_id, workflow, status, parent_run_id, fork_name, \
causal_cursor, quiescent) \
VALUES ($1, $2, $3, $4, $5, $6, $7, false)",
)
.bind(run_id)
.bind(root_run_id)
.bind(workflow)
.bind(WORKFLOW_RUNNABLE)
.bind(parent_run_id)
.bind(fork_name)
.bind(causal_cursor)
.execute(&mut **tx)
.await
.map_err(|err| sqlx_error("batch-insert-workflow-row", err))?;
Ok(())
}
fn workflow_record_from_row(row: sqlx::postgres::PgRow) -> Result<WorkflowRecord> {
let status: String = row
.try_get("status")
.map_err(|err| sqlx_error("batch-workflow-record", err))?;
let workflow_json: Value = row
.try_get("workflow")
.map_err(|err| sqlx_error("batch-workflow-record", err))?;
Ok(WorkflowRecord {
run_id: row
.try_get("run_id")
.map_err(|err| sqlx_error("batch-workflow-record", err))?,
status: WorkflowStatus::from_db(&status)?,
workflow: workflow_from_value(workflow_json)?,
parent_run_id: row
.try_get("parent_run_id")
.map_err(|err| sqlx_error("batch-workflow-record", err))?,
fork_name: row
.try_get("fork_name")
.map_err(|err| sqlx_error("batch-workflow-record", err))?,
error_sexpr: row
.try_get("error_sexpr")
.map_err(|err| sqlx_error("batch-workflow-record", err))?,
quiescent: row
.try_get("quiescent")
.map_err(|err| sqlx_error("batch-workflow-record", err))?,
})
}
fn continuation_record_from_row(row: sqlx::postgres::PgRow) -> Result<ContinuationRecord> {
let kind: String = row
.try_get("kind")
.map_err(|err| sqlx_error("batch-continuation-record", err))?;
let status: String = row
.try_get("status")
.map_err(|err| sqlx_error("batch-continuation-record", err))?;
Ok(ContinuationRecord {
continuation_id: row
.try_get("continuation_id")
.map_err(|err| sqlx_error("batch-continuation-record", err))?,
workflow_run_id: row
.try_get("workflow_run_id")
.map_err(|err| sqlx_error("batch-continuation-record", err))?,
kind: ContinuationKind::from_db(&kind)?,
status: ContinuationStatus::from_db(&status)?,
provider: row
.try_get("provider")
.map_err(|err| sqlx_error("batch-continuation-record", err))?,
output_key: row
.try_get("output_key")
.map_err(|err| sqlx_error("batch-continuation-record", err))?,
error_sexpr: row
.try_get("error_sexpr")
.map_err(|err| sqlx_error("batch-continuation-record", err))?,
quiescent: row
.try_get("quiescent")
.map_err(|err| sqlx_error("batch-continuation-record", err))?,
})
}
fn workflow_event_record_from_row(row: sqlx::postgres::PgRow) -> Result<WorkflowEventRecord> {
let caused_by: Value = row
.try_get("caused_by")
.map_err(|err| sqlx_error("batch-workflow-event-record", err))?;
Ok(WorkflowEventRecord {
event_id: row
.try_get("event_id")
.map_err(|err| sqlx_error("batch-workflow-event-record", err))?,
root_run_id: row
.try_get("root_run_id")
.map_err(|err| sqlx_error("batch-workflow-event-record", err))?,
run_id: row
.try_get("run_id")
.map_err(|err| sqlx_error("batch-workflow-event-record", err))?,
parent_run_id: row
.try_get("parent_run_id")
.map_err(|err| sqlx_error("batch-workflow-event-record", err))?,
fork_name: row
.try_get("fork_name")
.map_err(|err| sqlx_error("batch-workflow-event-record", err))?,
event_ordinal: row
.try_get("event_ordinal")
.map_err(|err| sqlx_error("batch-workflow-event-record", err))?,
caused_by: causal_ref_from_value(caused_by)?,
event_type: row
.try_get("event_type")
.map_err(|err| sqlx_error("batch-workflow-event-record", err))?,
event_version: row
.try_get("event_version")
.map_err(|err| sqlx_error("batch-workflow-event-record", err))?,
continuation_id: row
.try_get("continuation_id")
.map_err(|err| sqlx_error("batch-workflow-event-record", err))?,
event: row
.try_get("event")
.map_err(|err| sqlx_error("batch-workflow-event-record", err))?,
created_at: row
.try_get("created_at")
.map_err(|err| sqlx_error("batch-workflow-event-record", err))?,
})
}
fn workflow_to_value(workflow: &Workflow) -> Result<Value> {
serde_json::to_value(workflow).map_err(|err| {
json_error(
"invalid-workflow-state",
"failed to serialize workflow state",
err,
)
})
}
fn workflow_from_value(value: Value) -> Result<Workflow> {
serde_json::from_value(value).map_err(|err| {
json_error(
"invalid-workflow-state",
"failed to deserialize workflow state",
err,
)
})
}
fn provider_batch_status(status: MessageBatchProcessingStatus) -> &'static str {
match status {
MessageBatchProcessingStatus::InProgress => PROVIDER_BATCH_IN_PROGRESS,
MessageBatchProcessingStatus::Canceling => PROVIDER_BATCH_CANCELING,
MessageBatchProcessingStatus::Ended => PROVIDER_BATCH_ENDED,
}
}
fn should_flush(
pending_count: usize,
oldest_age: Option<Duration>,
has_active_batch: bool,
config: &Config,
) -> bool {
if pending_count == 0 {
return false;
}
if !has_active_batch {
return true;
}
pending_count >= config.min_batch_size
|| oldest_age
.map(|age| age >= config.max_batch_age)
.unwrap_or(false)
}
fn limit_i64(limit: usize) -> i64 {
i64::try_from(limit).unwrap_or(i64::MAX)
}
fn batch_error(code: &str, message: &str) -> handled::SError {
handled::SError::new("langcontinuation-batch")
.with_code(code)
.with_message(message)
}
fn sqlx_error(context: &str, err: sqlx::Error) -> handled::SError {
batch_error("postgres-error", "Postgres operation failed")
.with_string_field("context", context)
.with_string_field("source", &err.to_string())
}
fn json_error(code: &str, message: &str, err: serde_json::Error) -> handled::SError {
batch_error(code, message).with_string_field("source", &err.to_string())
}
fn anthropic_error(err: AnthropicError) -> handled::SError {
let mut error = batch_error(anthropic_error_code(&err), "Anthropic batch request failed")
.with_atom_field("retryable", err.is_retryable())
.with_string_field("source", &err.to_string());
if let Some(status_code) = err.status_code() {
error = error.with_atom_field("status_code", status_code);
}
if let Some(request_id) = err.request_id() {
error = error.with_string_field("request_id", request_id);
}
error
}
fn anthropic_error_code(err: &AnthropicError) -> &'static str {
if err.is_authentication() {
"anthropic-authentication"
} else if err.is_permission() {
"anthropic-permission"
} else if err.is_not_found() {
"anthropic-not-found"
} else if err.is_rate_limit() {
"anthropic-rate-limit"
} else if err.is_bad_request() {
"anthropic-bad-request"
} else if err.is_timeout() {
"anthropic-timeout"
} else if err.is_abort() {
"anthropic-abort"
} else if err.is_connection() {
"anthropic-connection"
} else if err.is_server_error() {
"anthropic-server-error"
} else if err.is_validation() {
"anthropic-validation"
} else if err.is_todo() {
"anthropic-unimplemented"
} else {
"anthropic-request-failed"
}
}
fn missing_anthropic_provider_error(provider: &str) -> handled::SError {
batch_error(
"missing-anthropic-provider",
"Anthropic provider is not registered",
)
.with_string_field("provider", provider)
}
#[cfg(test)]
mod tests {
use super::*;
use sqlx::postgres::PgPoolOptions;
use std::time::{SystemTime, UNIX_EPOCH};
#[test]
fn default_config_matches_planned_batch_defaults() {
let config = Config::default();
assert_eq!(config.poll_interval, Duration::from_secs(60));
assert_eq!(config.min_batch_size, 100);
assert_eq!(config.max_batch_age, Duration::from_secs(300));
assert_eq!(config.max_batch_requests, 10_000);
assert_eq!(config.max_workflows_per_poll, 1_000);
}
#[test]
fn should_flush_immediately_without_active_batch() {
let config = Config::default();
assert!(should_flush(1, Some(Duration::ZERO), false, &config));
}
#[test]
fn should_hold_small_young_batch_behind_active_batch() {
let config = Config::default();
assert!(!should_flush(
config.min_batch_size - 1,
Some(config.max_batch_age - Duration::from_secs(1)),
true,
&config
));
}
#[test]
fn should_flush_by_count_behind_active_batch() {
let config = Config::default();
assert!(should_flush(
config.min_batch_size,
Some(Duration::ZERO),
true,
&config
));
}
#[test]
fn should_flush_by_age_behind_active_batch() {
let config = Config::default();
assert!(should_flush(1, Some(config.max_batch_age), true, &config));
}
#[test]
fn workflow_status_round_trips_database_strings() {
for status in [
WorkflowStatus::Runnable,
WorkflowStatus::WaitingAnthropic,
WorkflowStatus::BlockedHuman,
WorkflowStatus::BlockedOpenAI,
WorkflowStatus::WaitingForkJoin,
WorkflowStatus::Halted,
WorkflowStatus::Failed,
] {
assert_eq!(WorkflowStatus::from_db(status.as_str()).unwrap(), status);
}
}
#[test]
fn generated_continuation_ids_fit_anthropic_custom_id_rules() {
let id = format!("lc{}", i64::MAX);
assert!(id.len() <= 64);
assert!(
id.bytes()
.all(|byte| byte.is_ascii_alphanumeric() || byte == b'-' || byte == b'_')
);
}
async fn test_pool() -> Option<PgPool> {
let Ok(url) = std::env::var("DATABASE_URL") else {
return None;
};
Some(
PgPoolOptions::new()
.max_connections(1)
.connect(&url)
.await
.expect("connect to DATABASE_URL"),
)
}
fn unique_run_id(prefix: &str) -> String {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time")
.as_nanos();
format!("{prefix}-{}-{nanos}", std::process::id())
}
#[tokio::test]
async fn postgres_migration_and_enqueue_are_idempotent_when_database_url_is_set() {
let Some(pool) = test_pool().await else {
return;
};
migrate(&pool).await.expect("migrate");
migrate(&pool).await.expect("migrate twice");
let run_id = unique_run_id("batch-enqueue");
let workflow = Workflow::new(&run_id, "entry");
let executor = Executor::with_default_config(Trampoline::default(), pool);
executor
.enqueue_workflow(workflow.clone())
.await
.expect("enqueue");
executor
.enqueue_workflow(workflow)
.await
.expect("enqueue same workflow");
assert_eq!(
executor.workflow_status(&run_id).await.expect("status"),
Some(WorkflowStatus::Runnable)
);
let events = executor
.load_workflow_events(&run_id)
.await
.expect("load events");
assert_eq!(events.len(), 1);
assert_eq!(events[0].event_ordinal, 0);
assert_eq!(events[0].event_type, "workflow.enqueued");
}
#[tokio::test]
async fn postgres_human_block_and_resume_when_database_url_is_set() {
let Some(pool) = test_pool().await else {
return;
};
migrate(&pool).await.expect("migrate");
let run_id = unique_run_id("batch-human");
let mut trampoline = Trampoline::default();
trampoline.register(
"entry",
crate::test_sync_call(|workflow| {
crate::__with_continuation(
workflow,
|_, continuation| -> std::result::Result<crate::ContinuationChoice, handled::SError> {
Ok(continuation.human(
crate::HumanRequest::new("approve"),
"human_answer",
"after",
))
},
)
}),
);
trampoline.register(
"after",
crate::test_sync_call(|workflow| {
let answer: String = workflow.from_env("human_answer").unwrap().unwrap();
workflow.into_env("after", answer == "yes").unwrap();
Ok(())
}),
);
let executor = Executor::with_default_config(trampoline, pool);
executor
.enqueue_workflow(Workflow::new(&run_id, "entry"))
.await
.expect("enqueue");
executor.run().await.expect("run to human block");
assert_eq!(
executor.workflow_status(&run_id).await.expect("status"),
Some(WorkflowStatus::BlockedHuman)
);
let blocked = executor
.list_blocked_continuations()
.await
.expect("blocked continuations");
let continuation = blocked
.into_iter()
.find(|record| record.workflow_run_id == run_id)
.expect("blocked human continuation");
executor
.resume_human(&continuation.continuation_id, "yes".to_string())
.await
.expect("resume human");
executor.run().await.expect("run after resume");
let workflow = executor
.load_workflow(&run_id)
.await
.expect("load workflow")
.expect("workflow row");
assert_eq!(workflow.status, WorkflowStatus::Halted);
assert_eq!(
workflow.workflow.from_env::<bool>("after").unwrap(),
Some(true)
);
let events = executor
.load_workflow_events(&run_id)
.await
.expect("load events");
let event_types: Vec<_> = events
.iter()
.map(|event| event.event_type.as_str())
.collect();
assert!(event_types.contains(&"workflow.enqueued"));
assert!(event_types.contains(&"local_call.started"));
assert!(event_types.contains(&"local_call.completed"));
assert!(event_types.contains(&"human.blocked"));
assert!(event_types.contains(&"human.resumed"));
assert!(event_types.contains(&"workflow.halted"));
}
}