use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::{Mutex, MutexGuard};
use futures::Stream;
use crate::error::PeError;
use crate::llm::{LlmProvider, LlmResponse, StreamChunk, ToolSchema};
use crate::message::{AiMessage, Message, MessageContent, ToolCall};
#[derive(Debug, Clone)]
enum MockResponse {
Text(String),
ToolCall {
tool_name: String,
args: serde_json::Value,
},
Error(PeError),
}
pub struct MockProvider {
responses: Mutex<VecDeque<MockResponse>>,
embed_response: Vec<f32>,
}
impl MockProvider {
fn responses_guard(&self) -> MutexGuard<'_, VecDeque<MockResponse>> {
match self.responses.lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
pub fn new() -> Self {
Self {
responses: Mutex::new(VecDeque::new()),
embed_response: vec![0.0; 128], }
}
#[must_use = "builder methods return the modified builder"]
pub fn respond_with(self, text: impl Into<String>) -> Self {
self.responses_guard()
.push_back(MockResponse::Text(text.into()));
self
}
#[must_use = "builder methods return the modified builder"]
pub fn respond_with_tool_call(
self,
tool_name: impl Into<String>,
args: serde_json::Value,
) -> Self {
self.responses_guard().push_back(MockResponse::ToolCall {
tool_name: tool_name.into(),
args,
});
self
}
#[must_use = "builder methods return the modified builder"]
pub fn respond_with_error(self, err: PeError) -> Self {
self.responses_guard().push_back(MockResponse::Error(err));
self
}
#[must_use = "builder methods return the modified builder"]
pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
self.embed_response = embedding;
self
}
pub fn remaining(&self) -> usize {
self.responses_guard().len()
}
fn next_response(&self) -> Result<MockResponse, PeError> {
self.responses_guard()
.pop_front()
.ok_or(PeError::MockProviderExhausted)
}
fn mock_response_to_llm(resp: MockResponse) -> Result<LlmResponse, PeError> {
match resp {
MockResponse::Text(text) => Ok(LlmResponse {
message: AiMessage {
content: MessageContent::Text(text),
tool_calls: vec![],
invalid_tool_calls: vec![],
usage_metadata: None,
response_metadata: Default::default(),
id: None,
},
provider_metadata: Default::default(),
}),
MockResponse::ToolCall { tool_name, args } => Ok(LlmResponse {
message: AiMessage {
content: MessageContent::Text(String::new()),
tool_calls: vec![ToolCall {
id: format!("call_{}", tool_name),
name: tool_name,
args,
}],
invalid_tool_calls: vec![],
usage_metadata: None,
response_metadata: Default::default(),
id: None,
},
provider_metadata: Default::default(),
}),
MockResponse::Error(e) => Err(e),
}
}
}
impl Default for MockProvider {
fn default() -> Self {
Self::new()
}
}
impl LlmProvider for MockProvider {
fn complete(
&self,
_messages: &[Message],
_tools: &[ToolSchema],
) -> Pin<Box<dyn std::future::Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
Box::pin(async move {
let resp = self.next_response()?;
Self::mock_response_to_llm(resp)
})
}
fn stream(&self, _messages: &[Message], _tools: &[ToolSchema]) -> crate::llm::StreamFuture<'_> {
Box::pin(async move {
let resp = self.next_response()?;
let llm_resp = Self::mock_response_to_llm(resp)?;
let text = llm_resp.message.content.as_text().unwrap_or("").to_string();
let chunks = vec![StreamChunk::Token(text), StreamChunk::Done(llm_resp)];
Ok(Box::pin(futures::stream::iter(chunks))
as Pin<Box<dyn Stream<Item = StreamChunk> + Send>>)
})
}
fn embed(
&self,
_text: &str,
) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
let embedding = self.embed_response.clone();
Box::pin(async move { Ok(embedding) })
}
fn provider_name(&self) -> &'static str {
"mock"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_text_response() {
let provider = MockProvider::new().respond_with("Hello, world!");
let resp = provider.complete(&[], &[]).await.unwrap();
assert_eq!(resp.message.content.as_text(), Some("Hello, world!"));
}
#[tokio::test]
async fn test_tool_call_response() {
let provider = MockProvider::new()
.respond_with_tool_call("web_search", serde_json::json!({ "query": "rust async" }));
let resp = provider.complete(&[], &[]).await.unwrap();
assert_eq!(resp.message.tool_calls.len(), 1);
assert_eq!(resp.message.tool_calls[0].name, "web_search");
}
#[tokio::test]
async fn test_multiple_responses_fifo() {
let provider = MockProvider::new()
.respond_with("first")
.respond_with("second")
.respond_with("third");
let r1 = provider.complete(&[], &[]).await.unwrap();
let r2 = provider.complete(&[], &[]).await.unwrap();
let r3 = provider.complete(&[], &[]).await.unwrap();
assert_eq!(r1.message.content.as_text(), Some("first"));
assert_eq!(r2.message.content.as_text(), Some("second"));
assert_eq!(r3.message.content.as_text(), Some("third"));
}
#[tokio::test]
async fn test_exhausted_queue_returns_error() {
let provider = MockProvider::new().respond_with("only one");
let _ = provider.complete(&[], &[]).await.unwrap();
let err = provider.complete(&[], &[]).await.unwrap_err();
assert!(matches!(err, PeError::MockProviderExhausted));
}
#[tokio::test]
async fn test_error_response() {
let provider = MockProvider::new().respond_with_error(PeError::LlmProvider {
details: "rate limited".into(),
});
let err = provider.complete(&[], &[]).await.unwrap_err();
assert!(matches!(err, PeError::LlmProvider { .. }));
}
#[tokio::test]
async fn test_embed_returns_configured_vector() {
let provider = MockProvider::new().with_embedding(vec![1.0, 2.0, 3.0]);
let embedding = provider.embed("test text").await.unwrap();
assert_eq!(embedding, vec![1.0, 2.0, 3.0]);
}
#[tokio::test]
async fn test_remaining_count() {
let provider = MockProvider::new().respond_with("a").respond_with("b");
assert_eq!(provider.remaining(), 2);
let _ = provider.complete(&[], &[]).await;
assert_eq!(provider.remaining(), 1);
}
#[test]
fn poisoned_queue_lock_is_recovered() {
let provider = MockProvider::new().respond_with("hello");
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _guard = provider.responses.lock().unwrap();
panic!("poison mock provider");
}));
assert!(result.is_err());
assert_eq!(provider.remaining(), 1);
}
#[tokio::test]
async fn test_object_safety() {
let provider: Box<dyn LlmProvider> = Box::new(MockProvider::new().respond_with("boxed"));
let resp = provider.complete(&[], &[]).await.unwrap();
assert_eq!(resp.message.content.as_text(), Some("boxed"));
}
}