use crate::core::types::{EmbeddingOptions, EmbeddingResult};
use anyhow::anyhow;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct GoogleEmbeddingModel {
pub api_key: String,
pub base_url: String,
pub client: Client,
}
impl GoogleEmbeddingModel {
#[must_use]
pub fn new(api_key: String) -> Self {
Self {
api_key,
base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
client: Client::new(),
}
}
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GoogleEmbedContentRequest {
model: String,
content: GoogleEmbedContent,
#[serde(skip_serializing_if = "Option::is_none")]
output_dimensionality: Option<u32>,
}
#[derive(Serialize)]
struct GoogleEmbedContent {
parts: Vec<GoogleEmbedPart>,
}
#[derive(Serialize)]
struct GoogleEmbedPart {
text: String,
}
#[derive(Deserialize)]
struct GoogleSingleEmbeddingResponse {
embedding: GoogleEmbeddingValues,
}
#[derive(Deserialize)]
struct GoogleEmbeddingValues {
values: Vec<f32>,
}
#[derive(Serialize)]
struct GoogleBatchEmbedRequest {
requests: Vec<GoogleBatchEmbedItem>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GoogleBatchEmbedItem {
model: String,
content: GoogleEmbedContent,
#[serde(skip_serializing_if = "Option::is_none")]
output_dimensionality: Option<u32>,
}
#[derive(Deserialize)]
struct GoogleBatchEmbeddingResponse {
embeddings: Vec<GoogleEmbeddingValues>,
}
#[async_trait]
impl crate::core::EmbeddingModel for GoogleEmbeddingModel {
async fn embed(
&self,
values: Vec<String>,
options: EmbeddingOptions,
) -> crate::core::Result<EmbeddingResult> {
if values.len() == 1 {
let request = GoogleEmbedContentRequest {
model: format!("models/{}", options.model_id),
content: GoogleEmbedContent {
parts: vec![GoogleEmbedPart {
text: values[0].clone(),
}],
},
output_dimensionality: options.dimensions,
};
let url = format!(
"{}/models/{}:embedContent?key={}",
self.base_url, options.model_id, self.api_key
);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let error_text = response.text().await?;
return Err(anyhow!("Google Embedding API error: {error_text}").into());
}
let resp: GoogleSingleEmbeddingResponse = response.json().await?;
Ok(EmbeddingResult {
embeddings: vec![resp.embedding.values],
usage: None,
})
} else {
let request = GoogleBatchEmbedRequest {
requests: values
.iter()
.map(|v| GoogleBatchEmbedItem {
model: format!("models/{}", options.model_id),
content: GoogleEmbedContent {
parts: vec![GoogleEmbedPart { text: v.clone() }],
},
output_dimensionality: options.dimensions,
})
.collect(),
};
let url = format!(
"{}/models/{}:batchEmbedContents?key={}",
self.base_url, options.model_id, self.api_key
);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let error_text = response.text().await?;
return Err(anyhow!("Google Embedding API error: {error_text}").into());
}
let resp: GoogleBatchEmbeddingResponse = response.json().await?;
Ok(EmbeddingResult {
embeddings: resp.embeddings.into_iter().map(|e| e.values).collect(),
usage: None,
})
}
}
}