Skip to main content

chat_gemini/api/
embedding.rs

1use crate::api::types::error::handle_gemini_error;
2use crate::api::types::request::GeminiEmbeddingRequest;
3use crate::api::types::response::GeminiEmbeddingResponse;
4use crate::client::GeminiClient;
5use chat_core::error::{ChatError, ChatFailure};
6use chat_core::traits::EmbeddingsProvider;
7use chat_core::transport::Transport;
8use chat_core::types::messages::Messages;
9use chat_core::types::response::EmbeddingsResponse;
10
11#[async_trait::async_trait]
12impl<T: Transport> EmbeddingsProvider for GeminiClient<T> {
13    async fn embed(&self, messages: &mut Messages) -> Result<EmbeddingsResponse, ChatFailure> {
14        let path = format!(
15            "{}/models/{}:embedContent",
16            self.base_path, self.model_name
17        );
18
19        let request_body =
20            GeminiEmbeddingRequest::from_core(messages, self.embeddings_config.as_ref())
21                .map_err(ChatFailure::from_err)?;
22
23        let body = serde_json::to_vec(&request_body)
24            .map_err(|e| ChatFailure::from_err(ChatError::InvalidResponse(e.to_string())))?;
25
26        let req = chat_core::transport::Request {
27            scheme: self.scheme.clone(),
28            host: self.host.clone(),
29            path,
30            headers: vec![
31                ("x-goog-api-key".into(), self.api_key.clone()),
32                ("Content-Type".into(), "application/json".into()),
33            ],
34            body,
35        };
36
37        let res = self
38            .transport
39            .send(req)
40            .await
41            .map_err(ChatFailure::from_err)?;
42
43        let res = handle_gemini_error(res)?;
44
45        let gemini_data: GeminiEmbeddingResponse = serde_json::from_slice(&res.body)
46            .map_err(|e| ChatFailure::from_err(ChatError::InvalidResponse(e.to_string())))?;
47
48        gemini_data
49            .into_core_embeddings_response()
50            .map_err(ChatFailure::from_err)
51    }
52}