Skip to main content

leann_core/embedding/
gemini.rs

1use anyhow::Result;
2use ndarray::Array2;
3use tracing::info;
4
5use super::EmbeddingProvider;
6use crate::settings;
7
8/// Gemini API embedding provider.
9pub struct GeminiEmbedding {
10    model: String,
11    api_key: String,
12    client: reqwest::blocking::Client,
13    dimensions: usize,
14}
15
16impl GeminiEmbedding {
17    pub fn new(model: &str, api_key: Option<&str>) -> Result<Self> {
18        let api_key = settings::resolve_gemini_api_key(api_key).ok_or_else(|| {
19            anyhow::anyhow!("Gemini API key required (set GOOGLE_API_KEY or GEMINI_API_KEY)")
20        })?;
21
22        Ok(Self {
23            model: model.to_string(),
24            api_key,
25            client: reqwest::blocking::Client::new(),
26            dimensions: 768, // Default, will be updated on first call
27        })
28    }
29}
30
31impl EmbeddingProvider for GeminiEmbedding {
32    fn compute_embeddings(&self, chunks: &[String]) -> Result<Array2<f32>> {
33        if chunks.is_empty() {
34            return Ok(Array2::zeros((0, self.dimensions)));
35        }
36
37        // Gemini limits to 100 requests per batch call (matches Python's max_batch_size=100)
38        let max_batch_size = 100;
39        let url = format!(
40            "https://generativelanguage.googleapis.com/v1beta/models/{}:batchEmbedContents?key={}",
41            self.model, self.api_key
42        );
43
44        let mut all_data: Vec<f32> = Vec::new();
45        let mut dim: Option<usize> = None;
46        let num_batches = chunks.len().div_ceil(max_batch_size);
47
48        for (i, batch) in chunks.chunks(max_batch_size).enumerate() {
49            info!(
50                "Gemini embedding batch {}/{} ({} chunks)",
51                i + 1,
52                num_batches,
53                batch.len()
54            );
55            let requests: Vec<serde_json::Value> = batch
56                .iter()
57                .map(|text| {
58                    serde_json::json!({
59                        "model": format!("models/{}", self.model),
60                        "content": {
61                            "parts": [{"text": text}]
62                        }
63                    })
64                })
65                .collect();
66
67            let payload = serde_json::json!({
68                "requests": requests,
69            });
70
71            let response = self.client.post(&url).json(&payload).send()?;
72
73            if !response.status().is_success() {
74                let status = response.status();
75                let body = response.text().unwrap_or_default();
76                anyhow::bail!("Gemini API error ({}): {}", status, body);
77            }
78
79            let body: serde_json::Value = response.json()?;
80
81            let embeddings_array = body["embeddings"]
82                .as_array()
83                .ok_or_else(|| anyhow::anyhow!("Missing 'embeddings' in Gemini response"))?;
84
85            if embeddings_array.is_empty() {
86                anyhow::bail!("Empty embeddings response from Gemini");
87            }
88
89            if dim.is_none() {
90                let first_values = embeddings_array[0]["values"]
91                    .as_array()
92                    .ok_or_else(|| anyhow::anyhow!("Missing 'values' in embedding"))?;
93                dim = Some(first_values.len());
94            }
95
96            for emb in embeddings_array {
97                let values = emb["values"]
98                    .as_array()
99                    .ok_or_else(|| anyhow::anyhow!("Missing 'values' in embedding"))?;
100                for v in values {
101                    all_data.push(v.as_f64().unwrap_or(0.0) as f32);
102                }
103            }
104        }
105
106        let d = dim.ok_or_else(|| anyhow::anyhow!("No embeddings returned from Gemini"))?;
107        Ok(Array2::from_shape_vec((chunks.len(), d), all_data)?)
108    }
109
110    fn dimensions(&self) -> usize {
111        self.dimensions
112    }
113
114    fn name(&self) -> &str {
115        "gemini"
116    }
117}