use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::Result;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ModelId(pub String);
impl ModelId {
pub fn new(provider: &str, model: &str, version: u32) -> Self {
Self(format!("{provider}:{model}:{version}"))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for ModelId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EmbeddingTask {
Query,
Document,
}
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
fn model_id(&self) -> ModelId;
fn dim(&self) -> u16;
async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
let mut out = self.embed_batch(&[text], EmbeddingTask::Query).await?;
out.pop()
.ok_or_else(|| crate::error::Error::Other("provider returned no vector".into()))
}
async fn embed_batch(&self, texts: &[&str], task: EmbeddingTask) -> Result<Vec<Vec<f32>>>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn model_id_format_is_stable() {
let id = ModelId::new("local", "multilingual-e5-small", 1);
assert_eq!(id.as_str(), "local:multilingual-e5-small:1");
}
#[test]
fn model_id_roundtrips_through_json() {
let id = ModelId::new("local", "bge-m3", 1);
let s = serde_json::to_string(&id).unwrap();
assert_eq!(s, "\"local:bge-m3:1\"");
let back: ModelId = serde_json::from_str(&s).unwrap();
assert_eq!(back, id);
}
struct FakeProvider {
id: ModelId,
dim: u16,
}
#[async_trait]
impl EmbeddingProvider for FakeProvider {
fn model_id(&self) -> ModelId {
self.id.clone()
}
fn dim(&self) -> u16 {
self.dim
}
async fn embed_batch(&self, texts: &[&str], _task: EmbeddingTask) -> Result<Vec<Vec<f32>>> {
Ok(texts
.iter()
.map(|t| {
let v = (t.len() as f32) / 100.0;
vec![v; self.dim as usize]
})
.collect())
}
}
#[tokio::test]
async fn default_embed_query_forwards_to_batch() {
let p = FakeProvider {
id: ModelId::new("test", "fake", 1),
dim: 4,
};
let v = p.embed_query("hello world").await.unwrap();
assert_eq!(v.len(), 4);
assert!((v[0] - 0.11).abs() < f32::EPSILON);
}
#[tokio::test]
async fn batch_returns_one_vector_per_input() {
let p = FakeProvider {
id: ModelId::new("test", "fake", 1),
dim: 4,
};
let v = p
.embed_batch(&["a", "bb", "ccc"], EmbeddingTask::Document)
.await
.unwrap();
assert_eq!(v.len(), 3);
assert!(v.iter().all(|row| row.len() == 4));
}
#[tokio::test]
async fn embed_query_propagates_empty_provider_result() {
struct Empty;
#[async_trait]
impl EmbeddingProvider for Empty {
fn model_id(&self) -> ModelId {
ModelId::new("test", "empty", 1)
}
fn dim(&self) -> u16 {
4
}
async fn embed_batch(
&self,
_texts: &[&str],
_task: EmbeddingTask,
) -> Result<Vec<Vec<f32>>> {
Ok(vec![])
}
}
let err = Empty.embed_query("x").await.unwrap_err();
assert!(format!("{err}").contains("no vector"));
}
}