use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::Result;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EncodingFormat {
#[default]
Float,
Base64,
}
impl EncodingFormat {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Float => "float",
Self::Base64 => "base64",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub model: String,
pub input: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<EncodingFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
impl EmbeddingRequest {
#[must_use]
pub fn new(model: impl Into<String>, input: Vec<String>) -> Self {
Self {
model: model.into(),
input,
encoding_format: None,
dimensions: None,
user: None,
}
}
#[must_use]
pub fn single(model: impl Into<String>, text: impl Into<String>) -> Self {
Self::new(model, vec![text.into()])
}
#[must_use]
pub const fn encoding_format(mut self, format: EncodingFormat) -> Self {
self.encoding_format = Some(format);
self
}
#[must_use]
pub const fn dimensions(mut self, dims: u32) -> Self {
self.dimensions = Some(dims);
self
}
#[must_use]
pub fn user(mut self, user: impl Into<String>) -> Self {
self.user = Some(user.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Embedding {
pub vector: Vec<f32>,
pub index: usize,
}
impl Embedding {
#[must_use]
pub const fn new(vector: Vec<f32>, index: usize) -> Self {
Self { vector, index }
}
#[must_use]
pub const fn dimension(&self) -> usize {
self.vector.len()
}
#[must_use]
pub fn cosine_similarity(&self, other: &Self) -> f32 {
if self.vector.len() != other.vector.len() {
return 0.0;
}
let dot_product: f32 = self
.vector
.iter()
.zip(other.vector.iter())
.map(|(a, b)| a * b)
.sum();
let norm_a: f32 = self.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = other.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot_product / (norm_a * norm_b)
}
#[must_use]
pub fn euclidean_distance(&self, other: &Self) -> f32 {
if self.vector.len() != other.vector.len() {
return f32::MAX;
}
self.vector
.iter()
.zip(other.vector.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
.sqrt()
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub embeddings: Vec<Embedding>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<EmbeddingUsage>,
}
impl EmbeddingResponse {
#[must_use]
pub const fn new(embeddings: Vec<Embedding>) -> Self {
Self {
embeddings,
model: None,
usage: None,
}
}
#[must_use]
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
#[must_use]
pub const fn with_usage(mut self, prompt_tokens: u32, total_tokens: u32) -> Self {
self.usage = Some(EmbeddingUsage {
prompt_tokens,
total_tokens,
});
self
}
#[must_use]
pub fn first(&self) -> Option<&Embedding> {
self.embeddings.first()
}
#[must_use]
pub fn vectors(&self) -> Vec<&Vec<f32>> {
self.embeddings.iter().map(|e| &e.vector).collect()
}
#[must_use]
pub fn tokens_used(&self) -> Option<u32> {
self.usage.as_ref().map(|u| u.total_tokens)
}
}
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse>;
async fn embed_single(&self, model: &str, text: &str) -> Result<Embedding> {
let request = EmbeddingRequest::single(model, text);
let response = self.embed(&request).await?;
response.embeddings.into_iter().next().ok_or_else(|| {
crate::error::LlmError::response_format("embedding", "empty response").into()
})
}
fn default_embedding_model(&self) -> &str;
fn embedding_dimension(&self) -> Option<usize> {
None
}
}