from __future__ import annotations
import asyncio
import sys
import time
from collections.abc import AsyncIterator, Callable
from dataclasses import dataclass, field
from typing import Any
from forge.clients.base import ChunkType, LLMClient, StreamChunk
from forge.context.manager import CompactEvent, ContextManager
from forge.context.strategies import CompactStrategy, NoCompact, SlidingWindowCompact, TieredCompact
from forge.core.messages import Message, MessageType
from forge.core.runner import WorkflowRunner
from forge.core.workflow import ToolCall, ToolDef, ToolSpec, Workflow
from forge.errors import ForgeError, StreamError
from forge.server import BudgetMode, ServerManager
from tests.eval.ablation import ABLATION_PRESETS, AblationConfig
from tests.eval.scenarios import ALL_SCENARIOS, EvalScenario
_COMPACTION_SCENARIOS = {
"compaction_stress", "phase2_compaction",
"compaction_stress_stateful", "phase2_compaction_stateful",
"inventory_audit", "supplier_deep_dive",
"compaction_chain_p1", "compaction_chain_p2", "compaction_chain_p3",
}
@dataclass
class RunResult:
scenario_name: str
completeness: bool
iterations_used: int
terminal_args: dict[str, Any] | None = None
accuracy: bool | None = None
validate_error: str | None = None
error_type: str | None = None
error_message: str | None = None
compaction_events: list[CompactEvent] = field(default_factory=list)
messages: list[Message] | None = None
elapsed_seconds: float = 0.0
stream_retries: int = 0
input_tokens: int = 0
output_tokens: int = 0
cost_usd: float = 0.0
@dataclass
class EvalConfig:
runs_per_scenario: int = 10
stream: bool = False
compact_strategy: CompactStrategy | None = None
strategy_overrides: dict[str, CompactStrategy] = field(default_factory=dict)
keep_message_history: bool = True
verbose: bool = False
budget_override: int | None = None
stream_retries: int = 2
class CountingClientWrapper:
def __init__(self, client: LLMClient) -> None:
self._client = client
self.call_count = 0
self.total_input_tokens = 0
self.total_output_tokens = 0
def __getattr__(self, name: str) -> Any:
return getattr(self._client, name)
def _collect_usage(self) -> None:
usage = getattr(self._client, "last_usage", None)
if usage:
for tu in usage.values():
self.total_input_tokens += tu.prompt_tokens
self.total_output_tokens += tu.completion_tokens
async def send(
self,
messages: list[dict[str, str]],
tools: list[ToolSpec] | None = None,
sampling: dict[str, Any] | None = None,
passthrough: dict[str, Any] | None = None,
inbound_anthropic_body: dict[str, Any] | None = None,
) -> Any:
self.call_count += 1
result = await self._client.send(
messages, tools=tools, sampling=sampling, passthrough=passthrough,
inbound_anthropic_body=inbound_anthropic_body,
)
self._collect_usage()
return result
async def send_stream(
self,
messages: list[dict[str, str]],
tools: list[ToolSpec] | None = None,
sampling: dict[str, Any] | None = None,
passthrough: dict[str, Any] | None = None,
inbound_anthropic_body: dict[str, Any] | None = None,
) -> AsyncIterator[StreamChunk]:
self.call_count += 1
async for chunk in self._client.send_stream(
messages, tools=tools, sampling=sampling, passthrough=passthrough,
inbound_anthropic_body=inbound_anthropic_body,
):
yield chunk
self._collect_usage()
async def get_context_length(self) -> int | None:
return await self._client.get_context_length()
def _resolve_strategy(
scenario: EvalScenario,
config: EvalConfig,
) -> CompactStrategy:
for tag in scenario.tags:
if tag in config.strategy_overrides:
return config.strategy_overrides[tag]
return config.compact_strategy or NoCompact()
def _build_workflow_with_capture(
scenario: EvalScenario,
ablation: AblationConfig | None = None,
) -> tuple[Workflow, dict[str, Any], Callable[[], bool] | None]:
capture: dict[str, Any] = {}
validate_state_fn: Callable[[], bool] | None = None
if scenario.build_workflow is not None:
base_workflow, validate_state_fn = scenario.build_workflow()
else:
base_workflow = scenario.workflow
validate_state_fn = scenario.validate_state
tools = dict(base_workflow.tools)
for tt_name in base_workflow.terminal_tools:
original_fn = base_workflow.get_callable(tt_name)
terminal_spec = base_workflow.tools[tt_name].spec
def capturing_terminal(_fn=original_fn, **kwargs: Any) -> Any:
capture["args"] = kwargs
return _fn(**kwargs)
tools[tt_name] = ToolDef(
spec=terminal_spec,
callable=capturing_terminal,
)
required_steps = base_workflow.required_steps
if ablation is not None and not ablation.step_enforcement_enabled:
required_steps = []
workflow = Workflow(
name=base_workflow.name,
description=base_workflow.description,
tools=tools,
required_steps=required_steps,
terminal_tool=base_workflow.terminal_tool,
system_prompt_template=base_workflow.system_prompt_template,
)
return workflow, capture, validate_state_fn
def _verbose_printer(msg: Message) -> None:
_MAX = 120
match msg.metadata.type:
case MessageType.TOOL_CALL:
if msg.tool_calls:
names = [tc.name for tc in msg.tool_calls]
label = ", ".join(names)
if len(names) > 1:
print(f" [tool_call] *** PARALLEL {len(names)} *** {label}")
else:
print(f" [tool_call] {label}")
else:
print(f" [tool_call] {msg.content}")
case MessageType.TOOL_RESULT:
text = msg.content[:_MAX] + "..." if len(msg.content) > _MAX else msg.content
print(f" [result] {text}")
case MessageType.REASONING:
text = msg.content[:_MAX] + "..." if len(msg.content) > _MAX else msg.content
print(f" [thinking] {text}")
case MessageType.RETRY_NUDGE:
print(" [nudge] retry")
case MessageType.STEP_NUDGE:
print(" [nudge] step enforcement")
async def run_scenario(
client: LLMClient,
scenario: EvalScenario,
config: EvalConfig,
ablation: AblationConfig | None = None,
) -> RunResult:
counting_client = CountingClientWrapper(client)
compaction_events: list[CompactEvent] = []
collected_messages: list[Message] = []
if ablation is not None and not ablation.compaction_enabled:
strategy = NoCompact()
else:
strategy = _resolve_strategy(scenario, config)
budget = config.budget_override if config.budget_override is not None else scenario.budget_tokens
ctx = ContextManager(
strategy=strategy,
budget_tokens=budget,
on_compact=compaction_events.append,
)
workflow, capture, validate_state_fn = _build_workflow_with_capture(scenario, ablation=ablation)
callbacks: list[Any] = []
if config.verbose:
callbacks.append(_verbose_printer)
if config.keep_message_history:
callbacks.append(collected_messages.append)
if not callbacks:
on_message = None
elif len(callbacks) == 1:
on_message = callbacks[0]
else:
def on_message(msg: Message) -> None:
for cb in callbacks:
cb(msg)
max_retries = scenario.max_retries_per_step
max_tool_errors = scenario.max_tool_errors
rescue_enabled = True
if ablation is not None:
max_retries = ablation.max_retries_per_step
max_tool_errors = ablation.max_tool_errors
rescue_enabled = ablation.rescue_enabled
runner = WorkflowRunner(
client=counting_client,
context_manager=ctx,
max_iterations=scenario.max_iterations,
max_retries_per_step=max_retries,
max_tool_errors=max_tool_errors,
stream=config.stream,
on_message=on_message,
rescue_enabled=rescue_enabled,
)
start = time.monotonic()
last_stream_error: StreamError | None = None
for attempt in range(1 + config.stream_retries):
if attempt > 0:
print(f" [retry {attempt}/{config.stream_retries}] StreamError, retrying...", flush=True)
counting_client.call_count = 0
counting_client.total_input_tokens = 0
counting_client.total_output_tokens = 0
compaction_events.clear()
collected_messages.clear()
workflow, capture, validate_state_fn = _build_workflow_with_capture(scenario, ablation=ablation)
start = time.monotonic()
try:
await runner.run(workflow, scenario.user_message)
elapsed = time.monotonic() - start
accuracy: bool | None = None
validate_error: str | None = None
if scenario.validate and capture.get("args") is not None:
try:
accuracy = scenario.validate(capture["args"])
except Exception as exc:
accuracy = None
validate_error = type(exc).__name__
if validate_state_fn is not None:
try:
state_ok = validate_state_fn()
if accuracy is None:
accuracy = state_ok
else:
accuracy = accuracy and state_ok
except Exception as exc:
accuracy = False
validate_error = f"validate_state: {type(exc).__name__}"
return RunResult(
scenario_name=scenario.name,
completeness=True,
iterations_used=counting_client.call_count,
terminal_args=capture.get("args"),
accuracy=accuracy,
validate_error=validate_error,
compaction_events=compaction_events,
messages=collected_messages if config.keep_message_history else None,
elapsed_seconds=elapsed,
stream_retries=attempt,
input_tokens=counting_client.total_input_tokens,
output_tokens=counting_client.total_output_tokens,
)
except StreamError as exc:
last_stream_error = exc
continue
except ForgeError as exc:
elapsed = time.monotonic() - start
return RunResult(
scenario_name=scenario.name,
completeness=False,
iterations_used=counting_client.call_count,
error_type=type(exc).__name__,
error_message=str(exc),
compaction_events=compaction_events,
messages=collected_messages if config.keep_message_history else None,
elapsed_seconds=elapsed,
stream_retries=attempt,
input_tokens=counting_client.total_input_tokens,
output_tokens=counting_client.total_output_tokens,
)
except Exception as exc:
elapsed = time.monotonic() - start
return RunResult(
scenario_name=scenario.name,
completeness=False,
iterations_used=counting_client.call_count,
error_type=type(exc).__name__,
error_message=str(exc),
compaction_events=compaction_events,
messages=collected_messages if config.keep_message_history else None,
elapsed_seconds=elapsed,
stream_retries=attempt,
input_tokens=counting_client.total_input_tokens,
output_tokens=counting_client.total_output_tokens,
)
elapsed = time.monotonic() - start
assert last_stream_error is not None
return RunResult(
scenario_name=scenario.name,
completeness=False,
iterations_used=counting_client.call_count,
error_type=type(last_stream_error).__name__,
error_message=str(last_stream_error),
compaction_events=compaction_events,
messages=collected_messages if config.keep_message_history else None,
elapsed_seconds=elapsed,
stream_retries=config.stream_retries,
input_tokens=counting_client.total_input_tokens,
output_tokens=counting_client.total_output_tokens,
)
async def run_eval(
client: LLMClient,
scenarios: list[EvalScenario],
config: EvalConfig,
resolved_budget: int | None = None,
tags: list[str] | None = None,
names: list[str] | None = None,
ablation: AblationConfig | None = None,
) -> dict[str, list[RunResult]]:
if tags:
scenarios = [
s for s in scenarios if any(t in s.tags for t in tags)
]
if names:
scenarios = [s for s in scenarios if s.name in names]
results: dict[str, list[RunResult]] = {}
for scenario in scenarios:
if scenario.name in _COMPACTION_SCENARIOS and ablation is not None and not ablation.compaction_enabled:
print(f" Skipping {scenario.name} (compaction disabled by ablation={ablation.name})")
continue
if scenario.name in _COMPACTION_SCENARIOS:
scenario_budget = scenario.budget_tokens
else:
scenario_budget = resolved_budget
if scenario_budget is not None and hasattr(client, "set_num_ctx"):
client.set_num_ctx(scenario_budget)
per_scenario_config = EvalConfig(
runs_per_scenario=config.runs_per_scenario,
stream=config.stream,
compact_strategy=config.compact_strategy,
strategy_overrides=config.strategy_overrides,
keep_message_history=config.keep_message_history,
verbose=config.verbose,
budget_override=scenario_budget,
stream_retries=config.stream_retries,
)
scenario_results: list[RunResult] = []
for run_idx in range(config.runs_per_scenario):
print(
f" Running {scenario.name} "
f"[{run_idx + 1}/{config.runs_per_scenario}]...",
flush=True,
)
result = await run_scenario(client, scenario, per_scenario_config, ablation=ablation)
scenario_results.append(result)
if not result.completeness:
status = f"FAIL ({result.error_type})"
elif result.accuracy is False:
status = "OK (incorrect)"
else:
status = "OK"
cost_str = ""
if result.input_tokens:
from tests.eval.batch_eval import _compute_cost
cost = _compute_cost(
client.model if hasattr(client, "model") else "",
result.input_tokens,
result.output_tokens,
)
if cost > 0:
cost_str = f", ${cost:.4f}"
print(
f" {status} — {result.iterations_used} iterations, "
f"{result.elapsed_seconds:.1f}s{cost_str}",
flush=True,
)
results[scenario.name] = scenario_results
return results
async def main() -> None:
import argparse
budget_choices = [m.value for m in BudgetMode]
parser = argparse.ArgumentParser(description="Forge eval harness")
parser.add_argument(
"--backend",
choices=["ollama", "llamafile", "anthropic"],
default="ollama",
)
parser.add_argument(
"--model",
help="Model name (e.g. ministral-3:14b for Ollama, claude-... for Anthropic). "
"Required for ollama/anthropic backends; rejected for llamafile (use --gguf).",
)
parser.add_argument(
"--gguf",
help="Path to GGUF / llamafile model file. Required for llamafile backend; "
"rejected for ollama/anthropic (use --model).",
)
parser.add_argument("--runs", type=int, default=10)
parser.add_argument("--stream", action="store_true")
parser.add_argument(
"--think",
choices=["true", "false", "auto"],
default="auto",
help="Think mode: true/false/auto. Ollama: controls think param in request. "
"Llamafile: true/auto captures [THINK] tags from content, false discards them.",
)
parser.add_argument("--tags", nargs="*", help="Filter scenarios by tag")
parser.add_argument("--scenario", nargs="*", help="Run specific scenario(s) by name")
parser.add_argument(
"--llamafile-mode",
choices=["native", "prompt", "auto"],
default="auto",
)
parser.add_argument(
"--budget-mode",
choices=budget_choices,
default=BudgetMode.FORGE_FULL.value,
help="Budget mode (prod BudgetMode). Compaction scenarios always override with their own budget.",
)
parser.add_argument(
"--num-ctx",
type=int,
default=None,
help="Exact token budget (requires --budget-mode manual).",
)
parser.add_argument(
"--no-history",
action="store_true",
help="Disable message history collection",
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Print live per-message trace during each run",
)
parser.add_argument(
"--probe",
action="store_true",
help="Print resolved budget from backend and exit (no eval run)",
)
parser.add_argument(
"--base-url",
default=None,
help="Override backend base URL (e.g. http://172.x.x.x:8080/v1 for WSL->Windows)",
)
parser.add_argument(
"--ablation",
choices=list(ABLATION_PRESETS.keys()),
default="reforged",
help="Ablation preset: selectively disable guardrails (default: reforged = all enabled)",
)
parser.add_argument(
"--tool-choice",
choices=["auto", "any"],
default=None,
help="Anthropic tool_choice type (default: auto). 'any' forces tool calls.",
)
parser.add_argument(
"--no-cache-prompt",
action="store_true",
help="Disable llama-server prompt caching (default: enabled)",
)
parser.add_argument(
"--compact-strategy",
choices=["tiered", "sliding", "none"],
default=None,
help="Override compaction strategy for all scenarios. "
"tiered=TieredCompact (default for compaction scenarios), "
"sliding=SlidingWindowCompact, none=NoCompact (context grows unbounded).",
)
args = parser.parse_args()
budget_mode = BudgetMode(args.budget_mode)
if budget_mode == BudgetMode.MANUAL and args.num_ctx is None:
parser.error("--budget-mode manual requires --num-ctx")
if args.backend in ("ollama", "anthropic"):
if args.gguf:
parser.error(f"--backend {args.backend} does not accept --gguf (use --model)")
if not args.model:
parser.error(f"--backend {args.backend} requires --model")
else: if args.model:
parser.error("--backend llamafile does not accept --model (use --gguf)")
if not args.gguf:
parser.error("--backend llamafile requires --gguf")
from pathlib import Path as _Path
display_name = args.model if args.backend != "llamafile" else _Path(args.gguf).stem
url_kw: dict = {"base_url": args.base_url} if args.base_url else {}
if args.backend == "ollama":
from forge.clients.ollama import OllamaClient
think_val = {"true": True, "false": False, "auto": None}[args.think]
client: LLMClient = OllamaClient(
model=args.model, think=think_val, **url_kw,
recommended_sampling=True,
)
elif args.backend == "anthropic":
from forge.clients.anthropic import AnthropicClient
client = AnthropicClient(model=args.model, tool_choice=args.tool_choice)
else:
from forge.clients.llamafile import LlamafileClient
think_val = {"true": True, "false": False, "auto": None}[args.think]
client = LlamafileClient(
gguf_path=args.gguf, mode=args.llamafile_mode, think=think_val,
cache_prompt=not args.no_cache_prompt, **url_kw,
recommended_sampling=True,
)
if args.backend == "anthropic":
resolved_budget = 200_000
else:
server = ServerManager(backend=args.backend)
try:
resolved_budget = await server.resolve_budget(budget_mode, args.num_ctx)
except Exception as exc:
print(f"ERROR: Cannot resolve budget: {exc}", file=sys.stderr)
print("Make sure the backend is running before starting the eval.", file=sys.stderr)
sys.exit(1)
if hasattr(client, "set_num_ctx"):
client.set_num_ctx(resolved_budget)
if args.probe:
print(f"Budget mode: {budget_mode.value}")
print(f"Resolved budget: {resolved_budget} tokens")
sys.exit(0)
_STRATEGY_MAP = {
"tiered": TieredCompact(keep_recent=2),
"sliding": SlidingWindowCompact(keep_recent=2),
"none": NoCompact(),
}
if args.compact_strategy is not None:
cli_strategy = _STRATEGY_MAP[args.compact_strategy]
config = EvalConfig(
runs_per_scenario=args.runs,
stream=args.stream,
keep_message_history=not args.no_history,
verbose=args.verbose,
budget_override=resolved_budget,
compact_strategy=cli_strategy,
strategy_overrides={},
)
else:
config = EvalConfig(
runs_per_scenario=args.runs,
stream=args.stream,
keep_message_history=not args.no_history,
verbose=args.verbose,
budget_override=resolved_budget,
strategy_overrides={
"compaction": TieredCompact(keep_recent=2),
},
)
ablation = ABLATION_PRESETS[args.ablation]
strategy_label = args.compact_strategy or "auto"
print(
f"\nForge Eval — backend: {args.backend}, model: {display_name}, "
f"runs: {args.runs}, stream: {args.stream}, budget-mode: {budget_mode.value}"
)
print(f"Resolved budget: {resolved_budget} tokens")
print(f"Compact strategy: {strategy_label}")
print(f"Ablation: {ablation.name}")
print(f"Tags filter: {args.tags or 'all'}")
print(f"Scenario filter: {args.scenario or 'all'}")
print()
results = await run_eval(
client, ALL_SCENARIOS, config,
resolved_budget=resolved_budget,
tags=args.tags, names=args.scenario,
ablation=ablation,
)
from tests.eval.metrics import print_report
print_report(results, scenarios=ALL_SCENARIOS, model_name=display_name)
all_runs = [r for runs in results.values() for r in runs]
total_input = sum(r.input_tokens for r in all_runs)
total_output = sum(r.output_tokens for r in all_runs)
if total_input:
from tests.eval.batch_eval import _compute_cost
total_cost = _compute_cost(args.model, total_input, total_output)
print(
f"Token usage: {total_input:,} input + {total_output:,} output"
f" = {total_input + total_output:,} total"
)
if total_cost > 0:
n_runs = len(all_runs)
print(f"Total cost: ${total_cost:.4f} ({n_runs} runs, ${total_cost / n_runs:.4f}/run)")
if __name__ == "__main__":
asyncio.run(main())