use serde_json::json;
use super::{Client, client::ApiResponse};
use crate::{
embeddings::{self, EmbeddingError},
http_client::HttpClientExt,
wasm_compat::WasmCompatSend,
};
pub const EMBEDDING_001: &str = "gemini-embedding-001";
pub const EMBEDDING_004: &str = "text-embedding-004";
fn model_default_ndims(model: &str) -> Option<usize> {
match model {
EMBEDDING_001 => Some(3072),
EMBEDDING_004 => Some(768),
_ => None,
}
}
#[derive(Clone)]
pub struct EmbeddingModel<T = reqwest::Client> {
client: Client<T>,
model: String,
ndims: usize,
}
impl<T> EmbeddingModel<T> {
pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
Self {
client,
model: model.into(),
ndims,
}
}
pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
Self {
client,
model: model.to_string(),
ndims,
}
}
}
impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
where
T: Clone + HttpClientExt + 'static,
{
type Client = Client<T>;
const MAX_DOCUMENTS: usize = 1024;
fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
let model = model.into();
let ndims = dims.or_else(|| model_default_ndims(&model)).unwrap_or(768);
Self::new(client.clone(), model, ndims)
}
fn ndims(&self) -> usize {
self.ndims
}
async fn embed_texts(
&self,
documents: impl IntoIterator<Item = String> + WasmCompatSend,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let documents: Vec<String> = documents.into_iter().collect();
let requests: Vec<_> = documents
.iter()
.map(|doc| {
json!({
"model": format!("models/{}", self.model),
"content": json!({
"parts": [json!({
"text": doc.to_string()
})]
}),
"output_dimensionality": self.ndims,
})
})
.collect();
let request_body = json!({ "requests": requests });
tracing::trace!(
target: "rig::embedding",
"Sending embedding request to Gemini API {}",
serde_json::to_string_pretty(&request_body).unwrap()
);
let request_body = serde_json::to_vec(&request_body)?;
let path = format!("/v1beta/models/{}:batchEmbedContents", self.model);
let req = self
.client
.post(path.as_str())?
.body(request_body)
.map_err(|e| EmbeddingError::HttpError(e.into()))?;
let response = self.client.send::<_, Vec<u8>>(req).await?;
let response: ApiResponse<gemini_api_types::EmbeddingResponse> =
serde_json::from_slice(&response.into_body().await?)?;
match response {
ApiResponse::Ok(response) => {
let docs = documents
.into_iter()
.zip(response.embeddings)
.map(|(document, embedding)| embeddings::Embedding {
document,
vec: embedding
.values
.into_iter()
.filter_map(|n| n.as_f64())
.collect(),
})
.collect();
Ok(docs)
}
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
}
}
}
#[allow(dead_code)]
mod gemini_api_types {
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::providers::gemini::gemini_api_types::{CodeExecutionResult, ExecutableCode};
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct EmbedContentRequest {
model: String,
content: EmbeddingContent,
task_type: TaskType,
title: String,
output_dimensionality: i32,
}
#[derive(Serialize)]
pub struct EmbeddingContent {
parts: Vec<EmbeddingContentPart>,
role: Option<String>,
}
#[derive(Serialize)]
pub struct EmbeddingContentPart {
text: String,
inline_data: Option<Blob>,
function_call: Option<FunctionCall>,
function_response: Option<FunctionResponse>,
file_data: Option<FileData>,
executable_code: Option<ExecutableCode>,
code_execution_result: Option<CodeExecutionResult>,
}
#[derive(Serialize)]
pub struct Blob {
data: String,
mime_type: String,
}
#[derive(Serialize)]
pub struct FunctionCall {
name: String,
args: Option<Value>,
}
#[derive(Serialize)]
pub struct FunctionResponse {
name: String,
result: Value,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct FileData {
file_uri: String,
mime_type: String,
}
#[derive(Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum TaskType {
Unspecified,
RetrievalQuery,
RetrievalDocument,
SemanticSimilarity,
Classification,
Clustering,
QuestionAnswering,
FactVerification,
}
#[derive(Debug, Deserialize)]
pub struct EmbeddingResponse {
pub embeddings: Vec<EmbeddingValues>,
}
#[derive(Debug, Deserialize)]
pub struct EmbeddingValues {
pub values: Vec<serde_json::Number>,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_default_ndims_lookup() {
assert_eq!(model_default_ndims(EMBEDDING_001), Some(3072));
assert_eq!(model_default_ndims(EMBEDDING_004), Some(768));
assert_eq!(model_default_ndims("unknown-model"), None);
}
#[test]
fn test_make_resolves_default_dims() {
let client = Client::new("test_key").unwrap();
let model =
<EmbeddingModel as embeddings::EmbeddingModel>::make(&client, EMBEDDING_001, None);
assert_eq!(embeddings::EmbeddingModel::ndims(&model), 3072);
let model =
<EmbeddingModel as embeddings::EmbeddingModel>::make(&client, EMBEDDING_004, None);
assert_eq!(embeddings::EmbeddingModel::ndims(&model), 768);
let model = <EmbeddingModel as embeddings::EmbeddingModel>::make(
&client,
"some-future-model",
None,
);
assert_eq!(embeddings::EmbeddingModel::ndims(&model), 768);
}
#[test]
fn test_make_respects_explicit_dims() {
let client = Client::new("test_key").unwrap();
let model =
<EmbeddingModel as embeddings::EmbeddingModel>::make(&client, EMBEDDING_001, Some(256));
assert_eq!(embeddings::EmbeddingModel::ndims(&model), 256);
}
#[test]
fn test_new_uses_provided_ndims() {
let client = Client::new("test_key").unwrap();
let model = EmbeddingModel::new(client, EMBEDDING_001, 512);
assert_eq!(embeddings::EmbeddingModel::ndims(&model), 512);
}
}