use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use either::Either;
use pravah::clients::{
Client, ClientError, ClientFactory, ClientOptions, ClientOutput, ClientResponse, Provider,
ToolCall,
};
use pravah::commons::Agent;
use pravah::flows::{AgentConfig, Flow, FlowError, FlowGraph, FlowRuntime, RunOut};
use pravah::tools::{Tool, ToolBox, ToolError};
use pravah::Context;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
struct MockHandle {
responses: Arc<Mutex<VecDeque<ClientResponse>>>,
}
#[async_trait]
impl Client for MockHandle {
async fn execute(&self, _messages: &[pravah::clients::Message]) -> Result<ClientResponse, ClientError> {
self.responses
.lock()
.unwrap()
.pop_front()
.ok_or_else(|| ClientError::Llm("mock: response queue exhausted".into()))
}
}
struct MockFactory {
responses: Arc<Mutex<VecDeque<ClientResponse>>>,
}
impl MockFactory {
fn new(responses: Vec<ClientResponse>) -> Self {
Self {
responses: Arc::new(Mutex::new(responses.into())),
}
}
}
impl ClientFactory for MockFactory {
fn create(&self, _url: &str, _opts: ClientOptions) -> Result<Box<dyn Client>, ClientError> {
Ok(Box::new(MockHandle {
responses: Arc::clone(&self.responses),
}))
}
}
fn resp(val: serde_json::Value) -> ClientResponse {
ClientResponse::new(Provider::OpenAi, ClientOutput::Output(val))
}
fn tool_resp(calls: Vec<ToolCall>) -> ClientResponse {
ClientResponse::new(
Provider::OpenAi,
ClientOutput::ToolCalls {
thought: None,
calls,
},
)
}
fn call(name: &str, args: serde_json::Value) -> ToolCall {
ToolCall {
id: format!("id-{name}"),
name: name.to_string(),
args,
thought_signatures: None,
}
}
fn ctx() -> Context {
Context::new(pravah::FlowConf {
working_dir: Some(std::env::temp_dir()),
..Default::default()
})
}
macro_rules! run_to_done {
($rt:expr) => {{
let c = ctx();
loop {
match $rt.next(c.clone()).await.expect("next() failed") {
RunOut::Continue => {}
RunOut::Done(v) => break v,
RunOut::Suspend { .. } => panic!("unexpected suspension"),
}
}
}};
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
struct SimpleIn {
task: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq)]
struct SimpleOut {
signed_by: String,
}
#[derive(Debug, Deserialize, JsonSchema)]
struct GateTool {
reason: String,
}
impl Tool for GateTool {
type Output = serde_json::Value;
fn name() -> &'static str {
"gate"
}
fn description() -> &'static str {
"Block until an external reviewer approves"
}
async fn call(self, _ctx: Context) -> Result<Self::Output, ToolError> {
Err(ToolError::suspend(json!({"reason": self.reason})))
}
}
impl Agent for SimpleIn {
type Output = SimpleOut;
fn build() -> AgentConfig {
AgentConfig::new("Approval agent", "openai://test-model")
.with_tools(ToolBox::builder().tool::<GateTool>().build())
}
}
impl Flow for SimpleIn {
type Output = SimpleOut;
fn build() -> Result<FlowGraph, FlowError> {
FlowGraph::builder().agent::<SimpleIn>().build()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
struct SetupIn {
seed: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
struct TaskIn {
prompt: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
struct ReviewIn {
result: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq)]
struct ReviewOut {
summary: String,
}
#[derive(Debug, Deserialize, JsonSchema)]
struct CheckTool {
note: String,
}
impl Tool for CheckTool {
type Output = serde_json::Value;
fn name() -> &'static str {
"check"
}
fn description() -> &'static str {
"Check item and suspend"
}
async fn call(self, _ctx: Context) -> Result<Self::Output, ToolError> {
Err(ToolError::suspend(json!({"note": self.note})))
}
}
impl Agent for TaskIn {
type Output = ReviewIn;
fn build() -> AgentConfig {
AgentConfig::new("task agent", "openai://test-model")
.with_tools(ToolBox::builder().tool::<CheckTool>().build())
}
}
impl Flow for SetupIn {
type Output = ReviewOut;
fn build() -> Result<FlowGraph, FlowError> {
FlowGraph::builder()
.work::<SetupIn, TaskIn, _, _>(|inp, _| async move {
Ok(TaskIn {
prompt: format!("process seed={}", inp.seed),
})
})
.agent::<TaskIn>()
.work::<ReviewIn, ReviewOut, _, _>(|r, _| async move {
Ok(ReviewOut {
summary: r.result.to_uppercase(),
})
})
.build()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
struct SplitIn {
n: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
struct BranchAlpha {
n: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
struct BranchBeta {
n: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
struct ProcessedBeta {
n: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq)]
struct MergedOut {
total: i32,
}
impl Flow for SplitIn {
type Output = MergedOut;
fn build() -> Result<FlowGraph, FlowError> {
FlowGraph::builder()
.fork::<SplitIn, BranchAlpha, BranchBeta, _>(|inp, _| {
Ok((BranchAlpha { n: inp.n }, BranchBeta { n: inp.n * 3 }))
})
.work::<BranchBeta, ProcessedBeta, _, _>(|b, _| async move {
Ok(ProcessedBeta { n: b.n + 1 })
})
.join::<BranchAlpha, ProcessedBeta, MergedOut, _>(|a, pb, _| {
Ok(MergedOut { total: a.n + pb.n })
})
.build()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
struct RouteIn {
go_left: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
struct LeftPath {
tag: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
struct RightPath {
tag: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq)]
struct RouteOut {
tag: String,
from_left: bool,
}
impl Flow for RouteIn {
type Output = RouteOut;
fn build() -> Result<FlowGraph, FlowError> {
FlowGraph::builder()
.either::<RouteIn, LeftPath, RightPath, _>(|inp, _| {
if inp.go_left {
Ok(Either::Left(LeftPath { tag: "L".into() }))
} else {
Ok(Either::Right(RightPath { tag: "R".into() }))
}
})
.work::<LeftPath, RouteOut, _, _>(|l, _| async move {
Ok(RouteOut { tag: l.tag, from_left: true })
})
.work::<RightPath, RouteOut, _, _>(|r, _| async move {
Ok(RouteOut { tag: r.tag, from_left: false })
})
.build()
}
}
#[tokio::test]
async fn simple_suspend_and_resume_completes() {
let factory = MockFactory::new(vec![
tool_resp(vec![call("gate", json!({"reason": "needs approval"}))]),
tool_resp(vec![call("submit", json!({"signed_by": "alice"}))]),
]);
let mut rt = FlowRuntime::new(SimpleIn { task: "deploy".into() })
.unwrap()
.with_factory(factory);
let out = rt.next(ctx()).await.expect("next() failed");
let tool_id = match out {
RunOut::Suspend { tool_id, .. } => tool_id,
other => panic!("expected Suspend, got {other:?}"),
};
assert!(tool_id.contains("gate"), "tool_id should mention 'gate': {tool_id}");
let after_resume = rt
.resume(ctx(), (tool_id, json!({"approved": true})))
.await
.expect("resume() failed");
assert!(
matches!(after_resume, RunOut::Continue),
"expected Continue after resume, got {after_resume:?}"
);
assert!(matches!(rt.next(ctx()).await.unwrap(), RunOut::Continue));
match rt.next(ctx()).await.unwrap() {
RunOut::Done(out) => assert_eq!(out.signed_by, "alice"),
other => panic!("expected Done, got {other:?}"),
}
}
#[tokio::test]
async fn simple_suspend_next_without_resume_errors() {
let factory = MockFactory::new(vec![
tool_resp(vec![call("gate", json!({"reason": "gate"}))]),
]);
let mut rt = FlowRuntime::new(SimpleIn { task: "test".into() })
.unwrap()
.with_factory(factory);
rt.next(ctx()).await.unwrap();
let err = rt.next(ctx()).await.unwrap_err();
assert!(
matches!(err, FlowError::ResumeRequired(_)),
"expected ResumeRequired, got {err:?}"
);
}
#[tokio::test]
async fn simple_resume_with_wrong_tool_id_errors() {
let factory = MockFactory::new(vec![
tool_resp(vec![call("gate", json!({"reason": "gate"}))]),
]);
let mut rt = FlowRuntime::new(SimpleIn { task: "test".into() })
.unwrap()
.with_factory(factory);
rt.next(ctx()).await.unwrap();
let err = rt
.resume(ctx(), ("WrongAgent::wrong_tool".into(), json!({})))
.await
.unwrap_err();
assert!(
matches!(err, FlowError::ResumeMismatchError(_)),
"expected ResumeMismatchError, got {err:?}"
);
}
#[tokio::test]
async fn simple_resume_when_not_suspended_errors() {
let factory = MockFactory::new(vec![resp(json!({"signed_by": "bot"}))]);
let mut rt = FlowRuntime::new(SimpleIn { task: "test".into() })
.unwrap()
.with_factory(factory);
rt.next(ctx()).await.unwrap();
let err = rt
.resume(ctx(), ("SimpleIn::gate".into(), json!({})))
.await
.unwrap_err();
assert!(
matches!(err, FlowError::UnexpectedResumption(_)),
"expected UnexpectedResumption, got {err:?}"
);
}
#[tokio::test]
async fn nested_suspend_and_resume_completes() {
let factory = MockFactory::new(vec![
tool_resp(vec![call("check", json!({"note": "needs review"}))]),
tool_resp(vec![call("submit", json!({"result": "ok"}))]),
]);
let mut rt = FlowRuntime::new(SetupIn { seed: 7 })
.unwrap()
.with_factory(factory);
assert!(matches!(rt.next(ctx()).await.unwrap(), RunOut::Continue));
let out = rt.next(ctx()).await.expect("step 2 failed");
let tool_id = match out {
RunOut::Suspend { tool_id, .. } => tool_id,
other => panic!("expected Suspend after agent step, got {other:?}"),
};
assert!(tool_id.contains("check"), "tool_id should contain 'check': {tool_id}");
let after_resume = rt
.resume(ctx(), (tool_id, json!({"approved": true})))
.await
.expect("resume failed");
assert!(
matches!(after_resume, RunOut::Continue),
"expected Continue after resume, got {after_resume:?}"
);
assert!(matches!(rt.next(ctx()).await.unwrap(), RunOut::Continue));
assert!(matches!(rt.next(ctx()).await.unwrap(), RunOut::Continue));
match rt.next(ctx()).await.unwrap() {
RunOut::Done(out) => assert_eq!(out.summary, "OK"),
other => panic!("expected Done, got {other:?}"),
}
}
#[tokio::test]
async fn nested_suspend_state_is_preserved_through_resume() {
let factory = MockFactory::new(vec![
tool_resp(vec![call("check", json!({"note": "verify"}) )]),
tool_resp(vec![call("submit", json!({"result": "completed"}))]),
]);
let mut rt = FlowRuntime::new(SetupIn { seed: 42 })
.unwrap()
.with_factory(factory);
rt.next(ctx()).await.unwrap(); let step2 = rt.next(ctx()).await.unwrap(); let tool_id = match step2 {
RunOut::Suspend { tool_id, .. } => tool_id,
other => panic!("{other:?}"),
};
rt.resume(ctx(), (tool_id, json!(null))).await.unwrap();
rt.next(ctx()).await.unwrap(); rt.next(ctx()).await.unwrap();
match rt.next(ctx()).await.unwrap() {
RunOut::Done(out) => assert_eq!(out.summary, "COMPLETED"),
other => panic!("expected Done, got {other:?}"),
}
}
#[tokio::test]
async fn fork_join_produces_correct_merged_value() {
let mut rt = FlowRuntime::new(SplitIn { n: 4 }).unwrap();
assert!(matches!(rt.next(ctx()).await.unwrap(), RunOut::Continue));
assert!(matches!(rt.next(ctx()).await.unwrap(), RunOut::Continue));
assert!(matches!(rt.next(ctx()).await.unwrap(), RunOut::Continue));
match rt.next(ctx()).await.unwrap() {
RunOut::Done(out) => assert_eq!(out, MergedOut { total: 17 }),
other => panic!("expected Done, got {other:?}"),
}
}
#[tokio::test]
async fn fork_join_no_dangling_nodes_after_done() {
let mut rt = FlowRuntime::new(SplitIn { n: 2 }).unwrap();
rt.next(ctx()).await.unwrap(); rt.next(ctx()).await.unwrap(); rt.next(ctx()).await.unwrap();
let first_done = rt.next(ctx()).await.unwrap();
assert!(
matches!(first_done, RunOut::Done(_)),
"expected Done, got {first_done:?}"
);
let second = rt.next(ctx()).await.unwrap();
assert!(
matches!(second, RunOut::Done(_)),
"expected Done on second call (no dangling nodes), got {second:?}"
);
}
#[tokio::test]
async fn join_does_not_fire_until_both_branches_present() {
let mut rt = FlowRuntime::new(SplitIn { n: 1 }).unwrap();
assert!(matches!(rt.next(ctx()).await.unwrap(), RunOut::Continue));
let step2 = rt.next(ctx()).await.unwrap();
assert!(
matches!(step2, RunOut::Continue),
"expected Continue (beta work, join not ready), got {step2:?}"
);
assert!(matches!(rt.next(ctx()).await.unwrap(), RunOut::Continue));
match rt.next(ctx()).await.unwrap() {
RunOut::Done(out) => assert_eq!(out.total, 1 + 3 + 1), other => panic!("expected Done, got {other:?}"),
}
}
#[tokio::test]
async fn either_left_branch_taken_and_completed() {
let mut rt = FlowRuntime::new(RouteIn { go_left: true }).unwrap();
assert!(matches!(rt.next(ctx()).await.unwrap(), RunOut::Continue));
assert!(matches!(rt.next(ctx()).await.unwrap(), RunOut::Continue));
match rt.next(ctx()).await.unwrap() {
RunOut::Done(out) => {
assert_eq!(out.tag, "L");
assert!(out.from_left);
}
other => panic!("expected Done, got {other:?}"),
}
}
#[tokio::test]
async fn either_right_branch_taken_and_completed() {
let mut rt = FlowRuntime::new(RouteIn { go_left: false }).unwrap();
assert!(matches!(rt.next(ctx()).await.unwrap(), RunOut::Continue));
assert!(matches!(rt.next(ctx()).await.unwrap(), RunOut::Continue));
match rt.next(ctx()).await.unwrap() {
RunOut::Done(out) => {
assert_eq!(out.tag, "R");
assert!(!out.from_left);
}
other => panic!("expected Done, got {other:?}"),
}
}
#[tokio::test]
async fn either_no_dangling_nodes_after_done() {
let mut rt = FlowRuntime::new(RouteIn { go_left: true }).unwrap();
run_to_done!(rt);
let second = rt.next(ctx()).await.unwrap();
assert!(
matches!(second, RunOut::Done(_)),
"expected Done on second call (no dangling right-branch state), got {second:?}"
);
}
#[tokio::test]
async fn either_left_and_right_are_independent() {
let out_left = {
let mut rt = FlowRuntime::new(RouteIn { go_left: true }).unwrap();
run_to_done!(rt)
};
let out_right = {
let mut rt = FlowRuntime::new(RouteIn { go_left: false }).unwrap();
run_to_done!(rt)
};
assert!(out_left.from_left);
assert!(!out_right.from_left);
assert_ne!(out_left.tag, out_right.tag);
}