#![doc = include_str!("../README.md")]
#![deny(missing_docs)]
use std::{
collections::BTreeMap,
collections::{HashMap, HashSet},
fmt::{Display, Formatter},
future::Future,
io::Write,
pin::Pin,
time::Instant,
};
use claudius::{
AccumulatingStream, Anthropic, ContentBlockDelta, Error as AnthropicError, Message,
MessageCreateParams, MessageStreamEvent, ToolResultBlock, ToolUnionParam, ToolUseBlock,
};
use futures::StreamExt;
use setsum::Setsum;
use uuid::Uuid;
pub use indicio;
pub mod live;
#[cfg(feature = "batch")]
pub mod batch;
pub static COLLECTOR: indicio::Collector = indicio::Collector::new();
const LLM_OUTPUT_BLUE: &str = "\x1b[34m";
const ANSI_RESET: &str = "\x1b[0m";
#[doc(hidden)]
pub fn __print_llm_output(text: &str) {
print_llm_output_chunk(text);
println!();
}
fn print_llm_output_chunk(text: &str) {
print!("{LLM_OUTPUT_BLUE}{text}{ANSI_RESET}");
let _ = std::io::stdout().flush();
}
pub(crate) fn log_indicio_clue(level: u64, clue: indicio::Value) {
COLLECTOR.emit(concat!(module_path!(), " ", file!()), line!(), level, clue);
}
pub(crate) fn wire_executor_indicio_stderr() {
if !COLLECTOR.is_logging() {
COLLECTOR.register(indicio::StdioEmitter);
}
if COLLECTOR.verbosity() < indicio::INFO {
COLLECTOR.set_verbosity(indicio::INFO);
}
}
#[cfg(feature = "batch")]
pub(crate) fn log_json_clue(level: u64, clue: serde_json::Value) {
log_indicio_clue(level, serde_json_to_indicio_value(clue));
}
pub(crate) fn log_executor_transition(executor: &str, transition: &str, fields: indicio::Value) {
log_indicio_clue(
indicio::INFO,
indicio::value!({
log_type: format!("langcontinuation.{executor}.executor_transition"),
transition: transition,
fields: fields,
}),
);
}
pub(crate) fn optional_indicio_string(value: Option<&str>) -> indicio::Value {
value
.map(indicio::Value::from)
.unwrap_or_else(|| indicio::value!({ null: true }))
}
#[cfg(feature = "batch")]
fn serde_json_to_indicio_value(value: serde_json::Value) -> indicio::Value {
match value {
serde_json::Value::Null => indicio::value!({ null: true }),
serde_json::Value::Bool(value) => indicio::Value::from(value),
serde_json::Value::Number(value) => {
if let Some(value) = value.as_u64() {
indicio::Value::from(value)
} else if let Some(value) = value.as_i64() {
indicio::Value::from(value)
} else if let Some(value) = value.as_f64() {
indicio::Value::from(value)
} else {
indicio::Value::from(value.to_string())
}
}
serde_json::Value::String(value) => indicio::Value::from(value),
serde_json::Value::Array(values) => indicio::Value::Array(
values
.into_iter()
.map(serde_json_to_indicio_value)
.collect::<Vec<_>>()
.into(),
),
serde_json::Value::Object(values) => indicio::Value::Object(
values
.into_iter()
.map(|(key, value)| (key, serde_json_to_indicio_value(value)))
.collect(),
),
}
}
pub struct Continuation {
_phantom: std::marker::PhantomData<()>,
}
impl Continuation {
pub fn goto(self) -> ContinuationChoice {
ContinuationChoice {
steps: vec![],
halt: false,
}
}
pub fn halt(self) -> ContinuationChoice {
ContinuationChoice {
steps: Vec::new(),
halt: true,
}
}
pub fn call(self, function: impl Into<String>) -> ContinuationChoice {
let function = function.into();
ContinuationChoice {
steps: vec![Step::Call {
function: function.clone(),
}],
halt: false,
}
}
pub fn anthropic(
self,
provider: impl Into<String>,
message: MessageCreateParams,
output_key: impl Into<String>,
next_function: impl Into<String>,
) -> ContinuationChoice {
let provider = provider.into();
let output_key = output_key.into();
let next_function = next_function.into();
ContinuationChoice {
steps: vec![
Step::Anthropic {
provider: provider.clone(),
message: Box::new(message),
output_key: output_key.clone(),
},
Step::Call {
function: next_function.clone(),
},
],
halt: false,
}
}
pub fn human(
self,
request: HumanRequest,
output_key: impl Into<String>,
next_function: impl Into<String>,
) -> ContinuationChoice {
let output_key = output_key.into();
let next_function = next_function.into();
ContinuationChoice {
steps: vec![
Step::Human {
request,
output_key: output_key.clone(),
},
Step::Call {
function: next_function.clone(),
},
],
halt: false,
}
}
pub fn tool_call(
self,
tool_uses: Vec<ToolUseBlock>,
output_key: impl Into<String>,
next_function: impl Into<String>,
) -> ContinuationChoice {
let output_key = output_key.into();
let next_function = next_function.into();
ContinuationChoice {
steps: vec![
Step::ToolCall {
tool_uses,
output_key: output_key.clone(),
},
Step::Call {
function: next_function.clone(),
},
],
halt: false,
}
}
pub fn fork_join(
self,
lhs: ForkBranch,
rhs: ForkBranch,
function: impl Into<String>,
) -> ContinuationChoice {
let function = function.into();
ContinuationChoice {
steps: vec![Step::ForkJoin {
lhs,
rhs,
function: function.clone(),
}],
halt: false,
}
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ForkBranch {
run_id: String,
function: String,
}
impl ForkBranch {
pub fn new(run_id: impl Into<String>, function: impl Into<String>) -> Self {
let run_id = run_id.into();
let function = function.into();
Self { run_id, function }
}
}
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct HumanRequest {
prompt: String,
context: serde_json::Value,
metadata: serde_json::Value,
}
impl HumanRequest {
pub fn new(prompt: impl Into<String>) -> Self {
Self {
prompt: prompt.into(),
context: serde_json::Value::Null,
metadata: serde_json::Value::Object(serde_json::Map::new()),
}
}
pub fn with_context(mut self, context: serde_json::Value) -> Self {
self.context = context;
self
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = metadata;
self
}
pub fn prompt(&self) -> &str {
&self.prompt
}
pub fn context(&self) -> &serde_json::Value {
&self.context
}
pub fn metadata(&self) -> &serde_json::Value {
&self.metadata
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub struct ToolCallId {
run_id: String,
tool_use_id: String,
}
impl ToolCallId {
pub fn new(run_id: impl Into<String>, tool_use_id: impl Into<String>) -> Self {
Self {
run_id: run_id.into(),
tool_use_id: tool_use_id.into(),
}
}
pub fn run_id(&self) -> &str {
&self.run_id
}
pub fn tool_use_id(&self) -> &str {
&self.tool_use_id
}
}
impl Display for ToolCallId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}", self.run_id, self.tool_use_id)
}
}
pub trait Tool: Send + Sync {
fn name(&self) -> String;
fn to_param(&self) -> ToolUnionParam;
fn call<'a>(
&'a self,
id: ToolCallId,
tool_use: &'a ToolUseBlock,
) -> Pin<Box<dyn Future<Output = ToolResultBlock> + Send + 'a>>;
}
pub fn client_tool_uses(response: &Message) -> Vec<ToolUseBlock> {
let uses: Vec<ToolUseBlock> = response
.content
.iter()
.filter_map(|block| match block {
claudius::ContentBlock::ToolUse(tool_use) => Some(tool_use.clone()),
_ => None,
})
.collect();
uses
}
pub enum ToolDispatch {
Tools(ContinuationChoice),
Done(Continuation),
}
pub fn dispatch_tool_uses(
continuation: Continuation,
response: &Message,
output_key: impl Into<String>,
next_function: impl Into<String>,
) -> ToolDispatch {
let tool_uses = client_tool_uses(response);
if tool_uses.is_empty() {
ToolDispatch::Done(continuation)
} else {
let choice = continuation.tool_call(tool_uses, output_key, next_function);
ToolDispatch::Tools(choice)
}
}
pub struct ContinuationChoice {
steps: Vec<Step>,
halt: bool,
}
impl ContinuationChoice {
fn apply_to(self, workflow: &mut Workflow) {
if self.halt {
workflow.continuation.clear();
} else {
workflow.continuation.extend(self.steps.into_iter().rev());
}
}
}
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
enum Step {
#[default]
Halt,
Anthropic {
provider: String,
message: Box<MessageCreateParams>,
output_key: String,
},
Human {
request: HumanRequest,
output_key: String,
},
ToolCall {
tool_uses: Vec<ToolUseBlock>,
output_key: String,
},
OpenAI {},
Call {
function: String,
},
ForkJoin {
lhs: ForkBranch,
rhs: ForkBranch,
function: String,
},
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum CausalRef {
RunId {
run_id: String,
},
EventId {
event_id: Uuid,
},
}
impl CausalRef {
fn run_id(run_id: impl Into<String>) -> Self {
Self::RunId {
run_id: run_id.into(),
}
}
fn event_id(event_id: Uuid) -> Self {
Self::EventId { event_id }
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ObservabilityContext {
pub causal_cursor: CausalRef,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ObservabilityConfig {
pub max_env_changes: usize,
pub max_event_payload_bytes: usize,
}
impl Default for ObservabilityConfig {
fn default() -> Self {
Self {
max_env_changes: 64,
max_event_payload_bytes: 32 * 1024,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum EventCauseMode {
Automatic,
Explicit,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct PendingWorkflowEvent {
pub event_id: Uuid,
pub caused_by: CausalRef,
pub event_type: String,
pub event_version: i16,
pub continuation_id: Option<String>,
pub event: serde_json::Value,
#[cfg_attr(not(feature = "batch"), allow(dead_code))]
#[serde(skip, default = "automatic_event_cause_mode")]
cause_mode: EventCauseMode,
}
fn automatic_event_cause_mode() -> EventCauseMode {
EventCauseMode::Automatic
}
impl PendingWorkflowEvent {
fn custom<T: serde::Serialize>(
event_type: impl Into<String>,
event_version: i16,
payload: T,
caused_by: CausalRef,
cause_mode: EventCauseMode,
config: &ObservabilityConfig,
) -> Result<Self, handled::SError> {
let event_type = event_type.into();
validate_custom_event_type(&event_type)?;
Self::new(
event_type,
event_version,
None,
payload,
caused_by,
cause_mode,
config,
)
}
#[cfg(feature = "batch")]
pub(crate) fn first_party<T: serde::Serialize>(
event_type: impl Into<String>,
continuation_id: Option<String>,
payload: T,
caused_by: CausalRef,
config: &ObservabilityConfig,
) -> Result<Self, handled::SError> {
Self::new(
event_type,
1,
continuation_id,
payload,
caused_by,
EventCauseMode::Automatic,
config,
)
}
fn new<T: serde::Serialize>(
event_type: impl Into<String>,
event_version: i16,
continuation_id: Option<String>,
payload: T,
caused_by: CausalRef,
cause_mode: EventCauseMode,
config: &ObservabilityConfig,
) -> Result<Self, handled::SError> {
let event_type = event_type.into();
if event_version <= 0 {
return Err(observability_error(
"invalid-event-version",
"workflow event version must be positive",
)
.with_atom_field("event_version", event_version));
}
let event = serde_json::to_value(payload).map_err(|err| {
observability_error(
"invalid-event-payload",
"failed to serialize workflow event payload",
)
.with_string_field("source", &err.to_string())
})?;
let payload_size = serde_json::to_vec(&event).map_err(|err| {
observability_error(
"invalid-event-payload",
"failed to measure workflow event payload",
)
.with_string_field("source", &err.to_string())
})?;
if payload_size.len() > config.max_event_payload_bytes {
return Err(observability_error(
"event-payload-too-large",
"workflow event payload exceeds configured maximum size",
)
.with_string_field("event_type", &event_type)
.with_atom_field("payload_bytes", payload_size.len())
.with_atom_field("max_payload_bytes", config.max_event_payload_bytes));
}
let event = Self {
event_id: Uuid::now_v7(),
caused_by,
event_type,
event_version,
continuation_id,
event,
cause_mode,
};
Ok(event)
}
#[cfg(feature = "batch")]
pub(crate) fn caused_automatically(&self) -> bool {
self.cause_mode == EventCauseMode::Automatic
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ValueShape {
Missing,
Null,
Bool,
Number,
String,
Array,
Object,
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct EnvValueSummary {
pub shape: ValueShape,
pub bytes: Option<usize>,
pub digest: Option<String>,
}
impl EnvValueSummary {
fn missing() -> Self {
Self {
shape: ValueShape::Missing,
bytes: None,
digest: None,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EnvChangeKind {
Added,
Removed,
Modified,
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct EnvChangeSummary {
pub key: String,
pub change: EnvChangeKind,
pub before: EnvValueSummary,
pub after: EnvValueSummary,
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct EnvChangeSetSummary {
pub changed_key_count: usize,
pub changes_truncated: bool,
pub env_before_digest: String,
pub env_after_digest: String,
pub env_before_key_count: usize,
pub env_after_key_count: usize,
pub changes: Vec<EnvChangeSummary>,
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum StepSummary {
Halt,
Call {
function: String,
},
Anthropic {
provider: String,
output_key: String,
},
Human {
output_key: String,
},
ToolCall {
tool_names: Vec<String>,
output_key: String,
},
OpenAI,
ForkJoin {
branch_run_id: BTreeMap<String, String>,
join_function: String,
},
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct FlowSummary {
pub current_step_before: StepSummary,
pub current_step_after: StepSummary,
pub continuation_depth_before: usize,
pub continuation_depth_after: usize,
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum WorkflowNext {
Halt,
LocalCall {
function: String,
},
Anthropic {
provider: String,
output_key: String,
},
Human {
output_key: String,
},
ToolCall {
tool_names: Vec<String>,
output_key: String,
},
OpenAI,
ForkJoin {
branch_run_id: BTreeMap<String, String>,
join_function: String,
},
}
#[derive(Clone, Debug)]
pub struct WorkflowOutcome {
pub result: WorkflowResult,
pub events: Vec<PendingWorkflowEvent>,
}
#[derive(Clone, Debug)]
pub struct WorkflowStepOutcome {
pub workflow: Workflow,
pub function: String,
pub env_changes: EnvChangeSetSummary,
pub flow: FlowSummary,
pub events: Vec<PendingWorkflowEvent>,
pub duration_ms: u128,
}
#[derive(Debug)]
pub struct WorkflowError {
pub workflow: Workflow,
pub function: Option<String>,
pub env_changes: EnvChangeSetSummary,
pub flow: FlowSummary,
pub events: Vec<PendingWorkflowEvent>,
pub source: handled::SError,
pub duration_ms: Option<u128>,
}
impl Display for WorkflowError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.source, f)
}
}
impl std::error::Error for WorkflowError {}
impl From<WorkflowError> for handled::SError {
fn from(error: WorkflowError) -> Self {
error
.source
.with_string_field("run_id", error.workflow.run_id())
.with_atom_field("pending_event_count", error.events.len())
}
}
#[derive(Clone, Debug)]
struct WorkflowObservabilityState {
causal_cursor: CausalRef,
pending_events: Vec<PendingWorkflowEvent>,
config: ObservabilityConfig,
}
impl Default for WorkflowObservabilityState {
fn default() -> Self {
Self {
causal_cursor: CausalRef::run_id(""),
pending_events: Vec::new(),
config: ObservabilityConfig::default(),
}
}
}
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
pub struct Workflow {
run_id: String,
env: HashMap<String, serde_json::Value>,
current_step: Step,
continuation: Vec<Step>,
#[serde(skip, default)]
observability: WorkflowObservabilityState,
}
impl Workflow {
pub fn new(run_id: impl Into<String>, call: impl Into<String>) -> Self {
let run_id = run_id.into();
let call = call.into();
Self {
run_id: run_id.clone(),
env: HashMap::default(),
current_step: Step::Call {
function: call.clone(),
},
continuation: Vec::new(),
observability: WorkflowObservabilityState {
causal_cursor: CausalRef::run_id(&run_id),
pending_events: Vec::new(),
config: ObservabilityConfig::default(),
},
}
}
pub fn run_id(&self) -> &str {
&self.run_id
}
pub fn set_observability_context(&mut self, context: ObservabilityContext) {
self.observability.causal_cursor = context.causal_cursor;
}
fn set_observability_config(&mut self, config: ObservabilityConfig) {
self.observability.config = config;
}
pub fn record_event<T: serde::Serialize>(
&mut self,
event_type: impl Into<String>,
payload: T,
) -> Result<Uuid, handled::SError> {
let caused_by = self.current_causal_cursor();
self.record_event_inner(event_type, 1, payload, caused_by, EventCauseMode::Automatic)
}
pub fn record_event_caused_by<T: serde::Serialize>(
&mut self,
event_type: impl Into<String>,
payload: T,
caused_by: CausalRef,
) -> Result<Uuid, handled::SError> {
self.record_event_inner(event_type, 1, payload, caused_by, EventCauseMode::Explicit)
}
fn record_event_inner<T: serde::Serialize>(
&mut self,
event_type: impl Into<String>,
event_version: i16,
payload: T,
caused_by: CausalRef,
cause_mode: EventCauseMode,
) -> Result<Uuid, handled::SError> {
let event = PendingWorkflowEvent::custom(
event_type,
event_version,
payload,
caused_by,
cause_mode,
&self.observability.config,
)?;
let event_id = event.event_id;
self.observability.causal_cursor = CausalRef::event_id(event_id);
self.observability.pending_events.push(event);
Ok(event_id)
}
fn current_causal_cursor(&self) -> CausalRef {
match &self.observability.causal_cursor {
CausalRef::RunId { run_id } if run_id.is_empty() => CausalRef::run_id(&self.run_id),
other => other.clone(),
}
}
pub(crate) fn drain_pending_events(&mut self) -> Vec<PendingWorkflowEvent> {
std::mem::take(&mut self.observability.pending_events)
}
pub fn into_env<T: serde::Serialize>(
&mut self,
key: impl Into<String>,
value: T,
) -> Result<(), serde_json::Error> {
(|| {
let key = key.into();
let value = serde_json::to_value(&value)?;
self.env.insert(key.clone(), value);
Ok(())
})()
}
pub fn from_env<T: for<'de> serde::Deserialize<'de>>(
&self,
key: impl AsRef<str>,
) -> Result<Option<T>, serde_json::Error> {
(|| {
let Some(value) = self.env.get(key.as_ref()) else {
return Ok(None);
};
serde_json::from_value(value.clone()).map(Some)
})()
}
fn advance(&mut self) -> bool {
if let Some(next_step) = self.continuation.pop() {
self.current_step = next_step;
true
} else {
self.current_step = Step::Halt;
false
}
}
fn prepare_next_step(&mut self) {
while matches!(self.current_step, Step::Halt) && !self.continuation.is_empty() {
self.advance();
}
}
fn fork_branch(&self, branch: ForkBranch) -> Self {
let run_id = branch.run_id;
Self {
run_id: run_id.clone(),
env: self.env.clone(),
current_step: Step::Call {
function: branch.function,
},
continuation: Vec::new(),
observability: WorkflowObservabilityState {
causal_cursor: CausalRef::run_id(&run_id),
pending_events: Vec::new(),
config: self.observability.config.clone(),
},
}
}
#[cfg(test)]
fn schedule(&mut self, next_step: Step) {
self.continuation.push(next_step);
}
}
#[derive(Clone, Debug)]
pub enum WorkflowResult {
Halt {
workflow: Workflow,
},
Anthropic {
workflow: Workflow,
provider: String,
message: Box<MessageCreateParams>,
output_key: String,
},
Human {
workflow: Workflow,
request: HumanRequest,
output_key: String,
},
ToolCall {
workflow: Workflow,
tool_uses: Vec<ToolUseBlock>,
output_key: String,
},
OpenAI {
workflow: Workflow,
},
ForkJoin {
workflow: Workflow,
lhs: Box<Workflow>,
rhs: Box<Workflow>,
},
}
pub type CallFuture<'a> = Pin<Box<dyn Future<Output = Result<(), handled::SError>> + 'a>>;
pub type Call = dyn for<'a> Fn(&'a mut Workflow) -> CallFuture<'a> + 'static;
#[cfg(test)]
pub(crate) fn test_sync_call<F>(
f: F,
) -> impl for<'a> Fn(&'a mut Workflow) -> CallFuture<'a> + 'static
where
F: Fn(&mut Workflow) -> Result<(), handled::SError> + Copy + 'static,
{
move |workflow| -> CallFuture<'_> { Box::pin(async move { f(workflow) }) }
}
#[cfg(test)]
pub(crate) fn test_run_trampoline(
trampoline: &Trampoline,
workflow: Workflow,
) -> Result<WorkflowResult, Box<WorkflowError>> {
let outcome = tokio::runtime::Runtime::new()
.expect("test runtime")
.block_on(trampoline.run(workflow))
.map_err(Box::new)?;
Ok(outcome.result)
}
#[derive(Default)]
pub struct Clients {
clients: HashMap<String, Anthropic>,
}
impl Clients {
pub fn new() -> Self {
Self::default()
}
pub fn register_anthropic(&mut self, provider: impl Into<String>, client: Anthropic) {
let provider = provider.into();
self.clients.insert(provider.clone(), client);
}
pub async fn send(
&self,
provider: &str,
params: MessageCreateParams,
) -> Result<Message, handled::SError> {
async {
let Some(client) = self.clients.get(provider) else {
return Err(missing_anthropic_provider_error(provider));
};
stream_anthropic_message(client, params).await
}
.await
}
}
async fn stream_anthropic_message(
client: &Anthropic,
params: MessageCreateParams,
) -> Result<Message, handled::SError> {
async {
let params = params.with_stream(true);
let stream = client.stream(¶ms).await.map_err(anthropic_error)?;
let (mut stream, message_rx) = AccumulatingStream::new(stream);
while let Some(event) = stream.next().await {
print_llm_stream_event_text(event.map_err(anthropic_error)?);
}
let message = message_rx
.await
.map_err(|err| stream_accumulation_channel_error(&err.to_string()))?
.map_err(anthropic_error)?;
Ok(message)
}
.await
}
fn print_llm_stream_event_text(event: MessageStreamEvent) {
match event {
MessageStreamEvent::ContentBlockStart(start) => {
if let claudius::ContentBlock::Text(text) = start.content_block
&& !text.text.is_empty()
{
print_llm_output_chunk(&text.text);
}
}
MessageStreamEvent::ContentBlockDelta(delta) => {
if let ContentBlockDelta::TextDelta(text_delta) = delta.delta {
print_llm_output_chunk(&text_delta.text);
}
}
_ => {}
}
}
fn stream_accumulation_channel_error(source: &str) -> handled::SError {
handled::SError::new("langcontinuation")
.with_code("anthropic-stream-accumulation")
.with_message("failed to receive accumulated Anthropic streaming message")
.with_string_field("source", source)
}
fn unsupported_openai_error() -> handled::SError {
handled::SError::new("langcontinuation")
.with_code("unsupported-openai-provider")
.with_message("OpenAI workflow steps are not supported")
.with_string_field("next_action", "ask for a proper Rust OpenAI client")
}
fn human_input_required_error(request: &HumanRequest, output_key: &str) -> handled::SError {
handled::SError::new("langcontinuation")
.with_code("human-input-required")
.with_message("human input is required to continue the workflow")
.with_string_field("output_key", output_key)
.with_string_field("prompt", request.prompt())
.with_string_field("next_action", "resume the workflow with human input")
}
fn missing_anthropic_provider_error(provider: &str) -> handled::SError {
handled::SError::new("langcontinuation")
.with_code("missing-anthropic-provider")
.with_message("Anthropic provider is not registered")
.with_string_field("provider", provider)
}
fn anthropic_error(err: AnthropicError) -> handled::SError {
let mut error = handled::SError::new("langcontinuation")
.with_code(anthropic_error_code(&err))
.with_message("Anthropic request failed")
.with_atom_field("retryable", err.is_retryable())
.with_string_field("source", &err.to_string());
if let Some(status_code) = err.status_code() {
error = error.with_atom_field("status_code", status_code);
}
if let Some(request_id) = err.request_id() {
error = error.with_string_field("request_id", request_id);
}
error
}
fn anthropic_error_code(err: &AnthropicError) -> &'static str {
if err.is_authentication() {
"anthropic-authentication"
} else if err.is_permission() {
"anthropic-permission"
} else if err.is_not_found() {
"anthropic-not-found"
} else if err.is_rate_limit() {
"anthropic-rate-limit"
} else if err.is_bad_request() {
"anthropic-bad-request"
} else if err.is_timeout() {
"anthropic-timeout"
} else if err.is_abort() {
"anthropic-abort"
} else if err.is_connection() {
"anthropic-connection"
} else if err.is_server_error() {
"anthropic-server-error"
} else if err.is_validation() {
"anthropic-validation"
} else if err.is_todo() {
"anthropic-unimplemented"
} else {
"anthropic-request-failed"
}
}
#[derive(Default)]
pub struct Trampoline {
fns: HashMap<String, Box<Call>>,
tools: HashMap<String, std::sync::Arc<dyn Tool>>,
observability_config: ObservabilityConfig,
}
impl Trampoline {
pub fn set_observability_config(&mut self, config: ObservabilityConfig) {
self.observability_config = config;
}
#[cfg(feature = "batch")]
pub(crate) fn observability_config(&self) -> &ObservabilityConfig {
&self.observability_config
}
pub async fn run(&self, mut workflow: Workflow) -> Result<WorkflowOutcome, WorkflowError> {
let mut events = Vec::new();
workflow.set_observability_config(self.observability_config.clone());
loop {
workflow.prepare_next_step();
match workflow.current_step.clone() {
Step::Halt => {
let mut workflow = workflow;
events.extend(workflow.drain_pending_events());
return Ok(WorkflowOutcome {
result: WorkflowResult::Halt { workflow },
events,
});
}
Step::Anthropic {
provider,
message,
output_key,
} => {
events.extend(workflow.drain_pending_events());
return Ok(WorkflowOutcome {
result: WorkflowResult::Anthropic {
workflow,
provider,
message,
output_key,
},
events,
});
}
Step::Human {
request,
output_key,
} => {
events.extend(workflow.drain_pending_events());
return Ok(WorkflowOutcome {
result: WorkflowResult::Human {
workflow,
request,
output_key,
},
events,
});
}
Step::ToolCall {
tool_uses,
output_key,
} => {
events.extend(workflow.drain_pending_events());
return Ok(WorkflowOutcome {
result: WorkflowResult::ToolCall {
workflow,
tool_uses,
output_key,
},
events,
});
}
Step::OpenAI {} => {
events.extend(workflow.drain_pending_events());
return Ok(WorkflowOutcome {
result: WorkflowResult::OpenAI { workflow },
events,
});
}
Step::ForkJoin {
lhs,
rhs,
function: _,
} => {
let lhs = workflow.fork_branch(lhs);
let rhs = workflow.fork_branch(rhs);
events.extend(workflow.drain_pending_events());
return Ok(WorkflowOutcome {
result: WorkflowResult::ForkJoin {
workflow,
lhs: Box::new(lhs),
rhs: Box::new(rhs),
},
events,
});
}
Step::Call { .. } => match self.run_one_local_call(workflow).await {
Ok(outcome) => {
events.extend(outcome.events);
workflow = outcome.workflow;
}
Err(mut err) => {
events.extend(err.events);
err.events = events;
return Err(err);
}
},
}
}
}
pub fn next_action(&self, workflow: &Workflow) -> WorkflowNext {
effective_step(workflow).next()
}
pub async fn run_one_local_call(
&self,
mut workflow: Workflow,
) -> Result<WorkflowStepOutcome, WorkflowError> {
workflow.set_observability_config(self.observability_config.clone());
workflow.prepare_next_step();
let before_env = workflow.env.clone();
let before_step = workflow.current_step.summary();
let before_depth = workflow.continuation.len();
let function = match workflow.current_step.clone() {
Step::Call { function } => function,
other => {
let flow = FlowSummary {
current_step_before: before_step,
current_step_after: other.summary(),
continuation_depth_before: before_depth,
continuation_depth_after: workflow.continuation.len(),
};
let env_changes =
summarize_env_changes(&before_env, &workflow.env, &self.observability_config)
.unwrap_or_else(empty_env_change_summary);
let events = workflow.drain_pending_events();
return Err(WorkflowError {
workflow,
function: None,
env_changes,
flow,
events,
source: observability_error(
"not-local-call",
"attempted to execute one local call when the workflow is not at a local call",
),
duration_ms: None,
});
}
};
let Some(implementation) = self.fns.get(&function) else {
let flow = FlowSummary {
current_step_before: before_step,
current_step_after: workflow.current_step.summary(),
continuation_depth_before: before_depth,
continuation_depth_after: workflow.continuation.len(),
};
let env_changes =
summarize_env_changes(&before_env, &workflow.env, &self.observability_config)
.unwrap_or_else(empty_env_change_summary);
let events = workflow.drain_pending_events();
return Err(WorkflowError {
workflow,
function: Some(function.clone()),
env_changes,
flow,
events,
source: missing_function_error(&function),
duration_ms: Some(0),
});
};
let started = Instant::now();
match implementation(&mut workflow).await {
Ok(()) => {
workflow.advance();
let duration_ms = started.elapsed().as_millis();
let env_changes =
summarize_env_changes(&before_env, &workflow.env, &self.observability_config)
.unwrap_or_else(empty_env_change_summary);
let flow = FlowSummary {
current_step_before: before_step,
current_step_after: workflow.current_step.summary(),
continuation_depth_before: before_depth,
continuation_depth_after: workflow.continuation.len(),
};
let events = workflow.drain_pending_events();
Ok(WorkflowStepOutcome {
workflow,
function,
env_changes,
flow,
events,
duration_ms,
})
}
Err(source) => {
let duration_ms = started.elapsed().as_millis();
let env_changes =
summarize_env_changes(&before_env, &workflow.env, &self.observability_config)
.unwrap_or_else(empty_env_change_summary);
let flow = FlowSummary {
current_step_before: before_step,
current_step_after: workflow.current_step.summary(),
continuation_depth_before: before_depth,
continuation_depth_after: workflow.continuation.len(),
};
let events = workflow.drain_pending_events();
Err(WorkflowError {
workflow,
function: Some(function),
env_changes,
flow,
events,
source,
duration_ms: Some(duration_ms),
})
}
}
}
pub fn resume_anthropic(
&self,
mut workflow: Workflow,
output_key: impl Into<String>,
message: Message,
) -> Result<Workflow, handled::SError> {
(|| {
let output_key = output_key.into();
match &workflow.current_step {
Step::Anthropic {
output_key: suspended_output_key,
..
} if suspended_output_key == &output_key => {}
Step::Anthropic {
output_key: suspended_output_key,
..
} => {
return Err(resume_error(
"anthropic-output-key-mismatch",
"attempted to resume an Anthropic step with the wrong output key",
Some(("expected", suspended_output_key)),
Some(("actual", &output_key)),
));
}
_ => {
return Err(resume_error(
"not-suspended-at-anthropic",
"attempted to resume Anthropic output on a workflow that is not suspended at an Anthropic step",
None,
Some(("output_key", &output_key)),
));
}
}
let value = serde_json::to_value(message).map_err(|err| {
handled::SError::new("langcontinuation")
.with_code("invalid-anthropic-response")
.with_message(
"failed to serialize Anthropic response into workflow environment",
)
.with_string_field("output_key", &output_key)
.with_string_field("source", &err.to_string())
})?;
workflow.env.insert(output_key, value);
workflow.advance();
Ok(workflow)
})()
}
pub fn resume_human<T: serde::Serialize>(
&self,
mut workflow: Workflow,
output_key: impl Into<String>,
value: T,
) -> Result<Workflow, handled::SError> {
(|| {
let output_key = output_key.into();
match &workflow.current_step {
Step::Human {
output_key: suspended_output_key,
..
} if suspended_output_key == &output_key => {}
Step::Human {
output_key: suspended_output_key,
..
} => {
return Err(resume_error(
"human-output-key-mismatch",
"attempted to resume a human step with the wrong output key",
Some(("expected", suspended_output_key)),
Some(("actual", &output_key)),
));
}
_ => {
return Err(resume_error(
"not-suspended-at-human",
"attempted to resume human output on a workflow that is not suspended at a human step",
None,
Some(("output_key", &output_key)),
));
}
}
insert_resume_value(
&mut workflow,
output_key,
value,
"invalid-human-response",
"failed to serialize human response into workflow environment",
)?;
workflow.advance();
Ok(workflow)
})()
}
pub fn resume_tool_call(
&self,
mut workflow: Workflow,
output_key: impl Into<String>,
results: Vec<ToolResultBlock>,
) -> Result<Workflow, handled::SError> {
(|| {
let output_key = output_key.into();
match &workflow.current_step {
Step::ToolCall {
output_key: suspended_output_key,
..
} if suspended_output_key == &output_key => {}
Step::ToolCall {
output_key: suspended_output_key,
..
} => {
return Err(resume_error(
"tool-call-output-key-mismatch",
"attempted to resume a tool-call step with the wrong output key",
Some(("expected", suspended_output_key)),
Some(("actual", &output_key)),
));
}
_ => {
return Err(resume_error(
"not-suspended-at-tool-call",
"attempted to resume tool results on a workflow that is not suspended at a tool-call step",
None,
Some(("output_key", &output_key)),
));
}
}
insert_resume_value(
&mut workflow,
output_key,
results,
"invalid-tool-results",
"failed to serialize tool results into workflow environment",
)?;
workflow.advance();
Ok(workflow)
})()
}
pub fn resume_fork_join(
&self,
mut workflow: Workflow,
lhs: Workflow,
rhs: Workflow,
) -> Result<Workflow, handled::SError> {
(|| {
let function = match &workflow.current_step {
Step::ForkJoin { function, .. } => function.clone(),
_ => {
return Err(fork_join_resume_error(&workflow.current_step));
}
};
require_halted_branch("lhs", &lhs)?;
require_halted_branch("rhs", &rhs)?;
workflow.env = merge_fork_join_env(&workflow.env, &lhs.env, &rhs.env)?;
workflow.current_step = Step::Call { function };
Ok(workflow)
})()
}
pub fn resume_open_ai<T: serde::Serialize>(
&self,
mut workflow: Workflow,
output_key: impl Into<String>,
value: T,
) -> Result<Workflow, handled::SError> {
(|| {
let output_key = output_key.into();
match &workflow.current_step {
Step::OpenAI {} => {}
_ => {
return Err(resume_error(
"not-suspended-at-openai",
"attempted to resume OpenAI output on a workflow that is not suspended at an OpenAI step",
None,
Some(("output_key", &output_key)),
));
}
}
insert_resume_value(
&mut workflow,
output_key,
value,
"invalid-openai-response",
"failed to serialize OpenAI response into workflow environment",
)?;
workflow.advance();
Ok(workflow)
})()
}
pub fn register(
&mut self,
function: impl Into<String>,
implementation: impl for<'a> Fn(&'a mut Workflow) -> CallFuture<'a> + 'static,
) {
let function = function.into();
self.fns
.insert(function.clone(), Box::new(implementation) as _);
}
pub fn register_tool(&mut self, tool: impl Tool + 'static) {
let name = tool.name();
self.tools.insert(name.clone(), std::sync::Arc::new(tool));
}
pub fn tool(&self, name: &str) -> Option<std::sync::Arc<dyn Tool>> {
self.tools.get(name).cloned()
}
pub async fn run_tool_calls(
&self,
run_id: &str,
tool_uses: &[ToolUseBlock],
) -> Result<Vec<ToolResultBlock>, handled::SError> {
async {
let mut results = Vec::with_capacity(tool_uses.len());
for tool_use in tool_uses {
let tool = self
.tool(&tool_use.name)
.ok_or_else(|| missing_tool_error(&tool_use.name))?;
let id = ToolCallId::new(run_id, &tool_use.id);
results.push(tool.call(id, tool_use).await);
}
Ok(results)
}
.await
}
}
impl Step {
fn summary(&self) -> StepSummary {
match self {
Step::Halt => StepSummary::Halt,
Step::Anthropic {
provider,
output_key,
..
} => StepSummary::Anthropic {
provider: provider.clone(),
output_key: output_key.clone(),
},
Step::Human { output_key, .. } => StepSummary::Human {
output_key: output_key.clone(),
},
Step::ToolCall {
tool_uses,
output_key,
} => StepSummary::ToolCall {
tool_names: tool_uses
.iter()
.map(|tool_use| tool_use.name.clone())
.collect(),
output_key: output_key.clone(),
},
Step::OpenAI {} => StepSummary::OpenAI,
Step::Call { function } => StepSummary::Call {
function: function.clone(),
},
Step::ForkJoin { lhs, rhs, function } => {
let mut branch_run_id = BTreeMap::new();
branch_run_id.insert("lhs".to_string(), lhs.run_id.clone());
branch_run_id.insert("rhs".to_string(), rhs.run_id.clone());
StepSummary::ForkJoin {
branch_run_id,
join_function: function.clone(),
}
}
}
}
fn next(&self) -> WorkflowNext {
match self.summary() {
StepSummary::Halt => WorkflowNext::Halt,
StepSummary::Call { function } => WorkflowNext::LocalCall { function },
StepSummary::Anthropic {
provider,
output_key,
} => WorkflowNext::Anthropic {
provider,
output_key,
},
StepSummary::Human { output_key } => WorkflowNext::Human { output_key },
StepSummary::ToolCall {
tool_names,
output_key,
} => WorkflowNext::ToolCall {
tool_names,
output_key,
},
StepSummary::OpenAI => WorkflowNext::OpenAI,
StepSummary::ForkJoin {
branch_run_id,
join_function,
} => WorkflowNext::ForkJoin {
branch_run_id,
join_function,
},
}
}
}
fn validate_custom_event_type(event_type: &str) -> Result<(), handled::SError> {
const RESERVED_PREFIXES: &[&str] = &[
"workflow.",
"local_call.",
"continuation.",
"anthropic.",
"openai.",
"human.",
"tool.",
"tool_call.",
"fork_join.",
];
if event_type.is_empty() || !event_type.contains('.') {
return Err(observability_error(
"invalid-custom-event-type",
"custom workflow event type must be non-empty and contain a dot",
)
.with_string_field("event_type", event_type));
}
if RESERVED_PREFIXES
.iter()
.any(|prefix| event_type.starts_with(prefix))
{
return Err(observability_error(
"reserved-custom-event-type",
"custom workflow event type uses a reserved first-party prefix",
)
.with_string_field("event_type", event_type));
}
Ok(())
}
fn summarize_env_changes(
before: &HashMap<String, serde_json::Value>,
after: &HashMap<String, serde_json::Value>,
config: &ObservabilityConfig,
) -> Result<EnvChangeSetSummary, handled::SError> {
let mut keys = HashSet::new();
keys.extend(before.keys().cloned());
keys.extend(after.keys().cloned());
let mut changed = Vec::new();
for key in keys {
let before_value = before.get(&key);
let after_value = after.get(&key);
if before_value == after_value {
continue;
}
let change = match (before_value, after_value) {
(None, Some(_)) => EnvChangeKind::Added,
(Some(_), None) => EnvChangeKind::Removed,
(Some(_), Some(_)) => EnvChangeKind::Modified,
(None, None) => continue,
};
changed.push(EnvChangeSummary {
key,
change,
before: summarize_env_value(before_value)?,
after: summarize_env_value(after_value)?,
});
}
changed.sort_by(|lhs, rhs| lhs.key.cmp(&rhs.key));
let changed_key_count = changed.len();
let changes_truncated = changed_key_count > config.max_env_changes;
changed.truncate(config.max_env_changes);
Ok(EnvChangeSetSummary {
changed_key_count,
changes_truncated,
env_before_digest: env_digest(before)?,
env_after_digest: env_digest(after)?,
env_before_key_count: before.len(),
env_after_key_count: after.len(),
changes: changed,
})
}
fn empty_env_change_summary(_: handled::SError) -> EnvChangeSetSummary {
EnvChangeSetSummary {
changed_key_count: 0,
changes_truncated: false,
env_before_digest: "setsum:v1:error".to_string(),
env_after_digest: "setsum:v1:error".to_string(),
env_before_key_count: 0,
env_after_key_count: 0,
changes: Vec::new(),
}
}
fn effective_step(workflow: &Workflow) -> Step {
if matches!(workflow.current_step, Step::Halt)
&& let Some(next) = workflow.continuation.last()
{
return next.clone();
}
workflow.current_step.clone()
}
fn summarize_env_value(
value: Option<&serde_json::Value>,
) -> Result<EnvValueSummary, handled::SError> {
let Some(value) = value else {
return Ok(EnvValueSummary::missing());
};
let normalized = normalized_json_bytes(value)?;
Ok(EnvValueSummary {
shape: value_shape(value),
bytes: Some(normalized.len()),
digest: Some(setsum_digest([normalized.as_slice()])),
})
}
fn value_shape(value: &serde_json::Value) -> ValueShape {
match value {
serde_json::Value::Null => ValueShape::Null,
serde_json::Value::Bool(_) => ValueShape::Bool,
serde_json::Value::Number(_) => ValueShape::Number,
serde_json::Value::String(_) => ValueShape::String,
serde_json::Value::Array(_) => ValueShape::Array,
serde_json::Value::Object(_) => ValueShape::Object,
}
}
fn env_digest(env: &HashMap<String, serde_json::Value>) -> Result<String, handled::SError> {
let mut setsum = Setsum::default();
for (key, value) in env {
let value_bytes = normalized_json_bytes(value)?;
let value_digest = setsum_digest([value_bytes.as_slice()]);
let element = length_prefixed_parts([key.as_bytes(), value_digest.as_bytes()]);
setsum.insert(&element);
}
Ok(format!("setsum:v1:{}", setsum.hexdigest()))
}
fn normalized_json_bytes(value: &serde_json::Value) -> Result<Vec<u8>, handled::SError> {
serde_json::to_vec(&normalize_json_value(value)).map_err(|err| {
observability_error(
"invalid-json-summary",
"failed to serialize normalized JSON for observability summary",
)
.with_string_field("source", &err.to_string())
})
}
fn normalize_json_value(value: &serde_json::Value) -> serde_json::Value {
match value {
serde_json::Value::Array(values) => {
serde_json::Value::Array(values.iter().map(normalize_json_value).collect())
}
serde_json::Value::Object(map) => {
let mut keys: Vec<_> = map.keys().collect();
keys.sort();
let mut normalized = serde_json::Map::new();
for key in keys {
if let Some(value) = map.get(key) {
normalized.insert(key.clone(), normalize_json_value(value));
}
}
serde_json::Value::Object(normalized)
}
other => other.clone(),
}
}
fn setsum_digest<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> String {
let mut setsum = Setsum::default();
setsum.insert(&length_prefixed_parts(parts));
format!("setsum:v1:{}", setsum.hexdigest())
}
fn length_prefixed_parts<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> Vec<u8> {
let mut bytes = Vec::new();
for part in parts {
bytes.extend_from_slice(&(part.len() as u64).to_be_bytes());
bytes.extend_from_slice(part);
}
bytes
}
fn observability_error(code: &str, message: &str) -> handled::SError {
handled::SError::new("langcontinuation")
.with_code(code)
.with_message(message)
}
fn require_halted_branch(name: &str, branch: &Workflow) -> Result<(), handled::SError> {
if matches!(branch.current_step, Step::Halt) {
Ok(())
} else {
Err(handled::SError::new("langcontinuation")
.with_code("fork-join-branch-not-halted")
.with_message("fork/join branch did not halt before join resume")
.with_string_field("branch", name)
.with_string_field("run_id", &branch.run_id)
.with_string_field("current_step", &format!("{:?}", branch.current_step)))
}
}
fn merge_fork_join_env(
base: &HashMap<String, serde_json::Value>,
lhs: &HashMap<String, serde_json::Value>,
rhs: &HashMap<String, serde_json::Value>,
) -> Result<HashMap<String, serde_json::Value>, handled::SError> {
(|| {
let mut merged = base.clone();
let mut keys = HashSet::new();
keys.extend(base.keys().cloned());
keys.extend(lhs.keys().cloned());
keys.extend(rhs.keys().cloned());
for key in keys {
let base_value = base.get(&key);
let lhs_value = lhs.get(&key);
let rhs_value = rhs.get(&key);
let lhs_changed = lhs_value != base_value;
let rhs_changed = rhs_value != base_value;
match (lhs_changed, rhs_changed) {
(false, false) => {}
(true, false) => apply_fork_env_change(&mut merged, key, lhs_value.cloned()),
(false, true) => apply_fork_env_change(&mut merged, key, rhs_value.cloned()),
(true, true) if lhs_value == rhs_value => {
apply_fork_env_change(&mut merged, key, lhs_value.cloned());
}
(true, true) => {
return Err(fork_join_env_conflict_error(&key, lhs_value, rhs_value));
}
}
}
Ok(merged)
})()
}
fn apply_fork_env_change(
env: &mut HashMap<String, serde_json::Value>,
key: String,
value: Option<serde_json::Value>,
) {
if let Some(value) = value {
env.insert(key, value);
} else {
env.remove(&key);
}
}
fn fork_join_env_conflict_error(
key: &str,
lhs: Option<&serde_json::Value>,
rhs: Option<&serde_json::Value>,
) -> handled::SError {
let lhs = format_fork_env_value(lhs);
let rhs = format_fork_env_value(rhs);
handled::SError::new("langcontinuation")
.with_code("fork-join-env-conflict")
.with_message("fork/join branches wrote conflicting environment values")
.with_string_field("key", key)
.with_string_field("lhs", &lhs)
.with_string_field("rhs", &rhs)
}
fn format_fork_env_value(value: Option<&serde_json::Value>) -> String {
value
.map(serde_json::Value::to_string)
.unwrap_or_else(|| "<missing>".into())
}
fn fork_join_resume_error(current_step: &Step) -> handled::SError {
handled::SError::new("langcontinuation")
.with_code("not-suspended-at-fork-join")
.with_message(
"attempted to resume fork/join on a workflow that is not suspended at a fork/join step",
)
.with_string_field("current_step", &format!("{current_step:?}"))
}
fn insert_resume_value<T: serde::Serialize>(
workflow: &mut Workflow,
output_key: String,
value: T,
error_code: &'static str,
error_message: &'static str,
) -> Result<(), handled::SError> {
(|| {
let value = serde_json::to_value(value).map_err(|err| {
handled::SError::new("langcontinuation")
.with_code(error_code)
.with_message(error_message)
.with_string_field("output_key", &output_key)
.with_string_field("source", &err.to_string())
})?;
workflow.env.insert(output_key, value);
Ok(())
})()
}
fn missing_tool_error(tool: &str) -> handled::SError {
handled::SError::new("langcontinuation")
.with_code("missing-tool")
.with_message("model called a tool that is not registered")
.with_string_field("tool", tool)
}
fn missing_function_error(function: &str) -> handled::SError {
handled::SError::new("langcontinuation")
.with_code("missing-function")
.with_message("registered trampoline function is missing")
.with_string_field("function", function)
}
fn resume_error(
code: &str,
message: &str,
expected: Option<(&str, &str)>,
actual: Option<(&str, &str)>,
) -> handled::SError {
let mut error = handled::SError::new("langcontinuation")
.with_code(code)
.with_message(message);
if let Some((key, value)) = expected {
error = error.with_string_field(key, value);
}
if let Some((key, value)) = actual {
error = error.with_string_field(key, value);
}
error
}
#[doc(hidden)]
pub fn __new_continuation() -> Continuation {
Continuation {
_phantom: std::marker::PhantomData,
}
}
#[doc(hidden)]
pub fn __apply_continuation(wf: &mut Workflow, result: ContinuationChoice) {
result.apply_to(wf);
}
#[doc(hidden)]
pub fn __with_continuation<F, E>(wf: &mut Workflow, f: F) -> Result<(), E>
where
F: FnOnce(&mut Workflow, Continuation) -> Result<ContinuationChoice, E>,
{
(|| {
let continuation = __new_continuation();
let result = f(wf, continuation)?;
__apply_continuation(wf, result);
Ok(())
})()
}
#[macro_export]
macro_rules! generate_goto {
(fn $fn_name:ident($wf:ident: &mut Workflow, $a:ident: $at:ident, $b:ident: $bt:ident, $continuation:ident: Continuation) -> Result<ContinuationChoice, $error:ty> $body:block) => {
pub fn $fn_name($wf: &mut Workflow) -> $crate::CallFuture<'_> {
Box::pin(async move {
let __result: Result<(), handled::SError> = (|| {
let key = format!("{}: {}", stringify!($a), stringify!($at));
let $a: $at = $wf
.from_env(key.clone())
.map_err(|err| $crate::env_decode_error(&key, err))?
.ok_or_else(|| $crate::missing_env_error(&key))?;
let key = format!("{}: {}", stringify!($b), stringify!($bt));
let $b: $bt = $wf
.from_env(key.clone())
.map_err(|err| $crate::env_decode_error(&key, err))?
.ok_or_else(|| $crate::missing_env_error(&key))?;
$crate::__with_continuation(
$wf,
|$wf, $continuation| -> Result<$crate::ContinuationChoice, $error> {
let $a = $a;
let $b = $b;
$body
},
)?;
Ok(())
})();
__result
})
}
};
(fn $fn_name:ident($wf:ident: &mut Workflow, $a:ident: $at:ident, $continuation:ident: Continuation) -> Result<ContinuationChoice, $error:ty> $body:block) => {
pub fn $fn_name($wf: &mut Workflow) -> $crate::CallFuture<'_> {
Box::pin(async move {
let __result: Result<(), handled::SError> = (|| {
let key = format!("{}: {}", stringify!($a), stringify!($at));
let $a: $at = $wf
.from_env(key.clone())
.map_err(|err| $crate::env_decode_error(&key, err))?
.ok_or_else(|| $crate::missing_env_error(&key))?;
$crate::__with_continuation(
$wf,
|$wf, $continuation| -> Result<$crate::ContinuationChoice, $error> {
let $a = $a;
$body
},
)?;
Ok(())
})();
__result
})
}
};
(async fn $fn_name:ident($wf:ident: &mut Workflow, $a:ident: $at:ident, $b:ident: $bt:ident, $continuation:ident: Continuation) -> Result<ContinuationChoice, $error:ty> $body:block) => {
pub fn $fn_name($wf: &mut Workflow) -> $crate::CallFuture<'_> {
Box::pin(async move {
let __result: Result<(), handled::SError> = async {
let key = format!("{}: {}", stringify!($a), stringify!($at));
let $a: $at = $wf
.from_env(key.clone())
.map_err(|err| $crate::env_decode_error(&key, err))?
.ok_or_else(|| $crate::missing_env_error(&key))?;
let key = format!("{}: {}", stringify!($b), stringify!($bt));
let $b: $bt = $wf
.from_env(key.clone())
.map_err(|err| $crate::env_decode_error(&key, err))?
.ok_or_else(|| $crate::missing_env_error(&key))?;
let $continuation = $crate::__new_continuation();
let result: Result<$crate::ContinuationChoice, $error> = {
let $a = $a;
let $b = $b;
$body
};
$crate::__apply_continuation($wf, result?);
Ok(())
}
.await;
__result
})
}
};
(async fn $fn_name:ident($wf:ident: &mut Workflow, $a:ident: $at:ident, $continuation:ident: Continuation) -> Result<ContinuationChoice, $error:ty> $body:block) => {
pub fn $fn_name($wf: &mut Workflow) -> $crate::CallFuture<'_> {
Box::pin(async move {
let __result: Result<(), handled::SError> = async {
let key = format!("{}: {}", stringify!($a), stringify!($at));
let $a: $at = $wf
.from_env(key.clone())
.map_err(|err| $crate::env_decode_error(&key, err))?
.ok_or_else(|| $crate::missing_env_error(&key))?;
let $continuation = $crate::__new_continuation();
let result: Result<$crate::ContinuationChoice, $error> = {
let $a = $a;
$body
};
$crate::__apply_continuation($wf, result?);
Ok(())
}
.await;
__result
})
}
};
}
pub fn env_decode_error(key: &str, err: serde_json::Error) -> handled::SError {
handled::SError::new("support-pipeline")
.with_code("invalid-env-value")
.with_message("failed to decode workflow environment value")
.with_string_field("key", key)
.with_string_field("source", &err.to_string())
}
pub fn env_encode_error(key: &str, err: serde_json::Error) -> handled::SError {
handled::SError::new("support-pipeline")
.with_code("invalid-output-value")
.with_message("failed to encode workflow environment value")
.with_string_field("key", key)
.with_string_field("source", &err.to_string())
}
pub fn missing_env_error(key: &str) -> handled::SError {
handled::SError::new("support-pipeline")
.with_code("missing-env-value")
.with_message("required workflow environment value is missing")
.with_string_field("key", key)
}
#[macro_export]
macro_rules! push_env {
($wf:ident.$t:ident: $tt:ty = $e:expr) => {
let key = format!("{}: {}", stringify!($t), stringify!($tt));
let value: $tt = $e;
$wf.into_env(&key, value)
.map_err(|err| $crate::env_encode_error(&key, err))?;
};
}
#[macro_export]
macro_rules! from_env {
(let $a:ident: $at:ident = $wf:ident.lookup()) => {
let key = format!("{}: {}", stringify!($a), stringify!($at));
let $a: $at = $wf
.from_env(key.clone())
.map_err(|err| $crate::env_decode_error(&key, err))?
.ok_or_else(|| $crate::missing_env_error(&key))?;
};
}
#[macro_export]
macro_rules! retval {
(let $a:ident: $at:ident = $wf:ident.lookup()) => {
let key = "retval".to_string();
let $a: $at = $wf
.from_env(key.clone())
.map_err(|err| $crate::env_decode_error(&key, err))?
.ok_or_else(|| $crate::missing_env_error(&key))?;
};
}
#[cfg(test)]
mod tests {
use super::*;
use claudius::{ContentBlock, KnownModel, TextBlock, Usage};
use serde_json::json;
fn anthropic_request(prompt: &str) -> Box<MessageCreateParams> {
Box::new(MessageCreateParams::simple(
prompt,
KnownModel::ClaudeHaiku45,
))
}
fn anthropic_step(prompt: &str, output_key: &str) -> Step {
Step::Anthropic {
provider: "anthropic".into(),
message: anthropic_request(prompt),
output_key: output_key.into(),
}
}
fn human_request(prompt: &str) -> HumanRequest {
HumanRequest::new(prompt)
.with_context(json!({"ticket_id": "ticket-001"}))
.with_metadata(json!({"queue": "support"}))
}
fn human_step(prompt: &str, output_key: &str) -> Step {
Step::Human {
request: human_request(prompt),
output_key: output_key.into(),
}
}
fn anthropic_response(text: &str) -> Message {
Message::new(
"msg_test".into(),
vec![ContentBlock::Text(TextBlock::new(text))],
KnownModel::ClaudeHaiku45.into(),
Usage::new(1, 1),
)
}
fn call_error(code: &str) -> handled::SError {
handled::SError::new("test").with_code(code)
}
fn fork_join_workflows() -> (Workflow, Workflow, Workflow) {
let mut workflow = Workflow::new("parent", "entry");
workflow.into_env("base", "inherited").unwrap();
let mut trampoline = Trampoline::default();
trampoline.register(
"entry",
test_sync_call(|workflow| {
__with_continuation(
workflow,
|_, continuation| -> Result<ContinuationChoice, handled::SError> {
Ok(continuation.fork_join(
ForkBranch::new("caller-lhs", "run_lhs"),
ForkBranch::new("caller-rhs", "run_rhs"),
"join",
))
},
)
}),
);
let result = test_run_trampoline(&trampoline, workflow).expect("run");
let WorkflowResult::ForkJoin { workflow, lhs, rhs } = result else {
panic!("workflow should suspend for fork/join");
};
(workflow, *lhs, *rhs)
}
fn halt_branch(branch: &mut Workflow) {
branch.current_step = Step::Halt;
}
fn assert_call_step(step: &Step, expected: &str) {
let Step::Call { function } = step else {
panic!("expected call step");
};
assert_eq!(function, expected);
}
#[test]
fn default_workflow_halts_with_nop_current_step() {
let trampoline = Trampoline::default();
let result = test_run_trampoline(&trampoline, Workflow::default()).expect("run");
let WorkflowResult::Halt { workflow } = result else {
panic!("default workflow should halt");
};
assert!(matches!(workflow.current_step, Step::Halt));
}
#[test]
fn pending_workflow_events_are_skipped_by_serde() {
let mut workflow = Workflow::new("run", "entry");
workflow
.record_event("ticket.received", json!({"queue": "support"}))
.expect("record event");
assert_eq!(workflow.drain_pending_events().len(), 1);
workflow
.record_event("ticket.received", json!({"queue": "support"}))
.expect("record event");
let encoded = serde_json::to_value(&workflow).expect("encode workflow");
let mut decoded: Workflow = serde_json::from_value(encoded).expect("decode workflow");
assert!(decoded.drain_pending_events().is_empty());
}
#[test]
fn custom_workflow_events_reject_reserved_prefixes() {
let mut workflow = Workflow::new("run", "entry");
let error = workflow
.record_event("workflow.fake", json!({}))
.expect_err("reserved prefix should fail");
assert!(error.to_string().contains("reserved-custom-event-type"));
}
#[tokio::test]
async fn workflow_error_preserves_pre_failure_events_and_partial_env() {
let mut trampoline = Trampoline::default();
trampoline.register("entry", |workflow| -> CallFuture<'_> {
Box::pin(async move {
workflow
.record_event("ticket.loaded", json!({"id": "T-1"}))
.unwrap();
workflow.into_env("partial", true).unwrap();
Err(call_error("boom"))
})
});
let mut workflow = Workflow::new("run", "entry");
workflow.set_observability_context(ObservabilityContext {
causal_cursor: CausalRef::RunId {
run_id: "run".into(),
},
});
let error = trampoline
.run_one_local_call(workflow)
.await
.expect_err("call should fail");
assert_eq!(error.events.len(), 1);
assert_eq!(error.events[0].event_type, "ticket.loaded");
assert_eq!(
error.workflow.from_env::<bool>("partial").unwrap(),
Some(true)
);
assert!(error.env_changes.changed_key_count >= 1);
assert!(error.to_string().contains("boom"));
}
#[test]
fn registered_call_mutates_env_and_halts() {
let workflow = Workflow::new("test", "mark_done");
let mut trampoline = Trampoline::default();
trampoline.register(
"mark_done",
test_sync_call(|workflow| {
workflow
.into_env("done", true)
.map_err(|_| call_error("serialize"))?;
Ok(())
}),
);
let result = test_run_trampoline(&trampoline, workflow).expect("run");
let WorkflowResult::Halt { workflow } = result else {
panic!("workflow should halt");
};
assert_eq!(workflow.from_env::<bool>("done").unwrap(), Some(true));
assert!(matches!(workflow.current_step, Step::Halt));
}
#[test]
fn missing_call_returns_structured_error() {
let workflow = Workflow::new("test", "missing");
let trampoline = Trampoline::default();
let error =
test_run_trampoline(&trampoline, workflow).expect_err("missing function should error");
assert!(error.to_string().contains("missing-function"));
assert!(error.to_string().contains("missing"));
}
#[test]
fn call_schedules_next_step_with_lifo_order() {
let mut workflow = Workflow::new("test", "first");
workflow.schedule(Step::Call {
function: "third".into(),
});
let mut trampoline = Trampoline::default();
trampoline.register(
"first",
test_sync_call(|workflow| {
let mut order = workflow
.from_env::<Vec<String>>("order")
.unwrap()
.unwrap_or_default();
order.push("first".into());
workflow
.into_env("order", order)
.map_err(|_| call_error("serialize"))?;
workflow.schedule(Step::Call {
function: "second".into(),
});
Ok(())
}),
);
trampoline.register(
"second",
test_sync_call(|workflow| {
let mut order = workflow
.from_env::<Vec<String>>("order")
.unwrap()
.unwrap_or_default();
order.push("second".into());
workflow
.into_env("order", order)
.map_err(|_| call_error("serialize"))?;
Ok(())
}),
);
trampoline.register(
"third",
test_sync_call(|workflow| {
let mut order = workflow
.from_env::<Vec<String>>("order")
.unwrap()
.unwrap_or_default();
order.push("third".into());
workflow
.into_env("order", order)
.map_err(|_| call_error("serialize"))?;
Ok(())
}),
);
let result = test_run_trampoline(&trampoline, workflow).expect("run");
let WorkflowResult::Halt { workflow } = result else {
panic!("workflow should halt");
};
assert_eq!(
workflow.from_env::<Vec<String>>("order").unwrap().unwrap(),
vec!["first", "second", "third"]
);
}
#[test]
fn fork_join_uses_caller_provided_branch_run_ids() {
let (workflow, lhs, rhs) = fork_join_workflows();
assert!(matches!(workflow.current_step, Step::ForkJoin { .. }));
assert_eq!(lhs.run_id, "caller-lhs");
assert_eq!(rhs.run_id, "caller-rhs");
}
#[test]
fn fork_join_branches_inherit_env_and_start_at_configured_calls() {
let (_, lhs, rhs) = fork_join_workflows();
assert_eq!(
lhs.from_env::<String>("base").unwrap(),
Some("inherited".into())
);
assert_eq!(
rhs.from_env::<String>("base").unwrap(),
Some("inherited".into())
);
assert_call_step(&lhs.current_step, "run_lhs");
assert_call_step(&rhs.current_step, "run_rhs");
assert!(lhs.continuation.is_empty());
assert!(rhs.continuation.is_empty());
}
#[test]
fn fork_join_conflicting_env_writes_return_structured_error() {
let (workflow, mut lhs, mut rhs) = fork_join_workflows();
halt_branch(&mut lhs);
halt_branch(&mut rhs);
lhs.into_env("shared", "lhs").unwrap();
rhs.into_env("shared", "rhs").unwrap();
let error = Trampoline::default()
.resume_fork_join(workflow, lhs, rhs)
.expect_err("conflicting branch writes should fail");
assert!(error.to_string().contains("fork-join-env-conflict"));
assert!(error.to_string().contains("shared"));
}
#[test]
fn fork_join_identical_same_key_writes_are_accepted() {
let (workflow, mut lhs, mut rhs) = fork_join_workflows();
halt_branch(&mut lhs);
halt_branch(&mut rhs);
lhs.into_env("shared", "same").unwrap();
rhs.into_env("shared", "same").unwrap();
let workflow = Trampoline::default()
.resume_fork_join(workflow, lhs, rhs)
.expect("identical branch writes should merge");
assert_eq!(
workflow.from_env::<String>("shared").unwrap(),
Some("same".into())
);
assert_call_step(&workflow.current_step, "join");
}
#[test]
fn fork_join_resume_rejects_non_fork_join_workflow() {
let mut lhs = Workflow::default();
let mut rhs = Workflow::default();
halt_branch(&mut lhs);
halt_branch(&mut rhs);
let error = Trampoline::default()
.resume_fork_join(Workflow::default(), lhs, rhs)
.expect_err("non-fork workflow should fail");
assert!(error.to_string().contains("not-suspended-at-fork-join"));
}
#[test]
fn fork_join_resume_rejects_non_halted_branch() {
let (workflow, lhs, mut rhs) = fork_join_workflows();
halt_branch(&mut rhs);
let error = Trampoline::default()
.resume_fork_join(workflow, lhs, rhs)
.expect_err("non-halted branch should fail");
assert!(error.to_string().contains("fork-join-branch-not-halted"));
assert!(error.to_string().contains("lhs"));
assert!(error.to_string().contains("caller-lhs"));
}
#[test]
fn anthropic_step_suspends_with_workflow_and_output_key() {
let mut workflow = Workflow::default();
workflow.schedule(anthropic_step("classify", "response"));
let trampoline = Trampoline::default();
let result = test_run_trampoline(&trampoline, workflow).expect("run");
let WorkflowResult::Anthropic {
workflow,
provider,
message,
output_key,
} = result
else {
panic!("workflow should suspend for Anthropic");
};
assert!(matches!(workflow.current_step, Step::Anthropic { .. }));
assert_eq!(provider, "anthropic");
assert_eq!(message.messages.len(), 1);
assert_eq!(output_key, "response");
}
#[test]
fn anthropic_continuation_suspends_then_invokes_next_function() {
let workflow = Workflow::new("test", "entry");
let mut trampoline = Trampoline::default();
trampoline.register(
"entry",
test_sync_call(|workflow| {
__with_continuation(
workflow,
|_, continuation| -> Result<ContinuationChoice, handled::SError> {
Ok(continuation.anthropic(
"anthropic",
*anthropic_request("classify"),
"response",
"after",
))
},
)
}),
);
trampoline.register(
"after",
test_sync_call(|workflow| {
let _: Message = workflow.from_env("response").unwrap().unwrap();
workflow
.into_env("after", true)
.map_err(|_| call_error("serialize"))?;
Ok(())
}),
);
let result = test_run_trampoline(&trampoline, workflow).expect("run");
let WorkflowResult::Anthropic {
workflow,
provider,
message,
output_key,
} = result
else {
panic!("workflow should suspend for Anthropic");
};
assert_eq!(provider, "anthropic");
assert_eq!(message.messages.len(), 1);
assert_eq!(output_key, "response");
let workflow = trampoline
.resume_anthropic(workflow, output_key, anthropic_response("done"))
.expect("resume");
let result = test_run_trampoline(&trampoline, workflow).expect("run after resume");
let WorkflowResult::Halt { workflow } = result else {
panic!("workflow should halt");
};
assert!(workflow.from_env::<Message>("response").unwrap().is_some());
assert_eq!(workflow.from_env::<bool>("after").unwrap(), Some(true));
}
#[test]
fn anthropic_resume_stores_message_and_advances() {
let mut workflow = Workflow::default();
workflow.schedule(Step::Call {
function: "after".into(),
});
workflow.schedule(anthropic_step("classify", "response"));
let mut trampoline = Trampoline::default();
trampoline.register(
"after",
test_sync_call(|workflow| {
workflow
.into_env("after", true)
.map_err(|_| call_error("serialize"))?;
Ok(())
}),
);
let result = test_run_trampoline(&trampoline, workflow).expect("run");
let WorkflowResult::Anthropic {
workflow,
output_key,
message: _,
provider: _,
} = result
else {
panic!("workflow should suspend for Anthropic");
};
let workflow = trampoline
.resume_anthropic(workflow, output_key, anthropic_response("done"))
.expect("resume");
let result = test_run_trampoline(&trampoline, workflow).expect("run after resume");
let WorkflowResult::Halt { workflow } = result else {
panic!("workflow should halt");
};
assert!(workflow.from_env::<Message>("response").unwrap().is_some());
assert_eq!(workflow.from_env::<bool>("after").unwrap(), Some(true));
assert!(matches!(workflow.current_step, Step::Halt));
}
#[test]
fn anthropic_resume_rejects_wrong_current_step() {
let error = Trampoline::default()
.resume_anthropic(Workflow::default(), "response", anthropic_response("done"))
.expect_err("resume should fail");
assert!(error.to_string().contains("not-suspended-at-anthropic"));
}
#[test]
fn anthropic_resume_rejects_wrong_output_key() {
let mut workflow = Workflow::default();
workflow.schedule(anthropic_step("classify", "expected"));
let trampoline = Trampoline::default();
let result = test_run_trampoline(&trampoline, workflow).expect("run");
let WorkflowResult::Anthropic { workflow, .. } = result else {
panic!("workflow should suspend for Anthropic");
};
let error = Trampoline::default()
.resume_anthropic(workflow, "actual", anthropic_response("done"))
.expect_err("resume should fail");
assert!(error.to_string().contains("anthropic-output-key-mismatch"));
}
#[test]
fn open_ai_resume_stores_value_and_advances() {
let mut workflow = Workflow::default();
workflow.schedule(Step::Call {
function: "after".into(),
});
workflow.schedule(Step::OpenAI {});
let mut trampoline = Trampoline::default();
trampoline.register(
"after",
test_sync_call(|workflow| {
workflow
.into_env("after", true)
.map_err(|_| call_error("serialize"))?;
Ok(())
}),
);
let result = test_run_trampoline(&trampoline, workflow).expect("run");
let WorkflowResult::OpenAI { workflow } = result else {
panic!("workflow should suspend for OpenAI");
};
let value = json!({"text": "done"});
let workflow = trampoline
.resume_open_ai(workflow, "response", value.clone())
.expect("resume");
let result = test_run_trampoline(&trampoline, workflow).expect("run after resume");
let WorkflowResult::Halt { workflow } = result else {
panic!("workflow should halt");
};
assert_eq!(
workflow.from_env::<serde_json::Value>("response").unwrap(),
Some(value)
);
assert_eq!(workflow.from_env::<bool>("after").unwrap(), Some(true));
assert!(matches!(workflow.current_step, Step::Halt));
}
#[test]
fn open_ai_resume_rejects_wrong_current_step() {
let error = Trampoline::default()
.resume_open_ai(Workflow::default(), "response", json!({"text": "done"}))
.expect_err("resume should fail");
assert!(error.to_string().contains("not-suspended-at-openai"));
}
#[test]
fn human_request_new_uses_null_context_and_empty_metadata() {
let request = HumanRequest::new("Review the answer");
assert_eq!(
request,
HumanRequest {
prompt: "Review the answer".into(),
context: serde_json::Value::Null,
metadata: json!({}),
}
);
}
#[test]
fn human_step_suspends_with_workflow_request_and_output_key() {
let mut workflow = Workflow::default();
workflow.schedule(human_step(
"Approve the ticket closure",
"human_answer: String",
));
let trampoline = Trampoline::default();
let result = test_run_trampoline(&trampoline, workflow).expect("run");
let WorkflowResult::Human {
workflow,
request,
output_key,
} = result
else {
panic!("workflow should suspend for human input");
};
assert!(matches!(workflow.current_step, Step::Human { .. }));
assert_eq!(request, human_request("Approve the ticket closure"));
assert_eq!(output_key, "human_answer: String");
}
#[test]
fn human_continuation_suspends_then_invokes_next_function() {
let workflow = Workflow::new("test", "entry");
let mut trampoline = Trampoline::default();
trampoline.register(
"entry",
test_sync_call(|workflow| {
__with_continuation(
workflow,
|_, continuation| -> Result<ContinuationChoice, handled::SError> {
Ok(continuation.human(
human_request("Approve the ticket closure"),
"human_answer: String",
"after",
))
},
)
}),
);
trampoline.register(
"after",
test_sync_call(|workflow| {
let answer: String = workflow.from_env("human_answer: String").unwrap().unwrap();
workflow
.into_env("after", format!("accepted: {answer}"))
.map_err(|_| call_error("serialize"))?;
Ok(())
}),
);
let result = test_run_trampoline(&trampoline, workflow).expect("run");
let WorkflowResult::Human {
workflow,
request,
output_key,
} = result
else {
panic!("workflow should suspend for human input");
};
assert_eq!(request, human_request("Approve the ticket closure"));
assert_eq!(output_key, "human_answer: String");
let workflow = trampoline
.resume_human(workflow, output_key, "yes".to_string())
.expect("resume");
let result = test_run_trampoline(&trampoline, workflow).expect("run after resume");
let WorkflowResult::Halt { workflow } = result else {
panic!("workflow should halt");
};
assert_eq!(
workflow.from_env::<String>("human_answer: String").unwrap(),
Some("yes".into())
);
assert_eq!(
workflow.from_env::<String>("after").unwrap(),
Some("accepted: yes".into())
);
}
#[test]
fn human_resume_stores_serializable_value_and_advances() {
let mut workflow = Workflow::default();
workflow.schedule(Step::Call {
function: "after".into(),
});
workflow.schedule(human_step("Fill out the review form", "human_answer"));
let mut trampoline = Trampoline::default();
trampoline.register(
"after",
test_sync_call(|workflow| {
workflow
.into_env("after", true)
.map_err(|_| call_error("serialize"))?;
Ok(())
}),
);
let result = test_run_trampoline(&trampoline, workflow).expect("run");
let WorkflowResult::Human {
workflow,
output_key,
request: _,
} = result
else {
panic!("workflow should suspend for human input");
};
let value = json!({"decision": "approved", "note": "looks correct"});
let workflow = trampoline
.resume_human(workflow, output_key, value.clone())
.expect("resume");
let result = test_run_trampoline(&trampoline, workflow).expect("run after resume");
let WorkflowResult::Halt { workflow } = result else {
panic!("workflow should halt");
};
assert_eq!(
workflow
.from_env::<serde_json::Value>("human_answer")
.unwrap(),
Some(value)
);
assert_eq!(workflow.from_env::<bool>("after").unwrap(), Some(true));
assert!(matches!(workflow.current_step, Step::Halt));
}
#[test]
fn human_resume_rejects_wrong_current_step() {
let error = Trampoline::default()
.resume_human(Workflow::default(), "human_answer", "yes")
.expect_err("resume should fail");
assert!(error.to_string().contains("not-suspended-at-human"));
}
#[test]
fn human_resume_rejects_wrong_output_key() {
let mut workflow = Workflow::default();
workflow.schedule(human_step("Approve the ticket closure", "expected"));
let trampoline = Trampoline::default();
let result = test_run_trampoline(&trampoline, workflow).expect("run");
let WorkflowResult::Human { workflow, .. } = result else {
panic!("workflow should suspend for human input");
};
let error = Trampoline::default()
.resume_human(workflow, "actual", "yes")
.expect_err("resume should fail");
assert!(error.to_string().contains("human-output-key-mismatch"));
}
#[test]
fn human_resume_rejects_invalid_response_serialization() {
struct InvalidResponse;
impl serde::Serialize for InvalidResponse {
fn serialize<S>(&self, _: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
Err(serde::ser::Error::custom("cannot serialize human response"))
}
}
let mut workflow = Workflow::default();
workflow.schedule(human_step("Approve the ticket closure", "human_answer"));
let trampoline = Trampoline::default();
let result = test_run_trampoline(&trampoline, workflow).expect("run");
let WorkflowResult::Human { workflow, .. } = result else {
panic!("workflow should suspend for human input");
};
let error = Trampoline::default()
.resume_human(workflow, "human_answer", InvalidResponse)
.expect_err("resume should fail");
assert!(error.to_string().contains("invalid-human-response"));
}
#[test]
fn human_suspended_workflow_round_trips_through_serde() {
let mut workflow = Workflow::default();
workflow.schedule(Step::Call {
function: "after".into(),
});
workflow.schedule(human_step("Approve the ticket closure", "human_answer"));
let trampoline = Trampoline::default();
let result = test_run_trampoline(&trampoline, workflow).expect("run");
let WorkflowResult::Human {
workflow,
request,
output_key,
} = result
else {
panic!("workflow should suspend for human input");
};
assert_eq!(request, human_request("Approve the ticket closure"));
assert_eq!(output_key, "human_answer");
let encoded = serde_json::to_string(&workflow).expect("serialize workflow");
let workflow: Workflow = serde_json::from_str(&encoded).expect("deserialize workflow");
let mut trampoline = Trampoline::default();
trampoline.register(
"after",
test_sync_call(|workflow| {
let answer: String = workflow.from_env("human_answer").unwrap().unwrap();
workflow
.into_env("after", answer == "approved")
.map_err(|_| call_error("serialize"))?;
Ok(())
}),
);
let workflow = trampoline
.resume_human(workflow, "human_answer", "approved".to_string())
.expect("resume");
let result = test_run_trampoline(&trampoline, workflow).expect("run after resume");
let WorkflowResult::Halt { workflow } = result else {
panic!("workflow should halt");
};
assert_eq!(workflow.from_env::<bool>("after").unwrap(), Some(true));
}
struct EchoTool;
impl Tool for EchoTool {
fn name(&self) -> String {
"echo".into()
}
fn to_param(&self) -> ToolUnionParam {
unimplemented!("tests do not advertise the tool to a model")
}
fn call<'a>(
&'a self,
id: ToolCallId,
tool_use: &'a ToolUseBlock,
) -> Pin<Box<dyn Future<Output = ToolResultBlock> + Send + 'a>> {
let tool_use_id = tool_use.id.clone();
let body = format!("echo {}", id);
Box::pin(async move { ToolResultBlock::new(tool_use_id).with_string_content(body) })
}
}
fn tool_use_block(id: &str, name: &str) -> ToolUseBlock {
ToolUseBlock::new(id, name, json!({}))
}
fn tool_use_response(id: &str, name: &str) -> Message {
Message::new(
"msg_tool".into(),
vec![ContentBlock::ToolUse(tool_use_block(id, name))],
KnownModel::ClaudeHaiku45.into(),
Usage::new(1, 1),
)
}
#[test]
fn tool_call_continuation_suspends_with_uses_and_output_key() {
let mut trampoline = Trampoline::default();
trampoline.register(
"entry",
test_sync_call(|workflow| {
__with_continuation(
workflow,
|_, continuation| -> Result<ContinuationChoice, handled::SError> {
Ok(continuation.tool_call(
vec![tool_use_block("toolu_1", "echo")],
"results",
"after",
))
},
)
}),
);
let result = test_run_trampoline(&trampoline, Workflow::new("run", "entry")).expect("run");
let WorkflowResult::ToolCall {
workflow,
tool_uses,
output_key,
} = result
else {
panic!("workflow should suspend at tool call");
};
assert_eq!(output_key, "results");
assert_eq!(tool_uses.len(), 1);
assert_eq!(tool_uses[0].id, "toolu_1");
assert_eq!(workflow.run_id(), "run");
}
#[tokio::test]
async fn run_tool_calls_dispatches_registered_tools_in_order() {
let mut trampoline = Trampoline::default();
trampoline.register_tool(EchoTool);
let uses = vec![
tool_use_block("toolu_a", "echo"),
tool_use_block("toolu_b", "echo"),
];
let results = trampoline
.run_tool_calls("run-7", &uses)
.await
.expect("dispatch");
assert_eq!(results.len(), 2);
assert_eq!(results[0].tool_use_id, "toolu_a");
assert_eq!(results[1].tool_use_id, "toolu_b");
let content = match results[0].content.clone().unwrap() {
claudius::ToolResultBlockContent::String(s) => s,
_ => panic!("expected string content"),
};
assert_eq!(content, "echo run-7:toolu_a");
}
#[tokio::test]
async fn run_tool_calls_reports_missing_tool() {
let trampoline = Trampoline::default();
let uses = vec![tool_use_block("toolu_x", "absent")];
let error = trampoline
.run_tool_calls("run", &uses)
.await
.expect_err("unregistered tool should error");
assert!(error.to_string().contains("missing-tool"));
assert!(error.to_string().contains("absent"));
}
#[test]
fn resume_tool_call_stores_results_and_advances() {
let mut trampoline = Trampoline::default();
trampoline.register(
"entry",
test_sync_call(|workflow| {
__with_continuation(
workflow,
|_, continuation| -> Result<ContinuationChoice, handled::SError> {
Ok(continuation.tool_call(
vec![tool_use_block("toolu_1", "echo")],
"results",
"after",
))
},
)
}),
);
trampoline.register(
"after",
test_sync_call(|workflow| {
let results: Vec<ToolResultBlock> = workflow.from_env("results").unwrap().unwrap();
workflow
.into_env("count", results.len() as u64)
.map_err(|_| call_error("serialize"))?;
Ok(())
}),
);
let result = test_run_trampoline(&trampoline, Workflow::new("run", "entry")).expect("run");
let WorkflowResult::ToolCall { workflow, .. } = result else {
panic!("workflow should suspend at tool call");
};
let results = vec![ToolResultBlock::new("toolu_1".into()).with_string_content("ok".into())];
let workflow = trampoline
.resume_tool_call(workflow, "results", results)
.expect("resume");
let result = test_run_trampoline(&trampoline, workflow).expect("run after resume");
let WorkflowResult::Halt { workflow } = result else {
panic!("workflow should halt");
};
assert_eq!(workflow.from_env::<u64>("count").unwrap(), Some(1));
}
#[test]
fn resume_tool_call_rejects_wrong_output_key() {
let mut trampoline = Trampoline::default();
trampoline.register(
"entry",
test_sync_call(|workflow| {
__with_continuation(
workflow,
|_, continuation| -> Result<ContinuationChoice, handled::SError> {
Ok(continuation.tool_call(
vec![tool_use_block("toolu_1", "echo")],
"results",
"after",
))
},
)
}),
);
let result = test_run_trampoline(&trampoline, Workflow::new("run", "entry")).expect("run");
let WorkflowResult::ToolCall { workflow, .. } = result else {
panic!("workflow should suspend at tool call");
};
let error = trampoline
.resume_tool_call(workflow, "wrong", Vec::new())
.expect_err("wrong output key should error");
assert!(error.to_string().contains("tool-call-output-key-mismatch"));
}
#[test]
fn resume_tool_call_rejects_non_tool_step() {
let trampoline = Trampoline::default();
let workflow = Workflow::new("run", "entry");
let error = trampoline
.resume_tool_call(workflow, "results", Vec::new())
.expect_err("non-tool step should error");
assert!(error.to_string().contains("not-suspended-at-tool-call"));
}
#[test]
fn dispatch_tool_uses_raises_suspension_only_when_tools_called() {
let continuation = __new_continuation();
match dispatch_tool_uses(
continuation,
&tool_use_response("toolu_1", "echo"),
"results",
"after",
) {
ToolDispatch::Tools(_) => {}
ToolDispatch::Done(_) => panic!("tool_use response should dispatch tools"),
}
let continuation = __new_continuation();
match dispatch_tool_uses(
continuation,
&anthropic_response("done"),
"results",
"after",
) {
ToolDispatch::Done(_) => {}
ToolDispatch::Tools(_) => panic!("text response should not dispatch tools"),
}
}
#[test]
fn tool_call_workflow_round_trips_through_serde() {
let mut trampoline = Trampoline::default();
trampoline.register(
"entry",
test_sync_call(|workflow| {
__with_continuation(
workflow,
|_, continuation| -> Result<ContinuationChoice, handled::SError> {
Ok(continuation.tool_call(
vec![tool_use_block("toolu_1", "echo")],
"results",
"after",
))
},
)
}),
);
let result = test_run_trampoline(&trampoline, Workflow::new("run", "entry")).expect("run");
let WorkflowResult::ToolCall { workflow, .. } = result else {
panic!("workflow should suspend at tool call");
};
let serialized = serde_json::to_string(&workflow).expect("serialize");
let restored: Workflow = serde_json::from_str(&serialized).expect("deserialize");
assert!(matches!(restored.current_step, Step::ToolCall { .. }));
}
}