use forge_guardrails::{
compact::NoCompact,
fold_and_serialize, format_tool_call_id,
respond::{respond_spec, respond_tool},
workflow::{IntoToolCallable, TerminalToolInput, ToolDef},
ApiFormat, ChunkStream, ClassifierAction, ContextManager, FinalResponseClass,
FinalResponseContext, FinalResponseScore, FinalResponseScorer, ForgeError, LLMClient,
LLMResponse, Message, MessageMeta, MessageRole, MessageType, OnMessageFn, SamplingParams,
ScoringContext, StreamChunk, ToolCall, ToolCallClass, ToolCallInfo, ToolCallScore,
ToolCallScorer, ToolResolutionError, ToolSpec, Workflow, WorkflowRunner,
};
use indexmap::IndexMap;
use serde_json::Value;
use std::sync::atomic::{AtomicI32, Ordering as AtomicOrdering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{watch, Mutex};
mod support;
use support::{ScriptedLlmClient, StreamMode};
type MockClient = ScriptedLlmClient;
struct BackendErrorClient {
call_count: AtomicI32,
}
struct SequenceFinalResponseScorer {
call_count: AtomicI32,
}
struct SleepOnceToolCallScorer {
calls: AtomicI32,
sleep: Duration,
}
impl SequenceFinalResponseScorer {
fn new() -> Self {
Self {
call_count: AtomicI32::new(0),
}
}
}
impl SleepOnceToolCallScorer {
fn new(sleep: Duration) -> Self {
Self {
calls: AtomicI32::new(0),
sleep,
}
}
}
impl FinalResponseScorer for SequenceFinalResponseScorer {
fn score(&self, _ctx: &FinalResponseContext) -> anyhow::Result<FinalResponseScore> {
let idx = self.call_count.fetch_add(1, AtomicOrdering::SeqCst);
if idx == 0 {
Ok(FinalResponseScore {
label: FinalResponseClass::MissingToolFact,
confidence: 0.99,
logits: vec![0.0, 9.0, 0.0, 0.0, 0.0],
action: ClassifierAction::AdvisoryNudge,
model_version: "fake-final".to_string(),
latency_ms: 1.0,
})
} else {
Ok(FinalResponseScore {
label: FinalResponseClass::ValidFinalResponse,
confidence: 1.0,
logits: vec![9.0, 0.0, 0.0, 0.0, 0.0],
action: ClassifierAction::Allow,
model_version: "fake-final".to_string(),
latency_ms: 1.0,
})
}
}
}
impl ToolCallScorer for SleepOnceToolCallScorer {
fn score(&self, _ctx: &ScoringContext, _candidate: &ToolCall) -> anyhow::Result<ToolCallScore> {
if self.calls.fetch_add(1, AtomicOrdering::SeqCst) == 0 {
std::thread::sleep(self.sleep);
}
Ok(ToolCallScore {
label: ToolCallClass::Valid,
confidence: 1.0,
logits: vec![9.0, 0.0],
action: ClassifierAction::Allow,
model_version: "sleep-once-fake".to_string(),
latency_ms: self.sleep.as_secs_f64() * 1000.0,
})
}
}
impl BackendErrorClient {
fn new() -> Self {
Self {
call_count: AtomicI32::new(0),
}
}
}
impl LLMClient for BackendErrorClient {
fn api_format(&self) -> ApiFormat {
ApiFormat::OpenAI
}
async fn send(
&self,
_messages: Vec<Value>,
_tools: Option<Vec<ToolSpec>>,
_sampling: Option<SamplingParams>,
) -> Result<LLMResponse, forge_guardrails::BackendError> {
self.call_count.fetch_add(1, AtomicOrdering::SeqCst);
Err(forge_guardrails::BackendError::new(503, "backend down"))
}
async fn send_stream(
&self,
_messages: Vec<Value>,
_tools: Option<Vec<ToolSpec>>,
_sampling: Option<SamplingParams>,
) -> Result<ChunkStream, forge_guardrails::StreamError> {
Err(forge_guardrails::StreamError::new("not used"))
}
async fn get_context_length(
&self,
) -> Result<Option<i64>, forge_guardrails::ContextDiscoveryError> {
Ok(Some(4096))
}
}
struct NoFinalStreamClient {
call_count: AtomicI32,
}
impl NoFinalStreamClient {
fn new() -> Self {
Self {
call_count: AtomicI32::new(0),
}
}
}
impl LLMClient for NoFinalStreamClient {
fn api_format(&self) -> ApiFormat {
ApiFormat::OpenAI
}
async fn send(
&self,
_messages: Vec<Value>,
_tools: Option<Vec<ToolSpec>>,
_sampling: Option<SamplingParams>,
) -> Result<LLMResponse, forge_guardrails::BackendError> {
Ok(make_text_response("not used"))
}
async fn send_stream(
&self,
_messages: Vec<Value>,
_tools: Option<Vec<ToolSpec>>,
_sampling: Option<SamplingParams>,
) -> Result<ChunkStream, forge_guardrails::StreamError> {
self.call_count.fetch_add(1, AtomicOrdering::SeqCst);
let chunks = futures_util::stream::iter(vec![Ok(StreamChunk::new(
forge_guardrails::ChunkType::TextDelta,
)
.with_content("partial"))]);
Ok(Box::pin(chunks))
}
async fn get_context_length(
&self,
) -> Result<Option<i64>, forge_guardrails::ContextDiscoveryError> {
Ok(Some(4096))
}
}
fn mock_client(responses: Vec<LLMResponse>) -> MockClient {
ScriptedLlmClient::new(responses).with_stream_mode(StreamMode::Unsupported)
}
fn make_tool_call(tool: &str, args: IndexMap<String, Value>) -> LLMResponse {
LLMResponse::ToolCalls(vec![forge_guardrails::ToolCall::new(tool, args)])
}
fn make_text_response(content: &str) -> LLMResponse {
LLMResponse::Text(forge_guardrails::TextResponse::new(content))
}
fn make_workflow_with_step_and_terminal<S, T>(step_tool: S, terminal_tool: T) -> Workflow
where
S: IntoToolCallable,
T: IntoToolCallable,
{
let mut tools: IndexMap<String, ToolDef> = IndexMap::new();
tools.insert(
"search".to_string(),
ToolDef::new(
ToolSpec::from_json_schema(
"search",
"Search tool",
&serde_json::json!({
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"]
}),
)
.expect("valid spec"),
step_tool,
),
);
tools.insert(
"respond".to_string(),
ToolDef::new(respond_spec(), terminal_tool),
);
Workflow::new(
"test_workflow",
"test workflow",
tools,
vec!["search".to_string()],
TerminalToolInput::Single("respond".to_string()),
"You are a helper.",
)
.expect("valid workflow")
}
fn make_simple_workflow() -> Workflow {
fn step_fn(args: Vec<String>) -> Result<String, ToolResolutionError> {
Ok(format!("search result for {:?}", args))
}
fn terminal_fn(args: Vec<String>) -> Result<String, ToolResolutionError> {
for arg in &args {
if let Some(val) = arg.strip_prefix("message=") {
return Ok(val.to_string());
}
}
Ok("default response".to_string())
}
make_workflow_with_step_and_terminal(step_fn, terminal_fn)
}
fn make_context_manager() -> Arc<Mutex<ContextManager>> {
Arc::new(Mutex::new(ContextManager::new(
Box::new(NoCompact),
4096,
None,
None,
None,
)))
}
fn make_runner(client: MockClient) -> Arc<WorkflowRunner<MockClient>> {
Arc::new(WorkflowRunner::new(
Arc::new(client),
make_context_manager(),
10, 3, 2, false, None, None, true, None, ))
}
#[tokio::test]
async fn ts001_simple_two_step_workflow() {
let mut args1 = IndexMap::new();
args1.insert("query".to_string(), Value::String("test".to_string()));
let mut args2 = IndexMap::new();
args2.insert(
"message".to_string(),
Value::String("final answer".to_string()),
);
let client = mock_client(vec![
make_tool_call("search", args1),
make_tool_call("respond", args2),
]);
let runner = make_runner(client);
let workflow = make_simple_workflow();
let result = runner
.run(&workflow, "search for test", None, None, None)
.await;
assert!(result.is_ok(), "Expected Ok, got {:?}", result);
let val = result.expect("ok");
assert_eq!(val, Value::String("final answer".to_string()));
}
#[tokio::test(flavor = "current_thread")]
async fn workflow_runner_classifier_scoring_does_not_block_current_thread_runtime() {
let mut args1 = IndexMap::new();
args1.insert("query".to_string(), Value::String("test".to_string()));
let mut args2 = IndexMap::new();
args2.insert(
"message".to_string(),
Value::String("final answer".to_string()),
);
let client = Arc::new(mock_client(vec![
make_tool_call("search", args1),
make_tool_call("respond", args2),
]));
let scorer = Arc::new(SleepOnceToolCallScorer::new(Duration::from_millis(100)));
let runner = WorkflowRunner::new(
client,
make_context_manager(),
10,
3,
2,
false,
None,
None,
true,
None,
)
.with_tool_call_scorer(scorer, None);
let workflow = make_simple_workflow();
let started = Instant::now();
let run = runner.run(&workflow, "search for test", None, None, None);
tokio::pin!(run);
tokio::select! {
_ = tokio::time::sleep(Duration::from_millis(20)) => {
assert!(
started.elapsed() < Duration::from_millis(80),
"WorkflowRunner classifier scoring blocked the current-thread runtime"
);
}
result = &mut run => {
panic!("workflow completed before blocking scorer delay: {result:?}");
}
}
let val = run.await.expect("workflow result");
assert_eq!(val, Value::String("final answer".to_string()));
}
#[tokio::test]
async fn final_response_scorer_retries_terminal_answer_before_execution() {
let collected: Arc<std::sync::Mutex<Vec<Message>>> =
Arc::new(std::sync::Mutex::new(Vec::new()));
let collected_clone = collected.clone();
let cb: OnMessageFn = Box::new(move |msg: &Message| {
if let Ok(mut messages) = collected_clone.lock() {
messages.push(msg.clone());
}
});
let mut search_args = IndexMap::new();
search_args.insert("query".to_string(), Value::String("test".to_string()));
let mut bad_args = IndexMap::new();
bad_args.insert("message".to_string(), Value::String("bad".to_string()));
let mut good_args = IndexMap::new();
good_args.insert("message".to_string(), Value::String("good".to_string()));
let client = Arc::new(mock_client(vec![
make_tool_call("search", search_args),
make_tool_call("respond", bad_args),
make_tool_call("respond", good_args),
]));
let final_scorer = Arc::new(SequenceFinalResponseScorer::new());
let runner = WorkflowRunner::new(
client.clone(),
make_context_manager(),
10,
3,
2,
false,
None,
Some(cb),
true,
None,
)
.with_final_response_scorer(final_scorer, None);
let workflow = make_simple_workflow();
let result = runner
.run(&workflow, "search for test", None, None, None)
.await;
assert_eq!(result.expect("ok"), Value::String("good".to_string()));
assert_eq!(client.calls(), 3);
assert!(collected
.lock()
.expect("messages lock")
.iter()
.any(|message| message.content.contains("[FinalResponseNudge]")));
}
#[tokio::test]
async fn ts002_text_response_retry() {
let mut args1 = IndexMap::new();
args1.insert("query".to_string(), Value::String("test".to_string()));
let mut args2 = IndexMap::new();
args2.insert("message".to_string(), Value::String("done".to_string()));
let client = mock_client(vec![
make_text_response("I think you should search for test"), make_tool_call("search", args1), make_tool_call("respond", args2), ]);
let runner = make_runner(client);
let workflow = make_simple_workflow();
let result = runner.run(&workflow, "do stuff", None, None, None).await;
assert!(result.is_ok(), "Expected Ok, got {:?}", result);
}
#[tokio::test]
async fn empty_tool_batch_retried_without_noop_execution() {
let collected: Arc<std::sync::Mutex<Vec<MessageType>>> =
Arc::new(std::sync::Mutex::new(Vec::new()));
let collected_clone = collected.clone();
let cb: OnMessageFn = Box::new(move |msg: &Message| {
if let Ok(mut v) = collected_clone.lock() {
v.push(msg.metadata.msg_type);
}
});
let mut search_args = IndexMap::new();
search_args.insert("query".to_string(), Value::String("test".to_string()));
let mut respond_args = IndexMap::new();
respond_args.insert("message".to_string(), Value::String("done".to_string()));
let client = mock_client(vec![
LLMResponse::ToolCalls(Vec::new()),
make_tool_call("search", search_args),
make_tool_call("respond", respond_args),
]);
let runner = Arc::new(WorkflowRunner::new(
Arc::new(client),
make_context_manager(),
10,
3,
2,
false,
None,
Some(cb),
true,
None,
));
let workflow = make_simple_workflow();
let result = runner.run(&workflow, "do stuff", None, None, None).await;
assert!(result.is_ok(), "Expected Ok, got {:?}", result);
assert_eq!(result.expect("ok"), Value::String("done".to_string()));
let final_collected = collected.lock().expect("lock");
assert_eq!(final_collected.first(), Some(&MessageType::RetryNudge));
assert_eq!(
final_collected
.iter()
.filter(|msg_type| **msg_type == MessageType::RetryNudge)
.count(),
1
);
}
#[tokio::test]
async fn invalid_tool_arguments_retry_before_execution() {
let collected: Arc<std::sync::Mutex<Vec<Message>>> =
Arc::new(std::sync::Mutex::new(Vec::new()));
let collected_clone = collected.clone();
let cb: OnMessageFn = Box::new(move |msg: &Message| {
if let Ok(mut v) = collected_clone.lock() {
v.push(msg.clone());
}
});
let mut search_args = IndexMap::new();
search_args.insert("query".to_string(), Value::String("test".to_string()));
let mut respond_args = IndexMap::new();
respond_args.insert("message".to_string(), Value::String("done".to_string()));
let client = Arc::new(mock_client(vec![
make_tool_call("search", IndexMap::new()),
make_tool_call("search", search_args),
make_tool_call("respond", respond_args),
]));
let runner = Arc::new(WorkflowRunner::new(
client.clone(),
make_context_manager(),
10,
3,
2,
false,
None,
Some(cb),
true,
None,
));
let workflow = make_simple_workflow();
let result = runner.run(&workflow, "do stuff", None, None, None).await;
assert_eq!(
result.expect("workflow ok"),
Value::String("done".to_string())
);
assert_eq!(client.calls(), 3);
let messages = collected.lock().expect("lock");
let invalid_result = messages
.iter()
.find(|msg| {
msg.metadata.msg_type == MessageType::RetryNudge
&& msg.content.contains("[InvalidArguments]")
})
.expect("invalid argument retry message");
assert!(invalid_result.content.contains("query is required"));
}
#[tokio::test]
async fn mixed_terminal_batch_retried_without_executing_partial_work() {
let collected: Arc<std::sync::Mutex<Vec<Message>>> =
Arc::new(std::sync::Mutex::new(Vec::new()));
let collected_clone = collected.clone();
let cb: OnMessageFn = Box::new(move |msg: &Message| {
if let Ok(mut v) = collected_clone.lock() {
v.push(msg.clone());
}
});
let mut search_args = IndexMap::new();
search_args.insert("query".to_string(), Value::String("test".to_string()));
let mut premature_respond_args = IndexMap::new();
premature_respond_args.insert("message".to_string(), Value::String("too soon".to_string()));
let mut final_respond_args = IndexMap::new();
final_respond_args.insert("message".to_string(), Value::String("done".to_string()));
let client = Arc::new(mock_client(vec![
LLMResponse::ToolCalls(vec![
forge_guardrails::ToolCall::new("search", search_args.clone()),
forge_guardrails::ToolCall::new("respond", premature_respond_args),
]),
make_tool_call("search", search_args),
make_tool_call("respond", final_respond_args),
]));
let runner = Arc::new(WorkflowRunner::new(
client.clone(),
make_context_manager(),
10,
3,
2,
false,
None,
Some(cb),
true,
None,
));
let workflow = make_simple_workflow();
let result = runner.run(&workflow, "do stuff", None, None, None).await;
assert_eq!(
result.expect("workflow ok"),
Value::String("done".to_string())
);
assert_eq!(client.calls(), 3);
let messages = collected.lock().expect("lock");
let retry_results: Vec<&Message> = messages
.iter()
.filter(|msg| msg.metadata.msg_type == MessageType::RetryNudge)
.collect();
assert_eq!(retry_results.len(), 2);
assert!(retry_results
.iter()
.all(|msg| msg.content.contains("Do not combine terminal")));
let blocked_call_ids: Vec<String> = messages
.iter()
.filter_map(|msg| {
if msg.metadata.msg_type == MessageType::ToolCall {
msg.tool_calls.as_ref().map(|calls| {
calls
.iter()
.map(|call| call.call_id.clone())
.collect::<Vec<_>>()
})
} else {
None
}
})
.next()
.expect("blocked tool call ids");
let blocked_result_ids: Vec<String> = retry_results
.iter()
.map(|msg| msg.tool_call_id.clone().expect("tool result id"))
.collect();
assert_eq!(blocked_call_ids, blocked_result_ids);
}
#[tokio::test]
async fn ts003_premature_terminal_blocked() {
let mut args1 = IndexMap::new();
args1.insert(
"message".to_string(),
Value::String("premature".to_string()),
);
let mut args2 = IndexMap::new();
args2.insert("query".to_string(), Value::String("test".to_string()));
let mut args3 = IndexMap::new();
args3.insert("message".to_string(), Value::String("final".to_string()));
let client = mock_client(vec![
make_tool_call("respond", args1), make_tool_call("search", args2), make_tool_call("respond", args3), ]);
let runner = make_runner(client);
let workflow = make_simple_workflow();
let result = runner.run(&workflow, "do stuff", None, None, None).await;
assert!(result.is_ok(), "Expected Ok, got {:?}", result);
assert_eq!(result.expect("ok"), Value::String("final".to_string()));
}
#[tokio::test]
async fn ts004_tool_error_feedback() {
fn failing_step(_args: Vec<String>) -> Result<String, ToolResolutionError> {
Err(ToolResolutionError::new("search failed: bad query"))
}
let mut args1 = IndexMap::new();
args1.insert("query".to_string(), Value::String("bad query".to_string()));
let mut args2 = IndexMap::new();
args2.insert("query".to_string(), Value::String("good query".to_string()));
let responses = std::iter::once(make_tool_call("search", args1))
.chain((0..9).map(|_| make_tool_call("search", args2.clone())))
.collect();
let client = mock_client(responses);
let mut tools: IndexMap<String, ToolDef> = IndexMap::new();
tools.insert(
"search".to_string(),
ToolDef::new(
ToolSpec::from_json_schema(
"search",
"Search",
&serde_json::json!({
"type": "object", "properties": {"query": {"type": "string"}}
}),
)
.expect("valid"),
failing_step,
),
);
tools.insert("respond".to_string(), respond_tool());
let workflow = Workflow::new(
"fail_workflow",
"fail test",
tools,
vec!["search".to_string()],
TerminalToolInput::Single("respond".to_string()),
"Helper.",
)
.expect("valid");
let runner = make_runner(client);
let result = runner.run(&workflow, "search", None, None, None).await;
assert!(result.is_err(), "Expected error, got {:?}", result);
}
#[tokio::test]
async fn ts005_streaming_mode_flag() {
let client = mock_client(vec![]);
let _runner = WorkflowRunner::new(
Arc::new(client),
make_context_manager(),
10,
3,
2,
true, None,
None,
true,
None,
);
}
#[tokio::test]
async fn ts006_cancellation() {
let mut args1 = IndexMap::new();
args1.insert("query".to_string(), Value::String("test".to_string()));
let client = mock_client(vec![make_tool_call("search", args1)]);
let runner = make_runner(client);
let workflow = make_simple_workflow();
let (_tx, rx) = watch::channel(true);
let result = runner.run(&workflow, "search", None, None, Some(rx)).await;
assert!(result.is_err());
match result.expect_err("should be error") {
ForgeError::WorkflowCancelled(e) => {
assert_eq!(e.iteration, 0);
}
other => panic!("Expected WorkflowCancelled, got {:?}", other),
}
}
#[tokio::test]
async fn ts007_initial_messages_seed() {
let collected: Arc<std::sync::Mutex<Vec<MessageType>>> =
Arc::new(std::sync::Mutex::new(Vec::new()));
let collected_clone = collected.clone();
let cb: OnMessageFn = Box::new(move |msg: &Message| {
if let Ok(mut v) = collected_clone.lock() {
v.push(msg.metadata.msg_type);
}
});
let mut args1 = IndexMap::new();
args1.insert(
"message".to_string(),
Value::String("seeded result".to_string()),
);
let client = mock_client(vec![make_tool_call("respond", args1)]);
let mut tools: IndexMap<String, ToolDef> = IndexMap::new();
tools.insert("respond".to_string(), respond_tool());
let workflow = Workflow::new(
"seeded",
"seeded test",
tools,
vec![],
TerminalToolInput::Single("respond".to_string()),
"You are a helper.",
)
.expect("valid");
let runner = Arc::new(WorkflowRunner::new(
Arc::new(client),
make_context_manager(),
10,
3,
2,
false,
None,
Some(cb),
true,
None,
));
let seed = vec![
Message::new(
MessageRole::System,
"System prompt",
MessageMeta::new(MessageType::SystemPrompt),
),
Message::new(
MessageRole::User,
"User message",
MessageMeta::new(MessageType::UserInput),
),
];
let result = runner
.run(&workflow, "ignored", None, Some(seed), None)
.await;
assert!(result.is_ok());
let final_collected = collected.lock().expect("lock");
assert!(!final_collected.contains(&MessageType::SystemPrompt));
assert!(!final_collected.contains(&MessageType::UserInput));
}
#[tokio::test]
async fn ts008_resolution_error_soft() {
fn soft_fail_step(_args: Vec<String>) -> Result<String, ToolResolutionError> {
Err(ToolResolutionError::new("try again with different args"))
}
let mut args1 = IndexMap::new();
args1.insert("query".to_string(), Value::String("test".to_string()));
let mut args2 = IndexMap::new();
args2.insert("message".to_string(), Value::String("done".to_string()));
let client = mock_client(vec![
make_tool_call("search", args1), make_tool_call("respond", args2), ]);
let mut tools: IndexMap<String, ToolDef> = IndexMap::new();
tools.insert(
"search".to_string(),
ToolDef::new(
ToolSpec::from_json_schema(
"search",
"Search",
&serde_json::json!({
"type": "object", "properties": {"query": {"type": "string"}}
}),
)
.expect("valid"),
soft_fail_step,
),
);
tools.insert("respond".to_string(), respond_tool());
let workflow = Workflow::new(
"soft_error",
"soft error test",
tools,
vec![], TerminalToolInput::Single("respond".to_string()),
"Helper.",
)
.expect("valid");
let runner = make_runner(client);
let result = runner
.run(&workflow, "search then respond", None, None, None)
.await;
assert!(result.is_ok(), "Expected Ok, got {:?}", result);
}
#[tokio::test]
async fn ts009_prerequisite_enforcement() {
use forge_guardrails::steps::Prerequisite;
use forge_guardrails::steps::StepTracker;
let tracker = StepTracker::new(vec!["search".to_string()]);
let args = IndexMap::new();
let result = tracker.check_prerequisites(
"analyze",
&args,
&[Prerequisite::NameOnly("search".to_string())],
);
assert!(!result.satisfied);
assert!(result.missing.contains(&"search".to_string()));
}
#[test]
fn ts010_reasoning_folding_wire_vs_internal() {
let reasoning = Message::new(
MessageRole::Assistant,
"let me think about this",
MessageMeta::new(MessageType::Reasoning),
);
let tool_call = Message::new(
MessageRole::Assistant,
"",
MessageMeta::new(MessageType::ToolCall),
)
.with_tool_calls(vec![ToolCallInfo::new(
"search",
Some(IndexMap::new()),
"tc_0001",
)]);
let internal = vec![reasoning.clone(), tool_call.clone()];
assert_eq!(internal.len(), 2, "Internally, 2 separate messages");
let wire = fold_and_serialize(&internal, "openai");
assert_eq!(wire.len(), 1, "On wire, folded into 1 message");
assert_eq!(wire[0]["content"], "let me think about this");
assert!(wire[0]["tool_calls"].is_array());
}
#[tokio::test]
async fn ts011_unknown_tool_nudge() {
let mut args_bad = IndexMap::new();
args_bad.insert("query".to_string(), Value::String("test".to_string()));
let mut args_good = IndexMap::new();
args_good.insert("query".to_string(), Value::String("test".to_string()));
let mut args_resp = IndexMap::new();
args_resp.insert("message".to_string(), Value::String("done".to_string()));
let client = mock_client(vec![
make_tool_call("nonexistent_tool", args_bad), make_tool_call("search", args_good), make_tool_call("respond", args_resp), ]);
let runner = make_runner(client);
let workflow = make_simple_workflow();
let result = runner.run(&workflow, "do stuff", None, None, None).await;
assert!(
result.is_ok(),
"Expected Ok after recovery, got {:?}",
result
);
}
#[tokio::test]
async fn ts012_retries_consume_iterations() {
let mut args1 = IndexMap::new();
args1.insert("query".to_string(), Value::String("test".to_string()));
let mut args2 = IndexMap::new();
args2.insert("message".to_string(), Value::String("done".to_string()));
let client = mock_client(vec![
make_text_response("thinking..."),
make_text_response("still thinking..."),
make_text_response("more thinking..."),
make_tool_call("search", args1),
make_tool_call("respond", args2),
]);
let runner = Arc::new(WorkflowRunner::new(
Arc::new(client),
make_context_manager(),
3, 1, 2,
false,
None,
None,
true,
None,
));
let workflow = make_simple_workflow();
let result = runner.run(&workflow, "do stuff", None, None, None).await;
assert!(result.is_err(), "Should fail with max iterations");
}
#[tokio::test]
async fn inference_retry_budget_raises_tool_call_error() {
let client = mock_client(vec![
make_text_response("bad 1"),
make_text_response("bad 2"),
make_text_response("bad 3"),
make_text_response("bad 4"),
]);
let mut messages = vec![Message::new(
MessageRole::User,
"start",
MessageMeta::new(MessageType::UserInput),
)];
let mut ctx = ContextManager::new(Box::new(NoCompact), 4096, None, None, None);
let validator = forge_guardrails::guardrails::ResponseValidator::new(
vec!["respond".to_string()],
false,
None,
);
let mut tracker = forge_guardrails::ErrorTracker::new(2, 2);
let mut counter = 0;
let tools = vec![respond_spec()];
let result = forge_guardrails::run_inference(
&mut messages,
&client,
&mut ctx,
&validator,
&mut tracker,
&tools,
&mut counter,
0,
"",
Some(10),
false,
None,
None,
)
.await;
assert!(matches!(result, Err(ForgeError::ToolCall(_))));
assert_eq!(client.calls(), 3);
}
#[tokio::test]
async fn inference_backend_error_is_not_retried() {
let client = BackendErrorClient::new();
let mut messages = vec![Message::new(
MessageRole::User,
"start",
MessageMeta::new(MessageType::UserInput),
)];
let mut ctx = ContextManager::new(Box::new(NoCompact), 4096, None, None, None);
let validator = forge_guardrails::guardrails::ResponseValidator::new(
vec!["respond".to_string()],
false,
None,
);
let mut tracker = forge_guardrails::ErrorTracker::new(2, 2);
let mut counter = 0;
let tools = vec![respond_spec()];
let result = forge_guardrails::run_inference(
&mut messages,
&client,
&mut ctx,
&validator,
&mut tracker,
&tools,
&mut counter,
0,
"",
Some(10),
false,
None,
None,
)
.await;
assert!(matches!(result, Err(ForgeError::Backend(_))));
assert_eq!(client.call_count.load(AtomicOrdering::SeqCst), 1);
}
#[tokio::test]
async fn inference_stream_without_final_is_stream_error() {
let client = NoFinalStreamClient::new();
let mut messages = vec![Message::new(
MessageRole::User,
"start",
MessageMeta::new(MessageType::UserInput),
)];
let mut ctx = ContextManager::new(Box::new(NoCompact), 4096, None, None, None);
let validator = forge_guardrails::guardrails::ResponseValidator::new(
vec!["respond".to_string()],
false,
None,
);
let mut tracker = forge_guardrails::ErrorTracker::new(2, 2);
let mut counter = 0;
let tools = vec![respond_spec()];
let result = forge_guardrails::run_inference(
&mut messages,
&client,
&mut ctx,
&validator,
&mut tracker,
&tools,
&mut counter,
0,
"",
Some(10),
true,
None,
None,
)
.await;
assert!(matches!(result, Err(ForgeError::Stream(_))));
assert_eq!(client.call_count.load(AtomicOrdering::SeqCst), 1);
}
#[test]
fn ts013_escalating_step_nudge_tiers() {
use forge_guardrails::nudges;
let t1 = nudges::step_nudge("respond", &["search"], 1);
let t2 = nudges::step_nudge("respond", &["search"], 2);
let t3 = nudges::step_nudge("respond", &["search"], 3);
assert!(!t1.contains("STOP"), "Tier 1 should be polite");
assert!(!t2.contains("STOP"), "Tier 2 should be direct");
assert!(t3.contains("STOP"), "Tier 3 should be aggressive");
}
#[tokio::test]
async fn ts014_multiple_terminal_tools() {
fn respond_fn(args: Vec<String>) -> Result<String, ToolResolutionError> {
for arg in &args {
if let Some(val) = arg.strip_prefix("message=") {
return Ok(val.to_string());
}
}
Ok("responded".to_string())
}
fn summarize_fn(args: Vec<String>) -> Result<String, ToolResolutionError> {
for arg in &args {
if let Some(val) = arg.strip_prefix("summary=") {
return Ok(val.to_string());
}
}
Ok("summarized".to_string())
}
let mut tools: IndexMap<String, ToolDef> = IndexMap::new();
tools.insert(
"search".to_string(),
ToolDef::new(
ToolSpec::from_json_schema(
"search",
"Search",
&serde_json::json!({
"type": "object", "properties": {"query": {"type": "string"}}
}),
)
.expect("valid"),
(|args: Vec<String>| Ok(format!("found: {:?}", args)))
as fn(Vec<String>) -> Result<String, ToolResolutionError>,
),
);
tools.insert(
"respond".to_string(),
ToolDef::new(respond_spec(), respond_fn),
);
tools.insert(
"summarize".to_string(),
ToolDef::new(
ToolSpec::from_json_schema(
"summarize",
"Summarize",
&serde_json::json!({
"type": "object", "properties": {"summary": {"type": "string"}}
}),
)
.expect("valid"),
summarize_fn,
),
);
let workflow = Workflow::new(
"multi_terminal",
"multi terminal test",
tools,
vec!["search".to_string()],
TerminalToolInput::Multiple(vec!["respond".to_string(), "summarize".to_string()]),
"Helper.",
)
.expect("valid");
let mut args1 = IndexMap::new();
args1.insert("query".to_string(), Value::String("test".to_string()));
let mut args2 = IndexMap::new();
args2.insert(
"summary".to_string(),
Value::String("summary result".to_string()),
);
let client = mock_client(vec![
make_tool_call("search", args1),
make_tool_call("summarize", args2), ]);
let runner = make_runner(client);
let result = runner
.run(&workflow, "search and summarize", None, None, None)
.await;
assert!(result.is_ok(), "Expected Ok, got {:?}", result);
assert_eq!(
result.expect("ok"),
Value::String("summary result".to_string())
);
}
#[tokio::test]
async fn terminal_mixed_batch_rejected_when_no_steps_pending() {
fn respond_fn(args: Vec<String>) -> Result<String, ToolResolutionError> {
for arg in &args {
if let Some(val) = arg.strip_prefix("message=") {
return Ok(val.to_string());
}
}
Ok("responded".to_string())
}
let side_effects = Arc::new(AtomicI32::new(0));
let side_effects_for_tool = side_effects.clone();
let side_effect_tool = move |_args: Vec<String>| -> Result<String, ToolResolutionError> {
side_effects_for_tool.fetch_add(1, AtomicOrdering::SeqCst);
Ok("side effect executed".to_string())
};
let mut tools: IndexMap<String, ToolDef> = IndexMap::new();
tools.insert(
"respond".to_string(),
ToolDef::new(respond_spec(), respond_fn),
);
tools.insert(
"side_effect".to_string(),
ToolDef::new(
ToolSpec::from_json_schema(
"side_effect",
"Side effect",
&serde_json::json!({
"type": "object", "properties": {}
}),
)
.expect("valid"),
side_effect_tool,
),
);
let workflow = Workflow::new(
"mixed_terminal",
"mixed terminal test",
tools,
vec![],
TerminalToolInput::Single("respond".to_string()),
"Helper.",
)
.expect("valid");
let mut terminal_args = IndexMap::new();
terminal_args.insert("message".to_string(), Value::String("done".to_string()));
let mixed_batch = LLMResponse::ToolCalls(vec![
forge_guardrails::ToolCall::new("respond", terminal_args.clone()),
forge_guardrails::ToolCall::new("side_effect", IndexMap::new()),
]);
let client = mock_client(vec![mixed_batch, make_tool_call("respond", terminal_args)]);
let runner = make_runner(client);
let result = runner.run(&workflow, "finish", None, None, None).await;
assert_eq!(result.expect("ok"), Value::String("done".to_string()));
assert_eq!(side_effects.load(AtomicOrdering::SeqCst), 0);
}
#[test]
fn ts015_rescue_json_from_text() {
let available = vec!["search", "respond"];
let text = r#"{"tool": "search", "args": {"query": "test"}}"#;
let rescued = forge_guardrails::rescue_tool_call(text, &available);
assert_eq!(rescued.len(), 1);
assert_eq!(rescued[0].tool, "search");
}
#[tokio::test]
async fn ts016_custom_retry_nudge_string() {
let nudge_text = "Please use a tool call!";
let mut args1 = IndexMap::new();
args1.insert("query".to_string(), Value::String("test".to_string()));
let mut args2 = IndexMap::new();
args2.insert("message".to_string(), Value::String("done".to_string()));
let client = mock_client(vec![
make_text_response("I will help you"), make_tool_call("search", args1),
make_tool_call("respond", args2),
]);
let runner = Arc::new(WorkflowRunner::new(
Arc::new(client),
make_context_manager(),
10,
3,
2,
false,
None,
None,
false, Some(nudge_text.to_string()),
));
let workflow = make_simple_workflow();
let result = runner.run(&workflow, "do stuff", None, None, None).await;
assert!(result.is_ok(), "Expected Ok, got {:?}", result);
}
#[test]
fn slot_worker_task_priority_ordering() {
use forge_guardrails::nudges::step_nudge;
let t1 = step_nudge("respond", &["search"], 1);
let t3 = step_nudge("respond", &["search"], 3);
assert!(t1.contains("search"));
assert!(t3.contains("STOP"));
}
#[test]
fn tool_call_id_monotonic() {
assert_eq!(format_tool_call_id(0), "call_000000000");
assert_eq!(format_tool_call_id(1), "call_000000001");
assert_eq!(format_tool_call_id(100), "call_000000100");
}
#[test]
fn fold_and_serialize_empty() {
let result = fold_and_serialize(&[], "openai");
assert!(result.is_empty());
}
#[test]
fn fold_and_serialize_multiple_pairs() {
let r1 = Message::new(
MessageRole::Assistant,
"think 1",
MessageMeta::new(MessageType::Reasoning),
);
let tc1 = Message::new(
MessageRole::Assistant,
"",
MessageMeta::new(MessageType::ToolCall),
)
.with_tool_calls(vec![ToolCallInfo::new(
"a",
Some(IndexMap::new()),
"tc_0001",
)]);
let r2 = Message::new(
MessageRole::Assistant,
"think 2",
MessageMeta::new(MessageType::Reasoning),
);
let tc2 = Message::new(
MessageRole::Assistant,
"",
MessageMeta::new(MessageType::ToolCall),
)
.with_tool_calls(vec![ToolCallInfo::new(
"b",
Some(IndexMap::new()),
"tc_0002",
)]);
let result = fold_and_serialize(&[r1, tc1, r2, tc2], "openai");
assert_eq!(result.len(), 2);
assert_eq!(result[0]["content"], "think 1");
assert_eq!(result[1]["content"], "think 2");
}
#[tokio::test]
async fn test_protocol_pairing_invariant() {
let mut args1 = IndexMap::new();
args1.insert(
"query".to_string(),
Value::String("pairing test".to_string()),
);
let mut args2 = IndexMap::new();
args2.insert("message".to_string(), Value::String("done".to_string()));
let collected = Arc::new(std::sync::Mutex::new(Vec::<Message>::new()));
let collected_clone = collected.clone();
let cb: OnMessageFn = Box::new(move |msg: &Message| {
let mut guard = collected_clone.lock().unwrap();
guard.push(msg.clone());
});
let client = mock_client(vec![
make_tool_call("search", args1),
make_tool_call("respond", args2),
]);
let context_mgr = make_context_manager();
let runner = WorkflowRunner::new(
Arc::new(client),
context_mgr.clone(),
10,
3,
2,
false,
None,
Some(cb),
true,
None,
);
let workflow = make_simple_workflow();
let result = runner.run(&workflow, "start", None, None, None).await;
assert!(result.is_ok());
let msgs = collected.lock().unwrap();
let mut tool_calls = Vec::new();
let mut tool_results = Vec::new();
for msg in msgs.iter() {
if msg.role == MessageRole::Assistant && msg.tool_calls.is_some() {
if let Some(ref calls) = msg.tool_calls {
for tc in calls {
tool_calls.push(tc.call_id.clone());
}
}
} else if msg.role == MessageRole::Tool {
if let Some(ref call_id) = msg.tool_call_id {
tool_results.push(call_id.clone());
}
}
}
assert_eq!(tool_calls.len(), 2);
assert_eq!(tool_results.len(), 2);
assert_eq!(tool_calls[0], "call_000000000");
assert_eq!(tool_calls, tool_results, "Each assistant tool call ID must match the corresponding tool result tool_call_id exactly.");
}
#[tokio::test]
async fn test_step_blocked_transcript() {
let mut args1 = IndexMap::new();
args1.insert(
"message".to_string(),
Value::String("premature terminal call".to_string()),
);
let mut args2 = IndexMap::new();
args2.insert(
"query".to_string(),
Value::String("valid required step".to_string()),
);
let mut args3 = IndexMap::new();
args3.insert(
"message".to_string(),
Value::String("terminal after step".to_string()),
);
let collected = Arc::new(std::sync::Mutex::new(Vec::<Message>::new()));
let collected_clone = collected.clone();
let cb: OnMessageFn = Box::new(move |msg: &Message| {
let mut guard = collected_clone.lock().unwrap();
guard.push(msg.clone());
});
let client = mock_client(vec![
make_tool_call("respond", args1),
make_tool_call("search", args2),
make_tool_call("respond", args3),
]);
let context_mgr = make_context_manager();
let runner = WorkflowRunner::new(
Arc::new(client),
context_mgr,
10,
3,
2,
false,
None,
Some(cb),
true,
None,
);
let workflow = make_simple_workflow();
let result = runner.run(&workflow, "start", None, None, None).await;
assert!(result.is_ok());
let msgs = collected.lock().unwrap();
let mut step_blocked_tool_call_ids = Vec::new();
let mut step_blocked_tool_result_ids = Vec::new();
for msg in msgs.iter() {
if msg.metadata.msg_type == MessageType::ToolCall {
if let Some(ref calls) = msg.tool_calls {
for tc in calls {
if tc.name == "respond" {
step_blocked_tool_call_ids.push(tc.call_id.clone());
}
}
}
} else if msg.metadata.msg_type == MessageType::StepNudge {
if let Some(ref name) = msg.tool_name {
if name == "respond" && msg.content.contains("[StepEnforcementError]") {
step_blocked_tool_result_ids.push(msg.tool_call_id.clone().unwrap_or_default());
}
}
}
}
assert!(
!step_blocked_tool_call_ids.is_empty(),
"Must contain a blocked respond tool call"
);
assert_eq!(
step_blocked_tool_call_ids.first(),
step_blocked_tool_result_ids.first(),
"Blocked tool call must have a matching error tool result message with same call ID."
);
}