use crate::spec_ai_core::agent::model::{
GenerationConfig, ModelProvider, ModelResponse, ModelStreamItem, ProviderKind,
ProviderMetadata, TokenUsage,
};
use anyhow::Result;
use async_stream::stream;
use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
#[derive(Debug, Clone)]
pub struct MockProvider {
responses: Vec<String>,
current_index: std::sync::Arc<std::sync::Mutex<usize>>,
model_name: String,
}
impl MockProvider {
pub fn new(response: impl Into<String>) -> Self {
Self {
responses: vec![response.into()],
current_index: std::sync::Arc::new(std::sync::Mutex::new(0)),
model_name: "mock-model".to_string(),
}
}
pub fn with_responses(responses: Vec<String>) -> Self {
Self {
responses,
current_index: std::sync::Arc::new(std::sync::Mutex::new(0)),
model_name: "mock-model".to_string(),
}
}
pub fn with_model_name(mut self, model_name: impl Into<String>) -> Self {
self.model_name = model_name.into();
self
}
fn next_response(&self) -> String {
let mut index = self.current_index.lock().unwrap();
let response = self.responses[*index % self.responses.len()].clone();
*index += 1;
response
}
}
impl Default for MockProvider {
fn default() -> Self {
Self::new("This is a mock response from the test provider.")
}
}
#[async_trait]
impl ModelProvider for MockProvider {
async fn generate(&self, _prompt: &str, _config: &GenerationConfig) -> Result<ModelResponse> {
let content = self.next_response();
let prompt_tokens = 10; let completion_tokens = content.split_whitespace().count() as u32;
Ok(ModelResponse {
content,
model: self.model_name.clone(),
usage: Some(TokenUsage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
}),
finish_reason: Some("stop".to_string()),
tool_calls: None,
reasoning: None,
})
}
async fn stream(
&self,
_prompt: &str,
_config: &GenerationConfig,
) -> Result<Pin<Box<dyn Stream<Item = Result<ModelStreamItem>> + Send>>> {
let content = self.next_response();
let words: Vec<String> = content.split_whitespace().map(|s| s.to_string()).collect();
let prompt_tokens = 10;
let completion_tokens = words.len() as u32;
let stream = stream! {
use crate::spec_ai_core::agent::model::ModelStreamItem;
for word in words {
yield Ok(ModelStreamItem::Content(format!("{} ", word)));
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
yield Ok(ModelStreamItem::Usage(TokenUsage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
}));
yield Ok(ModelStreamItem::FinishReason("stop".to_string()));
};
Ok(Box::pin(stream))
}
fn metadata(&self) -> ProviderMetadata {
ProviderMetadata {
name: "Mock Provider".to_string(),
supported_models: vec![
"mock-model".to_string(),
"mock-gpt-5".to_string(),
"mock-claude-3".to_string(),
],
supports_streaming: true,
}
}
fn kind(&self) -> ProviderKind {
ProviderKind::Mock
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
#[tokio::test]
async fn test_mock_provider_generate() {
let provider = MockProvider::new("Hello, world!");
let config = GenerationConfig::default();
let response = provider.generate("test prompt", &config).await.unwrap();
assert_eq!(response.content, "Hello, world!");
assert_eq!(response.model, "mock-model");
assert!(response.usage.is_some());
assert_eq!(response.finish_reason, Some("stop".to_string()));
}
#[tokio::test]
async fn test_mock_provider_multiple_responses() {
let provider = MockProvider::with_responses(vec![
"First response".to_string(),
"Second response".to_string(),
"Third response".to_string(),
]);
let config = GenerationConfig::default();
let resp1 = provider.generate("prompt", &config).await.unwrap();
assert_eq!(resp1.content, "First response");
let resp2 = provider.generate("prompt", &config).await.unwrap();
assert_eq!(resp2.content, "Second response");
let resp3 = provider.generate("prompt", &config).await.unwrap();
assert_eq!(resp3.content, "Third response");
let resp4 = provider.generate("prompt", &config).await.unwrap();
assert_eq!(resp4.content, "First response");
}
#[tokio::test]
async fn test_mock_provider_stream() {
let provider = MockProvider::new("Hello world test");
let config = GenerationConfig::default();
let mut stream = provider.stream("test prompt", &config).await.unwrap();
let mut content = String::new();
let mut usage = None;
let mut finish_reason = None;
while let Some(item) = stream.next().await {
match item.unwrap() {
ModelStreamItem::Content(text) => content.push_str(&text),
ModelStreamItem::Usage(u) => usage = Some(u),
ModelStreamItem::FinishReason(r) => finish_reason = Some(r),
}
}
assert!(content.contains("Hello"));
assert!(content.contains("world"));
assert!(content.contains("test"));
assert!(usage.is_some());
assert_eq!(finish_reason, Some("stop".to_string()));
}
#[tokio::test]
async fn test_mock_provider_metadata() {
let provider = MockProvider::default();
let metadata = provider.metadata();
assert_eq!(metadata.name, "Mock Provider");
assert!(metadata.supports_streaming);
assert!(!metadata.supported_models.is_empty());
}
#[tokio::test]
async fn test_mock_provider_custom_model_name() {
let provider = MockProvider::new("test").with_model_name("custom-model");
let config = GenerationConfig::default();
let response = provider.generate("prompt", &config).await.unwrap();
assert_eq!(response.model, "custom-model");
}
}