Skip to main content

anyllm/embedding/
mock.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::{Arc, Mutex};
3
4use crate::{
5    CapabilitySupport, EmbeddingCapability, EmbeddingProvider, EmbeddingRequest, EmbeddingResponse,
6    Error, ProviderIdentity, Result,
7};
8
9/// Deterministic embedding provider for tests.
10///
11/// Returns canned responses (or errors) in sequence, records every request,
12/// and exposes provider identity plus embedding capability overrides.
13#[derive(Debug, Clone)]
14pub struct MockEmbeddingProvider {
15    state: Arc<Mutex<MockEmbeddingState>>,
16    embedding_capabilities: HashMap<EmbeddingCapability, CapabilitySupport>,
17    provider_name: &'static str,
18}
19
20#[derive(Debug)]
21struct MockEmbeddingState {
22    responses: VecDeque<Result<EmbeddingResponse>>,
23    requests: Vec<EmbeddingRequest>,
24}
25
26impl MockEmbeddingProvider {
27    /// Create a mock embedding provider with no queued responses.
28    #[must_use]
29    pub fn empty() -> Self {
30        Self::new(std::iter::empty::<Result<EmbeddingResponse>>())
31    }
32
33    /// Create a mock embedding provider from an ordered queue of responses.
34    #[must_use]
35    pub fn new<I>(responses: I) -> Self
36    where
37        I: IntoIterator<Item = Result<EmbeddingResponse>>,
38    {
39        Self {
40            state: Arc::new(Mutex::new(MockEmbeddingState {
41                responses: responses.into_iter().collect(),
42                requests: Vec::new(),
43            })),
44            embedding_capabilities: HashMap::new(),
45            provider_name: "mock",
46        }
47    }
48
49    /// Convenience for a single successful response with the given vectors.
50    #[must_use]
51    pub fn with_vectors(vectors: Vec<Vec<f32>>) -> Self {
52        Self::new([Ok(EmbeddingResponse::new(vectors))])
53    }
54
55    /// Convenience for a single error response.
56    #[must_use]
57    pub fn with_error(error: Error) -> Self {
58        Self::new([Err(error)])
59    }
60
61    /// Queue an additional response without mutating existing queue entries.
62    pub fn push_response(&self, response: Result<EmbeddingResponse>) {
63        self.state.lock().unwrap().responses.push_back(response);
64    }
65
66    /// Override the provider identity name reported by this mock.
67    #[must_use]
68    pub fn with_provider_name(mut self, name: &'static str) -> Self {
69        self.provider_name = name;
70        self
71    }
72
73    /// Override support for an embedding capability.
74    #[must_use]
75    pub fn with_embedding_capability(
76        mut self,
77        capability: EmbeddingCapability,
78        support: CapabilitySupport,
79    ) -> Self {
80        self.embedding_capabilities.insert(capability, support);
81        self
82    }
83
84    /// Snapshot of every request the mock has received.
85    #[must_use]
86    pub fn requests(&self) -> Vec<EmbeddingRequest> {
87        self.state.lock().unwrap().requests.clone()
88    }
89
90    /// Number of queued responses remaining.
91    #[must_use]
92    pub fn pending_responses(&self) -> usize {
93        self.state.lock().unwrap().responses.len()
94    }
95}
96
97impl ProviderIdentity for MockEmbeddingProvider {
98    fn provider_name(&self) -> &'static str {
99        self.provider_name
100    }
101}
102
103impl EmbeddingProvider for MockEmbeddingProvider {
104    async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
105        let mut state = self.state.lock().unwrap();
106        state.requests.push(request.clone());
107        match state.responses.pop_front() {
108            Some(response) => response,
109            None => Err(Error::UnexpectedResponse(format!(
110                "mock embedding provider '{}' has no queued responses",
111                self.provider_name
112            ))),
113        }
114    }
115
116    fn embedding_capability(
117        &self,
118        _model: &str,
119        capability: EmbeddingCapability,
120    ) -> CapabilitySupport {
121        self.embedding_capabilities
122            .get(&capability)
123            .copied()
124            .unwrap_or(CapabilitySupport::Unknown)
125    }
126}
127
128#[cfg(test)]
129mod embedding_mock_tests {
130    use super::*;
131    use crate::{
132        CapabilitySupport, EmbeddingCapability, EmbeddingProvider, EmbeddingRequest,
133        EmbeddingResponse, Error, ProviderIdentity,
134    };
135
136    #[tokio::test]
137    async fn mock_embedding_provider_returns_queued_response() {
138        let provider = MockEmbeddingProvider::with_vectors(vec![vec![0.1, 0.2]]);
139        let request = EmbeddingRequest::new("mock-embed").input("hello");
140        let response = provider.embed(&request).await.unwrap();
141        assert_eq!(response.embeddings, vec![vec![0.1, 0.2]]);
142
143        let requests = provider.requests();
144        assert_eq!(requests.len(), 1);
145        assert_eq!(requests[0].inputs, vec!["hello".to_string()]);
146    }
147
148    #[tokio::test]
149    async fn mock_embedding_provider_returns_queued_error() {
150        let provider = MockEmbeddingProvider::with_error(Error::Auth("bad".into()));
151        let err = provider
152            .embed(&EmbeddingRequest::new("m").input("x"))
153            .await
154            .unwrap_err();
155        assert!(matches!(err, Error::Auth(_)));
156    }
157
158    #[tokio::test]
159    async fn mock_embedding_provider_returns_responses_in_order() {
160        let provider = MockEmbeddingProvider::new([
161            Ok(EmbeddingResponse::new(vec![vec![1.0]])),
162            Ok(EmbeddingResponse::new(vec![vec![2.0]])),
163        ]);
164        let first = provider
165            .embed(&EmbeddingRequest::new("m").input("a"))
166            .await
167            .unwrap();
168        let second = provider
169            .embed(&EmbeddingRequest::new("m").input("b"))
170            .await
171            .unwrap();
172        assert_eq!(first.embeddings, vec![vec![1.0]]);
173        assert_eq!(second.embeddings, vec![vec![2.0]]);
174    }
175
176    #[tokio::test]
177    async fn mock_embedding_provider_reports_exhaustion() {
178        let provider = MockEmbeddingProvider::empty();
179        let err = provider
180            .embed(&EmbeddingRequest::new("m").input("x"))
181            .await
182            .unwrap_err();
183        assert!(matches!(err, Error::UnexpectedResponse(_)));
184    }
185
186    #[test]
187    fn mock_embedding_provider_exposes_provider_identity_and_capabilities() {
188        let provider = MockEmbeddingProvider::empty()
189            .with_provider_name("demo-embed")
190            .with_embedding_capability(
191                EmbeddingCapability::BatchInput,
192                CapabilitySupport::Supported,
193            );
194        assert_eq!(provider.provider_name(), "demo-embed");
195        assert_eq!(
196            provider.embedding_capability("m", EmbeddingCapability::BatchInput),
197            CapabilitySupport::Supported
198        );
199        assert_eq!(
200            provider.embedding_capability("m", EmbeddingCapability::OutputDimensions),
201            CapabilitySupport::Unknown
202        );
203    }
204}