1use std::sync::Arc;
2
3use crate::error::{Error, Result};
4
5use super::backend::EmbeddingBackend;
6
7pub struct EmbeddingProvider(Arc<dyn EmbeddingBackend>);
22
23impl Clone for EmbeddingProvider {
24 fn clone(&self) -> Self {
25 Self(Arc::clone(&self.0))
26 }
27}
28
29impl EmbeddingProvider {
30 pub fn new(backend: impl EmbeddingBackend + 'static) -> Self {
32 Self(Arc::new(backend))
33 }
34
35 pub async fn embed(&self, input: &str) -> Result<Vec<u8>> {
42 if input.is_empty() {
43 return Err(Error::bad_request("embedding input must not be empty"));
44 }
45 self.0.embed(input).await
46 }
47
48 pub fn dimensions(&self) -> usize {
50 self.0.dimensions()
51 }
52
53 pub fn model_name(&self) -> &str {
55 self.0.model_name()
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62 use crate::embed::convert::from_f32_blob;
63 use crate::embed::test::InMemoryBackend;
64
65 #[tokio::test]
66 async fn embed_returns_blob_of_correct_length() {
67 let dims = 128;
68 let provider = EmbeddingProvider::new(InMemoryBackend::new(dims));
69 let blob = provider.embed("hello").await.unwrap();
70 assert_eq!(blob.len(), dims * 4);
71 }
72
73 #[tokio::test]
74 async fn embed_blob_roundtrips_to_floats() {
75 let dims = 4;
76 let provider = EmbeddingProvider::new(InMemoryBackend::new(dims));
77 let blob = provider.embed("test").await.unwrap();
78 let floats = from_f32_blob(&blob).unwrap();
79 assert_eq!(floats.len(), dims);
80 assert_eq!(floats, vec![0.1_f32; dims]);
81 }
82
83 #[tokio::test]
84 async fn embed_rejects_empty_input() {
85 let provider = EmbeddingProvider::new(InMemoryBackend::new(4));
86 let err = provider.embed("").await.unwrap_err();
87 assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
88 }
89
90 #[test]
91 fn dimensions_delegated() {
92 let provider = EmbeddingProvider::new(InMemoryBackend::new(768));
93 assert_eq!(provider.dimensions(), 768);
94 }
95
96 #[test]
97 fn model_name_delegated() {
98 let provider = EmbeddingProvider::new(InMemoryBackend::new(4));
99 assert_eq!(provider.model_name(), "test-embedding");
100 }
101}