rig/providers/gemini/
embedding.rs

1// ================================================================
2//! Google Gemini Embeddings Integration
3//! From [Gemini API Reference](https://ai.google.dev/api/embeddings)
4// ================================================================
5
6use serde_json::json;
7
8use crate::embeddings::{self, EmbeddingError};
9
10use super::{client::ApiResponse, Client};
11
12/// `embedding-001` embedding model
13pub const EMBEDDING_001: &str = "embedding-001";
14/// `text-embedding-004` embedding model
15pub const EMBEDDING_004: &str = "text-embedding-004";
16#[derive(Clone)]
17pub struct EmbeddingModel {
18    client: Client,
19    model: String,
20    ndims: Option<usize>,
21}
22
23impl EmbeddingModel {
24    pub fn new(client: Client, model: &str, ndims: Option<usize>) -> Self {
25        Self {
26            client,
27            model: model.to_string(),
28            ndims,
29        }
30    }
31}
32
33impl embeddings::EmbeddingModel for EmbeddingModel {
34    const MAX_DOCUMENTS: usize = 1024;
35
36    fn ndims(&self) -> usize {
37        match self.model.as_str() {
38            EMBEDDING_001 => 768,
39            EMBEDDING_004 => 1024,
40            _ => 0, // Default to 0 for unknown models
41        }
42    }
43
44    #[cfg_attr(feature = "worker", worker::send)]
45    async fn embed_texts(
46        &self,
47        documents: impl IntoIterator<Item = String> + Send,
48    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
49        let documents: Vec<_> = documents.into_iter().collect();
50        let mut request_body = json!({
51            "model": format!("models/{}", self.model),
52            "content": {
53                "parts": documents.iter().map(|doc| json!({ "text": doc })).collect::<Vec<_>>(),
54            },
55        });
56
57        if let Some(ndims) = self.ndims {
58            request_body["output_dimensionality"] = json!(ndims);
59        }
60
61        let response = self
62            .client
63            .post(&format!("/v1beta/models/{}:embedContent", self.model))
64            .json(&request_body)
65            .send()
66            .await?
67            .error_for_status()?
68            .json::<ApiResponse<gemini_api_types::EmbeddingResponse>>()
69            .await?;
70
71        match response {
72            ApiResponse::Ok(response) => {
73                let chunk_size = self.ndims.unwrap_or_else(|| self.ndims());
74                Ok(documents
75                    .into_iter()
76                    .zip(response.embedding.values.chunks(chunk_size))
77                    .map(|(document, embedding)| embeddings::Embedding {
78                        document,
79                        vec: embedding.to_vec(),
80                    })
81                    .collect())
82            }
83            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
84        }
85    }
86}
87
88// =================================================================
89// Gemini API Types
90// =================================================================
91/// Rust Implementation of the Gemini Types from [Gemini API Reference](https://ai.google.dev/api/embeddings)
92#[allow(dead_code)]
93mod gemini_api_types {
94    use serde::{Deserialize, Serialize};
95    use serde_json::Value;
96
97    use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
98
99    #[derive(Serialize)]
100    #[serde(rename_all = "camelCase")]
101    pub struct EmbedContentRequest {
102        model: String,
103        content: EmbeddingContent,
104        task_type: TaskType,
105        title: String,
106        output_dimensionality: i32,
107    }
108
109    #[derive(Serialize)]
110    pub struct EmbeddingContent {
111        parts: Vec<EmbeddingContentPart>,
112        /// Optional. The producer of the content. Must be either 'user' or 'model'. Useful to set for multi-turn
113        /// conversations, otherwise can be left blank or unset.
114        role: Option<String>,
115    }
116
117    /// A datatype containing media that is part of a multi-part Content message.
118    ///  - A Part consists of data which has an associated datatype. A Part can only contain one of the accepted types in Part.data.
119    ///  - A Part must have a fixed IANA MIME type identifying the type and subtype of the media if the inlineData field is filled with raw bytes.
120    #[derive(Serialize)]
121    pub struct EmbeddingContentPart {
122        /// Inline text.
123        text: String,
124        /// Inline media bytes.
125        inline_data: Option<Blob>,
126        /// A predicted FunctionCall returned from the model that contains a string representing the [FunctionDeclaration.name]
127        /// with the arguments and their values.
128        function_call: Option<FunctionCall>,
129        /// The result output of a FunctionCall that contains a string representing the [FunctionDeclaration.name] and a structured
130        /// JSON object containing any output from the function is used as context to the model.
131        function_response: Option<FunctionResponse>,
132        /// URI based data.
133        file_data: Option<FileData>,
134        /// Code generated by the model that is meant to be executed.
135        executable_code: Option<ExecutableCode>,
136        /// Result of executing the ExecutableCode.
137        code_execution_result: Option<CodeExecutionResult>,
138    }
139
140    /// Raw media bytes.
141    /// Text should not be sent as raw bytes, use the 'text' field.
142    #[derive(Serialize)]
143    pub struct Blob {
144        /// Raw bytes for media formats.A base64-encoded string.
145        data: String,
146        /// The IANA standard MIME type of the source data. Examples: - image/png - image/jpeg If an unsupported MIME type is
147        /// provided, an error will be returned. For a complete list of supported types, see Supported file formats.
148        mime_type: String,
149    }
150
151    #[derive(Serialize)]
152    pub struct FunctionCall {
153        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
154        name: String,
155        /// The function parameters and values in JSON object format.
156        args: Option<Value>,
157    }
158
159    #[derive(Serialize)]
160    pub struct FunctionResponse {
161        /// The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
162        name: String,
163        /// The result of the function call in JSON object format.
164        result: Value,
165    }
166
167    #[derive(Serialize)]
168    #[serde(rename_all = "camelCase")]
169    pub struct FileData {
170        /// The URI of the file.
171        file_uri: String,
172        /// The IANA standard MIME type of the source data.
173        mime_type: String,
174    }
175
176    #[derive(Serialize)]
177    #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
178    pub enum TaskType {
179        /// Unset value, which will default to one of the other enum values.
180        Unspecified,
181        /// Specifies the given text is a query in a search/retrieval setting.
182        RetrievalQuery,
183        /// Specifies the given text is a document from the corpus being searched.
184        RetrievalDocument,
185        /// Specifies the given text will be used for STS.
186        SemanticSimilarity,
187        /// Specifies that the given text will be classified.
188        Classification,
189        /// Specifies that the embeddings will be used for clustering.
190        Clustering,
191        /// Specifies that the given text will be used for question answering.
192        QuestionAnswering,
193        /// Specifies that the given text will be used for fact verification.
194        FactVerification,
195    }
196
197    #[derive(Debug, Deserialize)]
198    pub struct EmbeddingResponse {
199        pub embedding: EmbeddingValues,
200    }
201
202    #[derive(Debug, Deserialize)]
203    pub struct EmbeddingValues {
204        pub values: Vec<f64>,
205    }
206}