pub const CALLER_PROVIDED_EMBEDDER_ID: &str = "<caller-provided>";
use std::collections::HashMap;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use async_trait::async_trait;
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum EmbedRole {
Document,
Query,
}
impl EmbedRole {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
EmbedRole::Document => "document",
EmbedRole::Query => "query",
}
}
}
#[async_trait]
pub trait Embedder: Send + Sync + fmt::Debug + 'static {
fn id(&self) -> &str;
fn dimensions(&self) -> usize;
fn max_input_tokens(&self) -> Option<usize> {
None
}
fn capabilities(&self) -> crate::capabilities::EmbedderCapabilities {
crate::capabilities::EmbedderCapabilities::default()
}
async fn embed(&self, role: EmbedRole, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
fn describe(&self) -> serde_json::Value {
serde_json::Value::Null
}
}
type BoxFut<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
type FactoryFn = Arc<
dyn Fn(serde_json::Value) -> BoxFut<'static, Result<Box<dyn Embedder>>> + Send + Sync + 'static,
>;
#[derive(Default, Clone)]
pub struct EmbedderRegistry {
factories: HashMap<String, FactoryFn>,
}
impl fmt::Debug for EmbedderRegistry {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("EmbedderRegistry")
.field("families", &self.factories.keys().collect::<Vec<_>>())
.finish()
}
}
impl EmbedderRegistry {
#[must_use]
pub fn empty() -> Self {
Self::default()
}
pub fn register<F, Fut>(&mut self, family: &'static str, factory: F)
where
F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Box<dyn Embedder>>> + Send + 'static,
{
let f: FactoryFn = Arc::new(move |v| Box::pin(factory(v)));
self.factories.insert(family.to_string(), f);
}
pub async fn build(
&self,
family: &str,
config: serde_json::Value,
) -> Result<Box<dyn Embedder>> {
let factory = self
.factories
.get(family)
.ok_or_else(|| Error::Config(format!("embedder family {family:?} not registered")))?;
factory(config).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct Fake;
#[async_trait]
impl Embedder for Fake {
fn id(&self) -> &str {
"fake:1"
}
fn dimensions(&self) -> usize {
4
}
async fn embed(&self, _role: EmbedRole, t: &[&str]) -> Result<Vec<Vec<f32>>> {
Ok(t.iter().map(|_| vec![0.0; 4]).collect())
}
}
#[tokio::test]
async fn registry_round_trip() {
let mut r = EmbedderRegistry::empty();
r.register("fake", |_cfg| async {
Ok(Box::new(Fake) as Box<dyn Embedder>)
});
let e = r.build("fake", serde_json::Value::Null).await.unwrap();
assert_eq!(e.id(), "fake:1");
assert_eq!(e.dimensions(), 4);
let out = e.embed(EmbedRole::Document, &["a", "b"]).await.unwrap();
assert_eq!(out.len(), 2);
}
#[tokio::test]
async fn embed_role_round_trip() {
let e = Fake;
let d = e.embed(EmbedRole::Document, &["x"]).await.unwrap();
let q = e.embed(EmbedRole::Query, &["x"]).await.unwrap();
assert_eq!(d, q);
}
#[tokio::test]
async fn registry_unknown_family() {
let r = EmbedderRegistry::empty();
let err = r.build("nope", serde_json::Value::Null).await.unwrap_err();
assert!(matches!(err, Error::Config(_)));
}
}