forge-guardrails 0.1.2

Foundation types for an LLM-agent workflow framework
Documentation
use std::fs::OpenOptions;
use std::io::Write;
use std::path::{Path, PathBuf};

use forge_guardrails::{Message, MessageType, ToolSpec};
use indexmap::IndexMap;
use serde_json::{json, Value};

use crate::scenarios::SmokeScenario;

pub(crate) fn write_hard_negatives(
    output: Option<&str>,
    row: &Value,
    scenario: &SmokeScenario,
    messages: &[Message],
) -> Result<(), String> {
    let Some(output) = output else {
        return Ok(());
    };
    let Some(corrected_positive) = scenario.corrected_positive.as_ref() else {
        return Ok(());
    };
    if row.get("success").and_then(Value::as_bool).unwrap_or(false) {
        return Ok(());
    }

    let corrected_candidate_calls = corrected_candidate_calls(corrected_positive);
    let corrected_candidate_call = corrected_candidate_calls
        .as_array()
        .and_then(|calls| calls.first())
        .cloned()
        .unwrap_or(Value::Null);
    let corrected_final_response = corrected_positive
        .get("final_text")
        .cloned()
        .unwrap_or(Value::Null);
    let training_context = hard_negative_context(row, scenario, messages);
    let candidate_calls = candidate_calls_from_trace(row);
    let candidate_call = candidate_calls
        .as_array()
        .and_then(|calls| calls.last())
        .cloned()
        .unwrap_or(Value::Null);

    let outcome = json!({
        "scenario": scenario.name,
        "scenario_family": scenario.name,
        "user_request": scenario.user_message,
        "run": row.get("run").cloned().unwrap_or(Value::Null),
        "failure_kind": row.get("failure_kind").cloned().unwrap_or(Value::Null),
        "accuracy": row.get("accuracy").cloned().unwrap_or(Value::Null),
        "corrected_positive": corrected_positive,
        "corrected_candidate_call": corrected_candidate_call,
        "corrected_candidate_calls": corrected_candidate_calls,
        "corrected_final_response": corrected_final_response,
    });

    let tool_row = json!({
        "kind": "tool_call",
        "context": training_context,
        "candidate": {
            "tool_sequence": row.get("tool_sequence").cloned().unwrap_or(Value::Null),
            "tool_args": row.get("tool_args").cloned().unwrap_or(Value::Null),
            "candidate_call": candidate_call,
            "candidate_calls": candidate_calls,
        },
        "classifier_scores": row.get("classifier_scores").cloned().unwrap_or_else(|| json!([])),
        "outcome": outcome.clone(),
    });
    append_jsonl(
        &hard_negative_path(output, "tool_call_hard_negatives"),
        &tool_row,
    )?;

    let final_row = json!({
        "kind": "final_response",
        "context": training_context,
        "candidate": {
            "final_text": row.get("final_text").cloned().unwrap_or(Value::Null),
        },
        "final_response_classifier_scores": row
            .get("final_response_classifier_scores")
            .cloned()
            .unwrap_or_else(|| json!([])),
        "outcome": outcome,
    });
    append_jsonl(
        &hard_negative_path(output, "final_response_hard_negatives"),
        &final_row,
    )
}

fn hard_negative_context(row: &Value, scenario: &SmokeScenario, messages: &[Message]) -> Value {
    json!({
        "user_request": scenario.user_message,
        "workflow_state": hard_negative_workflow_state(row, scenario, messages),
        "available_tools": available_tools(&scenario.workflow.tools),
        "required_facts": scenario.required_facts,
    })
}

fn hard_negative_workflow_state(
    row: &Value,
    scenario: &SmokeScenario,
    messages: &[Message],
) -> Value {
    let tool_sequence = row
        .get("tool_sequence")
        .and_then(Value::as_array)
        .cloned()
        .unwrap_or_default();
    let prior_tools = tool_sequence
        .iter()
        .take(tool_sequence.len().saturating_sub(1))
        .filter_map(Value::as_str)
        .collect::<Vec<_>>();
    let completed_steps = scenario
        .workflow
        .required_steps
        .iter()
        .filter(|step| prior_tools.iter().any(|name| name == &step.as_str()))
        .cloned()
        .collect::<Vec<_>>();
    let pending_steps = scenario
        .workflow
        .required_steps
        .iter()
        .filter(|step| !completed_steps.contains(step))
        .cloned()
        .collect::<Vec<_>>();
    let mut terminal_tools = scenario
        .workflow
        .terminal_tools
        .iter()
        .cloned()
        .collect::<Vec<_>>();
    terminal_tools.sort();

    json!({
        "required_steps": scenario.workflow.required_steps,
        "completed_steps": completed_steps,
        "pending_steps": pending_steps,
        "terminal_tools": terminal_tools,
        "recent_errors": recent_errors(messages, row),
    })
}

fn available_tools(tools: &IndexMap<String, forge_guardrails::ToolDef>) -> Value {
    Value::Array(
        tools
            .values()
            .map(|tool| tool_spec_json(&tool.spec))
            .collect::<Vec<_>>(),
    )
}

fn tool_spec_json(spec: &ToolSpec) -> Value {
    json!({
        "name": spec.name,
        "description": spec.description,
        "parameters": spec
            .json_schema
            .clone()
            .unwrap_or_else(|| json!({"type": "object", "properties": {}})),
    })
}

fn recent_errors(messages: &[Message], row: &Value) -> Value {
    let mut errors = messages
        .iter()
        .filter_map(|message| {
            let is_error = matches!(
                message.metadata.msg_type,
                MessageType::RetryNudge | MessageType::StepNudge | MessageType::PrerequisiteNudge
            ) || (message.metadata.msg_type == MessageType::ToolResult
                && message.content.contains("[ToolError]"));
            if is_error && !message.content.trim().is_empty() {
                Some(json!(message.content))
            } else {
                None
            }
        })
        .collect::<Vec<_>>();
    if let Some(error) = row.get("error_message").filter(|value| !value.is_null()) {
        errors.push(error.clone());
    }
    Value::Array(errors)
}

fn candidate_calls_from_trace(row: &Value) -> Value {
    let names = row
        .get("tool_sequence")
        .and_then(Value::as_array)
        .cloned()
        .unwrap_or_default();
    let args = row
        .get("tool_args")
        .and_then(Value::as_array)
        .cloned()
        .unwrap_or_default();
    Value::Array(
        names
            .iter()
            .enumerate()
            .filter_map(|(index, name)| {
                let name = name.as_str()?;
                Some(json!({
                    "name": name,
                    "arguments": args.get(index).cloned().unwrap_or_else(|| json!({})),
                }))
            })
            .collect(),
    )
}

fn corrected_candidate_calls(corrected_positive: &Value) -> Value {
    if let Some(calls) = corrected_positive.get("candidate_calls") {
        return calls.clone();
    }
    if let Some(call) = corrected_positive.get("candidate_call") {
        return json!([call]);
    }
    json!([])
}

fn append_jsonl(path: &Path, row: &Value) -> Result<(), String> {
    let line = serde_json::to_string(row).map_err(|err| err.to_string())?;
    let mut file = OpenOptions::new()
        .create(true)
        .append(true)
        .open(path)
        .map_err(|err| err.to_string())?;
    writeln!(file, "{line}").map_err(|err| err.to_string())
}

fn hard_negative_path(output: &str, suffix: &str) -> PathBuf {
    let path = Path::new(output);
    let stem = path.file_stem().and_then(|value| value.to_str());
    let extension = path.extension().and_then(|value| value.to_str());
    let file_name = match (stem, extension) {
        (Some(stem), Some(extension)) => format!("{stem}.{suffix}.{extension}"),
        (Some(stem), None) => format!("{stem}.{suffix}"),
        _ => format!("{output}.{suffix}.jsonl"),
    };
    path.with_file_name(file_name)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::scenarios::build_scenario;
    use std::fs;
    use std::time::{SystemTime, UNIX_EPOCH};

    #[test]
    fn write_hard_negatives_creates_sibling_files_for_failed_rows() {
        let mut dir = std::env::temp_dir();
        let nonce = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .expect("clock")
            .as_nanos();
        dir.push(format!(
            "forge-eval-report-test-{}-{nonce}",
            std::process::id()
        ));
        fs::create_dir_all(&dir).expect("tempdir");
        let output = dir.join("rows.jsonl");
        let scenario = build_scenario("basic_2step", true).expect("scenario");
        let row = json!({
            "scenario": "basic_2step",
            "run": 1,
            "success": false,
            "accuracy": false,
            "failure_kind": "accuracy_false",
            "tool_sequence": ["get_country_info", "summarize"],
            "tool_args": [{"country": "France"}, {"content": "bad"}],
            "final_text": "bad",
            "classifier_scores": [],
            "final_response_classifier_scores": [],
        });

        write_hard_negatives(output.to_str(), &row, &scenario, &[]).expect("write");

        let tool_path = dir.join("rows.tool_call_hard_negatives.jsonl");
        let final_path = dir.join("rows.final_response_hard_negatives.jsonl");
        let tool_text = fs::read_to_string(tool_path).expect("tool hard negatives");
        let final_text = fs::read_to_string(final_path).expect("final hard negatives");
        let tool_row: Value = serde_json::from_str(tool_text.trim()).expect("tool row");
        let final_row: Value = serde_json::from_str(final_text.trim()).expect("final row");
        assert_eq!(tool_row["kind"], json!("tool_call"));
        assert_eq!(
            tool_row["context"]["user_request"],
            json!("What is the capital of France?")
        );
        assert_eq!(
            tool_row["context"]["required_facts"],
            json!(["Paris", "capital"])
        );
        assert_eq!(
            tool_row["context"]["workflow_state"]["required_steps"],
            json!(["get_country_info"])
        );
        assert_eq!(
            tool_row["candidate"]["candidate_call"],
            json!({"name": "summarize", "arguments": {"content": "bad"}})
        );
        assert_eq!(
            tool_row["outcome"]["corrected_candidate_call"],
            json!({"name": "summarize", "arguments": {"content": "The capital of France is Paris."}})
        );
        assert_eq!(final_row["kind"], json!("final_response"));
        assert_eq!(
            final_row["outcome"]["corrected_final_response"],
            json!("The capital of France is Paris.")
        );
    }
}