use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::Result;
#[cfg(feature = "anthropic")]
mod anthropic;
#[cfg(feature = "anthropic")]
pub use self::anthropic::{AnthropicConfig, AnthropicTextVectorizer};
#[cfg(feature = "azure-openai")]
mod azure_openai;
#[cfg(feature = "azure-openai")]
pub use azure_openai::{AzureOpenAIConfig, AzureOpenAITextVectorizer};
#[cfg(feature = "bedrock")]
mod bedrock;
#[cfg(feature = "bedrock")]
pub use self::bedrock::{BedrockConfig, BedrockTextVectorizer};
#[cfg(feature = "cohere")]
mod cohere;
#[cfg(feature = "cohere")]
pub use self::cohere::{CohereConfig, CohereTextVectorizer};
#[cfg(feature = "hf-local")]
mod hf_local;
#[cfg(feature = "hf-local")]
pub use self::hf_local::{HuggingFaceConfig, HuggingFaceTextVectorizer};
#[cfg(feature = "mistral")]
mod mistral;
#[cfg(feature = "mistral")]
pub use self::mistral::{MistralAITextVectorizer, MistralConfig};
#[cfg(feature = "voyageai")]
mod voyageai;
#[cfg(feature = "voyageai")]
pub use self::voyageai::{VoyageAIConfig, VoyageAITextVectorizer};
#[cfg(feature = "vertex-ai")]
mod vertex_ai;
#[cfg(feature = "vertex-ai")]
pub use self::vertex_ai::{VertexAIConfig, VertexAITextVectorizer};
#[derive(Debug, Clone, Serialize)]
pub struct EmbeddingRequest<'a> {
pub model: &'a str,
pub input: Vec<&'a str>,
}
pub trait Vectorizer: Send + Sync {
fn embed(&self, text: &str) -> Result<Vec<f32>>;
fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|text| self.embed(text)).collect()
}
}
#[async_trait]
pub trait AsyncVectorizer: Send + Sync {
async fn embed(&self, text: &str) -> Result<Vec<f32>>;
async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut embeddings = Vec::with_capacity(texts.len());
for text in texts {
embeddings.push(self.embed(text).await?);
}
Ok(embeddings)
}
}
#[derive(Debug, Clone)]
pub struct OpenAICompatibleConfig {
pub base_url: url::Url,
pub api_key: String,
pub model: String,
}
impl OpenAICompatibleConfig {
pub fn new(
base_url: impl AsRef<str>,
api_key: impl Into<String>,
model: impl Into<String>,
) -> Result<Self> {
Ok(Self {
base_url: url::Url::parse(base_url.as_ref())?,
api_key: api_key.into(),
model: model.into(),
})
}
fn embeddings_url(&self) -> Result<url::Url> {
Ok(self.base_url.join("embeddings")?)
}
}
#[derive(Debug, Clone)]
pub struct OpenAITextVectorizer {
config: OpenAICompatibleConfig,
client: reqwest::Client,
blocking_client: reqwest::blocking::Client,
}
impl OpenAITextVectorizer {
pub fn new(config: OpenAICompatibleConfig) -> Self {
Self {
config,
client: reqwest::Client::new(),
blocking_client: reqwest::blocking::Client::new(),
}
}
async fn embed_many_inner(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let response: EmbeddingResponse = self
.client
.post(self.config.embeddings_url()?)
.bearer_auth(&self.config.api_key)
.json(&EmbeddingRequest {
model: &self.config.model,
input: texts.to_vec(),
})
.send()
.await?
.error_for_status()?
.json()
.await?;
Ok(response
.data
.into_iter()
.map(|item| item.embedding)
.collect())
}
}
impl Vectorizer for OpenAITextVectorizer {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let response: EmbeddingResponse = self
.blocking_client
.post(self.config.embeddings_url()?)
.bearer_auth(&self.config.api_key)
.json(&EmbeddingRequest {
model: &self.config.model,
input: vec![text],
})
.send()?
.error_for_status()?
.json()?;
Ok(response
.data
.into_iter()
.next()
.map_or_else(Vec::new, |item| item.embedding))
}
fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let response: EmbeddingResponse = self
.blocking_client
.post(self.config.embeddings_url()?)
.bearer_auth(&self.config.api_key)
.json(&EmbeddingRequest {
model: &self.config.model,
input: texts.to_vec(),
})
.send()?
.error_for_status()?
.json()?;
Ok(response
.data
.into_iter()
.map(|item| item.embedding)
.collect())
}
}
#[async_trait]
impl AsyncVectorizer for OpenAITextVectorizer {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut embeddings = self.embed_many_inner(&[text]).await?;
Ok(embeddings.pop().unwrap_or_default())
}
async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
self.embed_many_inner(texts).await
}
}
#[derive(Debug, Clone)]
pub struct LiteLLMTextVectorizer {
inner: OpenAITextVectorizer,
}
impl LiteLLMTextVectorizer {
pub fn new(config: OpenAICompatibleConfig) -> Self {
Self {
inner: OpenAITextVectorizer::new(config),
}
}
}
impl Vectorizer for LiteLLMTextVectorizer {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
Vectorizer::embed(&self.inner, text)
}
fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
Vectorizer::embed_many(&self.inner, texts)
}
}
#[async_trait]
impl AsyncVectorizer for LiteLLMTextVectorizer {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
AsyncVectorizer::embed(&self.inner, text).await
}
async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
AsyncVectorizer::embed_many(&self.inner, texts).await
}
}
pub struct CustomTextVectorizer<F>
where
F: Fn(&str) -> Result<Vec<f32>> + Send + Sync + 'static,
{
embedder: Arc<F>,
}
impl<F> CustomTextVectorizer<F>
where
F: Fn(&str) -> Result<Vec<f32>> + Send + Sync + 'static,
{
pub fn new(embedder: F) -> Self {
Self {
embedder: Arc::new(embedder),
}
}
}
impl<F> Vectorizer for CustomTextVectorizer<F>
where
F: Fn(&str) -> Result<Vec<f32>> + Send + Sync + 'static,
{
fn embed(&self, text: &str) -> Result<Vec<f32>> {
(self.embedder)(text)
}
}
#[derive(Debug, Deserialize)]
pub(crate) struct EmbeddingResponse {
pub(crate) data: Vec<EmbeddingDatum>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct EmbeddingDatum {
pub(crate) embedding: Vec<f32>,
}