1use anyhow::Result;
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use sha2::{Digest, Sha256};
5use std::collections::HashMap;
6use std::fs;
7use std::path::Path;
8use std::sync::{Arc, Mutex};
9
10use super::base::{Provider, ProviderMetadata, ProviderUsage};
11use super::errors::ProviderError;
12use crate::conversation::message::Message;
13use crate::model::ModelConfig;
14use rmcp::model::Tool;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17struct TestInput {
18 system: String,
19 messages: Vec<Message>,
20 tools: Vec<Tool>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24struct TestOutput {
25 message: Message,
26 usage: ProviderUsage,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30struct TestRecord {
31 input: TestInput,
32 output: TestOutput,
33}
34
35pub struct TestProvider {
36 inner: Option<Arc<dyn Provider>>,
37 records: Arc<Mutex<HashMap<String, TestRecord>>>,
38 file_path: String,
39 name: String,
40}
41
42impl TestProvider {
43 pub fn new_recording(inner: Arc<dyn Provider>, file_path: impl Into<String>) -> Self {
44 Self {
45 inner: Some(inner),
46 records: Arc::new(Mutex::new(HashMap::new())),
47 file_path: file_path.into(),
48 name: Self::metadata().name,
49 }
50 }
51
52 pub fn new_replaying(file_path: impl Into<String>) -> Result<Self> {
53 let file_path = file_path.into();
54 let records = Self::load_records(&file_path)?;
55
56 Ok(Self {
57 inner: None,
58 records: Arc::new(Mutex::new(records)),
59 file_path,
60 name: Self::metadata().name,
61 })
62 }
63
64 pub fn finish_recording(self) -> Result<()> {
65 if self.inner.is_some() {
66 self.save_records()?;
67 }
68 Ok(())
69 }
70
71 fn hash_input(messages: &[Message]) -> String {
72 let stable_messages: Vec<_> = messages
73 .iter()
74 .map(|msg| (msg.role.clone(), msg.content.clone()))
75 .collect();
76 let serialized = serde_json::to_string(&stable_messages).unwrap_or_default();
77 let mut hasher = Sha256::new();
78 hasher.update(serialized.as_bytes());
79 format!("{:x}", hasher.finalize())
80 }
81
82 fn load_records(file_path: &str) -> Result<HashMap<String, TestRecord>> {
83 if !Path::new(file_path).exists() {
84 return Ok(HashMap::new());
85 }
86
87 let content = fs::read_to_string(file_path)?;
88 let records: HashMap<String, TestRecord> = serde_json::from_str(&content)?;
89 Ok(records)
90 }
91
92 pub fn save_records(&self) -> Result<()> {
93 let records = self.records.lock().unwrap();
94 let content = serde_json::to_string_pretty(&*records)?;
95 fs::write(&self.file_path, content)?;
96 Ok(())
97 }
98
99 pub fn get_record_count(&self) -> usize {
100 self.records.lock().unwrap().len()
101 }
102}
103
104#[async_trait]
105impl Provider for TestProvider {
106 fn metadata() -> ProviderMetadata {
107 ProviderMetadata::new(
108 "test",
109 "Test Provider",
110 "Provider for testing that can record/replay interactions",
111 "test-model",
112 vec!["test-model"],
113 "",
114 vec![],
115 )
116 }
117
118 fn get_name(&self) -> &str {
119 &self.name
120 }
121
122 async fn complete_with_model(
123 &self,
124 _model_config: &ModelConfig,
125 system: &str,
126 messages: &[Message],
127 tools: &[Tool],
128 ) -> Result<(Message, ProviderUsage), ProviderError> {
129 let hash = Self::hash_input(messages);
130
131 if let Some(inner) = &self.inner {
132 let (message, usage) = inner.complete(system, messages, tools).await?;
133
134 let record = TestRecord {
135 input: TestInput {
136 system: system.to_string(),
137 messages: messages.to_vec(),
138 tools: tools.to_vec(),
139 },
140 output: TestOutput {
141 message: message.clone(),
142 usage: usage.clone(),
143 },
144 };
145
146 {
147 let mut records = self.records.lock().unwrap();
148 records.insert(hash, record);
149 }
150
151 Ok((message, usage))
152 } else {
153 let records = self.records.lock().unwrap();
154 if let Some(record) = records.get(&hash) {
155 Ok((record.output.message.clone(), record.output.usage.clone()))
156 } else {
157 Err(ProviderError::ExecutionError(format!(
158 "No recorded response found for input hash: {}",
159 hash
160 )))
161 }
162 }
163 }
164
165 fn get_model_config(&self) -> ModelConfig {
166 ModelConfig::new_or_fail("test-model")
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::conversation::message::{Message, MessageContent};
174 use crate::providers::base::{ProviderUsage, Usage};
175 use chrono::Utc;
176 use rmcp::model::{RawTextContent, Role, TextContent};
177 use std::env;
178
179 #[derive(Clone)]
180 struct MockProvider {
181 model_config: ModelConfig,
182 response: String,
183 }
184
185 #[async_trait]
186 impl Provider for MockProvider {
187 fn metadata() -> ProviderMetadata {
188 ProviderMetadata::new(
189 "mock",
190 "Mock Provider",
191 "Mock provider for testing",
192 "mock-model",
193 vec!["mock-model"],
194 "",
195 vec![],
196 )
197 }
198
199 fn get_name(&self) -> &str {
200 "mock-testprovider"
201 }
202
203 async fn complete_with_model(
204 &self,
205 _model_config: &ModelConfig,
206 _system: &str,
207 _messages: &[Message],
208 _tools: &[Tool],
209 ) -> Result<(Message, ProviderUsage), ProviderError> {
210 Ok((
211 Message::new(
212 Role::Assistant,
213 Utc::now().timestamp(),
214 vec![MessageContent::Text(TextContent {
215 raw: RawTextContent {
216 text: self.response.clone(),
217 meta: None,
218 },
219 annotations: None,
220 })],
221 ),
222 ProviderUsage::new("mock-model".to_string(), Usage::default()),
223 ))
224 }
225
226 fn get_model_config(&self) -> ModelConfig {
227 self.model_config.clone()
228 }
229 }
230
231 #[tokio::test]
232 async fn test_record_and_replay() {
233 let temp_file = format!(
234 "{}/test_records_{}.json",
235 env::temp_dir().display(),
236 std::process::id()
237 );
238
239 let mock = Arc::new(MockProvider {
240 model_config: ModelConfig::new_or_fail("mock-model"),
241 response: "Hello, world!".to_string(),
242 });
243
244 {
245 let test_provider = TestProvider::new_recording(mock, &temp_file);
246
247 let result = test_provider.complete("You are helpful", &[], &[]).await;
248
249 assert!(result.is_ok());
250 let (message, _) = result.unwrap();
251
252 if let MessageContent::Text(content) = &message.content[0] {
253 assert_eq!(content.text, "Hello, world!");
254 }
255
256 assert_eq!(test_provider.get_record_count(), 1);
257 test_provider.finish_recording().unwrap();
258 }
259
260 {
261 let replay_provider = TestProvider::new_replaying(&temp_file).unwrap();
262
263 let result = replay_provider.complete("You are helpful", &[], &[]).await;
264
265 assert!(result.is_ok());
266 let (message, _) = result.unwrap();
267
268 if let MessageContent::Text(content) = &message.content[0] {
269 assert_eq!(content.text, "Hello, world!");
270 }
271 }
272
273 let _ = fs::remove_file(temp_file);
274 }
275
276 #[tokio::test]
277 async fn test_replay_missing_record() {
278 let temp_file = format!(
279 "{}/test_missing_{}.json",
280 env::temp_dir().display(),
281 std::process::id()
282 );
283
284 let replay_provider = TestProvider::new_replaying(&temp_file).unwrap();
285
286 let result = replay_provider
287 .complete("Different system prompt", &[], &[])
288 .await;
289
290 assert!(result.is_err());
291 assert!(result
292 .unwrap_err()
293 .to_string()
294 .contains("No recorded response found"));
295
296 let _ = fs::remove_file(temp_file);
297 }
298}