systemprompt-ai 0.2.1

Provider-agnostic LLM integration for systemprompt.io AI governance — Anthropic, OpenAI, Gemini, and local models unified behind one governed pipeline with cost tracking and audit.
Documentation
use anyhow::{Result, anyhow};
use std::time::Instant;

use crate::models::ai::{AiMessage, SamplingParams, SearchGroundedResponse, WebSource};
use crate::models::providers::anthropic::{
    AnthropicSearchContentBlock, AnthropicSearchRequest, AnthropicSearchResponse,
    AnthropicServerTool, AnthropicWebSearchResultItem,
};

use super::converters;
use super::provider::AnthropicProvider;

#[derive(Debug)]
pub struct SearchParams<'a> {
    pub messages: &'a [AiMessage],
    pub sampling: Option<&'a SamplingParams>,
    pub max_output_tokens: u32,
    pub model: &'a str,
    pub max_uses: Option<u32>,
}

impl<'a> SearchParams<'a> {
    pub const fn new(messages: &'a [AiMessage], max_output_tokens: u32, model: &'a str) -> Self {
        Self {
            messages,
            sampling: None,
            max_output_tokens,
            model,
            max_uses: None,
        }
    }

    pub const fn with_sampling(mut self, sampling: &'a SamplingParams) -> Self {
        self.sampling = Some(sampling);
        self
    }

    pub const fn with_max_uses(mut self, max_uses: u32) -> Self {
        self.max_uses = Some(max_uses);
        self
    }
}

pub async fn generate_with_web_search(
    provider: &AnthropicProvider,
    params: SearchParams<'_>,
) -> Result<SearchGroundedResponse> {
    let start = Instant::now();

    let (system_prompt, anthropic_messages) = converters::convert_messages(params.messages);

    let (temperature, top_p, top_k) = params
        .sampling
        .map_or((None, None, None), |s| (s.temperature, s.top_p, s.top_k));

    let web_search_tool = AnthropicServerTool::WebSearch {
        name: "web_search".to_string(),
        max_uses: params.max_uses.or(Some(5)),
    };

    let request = AnthropicSearchRequest {
        model: params.model.to_string(),
        messages: anthropic_messages,
        max_tokens: params.max_output_tokens,
        temperature,
        top_p,
        top_k,
        system: system_prompt,
        tools: vec![web_search_tool],
    };

    let response = provider
        .client
        .post(format!("{}/messages", provider.endpoint))
        .header("x-api-key", &provider.api_key)
        .header("anthropic-version", "2023-06-01")
        .header("content-type", "application/json")
        .json(&request)
        .send()
        .await
        .map_err(|e| anyhow!("HTTP request failed: {}", e))?;

    if !response.status().is_success() {
        let status = response.status();
        let error_body = response
            .text()
            .await
            .unwrap_or_else(|e| format!("<error reading response: {}>", e));
        return Err(anyhow!(
            "Anthropic API returned status {}: {}",
            status,
            error_body
        ));
    }

    let search_response: AnthropicSearchResponse = response
        .json()
        .await
        .map_err(|e| anyhow!("Failed to parse response: {}", e))?;

    Ok(extract_search_response(&search_response, start))
}

fn extract_search_response(
    response: &AnthropicSearchResponse,
    start: Instant,
) -> SearchGroundedResponse {
    let mut content_text = String::new();
    let mut sources = Vec::new();
    let mut web_search_queries = Vec::new();
    let mut seen_urls = std::collections::HashSet::new();

    for block in &response.content {
        match block {
            AnthropicSearchContentBlock::Text { text, citations } => {
                content_text.push_str(text);

                if let Some(cites) = citations {
                    for citation in cites {
                        if seen_urls.insert(citation.url.clone()) {
                            sources.push(WebSource {
                                title: citation.title.clone(),
                                uri: citation.url.clone(),
                                relevance: 1.0,
                            });
                        }
                    }
                }
            },
            AnthropicSearchContentBlock::ServerToolUse { input, .. } => {
                if let Some(query) = input.get("query").and_then(|q| q.as_str()) {
                    web_search_queries.push(query.to_string());
                }
            },
            AnthropicSearchContentBlock::WebSearchToolResult { content, .. } => {
                for item in content {
                    if let AnthropicWebSearchResultItem::WebSearchResult { url, title, .. } = item {
                        if seen_urls.insert(url.clone()) {
                            sources.push(WebSource {
                                title: title.clone(),
                                uri: url.clone(),
                                relevance: 1.0,
                            });
                        }
                    }
                }
            },
        }
    }

    let latency_ms = start.elapsed().as_millis() as u64;
    let tokens_used = Some(response.usage.input_tokens + response.usage.output_tokens);

    SearchGroundedResponse {
        content: content_text,
        sources,
        confidence_scores: Vec::new(),
        web_search_queries,
        url_context_metadata: None,
        tokens_used,
        latency_ms,
        finish_reason: response.stop_reason.clone(),
        safety_ratings: None,
    }
}