use super::data_source::DataSource;
use super::prompt_handler::{ExecutionContext, WorkflowPromptHandler};
use super::sequential::SequentialWorkflow;
use super::workflow_step::WorkflowStep;
use crate::error::Result;
use crate::server::cancellation::RequestHandlerExtra;
use crate::server::tasks::TaskRouter;
use crate::server::PromptHandler;
#[cfg(test)]
use crate::types::Role;
use crate::types::{Content, GetPromptResult, PromptInfo, PromptMessage};
use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
const WORKFLOW_PROGRESS_KEY: &str = "_workflow.progress";
const WORKFLOW_PAUSE_REASON_KEY: &str = "_workflow.pause_reason";
fn workflow_result_key(step_name: &str) -> String {
format!("_workflow.result.{step_name}")
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub(crate) enum StepStatus {
#[default]
Pending,
Completed,
Failed,
Skipped,
}
impl StepStatus {
fn as_str(self) -> &'static str {
match self {
Self::Pending => "pending",
Self::Completed => "completed",
Self::Failed => "failed",
Self::Skipped => "skipped",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum PauseReason {
UnresolvableParams {
blocked_step: String,
missing_param: String,
suggested_tool: String,
},
SchemaMismatch {
blocked_step: String,
missing_fields: Vec<String>,
suggested_tool: String,
},
ToolError {
failed_step: String,
error: String,
retryable: bool,
suggested_tool: String,
},
UnresolvedDependency {
blocked_step: String,
missing_output: String,
producing_step: String,
suggested_tool: String,
},
}
impl PauseReason {
fn to_value(&self) -> Value {
match self {
Self::UnresolvableParams {
blocked_step,
missing_param,
suggested_tool,
} => serde_json::json!({
"type": "unresolvableParams",
"blockedStep": blocked_step,
"missingParam": missing_param,
"suggestedTool": suggested_tool,
}),
Self::SchemaMismatch {
blocked_step,
missing_fields,
suggested_tool,
} => serde_json::json!({
"type": "schemaMismatch",
"blockedStep": blocked_step,
"missingFields": missing_fields,
"suggestedTool": suggested_tool,
}),
Self::ToolError {
failed_step,
error,
retryable,
suggested_tool,
} => serde_json::json!({
"type": "toolError",
"failedStep": failed_step,
"error": error,
"retryable": retryable,
"suggestedTool": suggested_tool,
}),
Self::UnresolvedDependency {
blocked_step,
missing_output,
producing_step,
suggested_tool,
} => serde_json::json!({
"type": "unresolvedDependency",
"blockedStep": blocked_step,
"missingOutput": missing_output,
"producingStep": producing_step,
"suggestedTool": suggested_tool,
}),
}
}
}
pub struct TaskWorkflowPromptHandler {
inner: WorkflowPromptHandler,
task_router: Arc<dyn TaskRouter>,
workflow: SequentialWorkflow,
}
impl std::fmt::Debug for TaskWorkflowPromptHandler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TaskWorkflowPromptHandler")
.field("workflow", &self.workflow.name())
.field("inner", &"WorkflowPromptHandler")
.finish()
}
}
impl TaskWorkflowPromptHandler {
pub fn new(
inner: WorkflowPromptHandler,
task_router: Arc<dyn TaskRouter>,
workflow: SequentialWorkflow,
) -> Self {
Self {
inner,
task_router,
workflow,
}
}
fn build_initial_progress_typed(&self) -> Value {
let steps: Vec<Value> = self
.workflow
.steps()
.iter()
.map(|step| {
let mut step_obj = serde_json::Map::new();
step_obj.insert("name".to_string(), Value::String(step.name().to_string()));
if let Some(tool) = step.tool() {
step_obj.insert("tool".to_string(), Value::String(tool.name().to_string()));
}
step_obj.insert("status".to_string(), Value::String("pending".to_string()));
Value::Object(step_obj)
})
.collect();
let goal = format!("{}: {}", self.workflow.name(), self.workflow.description());
serde_json::json!({
"goal": goal,
"steps": steps,
"schemaVersion": 1
})
}
fn build_meta_map(
task_id: &str,
task_status: &str,
step_names: &[String],
step_statuses: &[StepStatus],
pause_reason: Option<&PauseReason>,
) -> serde_json::Map<String, Value> {
let steps: Vec<Value> = step_names
.iter()
.zip(step_statuses.iter())
.map(|(name, status)| {
serde_json::json!({
"name": name,
"status": status.as_str(),
})
})
.collect();
let mut meta = serde_json::Map::new();
meta.insert("task_id".to_string(), Value::String(task_id.to_string()));
meta.insert(
"task_status".to_string(),
Value::String(task_status.to_string()),
);
meta.insert("steps".to_string(), Value::Array(steps));
if let Some(reason) = pause_reason {
meta.insert("pause_reason".to_string(), reason.to_value());
}
meta
}
fn build_placeholder_args(step: &WorkflowStep, args: &HashMap<String, String>) -> String {
let mut map = serde_json::Map::new();
for (arg_name, data_source) in step.arguments() {
let value = match data_source {
DataSource::PromptArg(name) => {
if let Some(val) = args.get(name.as_str()) {
Value::String(val.clone())
} else {
Value::String(format!("<prompt arg {}>", name))
}
},
DataSource::StepOutput {
step: binding,
field: None,
} => Value::String(format!("<output from {}>", binding)),
DataSource::StepOutput {
step: binding,
field: Some(f),
} => Value::String(format!("<field '{}' from {}>", f, binding)),
DataSource::Constant(val) => val.clone(),
};
map.insert(arg_name.to_string(), value);
}
serde_json::to_string(&Value::Object(map)).unwrap_or_else(|_| "{}".to_string())
}
fn build_handoff_message(
&self,
step_statuses: &[StepStatus],
pause_reason: &PauseReason,
args: &HashMap<String, String>,
execution_context: &ExecutionContext,
) -> PromptMessage {
let mut text = String::new();
match pause_reason {
PauseReason::ToolError {
failed_step,
error,
retryable,
..
} => {
text.push_str(&format!("Step '{}' failed: {}.", failed_step, error));
if *retryable {
text.push_str(" This step is retryable.");
}
},
PauseReason::UnresolvableParams {
blocked_step,
missing_param,
..
} => {
text.push_str(&format!(
"Could not resolve parameter '{}' for step '{}'.",
missing_param, blocked_step
));
},
PauseReason::SchemaMismatch {
blocked_step,
missing_fields,
..
} => {
let fields = missing_fields.join(", ");
text.push_str(&format!(
"Step '{}' has missing required fields: {}.",
blocked_step, fields
));
},
PauseReason::UnresolvedDependency {
blocked_step,
missing_output,
producing_step,
..
} => {
text.push_str(&format!(
"Step '{}' depends on output '{}' from step '{}', which did not complete.",
blocked_step, missing_output, producing_step
));
},
}
text.push_str("\n\nTo continue the workflow, make these tool calls:\n\n");
let mut step_num = 1;
if let PauseReason::ToolError {
failed_step,
retryable: true,
..
} = pause_reason
{
for (idx, step) in self.workflow.steps().iter().enumerate() {
if step_statuses.get(idx) == Some(&StepStatus::Failed)
&& step.name().as_str() == failed_step.as_str()
{
let tool_name = step
.tool()
.map_or_else(|| "unknown".to_string(), |t| t.name().to_string());
let args_str =
match self
.inner
.resolve_tool_parameters(step, args, execution_context)
{
Ok(resolved) => serde_json::to_string(&resolved)
.unwrap_or_else(|_| "{}".to_string()),
Err(_) => Self::build_placeholder_args(step, args),
};
text.push_str(&format!(
"{}. Call {} with {}\n",
step_num, tool_name, args_str
));
if let Some(guidance) = step.guidance() {
let guidance_text =
WorkflowPromptHandler::substitute_arguments(guidance, args);
text.push_str(&format!(" Note: {}\n", guidance_text));
}
step_num += 1;
break;
}
}
}
for (idx, step) in self.workflow.steps().iter().enumerate() {
if step_statuses.get(idx) != Some(&StepStatus::Pending) {
continue;
}
let tool_name = step
.tool()
.map_or_else(|| "unknown".to_string(), |t| t.name().to_string());
let args_str = match self
.inner
.resolve_tool_parameters(step, args, execution_context)
{
Ok(resolved) => {
serde_json::to_string(&resolved).unwrap_or_else(|_| "{}".to_string())
},
Err(_) => Self::build_placeholder_args(step, args),
};
text.push_str(&format!(
"{}. Call {} with {}\n",
step_num, tool_name, args_str
));
if let Some(guidance) = step.guidance() {
let guidance_text = WorkflowPromptHandler::substitute_arguments(guidance, args);
text.push_str(&format!(" Note: {}\n", guidance_text));
}
step_num += 1;
}
PromptMessage::assistant(Content::text(text))
}
fn resolve_owner(&self, extra: &RequestHandlerExtra) -> String {
match &extra.auth_context {
Some(ctx) => {
self.task_router
.resolve_owner(Some(&ctx.subject), ctx.client_id.as_deref(), None)
},
None => self.task_router.resolve_owner(None, None, None),
}
}
}
fn classify_resolution_failure(
step: &WorkflowStep,
all_steps: &[WorkflowStep],
step_statuses: &[StepStatus],
) -> PauseReason {
let blocked_step = step.name().to_string();
let suggested_tool = step
.tool()
.map(|t| t.name().to_string())
.unwrap_or_default();
for (_arg_name, data_source) in step.arguments() {
if let DataSource::StepOutput {
step: binding_name, ..
} = data_source
{
for (idx, producing_step) in all_steps.iter().enumerate() {
if let Some(binding) = producing_step.binding() {
if binding.as_str() == binding_name.as_str() {
if let Some(status) = step_statuses.get(idx) {
if *status == StepStatus::Failed || *status == StepStatus::Skipped {
let producing_tool = producing_step
.tool()
.map(|t| t.name().to_string())
.unwrap_or_default();
return PauseReason::UnresolvedDependency {
blocked_step,
missing_output: binding_name.to_string(),
producing_step: producing_step.name().to_string(),
suggested_tool: producing_tool,
};
}
}
}
}
}
}
}
let missing_param = step
.arguments()
.keys()
.next()
.map_or_else(|| "unknown".to_string(), |k| k.to_string());
PauseReason::UnresolvableParams {
blocked_step,
missing_param,
suggested_tool,
}
}
fn build_updated_progress(initial: &Value, step_statuses: &[StepStatus]) -> Value {
let mut progress = initial.clone();
if let Some(steps) = progress.get_mut("steps").and_then(|s| s.as_array_mut()) {
for (i, step) in steps.iter_mut().enumerate() {
if let Some(status) = step_statuses.get(i) {
if let Some(obj) = step.as_object_mut() {
obj.insert(
"status".to_string(),
Value::String(status.as_str().to_string()),
);
}
}
}
}
progress
}
#[async_trait]
impl PromptHandler for TaskWorkflowPromptHandler {
async fn handle(
&self,
args: HashMap<String, String>,
extra: RequestHandlerExtra,
) -> Result<GetPromptResult> {
let owner_id = self.resolve_owner(&extra);
let initial_progress = self.build_initial_progress_typed();
let task_id = match self
.task_router
.create_workflow_task(self.workflow.name(), &owner_id, initial_progress.clone())
.await
{
Ok(value) => value
.get("task")
.and_then(|t| t.get("taskId"))
.and_then(|v| v.as_str())
.map(String::from),
Err(e) => {
tracing::warn!(
"Task creation failed for workflow '{}', proceeding without task tracking: {}",
self.workflow.name(),
e
);
None
},
};
let Some(task_id) = task_id else {
let result = self.inner.handle(args, extra).await?;
return Ok(result);
};
let step_count = self.workflow.steps().len();
let total_steps = step_count;
let mut messages: Vec<PromptMessage> = Vec::new();
let mut execution_context = ExecutionContext::new();
let mut step_results: Vec<(String, Value)> = Vec::new();
let mut step_statuses: Vec<StepStatus> = vec![StepStatus::Pending; step_count];
let mut pause_reason: Option<PauseReason> = None;
messages.push(self.inner.create_user_intent(&args));
messages.push(self.inner.create_assistant_plan()?);
for (idx, step) in self.workflow.steps().iter().enumerate() {
if extra.is_cancelled() {
tracing::warn!("Workflow cancelled at step: {}", step.name());
return Err(crate::Error::internal(format!(
"Workflow '{}' cancelled at step {}",
self.workflow.name(),
step.name()
)));
}
let progress_message = format!("Step {}/{}: {}", idx + 1, total_steps, step.name());
if let Err(e) = extra
.report_count(idx + 1, total_steps, Some(progress_message))
.await
{
tracing::warn!("Failed to report workflow progress: {}", e);
}
if let Some(guidance_template) = step.guidance() {
let guidance_text =
WorkflowPromptHandler::substitute_arguments(guidance_template, &args);
messages.push(PromptMessage::assistant(Content::text(guidance_text)));
}
let fetch_resources_after_tool =
WorkflowPromptHandler::template_bindings_use_step_outputs(step.template_bindings());
if !fetch_resources_after_tool
&& !step.resources().is_empty()
&& self
.inner
.fetch_step_resources(step, &args, &execution_context, &extra, &mut messages)
.await
.is_err()
{
break;
}
if step.is_resource_only() {
messages.push(PromptMessage::assistant(Content::text(format!(
"I'll fetch the required resources for {}...",
step.name()
))));
if fetch_resources_after_tool
&& self
.inner
.fetch_step_resources(
step,
&args,
&execution_context,
&extra,
&mut messages,
)
.await
.is_err()
{
break;
}
step_statuses[idx] = StepStatus::Completed;
step_results.push((step.name().to_string(), Value::Null));
continue;
}
match self
.inner
.create_tool_call_announcement(step, &args, &execution_context)
{
Err(_) => {
pause_reason = Some(classify_resolution_failure(
step,
self.workflow.steps(),
&step_statuses,
));
break;
},
Ok(announcement) => {
let Ok(params) =
self.inner
.resolve_tool_parameters(step, &args, &execution_context)
else {
tracing::warn!(
"resolve_tool_parameters failed unexpectedly for step '{}' \
after announcement succeeded",
step.name()
);
pause_reason = Some(classify_resolution_failure(
step,
self.workflow.steps(),
&step_statuses,
));
break;
};
match self.inner.params_satisfy_tool_schema(step, ¶ms) {
Err(e) => {
tracing::warn!(
"params_satisfy_tool_schema error for step '{}': {}",
step.name(),
e
);
pause_reason = Some(PauseReason::UnresolvableParams {
blocked_step: step.name().to_string(),
missing_param: "unknown".to_string(),
suggested_tool: step
.tool()
.map(|t| t.name().to_string())
.unwrap_or_default(),
});
break;
},
Ok(ref missing) if !missing.is_empty() => {
let suggested_tool = step
.tool()
.map(|t| t.name().to_string())
.unwrap_or_default();
pause_reason = Some(PauseReason::SchemaMismatch {
blocked_step: step.name().to_string(),
missing_fields: missing.clone(),
suggested_tool,
});
break;
},
Ok(_) => {
messages.push(announcement);
match self
.inner
.execute_tool_step(step, &args, &execution_context, &extra)
.await
{
Ok(result) => {
messages.push(PromptMessage::user(Content::text(format!(
"Tool result:\n{}",
serde_json::to_string_pretty(&result)
.unwrap_or_else(|_| format!("{:?}", result))
))));
step_results.push((step.name().to_string(), result.clone()));
step_statuses[idx] = StepStatus::Completed;
if let Some(binding) = step.binding() {
execution_context.store_binding(binding.clone(), result);
}
if fetch_resources_after_tool
&& self
.inner
.fetch_step_resources(
step,
&args,
&execution_context,
&extra,
&mut messages,
)
.await
.is_err()
{
break;
}
},
Err(e) => {
messages.push(PromptMessage::user(Content::text(format!(
"Error executing tool: {}",
e
))));
let step_name = step.name().to_string();
step_results.push((
step_name.clone(),
serde_json::json!({"error": e.to_string()}),
));
step_statuses[idx] = StepStatus::Failed;
let suggested_tool = step
.tool()
.map(|t| t.name().to_string())
.unwrap_or_default();
pause_reason = Some(PauseReason::ToolError {
failed_step: step_name,
error: e.to_string(),
retryable: step.is_retryable(),
suggested_tool,
});
break;
},
}
},
}
},
}
}
if let Some(ref reason) = pause_reason {
let handoff =
self.build_handoff_message(&step_statuses, reason, &args, &execution_context);
messages.push(handoff);
}
let updated_progress = build_updated_progress(&initial_progress, &step_statuses);
let mut batch: HashMap<String, Value> = HashMap::new();
batch.insert(WORKFLOW_PROGRESS_KEY.to_string(), updated_progress);
for (step_name, result) in &step_results {
batch.insert(workflow_result_key(step_name), result.clone());
}
if let Some(ref reason) = pause_reason {
batch.insert(WORKFLOW_PAUSE_REASON_KEY.to_string(), reason.to_value());
}
let batch_value = serde_json::to_value(&batch).unwrap_or_else(|_| serde_json::json!({}));
if let Err(e) = self
.task_router
.set_task_variables(&task_id, &owner_id, batch_value)
.await
{
tracing::warn!(
"Failed to batch-write task variables for workflow '{}': {}",
self.workflow.name(),
e
);
}
let mut task_status = "working";
let all_completed =
pause_reason.is_none() && step_statuses.iter().all(|s| *s == StepStatus::Completed);
if all_completed {
let completion_result = serde_json::json!({
"completed": true,
"steps_completed": step_count,
});
match self
.task_router
.complete_workflow_task(&task_id, &owner_id, completion_result)
.await
{
Ok(_) => {
task_status = "completed";
},
Err(e) => {
tracing::warn!(
"Failed to auto-complete task for workflow '{}': {}",
self.workflow.name(),
e
);
},
}
}
let step_names: Vec<String> = self
.workflow
.steps()
.iter()
.map(|s| s.name().to_string())
.collect();
let meta = Self::build_meta_map(
&task_id,
task_status,
&step_names,
&step_statuses,
pause_reason.as_ref(),
);
let _ = extra
.report_count(
total_steps,
total_steps,
Some("Workflow execution complete".to_string()),
)
.await;
let mut result = GetPromptResult {
description: Some(self.workflow.description().to_string()),
messages,
_meta: None,
};
result._meta = Some(meta);
Ok(result)
}
fn metadata(&self) -> Option<PromptInfo> {
self.inner.metadata()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_meta_map_working_with_pause_reason() {
let step_names = vec!["validate".to_string(), "deploy".to_string()];
let step_statuses = vec![StepStatus::Completed, StepStatus::Pending];
let pause = PauseReason::ToolError {
failed_step: "deploy".to_string(),
error: "timeout".to_string(),
retryable: true,
suggested_tool: "deploy_service".to_string(),
};
let meta = TaskWorkflowPromptHandler::build_meta_map(
"task-123",
"working",
&step_names,
&step_statuses,
Some(&pause),
);
assert_eq!(meta["task_id"], "task-123");
assert_eq!(meta["task_status"], "working");
let pr = meta.get("pause_reason").expect("should have pause_reason");
assert_eq!(pr["type"], "toolError");
assert_eq!(pr["failedStep"], "deploy");
assert_eq!(pr["retryable"], true);
let steps = meta["steps"].as_array().unwrap();
assert_eq!(steps.len(), 2);
assert_eq!(steps[0]["status"], "completed");
assert_eq!(steps[1]["status"], "pending");
}
#[test]
fn build_meta_map_completed_no_pause_reason() {
let step_names = vec![
"validate".to_string(),
"deploy".to_string(),
"notify".to_string(),
];
let step_statuses = vec![
StepStatus::Completed,
StepStatus::Completed,
StepStatus::Completed,
];
let meta = TaskWorkflowPromptHandler::build_meta_map(
"task-456",
"completed",
&step_names,
&step_statuses,
None,
);
assert_eq!(meta["task_id"], "task-456");
assert_eq!(meta["task_status"], "completed");
assert!(
meta.get("pause_reason").is_none(),
"completed task should not have pause_reason"
);
let steps = meta["steps"].as_array().unwrap();
assert_eq!(steps.len(), 3);
for step in steps {
assert_eq!(step["status"], "completed");
}
}
#[test]
fn build_meta_map_with_step_statuses() {
let step_names = vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
];
let step_statuses = vec![
StepStatus::Completed,
StepStatus::Failed,
StepStatus::Skipped,
StepStatus::Pending,
];
let meta = TaskWorkflowPromptHandler::build_meta_map(
"task-789",
"working",
&step_names,
&step_statuses,
None,
);
let steps = meta["steps"].as_array().unwrap();
assert_eq!(steps[0]["status"], "completed");
assert_eq!(steps[1]["status"], "failed");
assert_eq!(steps[2]["status"], "skipped");
assert_eq!(steps[3]["status"], "pending");
}
#[test]
fn build_meta_map_empty_steps() {
let meta = TaskWorkflowPromptHandler::build_meta_map("task-000", "working", &[], &[], None);
assert_eq!(meta["task_id"], "task-000");
assert_eq!(meta["task_status"], "working");
let steps = meta["steps"].as_array().expect("steps should be an array");
assert!(steps.is_empty());
}
#[test]
fn build_initial_progress_typed_all_pending() {
use super::super::handles::ToolHandle;
let workflow = SequentialWorkflow::new("test_wf", "Test workflow")
.step(WorkflowStep::new("validate", ToolHandle::new("checker")))
.step(WorkflowStep::new("deploy", ToolHandle::new("deployer")))
.step(
WorkflowStep::fetch_resources("read_guide")
.with_resource("docs://guide")
.expect("valid URI"),
);
let inner = WorkflowPromptHandler::new(
SequentialWorkflow::new("dummy", "dummy"),
HashMap::new(),
HashMap::new(),
None,
);
let task_router: Arc<dyn TaskRouter> = Arc::new(DummyTaskRouter);
let handler = TaskWorkflowPromptHandler::new(inner, task_router, workflow);
let progress = handler.build_initial_progress_typed();
assert_eq!(progress["goal"], "test_wf: Test workflow");
assert_eq!(progress["schemaVersion"], 1);
let steps = progress["steps"].as_array().unwrap();
assert_eq!(steps.len(), 3);
for step in steps {
assert_eq!(step["status"], "pending");
}
assert_eq!(steps[0]["name"], "validate");
assert_eq!(steps[0]["tool"], "checker");
assert_eq!(steps[1]["name"], "deploy");
assert_eq!(steps[1]["tool"], "deployer");
assert_eq!(steps[2]["name"], "read_guide");
assert!(
steps[2].get("tool").is_none(),
"resource-only step should not have tool"
);
}
#[test]
fn classify_resolution_failure_unresolved_dependency() {
use super::super::handles::ToolHandle;
let step_a = WorkflowStep::new("produce", ToolHandle::new("fetcher")).bind("data_out");
let step_b = WorkflowStep::new("consume", ToolHandle::new("processor"))
.arg("input", DataSource::from_step("data_out"));
let all_steps = vec![step_a, step_b.clone()];
let statuses = vec![StepStatus::Failed, StepStatus::Pending];
let result = classify_resolution_failure(&step_b, &all_steps, &statuses);
match result {
PauseReason::UnresolvedDependency {
blocked_step,
missing_output,
producing_step,
suggested_tool,
} => {
assert_eq!(blocked_step, "consume");
assert_eq!(missing_output, "data_out");
assert_eq!(producing_step, "produce");
assert_eq!(suggested_tool, "fetcher");
},
other => panic!("Expected UnresolvedDependency, got: {:?}", other),
}
}
#[test]
fn classify_resolution_failure_generic_unresolvable() {
use super::super::handles::ToolHandle;
let step = WorkflowStep::new("do_thing", ToolHandle::new("tool_x"))
.arg("name", DataSource::prompt_arg("missing_arg"));
let all_steps = vec![step.clone()];
let statuses = vec![StepStatus::Pending];
let result = classify_resolution_failure(&step, &all_steps, &statuses);
match result {
PauseReason::UnresolvableParams {
blocked_step,
missing_param,
suggested_tool,
} => {
assert_eq!(blocked_step, "do_thing");
assert_eq!(missing_param, "name");
assert_eq!(suggested_tool, "tool_x");
},
other => panic!("Expected UnresolvableParams, got: {:?}", other),
}
}
#[test]
fn classify_resolution_failure_skipped_producer() {
use super::super::handles::ToolHandle;
let step_a = WorkflowStep::new("gather", ToolHandle::new("gatherer")).bind("info");
let step_b = WorkflowStep::new("analyze", ToolHandle::new("analyzer"))
.arg("data", DataSource::from_step("info"));
let all_steps = vec![step_a, step_b.clone()];
let statuses = vec![StepStatus::Skipped, StepStatus::Pending];
let result = classify_resolution_failure(&step_b, &all_steps, &statuses);
match result {
PauseReason::UnresolvedDependency {
blocked_step,
missing_output,
producing_step,
suggested_tool,
} => {
assert_eq!(blocked_step, "analyze");
assert_eq!(missing_output, "info");
assert_eq!(producing_step, "gather");
assert_eq!(suggested_tool, "gatherer");
},
other => panic!("Expected UnresolvedDependency, got: {:?}", other),
}
}
#[test]
fn pause_reason_to_value_all_variants() {
let reason = PauseReason::UnresolvableParams {
blocked_step: "step_a".to_string(),
missing_param: "param_x".to_string(),
suggested_tool: "tool_y".to_string(),
};
let val = reason.to_value();
assert_eq!(val["type"], "unresolvableParams");
assert_eq!(val["blockedStep"], "step_a");
assert_eq!(val["missingParam"], "param_x");
assert_eq!(val["suggestedTool"], "tool_y");
let reason = PauseReason::SchemaMismatch {
blocked_step: "step_b".to_string(),
missing_fields: vec!["f1".to_string(), "f2".to_string()],
suggested_tool: "tool_z".to_string(),
};
let val = reason.to_value();
assert_eq!(val["type"], "schemaMismatch");
assert_eq!(val["blockedStep"], "step_b");
assert_eq!(val["missingFields"], serde_json::json!(["f1", "f2"]));
let reason = PauseReason::ToolError {
failed_step: "deploy".to_string(),
error: "connection refused".to_string(),
retryable: false,
suggested_tool: "deploy_service".to_string(),
};
let val = reason.to_value();
assert_eq!(val["type"], "toolError");
assert_eq!(val["failedStep"], "deploy");
assert_eq!(val["error"], "connection refused");
assert_eq!(val["retryable"], false);
let reason = PauseReason::UnresolvedDependency {
blocked_step: "step_c".to_string(),
missing_output: "data".to_string(),
producing_step: "step_a".to_string(),
suggested_tool: "fetch_data".to_string(),
};
let val = reason.to_value();
assert_eq!(val["type"], "unresolvedDependency");
assert_eq!(val["blockedStep"], "step_c");
assert_eq!(val["missingOutput"], "data");
assert_eq!(val["producingStep"], "step_a");
assert_eq!(val["suggestedTool"], "fetch_data");
}
#[test]
fn workflow_result_key_produces_correct_keys() {
assert_eq!(workflow_result_key("validate"), "_workflow.result.validate");
assert_eq!(workflow_result_key("deploy"), "_workflow.result.deploy");
}
#[test]
fn build_updated_progress_applies_statuses() {
let initial = serde_json::json!({
"goal": "Test",
"steps": [
{"name": "a", "tool": "tool_a", "status": "pending"},
{"name": "b", "tool": "tool_b", "status": "pending"},
{"name": "c", "status": "pending"}
],
"schemaVersion": 1
});
let statuses = vec![
StepStatus::Completed,
StepStatus::Failed,
StepStatus::Pending,
];
let updated = build_updated_progress(&initial, &statuses);
let steps = updated["steps"].as_array().unwrap();
assert_eq!(steps[0]["status"], "completed");
assert_eq!(steps[1]["status"], "failed");
assert_eq!(steps[2]["status"], "pending");
assert_eq!(updated["goal"], "Test");
assert_eq!(updated["schemaVersion"], 1);
}
struct DummyTaskRouter;
#[async_trait]
impl TaskRouter for DummyTaskRouter {
async fn handle_task_call(
&self,
_tool_name: &str,
_arguments: Value,
_task_params: Value,
_owner_id: &str,
_progress_token: Option<Value>,
) -> Result<Value> {
unimplemented!()
}
async fn handle_tasks_get(&self, _params: Value, _owner_id: &str) -> Result<Value> {
unimplemented!()
}
async fn handle_tasks_result(&self, _params: Value, _owner_id: &str) -> Result<Value> {
unimplemented!()
}
async fn handle_tasks_list(&self, _params: Value, _owner_id: &str) -> Result<Value> {
unimplemented!()
}
async fn handle_tasks_cancel(&self, _params: Value, _owner_id: &str) -> Result<Value> {
unimplemented!()
}
fn resolve_owner(
&self,
_subject: Option<&str>,
_client_id: Option<&str>,
_session_id: Option<&str>,
) -> String {
"owner".to_string()
}
fn tool_requires_task(&self, _tool_name: &str, _tool_execution: Option<&Value>) -> bool {
false
}
fn task_capabilities(&self) -> Value {
serde_json::json!({})
}
}
fn make_handler(workflow: SequentialWorkflow) -> TaskWorkflowPromptHandler {
let inner =
WorkflowPromptHandler::new(workflow.clone(), HashMap::new(), HashMap::new(), None);
let task_router: Arc<dyn TaskRouter> = Arc::new(DummyTaskRouter);
TaskWorkflowPromptHandler::new(inner, task_router, workflow)
}
#[test]
fn handoff_message_tool_error_retryable() {
use super::super::handles::ToolHandle;
let workflow = SequentialWorkflow::new("deploy_wf", "Deploy workflow")
.step(WorkflowStep::new("validate", ToolHandle::new("checker")))
.step(
WorkflowStep::new("deploy", ToolHandle::new("deploy_service"))
.arg("region", DataSource::prompt_arg("region"))
.retryable(true),
)
.step(
WorkflowStep::new("notify", ToolHandle::new("notify_team"))
.arg("result", DataSource::from_step("deploy_out")),
);
let handler = make_handler(workflow);
let step_statuses = vec![
StepStatus::Completed,
StepStatus::Failed,
StepStatus::Pending,
];
let pause = PauseReason::ToolError {
failed_step: "deploy".to_string(),
error: "connection timeout".to_string(),
retryable: true,
suggested_tool: "deploy_service".to_string(),
};
let mut args = HashMap::new();
args.insert("region".to_string(), "us-east-1".to_string());
let ctx = ExecutionContext::new();
let msg = handler.build_handoff_message(&step_statuses, &pause, &args, &ctx);
assert_eq!(msg.role, Role::Assistant);
let text = match &msg.content {
Content::Text { text } => text.as_str(),
_ => panic!("Expected text content"),
};
assert!(
text.contains("Step 'deploy' failed: connection timeout."),
"should contain failure description"
);
assert!(
text.contains("This step is retryable."),
"should note retryable"
);
assert!(
text.contains("To continue the workflow, make these tool calls:"),
"should contain continuation header"
);
assert!(
text.contains("1. Call deploy_service with"),
"retryable failed step should be first"
);
assert!(
text.contains("2. Call notify_team with"),
"pending step should follow"
);
}
#[test]
fn handoff_message_unresolvable_params() {
use super::super::handles::ToolHandle;
let workflow = SequentialWorkflow::new("wf", "Workflow").step(
WorkflowStep::new("step_a", ToolHandle::new("tool_a"))
.arg("x", DataSource::prompt_arg("missing")),
);
let handler = make_handler(workflow);
let step_statuses = vec![StepStatus::Pending];
let pause = PauseReason::UnresolvableParams {
blocked_step: "step_a".to_string(),
missing_param: "x".to_string(),
suggested_tool: "tool_a".to_string(),
};
let msg = handler.build_handoff_message(
&step_statuses,
&pause,
&HashMap::new(),
&ExecutionContext::new(),
);
let text = match &msg.content {
Content::Text { text } => text.as_str(),
_ => panic!("Expected text content"),
};
assert!(
text.contains("Could not resolve parameter 'x' for step 'step_a'."),
"should describe unresolvable param. Got: {}",
text
);
assert!(
text.contains("1. Call tool_a with"),
"should list remaining step"
);
}
#[test]
fn handoff_message_unresolved_dependency() {
use super::super::handles::ToolHandle;
let workflow = SequentialWorkflow::new("wf", "Workflow")
.step(WorkflowStep::new("produce", ToolHandle::new("fetcher")).bind("data"))
.step(
WorkflowStep::new("consume", ToolHandle::new("processor"))
.arg("input", DataSource::from_step("data")),
);
let handler = make_handler(workflow);
let step_statuses = vec![StepStatus::Failed, StepStatus::Pending];
let pause = PauseReason::UnresolvedDependency {
blocked_step: "consume".to_string(),
missing_output: "data".to_string(),
producing_step: "produce".to_string(),
suggested_tool: "fetcher".to_string(),
};
let msg = handler.build_handoff_message(
&step_statuses,
&pause,
&HashMap::new(),
&ExecutionContext::new(),
);
let text = match &msg.content {
Content::Text { text } => text.as_str(),
_ => panic!("Expected text content"),
};
assert!(
text.contains("Step 'consume' depends on output 'data' from step 'produce', which did not complete."),
"should mention dependency. Got: {}",
text
);
assert!(text.contains("produce"), "should mention producing step");
}
#[test]
fn handoff_message_schema_mismatch() {
use super::super::handles::ToolHandle;
let workflow = SequentialWorkflow::new("wf", "Workflow")
.step(WorkflowStep::new("step_b", ToolHandle::new("tool_z")));
let handler = make_handler(workflow);
let step_statuses = vec![StepStatus::Pending];
let pause = PauseReason::SchemaMismatch {
blocked_step: "step_b".to_string(),
missing_fields: vec!["field_1".to_string(), "field_2".to_string()],
suggested_tool: "tool_z".to_string(),
};
let msg = handler.build_handoff_message(
&step_statuses,
&pause,
&HashMap::new(),
&ExecutionContext::new(),
);
let text = match &msg.content {
Content::Text { text } => text.as_str(),
_ => panic!("Expected text content"),
};
assert!(
text.contains("Step 'step_b' has missing required fields: field_1, field_2."),
"should list missing fields. Got: {}",
text
);
}
#[test]
fn handoff_message_no_task_id_in_text() {
use super::super::handles::ToolHandle;
let workflow = SequentialWorkflow::new("wf", "Workflow")
.step(WorkflowStep::new("s1", ToolHandle::new("t1")))
.step(WorkflowStep::new("s2", ToolHandle::new("t2")));
let handler = make_handler(workflow);
let step_statuses = vec![StepStatus::Completed, StepStatus::Pending];
let pause = PauseReason::ToolError {
failed_step: "s1".to_string(),
error: "oops".to_string(),
retryable: false,
suggested_tool: "t1".to_string(),
};
let msg = handler.build_handoff_message(
&step_statuses,
&pause,
&HashMap::new(),
&ExecutionContext::new(),
);
let text = match &msg.content {
Content::Text { text } => text.as_str(),
_ => panic!("Expected text content"),
};
assert!(
!text.contains("task_id"),
"narrative should not contain task_id"
);
assert!(
!text.contains("task-"),
"narrative should not contain task- prefix"
);
}
#[test]
fn placeholder_args_step_output() {
use super::super::handles::ToolHandle;
let step = WorkflowStep::new("consume", ToolHandle::new("processor"))
.arg("data", DataSource::from_step("result_binding"))
.arg("name", DataSource::prompt_arg("user_name"))
.arg("flag", DataSource::constant(serde_json::json!(true)))
.arg(
"detail",
DataSource::from_step_field("other_binding", "nested_field"),
);
let mut args = HashMap::new();
args.insert("user_name".to_string(), "Alice".to_string());
let result = TaskWorkflowPromptHandler::build_placeholder_args(&step, &args);
let parsed: serde_json::Value = serde_json::from_str(&result).expect("valid JSON");
let obj = parsed.as_object().expect("should be an object");
assert_eq!(
obj["data"], "<output from result_binding>",
"StepOutput without field should use placeholder"
);
assert_eq!(
obj["name"], "Alice",
"PromptArg with available value should resolve"
);
assert_eq!(obj["flag"], true, "Constant should serialize as-is");
assert_eq!(
obj["detail"], "<field 'nested_field' from other_binding>",
"StepOutput with field should use field placeholder"
);
}
#[test]
fn handoff_includes_guidance() {
use super::super::handles::ToolHandle;
let workflow = SequentialWorkflow::new("wf", "Workflow")
.step(WorkflowStep::new("validate", ToolHandle::new("checker")))
.step(
WorkflowStep::new("deploy", ToolHandle::new("deploy_service"))
.arg("region", DataSource::prompt_arg("region"))
.with_guidance("Deploy to the '{region}' region with validated config"),
);
let handler = make_handler(workflow);
let step_statuses = vec![StepStatus::Completed, StepStatus::Pending];
let pause = PauseReason::ToolError {
failed_step: "validate".to_string(),
error: "check failed".to_string(),
retryable: false,
suggested_tool: "checker".to_string(),
};
let mut args = HashMap::new();
args.insert("region".to_string(), "us-west-2".to_string());
let msg =
handler.build_handoff_message(&step_statuses, &pause, &args, &ExecutionContext::new());
let text = match &msg.content {
Content::Text { text } => text.as_str(),
_ => panic!("Expected text content"),
};
assert!(
text.contains("Note: Deploy to the 'us-west-2' region with validated config"),
"should include guidance with substituted args. Got: {}",
text
);
}
}