import json
import os
import sys
import re
NB_PATH = os.path.join(
os.path.dirname(__file__),
"..", "notebook", "toolcall_verifier_training_production_colab_v5.ipynb",
)
NB_PATH = os.path.normpath(NB_PATH)
def load_nb():
with open(NB_PATH, encoding="utf-8") as f:
return json.load(f)
def save_nb(nb):
with open(NB_PATH, "w", encoding="utf-8") as f:
json.dump(nb, f, indent=1, ensure_ascii=False)
f.write("\n")
def cell_src(cell):
return "".join(cell["source"])
def set_cell_src(cell, new_src):
lines = new_src.splitlines(keepends=True)
if lines and lines[-1].endswith("\n"):
lines[-1] = lines[-1][:-1]
cell["source"] = lines
def find_cell_by_marker(cells, marker):
for i, c in enumerate(cells):
if marker in cell_src(c):
return i
return None
SERIALIZE_V3_ADDITIONS = '''
def serialize_candidate_tool_schema(
tools: List[Dict[str, Any]],
candidate: Any,
) -> str:
"""Return the full parameter schema for the candidate tool only."""
names = set(candidate_tool_names(candidate))
for t in tools:
normalized = normalize_tool_for_prompt(t)
if normalized.get("name") in names:
return f"{normalized['name']}: {normalized['description']}\\nPARAMETERS: {compact_json(normalized['parameters'], 2400)}"
return ""
def serialize_competing_tool_signatures(
tools: List[Dict[str, Any]],
candidate: Any,
max_tools: int = 12,
) -> str:
"""Return name + description only (no parameters) for non-candidate tools."""
names = set(candidate_tool_names(candidate))
lines = []
for t in tools:
normalized = normalize_tool_for_prompt(t)
n = normalized.get("name", "")
if n not in names:
lines.append(f"{n}: {normalized.get('description', '')}")
if len(lines) >= max_tools:
break
return "\\n".join(lines)
def serialize_state_v3(input_obj: Dict[str, Any]) -> str:
"""Candidate-first layout: candidate call and its schema appear before the tool list.
Truncation at max_length will eat OTHER_AVAILABLE_TOOLS tail first,
preserving the most semantically important tokens."""
ws = input_obj["workflow_state"]
metadata = input_obj.get("metadata") or {}
candidate_schema = serialize_candidate_tool_schema(
input_obj["available_tools"], input_obj["candidate_call"]
)
competing_sigs = serialize_competing_tool_signatures(
input_obj["available_tools"], input_obj["candidate_call"]
)
return f"""SCHEMA_VERSION:
{input_obj['schema_version']}
USER_REQUEST:
{input_obj['user_request']}
CANDIDATE_CALL:
{compact_json(input_obj['candidate_call'], 2400)}
CANDIDATE_TOOL_SCHEMA:
{candidate_schema}
WORKFLOW_STATE:
required_steps={ws.get('required_steps', [])}
completed_steps={ws.get('completed_steps', [])}
pending_steps={ws.get('pending_steps', [])}
terminal_tools={ws.get('terminal_tools', [])}
recent_errors={ws.get('recent_errors', [])}
OTHER_AVAILABLE_TOOLS:
{competing_sigs}
SCORING_METADATA:
scenario_family={_json_or_null(metadata.get('scenario_family'))}
requires_transform={_json_or_null(metadata.get('requires_transform'))}
requires_synthesis={_json_or_null(metadata.get('requires_synthesis'))}
requires_all_tool_facts={_json_or_null(metadata.get('requires_all_tool_facts'))}
must_acknowledge_missing_data={_json_or_null(metadata.get('must_acknowledge_missing_data'))}""".strip()
'''
SERIALIZE_FROM_OBJECT_OLD = """def serialize_state_from_object(input_obj: Dict[str, Any]) -> str:
if SERIALIZER_VERSION == "serialize_state_v2":
return serialize_state_v2(input_obj)
return serialize_state_v1(input_obj)"""
SERIALIZE_FROM_OBJECT_NEW = """def serialize_state_from_object(input_obj: Dict[str, Any]) -> str:
if SERIALIZER_VERSION == "serialize_state_v3":
return serialize_state_v3(input_obj)
if SERIALIZER_VERSION == "serialize_state_v2":
return serialize_state_v2(input_obj)
return serialize_state_v1(input_obj)"""
FIXTURE_V2_BLOCK_OLD = """ACTIVE_SERIALIZER_FIXTURE = SERIALIZER_FIXTURE_V2 if USE_SERIALIZER_V2 else SERIALIZER_FIXTURE
(DATA_DIR / "serializer_fixture.json").write_text(json.dumps({
"input": ACTIVE_SERIALIZER_FIXTURE,
"serialized": serialize_state_from_object(ACTIVE_SERIALIZER_FIXTURE),
}, indent=2))
print(serialize_state_from_object(ACTIVE_SERIALIZER_FIXTURE))"""
FIXTURE_V3_BLOCK_NEW = """SERIALIZER_FIXTURE_V3 = build_input_object(
user_request=SERIALIZER_FIXTURE["user_request"],
tools=SERIALIZER_FIXTURE["available_tools"],
candidate=SERIALIZER_FIXTURE["candidate_call"],
required_steps=SERIALIZER_FIXTURE["workflow_state"]["required_steps"],
completed_steps=SERIALIZER_FIXTURE["workflow_state"]["completed_steps"],
pending_steps=SERIALIZER_FIXTURE["workflow_state"]["pending_steps"],
terminal_tools=SERIALIZER_FIXTURE["workflow_state"]["terminal_tools"],
scoring_metadata=infer_scoring_metadata("argument_transformation"),
schema_version=TOOLCALL_INPUT_SCHEMA_VERSION_V2,
)
(DATA_DIR / "serializer_fixture_v3.json").write_text(json.dumps({
"input": SERIALIZER_FIXTURE_V3,
"serialized": serialize_state_v3(SERIALIZER_FIXTURE_V3),
}, indent=2))
if SERIALIZER_VERSION == "serialize_state_v3":
ACTIVE_SERIALIZER_FIXTURE = SERIALIZER_FIXTURE_V3
elif USE_SERIALIZER_V2:
ACTIVE_SERIALIZER_FIXTURE = SERIALIZER_FIXTURE_V2
else:
ACTIVE_SERIALIZER_FIXTURE = SERIALIZER_FIXTURE
(DATA_DIR / "serializer_fixture.json").write_text(json.dumps({
"input": ACTIVE_SERIALIZER_FIXTURE,
"serialized": serialize_state_from_object(ACTIVE_SERIALIZER_FIXTURE),
}, indent=2))
print(serialize_state_from_object(ACTIVE_SERIALIZER_FIXTURE))"""
T4_PROVEN_OLD = """ # Proven T4 baseline from the first completed run:
# epoch 4 validation macro_f1=0.741871, accuracy=0.839146, runtime about 1h25m on T4.
"t4_proven": {
"max_per_source": 5_000,
"max_length": 768,
"epochs": 4,
"train_batch_size": 8,
"eval_batch_size": 16,
"grad_accum": 4,
"learning_rate": 1e-5,
"warmup_ratio": 0.08,
"early_stopping_patience": 2,
"max_tools_in_prompt": 24,
"dataloader_num_workers": 2,
"gradient_checkpointing": False,
"optimizer": "adamw_torch_fused" if torch.cuda.is_available() else "adamw_torch",
},"""
T4_PROVEN_NEW = """ # Proven T4 baseline — updated for candidate-first serializer experiment.
# max_length raised 768->1024 to cover p95 token lengths from truncation diagnostics.
# max_per_source reduced 5000->4000 and grad_accum raised 4->6 to stay within T4 VRAM budget.
# epoch 4 validation macro_f1=0.741871, accuracy=0.839146, runtime about 1h25m on T4 (pre-update baseline).
"t4_proven": {
"max_per_source": 4_000,
"max_length": 1024,
"epochs": 5,
"train_batch_size": 8,
"eval_batch_size": 16,
"grad_accum": 6,
"learning_rate": 1e-5,
"warmup_ratio": 0.08,
"early_stopping_patience": 2,
"max_tools_in_prompt": 24,
"dataloader_num_workers": 2,
"gradient_checkpointing": False,
"optimizer": "adamw_torch_fused" if torch.cuda.is_available() else "adamw_torch",
},"""
SERIALIZER_VERSION_OLD = """USE_SERIALIZER_V2 = True #@param {type:"boolean"}
INPUT_SCHEMA_VERSION = TOOLCALL_INPUT_SCHEMA_VERSION_V2 if USE_SERIALIZER_V2 else TOOLCALL_INPUT_SCHEMA_VERSION_V1
SERIALIZER_VERSION = "serialize_state_v2" if USE_SERIALIZER_V2 else "serialize_state_v1" """
SERIALIZER_VERSION_NEW = """USE_SERIALIZER_V2 = True #@param {type:"boolean"}
USE_SERIALIZER_V3 = True #@param {type:"boolean"} # Candidate-first layout; v1/v2 kept frozen for artifact compat
INPUT_SCHEMA_VERSION = TOOLCALL_INPUT_SCHEMA_VERSION_V2 if (USE_SERIALIZER_V2 or USE_SERIALIZER_V3) else TOOLCALL_INPUT_SCHEMA_VERSION_V1
if USE_SERIALIZER_V3:
SERIALIZER_VERSION = "serialize_state_v3"
elif USE_SERIALIZER_V2:
SERIALIZER_VERSION = "serialize_state_v2"
else:
SERIALIZER_VERSION = "serialize_state_v1" """
CONTRASTIVE_PAIRS_BLOCK = '''
# ---------------------------------------------------------------------------
# Contrastive wrong_tool_semantic pairs
# Each triplet emits: valid candidate + wrong_tool_semantic candidate with same group_id.
# Focus on tool families with overlapping argument shapes: get/set, search/fetch,
# completed-step repeats, premature terminal use, scoped vs global search.
# ---------------------------------------------------------------------------
FORGE_CONTRASTIVE_WRONG_TOOL_PAIRS = [
# ---- Fetch / Update pairs ----
{
"user_request": "Retrieve the order details for order ID 00812.",
"tools": [
{"name": "get_order", "description": "Fetch details for a specific order.", "parameters": {"type": "object", "properties": {"order_id": {"type": "string"}}, "required": ["order_id"]}},
{"name": "update_order", "description": "Modify an existing order.", "parameters": {"type": "object", "properties": {"order_id": {"type": "string"}, "status": {"type": "string"}}, "required": ["order_id", "status"]}},
{"name": "report", "description": "Produce final report.", "parameters": {"type": "object", "properties": {"findings": {"type": "string"}}, "required": ["findings"]}},
],
"valid_candidate": {"name": "get_order", "arguments": {"order_id": "00812"}},
"wrong_tool_candidate": {"name": "update_order", "arguments": {"order_id": "00812", "status": "pending"}},
"required_steps": ["get_order"], "completed_steps": [], "pending_steps": ["get_order"], "terminal_tools": ["report"],
},
{
"user_request": "Look up the inventory count for SKU-4421.",
"tools": [
{"name": "get_inventory", "description": "Look up stock levels for a given SKU.", "parameters": {"type": "object", "properties": {"sku": {"type": "string"}}, "required": ["sku"]}},
{"name": "set_inventory", "description": "Update stock levels for a given SKU.", "parameters": {"type": "object", "properties": {"sku": {"type": "string"}, "count": {"type": "integer"}}, "required": ["sku", "count"]}},
{"name": "respond", "description": "Send final response.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
],
"valid_candidate": {"name": "get_inventory", "arguments": {"sku": "SKU-4421"}},
"wrong_tool_candidate": {"name": "set_inventory", "arguments": {"sku": "SKU-4421", "count": 0}},
"required_steps": ["get_inventory"], "completed_steps": [], "pending_steps": ["get_inventory"], "terminal_tools": ["respond"],
},
{
"user_request": "Show me the customer record for customer 10045.",
"tools": [
{"name": "get_customer", "description": "Retrieve a customer record by ID.", "parameters": {"type": "object", "properties": {"customer_id": {"type": "integer"}}, "required": ["customer_id"]}},
{"name": "update_customer", "description": "Update fields on a customer record.", "parameters": {"type": "object", "properties": {"customer_id": {"type": "integer"}, "email": {"type": "string"}}, "required": ["customer_id"]}},
{"name": "respond", "description": "Send final response.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
],
"valid_candidate": {"name": "get_customer", "arguments": {"customer_id": 10045}},
"wrong_tool_candidate": {"name": "update_customer", "arguments": {"customer_id": 10045}},
"required_steps": ["get_customer"], "completed_steps": [], "pending_steps": ["get_customer"], "terminal_tools": ["respond"],
},
# ---- Scoped search vs. global search ----
{
"user_request": "Search for running shoes in the footwear category.",
"tools": [
{"name": "search_category", "description": "Search products within a specific category.", "parameters": {"type": "object", "properties": {"query": {"type": "string"}, "category": {"type": "string"}}, "required": ["query", "category"]}},
{"name": "search_all", "description": "Search all products across all categories.", "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}},
{"name": "respond", "description": "Send final response.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
],
"valid_candidate": {"name": "search_category", "arguments": {"query": "running shoes", "category": "footwear"}},
"wrong_tool_candidate": {"name": "search_all", "arguments": {"query": "running shoes"}},
"required_steps": ["search_category"], "completed_steps": [], "pending_steps": ["search_category"], "terminal_tools": ["respond"],
},
{
"user_request": "Find the Q4 2024 invoice for account ACC-982.",
"tools": [
{"name": "search_invoices", "description": "Search invoices for a specific account.", "parameters": {"type": "object", "properties": {"account_id": {"type": "string"}, "quarter": {"type": "string"}}, "required": ["account_id"]}},
{"name": "search_all_records", "description": "Global record search across accounts and invoices.", "parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}},
{"name": "respond", "description": "Send final response.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
],
"valid_candidate": {"name": "search_invoices", "arguments": {"account_id": "ACC-982", "quarter": "Q4-2024"}},
"wrong_tool_candidate": {"name": "search_all_records", "arguments": {"query": "Q4 2024 invoice ACC-982"}},
"required_steps": ["search_invoices"], "completed_steps": [], "pending_steps": ["search_invoices"], "terminal_tools": ["respond"],
},
# ---- Premature terminal tool ----
{
"user_request": "Fetch the sales data for Q3 2024 and then generate a report.",
"tools": [
{"name": "fetch_sales", "description": "Fetch sales data for a given period.", "parameters": {"type": "object", "properties": {"quarter": {"type": "string"}}, "required": ["quarter"]}},
{"name": "generate_report", "description": "Generate the final sales report.", "parameters": {"type": "object", "properties": {"summary": {"type": "string"}}, "required": ["summary"]}},
],
"valid_candidate": {"name": "fetch_sales", "arguments": {"quarter": "Q3-2024"}},
"wrong_tool_candidate": {"name": "generate_report", "arguments": {"summary": "Sales data for Q3 2024."}},
"required_steps": ["fetch_sales", "generate_report"], "completed_steps": [], "pending_steps": ["fetch_sales", "generate_report"], "terminal_tools": ["generate_report"],
},
{
"user_request": "First pull the user activity log, then summarize it.",
"tools": [
{"name": "get_activity_log", "description": "Retrieve user activity log for a period.", "parameters": {"type": "object", "properties": {"user_id": {"type": "string"}, "days": {"type": "integer"}}, "required": ["user_id"]}},
{"name": "summarize", "description": "Summarize content into a report.", "parameters": {"type": "object", "properties": {"content": {"type": "string"}}, "required": ["content"]}},
],
"valid_candidate": {"name": "get_activity_log", "arguments": {"user_id": "usr-77", "days": 30}},
"wrong_tool_candidate": {"name": "summarize", "arguments": {"content": "Activity log for user usr-77."}},
"required_steps": ["get_activity_log", "summarize"], "completed_steps": [], "pending_steps": ["get_activity_log", "summarize"], "terminal_tools": ["summarize"],
},
# ---- Already-completed step repeat ----
{
"user_request": "Analyze the dataset and produce a report.",
"tools": [
{"name": "analyze_data", "description": "Run statistical analysis on a dataset.", "parameters": {"type": "object", "properties": {"dataset_id": {"type": "string"}}, "required": ["dataset_id"]}},
{"name": "produce_report", "description": "Write and deliver the final report.", "parameters": {"type": "object", "properties": {"findings": {"type": "string"}}, "required": ["findings"]}},
],
"valid_candidate": {"name": "produce_report", "arguments": {"findings": "Dataset shows 15% growth in Q4."}},
"wrong_tool_candidate": {"name": "analyze_data", "arguments": {"dataset_id": "ds-001"}},
"required_steps": ["analyze_data", "produce_report"], "completed_steps": ["analyze_data"], "pending_steps": ["produce_report"], "terminal_tools": ["produce_report"],
},
{
"user_request": "Pull the transaction history and generate the audit report.",
"tools": [
{"name": "get_transactions", "description": "Fetch transaction history.", "parameters": {"type": "object", "properties": {"account": {"type": "string"}}, "required": ["account"]}},
{"name": "audit_report", "description": "Generate audit report from transactions.", "parameters": {"type": "object", "properties": {"summary": {"type": "string"}}, "required": ["summary"]}},
],
"valid_candidate": {"name": "audit_report", "arguments": {"summary": "No anomalies found in transaction history for account-22."}},
"wrong_tool_candidate": {"name": "get_transactions", "arguments": {"account": "account-22"}},
"required_steps": ["get_transactions", "audit_report"], "completed_steps": ["get_transactions"], "pending_steps": ["audit_report"], "terminal_tools": ["audit_report"],
},
# ---- List vs. delete semantic confusion ----
{
"user_request": "List all pending tickets for project FORGE-7.",
"tools": [
{"name": "list_tickets", "description": "List tickets for a project filtered by status.", "parameters": {"type": "object", "properties": {"project": {"type": "string"}, "status": {"type": "string"}}, "required": ["project"]}},
{"name": "delete_ticket", "description": "Delete a ticket by ID.", "parameters": {"type": "object", "properties": {"ticket_id": {"type": "string"}}, "required": ["ticket_id"]}},
{"name": "respond", "description": "Send final response.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
],
"valid_candidate": {"name": "list_tickets", "arguments": {"project": "FORGE-7", "status": "pending"}},
"wrong_tool_candidate": {"name": "delete_ticket", "arguments": {"ticket_id": "FORGE-7"}},
"required_steps": ["list_tickets"], "completed_steps": [], "pending_steps": ["list_tickets"], "terminal_tools": ["respond"],
},
# ---- Create vs. fetch ----
{
"user_request": "Get the profile for user ID 5514.",
"tools": [
{"name": "get_profile", "description": "Retrieve an existing user profile.", "parameters": {"type": "object", "properties": {"user_id": {"type": "integer"}}, "required": ["user_id"]}},
{"name": "create_profile", "description": "Create a new user profile.", "parameters": {"type": "object", "properties": {"user_id": {"type": "integer"}, "name": {"type": "string"}}, "required": ["user_id", "name"]}},
{"name": "respond", "description": "Send final response.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
],
"valid_candidate": {"name": "get_profile", "arguments": {"user_id": 5514}},
"wrong_tool_candidate": {"name": "create_profile", "arguments": {"user_id": 5514, "name": "Unknown"}},
"required_steps": ["get_profile"], "completed_steps": [], "pending_steps": ["get_profile"], "terminal_tools": ["respond"],
},
# ---- Approve vs. reject ----
{
"user_request": "Approve the purchase request PR-2041.",
"tools": [
{"name": "approve_request", "description": "Approve a purchase request.", "parameters": {"type": "object", "properties": {"request_id": {"type": "string"}}, "required": ["request_id"]}},
{"name": "reject_request", "description": "Reject a purchase request.", "parameters": {"type": "object", "properties": {"request_id": {"type": "string"}, "reason": {"type": "string"}}, "required": ["request_id", "reason"]}},
{"name": "respond", "description": "Send final response.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
],
"valid_candidate": {"name": "approve_request", "arguments": {"request_id": "PR-2041"}},
"wrong_tool_candidate": {"name": "reject_request", "arguments": {"request_id": "PR-2041", "reason": "Budget exceeded"}},
"required_steps": ["approve_request"], "completed_steps": [], "pending_steps": ["approve_request"], "terminal_tools": ["respond"],
},
# ---- Read vs. write account ----
{
"user_request": "Check the balance on account ACC-1144.",
"tools": [
{"name": "get_balance", "description": "Get the current balance of an account.", "parameters": {"type": "object", "properties": {"account_id": {"type": "string"}}, "required": ["account_id"]}},
{"name": "transfer_funds", "description": "Transfer funds between accounts.", "parameters": {"type": "object", "properties": {"from_account": {"type": "string"}, "to_account": {"type": "string"}, "amount": {"type": "number"}}, "required": ["from_account", "to_account", "amount"]}},
{"name": "respond", "description": "Send final response.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
],
"valid_candidate": {"name": "get_balance", "arguments": {"account_id": "ACC-1144"}},
"wrong_tool_candidate": {"name": "transfer_funds", "arguments": {"from_account": "ACC-1144", "to_account": "ACC-0000", "amount": 0}},
"required_steps": ["get_balance"], "completed_steps": [], "pending_steps": ["get_balance"], "terminal_tools": ["respond"],
},
# ---- Compute vs. fetch computed result ----
{
"user_request": "Calculate the tax owed for invoice INV-9901.",
"tools": [
{"name": "compute_tax", "description": "Calculate tax for a given invoice.", "parameters": {"type": "object", "properties": {"invoice_id": {"type": "string"}}, "required": ["invoice_id"]}},
{"name": "get_invoice", "description": "Fetch the raw invoice data.", "parameters": {"type": "object", "properties": {"invoice_id": {"type": "string"}}, "required": ["invoice_id"]}},
{"name": "respond", "description": "Send final response.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
],
"valid_candidate": {"name": "compute_tax", "arguments": {"invoice_id": "INV-9901"}},
"wrong_tool_candidate": {"name": "get_invoice", "arguments": {"invoice_id": "INV-9901"}},
"required_steps": ["compute_tax"], "completed_steps": [], "pending_steps": ["compute_tax"], "terminal_tools": ["respond"],
},
# ---- Send vs. schedule ----
{
"user_request": "Send the welcome email to user usr-42 now.",
"tools": [
{"name": "send_email", "description": "Send an email immediately.", "parameters": {"type": "object", "properties": {"user_id": {"type": "string"}, "template": {"type": "string"}}, "required": ["user_id", "template"]}},
{"name": "schedule_email", "description": "Schedule an email for future delivery.", "parameters": {"type": "object", "properties": {"user_id": {"type": "string"}, "template": {"type": "string"}, "send_at": {"type": "string"}}, "required": ["user_id", "template", "send_at"]}},
{"name": "respond", "description": "Send final response.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
],
"valid_candidate": {"name": "send_email", "arguments": {"user_id": "usr-42", "template": "welcome"}},
"wrong_tool_candidate": {"name": "schedule_email", "arguments": {"user_id": "usr-42", "template": "welcome", "send_at": "2025-01-01T09:00:00Z"}},
"required_steps": ["send_email"], "completed_steps": [], "pending_steps": ["send_email"], "terminal_tools": ["respond"],
},
# ---- Lock vs. unlock ----
{
"user_request": "Lock the account for user usr-99 due to suspicious activity.",
"tools": [
{"name": "lock_account", "description": "Lock a user account to prevent access.", "parameters": {"type": "object", "properties": {"user_id": {"type": "string"}, "reason": {"type": "string"}}, "required": ["user_id", "reason"]}},
{"name": "unlock_account", "description": "Restore access to a locked user account.", "parameters": {"type": "object", "properties": {"user_id": {"type": "string"}}, "required": ["user_id"]}},
{"name": "respond", "description": "Send final response.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
],
"valid_candidate": {"name": "lock_account", "arguments": {"user_id": "usr-99", "reason": "suspicious activity"}},
"wrong_tool_candidate": {"name": "unlock_account", "arguments": {"user_id": "usr-99"}},
"required_steps": ["lock_account"], "completed_steps": [], "pending_steps": ["lock_account"], "terminal_tools": ["respond"],
},
# ---- Publish vs. archive ----
{
"user_request": "Publish the draft report RPT-555 to the portal.",
"tools": [
{"name": "publish_report", "description": "Publish a report draft to the live portal.", "parameters": {"type": "object", "properties": {"report_id": {"type": "string"}}, "required": ["report_id"]}},
{"name": "archive_report", "description": "Move a report to archive storage.", "parameters": {"type": "object", "properties": {"report_id": {"type": "string"}}, "required": ["report_id"]}},
{"name": "respond", "description": "Send final response.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
],
"valid_candidate": {"name": "publish_report", "arguments": {"report_id": "RPT-555"}},
"wrong_tool_candidate": {"name": "archive_report", "arguments": {"report_id": "RPT-555"}},
"required_steps": ["publish_report"], "completed_steps": [], "pending_steps": ["publish_report"], "terminal_tools": ["respond"],
},
# ---- Escalate vs. resolve ----
{
"user_request": "Escalate ticket TKT-1022 to tier-2 support.",
"tools": [
{"name": "escalate_ticket", "description": "Escalate a ticket to a higher support tier.", "parameters": {"type": "object", "properties": {"ticket_id": {"type": "string"}, "tier": {"type": "integer"}}, "required": ["ticket_id", "tier"]}},
{"name": "resolve_ticket", "description": "Mark a ticket as resolved.", "parameters": {"type": "object", "properties": {"ticket_id": {"type": "string"}, "resolution": {"type": "string"}}, "required": ["ticket_id", "resolution"]}},
{"name": "respond", "description": "Send final response.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
],
"valid_candidate": {"name": "escalate_ticket", "arguments": {"ticket_id": "TKT-1022", "tier": 2}},
"wrong_tool_candidate": {"name": "resolve_ticket", "arguments": {"ticket_id": "TKT-1022", "resolution": "Escalated automatically"}},
"required_steps": ["escalate_ticket"], "completed_steps": [], "pending_steps": ["escalate_ticket"], "terminal_tools": ["respond"],
},
# ---- Enable vs. disable feature ----
{
"user_request": "Enable the two-factor authentication feature for tenant ACME.",
"tools": [
{"name": "enable_feature", "description": "Enable a feature flag for a tenant.", "parameters": {"type": "object", "properties": {"tenant": {"type": "string"}, "feature": {"type": "string"}}, "required": ["tenant", "feature"]}},
{"name": "disable_feature", "description": "Disable a feature flag for a tenant.", "parameters": {"type": "object", "properties": {"tenant": {"type": "string"}, "feature": {"type": "string"}}, "required": ["tenant", "feature"]}},
{"name": "respond", "description": "Send final response.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
],
"valid_candidate": {"name": "enable_feature", "arguments": {"tenant": "ACME", "feature": "two_factor_auth"}},
"wrong_tool_candidate": {"name": "disable_feature", "arguments": {"tenant": "ACME", "feature": "two_factor_auth"}},
"required_steps": ["enable_feature"], "completed_steps": [], "pending_steps": ["enable_feature"], "terminal_tools": ["respond"],
},
# ---- Subscribe vs. unsubscribe ----
{
"user_request": "Subscribe user usr-301 to the weekly digest newsletter.",
"tools": [
{"name": "subscribe_newsletter", "description": "Subscribe a user to a newsletter.", "parameters": {"type": "object", "properties": {"user_id": {"type": "string"}, "newsletter": {"type": "string"}}, "required": ["user_id", "newsletter"]}},
{"name": "unsubscribe_newsletter", "description": "Unsubscribe a user from a newsletter.", "parameters": {"type": "object", "properties": {"user_id": {"type": "string"}, "newsletter": {"type": "string"}}, "required": ["user_id", "newsletter"]}},
{"name": "respond", "description": "Send final response.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
],
"valid_candidate": {"name": "subscribe_newsletter", "arguments": {"user_id": "usr-301", "newsletter": "weekly_digest"}},
"wrong_tool_candidate": {"name": "unsubscribe_newsletter", "arguments": {"user_id": "usr-301", "newsletter": "weekly_digest"}},
"required_steps": ["subscribe_newsletter"], "completed_steps": [], "pending_steps": ["subscribe_newsletter"], "terminal_tools": ["respond"],
},
]
def build_contrastive_wrong_tool_rows() -> List[VerifierRow]:
rows: List[VerifierRow] = []
for i, pair in enumerate(FORGE_CONTRASTIVE_WRONG_TOOL_PAIRS):
user_request = pair["user_request"]
tools = pair["tools"]
required = pair.get("required_steps", [])
completed = pair.get("completed_steps", [])
pending = pair.get("pending_steps", required)
terminal = pair.get("terminal_tools", [])
group_id = stable_id("forge_contrastive_wts", i, user_request)
group_id = f"contrastive_wts_{i:03d}_{group_id[:8]}"
base_meta = {
"generator": "contrastive_pair",
"scenario_family": "wrong_tool_contrastive",
"source_kind": "forge_contrastive",
"contrastive_pair_index": i,
}
sm = infer_scoring_metadata("wrong_tool_contrastive")
rows.append(make_row(
"forge_contrastive_wts", "valid",
user_request, tools, pair["valid_candidate"], 1.0,
dict(base_meta, negative_type="contrastive_valid"),
required_steps=required, completed_steps=completed,
pending_steps=pending, terminal_tools=terminal,
group_id=group_id, scoring_metadata=sm,
))
rows.append(make_row(
"forge_contrastive_wts", "wrong_tool_semantic",
user_request, tools, pair["wrong_tool_candidate"], 0.05,
dict(base_meta, negative_type="wrong_tool_contrastive"),
required_steps=required, completed_steps=completed,
pending_steps=pending, terminal_tools=terminal,
group_id=group_id, scoring_metadata=sm,
))
return rows
'''
CONTRASTIVE_CALL_SITE_OLD = "forge_rows = (\n build_forge_synthetic_rows()\n + build_argument_semantic_rows()\n + build_error_recovery_numeric_semantic_rows()\n)"
CONTRASTIVE_CALL_SITE_NEW = "forge_rows = (\n build_forge_synthetic_rows()\n + build_argument_semantic_rows()\n + build_error_recovery_numeric_semantic_rows()\n + build_contrastive_wrong_tool_rows()\n)"
PROTECTED_VALID_BLOCK = '''
# ---------------------------------------------------------------------------
# Protected valid slice expansion — fixed-width numeric strings
# ---------------------------------------------------------------------------
FORGE_FIXED_WIDTH_VALID_CASES = [
# (user_request, tool_name, tool_description, param_name, param_description, valid_value, wrong_int_value)
("Fetch 42 records.", "fetch", "Fetch records by count. Count must be zero-padded 4-digit string.", "count", "4-digit zero-padded count.", "0042", 42),
("Fetch 7 records.", "fetch", "Fetch records by count. Count must be zero-padded 4-digit string.", "count", "4-digit zero-padded count.", "0007", 7),
("Fetch 100 records.", "fetch", "Fetch records by count. Count must be zero-padded 4-digit string.", "count", "4-digit zero-padded count.", "0100", 100),
("Fetch 999 records.", "fetch", "Fetch records by count. Count must be zero-padded 4-digit string.", "count", "4-digit zero-padded count.", "0999", 999),
("Fetch 1 record.", "fetch", "Fetch records by count. Count must be zero-padded 4-digit string.", "count", "4-digit zero-padded count.", "0001", 1),
("Retrieve order 00045.", "get_order", "Fetch order by ID. ID must be zero-padded 5-digit string.", "order_id", "5-digit zero-padded order ID.", "00045", 45),
("Retrieve order 00001.", "get_order", "Fetch order by ID. ID must be zero-padded 5-digit string.", "order_id", "5-digit zero-padded order ID.", "00001", 1),
("Retrieve order 09999.", "get_order", "Fetch order by ID. ID must be zero-padded 5-digit string.", "order_id", "5-digit zero-padded order ID.", "09999", 9999),
("Look up zip code 07030.", "lookup_zip", "Look up a ZIP code. Must be 5-digit string.", "zip_code", "5-digit ZIP code string.", "07030", 7030),
("Look up zip code 00501.", "lookup_zip", "Look up a ZIP code. Must be 5-digit string.", "zip_code", "5-digit ZIP code string.", "00501", 501),
("Look up zip code 90210.", "lookup_zip", "Look up a ZIP code. Must be 5-digit string.", "zip_code", "5-digit ZIP code string.", "90210", 90210),
("Send SMS to +14155551234.", "send_sms", "Send SMS to a phone number in E.164 format.", "phone", "E.164 phone number string.", "+14155551234", 14155551234),
("Send SMS to +441234567890.", "send_sms", "Send SMS to a phone number in E.164 format.", "phone", "E.164 phone number string.", "+441234567890", 441234567890),
("Look up IBAN DE89370400440532013000.", "verify_iban", "Verify an IBAN. Must be string.", "iban", "IBAN string.", "DE89370400440532013000", 89370400440532013000),
("Route payment via routing number 021000021.", "route_payment", "Route payment. Routing number must be string.", "routing_number", "9-digit routing number string.", "021000021", 21000021),
("Route payment via routing number 111000025.", "route_payment", "Route payment. Routing number must be string.", "routing_number", "9-digit routing number string.", "111000025", 111000025),
("Lookup account 000123.", "lookup_account", "Lookup an account. ID must be zero-padded 6-digit string.", "account_id", "6-digit zero-padded account ID.", "000123", 123),
("Lookup account 001000.", "lookup_account", "Lookup an account. ID must be zero-padded 6-digit string.", "account_id", "6-digit zero-padded account ID.", "001000", 1000),
("Fetch batch 0050.", "fetch_batch", "Fetch a batch by ID. ID must be zero-padded 4-digit string.", "batch_id", "4-digit zero-padded batch ID.", "0050", 50),
("Fetch batch 0001.", "fetch_batch", "Fetch a batch by ID. ID must be zero-padded 4-digit string.", "batch_id", "4-digit zero-padded batch ID.", "0001", 1),
]
def fixed_width_tools(tool_name: str, tool_description: str, param_name: str, param_description: str) -> List[Dict[str, Any]]:
return [
{
"name": tool_name,
"description": tool_description,
"parameters": {
"type": "object",
"properties": {param_name: {"type": "string", "description": param_description}},
"required": [param_name],
},
},
{"name": "respond", "description": "Send final answer.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
]
def build_fixed_width_numeric_rows() -> List[VerifierRow]:
rows: List[VerifierRow] = []
family_meta = infer_scoring_metadata("fixed_width_numeric")
for (user_request, tool_name, tool_desc, param_name, param_desc, valid_value, wrong_int) in FORGE_FIXED_WIDTH_VALID_CASES:
tools = fixed_width_tools(tool_name, tool_desc, param_name, param_desc)
required = [tool_name]
terminal = ["respond"]
group_id = stable_id("forge_fixed_width", user_request, tool_name, param_name, valid_value)
valid_meta = {
"generator": "forge_fixed_width_numeric",
"scenario_family": "fixed_width_numeric",
"source_kind": "synthetic_numeric_string",
"corrected_positive": True,
"valid_protection_fixed_width_numeric_string": True,
}
rows.append(make_row(
"forge_fixed_width_numeric", "valid",
user_request, tools, {"name": tool_name, "arguments": {param_name: valid_value}}, 1.0,
valid_meta, required_steps=required, completed_steps=[], pending_steps=required, terminal_tools=terminal,
group_id=group_id, scoring_metadata=family_meta,
))
for neg_type, neg_val in [("integer_instead_of_string", wrong_int), ("unpadded_string", str(wrong_int))]:
neg_key = json.dumps(neg_val, default=str, sort_keys=True)
valid_key = json.dumps(valid_value, sort_keys=True)
if neg_key == valid_key:
continue
rows.append(make_row(
"forge_fixed_width_numeric", "wrong_arguments_semantic",
user_request, tools, {"name": tool_name, "arguments": {param_name: neg_val}}, 0.05,
{
"generator": "forge_fixed_width_numeric",
"scenario_family": "fixed_width_numeric",
"source_kind": "synthetic_numeric_string",
"negative_type": f"fixed_width_{neg_type}",
"valid_counterpart": valid_value,
},
required_steps=required, completed_steps=[], pending_steps=required, terminal_tools=terminal,
group_id=group_id, scoring_metadata=family_meta,
))
return rows
# ---------------------------------------------------------------------------
# Protected valid slice expansion — corrected error recovery
# ---------------------------------------------------------------------------
FORGE_ERROR_RECOVERY_SCENARIOS = [
{
"user_request": "Fetch records for account ACC-88. A previous fetch attempt failed with an invalid account ID format.",
"tool_name": "fetch_records", "tool_desc": "Fetch records for an account. Account ID must be string in 'ACC-NNN' format.",
"param_name": "account_id", "param_desc": "Account ID string in ACC-NNN format.",
"valid_value": "ACC-88", "wrong_args": [("integer_id", 88), ("wrong_prefix", "account-88"), ("bare_number_str", "88")],
"recent_errors": ["Error: Invalid account_id format. Got: 88. Expected string like ACC-NNN."],
"wrong_tool": "lookup_account",
"wrong_tool_desc": "Look up account metadata by numeric ID.",
"wrong_tool_params": {"type": "object", "properties": {"id": {"type": "integer"}}, "required": ["id"]},
"wrong_tool_call": {"name": "lookup_account", "arguments": {"id": 88}},
},
{
"user_request": "Retrieve order ORD-0042. Prior attempt got 'order not found' because the ID was numeric instead of string.",
"tool_name": "get_order", "tool_desc": "Fetch an order. Order ID must be string in 'ORD-NNNN' format.",
"param_name": "order_id", "param_desc": "Order ID string like ORD-0042.",
"valid_value": "ORD-0042", "wrong_args": [("integer_id", 42), ("bare_string", "0042"), ("wrong_prefix", "order-0042")],
"recent_errors": ["Error: order_id must be string like ORD-NNNN, got integer 42."],
"wrong_tool": "search_orders",
"wrong_tool_desc": "Search orders by keyword query.",
"wrong_tool_params": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]},
"wrong_tool_call": {"name": "search_orders", "arguments": {"query": "ORD-0042"}},
},
{
"user_request": "Fetch the first 5 records. Last call used integer 5 but the API needs zero-padded string.",
"tool_name": "fetch", "tool_desc": "Fetch records by count. Count must be zero-padded 4-digit string.",
"param_name": "count", "param_desc": "4-digit zero-padded count string.",
"valid_value": "0005", "wrong_args": [("integer", 5), ("unpadded_string", "5"), ("over_padded", "000005")],
"recent_errors": ["Error: count must be a 4-digit zero-padded string like '0005', got 5."],
"wrong_tool": "summarize",
"wrong_tool_desc": "Summarize previously fetched records.",
"wrong_tool_params": {"type": "object", "properties": {"content": {"type": "string"}}, "required": ["content"]},
"wrong_tool_call": {"name": "summarize", "arguments": {"content": "5 records"}},
},
{
"user_request": "Look up ZIP code 07040. Previous lookup failed because the ZIP was passed as integer.",
"tool_name": "lookup_zip", "tool_desc": "Look up a ZIP code. Must be 5-digit string.",
"param_name": "zip_code", "param_desc": "5-digit ZIP code string.",
"valid_value": "07040", "wrong_args": [("integer", 7040), ("unpadded_string", "7040"), ("over_padded", "007040")],
"recent_errors": ["Error: zip_code must be string '07040', got integer 7040."],
"wrong_tool": "get_city",
"wrong_tool_desc": "Get city name by city ID.",
"wrong_tool_params": {"type": "object", "properties": {"city_id": {"type": "integer"}}, "required": ["city_id"]},
"wrong_tool_call": {"name": "get_city", "arguments": {"city_id": 7040}},
},
{
"user_request": "Submit audit report for transaction batch 0033. Previous attempt failed: batch_id was integer.",
"tool_name": "submit_audit", "tool_desc": "Submit an audit report. batch_id must be zero-padded 4-digit string.",
"param_name": "batch_id", "param_desc": "4-digit zero-padded batch ID string.",
"valid_value": "0033", "wrong_args": [("integer", 33), ("unpadded", "33"), ("over_padded", "000033")],
"recent_errors": ["Error: batch_id must be string '0033', got integer 33."],
"wrong_tool": "list_batches",
"wrong_tool_desc": "List all available batches.",
"wrong_tool_params": {"type": "object", "properties": {}, "required": []},
"wrong_tool_call": {"name": "list_batches", "arguments": {}},
},
]
def error_recovery_scenario_tools(scen: Dict[str, Any]) -> List[Dict[str, Any]]:
return [
{
"name": scen["tool_name"],
"description": scen["tool_desc"],
"parameters": {
"type": "object",
"properties": {scen["param_name"]: {"type": "string", "description": scen["param_desc"]}},
"required": [scen["param_name"]],
},
},
{
"name": scen["wrong_tool"],
"description": scen["wrong_tool_desc"],
"parameters": scen["wrong_tool_params"],
},
{"name": "respond", "description": "Send final answer.", "parameters": {"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}},
]
def build_error_recovery_protected_rows() -> List[VerifierRow]:
rows: List[VerifierRow] = []
family_meta = infer_scoring_metadata("error_recovery")
for scen in FORGE_ERROR_RECOVERY_SCENARIOS:
tools = error_recovery_scenario_tools(scen)
required = [scen["tool_name"]]
terminal = ["respond"]
recent_errors = scen.get("recent_errors", [])
group_id = stable_id("forge_error_recovery_protected", scen["user_request"], scen["tool_name"])
valid_meta = {
"generator": "forge_error_recovery_protected",
"scenario_family": "error_recovery",
"source_kind": "synthetic_error_recovery",
"corrected_positive": True,
"valid_protection_corrected_error_recovery": True,
}
rows.append(make_row(
"forge_error_recovery_protected", "valid",
scen["user_request"], tools, {"name": scen["tool_name"], "arguments": {scen["param_name"]: scen["valid_value"]}}, 1.0,
valid_meta, required_steps=required, completed_steps=[], pending_steps=required, terminal_tools=terminal,
recent_errors=recent_errors, group_id=group_id, scoring_metadata=family_meta,
))
# Paired wrong_arguments_semantic hard negatives
seen_keys = {json.dumps(scen["valid_value"], sort_keys=True)}
for neg_type, neg_val in scen.get("wrong_args", []):
neg_key = json.dumps(neg_val, default=str, sort_keys=True)
if neg_key in seen_keys:
continue
seen_keys.add(neg_key)
rows.append(make_row(
"forge_error_recovery_protected", "wrong_arguments_semantic",
scen["user_request"], tools, {"name": scen["tool_name"], "arguments": {scen["param_name"]: neg_val}}, 0.05,
{
"generator": "forge_error_recovery_protected",
"scenario_family": "error_recovery",
"source_kind": "synthetic_error_recovery",
"negative_type": f"error_recovery_{neg_type}",
"valid_counterpart": scen["valid_value"],
},
required_steps=required, completed_steps=[], pending_steps=required, terminal_tools=terminal,
recent_errors=recent_errors, group_id=group_id, scoring_metadata=family_meta,
))
# Paired wrong_tool_semantic hard negative
rows.append(make_row(
"forge_error_recovery_protected", "wrong_tool_semantic",
scen["user_request"], tools, scen["wrong_tool_call"], 0.05,
{
"generator": "forge_error_recovery_protected",
"scenario_family": "error_recovery",
"source_kind": "synthetic_error_recovery",
"negative_type": "error_recovery_wrong_tool",
},
required_steps=required, completed_steps=[], pending_steps=required, terminal_tools=terminal,
recent_errors=recent_errors, group_id=group_id, scoring_metadata=family_meta,
))
return rows
'''
CONTRASTIVE_CALL_SITE_PROTECTED_NEW = """forge_rows = (
build_forge_synthetic_rows()
+ build_argument_semantic_rows()
+ build_error_recovery_numeric_semantic_rows()
+ build_contrastive_wrong_tool_rows()
+ build_fixed_width_numeric_rows()
+ build_error_recovery_protected_rows()
)"""
GATE_DEFICIT_OLD = """ # Select checkpoints by distance from promotion gates. A low-confidence model with poor recall should not win
# merely because it makes few high-confidence objections.
valid_recall_deficit = max(0.0, CHECKPOINT_VALID_RECALL_GATE - valid_recall) / CHECKPOINT_VALID_RECALL_GATE
wrong_tool_precision_deficit = max(0.0, CHECKPOINT_WRONG_TOOL_PRECISION_GATE - wrong_tool_precision) / CHECKPOINT_WRONG_TOOL_PRECISION_GATE
false_objection_excess = max(0.0, valid_false_objection_90 - CHECKPOINT_FALSE_OBJECTION_90_GATE) / CHECKPOINT_FALSE_OBJECTION_90_GATE
gate_deficit = float(
valid_recall_deficit
+ wrong_tool_precision_deficit
+ 5.0 * false_objection_excess
+ 0.5 * valid_to_wrong_args_rate
)
gate_deficit_score = -gate_deficit"""
GATE_DEFICIT_NEW = """ # Constrained lexicographic checkpoint selection.
# 1. Discard checkpoints with false_objection > 2.5x gate ceiling (non-promotable).
# 2. Among remaining: maximize valid_recall, then wrong_tool_precision, then wrong_tool_recall, then macro_f1.
# This prevents the blended gate_deficit from selecting an epoch that collapses valid_recall
# while scoring well only because it rarely makes high-confidence objections.
CHECKPOINT_FALSE_OBJECTION_DISCARD_CEILING = 2.5 * CHECKPOINT_FALSE_OBJECTION_90_GATE
valid_recall_deficit = max(0.0, CHECKPOINT_VALID_RECALL_GATE - valid_recall) / CHECKPOINT_VALID_RECALL_GATE
wrong_tool_precision_deficit = max(0.0, CHECKPOINT_WRONG_TOOL_PRECISION_GATE - wrong_tool_precision) / CHECKPOINT_WRONG_TOOL_PRECISION_GATE
false_objection_excess = max(0.0, valid_false_objection_90 - CHECKPOINT_FALSE_OBJECTION_90_GATE) / CHECKPOINT_FALSE_OBJECTION_90_GATE
# Keep legacy gate_deficit for telemetry backward-compat.
gate_deficit = float(
valid_recall_deficit
+ wrong_tool_precision_deficit
+ 5.0 * false_objection_excess
+ 0.5 * valid_to_wrong_args_rate
)
constrained_promotable = bool(valid_false_objection_90 <= CHECKPOINT_FALSE_OBJECTION_DISCARD_CEILING)
if not constrained_promotable:
gate_deficit_score = float("-inf")
else:
# Lexicographic: valid_recall is primary, wrong_tool_precision secondary, etc.
gate_deficit_score = (
valid_recall # primary: maximize valid recall [0, 1]
+ 0.1 * wrong_tool_precision # secondary: wrong_tool precision [0, 0.1]
+ 0.01 * wrong_tool_recall # tertiary: wrong_tool recall [0, 0.01]
+ 0.001 * present_f1 # quaternary: macro F1 [0, 0.001]
- 10.0 * max(0.0, valid_false_objection_90 - CHECKPOINT_FALSE_OBJECTION_90_GATE)
)"""
GATE_DEFICIT_RETURN_OLD = """ \"gate_deficit\": gate_deficit,
\"gate_deficit_score\": gate_deficit_score,
}"""
GATE_DEFICIT_RETURN_NEW = """ "gate_deficit": gate_deficit,
"gate_deficit_score": gate_deficit_score,
"constrained_promotable": constrained_promotable,
}"""
EXCL_DI_FLAG_OLD = """# Keep class weights off by default. Previous weighted + synthetic rare-class runs regressed badly.
USE_CLASS_WEIGHTS = False #@param {type:"boolean"}"""
EXCL_DI_FLAG_NEW = """# Keep class weights off by default. Previous weighted + synthetic rare-class runs regressed badly.
USE_CLASS_WEIGHTS = False #@param {type:"boolean"}
# When True, deterministic_invalid rows are removed from train/val/test.
# Rust deterministic rules are authoritative for this class; ML label competition adds noise.
# The six-label artifact format is preserved — the model simply learns near-zero logits for this class.
EXCLUDE_DETERMINISTIC_INVALID_FROM_ML_TRAINING = True #@param {type:"boolean"}"""
EXCL_DI_IN_DATAFRAME_OLD = """print_source_label_summary(frame: pd.DataFrame, title: str) -> None:
print(f\"\\n{title} by source and label:\")
print(pd.crosstab(frame[\"source\"], frame[\"label\"]).to_string())"""
EXCL_DI_FILTER_INSERTION_MARKER = "FAIL_ON_SUSPICIOUS_VALID_HARD_NEGATIVES = True"
EXCL_DI_FILTER_SNIPPET = """FAIL_ON_SUSPICIOUS_VALID_HARD_NEGATIVES = True #@param {type:"boolean"}
MAX_SUSPICIOUS_VALID_SAMPLE_ROWS = 50 #@param {type:"integer"}
# Respect EXCLUDE_DETERMINISTIC_INVALID_FROM_ML_TRAINING flag set in the config cell.
_EXCLUDE_DI = bool(globals().get("EXCLUDE_DETERMINISTIC_INVALID_FROM_ML_TRAINING", True))"""
SOURCE_BALANCED_EVAL_SNIPPET = """
def source_balanced_eval_summary(scored: pd.DataFrame, split_name: str) -> Dict[str, Any]:
\"\"\"Per-source breakdown of key promotion metrics; forge-weighted aggregate score.\"\"\"
if scored.empty:
return {}
forge_sources = {"forge_eval", "forge_synthetic", "forge_contrastive_wts",
"forge_argument_semantic", "forge_error_recovery_numeric",
"forge_fixed_width_numeric", "forge_error_recovery_protected",
"forge_augmented"}
sources = scored["source"].unique().tolist()
rows_by_source: Dict[str, Any] = {}
forge_correct, forge_total, public_correct, public_total = 0, 0, 0, 0
for src in sources:
src_df = scored[scored["source"] == src]
if src_df.empty:
continue
valid_mask = src_df["true_label"] == "valid"
valid_count = int(valid_mask.sum())
valid_recall_src = float((src_df.loc[valid_mask, "pred_label"] == "valid").mean()) if valid_count else float("nan")
fo_mask = valid_mask & (src_df["pred_label"] != "valid") & (src_df["confidence"] >= 0.90)
false_obj_src = float(fo_mask.sum() / valid_count) if valid_count else float("nan")
wts_pred = src_df["pred_label"] == "wrong_tool_semantic"
wts_true = src_df["true_label"] == "wrong_tool_semantic"
wts_prec = float((src_df.loc[wts_pred, "true_label"] == "wrong_tool_semantic").mean()) if int(wts_pred.sum()) else float("nan")
correct = int((src_df["pred_label"] == src_df["true_label"]).sum())
total = len(src_df)
rows_by_source[src] = {
"total": total, "correct": correct, "accuracy": round(correct / total, 4) if total else float("nan"),
"valid_count": valid_count, "valid_recall": round(valid_recall_src, 4),
"false_objection_90": round(false_obj_src, 4),
"wrong_tool_semantic_precision": round(wts_prec, 4),
}
if src in forge_sources:
forge_correct += correct; forge_total += total
else:
public_correct += correct; public_total += total
forge_acc = forge_correct / forge_total if forge_total else float("nan")
public_acc = public_correct / public_total if public_total else float("nan")
forge_weight = 0.70
weighted_score = forge_weight * forge_acc + (1 - forge_weight) * public_acc if (forge_total and public_total) else (forge_acc if forge_total else public_acc)
summary = {
"split": split_name,
"per_source": rows_by_source,
"forge_accuracy": round(forge_acc, 4),
"public_accuracy": round(public_acc, 4),
"forge_weighted_score": round(weighted_score, 4),
}
print(f"\\n=== Source-balanced eval summary ({split_name}) ===")
for src, info in sorted(rows_by_source.items(), key=lambda kv: -kv[1]["total"]):
flag = " [FORGE]" if src in forge_sources else ""
print(f" {src}{flag}: n={info['total']}, acc={info['accuracy']:.3f}, "
f"valid_recall={info['valid_recall']:.3f}, "
f"false_obj_90={info['false_objection_90']:.3f}, "
f"wts_prec={info['wrong_tool_semantic_precision']:.3f}")
print(f" Forge-weighted aggregate score: {weighted_score:.4f} "
f"(forge={forge_acc:.3f} x0.70 + public={public_acc:.3f} x0.30)")
if not math.isnan(forge_acc) and forge_acc >= VALID_RECALL_GATE and math.isnan(public_acc):
print(" NOTE: Only forge sources present — score is forge-only.")
return summary
"""
NEEDS_CLARIFICATION_WARNING_OLD = "NEEDS_CLARIFICATION_MIN_SUPPORT = 50"
NEEDS_CLARIFICATION_WARNING_NEW = """NEEDS_CLARIFICATION_MIN_SUPPORT = 50
DETERMINISTIC_INVALID_EXCLUDED = bool(globals().get("EXCLUDE_DETERMINISTIC_INVALID_FROM_ML_TRAINING", True))"""
def patch_cell_2(cells):
idx = find_cell_by_marker(cells, "t4_proven")
assert idx is not None, "Could not find t4_proven cell"
src = cell_src(cells[idx])
assert T4_PROVEN_OLD in src, "t4_proven old profile not found in cell"
src = src.replace(T4_PROVEN_OLD, T4_PROVEN_NEW)
assert SERIALIZER_VERSION_OLD in src, "SERIALIZER_VERSION_OLD not found"
src = src.replace(SERIALIZER_VERSION_OLD, SERIALIZER_VERSION_NEW)
assert EXCL_DI_FLAG_OLD in src, "EXCL_DI_FLAG_OLD not found"
src = src.replace(EXCL_DI_FLAG_OLD, EXCL_DI_FLAG_NEW)
set_cell_src(cells[idx], src)
print(f" [C2+C6] Patched cell {idx}: t4_proven max_length + SERIALIZER_VERSION + EXCLUDE_DI flag")
def patch_cell_9(cells):
idx = find_cell_by_marker(cells, "def serialize_state_v1")
assert idx is not None, "Could not find serialize_state_v1 cell"
src = cell_src(cells[idx])
assert "def serialize_state_v2" in src, "serialize_state_v2 not found"
assert SERIALIZE_FROM_OBJECT_OLD in src, "serialize_state_from_object old not found"
src = src.replace(SERIALIZE_FROM_OBJECT_OLD, SERIALIZE_V3_ADDITIONS + "\n" + SERIALIZE_FROM_OBJECT_NEW)
assert FIXTURE_V2_BLOCK_OLD in src, "fixture v2 block old not found"
src = src.replace(FIXTURE_V2_BLOCK_OLD, FIXTURE_V3_BLOCK_NEW)
set_cell_src(cells[idx], src)
print(f" [C1] Patched cell {idx}: serialize_state_v3 + v3 fixture")
def patch_cell_14(cells):
idx = find_cell_by_marker(cells, "forge_rows = (\n build_forge_synthetic_rows()")
assert idx is not None, "Could not find forge_rows assignment cell"
src = cell_src(cells[idx])
assert CONTRASTIVE_CALL_SITE_OLD in src, "CONTRASTIVE_CALL_SITE_OLD not found"
new_block = CONTRASTIVE_PAIRS_BLOCK + "\n" + PROTECTED_VALID_BLOCK + "\n" + CONTRASTIVE_CALL_SITE_PROTECTED_NEW
src = src.replace(CONTRASTIVE_CALL_SITE_OLD, new_block)
set_cell_src(cells[idx], src)
print(f" [C3+C4] Patched cell {idx}: contrastive pairs + protected valid row builders")
def patch_cell_20(cells):
idx = find_cell_by_marker(cells, "FAIL_ON_SUSPICIOUS_VALID_HARD_NEGATIVES")
assert idx is not None, "Could not find FAIL_ON_SUSPICIOUS_VALID_HARD_NEGATIVES cell"
src = cell_src(cells[idx])
assert "FAIL_ON_SUSPICIOUS_VALID_HARD_NEGATIVES = True" in src, "marker not found"
old_marker = "FAIL_ON_SUSPICIOUS_VALID_HARD_NEGATIVES = True #@param {type:\"boolean\"}\nMAX_SUSPICIOUS_VALID_SAMPLE_ROWS = 50 #@param {type:\"integer\"}"
new_marker = ("FAIL_ON_SUSPICIOUS_VALID_HARD_NEGATIVES = True #@param {type:\"boolean\"}\n"
"MAX_SUSPICIOUS_VALID_SAMPLE_ROWS = 50 #@param {type:\"integer\"}\n\n"
"# Filter deterministic_invalid rows from ML training when flag is set.\n"
"_EXCLUDE_DI = bool(globals().get(\"EXCLUDE_DETERMINISTIC_INVALID_FROM_ML_TRAINING\", True))")
if old_marker in src:
src = src.replace(old_marker, new_marker)
filter_injection = ("\n\nif _EXCLUDE_DI:\n"
" _before_di = len(all_rows)\n"
" all_rows = [r for r in all_rows if r.label != \"deterministic_invalid\"]\n"
" _removed_di = _before_di - len(all_rows)\n"
" if _removed_di:\n"
" print(f\"[EXCLUDE_DI] Removed {_removed_di} deterministic_invalid rows from ML training.\")\n")
target = "def metadata_value("
if target in src:
src = src.replace(target, filter_injection + "\n" + target, 1)
set_cell_src(cells[idx], src)
print(f" [C6] Patched cell {idx}: EXCLUDE_DI filter")
else:
print(f" [C6] WARNING: old_marker not found in cell {idx}, skipping DI filter injection")
def patch_cell_28(cells):
idx = find_cell_by_marker(cells, "gate_deficit_score")
assert idx is not None, "Could not find gate_deficit_score cell"
src = cell_src(cells[idx])
assert GATE_DEFICIT_OLD in src, "GATE_DEFICIT_OLD not found in cell"
src = src.replace(GATE_DEFICIT_OLD, GATE_DEFICIT_NEW)
assert GATE_DEFICIT_RETURN_OLD in src, "GATE_DEFICIT_RETURN_OLD not found"
src = src.replace(GATE_DEFICIT_RETURN_OLD, GATE_DEFICIT_RETURN_NEW)
set_cell_src(cells[idx], src)
print(f" [C5] Patched cell {idx}: constrained lexicographic checkpoint selection")
def patch_cell_29(cells):
idx = find_cell_by_marker(cells, "NEEDS_CLARIFICATION_MIN_SUPPORT")
assert idx is not None, "Could not find NEEDS_CLARIFICATION_MIN_SUPPORT cell"
src = cell_src(cells[idx])
assert NEEDS_CLARIFICATION_WARNING_OLD in src, "NC warning old not found"
src = src.replace(NEEDS_CLARIFICATION_WARNING_OLD, NEEDS_CLARIFICATION_WARNING_NEW)
if "import math" not in src:
src = "import math\n" + src
anchor = "def confusion_pair_diagnostics("
if anchor in src and "source_balanced_eval_summary" not in src:
src = src.replace(anchor, SOURCE_BALANCED_EVAL_SNIPPET + "\n" + anchor)
nc_warn_anchor = "NEEDS_CLARIFICATION_MIN_SUPPORT = 50"
nc_warn = ("NEEDS_CLARIFICATION_MIN_SUPPORT = 50\n"
"_nc_train_support = 0 # will be updated after split\n"
"if _nc_train_support < 200:\n"
" print(f\"NOTE: needs_clarification train support={_nc_train_support} < 200. \"\n"
" \"Label is telemetry-only (advisory_min_confidence=1.01). \"\n"
" \"Promotion gate will not apply until validation support >= NEEDS_CLARIFICATION_MIN_SUPPORT={NEEDS_CLARIFICATION_MIN_SUPPORT}.\")")
if "needs_clarification low-support" not in src:
needs_cl_comment = ("\n# needs_clarification is telemetry-only until test support "
">= NEEDS_CLARIFICATION_MIN_SUPPORT. Current artifact "
"thresholds enforce advisory_min_confidence=1.01 (no-op).\n")
src = src.replace("NEEDS_CLARIFICATION_MIN_SUPPORT = 50",
"NEEDS_CLARIFICATION_MIN_SUPPORT = 50" + needs_cl_comment)
set_cell_src(cells[idx], src)
print(f" [C7+C8] Patched cell {idx}: DETERMINISTIC_INVALID_EXCLUDED + source_balanced_eval_summary")
def main():
print(f"Loading notebook: {NB_PATH}")
nb = load_nb()
cells = nb["cells"]
print(f"Total cells: {len(cells)}")
print("\nApplying patches...")
patch_cell_2(cells)
patch_cell_9(cells)
patch_cell_14(cells)
patch_cell_20(cells)
patch_cell_28(cells)
patch_cell_29(cells)
save_nb(nb)
print(f"\nNotebook saved: {NB_PATH}")
print("\nVerifying patches...")
nb2 = load_nb()
cells2 = nb2["cells"]
checks = [
("serialize_state_v3", "serialize_state_v3 defined"),
("serialize_candidate_tool_schema", "candidate tool schema helper defined"),
("CANDIDATE_TOOL_SCHEMA:", "v3 layout marker in notebook"),
("USE_SERIALIZER_V3 = True", "USE_SERIALIZER_V3 flag"),
("max_length\": 1024", "t4_proven max_length 1024"),
("max_per_source\": 4_000", "t4_proven max_per_source 4000"),
("grad_accum\": 6", "t4_proven grad_accum 6"),
("FORGE_CONTRASTIVE_WRONG_TOOL_PAIRS", "contrastive pairs list"),
("build_contrastive_wrong_tool_rows", "contrastive rows builder"),
("FORGE_FIXED_WIDTH_VALID_CASES", "fixed-width valid cases"),
("build_fixed_width_numeric_rows", "fixed-width rows builder"),
("build_error_recovery_protected_rows", "error recovery protected rows builder"),
("FORGE_ERROR_RECOVERY_SCENARIOS", "error recovery scenarios list"),
("constrained_promotable", "constrained promotable metric"),
("CHECKPOINT_FALSE_OBJECTION_DISCARD_CEILING", "discard ceiling constant"),
("EXCLUDE_DETERMINISTIC_INVALID_FROM_ML_TRAINING", "DI exclusion flag"),
("DETERMINISTIC_INVALID_EXCLUDED", "DI excluded runtime var"),
("source_balanced_eval_summary", "source balanced eval function"),
("serializer_fixture_v3.json", "v3 fixture write"),
]
full_src = "\n".join(cell_src(c) for c in cells2)
all_ok = True
for marker, label in checks:
ok = marker in full_src
status = "✓" if ok else "✗ MISSING"
print(f" {status}: {label}")
if not ok:
all_ok = False
if all_ok:
print("\nAll patches verified successfully.")
else:
print("\nSome patches failed verification — check output above.")
sys.exit(1)
if __name__ == "__main__":
main()