1use std::sync::Arc;
2
3use crate::error::{Error, Result};
4
5use super::backend::EmbeddingBackend;
6
7pub struct EmbeddingProvider(Arc<dyn EmbeddingBackend>);
24
25impl Clone for EmbeddingProvider {
26 fn clone(&self) -> Self {
27 Self(Arc::clone(&self.0))
28 }
29}
30
31impl EmbeddingProvider {
32 pub fn new(backend: impl EmbeddingBackend + 'static) -> Self {
34 Self(Arc::new(backend))
35 }
36
37 pub async fn embed(&self, input: &str) -> Result<Vec<u8>> {
44 if input.is_empty() {
45 return Err(Error::bad_request("embedding input must not be empty"));
46 }
47 self.0.embed(input).await
48 }
49
50 pub fn dimensions(&self) -> usize {
52 self.0.dimensions()
53 }
54
55 pub fn model_name(&self) -> &str {
57 self.0.model_name()
58 }
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64 use crate::embed::convert::from_f32_blob;
65 use crate::embed::test::InMemoryBackend;
66
67 #[tokio::test]
68 async fn embed_returns_blob_of_correct_length() {
69 let dims = 128;
70 let provider = EmbeddingProvider::new(InMemoryBackend::new(dims));
71 let blob = provider.embed("hello").await.unwrap();
72 assert_eq!(blob.len(), dims * 4);
73 }
74
75 #[tokio::test]
76 async fn embed_blob_roundtrips_to_floats() {
77 let dims = 4;
78 let provider = EmbeddingProvider::new(InMemoryBackend::new(dims));
79 let blob = provider.embed("test").await.unwrap();
80 let floats = from_f32_blob(&blob).unwrap();
81 assert_eq!(floats.len(), dims);
82 assert_eq!(floats, vec![0.1_f32; dims]);
83 }
84
85 #[tokio::test]
86 async fn embed_rejects_empty_input() {
87 let provider = EmbeddingProvider::new(InMemoryBackend::new(4));
88 let err = provider.embed("").await.unwrap_err();
89 assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
90 }
91
92 #[test]
93 fn dimensions_delegated() {
94 let provider = EmbeddingProvider::new(InMemoryBackend::new(768));
95 assert_eq!(provider.dimensions(), 768);
96 }
97
98 #[test]
99 fn model_name_delegated() {
100 let provider = EmbeddingProvider::new(InMemoryBackend::new(4));
101 assert_eq!(provider.model_name(), "test-embedding");
102 }
103}