use serde_json::json;
use crate::embeddings::{self, EmbeddingError};
use super::{client::ApiResponse, Client};
pub const EMBEDDING_001: &str = "embedding-001";
pub const EMBEDDING_004: &str = "text-embedding-004";
#[derive(Clone)]
pub struct EmbeddingModel {
client: Client,
model: String,
ndims: Option<usize>,
}
impl EmbeddingModel {
pub fn new(client: Client, model: &str, ndims: Option<usize>) -> Self {
Self {
client,
model: model.to_string(),
ndims,
}
}
}
impl embeddings::EmbeddingModel for EmbeddingModel {
const MAX_DOCUMENTS: usize = 1024;
fn ndims(&self) -> usize {
match self.model.as_str() {
EMBEDDING_001 => 768,
EMBEDDING_004 => 1024,
_ => 0, }
}
#[cfg_attr(feature = "worker", worker::send)]
async fn embed_texts(
&self,
documents: impl IntoIterator<Item = String> + Send,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let documents: Vec<_> = documents.into_iter().collect();
let mut request_body = json!({
"model": format!("models/{}", self.model),
"content": {
"parts": documents.iter().map(|doc| json!({ "text": doc })).collect::<Vec<_>>(),
},
});
if let Some(ndims) = self.ndims {
request_body["output_dimensionality"] = json!(ndims);
}
let response = self
.client
.post(&format!("/v1beta/models/{}:embedContent", self.model))
.json(&request_body)
.send()
.await?
.error_for_status()?
.json::<ApiResponse<gemini_api_types::EmbeddingResponse>>()
.await?;
match response {
ApiResponse::Ok(response) => {
let chunk_size = self.ndims.unwrap_or_else(|| self.ndims());
Ok(documents
.into_iter()
.zip(response.embedding.values.chunks(chunk_size))
.map(|(document, embedding)| embeddings::Embedding {
document,
vec: embedding.to_vec(),
})
.collect())
}
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
}
}
}
#[allow(dead_code)]
mod gemini_api_types {
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct EmbedContentRequest {
model: String,
content: EmbeddingContent,
task_type: TaskType,
title: String,
output_dimensionality: i32,
}
#[derive(Serialize)]
pub struct EmbeddingContent {
parts: Vec<EmbeddingContentPart>,
role: Option<String>,
}
#[derive(Serialize)]
pub struct EmbeddingContentPart {
text: String,
inline_data: Option<Blob>,
function_call: Option<FunctionCall>,
function_response: Option<FunctionResponse>,
file_data: Option<FileData>,
executable_code: Option<ExecutableCode>,
code_execution_result: Option<CodeExecutionResult>,
}
#[derive(Serialize)]
pub struct Blob {
data: String,
mime_type: String,
}
#[derive(Serialize)]
pub struct FunctionCall {
name: String,
args: Option<Value>,
}
#[derive(Serialize)]
pub struct FunctionResponse {
name: String,
result: Value,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct FileData {
file_uri: String,
mime_type: String,
}
#[derive(Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum TaskType {
Unspecified,
RetrievalQuery,
RetrievalDocument,
SemanticSimilarity,
Classification,
Clustering,
QuestionAnswering,
FactVerification,
}
#[derive(Debug, Deserialize)]
pub struct EmbeddingResponse {
pub embedding: EmbeddingValues,
}
#[derive(Debug, Deserialize)]
pub struct EmbeddingValues {
pub values: Vec<f64>,
}
}