Skip to main content

leann_core/embedding/
gemini.rs

1use anyhow::Result;
2use ndarray::Array2;
3use tracing::info;
4
5use super::EmbeddingProvider;
6use crate::settings;
7
8/// Gemini API embedding provider.
9pub struct GeminiEmbedding {
10    model: String,
11    api_key: String,
12    client: reqwest::blocking::Client,
13    dimensions: usize,
14}
15
16impl GeminiEmbedding {
17    pub fn new(model: &str, api_key: Option<&str>) -> Result<Self> {
18        let api_key = settings::resolve_gemini_api_key(api_key).ok_or_else(|| {
19            anyhow::anyhow!("Gemini API key required (set GOOGLE_API_KEY or GEMINI_API_KEY)")
20        })?;
21
22        Ok(Self {
23            model: model.to_string(),
24            api_key,
25            client: reqwest::blocking::Client::new(),
26            dimensions: 768, // Default, will be updated on first call
27        })
28    }
29}
30
31impl EmbeddingProvider for GeminiEmbedding {
32    fn compute_embeddings(
33        &self,
34        chunks: &[String],
35        _progress: Option<&dyn crate::hnsw::IndexProgress>,
36    ) -> Result<Array2<f32>> {
37        if chunks.is_empty() {
38            return Ok(Array2::zeros((0, self.dimensions)));
39        }
40
41        // Gemini limits to 100 requests per batch call (matches Python's max_batch_size=100)
42        let max_batch_size = 100;
43        let url = format!(
44            "https://generativelanguage.googleapis.com/v1beta/models/{}:batchEmbedContents?key={}",
45            self.model, self.api_key
46        );
47
48        let mut all_data: Vec<f32> = Vec::new();
49        let mut dim: Option<usize> = None;
50        let num_batches = chunks.len().div_ceil(max_batch_size);
51
52        for (i, batch) in chunks.chunks(max_batch_size).enumerate() {
53            info!(
54                "Gemini embedding batch {}/{} ({} chunks)",
55                i + 1,
56                num_batches,
57                batch.len()
58            );
59            let requests: Vec<serde_json::Value> = batch
60                .iter()
61                .map(|text| {
62                    serde_json::json!({
63                        "model": format!("models/{}", self.model),
64                        "content": {
65                            "parts": [{"text": text}]
66                        }
67                    })
68                })
69                .collect();
70
71            let payload = serde_json::json!({
72                "requests": requests,
73            });
74
75            let response = self.client.post(&url).json(&payload).send()?;
76
77            if !response.status().is_success() {
78                let status = response.status();
79                let body = response.text().unwrap_or_default();
80                anyhow::bail!("Gemini API error ({}): {}", status, body);
81            }
82
83            let body: serde_json::Value = response.json()?;
84
85            let embeddings_array = body["embeddings"]
86                .as_array()
87                .ok_or_else(|| anyhow::anyhow!("Missing 'embeddings' in Gemini response"))?;
88
89            if embeddings_array.is_empty() {
90                anyhow::bail!("Empty embeddings response from Gemini");
91            }
92
93            if dim.is_none() {
94                let first_values = embeddings_array[0]["values"]
95                    .as_array()
96                    .ok_or_else(|| anyhow::anyhow!("Missing 'values' in embedding"))?;
97                dim = Some(first_values.len());
98            }
99
100            for emb in embeddings_array {
101                let values = emb["values"]
102                    .as_array()
103                    .ok_or_else(|| anyhow::anyhow!("Missing 'values' in embedding"))?;
104                for v in values {
105                    all_data.push(v.as_f64().unwrap_or(0.0) as f32);
106                }
107            }
108        }
109
110        let d = dim.ok_or_else(|| anyhow::anyhow!("No embeddings returned from Gemini"))?;
111        Ok(Array2::from_shape_vec((chunks.len(), d), all_data)?)
112    }
113
114    fn dimensions(&self) -> usize {
115        self.dimensions
116    }
117
118    fn name(&self) -> &str {
119        "gemini"
120    }
121}