Skip to main content

aster/providers/
testprovider.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use sha2::{Digest, Sha256};
5use std::collections::HashMap;
6use std::fs;
7use std::path::Path;
8use std::sync::{Arc, Mutex};
9
10use super::base::{Provider, ProviderMetadata, ProviderUsage};
11use super::errors::ProviderError;
12use crate::conversation::message::Message;
13use crate::model::ModelConfig;
14use rmcp::model::Tool;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17struct TestInput {
18    system: String,
19    messages: Vec<Message>,
20    tools: Vec<Tool>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24struct TestOutput {
25    message: Message,
26    usage: ProviderUsage,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30struct TestRecord {
31    input: TestInput,
32    output: TestOutput,
33}
34
35pub struct TestProvider {
36    inner: Option<Arc<dyn Provider>>,
37    records: Arc<Mutex<HashMap<String, TestRecord>>>,
38    file_path: String,
39    name: String,
40}
41
42impl TestProvider {
43    pub fn new_recording(inner: Arc<dyn Provider>, file_path: impl Into<String>) -> Self {
44        Self {
45            inner: Some(inner),
46            records: Arc::new(Mutex::new(HashMap::new())),
47            file_path: file_path.into(),
48            name: Self::metadata().name,
49        }
50    }
51
52    pub fn new_replaying(file_path: impl Into<String>) -> Result<Self> {
53        let file_path = file_path.into();
54        let records = Self::load_records(&file_path)?;
55
56        Ok(Self {
57            inner: None,
58            records: Arc::new(Mutex::new(records)),
59            file_path,
60            name: Self::metadata().name,
61        })
62    }
63
64    pub fn finish_recording(self) -> Result<()> {
65        if self.inner.is_some() {
66            self.save_records()?;
67        }
68        Ok(())
69    }
70
71    fn hash_input(messages: &[Message]) -> String {
72        let stable_messages: Vec<_> = messages
73            .iter()
74            .map(|msg| (msg.role.clone(), msg.content.clone()))
75            .collect();
76        let serialized = serde_json::to_string(&stable_messages).unwrap_or_default();
77        let mut hasher = Sha256::new();
78        hasher.update(serialized.as_bytes());
79        format!("{:x}", hasher.finalize())
80    }
81
82    fn load_records(file_path: &str) -> Result<HashMap<String, TestRecord>> {
83        if !Path::new(file_path).exists() {
84            return Ok(HashMap::new());
85        }
86
87        let content = fs::read_to_string(file_path)?;
88        let records: HashMap<String, TestRecord> = serde_json::from_str(&content)?;
89        Ok(records)
90    }
91
92    pub fn save_records(&self) -> Result<()> {
93        let records = self.records.lock().unwrap();
94        let content = serde_json::to_string_pretty(&*records)?;
95        fs::write(&self.file_path, content)?;
96        Ok(())
97    }
98
99    pub fn get_record_count(&self) -> usize {
100        self.records.lock().unwrap().len()
101    }
102}
103
104#[async_trait]
105impl Provider for TestProvider {
106    fn metadata() -> ProviderMetadata {
107        ProviderMetadata::new(
108            "test",
109            "Test Provider",
110            "Provider for testing that can record/replay interactions",
111            "test-model",
112            vec!["test-model"],
113            "",
114            vec![],
115        )
116    }
117
118    fn get_name(&self) -> &str {
119        &self.name
120    }
121
122    async fn complete_with_model(
123        &self,
124        _model_config: &ModelConfig,
125        system: &str,
126        messages: &[Message],
127        tools: &[Tool],
128    ) -> Result<(Message, ProviderUsage), ProviderError> {
129        let hash = Self::hash_input(messages);
130
131        if let Some(inner) = &self.inner {
132            let (message, usage) = inner.complete(system, messages, tools).await?;
133
134            let record = TestRecord {
135                input: TestInput {
136                    system: system.to_string(),
137                    messages: messages.to_vec(),
138                    tools: tools.to_vec(),
139                },
140                output: TestOutput {
141                    message: message.clone(),
142                    usage: usage.clone(),
143                },
144            };
145
146            {
147                let mut records = self.records.lock().unwrap();
148                records.insert(hash, record);
149            }
150
151            Ok((message, usage))
152        } else {
153            let records = self.records.lock().unwrap();
154            if let Some(record) = records.get(&hash) {
155                Ok((record.output.message.clone(), record.output.usage.clone()))
156            } else {
157                Err(ProviderError::ExecutionError(format!(
158                    "No recorded response found for input hash: {}",
159                    hash
160                )))
161            }
162        }
163    }
164
165    fn get_model_config(&self) -> ModelConfig {
166        ModelConfig::new_or_fail("test-model")
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::conversation::message::{Message, MessageContent};
174    use crate::providers::base::{ProviderUsage, Usage};
175    use chrono::Utc;
176    use rmcp::model::{RawTextContent, Role, TextContent};
177    use std::env;
178
179    #[derive(Clone)]
180    struct MockProvider {
181        model_config: ModelConfig,
182        response: String,
183    }
184
185    #[async_trait]
186    impl Provider for MockProvider {
187        fn metadata() -> ProviderMetadata {
188            ProviderMetadata::new(
189                "mock",
190                "Mock Provider",
191                "Mock provider for testing",
192                "mock-model",
193                vec!["mock-model"],
194                "",
195                vec![],
196            )
197        }
198
199        fn get_name(&self) -> &str {
200            "mock-testprovider"
201        }
202
203        async fn complete_with_model(
204            &self,
205            _model_config: &ModelConfig,
206            _system: &str,
207            _messages: &[Message],
208            _tools: &[Tool],
209        ) -> Result<(Message, ProviderUsage), ProviderError> {
210            Ok((
211                Message::new(
212                    Role::Assistant,
213                    Utc::now().timestamp(),
214                    vec![MessageContent::Text(TextContent {
215                        raw: RawTextContent {
216                            text: self.response.clone(),
217                            meta: None,
218                        },
219                        annotations: None,
220                    })],
221                ),
222                ProviderUsage::new("mock-model".to_string(), Usage::default()),
223            ))
224        }
225
226        fn get_model_config(&self) -> ModelConfig {
227            self.model_config.clone()
228        }
229    }
230
231    #[tokio::test]
232    async fn test_record_and_replay() {
233        let temp_file = format!(
234            "{}/test_records_{}.json",
235            env::temp_dir().display(),
236            std::process::id()
237        );
238
239        let mock = Arc::new(MockProvider {
240            model_config: ModelConfig::new_or_fail("mock-model"),
241            response: "Hello, world!".to_string(),
242        });
243
244        {
245            let test_provider = TestProvider::new_recording(mock, &temp_file);
246
247            let result = test_provider.complete("You are helpful", &[], &[]).await;
248
249            assert!(result.is_ok());
250            let (message, _) = result.unwrap();
251
252            if let MessageContent::Text(content) = &message.content[0] {
253                assert_eq!(content.text, "Hello, world!");
254            }
255
256            assert_eq!(test_provider.get_record_count(), 1);
257            test_provider.finish_recording().unwrap();
258        }
259
260        {
261            let replay_provider = TestProvider::new_replaying(&temp_file).unwrap();
262
263            let result = replay_provider.complete("You are helpful", &[], &[]).await;
264
265            assert!(result.is_ok());
266            let (message, _) = result.unwrap();
267
268            if let MessageContent::Text(content) = &message.content[0] {
269                assert_eq!(content.text, "Hello, world!");
270            }
271        }
272
273        let _ = fs::remove_file(temp_file);
274    }
275
276    #[tokio::test]
277    async fn test_replay_missing_record() {
278        let temp_file = format!(
279            "{}/test_missing_{}.json",
280            env::temp_dir().display(),
281            std::process::id()
282        );
283
284        let replay_provider = TestProvider::new_replaying(&temp_file).unwrap();
285
286        let result = replay_provider
287            .complete("Different system prompt", &[], &[])
288            .await;
289
290        assert!(result.is_err());
291        assert!(result
292            .unwrap_err()
293            .to_string()
294            .contains("No recorded response found"));
295
296        let _ = fs::remove_file(temp_file);
297    }
298}