forge-guardrails 0.1.2

Foundation types for an LLM-agent workflow framework
Documentation
use super::*;
use crate::clients::base::{
    ApiFormat, ChunkStream, LLMClient, LLMRequestOptions, LLMResponse, SamplingParams,
    TextResponse, ToolCall,
};
use crate::context::manager::ContextManager;
use crate::context::strategies::NoCompact;
use crate::core::message::{Message, MessageMeta, MessageRole, MessageType};
use crate::core::tool_spec::ToolSpec;
use crate::guardrails::{ErrorTracker, ResponseValidator};
use indexmap::IndexMap;
use serde_json::json;
use std::sync::Arc;

#[test]
fn tool_call_id_format() {
    assert_eq!(format_tool_call_id(0), "call_000000000");
    assert_eq!(format_tool_call_id(1), "call_000000001");
    assert_eq!(format_tool_call_id(42), "call_000000042");
    assert_eq!(format_tool_call_id(999999999), "call_999999999");
}

#[test]
fn inference_result_fields() {
    let result = InferenceResult {
        response: LLMResponse::Text(crate::clients::base::TextResponse::new("hi")),
        usage: None,
        usage_details: None,
        call_info: None,
        provider_response: None,
        provider_events: None,
        new_messages: vec![],
        tool_call_counter: 5,
        attempts: 1,
    };
    assert_eq!(result.tool_call_counter, 5);
    assert_eq!(result.attempts, 1);
}

#[test]
fn next_unique_tool_call_id_skips_existing_history_ids() {
    let messages = vec![
        Message::new(
            MessageRole::Assistant,
            "",
            MessageMeta::new(MessageType::ToolCall),
        )
        .with_tool_calls(vec![crate::core::message::ToolCallInfo::new(
            "prior",
            Some(IndexMap::new()),
            "call_000000000",
        )]),
        Message::new(
            MessageRole::Tool,
            "prior result",
            MessageMeta::new(MessageType::ToolResult),
        )
        .with_tool_name("prior")
        .with_tool_call_id("call_000000000"),
    ];
    let mut seen = existing_tool_call_ids(&messages);
    let mut counter = 0;

    assert_eq!(
        next_unique_tool_call_id(&mut counter, &mut seen),
        "call_000000001"
    );
    assert_eq!(counter, 2);
}

struct RetryRecordingClient {
    raw_bodies: std::sync::Mutex<Vec<Option<Arc<Value>>>>,
    initial_messages: std::sync::Mutex<Vec<Option<Arc<[Value]>>>>,
}

impl RetryRecordingClient {
    fn new() -> Self {
        Self {
            raw_bodies: std::sync::Mutex::new(Vec::new()),
            initial_messages: std::sync::Mutex::new(Vec::new()),
        }
    }
}

impl LLMClient for RetryRecordingClient {
    fn api_format(&self) -> ApiFormat {
        ApiFormat::OpenAI
    }

    async fn send(
        &self,
        _messages: Vec<Value>,
        _tools: Option<Vec<ToolSpec>>,
        _sampling: Option<SamplingParams>,
    ) -> Result<LLMResponse, crate::error::BackendError> {
        Ok(LLMResponse::Text(TextResponse::new("unused")))
    }

    async fn send_with_options(
        &self,
        _messages: Vec<Value>,
        _tools: Option<Vec<ToolSpec>>,
        options: LLMRequestOptions,
    ) -> Result<LLMResponse, crate::error::BackendError> {
        let mut raw_bodies = self.raw_bodies.lock().unwrap();
        let attempt = raw_bodies.len();
        raw_bodies.push(options.inbound_anthropic_body);
        drop(raw_bodies);
        self.initial_messages
            .lock()
            .unwrap()
            .push(options.initial_openai_messages);

        if attempt == 0 {
            Ok(LLMResponse::Text(TextResponse::new("not a tool call")))
        } else {
            let mut args = IndexMap::new();
            args.insert("message".to_string(), json!("ok"));
            Ok(LLMResponse::ToolCalls(vec![ToolCall::new("respond", args)]))
        }
    }

    async fn send_stream(
        &self,
        _messages: Vec<Value>,
        _tools: Option<Vec<ToolSpec>>,
        _sampling: Option<SamplingParams>,
    ) -> Result<ChunkStream, crate::error::StreamError> {
        Err(crate::error::StreamError::new("not implemented"))
    }

    async fn get_context_length(&self) -> Result<Option<i64>, crate::error::ContextDiscoveryError> {
        Ok(Some(4096))
    }
}

#[tokio::test]
async fn raw_anthropic_body_is_cleared_after_clean_attempt() {
    let client = RetryRecordingClient::new();
    let raw = Arc::new(json!({
        "model": "claude-3",
        "max_tokens": 64,
        "system": [{
            "type": "text",
            "text": "system",
            "cache_control": {"type": "ephemeral"}
        }],
        "messages": [{
            "role": "user",
            "content": [{
                "type": "text",
                "text": "hi",
                "cache_control": {"type": "ephemeral"}
            }]
        }]
    }));
    let mut messages = vec![Message::new(
        MessageRole::User,
        "hi",
        MessageMeta::new(MessageType::UserInput),
    )];
    let initial_messages: Arc<[Value]> =
        Arc::from(fold_and_serialize(&messages, "openai").into_boxed_slice());
    let mut context = ContextManager::new(Box::new(NoCompact), 4096, None, None, None);
    let validator = ResponseValidator::new(vec!["respond".to_string()], false, None);
    let mut tracker = ErrorTracker::new(3, 2);
    let mut counter = 0;
    let tools = vec![crate::tools::respond::respond_spec()];

    let result = run_inference_with_options(
        &mut messages,
        &client,
        &mut context,
        &validator,
        &mut tracker,
        &tools,
        &mut counter,
        0,
        "",
        Some(3),
        false,
        None,
        LLMRequestOptions {
            inbound_anthropic_body: Some(raw.clone()),
            initial_openai_messages: Some(initial_messages.clone()),
            ..Default::default()
        },
    )
    .await
    .expect("inference")
    .expect("result");

    assert_eq!(result.attempts, 2);
    let raw_bodies = client.raw_bodies.lock().unwrap().clone();
    assert_eq!(raw_bodies.len(), 2);
    assert!(raw_bodies[0]
        .as_ref()
        .is_some_and(|body| Arc::ptr_eq(body, &raw)));
    assert!(raw_bodies[1].is_none());
    let recorded_initial_messages = client.initial_messages.lock().unwrap().clone();
    assert_eq!(recorded_initial_messages.len(), 2);
    assert!(recorded_initial_messages[0]
        .as_ref()
        .is_some_and(|messages| Arc::ptr_eq(messages, &initial_messages)));
    assert!(recorded_initial_messages[1].is_none());
}