Skip to main content

pe_core/
mock_provider.rs

1//! MockProvider — deterministic LLM provider for testing.
2//!
3//! Responses are consumed from a queue in FIFO order.
4//! Returns an error if the queue is exhausted.
5
6use std::collections::VecDeque;
7use std::pin::Pin;
8use std::sync::{Mutex, MutexGuard};
9
10use futures::Stream;
11
12use crate::error::PeError;
13use crate::llm::{LlmProvider, LlmResponse, StreamChunk, ToolSchema};
14use crate::message::{AiMessage, Message, MessageContent, ToolCall};
15
16/// What the mock should return for a given call.
17#[derive(Debug, Clone)]
18enum MockResponse {
19    /// Return an AiMessage with text content.
20    Text(String),
21    /// Return an AiMessage with a tool call.
22    ToolCall {
23        tool_name: String,
24        args: serde_json::Value,
25    },
26    /// Return an error.
27    Error(PeError),
28}
29
30/// Deterministic LLM provider for tests.
31///
32/// Queue responses with the builder API, then call `complete()` to
33/// consume them in FIFO order.
34///
35/// # Example
36///
37/// ```
38/// use pe_core::mock_provider::MockProvider;
39/// use pe_core::llm::LlmProvider;
40///
41/// # tokio_test::block_on(async {
42/// let provider = MockProvider::new()
43///     .respond_with("Hello!")
44///     .respond_with("Goodbye!");
45///
46/// let r1 = provider.complete(&[], &[]).await.unwrap();
47/// assert_eq!(r1.message.content.as_text(), Some("Hello!"));
48///
49/// let r2 = provider.complete(&[], &[]).await.unwrap();
50/// assert_eq!(r2.message.content.as_text(), Some("Goodbye!"));
51/// # });
52/// ```
53pub struct MockProvider {
54    responses: Mutex<VecDeque<MockResponse>>,
55    embed_response: Vec<f32>,
56}
57
58impl MockProvider {
59    fn responses_guard(&self) -> MutexGuard<'_, VecDeque<MockResponse>> {
60        match self.responses.lock() {
61            Ok(guard) => guard,
62            Err(poisoned) => poisoned.into_inner(),
63        }
64    }
65
66    /// Create a new MockProvider with an empty response queue.
67    pub fn new() -> Self {
68        Self {
69            responses: Mutex::new(VecDeque::new()),
70            embed_response: vec![0.0; 128], // default 128-dim zero vector
71        }
72    }
73
74    /// Queue a plain text response.
75    #[must_use = "builder methods return the modified builder"]
76    pub fn respond_with(self, text: impl Into<String>) -> Self {
77        self.responses_guard()
78            .push_back(MockResponse::Text(text.into()));
79        self
80    }
81
82    /// Queue a tool call response.
83    #[must_use = "builder methods return the modified builder"]
84    pub fn respond_with_tool_call(
85        self,
86        tool_name: impl Into<String>,
87        args: serde_json::Value,
88    ) -> Self {
89        self.responses_guard().push_back(MockResponse::ToolCall {
90            tool_name: tool_name.into(),
91            args,
92        });
93        self
94    }
95
96    /// Queue an error response.
97    #[must_use = "builder methods return the modified builder"]
98    pub fn respond_with_error(self, err: PeError) -> Self {
99        self.responses_guard().push_back(MockResponse::Error(err));
100        self
101    }
102
103    /// Set the embedding to return for all `embed()` calls.
104    #[must_use = "builder methods return the modified builder"]
105    pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
106        self.embed_response = embedding;
107        self
108    }
109
110    /// Number of responses remaining in the queue.
111    pub fn remaining(&self) -> usize {
112        self.responses_guard().len()
113    }
114
115    fn next_response(&self) -> Result<MockResponse, PeError> {
116        self.responses_guard()
117            .pop_front()
118            .ok_or(PeError::MockProviderExhausted)
119    }
120
121    fn mock_response_to_llm(resp: MockResponse) -> Result<LlmResponse, PeError> {
122        match resp {
123            MockResponse::Text(text) => Ok(LlmResponse {
124                message: AiMessage {
125                    content: MessageContent::Text(text),
126                    tool_calls: vec![],
127                    invalid_tool_calls: vec![],
128                    usage_metadata: None,
129                    response_metadata: Default::default(),
130                    id: None,
131                },
132                provider_metadata: Default::default(),
133            }),
134            MockResponse::ToolCall { tool_name, args } => Ok(LlmResponse {
135                message: AiMessage {
136                    content: MessageContent::Text(String::new()),
137                    tool_calls: vec![ToolCall {
138                        id: format!("call_{}", tool_name),
139                        name: tool_name,
140                        args,
141                    }],
142                    invalid_tool_calls: vec![],
143                    usage_metadata: None,
144                    response_metadata: Default::default(),
145                    id: None,
146                },
147                provider_metadata: Default::default(),
148            }),
149            MockResponse::Error(e) => Err(e),
150        }
151    }
152}
153
154impl Default for MockProvider {
155    fn default() -> Self {
156        Self::new()
157    }
158}
159
160impl LlmProvider for MockProvider {
161    fn complete(
162        &self,
163        _messages: &[Message],
164        _tools: &[ToolSchema],
165    ) -> Pin<Box<dyn std::future::Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
166        Box::pin(async move {
167            let resp = self.next_response()?;
168            Self::mock_response_to_llm(resp)
169        })
170    }
171
172    fn stream(&self, _messages: &[Message], _tools: &[ToolSchema]) -> crate::llm::StreamFuture<'_> {
173        Box::pin(async move {
174            let resp = self.next_response()?;
175            let llm_resp = Self::mock_response_to_llm(resp)?;
176
177            // For mock streaming, emit the full text as one token then Done
178            let text = llm_resp.message.content.as_text().unwrap_or("").to_string();
179            let chunks = vec![StreamChunk::Token(text), StreamChunk::Done(llm_resp)];
180
181            Ok(Box::pin(futures::stream::iter(chunks))
182                as Pin<Box<dyn Stream<Item = StreamChunk> + Send>>)
183        })
184    }
185
186    fn embed(
187        &self,
188        _text: &str,
189    ) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
190        let embedding = self.embed_response.clone();
191        Box::pin(async move { Ok(embedding) })
192    }
193
194    fn provider_name(&self) -> &'static str {
195        "mock"
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[tokio::test]
204    async fn test_text_response() {
205        let provider = MockProvider::new().respond_with("Hello, world!");
206
207        let resp = provider.complete(&[], &[]).await.unwrap();
208        assert_eq!(resp.message.content.as_text(), Some("Hello, world!"));
209    }
210
211    #[tokio::test]
212    async fn test_tool_call_response() {
213        let provider = MockProvider::new()
214            .respond_with_tool_call("web_search", serde_json::json!({ "query": "rust async" }));
215
216        let resp = provider.complete(&[], &[]).await.unwrap();
217        assert_eq!(resp.message.tool_calls.len(), 1);
218        assert_eq!(resp.message.tool_calls[0].name, "web_search");
219    }
220
221    #[tokio::test]
222    async fn test_multiple_responses_fifo() {
223        let provider = MockProvider::new()
224            .respond_with("first")
225            .respond_with("second")
226            .respond_with("third");
227
228        let r1 = provider.complete(&[], &[]).await.unwrap();
229        let r2 = provider.complete(&[], &[]).await.unwrap();
230        let r3 = provider.complete(&[], &[]).await.unwrap();
231
232        assert_eq!(r1.message.content.as_text(), Some("first"));
233        assert_eq!(r2.message.content.as_text(), Some("second"));
234        assert_eq!(r3.message.content.as_text(), Some("third"));
235    }
236
237    #[tokio::test]
238    async fn test_exhausted_queue_returns_error() {
239        let provider = MockProvider::new().respond_with("only one");
240
241        let _ = provider.complete(&[], &[]).await.unwrap();
242        let err = provider.complete(&[], &[]).await.unwrap_err();
243
244        assert!(matches!(err, PeError::MockProviderExhausted));
245    }
246
247    #[tokio::test]
248    async fn test_error_response() {
249        let provider = MockProvider::new().respond_with_error(PeError::LlmProvider {
250            details: "rate limited".into(),
251        });
252
253        let err = provider.complete(&[], &[]).await.unwrap_err();
254        assert!(matches!(err, PeError::LlmProvider { .. }));
255    }
256
257    #[tokio::test]
258    async fn test_embed_returns_configured_vector() {
259        let provider = MockProvider::new().with_embedding(vec![1.0, 2.0, 3.0]);
260
261        let embedding = provider.embed("test text").await.unwrap();
262        assert_eq!(embedding, vec![1.0, 2.0, 3.0]);
263    }
264
265    #[tokio::test]
266    async fn test_remaining_count() {
267        let provider = MockProvider::new().respond_with("a").respond_with("b");
268
269        assert_eq!(provider.remaining(), 2);
270        let _ = provider.complete(&[], &[]).await;
271        assert_eq!(provider.remaining(), 1);
272    }
273
274    #[test]
275    fn poisoned_queue_lock_is_recovered() {
276        let provider = MockProvider::new().respond_with("hello");
277
278        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
279            let _guard = provider.responses.lock().unwrap();
280            panic!("poison mock provider");
281        }));
282        assert!(result.is_err());
283
284        assert_eq!(provider.remaining(), 1);
285    }
286
287    #[tokio::test]
288    async fn test_object_safety() {
289        // Verify LlmProvider is object-safe
290        let provider: Box<dyn LlmProvider> = Box::new(MockProvider::new().respond_with("boxed"));
291        let resp = provider.complete(&[], &[]).await.unwrap();
292        assert_eq!(resp.message.content.as_text(), Some("boxed"));
293    }
294}