Skip to main content

paladin_llm/
mock.rs

1//! Mock LLM adapters for testing.
2//!
3//! This module provides [`MockLlmAdapter`] and [`MultiStepMockLlmPort`] for
4//! use in unit and integration tests that need an in-process LLM adapter
5//! without real API calls.
6
7use async_trait::async_trait;
8use chrono::Utc;
9use futures::stream;
10use paladin_ports::output::llm_port::{
11    FinishReason, LlmError, LlmPort, LlmRequest, LlmResponse, ProviderCapabilities,
12    StreamingResponse, TokenUsage,
13};
14use std::collections::HashMap;
15use std::sync::{Arc, Mutex};
16use std::time::Duration;
17use uuid::Uuid;
18
19/// A single mocked response entry — either a success string or an error.
20#[derive(Debug, Clone)]
21enum MockEntry {
22    Success(String),
23    Error(LlmError),
24}
25
26/// Internal state shared between clones of a [`MockLlmAdapter`].
27#[derive(Debug)]
28struct MockState {
29    responses: Vec<MockEntry>,
30    response_index: usize,
31    delay: Option<Duration>,
32    token_usage: TokenUsage,
33    finish_reason: FinishReason,
34    available_models: Vec<String>,
35    call_count: usize,
36}
37
38impl Default for MockState {
39    fn default() -> Self {
40        Self {
41            responses: vec![MockEntry::Success("Mock LLM response".to_string())],
42            response_index: 0,
43            delay: None,
44            token_usage: TokenUsage {
45                prompt_tokens: 10,
46                completion_tokens: 20,
47                total_tokens: 30,
48            },
49            finish_reason: FinishReason::Stop,
50            available_models: vec!["mock-model".to_string()],
51            call_count: 0,
52        }
53    }
54}
55
56/// Configurable mock adapter for the [`LlmPort`] trait.
57///
58/// Suitable for unit tests and integration tests that need an in-process LLM
59/// without making real API calls.
60///
61/// # Example
62///
63/// ```rust
64/// use paladin_llm::mock::MockLlmAdapter;
65///
66/// let adapter = MockLlmAdapter::new()
67///     .with_responses(vec![
68///         "First response".to_string(),
69///         "Second response".to_string(),
70///     ]);
71/// ```
72#[derive(Debug, Clone)]
73pub struct MockLlmAdapter {
74    state: Arc<Mutex<MockState>>,
75}
76
77impl MockLlmAdapter {
78    /// Create a new adapter with default configuration (single success response).
79    pub fn new() -> Self {
80        Self {
81            state: Arc::new(Mutex::new(MockState::default())),
82        }
83    }
84
85    /// Set the sequence of success responses to return (cycling when exhausted).
86    pub fn with_responses(self, responses: Vec<String>) -> Self {
87        let mut state = self.state.lock().unwrap();
88        state.responses = responses.into_iter().map(MockEntry::Success).collect();
89        state.response_index = 0;
90        drop(state);
91        self
92    }
93
94    /// Queue a single success response.
95    pub fn with_response(self, response: impl Into<String>) -> Self {
96        self.with_responses(vec![response.into()])
97    }
98
99    /// Set an error to be returned on the next call.
100    pub fn with_error(self, error: LlmError) -> Self {
101        let mut state = self.state.lock().unwrap();
102        state.responses = vec![MockEntry::Error(error)];
103        state.response_index = 0;
104        drop(state);
105        self
106    }
107
108    /// Simulate network latency by adding a delay before each response.
109    pub fn with_delay(self, delay: Duration) -> Self {
110        self.state.lock().unwrap().delay = Some(delay);
111        self
112    }
113
114    /// Configure the token usage returned with each response using a `TokenUsage` struct.
115    pub fn with_token_usage_struct(self, usage: TokenUsage) -> Self {
116        self.state.lock().unwrap().token_usage = usage;
117        self
118    }
119
120    /// Configure the token usage returned with each response (prompt, completion, total).
121    pub fn with_token_usage(
122        self,
123        prompt_tokens: u32,
124        completion_tokens: u32,
125        total_tokens: u32,
126    ) -> Self {
127        self.state.lock().unwrap().token_usage = TokenUsage {
128            prompt_tokens,
129            completion_tokens,
130            total_tokens,
131        };
132        self
133    }
134
135    /// Configure the `FinishReason` returned with each successful response.
136    pub fn with_finish_reason(self, reason: FinishReason) -> Self {
137        self.state.lock().unwrap().finish_reason = reason;
138        self
139    }
140
141    /// Configure the list of models reported as available.
142    pub fn with_available_models(self, models: Vec<String>) -> Self {
143        self.state.lock().unwrap().available_models = models;
144        self
145    }
146
147    /// Queue an error response followed by a success response.
148    ///
149    /// Useful for testing retry / recovery logic.
150    pub fn with_error_then_response(self, error: LlmError, response: impl Into<String>) -> Self {
151        let mut state = self.state.lock().unwrap();
152        state.responses = vec![MockEntry::Error(error), MockEntry::Success(response.into())];
153        state.response_index = 0;
154        drop(state);
155        self
156    }
157
158    /// Return the number of times [`LlmPort::generate`] has been called.
159    pub fn call_count(&self) -> usize {
160        self.state.lock().unwrap().call_count
161    }
162
163    /// Alias for [`call_count`](Self::call_count) for compatibility.
164    pub fn get_call_count(&self) -> usize {
165        self.call_count()
166    }
167
168    /// Reset the call counter and response index to zero.
169    pub fn reset(&self) {
170        let mut state = self.state.lock().unwrap();
171        state.call_count = 0;
172        state.response_index = 0;
173    }
174
175    /// Return `true` if the adapter was called at least once.
176    pub fn was_called(&self) -> bool {
177        self.call_count() > 0
178    }
179}
180
181impl Default for MockLlmAdapter {
182    fn default() -> Self {
183        Self::new()
184    }
185}
186
187#[async_trait]
188impl LlmPort for MockLlmAdapter {
189    async fn generate(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
190        let (response_entry, delay, token_usage, finish_reason) = {
191            let mut state = self.state.lock().unwrap();
192            state.call_count += 1;
193            let index = state.response_index;
194            let entry = state
195                .responses
196                .get(index)
197                .cloned()
198                .unwrap_or(MockEntry::Success("Mock LLM response".to_string()));
199            // Advance index, cycling through responses
200            state.response_index = (index + 1) % state.responses.len().max(1);
201            (
202                entry,
203                state.delay,
204                state.token_usage.clone(),
205                state.finish_reason.clone(),
206            )
207        };
208
209        if let Some(delay) = delay {
210            tokio::time::sleep(delay).await;
211        }
212
213        match response_entry {
214            MockEntry::Error(e) => Err(e),
215            MockEntry::Success(content) => Ok(LlmResponse {
216                id: Uuid::new_v4(),
217                request_id: request.id,
218                model: request.model.clone(),
219                content,
220                finish_reason,
221                usage: token_usage,
222                created_at: Utc::now(),
223                metadata: HashMap::new(),
224                function_call: None,
225            }),
226        }
227    }
228
229    async fn generate_stream(
230        &self,
231        request: LlmRequest,
232    ) -> Result<Box<dyn futures::Stream<Item = Result<StreamingResponse, LlmError>> + Send>, LlmError>
233    {
234        let response = self.generate(request).await?;
235        // Emit the full response as a single streaming chunk, then stop.
236        let chunks = vec![
237            Ok(StreamingResponse {
238                id: Uuid::new_v4(),
239                delta: response.content.clone(),
240                finish_reason: None,
241            }),
242            Ok(StreamingResponse {
243                id: Uuid::new_v4(),
244                delta: String::new(),
245                finish_reason: Some(response.finish_reason),
246            }),
247        ];
248        Ok(Box::new(stream::iter(chunks)))
249    }
250
251    async fn validate_model(&self, model: &str) -> Result<bool, LlmError> {
252        let state = self.state.lock().unwrap();
253        Ok(state.available_models.contains(&model.to_string()))
254    }
255
256    async fn get_available_models(&self) -> Result<Vec<String>, LlmError> {
257        Ok(self.state.lock().unwrap().available_models.clone())
258    }
259
260    fn get_provider_name(&self) -> &'static str {
261        "MockLLM"
262    }
263
264    fn get_capabilities(&self) -> ProviderCapabilities {
265        ProviderCapabilities {
266            supports_streaming: true,
267            supports_tool_calling: false,
268            supports_function_calling: false,
269            supports_vision: false,
270            max_context_tokens: Some(4096),
271            supports_embeddings: false,
272            supports_system_messages: true,
273        }
274    }
275}
276
277// ---------------------------------------------------------------------------
278// MultiStepMockLlmPort
279// ---------------------------------------------------------------------------
280
281/// A mock adapter that returns different responses for successive calls.
282///
283/// Unlike [`MockLlmAdapter`] (which cycles), this adapter returns each response
284/// exactly once and then panics if called more times than there are responses.
285/// This is useful for tests that assert the *exact* sequence of LLM calls.
286///
287/// # Example
288///
289/// ```rust
290/// use paladin_llm::mock::MultiStepMockLlmPort;
291///
292/// let adapter = MultiStepMockLlmPort::new(vec![
293///     "Step 1 response".to_string(),
294///     "Step 2 response".to_string(),
295/// ]);
296/// ```
297#[derive(Debug)]
298pub struct MultiStepMockLlmPort {
299    responses: Vec<String>,
300    call_count: Arc<Mutex<usize>>,
301}
302
303impl MultiStepMockLlmPort {
304    /// Create a new multi-step mock with the given response sequence.
305    pub fn new(responses: Vec<String>) -> Self {
306        Self {
307            responses,
308            call_count: Arc::new(Mutex::new(0)),
309        }
310    }
311
312    /// Return the number of times [`LlmPort::generate`] has been called.
313    pub fn call_count(&self) -> usize {
314        *self.call_count.lock().unwrap()
315    }
316}
317
318#[async_trait]
319impl LlmPort for MultiStepMockLlmPort {
320    async fn generate(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
321        let mut count = self.call_count.lock().unwrap();
322        let index = *count;
323        *count += 1;
324        drop(count);
325
326        let content = self
327            .responses
328            .get(index)
329            .cloned()
330            .unwrap_or_else(|| format!("Mock step {} response", index));
331
332        Ok(LlmResponse {
333            id: Uuid::new_v4(),
334            request_id: request.id,
335            model: request.model.clone(),
336            content,
337            finish_reason: FinishReason::Stop,
338            usage: TokenUsage {
339                prompt_tokens: 10,
340                completion_tokens: 20,
341                total_tokens: 30,
342            },
343            created_at: Utc::now(),
344            metadata: HashMap::new(),
345            function_call: None,
346        })
347    }
348
349    async fn generate_stream(
350        &self,
351        request: LlmRequest,
352    ) -> Result<Box<dyn futures::Stream<Item = Result<StreamingResponse, LlmError>> + Send>, LlmError>
353    {
354        let response = self.generate(request).await?;
355        let chunks = vec![Ok(StreamingResponse {
356            id: Uuid::new_v4(),
357            delta: response.content,
358            finish_reason: Some(FinishReason::Stop),
359        })];
360        Ok(Box::new(stream::iter(chunks)))
361    }
362
363    async fn validate_model(&self, _model: &str) -> Result<bool, LlmError> {
364        Ok(true)
365    }
366
367    async fn get_available_models(&self) -> Result<Vec<String>, LlmError> {
368        Ok(vec!["mock-model".to_string()])
369    }
370
371    fn get_provider_name(&self) -> &'static str {
372        "multi-step-mock"
373    }
374
375    fn get_capabilities(&self) -> ProviderCapabilities {
376        ProviderCapabilities {
377            supports_streaming: true,
378            supports_tool_calling: false,
379            supports_function_calling: false,
380            supports_vision: false,
381            max_context_tokens: Some(4096),
382            supports_embeddings: false,
383            supports_system_messages: true,
384        }
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use paladin_core::platform::container::prompt::{PromptItem, PromptType, UserPrompt};
392    use paladin_ports::output::llm_port::LlmPort;
393    use uuid::Uuid;
394
395    fn make_request() -> LlmRequest {
396        let prompt = PromptItem::new(PromptType::User(UserPrompt {
397            query: "test query".to_string(),
398            context: None,
399        }))
400        .unwrap();
401        LlmRequest {
402            id: Uuid::new_v4(),
403            model: "mock-model".to_string(),
404            prompt,
405            attachments: vec![],
406            stream: false,
407            metadata: HashMap::new(),
408        }
409    }
410
411    #[tokio::test]
412    async fn test_mock_returns_default_response() {
413        let adapter = MockLlmAdapter::new();
414        let request = make_request();
415        let response = adapter.generate(request).await.unwrap();
416        assert_eq!(response.content, "Mock LLM response");
417    }
418
419    #[tokio::test]
420    async fn test_mock_cycles_responses() {
421        let adapter =
422            MockLlmAdapter::new().with_responses(vec!["First".to_string(), "Second".to_string()]);
423        let r1 = adapter.generate(make_request()).await.unwrap();
424        let r2 = adapter.generate(make_request()).await.unwrap();
425        let r3 = adapter.generate(make_request()).await.unwrap(); // cycles back
426        assert_eq!(r1.content, "First");
427        assert_eq!(r2.content, "Second");
428        assert_eq!(r3.content, "First");
429    }
430
431    #[tokio::test]
432    async fn test_mock_tracks_call_count() {
433        let adapter = MockLlmAdapter::new();
434        assert_eq!(adapter.call_count(), 0);
435        adapter.generate(make_request()).await.unwrap();
436        assert_eq!(adapter.call_count(), 1);
437    }
438
439    #[tokio::test]
440    async fn test_mock_returns_error() {
441        let adapter = MockLlmAdapter::new().with_error(LlmError::RateLimitExceeded);
442        let result = adapter.generate(make_request()).await;
443        assert!(matches!(result, Err(LlmError::RateLimitExceeded)));
444    }
445
446    #[tokio::test]
447    async fn test_multi_step_returns_sequence() {
448        let adapter = MultiStepMockLlmPort::new(vec![
449            "Step 1".to_string(),
450            "Step 2".to_string(),
451            "Step 3".to_string(),
452        ]);
453        let r1 = adapter.generate(make_request()).await.unwrap();
454        let r2 = adapter.generate(make_request()).await.unwrap();
455        let r3 = adapter.generate(make_request()).await.unwrap();
456        assert_eq!(r1.content, "Step 1");
457        assert_eq!(r2.content, "Step 2");
458        assert_eq!(r3.content, "Step 3");
459    }
460
461    #[tokio::test]
462    async fn test_multi_step_tracks_call_count() {
463        let adapter = MultiStepMockLlmPort::new(vec!["A".to_string()]);
464        assert_eq!(adapter.call_count(), 0);
465        adapter.generate(make_request()).await.unwrap();
466        assert_eq!(adapter.call_count(), 1);
467    }
468}