forge-guardrails 0.1.2

Foundation types for an LLM-agent workflow framework
Documentation
use std::sync::{Arc, Mutex as StdMutex};

use forge_guardrails::error::ToolError;
use forge_guardrails::workflow::ToolCallable;
use forge_guardrails::{ToolDef, ToolSpec, Workflow};
use indexmap::IndexMap;
use serde_json::{json, Value};

pub(crate) struct SmokeScenario {
    pub(crate) name: String,
    pub(crate) workflow: Workflow,
    pub(crate) user_message: String,
    pub(crate) required_facts: Vec<String>,
    pub(crate) capture: Arc<StdMutex<Option<IndexMap<String, Value>>>>,
    pub(crate) corrected_positive: Option<Value>,
}

pub(crate) fn build_scenario(
    name: &str,
    use_required_steps: bool,
) -> Result<SmokeScenario, String> {
    match name {
        "basic_2step" => basic_2step(use_required_steps),
        "sequential_3step" => sequential_3step(use_required_steps),
        "error_recovery" => error_recovery(use_required_steps),
        "inconsistent_api_recovery_stateful" => {
            inconsistent_api_recovery_stateful(use_required_steps)
        }
        other => Err(format!("unsupported scenario: {other}")),
    }
}

fn basic_2step(use_required_steps: bool) -> Result<SmokeScenario, String> {
    let capture = Arc::new(StdMutex::new(None));
    let mut tools = IndexMap::new();
    tools.insert(
        "get_country_info".to_string(),
        make_tool(
            "get_country_info",
            "Look up facts about a country.",
            json!({
                "type": "object",
                "properties": {"country": {"type": "string", "description": "Country name"}},
                "required": ["country"]
            }),
            |_args| {
                Ok(json!(
                    "The capital of France is Paris. Population: 2.1 million."
                ))
            },
        )?,
    );
    tools.insert(
        "summarize".to_string(),
        terminal_tool(
            "summarize",
            "Summarize content and provide the final answer.",
            json!({
                "type": "object",
                "properties": {"content": {"type": "string", "description": "The content to summarize"}},
                "required": ["content"]
            }),
            "content",
            capture.clone(),
        )?,
    );
    let required = if use_required_steps {
        vec!["get_country_info".to_string()]
    } else {
        Vec::new()
    };
    let workflow = Workflow::new(
        "basic_2step",
        "Simple 2-step information retrieval and summary",
        tools,
        required,
        "summarize".to_string().into(),
        "You are a helpful assistant. First use get_country_info, then summarize.",
    )?;
    Ok(SmokeScenario {
        name: "basic_2step".to_string(),
        workflow,
        user_message: "What is the capital of France?".to_string(),
        required_facts: vec!["Paris".to_string(), "capital".to_string()],
        capture,
        corrected_positive: Some(json!({
            "final_text": "The capital of France is Paris.",
            "candidate_calls": [
                {"name": "summarize", "arguments": {"content": "The capital of France is Paris."}}
            ]
        })),
    })
}

fn sequential_3step(use_required_steps: bool) -> Result<SmokeScenario, String> {
    let capture = Arc::new(StdMutex::new(None));
    let mut tools = IndexMap::new();
    tools.insert(
        "fetch_sales_data".to_string(),
        make_tool(
            "fetch_sales_data",
            "Fetch sales data for a given quarter and year.",
            json!({
                "type": "object",
                "properties": {
                    "quarter": {"type": "integer", "description": "Quarter number"},
                    "year": {"type": "integer", "description": "Four-digit year"}
                },
                "required": ["quarter", "year"]
            }),
            |_args| {
                Ok(json!(
                    "Dataset: 150 records, 12 columns, covering Q1-Q4 2024 sales data."
                ))
            },
        )?,
    );
    tools.insert(
        "analyze_sales".to_string(),
        make_tool(
            "analyze_sales",
            "Analyze the loaded sales data and produce findings.",
            json!({"type": "object", "properties": {}}),
            |_args| Ok(json!("Analysis: Revenue grew 23% YoY. Top product: Widget Pro. Weakest region: APAC.")),
        )?,
    );
    tools.insert(
        "report".to_string(),
        terminal_tool(
            "report",
            "Produce a final report from findings.",
            json!({
                "type": "object",
                "properties": {"findings": {"type": "string", "description": "The findings to include in the report"}},
                "required": ["findings"]
            }),
            "findings",
            capture.clone(),
        )?,
    );
    let required = if use_required_steps {
        vec!["fetch_sales_data".to_string(), "analyze_sales".to_string()]
    } else {
        Vec::new()
    };
    let workflow = Workflow::new(
        "sequential_3step",
        "Fetch data, analyze, then report",
        tools,
        required,
        "report".to_string().into(),
        "You are a data analyst assistant. Fetch data, analyze it, then report.",
    )?;
    Ok(SmokeScenario {
        name: "sequential_3step".to_string(),
        workflow,
        user_message: "Generate a sales report from the Q4 2024 dataset.".to_string(),
        required_facts: vec![
            "23% YoY growth".to_string(),
            "Widget Pro".to_string(),
            "APAC".to_string(),
        ],
        capture,
        corrected_positive: Some(json!({
            "final_text": "Revenue grew 23% YoY. Top product: Widget Pro. Weakest region: APAC.",
            "candidate_calls": [
                {
                    "name": "report",
                    "arguments": {
                        "findings": "Revenue grew 23% YoY. Top product: Widget Pro. Weakest region: APAC."
                    }
                }
            ]
        })),
    })
}

fn error_recovery(use_required_steps: bool) -> Result<SmokeScenario, String> {
    let capture = Arc::new(StdMutex::new(None));
    let mut tools = IndexMap::new();
    tools.insert(
        "fetch".to_string(),
        make_tool(
            "fetch",
            "Fetch records. The count parameter must be a numeric string.",
            json!({
                "type": "object",
                "properties": {"count": {"type": "string", "description": "Zero-padded 4-digit count"}},
                "required": ["count"]
            }),
            |args| {
                let count = args.get("count").and_then(Value::as_str).unwrap_or("");
                if count.len() == 4 && count.chars().all(|c| c.is_ascii_digit()) {
                    Ok(json!(format!(
                        "Fetched {} records.",
                        count.parse::<i64>().unwrap_or(0)
                    )))
                } else {
                    Err(ToolError::Execution(format!(
                        "count must be a zero-padded 4-digit string, got '{count}'"
                    )))
                }
            },
        )?,
    );
    tools.insert(
        "summarize".to_string(),
        terminal_tool(
            "summarize",
            "Summarize the fetched content.",
            json!({
                "type": "object",
                "properties": {"content": {"type": "string", "description": "The content to summarize"}},
                "required": ["content"]
            }),
            "content",
            capture.clone(),
        )?,
    );
    let required = if use_required_steps {
        vec!["fetch".to_string()]
    } else {
        Vec::new()
    };
    let workflow = Workflow::new(
        "error_recovery",
        "Fetch with validation, then summarize",
        tools,
        required,
        "summarize".to_string().into(),
        "You are a helpful assistant. Fetch the requested records, then summarize them.",
    )?;
    Ok(SmokeScenario {
        name: "error_recovery".to_string(),
        workflow,
        user_message: "Fetch 10 records and summarize them.".to_string(),
        required_facts: vec!["10 records".to_string()],
        capture,
        corrected_positive: Some(json!({
            "final_text": "Fetched 10 records and summarized the record count.",
            "candidate_calls": [
                {"name": "fetch", "arguments": {"count": "0010"}},
                {"name": "summarize", "arguments": {"content": "Fetched 10 records and summarized the record count."}}
            ]
        })),
    })
}

fn inconsistent_api_recovery_stateful(use_required_steps: bool) -> Result<SmokeScenario, String> {
    let capture = Arc::new(StdMutex::new(None));
    let mut tools = IndexMap::new();
    tools.insert(
        "legacy_list_accounts".to_string(),
        make_tool(
            "legacy_list_accounts",
            "List available accounts in the legacy audit system.",
            json!({"type": "object", "properties": {}}),
            |_args| Ok(json!("Accounts: ACC-12345 Acme Corp")),
        )?,
    );
    tools.insert(
        "legacy_submit_audit".to_string(),
        terminal_tool(
            "legacy_submit_audit",
            "Submit the final compliance audit report.",
            json!({
                "type": "object",
                "properties": {
                    "report": {"type": "string", "description": "Final audit report"}
                },
                "required": ["report"]
            }),
            "report",
            capture.clone(),
        )?,
    );
    let required = if use_required_steps {
        vec!["legacy_list_accounts".to_string()]
    } else {
        Vec::new()
    };
    let workflow = Workflow::new(
        "inconsistent_api_recovery_stateful",
        "Stateful legacy audit smoke scenario",
        tools,
        required,
        "legacy_submit_audit".to_string().into(),
        "You are a compliance audit assistant. List accounts before submitting the final audit.",
    )?;
    Ok(SmokeScenario {
        name: "inconsistent_api_recovery_stateful".to_string(),
        workflow,
        user_message: concat!(
            "Run a compliance audit for Acme Corp. Include account ACC-12345 and ",
            "compliance_status PASS in the submitted report."
        )
        .to_string(),
        required_facts: vec!["ACC-12345".to_string(), "PASS".to_string()],
        capture,
        corrected_positive: Some(json!({
            "final_text": "Account ACC-12345 compliance_status PASS.",
            "account_id": "ACC-12345",
            "compliance_status": "PASS",
            "candidate_calls": [
                {
                    "name": "legacy_submit_audit",
                    "arguments": {
                        "report": "account_id=ACC-12345; compliance_status=PASS"
                    }
                }
            ]
        })),
    })
}

fn make_tool<F>(name: &str, description: &str, schema: Value, func: F) -> Result<ToolDef, String>
where
    F: Fn(IndexMap<String, Value>) -> Result<Value, ToolError> + Send + Sync + 'static,
{
    let spec = ToolSpec::from_json_schema(name, description, &schema)?;
    let func = Arc::new(func);
    let callable: ToolCallable = Arc::new(move |args| {
        let func = func.clone();
        Box::pin(async move { func(args) })
    });
    Ok(ToolDef::new(spec, callable))
}

fn terminal_tool(
    name: &str,
    description: &str,
    schema: Value,
    output_key: &'static str,
    capture: Arc<StdMutex<Option<IndexMap<String, Value>>>>,
) -> Result<ToolDef, String> {
    make_tool(name, description, schema, move |args| {
        *capture.lock().expect("capture lock") = Some(args.clone());
        Ok(args.get(output_key).cloned().unwrap_or_else(|| json!("")))
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn builds_plumbing_scenario() {
        let scenario = build_scenario("basic_2step", true).expect("scenario");
        assert_eq!(scenario.workflow.required_steps, vec!["get_country_info"]);
        assert!(scenario.workflow.terminal_tools.contains("summarize"));
    }

    #[test]
    fn builds_inconsistent_api_recovery_stateful_scenario() {
        let scenario =
            build_scenario("inconsistent_api_recovery_stateful", true).expect("scenario");
        assert_eq!(
            scenario.workflow.required_steps,
            vec!["legacy_list_accounts"]
        );
        assert!(scenario
            .workflow
            .terminal_tools
            .contains("legacy_submit_audit"));
    }
}