use std::sync::Arc;
use async_trait::async_trait;
use markhor_core::embedding::{Embedder, Embedding, EmbeddingError, EmbeddingUseCase};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::{debug, error, instrument, trace, warn};
use url::Url; use secrecy::ExposeSecret;
use crate::gemini::error::map_response_error;
use super::{error::GeminiError, shared::{self, GeminiConfig, SharedGeminiClient, EXTENSION_URI}};
#[derive(Debug, Clone)] pub struct GeminiEmbedder {
shared_client: Arc<SharedGeminiClient>,
model_path_segment: String, task_type: Option<String>, }
impl GeminiEmbedder {
pub fn new(
api_key: impl Into<String>,
model_name: impl Into<String>,
) -> Result<Self, GeminiError> {
Self::new_with_options(api_key, model_name, None, None, None)
}
pub fn new_with_options(
api_key: impl Into<String>,
model_name: impl Into<String>,
task_type: Option<String>,
api_base_url: Option<String>,
client_override: Option<Client>,
) -> Result<Self, GeminiError> {
let mut config = GeminiConfig::new(api_key)?;
if let Some(base_url_str) = api_base_url {
config = config.base_url(&base_url_str)?;
}
let shared_client = SharedGeminiClient::new(config, client_override)?;
Self::new_with_shared_client(Arc::new(shared_client), model_name.into(), task_type)
}
#[instrument(name = "gemini_embedder_from_config", skip(shared_client), fields(model_name=%model_name))]
pub fn new_with_shared_client(
shared_client: Arc<SharedGeminiClient>,
model_name: String,
task_type: Option<String>,
) -> Result<Self, GeminiError> {
if model_name.is_empty() {
return Err(GeminiError::InvalidConfiguration("Model name cannot be empty".to_string()));
}
let model_path_segment = format!("models/{}", model_name);
let use_case = map_task_type_to_use_case(task_type.as_deref());
debug!(model=%model_name, task_type=?task_type, use_case=?use_case, "GeminiEmbedder created.");
Ok(Self {
shared_client,
model_path_segment,
task_type,
})
}
fn build_batch_embed_url(&self) -> Result<Url, GeminiError> {
let path_segment = format!("{}:batchEmbedContents", self.model_path_segment);
self.shared_client.build_url(&path_segment) }
}
pub fn map_task_type_to_use_case(task_type: Option<&str>) -> EmbeddingUseCase {
match task_type {
Some("RETRIEVAL_QUERY") => EmbeddingUseCase::RetrievalQuery,
Some("RETRIEVAL_DOCUMENT") => EmbeddingUseCase::RetrievalDocument,
Some("SEMANTIC_SIMILARITY") | Some("SIMILARITY") => EmbeddingUseCase::Similarity,
Some("CLASSIFICATION") => EmbeddingUseCase::Classification,
Some("CLUSTERING") => EmbeddingUseCase::Clustering,
Some("QUESTION_ANSWERING") => EmbeddingUseCase::QuestionAnswering,
Some("FACT_VERIFICATION") => EmbeddingUseCase::FactVerification,
Some(other) if other.starts_with("CODE_") => EmbeddingUseCase::CodeRetrievalQuery, Some(other) => EmbeddingUseCase::Other(other.to_string()),
None => EmbeddingUseCase::General, }
}
const BATCH_LIMIT: usize = 100;
#[async_trait]
impl Embedder for GeminiEmbedder {
#[instrument(skip(self, texts), fields(model=%self.model_name(), num_texts=texts.len()))]
async fn embed(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbeddingError> {
async {
if texts.is_empty() {
debug!("Input texts slice is empty, returning empty embeddings.");
return Ok(vec![]);
}
if texts.len() > BATCH_LIMIT {
error!(requested = texts.len(), limit = BATCH_LIMIT, "Batch size exceeds limit");
return Err(GeminiError::BatchTooLarge {
limit: Some(BATCH_LIMIT),
actual: texts.len(),
});
}
let url = self.build_batch_embed_url()?; debug!(%url, "Sending batch embed request to Gemini");
let requests: Vec<GeminiEmbedRequest> = texts
.iter()
.map(|text| GeminiEmbedRequest {
model: &self.model_path_segment, content: GeminiContent {
parts: vec![GeminiPart { text: text }],
},
task_type: self.task_type.as_deref(),
})
.collect();
let request_body = GeminiBatchRequest { requests };
let request_json = serde_json::to_string(&request_body)
.map_err(|e| {
error!(error = %e, "Failed to serialize Gemini embed request body");
GeminiError::RequestSerialization(e)
})?;
trace!(body = %request_json, "Constructed Gemini embed request body JSON");
let response = self.shared_client.http_client()
.post(url)
.header("x-goog-api-key", self.shared_client.config().api_key.expose_secret()) .header("Content-Type", "application/json")
.body(request_json)
.send()
.await
.map_err(GeminiError::Network)?;
if !response.status().is_success() {
let status = response.status();
error!(%status, "Gemini embed API returned error status");
return Err(map_response_error(response).await);
}
let status = response.status();
debug!(%status, "Received successful response for embed request");
let raw_body = response.text()
.await
.map_err(|e| {
error!(error = %e, "Failed to read successful response body for embed");
GeminiError::Network(e)
})?;
trace!(body = %raw_body, "Received Gemini embed response body");
let response_data: GeminiBatchResponse = serde_json::from_str(&raw_body)
.map_err(|e| {
error!(parse_error = %e, raw_body = %raw_body, "Failed to parse Gemini embed response JSON");
GeminiError::ResponseParsing {
context: "Parsing batch embed response".to_string(),
source: e,
}
})?;
if response_data.embeddings.len() != texts.len() {
let msg = format!(
"API returned {} embeddings, but expected {}",
response_data.embeddings.len(), texts.len()
);
error!(message = %msg, "Mismatch between input text count and received embeddings count");
return Err(GeminiError::UnexpectedResponse(msg));
}
debug!("Successfully parsed Gemini embed response, received {} embeddings.", response_data.embeddings.len());
let embeddings_vec = response_data.embeddings
.into_iter()
.map(|e| Embedding::from(e.values)) .collect();
Ok(embeddings_vec)
}
.await
.map_err(|err| { err.into() })
}
fn dimensions(&self) -> Option<usize> {
match self.model_name() {
"embedding-001" => Some(768),
"text-embedding-004" => Some(768),
_ => {
warn!(model = %self.model_name(), "Unknown Gemini embedding model, dimensions not set.");
None
}
}
}
fn model_name(&self) -> &str {
&self.model_path_segment[7..] }
fn intended_use_case(&self) -> EmbeddingUseCase {
match self.task_type.as_ref().map(|s| s.as_str()) {
Some("RETRIEVAL_QUERY") => EmbeddingUseCase::RetrievalQuery,
Some("RETRIEVAL_DOCUMENT") => EmbeddingUseCase::RetrievalDocument,
Some("SEMANTIC_SIMILARITY") | Some("SIMILARITY") => EmbeddingUseCase::Similarity,
Some("CLASSIFICATION") => EmbeddingUseCase::Classification,
Some("CLUSTERING") => EmbeddingUseCase::Clustering,
Some("QUESTION_ANSWERING") => EmbeddingUseCase::QuestionAnswering,
Some("FACT_VERIFICATION") => EmbeddingUseCase::FactVerification,
Some(other) if other.starts_with("CODE_") => EmbeddingUseCase::CodeRetrievalQuery, Some(other) => EmbeddingUseCase::Other(other.to_string()),
None => EmbeddingUseCase::General, }
}
fn max_batch_size_hint(&self) -> Option<usize> {
Some(BATCH_LIMIT)
}
fn max_chunk_length_hint(&self) -> Option<usize> {
match self.model_name() {
"embedding-001" => Some(8000),
"text-embedding-004" => Some(8000),
"gemini-embedding-exp-03-07" => Some(32000),
_ => {
warn!(model = %self.model_name(), "Unknown Gemini embedding model, max chunk length hint not set.");
None
}
}
}
}
#[derive(Serialize, Debug)]
pub struct GeminiBatchRequest<'a> {
pub requests: Vec<GeminiEmbedRequest<'a>>,
}
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct GeminiEmbedRequest<'a> {
pub model: &'a str, pub content: GeminiContent<'a>,
#[serde(skip_serializing_if = "Option::is_none")]
pub task_type: Option<&'a str>, }
#[derive(Serialize, Debug)]
pub struct GeminiContent<'a> {
pub parts: Vec<GeminiPart<'a>>,
}
#[derive(Serialize, Debug)]
pub struct GeminiPart<'a> {
pub text: &'a str,
}
#[derive(Deserialize, Debug)]
pub struct GeminiBatchResponse {
pub embeddings: Vec<GeminiEmbeddingValue>,
}
#[derive(Deserialize, Debug)]
pub struct GeminiEmbeddingValue {
pub values: Vec<f32>,
}
#[derive(Deserialize, Debug)]
pub struct GeminiApiErrorResponse {
pub error: GeminiApiErrorDetail,
}
#[derive(Deserialize, Debug)]
pub struct GeminiApiErrorDetail {
pub code: i32,
pub message: String,
pub status: String, }
#[derive(Debug, Clone)]
pub struct GeminiEmbedderOptions {
pub task_type: Option<String>,
pub client: Option<reqwest::Client>,
pub api_base_url: Option<String>,
}
impl Default for GeminiEmbedderOptions {
fn default() -> Self {
Self {
task_type: None,
client: None,
api_base_url: None,
}
}
}