systemprompt-ai 0.2.2

Provider-agnostic LLM integration for systemprompt.io AI governance — Anthropic, OpenAI, Gemini, and local models unified behind one governed pipeline with cost tracking and audit.
Documentation
use anyhow::{Result, anyhow};
use std::time::Instant;

use crate::models::ai::{AiMessage, SamplingParams, SearchGroundedResponse, WebSource};
use crate::models::providers::gemini::{
    GeminiPart, GeminiRequest, GeminiResponse, GeminiTool, GoogleSearch, UrlContext,
};

use super::constants::defaults;
use super::provider::GeminiProvider;
use super::{converters, request_builders};

pub struct SearchParams<'a> {
    pub messages: &'a [AiMessage],
    pub sampling: Option<&'a SamplingParams>,
    pub max_output_tokens: u32,
    pub model: &'a str,
    pub urls: Option<Vec<String>>,
}

pub struct SearchParamsBuilder<'a> {
    messages: &'a [AiMessage],
    sampling: Option<&'a SamplingParams>,
    max_output_tokens: u32,
    model: &'a str,
    urls: Option<Vec<String>>,
}

impl<'a> SearchParamsBuilder<'a> {
    pub const fn new(messages: &'a [AiMessage], max_output_tokens: u32, model: &'a str) -> Self {
        Self {
            messages,
            sampling: None,
            max_output_tokens,
            model,
            urls: None,
        }
    }

    pub const fn with_sampling(mut self, sampling: &'a SamplingParams) -> Self {
        self.sampling = Some(sampling);
        self
    }

    pub fn with_urls(mut self, urls: Vec<String>) -> Self {
        self.urls = Some(urls);
        self
    }

    pub fn build(self) -> SearchParams<'a> {
        SearchParams {
            messages: self.messages,
            sampling: self.sampling,
            max_output_tokens: self.max_output_tokens,
            model: self.model,
            urls: self.urls,
        }
    }
}

impl<'a> SearchParams<'a> {
    pub const fn builder(
        messages: &'a [AiMessage],
        max_output_tokens: u32,
        model: &'a str,
    ) -> SearchParamsBuilder<'a> {
        SearchParamsBuilder::new(messages, max_output_tokens, model)
    }
}

pub async fn generate_with_google_search(
    provider: &GeminiProvider,
    params: SearchParams<'_>,
) -> Result<SearchGroundedResponse> {
    let start = Instant::now();

    let contents = converters::convert_messages(params.messages);
    let generation_config = request_builders::build_generation_config(
        params.sampling,
        params.max_output_tokens,
        None,
        None,
    );

    let gemini_tools = build_search_tools(params.urls.is_some());

    let request = GeminiRequest {
        contents,
        generation_config: Some(generation_config),
        safety_settings: None,
        tools: Some(gemini_tools),
        tool_config: None,
    };

    let response_text =
        request_builders::send_request(provider, &request, params.model, "generateContent").await?;

    let gemini_response: GeminiResponse = request_builders::parse_response(&response_text)?;

    extract_grounded_response(&gemini_response, start)
}

fn build_search_tools(include_url_context: bool) -> Vec<GeminiTool> {
    let mut tools = vec![GeminiTool {
        function_declarations: None,
        google_search: Some(GoogleSearch::default()),
        url_context: None,
        code_execution: None,
    }];

    if include_url_context {
        tools.push(GeminiTool {
            function_declarations: None,
            google_search: None,
            url_context: Some(UrlContext::default()),
            code_execution: None,
        });
    }

    tools
}

fn extract_grounded_response(
    response: &GeminiResponse,
    start: Instant,
) -> Result<SearchGroundedResponse> {
    let candidate = response
        .candidates
        .first()
        .ok_or_else(|| anyhow!("No response from Gemini"))?;

    let content_text = candidate
        .content
        .as_ref()
        .and_then(|c| {
            c.parts.iter().find_map(|p| match p {
                GeminiPart::Text { text } => Some(text.clone()),
                _ => None,
            })
        })
        .unwrap_or_else(String::new);

    let mut sources = Vec::new();
    let mut confidence_scores = Vec::new();
    let mut web_search_queries = Vec::new();

    if let Some(grounding) = &candidate.grounding_metadata {
        for chunk in &grounding.grounding_chunks {
            sources.push(WebSource {
                title: chunk.web.title.clone(),
                uri: chunk.web.uri.clone(),
                relevance: defaults::RELEVANCE_SCORE,
            });
        }

        for support in &grounding.grounding_supports {
            for score in &support.confidence_scores {
                confidence_scores.push(*score);
            }
        }

        web_search_queries.clone_from(&grounding.web_search_queries);
    }

    let url_context_metadata = candidate.url_context_metadata.as_ref().map(|meta| {
        use systemprompt_models::ai::UrlMetadata;
        meta.url_metadata
            .iter()
            .map(|url_meta| UrlMetadata {
                retrieved_url: url_meta.retrieved_url.clone(),
                url_retrieval_status: url_meta.url_retrieval_status.clone(),
            })
            .collect()
    });

    let latency_ms = start.elapsed().as_millis() as u64;

    let finish_reason = candidate.finish_reason.clone();
    let safety_ratings = candidate.safety_ratings.as_ref().map(|ratings| {
        ratings
            .iter()
            .map(|r| {
                serde_json::json!({
                    "category": r.category,
                    "probability": r.probability
                })
            })
            .collect()
    });

    Ok(SearchGroundedResponse {
        content: content_text,
        sources,
        confidence_scores,
        web_search_queries,
        url_context_metadata,
        tokens_used: response.usage_metadata.as_ref().map(|u| u.total),
        latency_ms,
        finish_reason,
        safety_ratings,
    })
}