1use std::pin::Pin;
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{Error, Result};
7
8use super::backend::EmbeddingBackend;
9use super::config::GeminiConfig;
10use super::convert::to_f32_blob;
11
12const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
13
14struct Inner {
15 client: reqwest::Client,
16 api_key: String,
17 model: String,
18 dimensions: usize,
19}
20
21pub struct GeminiEmbedding(Arc<Inner>);
33
34impl Clone for GeminiEmbedding {
35 fn clone(&self) -> Self {
36 Self(Arc::clone(&self.0))
37 }
38}
39
40impl GeminiEmbedding {
41 pub fn new(client: reqwest::Client, config: &GeminiConfig) -> Result<Self> {
47 config.validate()?;
48 Ok(Self(Arc::new(Inner {
49 client,
50 api_key: config.api_key.clone(),
51 model: config.model.clone(),
52 dimensions: config.dimensions,
53 })))
54 }
55}
56
57impl EmbeddingBackend for GeminiEmbedding {
58 fn embed(&self, input: &str) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send + '_>> {
59 let input = input.to_owned();
60 Box::pin(async move {
61 let url = format!("{BASE_URL}/models/{}:embedContent", self.0.model,);
62 let body = Request {
63 content: Content {
64 parts: vec![Part { text: &input }],
65 },
66 output_dimensionality: self.0.dimensions,
67 };
68
69 let resp = self
70 .0
71 .client
72 .post(&url)
73 .header("x-goog-api-key", &self.0.api_key)
74 .json(&body)
75 .send()
76 .await
77 .map_err(|e| Error::internal("gemini embeddings request failed").chain(e))?;
78
79 if !resp.status().is_success() {
80 let status = resp.status();
81 let text = resp.text().await.unwrap_or_default();
82 return Err(Error::internal(format!(
83 "gemini embedding error: {status}: {text}"
84 )));
85 }
86
87 let parsed: Response = resp.json().await.map_err(|e| {
88 Error::internal("failed to parse gemini embedding response").chain(e)
89 })?;
90
91 Ok(to_f32_blob(&parsed.embedding.values))
92 })
93 }
94
95 fn dimensions(&self) -> usize {
96 self.0.dimensions
97 }
98
99 fn model_name(&self) -> &str {
100 &self.0.model
101 }
102}
103
104#[derive(Serialize)]
105#[serde(rename_all = "camelCase")]
106struct Request<'a> {
107 content: Content<'a>,
108 output_dimensionality: usize,
109}
110
111#[derive(Serialize)]
112struct Content<'a> {
113 parts: Vec<Part<'a>>,
114}
115
116#[derive(Serialize)]
117struct Part<'a> {
118 text: &'a str,
119}
120
121#[derive(Deserialize)]
122struct Response {
123 embedding: EmbeddingValues,
124}
125
126#[derive(Deserialize)]
127struct EmbeddingValues {
128 values: Vec<f32>,
129}