aisdk 0.5.2

An open-source Rust library for building AI-powered applications, inspired by the Vercel AI SDK. It provides a robust, type-safe, and easy-to-use interface for interacting with various Large Language Models (LLMs).
Documentation
//! Client implementation for the Google provider.
use crate::core::client::{EmbeddingClient, LanguageModelClient};
use crate::error::{Error, Result};
use crate::providers::google::{Google, ModelName};
use derive_builder::Builder;
use reqwest::header::CONTENT_TYPE;
use reqwest_eventsource::Event;
use serde::{Deserialize, Serialize};

pub(crate) mod types;

#[derive(Debug, Default, Clone, Serialize, Deserialize, Builder)]
#[builder(setter(into), build_fn(error = "Error"))]
pub(crate) struct GoogleOptions {
    pub(crate) model: String,
    #[serde(skip_serializing_if = "Option::is_none")]
    #[builder(default)]
    pub(crate) request: Option<types::GenerateContentRequest>,
    #[serde(skip)]
    #[builder(default)]
    pub(crate) streaming: bool,
}

impl GoogleOptions {
    pub(crate) fn builder() -> GoogleOptionsBuilder {
        GoogleOptionsBuilder::default()
    }
}

#[derive(Builder, Debug, Default, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct GoogleEmbeddingOptions {
    pub(crate) model: String,
    pub(crate) requests: Vec<types::EmbedContentRequest>,
}

impl<M: ModelName> LanguageModelClient for Google<M> {
    type Response = types::GenerateContentResponse;
    type StreamEvent = types::GoogleStreamEvent;

    fn path(&self) -> String {
        if let Some(ref path) = self.settings.path {
            return path.clone();
        }
        if self.lm_options.streaming {
            return format!(
                "/v1beta/models/{}:streamGenerateContent",
                self.lm_options.model
            );
        };
        format!("/v1beta/models/{}:generateContent", self.lm_options.model)
    }

    fn method(&self) -> reqwest::Method {
        reqwest::Method::POST
    }

    fn headers(&self) -> reqwest::header::HeaderMap {
        let mut headers = reqwest::header::HeaderMap::new();
        headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
        headers.insert("x-goog-api-key", self.settings.api_key.parse().unwrap());
        headers
    }

    fn query_params(&self) -> Vec<(&str, &str)> {
        if self.lm_options.streaming {
            return vec![("alt", "sse")];
        }
        Vec::new()
    }

    fn body(&self) -> reqwest::Body {
        if let Some(request) = &self.lm_options.request {
            let body = serde_json::to_string(request).unwrap();
            return reqwest::Body::from(body);
        };
        reqwest::Body::from("{}")
    }

    fn parse_stream_sse(
        event: std::result::Result<Event, reqwest_eventsource::Error>,
    ) -> Result<Self::StreamEvent> {
        match event {
            Ok(event) => match event {
                Event::Open => Ok(types::GoogleStreamEvent::NotSupported("{}".to_string())),
                Event::Message(msg) => {
                    let value: serde_json::Value =
                        serde_json::from_str(&msg.data).map_err(|e| Error::ApiError {
                            status_code: None,
                            details: format!("Invalid JSON in SSE data: {e}"),
                        })?;

                    Ok(
                        serde_json::from_value::<types::GenerateContentResponse>(value)
                            .map(types::GoogleStreamEvent::Response)
                            .unwrap_or(types::GoogleStreamEvent::NotSupported(msg.data)),
                    )
                }
            },
            Err(e) => {
                // Extract status code if it's an InvalidStatusCode error
                let status_code = match &e {
                    reqwest_eventsource::Error::InvalidStatusCode(status, _) => Some(*status),
                    _ => None,
                };
                Err(Error::ApiError {
                    status_code,
                    details: e.to_string(),
                })
            }
        }
    }

    fn end_stream(event: &Self::StreamEvent) -> bool {
        match event {
            types::GoogleStreamEvent::Response(resp) => {
                resp.candidates.iter().any(|c| c.finish_reason.is_some())
            }
            _ => false,
        }
    }
}

impl<M: ModelName> EmbeddingClient for Google<M> {
    type Response = types::BatchEmbedContentsResponse;

    fn path(&self) -> String {
        format!(
            "/v1beta/models/{}:batchEmbedContents",
            self.embedding_options.model
        )
    }

    fn method(&self) -> reqwest::Method {
        reqwest::Method::POST
    }

    fn headers(&self) -> reqwest::header::HeaderMap {
        let mut headers = reqwest::header::HeaderMap::new();
        headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
        headers.insert("x-goog-api-key", self.settings.api_key.parse().unwrap());
        headers
    }

    fn query_params(&self) -> Vec<(&str, &str)> {
        Vec::new()
    }

    fn body(&self) -> reqwest::Body {
        let request = types::BatchEmbedContentsRequest {
            requests: self.embedding_options.requests.clone(),
        };
        let body = serde_json::to_string(&request).unwrap();
        reqwest::Body::from(body)
    }
}