use std::sync::Arc;
use async_trait::async_trait;
use crate::{EmbeddingProvider, error::EmbeddingError};
#[derive(Debug)]
pub struct BoxError(Box<dyn std::error::Error + Send + Sync + 'static>);
impl BoxError {
#[must_use]
pub fn new<E: EmbeddingError>(err: E) -> Self {
Self(Box::new(err))
}
}
impl std::fmt::Display for BoxError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)
}
}
impl std::error::Error for BoxError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&*self.0)
}
}
pub struct ErasedEmbeddingProvider<P>(P);
#[async_trait]
impl<P: EmbeddingProvider> EmbeddingProvider for ErasedEmbeddingProvider<P> {
type Error = BoxError;
fn model_id(&self) -> &str {
self.0.model_id()
}
fn dimensions(&self) -> usize {
self.0.dimensions()
}
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, Self::Error> {
self.0.embed(texts).await.map_err(BoxError::new)
}
}
pub type DynEmbeddingProvider = dyn EmbeddingProvider<Error = BoxError>;
#[must_use]
pub fn into_dyn<P: EmbeddingProvider>(provider: P) -> Arc<DynEmbeddingProvider> {
Arc::new(ErasedEmbeddingProvider(provider))
}
#[cfg(test)]
mod tests {
#![allow(clippy::pedantic, clippy::nursery, missing_docs)]
use std::sync::Arc;
use super::{BoxError, DynEmbeddingProvider, into_dyn};
use crate::{EmbeddingProvider, error::DummyError};
struct DummyEmbedder;
#[async_trait::async_trait]
impl EmbeddingProvider for DummyEmbedder {
type Error = DummyError;
fn model_id(&self) -> &str {
"dummy-2"
}
fn dimensions(&self) -> usize {
2
}
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, Self::Error> {
if texts.is_empty() {
return Err(DummyError::Embed("empty batch".to_owned()));
}
Ok(texts
.iter()
.map(|t| vec![t.len() as f32, t.bytes().next().unwrap_or(0) as f32])
.collect())
}
}
#[tokio::test]
async fn erased_embeds_batch() {
let p: Arc<DynEmbeddingProvider> = into_dyn(DummyEmbedder);
assert_eq!(p.model_id(), "dummy-2");
assert_eq!(p.dimensions(), 2);
let out = p.embed(&["ab".to_owned(), "xyz".to_owned()]).await.unwrap();
assert_eq!(out.len(), 2);
assert_eq!(out[0], vec![2.0, b'a' as f32]);
}
#[tokio::test]
async fn erased_error_is_preserved() {
let p: Arc<DynEmbeddingProvider> = into_dyn(DummyEmbedder);
let err = p.embed(&[]).await.unwrap_err();
assert_eq!(format!("{err}"), "embed failed: empty batch");
let src = std::error::Error::source(&err).expect("source present");
assert_eq!(format!("{src}"), "embed failed: empty batch");
}
#[test]
fn box_error_satisfies_embedding_error() {
fn require<E: crate::error::EmbeddingError>() {}
require::<BoxError>();
}
}