Skip to main content

modo/embed/
gemini.rs

1use std::pin::Pin;
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{Error, Result};
7
8use super::backend::EmbeddingBackend;
9use super::config::GeminiConfig;
10use super::convert::to_f32_blob;
11
12const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
13
14struct Inner {
15    client: reqwest::Client,
16    api_key: String,
17    model: String,
18    dimensions: usize,
19}
20
21/// Google Gemini embedding provider.
22///
23/// Calls the Gemini `embedContent` API and returns a little-endian f32 blob.
24///
25/// # Example
26///
27/// ```rust,ignore
28/// let client = reqwest::Client::new();
29/// let provider = GeminiEmbedding::new(client, &config)?;
30/// let embedder = EmbeddingProvider::new(provider);
31/// ```
32pub struct GeminiEmbedding(Arc<Inner>);
33
34impl Clone for GeminiEmbedding {
35    fn clone(&self) -> Self {
36        Self(Arc::clone(&self.0))
37    }
38}
39
40impl GeminiEmbedding {
41    /// Create from config. Validates config at construction.
42    ///
43    /// # Errors
44    ///
45    /// Returns `Error::bad_request` if config validation fails.
46    pub fn new(client: reqwest::Client, config: &GeminiConfig) -> Result<Self> {
47        config.validate()?;
48        Ok(Self(Arc::new(Inner {
49            client,
50            api_key: config.api_key.clone(),
51            model: config.model.clone(),
52            dimensions: config.dimensions,
53        })))
54    }
55}
56
57impl EmbeddingBackend for GeminiEmbedding {
58    fn embed(&self, input: &str) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send + '_>> {
59        let input = input.to_owned();
60        Box::pin(async move {
61            let url = format!("{BASE_URL}/models/{}:embedContent", self.0.model,);
62            let body = Request {
63                content: Content {
64                    parts: vec![Part { text: &input }],
65                },
66                output_dimensionality: self.0.dimensions,
67            };
68
69            let resp = self
70                .0
71                .client
72                .post(&url)
73                .header("x-goog-api-key", &self.0.api_key)
74                .json(&body)
75                .send()
76                .await
77                .map_err(|e| Error::internal("gemini embeddings request failed").chain(e))?;
78
79            if !resp.status().is_success() {
80                let status = resp.status();
81                let text = resp.text().await.unwrap_or_default();
82                return Err(Error::internal(format!(
83                    "gemini embedding error: {status}: {text}"
84                )));
85            }
86
87            let parsed: Response = resp.json().await.map_err(|e| {
88                Error::internal("failed to parse gemini embedding response").chain(e)
89            })?;
90
91            Ok(to_f32_blob(&parsed.embedding.values))
92        })
93    }
94
95    fn dimensions(&self) -> usize {
96        self.0.dimensions
97    }
98
99    fn model_name(&self) -> &str {
100        &self.0.model
101    }
102}
103
104#[derive(Serialize)]
105#[serde(rename_all = "camelCase")]
106struct Request<'a> {
107    content: Content<'a>,
108    output_dimensionality: usize,
109}
110
111#[derive(Serialize)]
112struct Content<'a> {
113    parts: Vec<Part<'a>>,
114}
115
116#[derive(Serialize)]
117struct Part<'a> {
118    text: &'a str,
119}
120
121#[derive(Deserialize)]
122struct Response {
123    embedding: EmbeddingValues,
124}
125
126#[derive(Deserialize)]
127struct EmbeddingValues {
128    values: Vec<f32>,
129}