polyc_embeddings/
erased.rs1use std::sync::Arc;
13
14use async_trait::async_trait;
15
16use crate::{EmbeddingProvider, error::EmbeddingError};
17
18#[derive(Debug)]
25pub struct BoxError(Box<dyn std::error::Error + Send + Sync + 'static>);
26
27impl BoxError {
28 #[must_use]
30 pub fn new<E: EmbeddingError>(err: E) -> Self {
31 Self(Box::new(err))
32 }
33}
34
35impl std::fmt::Display for BoxError {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 std::fmt::Display::fmt(&self.0, f)
38 }
39}
40
41impl std::error::Error for BoxError {
42 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
43 Some(&*self.0)
44 }
45}
46
47pub struct ErasedEmbeddingProvider<P>(P);
51
52#[async_trait]
53impl<P: EmbeddingProvider> EmbeddingProvider for ErasedEmbeddingProvider<P> {
54 type Error = BoxError;
55
56 fn model_id(&self) -> &str {
57 self.0.model_id()
58 }
59
60 fn dimensions(&self) -> usize {
61 self.0.dimensions()
62 }
63
64 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, Self::Error> {
65 self.0.embed(texts).await.map_err(BoxError::new)
66 }
67}
68
69pub type DynEmbeddingProvider = dyn EmbeddingProvider<Error = BoxError>;
72
73#[must_use]
76pub fn into_dyn<P: EmbeddingProvider>(provider: P) -> Arc<DynEmbeddingProvider> {
77 Arc::new(ErasedEmbeddingProvider(provider))
78}
79
80#[cfg(test)]
81mod tests {
82 #![allow(clippy::pedantic, clippy::nursery, missing_docs)]
83
84 use std::sync::Arc;
85
86 use super::{BoxError, DynEmbeddingProvider, into_dyn};
87 use crate::{EmbeddingProvider, error::DummyError};
88
89 struct DummyEmbedder;
92
93 #[async_trait::async_trait]
94 impl EmbeddingProvider for DummyEmbedder {
95 type Error = DummyError;
96
97 fn model_id(&self) -> &str {
98 "dummy-2"
99 }
100
101 fn dimensions(&self) -> usize {
102 2
103 }
104
105 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, Self::Error> {
106 if texts.is_empty() {
107 return Err(DummyError::Embed("empty batch".to_owned()));
108 }
109 Ok(texts
111 .iter()
112 .map(|t| vec![t.len() as f32, t.bytes().next().unwrap_or(0) as f32])
113 .collect())
114 }
115 }
116
117 #[tokio::test]
118 async fn erased_embeds_batch() {
119 let p: Arc<DynEmbeddingProvider> = into_dyn(DummyEmbedder);
120 assert_eq!(p.model_id(), "dummy-2");
121 assert_eq!(p.dimensions(), 2);
122 let out = p.embed(&["ab".to_owned(), "xyz".to_owned()]).await.unwrap();
123 assert_eq!(out.len(), 2);
124 assert_eq!(out[0], vec![2.0, b'a' as f32]);
125 }
126
127 #[tokio::test]
128 async fn erased_error_is_preserved() {
129 let p: Arc<DynEmbeddingProvider> = into_dyn(DummyEmbedder);
130 let err = p.embed(&[]).await.unwrap_err();
131 assert_eq!(format!("{err}"), "embed failed: empty batch");
132 let src = std::error::Error::source(&err).expect("source present");
133 assert_eq!(format!("{src}"), "embed failed: empty batch");
134 }
135
136 #[test]
137 fn box_error_satisfies_embedding_error() {
138 fn require<E: crate::error::EmbeddingError>() {}
139 require::<BoxError>();
140 }
141}