leann_core/embedding/
gemini.rs1use anyhow::Result;
2use ndarray::Array2;
3use tracing::info;
4
5use super::EmbeddingProvider;
6use crate::settings;
7
8pub struct GeminiEmbedding {
10 model: String,
11 api_key: String,
12 client: reqwest::blocking::Client,
13 dimensions: usize,
14}
15
16impl GeminiEmbedding {
17 pub fn new(model: &str, api_key: Option<&str>) -> Result<Self> {
18 let api_key = settings::resolve_gemini_api_key(api_key).ok_or_else(|| {
19 anyhow::anyhow!("Gemini API key required (set GOOGLE_API_KEY or GEMINI_API_KEY)")
20 })?;
21
22 Ok(Self {
23 model: model.to_string(),
24 api_key,
25 client: reqwest::blocking::Client::new(),
26 dimensions: 768, })
28 }
29}
30
31impl EmbeddingProvider for GeminiEmbedding {
32 fn compute_embeddings(&self, chunks: &[String]) -> Result<Array2<f32>> {
33 if chunks.is_empty() {
34 return Ok(Array2::zeros((0, self.dimensions)));
35 }
36
37 let max_batch_size = 100;
39 let url = format!(
40 "https://generativelanguage.googleapis.com/v1beta/models/{}:batchEmbedContents?key={}",
41 self.model, self.api_key
42 );
43
44 let mut all_data: Vec<f32> = Vec::new();
45 let mut dim: Option<usize> = None;
46 let num_batches = chunks.len().div_ceil(max_batch_size);
47
48 for (i, batch) in chunks.chunks(max_batch_size).enumerate() {
49 info!(
50 "Gemini embedding batch {}/{} ({} chunks)",
51 i + 1,
52 num_batches,
53 batch.len()
54 );
55 let requests: Vec<serde_json::Value> = batch
56 .iter()
57 .map(|text| {
58 serde_json::json!({
59 "model": format!("models/{}", self.model),
60 "content": {
61 "parts": [{"text": text}]
62 }
63 })
64 })
65 .collect();
66
67 let payload = serde_json::json!({
68 "requests": requests,
69 });
70
71 let response = self.client.post(&url).json(&payload).send()?;
72
73 if !response.status().is_success() {
74 let status = response.status();
75 let body = response.text().unwrap_or_default();
76 anyhow::bail!("Gemini API error ({}): {}", status, body);
77 }
78
79 let body: serde_json::Value = response.json()?;
80
81 let embeddings_array = body["embeddings"]
82 .as_array()
83 .ok_or_else(|| anyhow::anyhow!("Missing 'embeddings' in Gemini response"))?;
84
85 if embeddings_array.is_empty() {
86 anyhow::bail!("Empty embeddings response from Gemini");
87 }
88
89 if dim.is_none() {
90 let first_values = embeddings_array[0]["values"]
91 .as_array()
92 .ok_or_else(|| anyhow::anyhow!("Missing 'values' in embedding"))?;
93 dim = Some(first_values.len());
94 }
95
96 for emb in embeddings_array {
97 let values = emb["values"]
98 .as_array()
99 .ok_or_else(|| anyhow::anyhow!("Missing 'values' in embedding"))?;
100 for v in values {
101 all_data.push(v.as_f64().unwrap_or(0.0) as f32);
102 }
103 }
104 }
105
106 let d = dim.ok_or_else(|| anyhow::anyhow!("No embeddings returned from Gemini"))?;
107 Ok(Array2::from_shape_vec((chunks.len(), d), all_data)?)
108 }
109
110 fn dimensions(&self) -> usize {
111 self.dimensions
112 }
113
114 fn name(&self) -> &str {
115 "gemini"
116 }
117}