1use async_trait::async_trait;
8use chrono::Utc;
9use futures::stream;
10use paladin_ports::output::llm_port::{
11 FinishReason, LlmError, LlmPort, LlmRequest, LlmResponse, ProviderCapabilities,
12 StreamingResponse, TokenUsage,
13};
14use std::collections::HashMap;
15use std::sync::{Arc, Mutex};
16use std::time::Duration;
17use uuid::Uuid;
18
19#[derive(Debug, Clone)]
21enum MockEntry {
22 Success(String),
23 Error(LlmError),
24}
25
26#[derive(Debug)]
28struct MockState {
29 responses: Vec<MockEntry>,
30 response_index: usize,
31 delay: Option<Duration>,
32 token_usage: TokenUsage,
33 finish_reason: FinishReason,
34 available_models: Vec<String>,
35 call_count: usize,
36}
37
38impl Default for MockState {
39 fn default() -> Self {
40 Self {
41 responses: vec![MockEntry::Success("Mock LLM response".to_string())],
42 response_index: 0,
43 delay: None,
44 token_usage: TokenUsage {
45 prompt_tokens: 10,
46 completion_tokens: 20,
47 total_tokens: 30,
48 },
49 finish_reason: FinishReason::Stop,
50 available_models: vec!["mock-model".to_string()],
51 call_count: 0,
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
73pub struct MockLlmAdapter {
74 state: Arc<Mutex<MockState>>,
75}
76
77impl MockLlmAdapter {
78 pub fn new() -> Self {
80 Self {
81 state: Arc::new(Mutex::new(MockState::default())),
82 }
83 }
84
85 pub fn with_responses(self, responses: Vec<String>) -> Self {
87 let mut state = self.state.lock().unwrap();
88 state.responses = responses.into_iter().map(MockEntry::Success).collect();
89 state.response_index = 0;
90 drop(state);
91 self
92 }
93
94 pub fn with_response(self, response: impl Into<String>) -> Self {
96 self.with_responses(vec![response.into()])
97 }
98
99 pub fn with_error(self, error: LlmError) -> Self {
101 let mut state = self.state.lock().unwrap();
102 state.responses = vec![MockEntry::Error(error)];
103 state.response_index = 0;
104 drop(state);
105 self
106 }
107
108 pub fn with_delay(self, delay: Duration) -> Self {
110 self.state.lock().unwrap().delay = Some(delay);
111 self
112 }
113
114 pub fn with_token_usage_struct(self, usage: TokenUsage) -> Self {
116 self.state.lock().unwrap().token_usage = usage;
117 self
118 }
119
120 pub fn with_token_usage(
122 self,
123 prompt_tokens: u32,
124 completion_tokens: u32,
125 total_tokens: u32,
126 ) -> Self {
127 self.state.lock().unwrap().token_usage = TokenUsage {
128 prompt_tokens,
129 completion_tokens,
130 total_tokens,
131 };
132 self
133 }
134
135 pub fn with_finish_reason(self, reason: FinishReason) -> Self {
137 self.state.lock().unwrap().finish_reason = reason;
138 self
139 }
140
141 pub fn with_available_models(self, models: Vec<String>) -> Self {
143 self.state.lock().unwrap().available_models = models;
144 self
145 }
146
147 pub fn with_error_then_response(self, error: LlmError, response: impl Into<String>) -> Self {
151 let mut state = self.state.lock().unwrap();
152 state.responses = vec![MockEntry::Error(error), MockEntry::Success(response.into())];
153 state.response_index = 0;
154 drop(state);
155 self
156 }
157
158 pub fn call_count(&self) -> usize {
160 self.state.lock().unwrap().call_count
161 }
162
163 pub fn get_call_count(&self) -> usize {
165 self.call_count()
166 }
167
168 pub fn reset(&self) {
170 let mut state = self.state.lock().unwrap();
171 state.call_count = 0;
172 state.response_index = 0;
173 }
174
175 pub fn was_called(&self) -> bool {
177 self.call_count() > 0
178 }
179}
180
181impl Default for MockLlmAdapter {
182 fn default() -> Self {
183 Self::new()
184 }
185}
186
187#[async_trait]
188impl LlmPort for MockLlmAdapter {
189 async fn generate(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
190 let (response_entry, delay, token_usage, finish_reason) = {
191 let mut state = self.state.lock().unwrap();
192 state.call_count += 1;
193 let index = state.response_index;
194 let entry = state
195 .responses
196 .get(index)
197 .cloned()
198 .unwrap_or(MockEntry::Success("Mock LLM response".to_string()));
199 state.response_index = (index + 1) % state.responses.len().max(1);
201 (
202 entry,
203 state.delay,
204 state.token_usage.clone(),
205 state.finish_reason.clone(),
206 )
207 };
208
209 if let Some(delay) = delay {
210 tokio::time::sleep(delay).await;
211 }
212
213 match response_entry {
214 MockEntry::Error(e) => Err(e),
215 MockEntry::Success(content) => Ok(LlmResponse {
216 id: Uuid::new_v4(),
217 request_id: request.id,
218 model: request.model.clone(),
219 content,
220 finish_reason,
221 usage: token_usage,
222 created_at: Utc::now(),
223 metadata: HashMap::new(),
224 function_call: None,
225 }),
226 }
227 }
228
229 async fn generate_stream(
230 &self,
231 request: LlmRequest,
232 ) -> Result<Box<dyn futures::Stream<Item = Result<StreamingResponse, LlmError>> + Send>, LlmError>
233 {
234 let response = self.generate(request).await?;
235 let chunks = vec![
237 Ok(StreamingResponse {
238 id: Uuid::new_v4(),
239 delta: response.content.clone(),
240 finish_reason: None,
241 }),
242 Ok(StreamingResponse {
243 id: Uuid::new_v4(),
244 delta: String::new(),
245 finish_reason: Some(response.finish_reason),
246 }),
247 ];
248 Ok(Box::new(stream::iter(chunks)))
249 }
250
251 async fn validate_model(&self, model: &str) -> Result<bool, LlmError> {
252 let state = self.state.lock().unwrap();
253 Ok(state.available_models.contains(&model.to_string()))
254 }
255
256 async fn get_available_models(&self) -> Result<Vec<String>, LlmError> {
257 Ok(self.state.lock().unwrap().available_models.clone())
258 }
259
260 fn get_provider_name(&self) -> &'static str {
261 "MockLLM"
262 }
263
264 fn get_capabilities(&self) -> ProviderCapabilities {
265 ProviderCapabilities {
266 supports_streaming: true,
267 supports_tool_calling: false,
268 supports_function_calling: false,
269 supports_vision: false,
270 max_context_tokens: Some(4096),
271 supports_embeddings: false,
272 supports_system_messages: true,
273 }
274 }
275}
276
277#[derive(Debug)]
298pub struct MultiStepMockLlmPort {
299 responses: Vec<String>,
300 call_count: Arc<Mutex<usize>>,
301}
302
303impl MultiStepMockLlmPort {
304 pub fn new(responses: Vec<String>) -> Self {
306 Self {
307 responses,
308 call_count: Arc::new(Mutex::new(0)),
309 }
310 }
311
312 pub fn call_count(&self) -> usize {
314 *self.call_count.lock().unwrap()
315 }
316}
317
318#[async_trait]
319impl LlmPort for MultiStepMockLlmPort {
320 async fn generate(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
321 let mut count = self.call_count.lock().unwrap();
322 let index = *count;
323 *count += 1;
324 drop(count);
325
326 let content = self
327 .responses
328 .get(index)
329 .cloned()
330 .unwrap_or_else(|| format!("Mock step {} response", index));
331
332 Ok(LlmResponse {
333 id: Uuid::new_v4(),
334 request_id: request.id,
335 model: request.model.clone(),
336 content,
337 finish_reason: FinishReason::Stop,
338 usage: TokenUsage {
339 prompt_tokens: 10,
340 completion_tokens: 20,
341 total_tokens: 30,
342 },
343 created_at: Utc::now(),
344 metadata: HashMap::new(),
345 function_call: None,
346 })
347 }
348
349 async fn generate_stream(
350 &self,
351 request: LlmRequest,
352 ) -> Result<Box<dyn futures::Stream<Item = Result<StreamingResponse, LlmError>> + Send>, LlmError>
353 {
354 let response = self.generate(request).await?;
355 let chunks = vec![Ok(StreamingResponse {
356 id: Uuid::new_v4(),
357 delta: response.content,
358 finish_reason: Some(FinishReason::Stop),
359 })];
360 Ok(Box::new(stream::iter(chunks)))
361 }
362
363 async fn validate_model(&self, _model: &str) -> Result<bool, LlmError> {
364 Ok(true)
365 }
366
367 async fn get_available_models(&self) -> Result<Vec<String>, LlmError> {
368 Ok(vec!["mock-model".to_string()])
369 }
370
371 fn get_provider_name(&self) -> &'static str {
372 "multi-step-mock"
373 }
374
375 fn get_capabilities(&self) -> ProviderCapabilities {
376 ProviderCapabilities {
377 supports_streaming: true,
378 supports_tool_calling: false,
379 supports_function_calling: false,
380 supports_vision: false,
381 max_context_tokens: Some(4096),
382 supports_embeddings: false,
383 supports_system_messages: true,
384 }
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use paladin_core::platform::container::prompt::{PromptItem, PromptType, UserPrompt};
392 use paladin_ports::output::llm_port::LlmPort;
393 use uuid::Uuid;
394
395 fn make_request() -> LlmRequest {
396 let prompt = PromptItem::new(PromptType::User(UserPrompt {
397 query: "test query".to_string(),
398 context: None,
399 }))
400 .unwrap();
401 LlmRequest {
402 id: Uuid::new_v4(),
403 model: "mock-model".to_string(),
404 prompt,
405 attachments: vec![],
406 stream: false,
407 metadata: HashMap::new(),
408 }
409 }
410
411 #[tokio::test]
412 async fn test_mock_returns_default_response() {
413 let adapter = MockLlmAdapter::new();
414 let request = make_request();
415 let response = adapter.generate(request).await.unwrap();
416 assert_eq!(response.content, "Mock LLM response");
417 }
418
419 #[tokio::test]
420 async fn test_mock_cycles_responses() {
421 let adapter =
422 MockLlmAdapter::new().with_responses(vec!["First".to_string(), "Second".to_string()]);
423 let r1 = adapter.generate(make_request()).await.unwrap();
424 let r2 = adapter.generate(make_request()).await.unwrap();
425 let r3 = adapter.generate(make_request()).await.unwrap(); assert_eq!(r1.content, "First");
427 assert_eq!(r2.content, "Second");
428 assert_eq!(r3.content, "First");
429 }
430
431 #[tokio::test]
432 async fn test_mock_tracks_call_count() {
433 let adapter = MockLlmAdapter::new();
434 assert_eq!(adapter.call_count(), 0);
435 adapter.generate(make_request()).await.unwrap();
436 assert_eq!(adapter.call_count(), 1);
437 }
438
439 #[tokio::test]
440 async fn test_mock_returns_error() {
441 let adapter = MockLlmAdapter::new().with_error(LlmError::RateLimitExceeded);
442 let result = adapter.generate(make_request()).await;
443 assert!(matches!(result, Err(LlmError::RateLimitExceeded)));
444 }
445
446 #[tokio::test]
447 async fn test_multi_step_returns_sequence() {
448 let adapter = MultiStepMockLlmPort::new(vec![
449 "Step 1".to_string(),
450 "Step 2".to_string(),
451 "Step 3".to_string(),
452 ]);
453 let r1 = adapter.generate(make_request()).await.unwrap();
454 let r2 = adapter.generate(make_request()).await.unwrap();
455 let r3 = adapter.generate(make_request()).await.unwrap();
456 assert_eq!(r1.content, "Step 1");
457 assert_eq!(r2.content, "Step 2");
458 assert_eq!(r3.content, "Step 3");
459 }
460
461 #[tokio::test]
462 async fn test_multi_step_tracks_call_count() {
463 let adapter = MultiStepMockLlmPort::new(vec!["A".to_string()]);
464 assert_eq!(adapter.call_count(), 0);
465 adapter.generate(make_request()).await.unwrap();
466 assert_eq!(adapter.call_count(), 1);
467 }
468}