use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use crate::{
CapabilitySupport, EmbeddingCapability, EmbeddingProvider, EmbeddingRequest, EmbeddingResponse,
Error, ProviderIdentity, Result,
};
#[derive(Debug, Clone)]
pub struct MockEmbeddingProvider {
state: Arc<Mutex<MockEmbeddingState>>,
embedding_capabilities: HashMap<EmbeddingCapability, CapabilitySupport>,
provider_name: &'static str,
}
#[derive(Debug)]
struct MockEmbeddingState {
responses: VecDeque<Result<EmbeddingResponse>>,
requests: Vec<EmbeddingRequest>,
}
impl MockEmbeddingProvider {
#[must_use]
pub fn empty() -> Self {
Self::new(std::iter::empty::<Result<EmbeddingResponse>>())
}
#[must_use]
pub fn new<I>(responses: I) -> Self
where
I: IntoIterator<Item = Result<EmbeddingResponse>>,
{
Self {
state: Arc::new(Mutex::new(MockEmbeddingState {
responses: responses.into_iter().collect(),
requests: Vec::new(),
})),
embedding_capabilities: HashMap::new(),
provider_name: "mock",
}
}
#[must_use]
pub fn with_vectors(vectors: Vec<Vec<f32>>) -> Self {
Self::new([Ok(EmbeddingResponse::new(vectors))])
}
#[must_use]
pub fn with_error(error: Error) -> Self {
Self::new([Err(error)])
}
pub fn push_response(&self, response: Result<EmbeddingResponse>) {
self.state.lock().unwrap().responses.push_back(response);
}
#[must_use]
pub fn with_provider_name(mut self, name: &'static str) -> Self {
self.provider_name = name;
self
}
#[must_use]
pub fn with_embedding_capability(
mut self,
capability: EmbeddingCapability,
support: CapabilitySupport,
) -> Self {
self.embedding_capabilities.insert(capability, support);
self
}
#[must_use]
pub fn requests(&self) -> Vec<EmbeddingRequest> {
self.state.lock().unwrap().requests.clone()
}
#[must_use]
pub fn pending_responses(&self) -> usize {
self.state.lock().unwrap().responses.len()
}
}
impl ProviderIdentity for MockEmbeddingProvider {
fn provider_name(&self) -> &'static str {
self.provider_name
}
}
impl EmbeddingProvider for MockEmbeddingProvider {
async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
let mut state = self.state.lock().unwrap();
state.requests.push(request.clone());
match state.responses.pop_front() {
Some(response) => response,
None => Err(Error::UnexpectedResponse(format!(
"mock embedding provider '{}' has no queued responses",
self.provider_name
))),
}
}
fn embedding_capability(
&self,
_model: &str,
capability: EmbeddingCapability,
) -> CapabilitySupport {
self.embedding_capabilities
.get(&capability)
.copied()
.unwrap_or(CapabilitySupport::Unknown)
}
}
#[cfg(test)]
mod embedding_mock_tests {
use super::*;
use crate::{
CapabilitySupport, EmbeddingCapability, EmbeddingProvider, EmbeddingRequest,
EmbeddingResponse, Error, ProviderIdentity,
};
#[tokio::test]
async fn mock_embedding_provider_returns_queued_response() {
let provider = MockEmbeddingProvider::with_vectors(vec![vec![0.1, 0.2]]);
let request = EmbeddingRequest::new("mock-embed").input("hello");
let response = provider.embed(&request).await.unwrap();
assert_eq!(response.embeddings, vec![vec![0.1, 0.2]]);
let requests = provider.requests();
assert_eq!(requests.len(), 1);
assert_eq!(requests[0].inputs, vec!["hello".to_string()]);
}
#[tokio::test]
async fn mock_embedding_provider_returns_queued_error() {
let provider = MockEmbeddingProvider::with_error(Error::Auth("bad".into()));
let err = provider
.embed(&EmbeddingRequest::new("m").input("x"))
.await
.unwrap_err();
assert!(matches!(err, Error::Auth(_)));
}
#[tokio::test]
async fn mock_embedding_provider_returns_responses_in_order() {
let provider = MockEmbeddingProvider::new([
Ok(EmbeddingResponse::new(vec![vec![1.0]])),
Ok(EmbeddingResponse::new(vec![vec![2.0]])),
]);
let first = provider
.embed(&EmbeddingRequest::new("m").input("a"))
.await
.unwrap();
let second = provider
.embed(&EmbeddingRequest::new("m").input("b"))
.await
.unwrap();
assert_eq!(first.embeddings, vec![vec![1.0]]);
assert_eq!(second.embeddings, vec![vec![2.0]]);
}
#[tokio::test]
async fn mock_embedding_provider_reports_exhaustion() {
let provider = MockEmbeddingProvider::empty();
let err = provider
.embed(&EmbeddingRequest::new("m").input("x"))
.await
.unwrap_err();
assert!(matches!(err, Error::UnexpectedResponse(_)));
}
#[test]
fn mock_embedding_provider_exposes_provider_identity_and_capabilities() {
let provider = MockEmbeddingProvider::empty()
.with_provider_name("demo-embed")
.with_embedding_capability(
EmbeddingCapability::BatchInput,
CapabilitySupport::Supported,
);
assert_eq!(provider.provider_name(), "demo-embed");
assert_eq!(
provider.embedding_capability("m", EmbeddingCapability::BatchInput),
CapabilitySupport::Supported
);
assert_eq!(
provider.embedding_capability("m", EmbeddingCapability::OutputDimensions),
CapabilitySupport::Unknown
);
}
}