from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from unittest.mock import MagicMock
import pytest
from forge.clients.base import ChunkType, LLMClient, StreamChunk
from forge.context.manager import ContextManager
from forge.context.strategies import NoCompact
from forge.core.messages import Message, MessageMeta, MessageRole, MessageType, ToolCallInfo
from forge.core.runner import WorkflowRunner
from pydantic import BaseModel
from forge.core.workflow import (
LLMResponse,
TextResponse,
ToolCall,
ToolDef,
ToolSpec,
Workflow,
)
from forge.errors import MaxIterationsError, PrerequisiteError, StepEnforcementError, StreamError, ToolCallError, ToolExecutionError, ToolResolutionError, WorkflowCancelledError
class EmptyParams(BaseModel):
pass
class MockClient:
def __init__(self, responses: list[ToolCall | TextResponse]):
self.responses = list(responses)
self._call_index = 0
self.send_calls: list[tuple[list[dict], list[ToolSpec] | None]] = []
self.send_stream_calls: list[tuple[list[dict], list[ToolSpec] | None]] = []
def _next(self) -> LLMResponse:
resp = self.responses[self._call_index]
self._call_index += 1
if isinstance(resp, ToolCall):
return [resp]
return resp
async def send(
self,
messages: list[dict[str, str]],
tools: list[ToolSpec] | None = None,
sampling: dict[str, object] | None = None,
passthrough: dict[str, object] | None = None,
inbound_anthropic_body: dict[str, object] | None = None,
) -> LLMResponse:
self.send_calls.append((messages, tools))
return self._next()
async def send_stream(
self,
messages: list[dict[str, str]],
tools: list[ToolSpec] | None = None,
sampling: dict[str, object] | None = None,
passthrough: dict[str, object] | None = None,
inbound_anthropic_body: dict[str, object] | None = None,
) -> AsyncIterator[StreamChunk]:
self.send_stream_calls.append((messages, tools))
resp = self._next()
yield StreamChunk(type=ChunkType.TEXT_DELTA, content="partial...")
yield StreamChunk(type=ChunkType.FINAL, response=resp)
async def get_context_length(self) -> int | None:
return None
def _make_tool(name: str, fn=None) -> ToolDef:
if fn is None:
fn = lambda **kwargs: f"{name}_result"
return ToolDef(
spec=ToolSpec(name=name, description=f"Tool {name}", parameters=EmptyParams),
callable=fn,
)
def _make_workflow(
tools: dict[str, ToolDef] | None = None,
required_steps: list[str] | None = None,
terminal_tool: str = "submit",
) -> Workflow:
if tools is None:
tools = {
"fetch": _make_tool("fetch"),
"submit": _make_tool("submit"),
}
if required_steps is None:
required_steps = ["fetch"]
return Workflow(
name="test_wf",
description="A test workflow",
tools=tools,
required_steps=required_steps,
terminal_tool=terminal_tool,
system_prompt_template="You are a {role}.",
)
def _make_runner(
client: MockClient,
max_iterations: int = 10,
max_retries_per_step: int = 3,
max_tool_errors: int = 2,
stream: bool = False,
on_chunk=None,
budget_tokens: int = 100_000,
) -> WorkflowRunner:
ctx = ContextManager(strategy=NoCompact(), budget_tokens=budget_tokens)
return WorkflowRunner(
client=client,
context_manager=ctx,
max_iterations=max_iterations,
max_retries_per_step=max_retries_per_step,
max_tool_errors=max_tool_errors,
stream=stream,
on_chunk=on_chunk,
)
class TestHappyPath:
@pytest.mark.asyncio
async def test_simple_workflow(self):
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
result = await runner.run(_make_workflow(), "do something", prompt_vars={"role": "tester"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_multi_step_workflow(self):
tools = {
"step_a": _make_tool("step_a"),
"step_b": _make_tool("step_b"),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["step_a", "step_b"])
client = MockClient([
ToolCall(tool="step_a", args={}),
ToolCall(tool="step_b", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
result = await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_tool_args_forwarded(self):
received_args = {}
def capture_tool(**kwargs):
received_args.update(kwargs)
return "captured"
tools = {
"fetch": _make_tool("fetch", fn=capture_tool),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={"key": "value", "count": 42}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert received_args == {"key": "value", "count": 42}
class TestRetryLogic:
@pytest.mark.asyncio
async def test_text_response_triggers_retry(self):
client = MockClient([
TextResponse(content="I don't know"),
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
result = await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_retries_exhausted_raises_tool_call_error(self):
client = MockClient([
TextResponse(content="nope1"),
TextResponse(content="nope2"),
TextResponse(content="nope3"),
TextResponse(content="final_nope"),
])
runner = _make_runner(client, max_retries_per_step=3)
with pytest.raises(ToolCallError) as exc_info:
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert exc_info.value.raw_response == "final_nope"
@pytest.mark.asyncio
async def test_retry_counter_resets_on_tool_call(self):
tools = {
"step_a": _make_tool("step_a"),
"step_b": _make_tool("step_b"),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["step_a", "step_b"])
client = MockClient([
TextResponse(content="fail1"),
TextResponse(content="fail2"),
ToolCall(tool="step_a", args={}),
TextResponse(content="fail3"),
TextResponse(content="fail4"),
ToolCall(tool="step_b", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client, max_retries_per_step=3)
result = await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_retries_consume_iterations(self):
client = MockClient([
TextResponse(content="fail1"), TextResponse(content="fail2"), ToolCall(tool="fetch", args={}), ToolCall(tool="submit", args={}), ])
runner = _make_runner(client, max_iterations=3, max_retries_per_step=3)
with pytest.raises(MaxIterationsError) as exc_info:
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert exc_info.value.iterations == 3
@pytest.mark.asyncio
async def test_max_iterations_bounds_total_llm_calls(self):
client = MockClient([
TextResponse(content="nope"), TextResponse(content="nope"), TextResponse(content="nope"), ])
runner = _make_runner(client, max_iterations=3, max_retries_per_step=5)
with pytest.raises(MaxIterationsError):
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert len(client.send_calls) == 3
class TestStepEnforcement:
@pytest.mark.asyncio
async def test_terminal_before_required_steps_injects_nudge(self):
call_count = 0
def counting_submit(**kwargs):
nonlocal call_count
call_count += 1
return "submitted"
tools = {
"fetch": _make_tool("fetch"),
"submit": _make_tool("submit", fn=counting_submit),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="submit", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="submit", args={}), ])
runner = _make_runner(client)
result = await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert result == "submitted"
assert call_count == 1
@pytest.mark.asyncio
async def test_premature_terminal_resets_retry_counter(self):
client = MockClient([
TextResponse(content="garbage"), TextResponse(content="garbage"), ToolCall(tool="submit", args={}), TextResponse(content="garbage"), TextResponse(content="garbage"), ToolCall(tool="fetch", args={}), ToolCall(tool="submit", args={}), ])
runner = _make_runner(client, max_retries_per_step=3)
result = await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_step_nudge_allows_recovery(self):
client = MockClient([
ToolCall(tool="submit", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="submit", args={}), ])
runner = _make_runner(client)
result = await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_escalating_nudge_tiers(self):
client = MockClient([
ToolCall(tool="submit", args={}), ToolCall(tool="submit", args={}), ToolCall(tool="submit", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="submit", args={}), ])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
original_compact = ctx.maybe_compact
captured_messages: list[list[Message]] = []
def spy_compact(messages, step_index=0, step_hint=""):
captured_messages.append(list(messages))
return original_compact(messages, step_index=step_index, step_hint=step_hint)
ctx.maybe_compact = spy_compact
runner = WorkflowRunner(client=client, context_manager=ctx)
result = await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
all_nudges = []
for msgs in captured_messages:
for m in msgs:
if m.metadata.type == MessageType.STEP_NUDGE and m not in all_nudges:
all_nudges.append(m)
assert len(all_nudges) == 3
assert "cannot call submit yet" in all_nudges[0].content.lower()
assert "must call one of these tools now" in all_nudges[1].content.lower()
assert "STOP" in all_nudges[2].content
@pytest.mark.asyncio
async def test_premature_terminal_exhausted_raises_step_enforcement_error(self):
client = MockClient([
ToolCall(tool="submit", args={}), ToolCall(tool="submit", args={}), ToolCall(tool="submit", args={}), ToolCall(tool="submit", args={}), ])
runner = _make_runner(client)
with pytest.raises(StepEnforcementError) as exc_info:
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert exc_info.value.terminal_tool == "submit"
assert exc_info.value.attempts == 4
assert exc_info.value.pending_steps == ["fetch"]
@pytest.mark.asyncio
async def test_premature_terminal_counter_resets_on_progress(self):
client = MockClient([
ToolCall(tool="submit", args={}), ToolCall(tool="submit", args={}), ToolCall(tool="submit", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="submit", args={}), ToolCall(tool="submit", args={}), ToolCall(tool="submit", args={}), ToolCall(tool="submit", args={}), ])
runner = _make_runner(client, max_iterations=10)
result = await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
class TestErrorHandling:
@pytest.mark.asyncio
async def test_unknown_tool_nudges_then_recovers(self):
client = MockClient([
ToolCall(tool="get_pricing", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="submit", args={}), ])
runner = _make_runner(client)
result = await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_unknown_tool_nudge_lists_available_tools(self):
client = MockClient([
ToolCall(tool="nonexistent", args={}),
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
second_call_msgs = client.send_calls[1][0]
nudge_content = second_call_msgs[-1]["content"]
assert "nonexistent" in nudge_content
assert "fetch" in nudge_content
assert "submit" in nudge_content
assert "does not exist" in nudge_content
@pytest.mark.asyncio
async def test_unknown_tool_consumes_iteration(self):
client = MockClient([
ToolCall(tool="bad_name", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="submit", args={}), ])
runner = _make_runner(client, max_iterations=2)
with pytest.raises(MaxIterationsError):
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
@pytest.mark.asyncio
async def test_unknown_tool_exhausts_retries(self):
client = MockClient([
ToolCall(tool="bad1", args={}),
ToolCall(tool="bad2", args={}),
ToolCall(tool="bad3", args={}),
ToolCall(tool="bad4", args={}),
])
runner = _make_runner(client, max_retries_per_step=3)
with pytest.raises(ToolCallError, match="Retries exhausted"):
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
@pytest.mark.asyncio
async def test_unknown_tool_and_text_response_share_retry_counter(self):
client = MockClient([
TextResponse(content="garbage"), ToolCall(tool="nonexistent", args={}), TextResponse(content="more garbage"), ToolCall(tool="still_wrong", args={}), ])
runner = _make_runner(client, max_retries_per_step=3)
with pytest.raises(ToolCallError, match="Retries exhausted"):
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
@pytest.mark.asyncio
async def test_tool_error_feeds_back_then_recovers(self):
call_count = 0
def flaky_fetch(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise TypeError("expected int, got str for argument 'count'")
return "fetch_result"
tools = {
"fetch": _make_tool("fetch", fn=flaky_fetch),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={"count": "five"}), ToolCall(tool="fetch", args={"count": 5}), ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
result = await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
assert call_count == 2
@pytest.mark.asyncio
async def test_tool_error_message_contains_exception_info(self):
def bad_tool(**kwargs):
raise ValueError("something went wrong")
tools = {
"fetch": _make_tool("fetch", fn=bad_tool),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="fetch", args={}), ])
runner = _make_runner(client, max_tool_errors=2)
with pytest.raises(ToolExecutionError):
await runner.run(wf, "go", prompt_vars={"role": "agent"})
second_call_msgs = client.send_calls[1][0]
error_msg = second_call_msgs[-1]["content"]
assert "[ToolError]" in error_msg
assert "ValueError" in error_msg
assert "something went wrong" in error_msg
@pytest.mark.asyncio
async def test_tool_errors_exhaust_max_tool_errors(self):
def bad_tool(**kwargs):
raise ValueError("always fails")
tools = {
"fetch": _make_tool("fetch", fn=bad_tool),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="fetch", args={}), ])
runner = _make_runner(client, max_tool_errors=2)
with pytest.raises(ToolExecutionError) as exc_info:
await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert exc_info.value.tool_name == "fetch"
assert isinstance(exc_info.value.cause, ValueError)
@pytest.mark.asyncio
async def test_tool_error_counter_resets_on_success(self):
call_count = 0
def sometimes_fails(**kwargs):
nonlocal call_count
call_count += 1
if call_count in (1, 3):
raise ValueError(f"fail #{call_count}")
return "ok"
tools = {
"fetch": _make_tool("fetch", fn=sometimes_fails),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="submit", args={}),
])
runner = _make_runner(client, max_tool_errors=1)
result = await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_terminal_tool_error_recovery(self):
call_count = 0
def flaky_submit(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise TypeError("bad args")
return "submitted"
tools = {
"fetch": _make_tool("fetch"),
"submit": _make_tool("submit", fn=flaky_submit),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={"bad": True}), ToolCall(tool="submit", args={}), ])
runner = _make_runner(client)
result = await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert result == "submitted"
@pytest.mark.asyncio
async def test_failed_required_step_not_recorded(self):
call_count = 0
def flaky_fetch(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ValueError("file not found")
return "fetch_result"
tools = {
"fetch": _make_tool("fetch", fn=flaky_fetch),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}), ToolCall(tool="submit", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="submit", args={}), ])
runner = _make_runner(client)
result = await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
assert call_count == 2
@pytest.mark.asyncio
async def test_tool_error_consumes_iteration(self):
def bad_tool(**kwargs):
raise ValueError("fails")
tools = {
"fetch": _make_tool("fetch", fn=bad_tool),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="fetch", args={}), ])
runner = _make_runner(client, max_iterations=2, max_tool_errors=5)
with pytest.raises(MaxIterationsError):
await runner.run(wf, "go", prompt_vars={"role": "agent"})
@pytest.mark.asyncio
async def test_max_iterations_exceeded(self):
client = MockClient([
ToolCall(tool="fetch", args={}) for _ in range(5)
])
runner = _make_runner(client, max_iterations=3)
with pytest.raises(MaxIterationsError) as exc_info:
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert exc_info.value.iterations == 3
assert "fetch" in exc_info.value.completed_steps
class TestContextManagement:
@pytest.mark.asyncio
async def test_compaction_called_each_iteration(self):
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
original_compact = ctx.maybe_compact
compact_calls = []
def spy_compact(messages, step_index=0, step_hint=""):
compact_calls.append({"step_index": step_index, "step_hint": step_hint})
return original_compact(messages, step_index=step_index, step_hint=step_hint)
ctx.maybe_compact = spy_compact
runner = WorkflowRunner(client=client, context_manager=ctx)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert len(compact_calls) == 2
assert compact_calls[0]["step_index"] == 0
assert compact_calls[1]["step_index"] == 1
@pytest.mark.asyncio
async def test_messages_grow_correctly(self):
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
second_call_msgs = client.send_calls[1][0]
assert len(second_call_msgs) == 4
assert second_call_msgs[0]["role"] == "system"
assert second_call_msgs[1]["role"] == "user"
assert second_call_msgs[2]["role"] == "assistant"
assert second_call_msgs[3]["role"] == "tool"
class TestStreaming:
@pytest.mark.asyncio
async def test_stream_mode_uses_send_stream(self):
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client, stream=True)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert len(client.send_stream_calls) == 2
assert len(client.send_calls) == 0
@pytest.mark.asyncio
async def test_on_chunk_callback_receives_chunks(self):
received_chunks: list[StreamChunk] = []
async def collect(chunk: StreamChunk) -> None:
received_chunks.append(chunk)
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client, stream=True, on_chunk=collect)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert len(received_chunks) == 4
assert received_chunks[0].type == ChunkType.TEXT_DELTA
assert received_chunks[1].type == ChunkType.FINAL
@pytest.mark.asyncio
async def test_stream_extracts_final_response(self):
client = MockClient([
ToolCall(tool="fetch", args={"x": 1}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client, stream=True)
result = await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_stream_without_final_chunk_raises_stream_error(self):
class NoFinalClient:
async def send(self, messages, tools=None, sampling=None, passthrough=None, inbound_anthropic_body=None):
return [ToolCall(tool="fetch", args={})]
async def send_stream(self, messages, tools=None, sampling=None, passthrough=None, inbound_anthropic_body=None):
yield StreamChunk(type=ChunkType.TEXT_DELTA, content="partial")
async def get_context_length(self):
return None
runner = _make_runner(MockClient([]), stream=True)
runner.client = NoFinalClient()
with pytest.raises(StreamError, match="FINAL chunk"):
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
class TestAsyncToolSupport:
@pytest.mark.asyncio
async def test_async_tool_callable(self):
async def async_fetch(**kwargs):
return "async_result"
tools = {
"fetch": _make_tool("fetch", fn=async_fetch),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
result = await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_sync_tool_callable(self):
def sync_fetch(**kwargs):
return "sync_result"
tools = {
"fetch": _make_tool("fetch", fn=sync_fetch),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
result = await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
class TestMessageStructure:
@pytest.mark.asyncio
async def test_initial_messages_correct(self):
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
await runner.run(_make_workflow(), "do something", prompt_vars={"role": "tester"})
first_call_msgs = client.send_calls[0][0]
assert len(first_call_msgs) == 2
assert first_call_msgs[0]["role"] == "system"
assert first_call_msgs[0]["content"] == "You are a tester."
assert first_call_msgs[1]["role"] == "user"
assert first_call_msgs[1]["content"] == "do something"
@pytest.mark.asyncio
async def test_tool_call_message_format(self):
client = MockClient([
ToolCall(tool="fetch", args={"key": "val"}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
second_msgs = client.send_calls[1][0]
tc = second_msgs[2]
assert tc["role"] == "assistant"
assert tc["tool_calls"][0]["function"]["name"] == "fetch"
assert tc["tool_calls"][0]["function"]["arguments"] == {"key": "val"}
assert second_msgs[3]["content"] == "fetch_result"
@pytest.mark.asyncio
async def test_text_response_emits_assistant_before_retry_nudge(self):
client = MockClient([
TextResponse(content="I'm not sure what to do"),
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
second_call_msgs = client.send_calls[1][0]
assert len(second_call_msgs) == 4
assert second_call_msgs[2]["role"] == "assistant"
assert second_call_msgs[2]["content"] == "I'm not sure what to do"
assert second_call_msgs[3]["role"] == "user"
assert "not a valid tool call" in second_call_msgs[3]["content"]
@pytest.mark.asyncio
async def test_unknown_tool_emits_assistant_before_nudge(self):
client = MockClient([
ToolCall(tool="nonexistent", args={"x": 1}),
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
second_call_msgs = client.send_calls[1][0]
assert len(second_call_msgs) == 4
assert second_call_msgs[2]["role"] == "assistant"
assert second_call_msgs[2]["tool_calls"][0]["function"]["name"] == "nonexistent"
assert second_call_msgs[3]["role"] == "tool"
assert "[UnknownTool]" in second_call_msgs[3]["content"]
assert "does not exist" in second_call_msgs[3]["content"]
@pytest.mark.asyncio
async def test_step_nudge_emits_assistant_before_nudge(self):
client = MockClient([
ToolCall(tool="submit", args={}), ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
second_call_msgs = client.send_calls[1][0]
assert len(second_call_msgs) == 4
assert second_call_msgs[2]["role"] == "assistant"
assert second_call_msgs[2]["tool_calls"][0]["function"]["name"] == "submit"
assert second_call_msgs[3]["role"] == "tool"
assert "[StepEnforcementError]" in second_call_msgs[3]["content"]
assert "cannot call submit yet" in second_call_msgs[3]["content"].lower()
@pytest.mark.asyncio
async def test_retry_nudge_message_metadata(self):
client = MockClient([
TextResponse(content="bad output"),
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
original_compact = ctx.maybe_compact
captured_messages: list[list[Message]] = []
def spy_compact(messages, step_index=0, step_hint=""):
captured_messages.append(list(messages))
return original_compact(messages, step_index=step_index, step_hint=step_hint)
ctx.maybe_compact = spy_compact
runner = WorkflowRunner(client=client, context_manager=ctx)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
second_iteration_msgs = captured_messages[1]
nudge_msgs = [m for m in second_iteration_msgs if m.metadata.type == MessageType.RETRY_NUDGE]
assert len(nudge_msgs) == 1
assert "not a valid tool call" in nudge_msgs[0].content
@pytest.mark.asyncio
async def test_step_nudge_message_metadata(self):
client = MockClient([
ToolCall(tool="submit", args={}), ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
original_compact = ctx.maybe_compact
captured_messages: list[list[Message]] = []
def spy_compact(messages, step_index=0, step_hint=""):
captured_messages.append(list(messages))
return original_compact(messages, step_index=step_index, step_hint=step_hint)
ctx.maybe_compact = spy_compact
runner = WorkflowRunner(client=client, context_manager=ctx)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
second_iteration_msgs = captured_messages[1]
nudge_msgs = [m for m in second_iteration_msgs if m.metadata.type == MessageType.STEP_NUDGE]
assert len(nudge_msgs) == 1
assert "cannot call submit yet" in nudge_msgs[0].content.lower()
class TestRescueToolCalls:
@pytest.mark.asyncio
async def test_rescue_json_from_text_response(self):
client = MockClient([
TextResponse(content='{"tool": "fetch", "args": {"key": "val"}}'),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
result = await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
assert len(client.send_calls) == 2
@pytest.mark.asyncio
async def test_rescue_rehearsal_syntax(self):
client = MockClient([
TextResponse(content='fetch[ARGS]{"key": "val"}'),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
result = await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
assert len(client.send_calls) == 2
@pytest.mark.asyncio
async def test_rescue_resets_retry_counter(self):
client = MockClient([
TextResponse(content="plain garbage"), TextResponse(content="more garbage"), TextResponse(content='{"tool": "fetch", "args": {}}'), TextResponse(content="garbage again"), TextResponse(content="still garbage"), TextResponse(content='{"tool": "fetch", "args": {}}'), ToolCall(tool="submit", args={}),
])
runner = _make_runner(client, max_retries_per_step=3, max_iterations=10)
result = await runner.run(
_make_workflow(required_steps=[]),
"go", prompt_vars={"role": "agent"},
)
assert result == "submit_result"
@pytest.mark.asyncio
async def test_rescue_unknown_tool_falls_through(self):
client = MockClient([
TextResponse(content='{"tool": "nonexistent", "args": {}}'),
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
result = await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_rescue_still_executes_tool(self):
received_args = {}
def capture_tool(**kwargs):
received_args.update(kwargs)
return "captured"
tools = {
"fetch": _make_tool("fetch", fn=capture_tool),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
TextResponse(content='{"tool": "fetch", "args": {"count": 42}}'),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert received_args == {"count": 42}
@pytest.mark.asyncio
async def test_rescue_satisfies_required_step(self):
client = MockClient([
TextResponse(content='{"tool": "fetch", "args": {}}'),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
result = await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_non_rescuable_text_still_nudges(self):
client = MockClient([
TextResponse(content="Let me think about this..."),
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
result = await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
assert len(client.send_calls) == 3
class TestReasoningCapture:
@pytest.mark.asyncio
async def test_reasoning_message_appended_before_tool_call(self):
client = MockClient([
ToolCall(tool="fetch", args={}, reasoning="I need to fetch the data first."),
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
original_compact = ctx.maybe_compact
captured_messages: list[list[Message]] = []
def spy_compact(messages, step_index=0, step_hint=""):
captured_messages.append(list(messages))
return original_compact(messages, step_index=step_index, step_hint=step_hint)
ctx.maybe_compact = spy_compact
runner = WorkflowRunner(client=client, context_manager=ctx)
result = await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
second_iteration_msgs = captured_messages[1]
reasoning_msgs = [m for m in second_iteration_msgs if m.metadata.type == MessageType.REASONING]
assert len(reasoning_msgs) == 1
assert reasoning_msgs[0].content == "I need to fetch the data first."
assert reasoning_msgs[0].role == MessageRole.ASSISTANT
types = [m.metadata.type for m in second_iteration_msgs]
reasoning_idx = types.index(MessageType.REASONING)
tool_call_idx = types.index(MessageType.TOOL_CALL)
assert reasoning_idx < tool_call_idx
@pytest.mark.asyncio
async def test_no_reasoning_message_when_reasoning_is_none(self):
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
original_compact = ctx.maybe_compact
captured_messages: list[list[Message]] = []
def spy_compact(messages, step_index=0, step_hint=""):
captured_messages.append(list(messages))
return original_compact(messages, step_index=step_index, step_hint=step_hint)
ctx.maybe_compact = spy_compact
runner = WorkflowRunner(client=client, context_manager=ctx)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
for msgs in captured_messages:
reasoning_msgs = [m for m in msgs if m.metadata.type == MessageType.REASONING]
assert len(reasoning_msgs) == 0
@pytest.mark.asyncio
async def test_reasoning_preserved_on_tool_error(self):
call_count = 0
def flaky_fetch(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ValueError("bad input")
return "ok"
tools = {
"fetch": _make_tool("fetch", fn=flaky_fetch),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}, reasoning="Let me try fetching."),
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
original_compact = ctx.maybe_compact
captured_messages: list[list[Message]] = []
def spy_compact(messages, step_index=0, step_hint=""):
captured_messages.append(list(messages))
return original_compact(messages, step_index=step_index, step_hint=step_hint)
ctx.maybe_compact = spy_compact
runner = WorkflowRunner(client=client, context_manager=ctx)
result = await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
second_iteration_msgs = captured_messages[1]
reasoning_msgs = [m for m in second_iteration_msgs if m.metadata.type == MessageType.REASONING]
assert len(reasoning_msgs) == 1
assert reasoning_msgs[0].content == "Let me try fetching."
types = [m.metadata.type for m in second_iteration_msgs]
reasoning_idx = types.index(MessageType.REASONING)
assert types[reasoning_idx + 1] == MessageType.TOOL_CALL
@pytest.mark.asyncio
async def test_reasoning_preserved_on_terminal_tool_error(self):
call_count = 0
def flaky_submit(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise TypeError("bad args")
return "submitted"
tools = {
"fetch": _make_tool("fetch"),
"submit": _make_tool("submit", fn=flaky_submit),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}, reasoning="Time to submit the result."),
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
original_compact = ctx.maybe_compact
captured_messages: list[list[Message]] = []
def spy_compact(messages, step_index=0, step_hint=""):
captured_messages.append(list(messages))
return original_compact(messages, step_index=step_index, step_hint=step_hint)
ctx.maybe_compact = spy_compact
runner = WorkflowRunner(client=client, context_manager=ctx)
result = await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert result == "submitted"
third_iteration_msgs = captured_messages[2]
reasoning_msgs = [m for m in third_iteration_msgs if m.metadata.type == MessageType.REASONING]
assert len(reasoning_msgs) == 1
assert reasoning_msgs[0].content == "Time to submit the result."
types = [m.metadata.type for m in third_iteration_msgs]
reasoning_idx = types.index(MessageType.REASONING)
assert types[reasoning_idx + 1] == MessageType.TOOL_CALL
@pytest.mark.asyncio
async def test_reasoning_folded_into_tool_call_on_wire(self):
client = MockClient([
ToolCall(tool="fetch", args={}, reasoning="Thinking about this..."),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
second_call_msgs = client.send_calls[1][0]
assert len(second_call_msgs) == 4
assert second_call_msgs[2]["role"] == "assistant"
assert second_call_msgs[2]["content"] == "Thinking about this..."
assert "tool_calls" in second_call_msgs[2]
@pytest.mark.asyncio
async def test_text_response_not_folded_into_tool_call(self):
client = MockClient([
TextResponse(content="Let me think about this..."),
ToolCall(tool="fetch", args={}, reasoning="Now I know what to do"),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
third_call_msgs = client.send_calls[2][0]
tc_msgs = [m for m in third_call_msgs if "tool_calls" in m]
assert len(tc_msgs) == 1
assert tc_msgs[0]["content"] == "Now I know what to do"
class TestOnMessageCallback:
@pytest.mark.asyncio
async def test_on_message_receives_all_messages(self):
collected: list[Message] = []
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
runner = WorkflowRunner(
client=client, context_manager=ctx, on_message=collected.append,
)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
types = [m.metadata.type for m in collected]
assert types == [
MessageType.SYSTEM_PROMPT,
MessageType.USER_INPUT,
MessageType.TOOL_CALL,
MessageType.TOOL_RESULT,
MessageType.TOOL_CALL,
MessageType.TOOL_RESULT,
]
@pytest.mark.asyncio
async def test_on_message_none_is_safe(self):
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client) result = await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_on_message_captures_retry_nudge(self):
collected: list[Message] = []
client = MockClient([
TextResponse(content="bad"),
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
runner = WorkflowRunner(
client=client, context_manager=ctx, on_message=collected.append,
)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
types = [m.metadata.type for m in collected]
assert MessageType.RETRY_NUDGE in types
@pytest.mark.asyncio
async def test_on_message_captures_step_nudge(self):
collected: list[Message] = []
client = MockClient([
ToolCall(tool="submit", args={}), ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
runner = WorkflowRunner(
client=client, context_manager=ctx, on_message=collected.append,
)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
types = [m.metadata.type for m in collected]
assert MessageType.STEP_NUDGE in types
@pytest.mark.asyncio
async def test_on_message_captures_tool_errors(self):
collected: list[Message] = []
def bad_fetch(**kwargs):
raise ValueError("broken")
tools = {
"fetch": _make_tool("fetch", fn=bad_fetch),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="fetch", args={}), ])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
runner = WorkflowRunner(
client=client, context_manager=ctx,
max_tool_errors=2, on_message=collected.append,
)
with pytest.raises(ToolExecutionError):
await runner.run(wf, "go", prompt_vars={"role": "agent"})
error_msgs = [m for m in collected if "[ToolError]" in m.content]
assert len(error_msgs) == 3
@pytest.mark.asyncio
async def test_on_message_captures_reasoning(self):
collected: list[Message] = []
client = MockClient([
ToolCall(tool="fetch", args={}, reasoning="Thinking..."),
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
runner = WorkflowRunner(
client=client, context_manager=ctx, on_message=collected.append,
)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
types = [m.metadata.type for m in collected]
assert MessageType.REASONING in types
reasoning = [m for m in collected if m.metadata.type == MessageType.REASONING]
assert reasoning[0].content == "Thinking..."
@pytest.mark.asyncio
async def test_on_message_captures_rescued_tool_call(self):
collected: list[Message] = []
client = MockClient([
TextResponse(content='{"tool": "fetch", "args": {"key": "val"}}'),
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
runner = WorkflowRunner(
client=client, context_manager=ctx, on_message=collected.append,
)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
types = [m.metadata.type for m in collected]
assert MessageType.RETRY_NUDGE not in types
assert MessageType.TEXT_RESPONSE not in types
assert MessageType.TOOL_CALL in types
assert MessageType.TOOL_RESULT in types
@pytest.mark.asyncio
async def test_on_message_with_streaming(self):
collected: list[Message] = []
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
runner = WorkflowRunner(
client=client, context_manager=ctx,
stream=True, on_message=collected.append,
)
await runner.run(_make_workflow(), "go", prompt_vars={"role": "agent"})
types = [m.metadata.type for m in collected]
assert types == [
MessageType.SYSTEM_PROMPT,
MessageType.USER_INPUT,
MessageType.TOOL_CALL,
MessageType.TOOL_RESULT,
MessageType.TOOL_CALL,
MessageType.TOOL_RESULT,
]
class TestInitialMessages:
@pytest.mark.asyncio
async def test_initial_messages_skips_system_and_user_init(self):
collected: list[Message] = []
client = MockClient([
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
runner = WorkflowRunner(
client=client, context_manager=ctx,
on_message=collected.append,
)
seed = [
Message(MessageRole.SYSTEM, "You are a tester.", MessageMeta(MessageType.SYSTEM_PROMPT)),
Message(MessageRole.USER, "do something", MessageMeta(MessageType.USER_INPUT)),
]
wf = _make_workflow(required_steps=[])
await runner.run(wf, "do something", initial_messages=seed)
types = [m.metadata.type for m in collected]
assert MessageType.SYSTEM_PROMPT not in types
assert MessageType.USER_INPUT not in types
assert MessageType.TOOL_CALL in types
assert MessageType.TOOL_RESULT in types
@pytest.mark.asyncio
async def test_initial_messages_included_in_api_call(self):
client = MockClient([
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
runner = WorkflowRunner(client=client, context_manager=ctx)
seed = [
Message(MessageRole.SYSTEM, "You are a tester.", MessageMeta(MessageType.SYSTEM_PROMPT)),
Message(MessageRole.USER, "first question", MessageMeta(MessageType.USER_INPUT)),
Message(MessageRole.ASSISTANT, "", MessageMeta(MessageType.TOOL_CALL),
tool_calls=[ToolCallInfo(name="fetch", args={}, call_id="c0")]),
Message(MessageRole.TOOL, "fetch_result", MessageMeta(MessageType.TOOL_RESULT),
tool_name="fetch", tool_call_id="c0"),
Message(MessageRole.USER, "follow-up", MessageMeta(MessageType.USER_INPUT)),
]
wf = _make_workflow(required_steps=[])
await runner.run(wf, "follow-up", initial_messages=seed)
api_msgs, _ = client.send_calls[0]
assert api_msgs[0]["role"] == "system"
assert api_msgs[0]["content"] == "You are a tester."
assert api_msgs[1]["role"] == "user"
assert api_msgs[1]["content"] == "first question"
@pytest.mark.asyncio
async def test_none_initial_messages_is_default(self):
collected: list[Message] = []
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
runner = WorkflowRunner(
client=client, context_manager=ctx,
on_message=collected.append,
)
wf = _make_workflow()
await runner.run(wf, "go", prompt_vars={"role": "agent"}, initial_messages=None)
types = [m.metadata.type for m in collected]
assert types[0] == MessageType.SYSTEM_PROMPT
assert types[1] == MessageType.USER_INPUT
class TestToolResolutionError:
@pytest.mark.asyncio
async def test_resolution_error_feeds_back_and_recovers(self):
call_count = 0
def lookup(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ToolResolutionError("No entry found for 'capital of France'. Try another key.")
return "The capital of France is Paris."
tools = {
"fetch": _make_tool("fetch", fn=lookup),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={"query": "capital of France"}),
ToolCall(tool="fetch", args={"query": "france"}),
ToolCall(tool="submit", args={}),
])
runner = _make_runner(client)
result = await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
assert call_count == 2
@pytest.mark.asyncio
async def test_resolution_error_message_fed_back_to_model(self):
def always_miss(**kwargs):
raise ToolResolutionError("No entry found for 'bad key'. Try another key.")
tools = {
"fetch": _make_tool("fetch", fn=always_miss),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="fetch", args={}),
ToolCall(tool="fetch", args={}),
])
runner = _make_runner(client, max_iterations=3)
with pytest.raises(MaxIterationsError):
await runner.run(wf, "go", prompt_vars={"role": "agent"})
second_call_msgs = client.send_calls[1][0]
error_msg = second_call_msgs[-1]["content"]
assert "[ToolResolutionError]" in error_msg
assert "No entry found" in error_msg
@pytest.mark.asyncio
async def test_resolution_error_does_not_increment_consecutive_tool_errors(self):
def always_miss(**kwargs):
raise ToolResolutionError("miss")
tools = {
"fetch": _make_tool("fetch", fn=always_miss),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}) for _ in range(5)
])
runner = _make_runner(client, max_iterations=5, max_tool_errors=1)
with pytest.raises(MaxIterationsError):
await runner.run(wf, "go", prompt_vars={"role": "agent"})
@pytest.mark.asyncio
async def test_resolution_error_does_not_record_step(self):
call_count = 0
def miss_then_hit(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ToolResolutionError("miss")
return "hit"
tools = {
"fetch": _make_tool("fetch", fn=miss_then_hit),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}), ToolCall(tool="submit", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="submit", args={}), ])
runner = _make_runner(client)
result = await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
assert call_count == 2
@pytest.mark.asyncio
async def test_resolution_error_bounded_by_max_iterations(self):
def always_miss(**kwargs):
raise ToolResolutionError("miss")
tools = {
"fetch": _make_tool("fetch", fn=always_miss),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}) for _ in range(4)
])
runner = _make_runner(client, max_iterations=3)
with pytest.raises(MaxIterationsError) as exc_info:
await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert exc_info.value.iterations == 3
@pytest.mark.asyncio
async def test_resolution_error_then_hard_error_counts_correctly(self):
call_count = 0
def mixed_errors(**kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ToolResolutionError("soft miss")
if call_count == 2:
raise ValueError("hard crash")
return "ok"
tools = {
"fetch": _make_tool("fetch", fn=mixed_errors),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="fetch", args={}), ToolCall(tool="submit", args={}),
])
runner = _make_runner(client, max_tool_errors=1)
result = await runner.run(wf, "go", prompt_vars={"role": "agent"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_resolution_error_on_message_callback(self):
collected: list[Message] = []
def miss(**kwargs):
raise ToolResolutionError("No entry found for 'x'. Try another key.")
tools = {
"fetch": _make_tool("fetch", fn=miss),
"submit": _make_tool("submit"),
}
wf = _make_workflow(tools=tools, required_steps=["fetch"])
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="fetch", args={}),
])
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
runner = WorkflowRunner(
client=client, context_manager=ctx,
max_iterations=2, on_message=collected.append,
)
with pytest.raises(MaxIterationsError):
await runner.run(wf, "go", prompt_vars={"role": "agent"})
error_msgs = [m for m in collected if "[ToolResolutionError]" in m.content]
assert len(error_msgs) == 2
assert all(m.metadata.type == MessageType.TOOL_RESULT for m in error_msgs)
@pytest.mark.asyncio
async def test_resolution_error_is_not_forge_error(self):
from forge.errors import ForgeError
err = ToolResolutionError("test")
assert isinstance(err, Exception)
assert not isinstance(err, ForgeError)
class TestPrerequisiteEnforcement:
@pytest.mark.asyncio
async def test_prereq_nudge_then_success(self):
tools = {
"read_file": _make_tool("read_file"),
"edit_file": ToolDef(
spec=ToolSpec(name="edit_file", description="Edit", parameters=EmptyParams),
callable=lambda **kwargs: "edited",
prerequisites=[{"tool": "read_file", "match_arg": "path"}],
),
"submit": _make_tool("submit"),
}
client = MockClient([
ToolCall(tool="edit_file", args={"path": "foo.py"}),
ToolCall(tool="read_file", args={"path": "foo.py"}),
ToolCall(tool="edit_file", args={"path": "foo.py"}),
ToolCall(tool="submit", args={}),
])
wf = _make_workflow(tools=tools, required_steps=[], terminal_tool="submit")
runner = _make_runner(client)
result = await runner.run(wf, "fix foo.py", prompt_vars={"role": "dev"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_prereq_nudge_emits_correct_message_type(self):
collected = []
tools = {
"read_file": _make_tool("read_file"),
"edit_file": ToolDef(
spec=ToolSpec(name="edit_file", description="Edit", parameters=EmptyParams),
callable=lambda **kwargs: "edited",
prerequisites=["read_file"],
),
"submit": _make_tool("submit"),
}
client = MockClient([
ToolCall(tool="edit_file", args={}),
ToolCall(tool="read_file", args={}),
ToolCall(tool="edit_file", args={}),
ToolCall(tool="submit", args={}),
])
wf = _make_workflow(tools=tools, required_steps=[], terminal_tool="submit")
runner = _make_runner(client)
runner.on_message = collected.append
await runner.run(wf, "go", prompt_vars={"role": "dev"})
prereq_nudges = [m for m in collected if m.metadata.type == MessageType.PREREQUISITE_NUDGE]
assert len(prereq_nudges) == 1
assert "read_file" in prereq_nudges[0].content
@pytest.mark.asyncio
async def test_prereq_exhaustion_raises(self):
tools = {
"read_file": _make_tool("read_file"),
"edit_file": ToolDef(
spec=ToolSpec(name="edit_file", description="Edit", parameters=EmptyParams),
callable=lambda **kwargs: "edited",
prerequisites=["read_file"],
),
"submit": _make_tool("submit"),
}
client = MockClient([
ToolCall(tool="edit_file", args={}),
ToolCall(tool="edit_file", args={}),
ToolCall(tool="edit_file", args={}),
ToolCall(tool="edit_file", args={}),
])
wf = _make_workflow(tools=tools, required_steps=[], terminal_tool="submit")
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
runner = WorkflowRunner(client=client, context_manager=ctx, max_iterations=10)
with pytest.raises(PrerequisiteError, match="read_file"):
await runner.run(wf, "go", prompt_vars={"role": "dev"})
@pytest.mark.asyncio
async def test_no_prereqs_no_interference(self):
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
wf = _make_workflow()
runner = _make_runner(client)
result = await runner.run(wf, "go", prompt_vars={"role": "dev"})
assert result == "submit_result"
@pytest.mark.asyncio
async def test_args_recorded_for_prereq_tracking(self):
collected = []
tools = {
"read_file": _make_tool("read_file"),
"edit_file": ToolDef(
spec=ToolSpec(name="edit_file", description="Edit", parameters=EmptyParams),
callable=lambda **kwargs: "edited",
prerequisites=[{"tool": "read_file", "match_arg": "path"}],
),
"submit": _make_tool("submit"),
}
client = MockClient([
ToolCall(tool="read_file", args={"path": "a.py"}),
ToolCall(tool="edit_file", args={"path": "b.py"}),
ToolCall(tool="read_file", args={"path": "b.py"}),
ToolCall(tool="edit_file", args={"path": "b.py"}),
ToolCall(tool="submit", args={}),
])
wf = _make_workflow(tools=tools, required_steps=[], terminal_tool="submit")
runner = _make_runner(client)
runner.on_message = collected.append
result = await runner.run(wf, "go", prompt_vars={"role": "dev"})
assert result == "submit_result"
prereq_nudges = [m for m in collected if m.metadata.type == MessageType.PREREQUISITE_NUDGE]
assert len(prereq_nudges) == 1
class TestMultipleTerminalTools:
@pytest.mark.asyncio
async def test_first_terminal_exits(self):
tools = {
"gather": _make_tool("gather"),
"set_ac": _make_tool("set_ac", fn=lambda **kw: "ac_on"),
"no_action": _make_tool("no_action", fn=lambda **kw: "skipped"),
}
client = MockClient([
ToolCall(tool="gather", args={}),
ToolCall(tool="set_ac", args={}),
])
wf = _make_workflow(
tools=tools, required_steps=["gather"],
terminal_tool=["set_ac", "no_action"],
)
runner = _make_runner(client)
result = await runner.run(wf, "manage ac", prompt_vars={"role": "agent"})
assert result == "ac_on"
@pytest.mark.asyncio
async def test_second_terminal_exits(self):
tools = {
"gather": _make_tool("gather"),
"set_ac": _make_tool("set_ac", fn=lambda **kw: "ac_on"),
"no_action": _make_tool("no_action", fn=lambda **kw: "skipped"),
}
client = MockClient([
ToolCall(tool="gather", args={}),
ToolCall(tool="no_action", args={}),
])
wf = _make_workflow(
tools=tools, required_steps=["gather"],
terminal_tool=["set_ac", "no_action"],
)
runner = _make_runner(client)
result = await runner.run(wf, "manage ac", prompt_vars={"role": "agent"})
assert result == "skipped"
@pytest.mark.asyncio
async def test_premature_terminal_blocked_for_all(self):
collected = []
tools = {
"gather": _make_tool("gather"),
"set_ac": _make_tool("set_ac"),
"no_action": _make_tool("no_action"),
}
client = MockClient([
ToolCall(tool="set_ac", args={}), ToolCall(tool="no_action", args={}), ToolCall(tool="gather", args={}), ToolCall(tool="set_ac", args={}), ])
wf = _make_workflow(
tools=tools, required_steps=["gather"],
terminal_tool=["set_ac", "no_action"],
)
runner = _make_runner(client)
runner.on_message = collected.append
await runner.run(wf, "go", prompt_vars={"role": "agent"})
step_nudges = [m for m in collected if m.metadata.type == MessageType.STEP_NUDGE]
assert len(step_nudges) == 2
class TestCancellation:
@pytest.mark.asyncio
async def test_cancel_before_start(self):
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
wf = _make_workflow()
runner = _make_runner(client)
cancel = asyncio.Event()
cancel.set()
with pytest.raises(WorkflowCancelledError) as exc_info:
await runner.run(wf, "go", prompt_vars={"role": "dev"}, cancel_event=cancel)
assert exc_info.value.iteration == 0
assert exc_info.value.completed_steps == {}
@pytest.mark.asyncio
async def test_cancel_mid_workflow(self):
cancel = asyncio.Event()
def cancel_after_fetch(**kwargs):
cancel.set()
return "fetch_result"
tools = {
"fetch": _make_tool("fetch", fn=cancel_after_fetch),
"submit": _make_tool("submit"),
}
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
wf = _make_workflow(tools=tools, required_steps=["fetch"])
runner = _make_runner(client)
with pytest.raises(WorkflowCancelledError) as exc_info:
await runner.run(wf, "go", prompt_vars={"role": "dev"}, cancel_event=cancel)
assert "fetch" in exc_info.value.completed_steps
assert exc_info.value.iteration == 1
@pytest.mark.asyncio
async def test_cancel_preserves_messages(self):
cancel = asyncio.Event()
def cancel_after_fetch(**kwargs):
cancel.set()
return "fetch_result"
tools = {
"fetch": _make_tool("fetch", fn=cancel_after_fetch),
"submit": _make_tool("submit"),
}
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
wf = _make_workflow(tools=tools, required_steps=["fetch"])
runner = _make_runner(client)
with pytest.raises(WorkflowCancelledError) as exc_info:
await runner.run(wf, "go", prompt_vars={"role": "dev"}, cancel_event=cancel)
messages = exc_info.value.messages
assert len(messages) > 0
types = [m.metadata.type for m in messages]
assert MessageType.SYSTEM_PROMPT in types
assert MessageType.USER_INPUT in types
assert MessageType.TOOL_RESULT in types
@pytest.mark.asyncio
async def test_no_cancel_event_runs_normally(self):
client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
wf = _make_workflow()
runner = _make_runner(client)
result = await runner.run(wf, "go", prompt_vars={"role": "dev"}, cancel_event=None)
assert result == "submit_result"
@pytest.mark.asyncio
async def test_unset_cancel_event_runs_normally(self):
cancel = asyncio.Event() client = MockClient([
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
wf = _make_workflow()
runner = _make_runner(client)
result = await runner.run(wf, "go", prompt_vars={"role": "dev"}, cancel_event=cancel)
assert result == "submit_result"
class TestCustomRetryNudge:
@pytest.mark.asyncio
async def test_custom_nudge_string(self):
collected = []
client = MockClient([
TextResponse(content="bare text"),
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
wf = _make_workflow()
runner = _make_runner(client)
runner.on_message = collected.append
runner._retry_nudge_fn = lambda _raw: "Wrap in respond tool."
result = await runner.run(wf, "go", prompt_vars={"role": "dev"})
nudges = [m for m in collected if m.metadata.type == MessageType.RETRY_NUDGE]
assert len(nudges) == 1
assert nudges[0].content == "Wrap in respond tool."
@pytest.mark.asyncio
async def test_custom_nudge_callable(self):
collected = []
client = MockClient([
TextResponse(content="my response"),
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
wf = _make_workflow()
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
runner = WorkflowRunner(
client=client, context_manager=ctx,
retry_nudge=lambda raw: f"Please use a tool. You said: {raw[:10]}",
)
runner.on_message = collected.append
result = await runner.run(wf, "go", prompt_vars={"role": "dev"})
nudges = [m for m in collected if m.metadata.type == MessageType.RETRY_NUDGE]
assert len(nudges) == 1
assert "Please use a tool. You said: my respons" in nudges[0].content
@pytest.mark.asyncio
async def test_string_retry_nudge_constructor(self):
ctx = ContextManager(strategy=NoCompact(), budget_tokens=100_000)
runner = WorkflowRunner(
client=MockClient([]),
context_manager=ctx,
retry_nudge="Use the respond tool.",
)
assert runner._retry_nudge_fn is not None
assert runner._retry_nudge_fn("anything") == "Use the respond tool."
@pytest.mark.asyncio
async def test_none_retry_nudge_uses_default(self):
collected = []
client = MockClient([
TextResponse(content="bare text"),
ToolCall(tool="fetch", args={}),
ToolCall(tool="submit", args={}),
])
wf = _make_workflow()
runner = _make_runner(client)
runner.on_message = collected.append
await runner.run(wf, "go", prompt_vars={"role": "dev"})
nudges = [m for m in collected if m.metadata.type == MessageType.RETRY_NUDGE]
assert len(nudges) == 1
assert "tool call" in nudges[0].content.lower()