Skip to main content

polyc_embeddings/
erased.rs

1//! Type erasure for [`EmbeddingProvider`] so callers can hold a single
2//! `Arc<dyn EmbeddingProvider>` regardless of which backend is configured.
3//!
4//! Mirrors `polyc_llm::erased`: the trait keeps a per-backend associated
5//! [`EmbeddingError`] type, a trait object must fix that, so this module
6//! supplies one uniform error — [`BoxError`] — and an
7//! [`ErasedEmbeddingProvider`] adapter that maps any backend's error into it.
8//! The result is [`DynEmbeddingProvider`], the single trait-object type callers
9//! store and dispatch through. Adding a backend costs one trait impl plus one
10//! [`into_dyn`] call at the wiring boundary.
11
12use std::sync::Arc;
13
14use async_trait::async_trait;
15
16use crate::{EmbeddingProvider, error::EmbeddingError};
17
18/// A backend error erased to one concrete type, so backends with differing
19/// associated `Error`s can be stored behind a single trait object.
20///
21/// Transparent wrapper: [`Display`](std::fmt::Display) delegates to the inner
22/// error and [`source`](std::error::Error::source) exposes it, so logs and
23/// error chains read exactly as the un-erased error did.
24#[derive(Debug)]
25pub struct BoxError(Box<dyn std::error::Error + Send + Sync + 'static>);
26
27impl BoxError {
28    /// Erase any [`EmbeddingError`] into a `BoxError`.
29    #[must_use]
30    pub fn new<E: EmbeddingError>(err: E) -> Self {
31        Self(Box::new(err))
32    }
33}
34
35impl std::fmt::Display for BoxError {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        std::fmt::Display::fmt(&self.0, f)
38    }
39}
40
41impl std::error::Error for BoxError {
42    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
43        Some(&*self.0)
44    }
45}
46
47/// Adapter that wraps a concrete [`EmbeddingProvider`] and erases its
48/// associated error to [`BoxError`], so the wrapped value coerces to
49/// [`DynEmbeddingProvider`].
50pub struct ErasedEmbeddingProvider<P>(P);
51
52#[async_trait]
53impl<P: EmbeddingProvider> EmbeddingProvider for ErasedEmbeddingProvider<P> {
54    type Error = BoxError;
55
56    fn model_id(&self) -> &str {
57        self.0.model_id()
58    }
59
60    fn dimensions(&self) -> usize {
61        self.0.dimensions()
62    }
63
64    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, Self::Error> {
65        self.0.embed(texts).await.map_err(BoxError::new)
66    }
67}
68
69/// The single trait-object provider type callers store. Every concrete backend
70/// is erased to this via [`into_dyn`].
71pub type DynEmbeddingProvider = dyn EmbeddingProvider<Error = BoxError>;
72
73/// Erase a concrete provider and wrap it in an `Arc` as a
74/// [`DynEmbeddingProvider`]. The one wiring-boundary call.
75#[must_use]
76pub fn into_dyn<P: EmbeddingProvider>(provider: P) -> Arc<DynEmbeddingProvider> {
77    Arc::new(ErasedEmbeddingProvider(provider))
78}
79
80#[cfg(test)]
81mod tests {
82    #![allow(clippy::pedantic, clippy::nursery, missing_docs)]
83
84    use std::sync::Arc;
85
86    use super::{BoxError, DynEmbeddingProvider, into_dyn};
87    use crate::{EmbeddingProvider, error::DummyError};
88
89    /// Backend whose `Error` is `DummyError` — a different concrete type than
90    /// `BoxError`, so erasing it actually exercises the conversion.
91    struct DummyEmbedder;
92
93    #[async_trait::async_trait]
94    impl EmbeddingProvider for DummyEmbedder {
95        type Error = DummyError;
96
97        fn model_id(&self) -> &str {
98            "dummy-2"
99        }
100
101        fn dimensions(&self) -> usize {
102            2
103        }
104
105        async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, Self::Error> {
106            if texts.is_empty() {
107                return Err(DummyError::Embed("empty batch".to_owned()));
108            }
109            // Deterministic toy embedding: [len, first-byte].
110            Ok(texts
111                .iter()
112                .map(|t| vec![t.len() as f32, t.bytes().next().unwrap_or(0) as f32])
113                .collect())
114        }
115    }
116
117    #[tokio::test]
118    async fn erased_embeds_batch() {
119        let p: Arc<DynEmbeddingProvider> = into_dyn(DummyEmbedder);
120        assert_eq!(p.model_id(), "dummy-2");
121        assert_eq!(p.dimensions(), 2);
122        let out = p.embed(&["ab".to_owned(), "xyz".to_owned()]).await.unwrap();
123        assert_eq!(out.len(), 2);
124        assert_eq!(out[0], vec![2.0, b'a' as f32]);
125    }
126
127    #[tokio::test]
128    async fn erased_error_is_preserved() {
129        let p: Arc<DynEmbeddingProvider> = into_dyn(DummyEmbedder);
130        let err = p.embed(&[]).await.unwrap_err();
131        assert_eq!(format!("{err}"), "embed failed: empty batch");
132        let src = std::error::Error::source(&err).expect("source present");
133        assert_eq!(format!("{src}"), "embed failed: empty batch");
134    }
135
136    #[test]
137    fn box_error_satisfies_embedding_error() {
138        fn require<E: crate::error::EmbeddingError>() {}
139        require::<BoxError>();
140    }
141}