leann_core/embedding/
gemini.rs1use anyhow::Result;
2use ndarray::Array2;
3use tracing::info;
4
5use super::EmbeddingProvider;
6use crate::settings;
7
8pub struct GeminiEmbedding {
10 model: String,
11 api_key: String,
12 client: reqwest::blocking::Client,
13 dimensions: usize,
14}
15
16impl GeminiEmbedding {
17 pub fn new(model: &str, api_key: Option<&str>) -> Result<Self> {
18 let api_key = settings::resolve_gemini_api_key(api_key).ok_or_else(|| {
19 anyhow::anyhow!("Gemini API key required (set GOOGLE_API_KEY or GEMINI_API_KEY)")
20 })?;
21
22 Ok(Self {
23 model: model.to_string(),
24 api_key,
25 client: reqwest::blocking::Client::new(),
26 dimensions: 768, })
28 }
29}
30
31impl EmbeddingProvider for GeminiEmbedding {
32 fn compute_embeddings(
33 &self,
34 chunks: &[String],
35 _progress: Option<&dyn crate::hnsw::IndexProgress>,
36 ) -> Result<Array2<f32>> {
37 if chunks.is_empty() {
38 return Ok(Array2::zeros((0, self.dimensions)));
39 }
40
41 let max_batch_size = 100;
43 let url = format!(
44 "https://generativelanguage.googleapis.com/v1beta/models/{}:batchEmbedContents?key={}",
45 self.model, self.api_key
46 );
47
48 let mut all_data: Vec<f32> = Vec::new();
49 let mut dim: Option<usize> = None;
50 let num_batches = chunks.len().div_ceil(max_batch_size);
51
52 for (i, batch) in chunks.chunks(max_batch_size).enumerate() {
53 info!(
54 "Gemini embedding batch {}/{} ({} chunks)",
55 i + 1,
56 num_batches,
57 batch.len()
58 );
59 let requests: Vec<serde_json::Value> = batch
60 .iter()
61 .map(|text| {
62 serde_json::json!({
63 "model": format!("models/{}", self.model),
64 "content": {
65 "parts": [{"text": text}]
66 }
67 })
68 })
69 .collect();
70
71 let payload = serde_json::json!({
72 "requests": requests,
73 });
74
75 let response = self.client.post(&url).json(&payload).send()?;
76
77 if !response.status().is_success() {
78 let status = response.status();
79 let body = response.text().unwrap_or_default();
80 anyhow::bail!("Gemini API error ({}): {}", status, body);
81 }
82
83 let body: serde_json::Value = response.json()?;
84
85 let embeddings_array = body["embeddings"]
86 .as_array()
87 .ok_or_else(|| anyhow::anyhow!("Missing 'embeddings' in Gemini response"))?;
88
89 if embeddings_array.is_empty() {
90 anyhow::bail!("Empty embeddings response from Gemini");
91 }
92
93 if dim.is_none() {
94 let first_values = embeddings_array[0]["values"]
95 .as_array()
96 .ok_or_else(|| anyhow::anyhow!("Missing 'values' in embedding"))?;
97 dim = Some(first_values.len());
98 }
99
100 for emb in embeddings_array {
101 let values = emb["values"]
102 .as_array()
103 .ok_or_else(|| anyhow::anyhow!("Missing 'values' in embedding"))?;
104 for v in values {
105 all_data.push(v.as_f64().unwrap_or(0.0) as f32);
106 }
107 }
108 }
109
110 let d = dim.ok_or_else(|| anyhow::anyhow!("No embeddings returned from Gemini"))?;
111 Ok(Array2::from_shape_vec((chunks.len(), d), all_data)?)
112 }
113
114 fn dimensions(&self) -> usize {
115 self.dimensions
116 }
117
118 fn name(&self) -> &str {
119 "gemini"
120 }
121}