use serde_json::json;
use super::{Client, client::ApiResponse};
use crate::{
embeddings::{self, EmbeddingError},
http_client::HttpClientExt,
wasm_compat::WasmCompatSend,
};
pub const EMBEDDING_001: &str = "embedding-001";
pub const EMBEDDING_004: &str = "text-embedding-004";
#[derive(Clone)]
pub struct EmbeddingModel<T = reqwest::Client> {
client: Client<T>,
model: String,
ndims: Option<usize>,
}
impl<T> EmbeddingModel<T> {
pub fn new(client: Client<T>, model: impl Into<String>, ndims: Option<usize>) -> Self {
Self {
client,
model: model.into(),
ndims,
}
}
pub fn with_model(client: Client<T>, model: &str, ndims: Option<usize>) -> Self {
Self {
client,
model: model.to_string(),
ndims,
}
}
}
impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
where
T: Clone + HttpClientExt + 'static,
{
type Client = Client<T>;
const MAX_DOCUMENTS: usize = 1024;
fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
Self::new(client.clone(), model, dims)
}
fn ndims(&self) -> usize {
768
}
async fn embed_texts(
&self,
documents: impl IntoIterator<Item = String> + WasmCompatSend,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let documents: Vec<String> = documents.into_iter().collect();
let requests: Vec<_> = documents
.iter()
.map(|doc| {
json!({
"model": format!("models/{}", self.model),
"content": json!({
"parts": [json!({
"text": doc.to_string()
})]
}),
"output_dimensionality": self.ndims,
})
})
.collect();
let request_body = json!({ "requests": requests });
tracing::trace!(
target: "rig::embedding",
"Sending embedding request to Gemini API {}",
serde_json::to_string_pretty(&request_body).unwrap()
);
let request_body = serde_json::to_vec(&request_body)?;
let path = format!("/v1beta/models/{}:batchEmbedContents", self.model);
let req = self
.client
.post(path.as_str())?
.body(request_body)
.map_err(|e| EmbeddingError::HttpError(e.into()))?;
let response = self.client.send::<_, Vec<u8>>(req).await?;
let response: ApiResponse<gemini_api_types::EmbeddingResponse> =
serde_json::from_slice(&response.into_body().await?)?;
match response {
ApiResponse::Ok(response) => {
let docs = documents
.into_iter()
.zip(response.embeddings)
.map(|(document, embedding)| embeddings::Embedding {
document,
vec: embedding.values,
})
.collect();
Ok(docs)
}
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
}
}
}
#[allow(dead_code)]
mod gemini_api_types {
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct EmbedContentRequest {
model: String,
content: EmbeddingContent,
task_type: TaskType,
title: String,
output_dimensionality: i32,
}
#[derive(Serialize)]
pub struct EmbeddingContent {
parts: Vec<EmbeddingContentPart>,
role: Option<String>,
}
#[derive(Serialize)]
pub struct EmbeddingContentPart {
text: String,
inline_data: Option<Blob>,
function_call: Option<FunctionCall>,
function_response: Option<FunctionResponse>,
file_data: Option<FileData>,
executable_code: Option<ExecutableCode>,
code_execution_result: Option<CodeExecutionResult>,
}
#[derive(Serialize)]
pub struct Blob {
data: String,
mime_type: String,
}
#[derive(Serialize)]
pub struct FunctionCall {
name: String,
args: Option<Value>,
}
#[derive(Serialize)]
pub struct FunctionResponse {
name: String,
result: Value,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct FileData {
file_uri: String,
mime_type: String,
}
#[derive(Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum TaskType {
Unspecified,
RetrievalQuery,
RetrievalDocument,
SemanticSimilarity,
Classification,
Clustering,
QuestionAnswering,
FactVerification,
}
#[derive(Debug, Deserialize)]
pub struct EmbeddingResponse {
pub embeddings: Vec<EmbeddingValues>,
}
#[derive(Debug, Deserialize)]
pub struct EmbeddingValues {
pub values: Vec<f64>,
}
}