1use async_trait::async_trait;
4use std::sync::Arc;
5use tokio::sync::Mutex;
6
7use super::provider::{
8 CompletionRequest, CompletionResponse, CompletionStream, LlmProvider, StopReason, TokenUsage,
9};
10use crate::Result;
11
12#[derive(Clone)]
16pub struct MockLlmProvider {
17 responses: Arc<Mutex<MockResponses>>,
18}
19
20struct MockResponses {
21 canned: Vec<String>,
22 index: usize,
23}
24
25impl MockLlmProvider {
26 pub fn new(responses: Vec<String>) -> Self {
42 Self {
43 responses: Arc::new(Mutex::new(MockResponses {
44 canned: responses,
45 index: 0,
46 })),
47 }
48 }
49
50 pub fn with_response(response: impl Into<String>) -> Self {
52 Self::new(vec![response.into()])
53 }
54}
55
56#[async_trait]
57impl LlmProvider for MockLlmProvider {
58 async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse> {
59 let mut responses = self.responses.lock().await;
60
61 let content = responses.canned[responses.index].clone();
63
64 responses.index = (responses.index + 1) % responses.canned.len();
66
67 Ok(CompletionResponse {
68 content,
69 tokens_used: TokenUsage {
70 input: 10, output: 20,
72 },
73 stop_reason: StopReason::EndTurn,
74 })
75 }
76
77 async fn complete_streaming(&self, _request: CompletionRequest) -> Result<CompletionStream> {
78 Err(crate::Error::llm(
80 "Streaming not supported in mock provider",
81 ))
82 }
83}
84
85#[cfg(test)]
86#[allow(clippy::unwrap_used)]
87mod tests {
88 use super::*;
89 use crate::llm::Message;
90
91 #[tokio::test]
92 async fn test_mock_provider_single_response() {
93 let provider = MockLlmProvider::with_response("Test response");
94
95 let request = CompletionRequest::new(vec![Message::user("Hello")]);
96
97 let response = provider.complete(request).await.unwrap();
98 assert_eq!(response.content, "Test response");
99 }
100
101 #[tokio::test]
102 async fn test_mock_provider_multiple_responses() {
103 let provider = MockLlmProvider::new(vec![
104 "First".to_string(),
105 "Second".to_string(),
106 "Third".to_string(),
107 ]);
108
109 let request = CompletionRequest::new(vec![Message::user("Test")]);
110
111 assert_eq!(
112 provider.complete(request.clone()).await.unwrap().content,
113 "First"
114 );
115 assert_eq!(
116 provider.complete(request.clone()).await.unwrap().content,
117 "Second"
118 );
119 assert_eq!(
120 provider.complete(request.clone()).await.unwrap().content,
121 "Third"
122 );
123 assert_eq!(
125 provider.complete(request.clone()).await.unwrap().content,
126 "First"
127 );
128 }
129
130 #[tokio::test]
131 async fn test_mock_provider_clone() {
132 let provider = MockLlmProvider::with_response("Shared");
133 let provider2 = provider.clone();
134
135 let request = CompletionRequest::new(vec![Message::user("Test")]);
136
137 provider.complete(request.clone()).await.unwrap();
139 let response = provider2.complete(request).await.unwrap();
141 assert_eq!(response.content, "Shared");
142 }
143}