1use std::collections::{HashMap, VecDeque};
2use std::sync::{Arc, Mutex};
3
4use crate::{
5 CapabilitySupport, EmbeddingCapability, EmbeddingProvider, EmbeddingRequest, EmbeddingResponse,
6 Error, ProviderIdentity, Result,
7};
8
9#[derive(Debug, Clone)]
14pub struct MockEmbeddingProvider {
15 state: Arc<Mutex<MockEmbeddingState>>,
16 embedding_capabilities: HashMap<EmbeddingCapability, CapabilitySupport>,
17 provider_name: &'static str,
18}
19
20#[derive(Debug)]
21struct MockEmbeddingState {
22 responses: VecDeque<Result<EmbeddingResponse>>,
23 requests: Vec<EmbeddingRequest>,
24}
25
26impl MockEmbeddingProvider {
27 #[must_use]
29 pub fn empty() -> Self {
30 Self::new(std::iter::empty::<Result<EmbeddingResponse>>())
31 }
32
33 #[must_use]
35 pub fn new<I>(responses: I) -> Self
36 where
37 I: IntoIterator<Item = Result<EmbeddingResponse>>,
38 {
39 Self {
40 state: Arc::new(Mutex::new(MockEmbeddingState {
41 responses: responses.into_iter().collect(),
42 requests: Vec::new(),
43 })),
44 embedding_capabilities: HashMap::new(),
45 provider_name: "mock",
46 }
47 }
48
49 #[must_use]
51 pub fn with_vectors(vectors: Vec<Vec<f32>>) -> Self {
52 Self::new([Ok(EmbeddingResponse::new(vectors))])
53 }
54
55 #[must_use]
57 pub fn with_error(error: Error) -> Self {
58 Self::new([Err(error)])
59 }
60
61 pub fn push_response(&self, response: Result<EmbeddingResponse>) {
63 self.state.lock().unwrap().responses.push_back(response);
64 }
65
66 #[must_use]
68 pub fn with_provider_name(mut self, name: &'static str) -> Self {
69 self.provider_name = name;
70 self
71 }
72
73 #[must_use]
75 pub fn with_embedding_capability(
76 mut self,
77 capability: EmbeddingCapability,
78 support: CapabilitySupport,
79 ) -> Self {
80 self.embedding_capabilities.insert(capability, support);
81 self
82 }
83
84 #[must_use]
86 pub fn requests(&self) -> Vec<EmbeddingRequest> {
87 self.state.lock().unwrap().requests.clone()
88 }
89
90 #[must_use]
92 pub fn pending_responses(&self) -> usize {
93 self.state.lock().unwrap().responses.len()
94 }
95}
96
97impl ProviderIdentity for MockEmbeddingProvider {
98 fn provider_name(&self) -> &'static str {
99 self.provider_name
100 }
101}
102
103impl EmbeddingProvider for MockEmbeddingProvider {
104 async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
105 let mut state = self.state.lock().unwrap();
106 state.requests.push(request.clone());
107 match state.responses.pop_front() {
108 Some(response) => response,
109 None => Err(Error::UnexpectedResponse(format!(
110 "mock embedding provider '{}' has no queued responses",
111 self.provider_name
112 ))),
113 }
114 }
115
116 fn embedding_capability(
117 &self,
118 _model: &str,
119 capability: EmbeddingCapability,
120 ) -> CapabilitySupport {
121 self.embedding_capabilities
122 .get(&capability)
123 .copied()
124 .unwrap_or(CapabilitySupport::Unknown)
125 }
126}
127
128#[cfg(test)]
129mod embedding_mock_tests {
130 use super::*;
131 use crate::{
132 CapabilitySupport, EmbeddingCapability, EmbeddingProvider, EmbeddingRequest,
133 EmbeddingResponse, Error, ProviderIdentity,
134 };
135
136 #[tokio::test]
137 async fn mock_embedding_provider_returns_queued_response() {
138 let provider = MockEmbeddingProvider::with_vectors(vec![vec![0.1, 0.2]]);
139 let request = EmbeddingRequest::new("mock-embed").input("hello");
140 let response = provider.embed(&request).await.unwrap();
141 assert_eq!(response.embeddings, vec![vec![0.1, 0.2]]);
142
143 let requests = provider.requests();
144 assert_eq!(requests.len(), 1);
145 assert_eq!(requests[0].inputs, vec!["hello".to_string()]);
146 }
147
148 #[tokio::test]
149 async fn mock_embedding_provider_returns_queued_error() {
150 let provider = MockEmbeddingProvider::with_error(Error::Auth("bad".into()));
151 let err = provider
152 .embed(&EmbeddingRequest::new("m").input("x"))
153 .await
154 .unwrap_err();
155 assert!(matches!(err, Error::Auth(_)));
156 }
157
158 #[tokio::test]
159 async fn mock_embedding_provider_returns_responses_in_order() {
160 let provider = MockEmbeddingProvider::new([
161 Ok(EmbeddingResponse::new(vec![vec![1.0]])),
162 Ok(EmbeddingResponse::new(vec![vec![2.0]])),
163 ]);
164 let first = provider
165 .embed(&EmbeddingRequest::new("m").input("a"))
166 .await
167 .unwrap();
168 let second = provider
169 .embed(&EmbeddingRequest::new("m").input("b"))
170 .await
171 .unwrap();
172 assert_eq!(first.embeddings, vec![vec![1.0]]);
173 assert_eq!(second.embeddings, vec![vec![2.0]]);
174 }
175
176 #[tokio::test]
177 async fn mock_embedding_provider_reports_exhaustion() {
178 let provider = MockEmbeddingProvider::empty();
179 let err = provider
180 .embed(&EmbeddingRequest::new("m").input("x"))
181 .await
182 .unwrap_err();
183 assert!(matches!(err, Error::UnexpectedResponse(_)));
184 }
185
186 #[test]
187 fn mock_embedding_provider_exposes_provider_identity_and_capabilities() {
188 let provider = MockEmbeddingProvider::empty()
189 .with_provider_name("demo-embed")
190 .with_embedding_capability(
191 EmbeddingCapability::BatchInput,
192 CapabilitySupport::Supported,
193 );
194 assert_eq!(provider.provider_name(), "demo-embed");
195 assert_eq!(
196 provider.embedding_capability("m", EmbeddingCapability::BatchInput),
197 CapabilitySupport::Supported
198 );
199 assert_eq!(
200 provider.embedding_capability("m", EmbeddingCapability::OutputDimensions),
201 CapabilitySupport::Unknown
202 );
203 }
204}