spec-ai 0.8.4

A framework for building AI agents with structured outputs, policy enforcement, and execution tracking
Documentation
//! Mock Model Provider
//!
//! A simple mock provider for testing that returns canned responses.

use crate::spec_ai_core::agent::model::{
    GenerationConfig, ModelProvider, ModelResponse, ModelStreamItem, ProviderKind,
    ProviderMetadata, TokenUsage,
};
use anyhow::Result;
use async_stream::stream;
use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;

/// Mock provider that returns predefined responses
#[derive(Debug, Clone)]
pub struct MockProvider {
    /// Canned responses to cycle through
    responses: Vec<String>,
    /// Current response index
    current_index: std::sync::Arc<std::sync::Mutex<usize>>,
    /// Model name to report
    model_name: String,
}

impl MockProvider {
    /// Create a new mock provider with a single response
    pub fn new(response: impl Into<String>) -> Self {
        Self {
            responses: vec![response.into()],
            current_index: std::sync::Arc::new(std::sync::Mutex::new(0)),
            model_name: "mock-model".to_string(),
        }
    }

    /// Create a new mock provider with multiple responses
    pub fn with_responses(responses: Vec<String>) -> Self {
        Self {
            responses,
            current_index: std::sync::Arc::new(std::sync::Mutex::new(0)),
            model_name: "mock-model".to_string(),
        }
    }

    /// Set the model name
    pub fn with_model_name(mut self, model_name: impl Into<String>) -> Self {
        self.model_name = model_name.into();
        self
    }

    /// Get the next response (cycles through available responses)
    fn next_response(&self) -> String {
        let mut index = self.current_index.lock().unwrap();
        let response = self.responses[*index % self.responses.len()].clone();
        *index += 1;
        response
    }
}

impl Default for MockProvider {
    fn default() -> Self {
        Self::new("This is a mock response from the test provider.")
    }
}

#[async_trait]
impl ModelProvider for MockProvider {
    async fn generate(&self, _prompt: &str, _config: &GenerationConfig) -> Result<ModelResponse> {
        let content = self.next_response();
        let prompt_tokens = 10; // Mock values
        let completion_tokens = content.split_whitespace().count() as u32;

        Ok(ModelResponse {
            content,
            model: self.model_name.clone(),
            usage: Some(TokenUsage {
                prompt_tokens,
                completion_tokens,
                total_tokens: prompt_tokens + completion_tokens,
            }),
            finish_reason: Some("stop".to_string()),
            tool_calls: None,
            reasoning: None,
        })
    }

    async fn stream(
        &self,
        _prompt: &str,
        _config: &GenerationConfig,
    ) -> Result<Pin<Box<dyn Stream<Item = Result<ModelStreamItem>> + Send>>> {
        let content = self.next_response();
        let words: Vec<String> = content.split_whitespace().map(|s| s.to_string()).collect();
        let prompt_tokens = 10;
        let completion_tokens = words.len() as u32;

        let stream = stream! {
            use crate::spec_ai_core::agent::model::ModelStreamItem;
            for word in words {
                yield Ok(ModelStreamItem::Content(format!("{} ", word)));
                // Simulate network delay
                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
            }

            // Yield usage at the end
            yield Ok(ModelStreamItem::Usage(TokenUsage {
                prompt_tokens,
                completion_tokens,
                total_tokens: prompt_tokens + completion_tokens,
            }));

            // Yield finish reason
            yield Ok(ModelStreamItem::FinishReason("stop".to_string()));
        };

        Ok(Box::pin(stream))
    }

    fn metadata(&self) -> ProviderMetadata {
        ProviderMetadata {
            name: "Mock Provider".to_string(),
            supported_models: vec![
                "mock-model".to_string(),
                "mock-gpt-5".to_string(),
                "mock-claude-3".to_string(),
            ],
            supports_streaming: true,
        }
    }

    fn kind(&self) -> ProviderKind {
        ProviderKind::Mock
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use futures::StreamExt;

    #[tokio::test]
    async fn test_mock_provider_generate() {
        let provider = MockProvider::new("Hello, world!");
        let config = GenerationConfig::default();

        let response = provider.generate("test prompt", &config).await.unwrap();

        assert_eq!(response.content, "Hello, world!");
        assert_eq!(response.model, "mock-model");
        assert!(response.usage.is_some());
        assert_eq!(response.finish_reason, Some("stop".to_string()));
    }

    #[tokio::test]
    async fn test_mock_provider_multiple_responses() {
        let provider = MockProvider::with_responses(vec![
            "First response".to_string(),
            "Second response".to_string(),
            "Third response".to_string(),
        ]);
        let config = GenerationConfig::default();

        let resp1 = provider.generate("prompt", &config).await.unwrap();
        assert_eq!(resp1.content, "First response");

        let resp2 = provider.generate("prompt", &config).await.unwrap();
        assert_eq!(resp2.content, "Second response");

        let resp3 = provider.generate("prompt", &config).await.unwrap();
        assert_eq!(resp3.content, "Third response");

        // Should cycle back to first
        let resp4 = provider.generate("prompt", &config).await.unwrap();
        assert_eq!(resp4.content, "First response");
    }

    #[tokio::test]
    async fn test_mock_provider_stream() {
        let provider = MockProvider::new("Hello world test");
        let config = GenerationConfig::default();

        let mut stream = provider.stream("test prompt", &config).await.unwrap();
        let mut content = String::new();
        let mut usage = None;
        let mut finish_reason = None;

        while let Some(item) = stream.next().await {
            match item.unwrap() {
                ModelStreamItem::Content(text) => content.push_str(&text),
                ModelStreamItem::Usage(u) => usage = Some(u),
                ModelStreamItem::FinishReason(r) => finish_reason = Some(r),
            }
        }

        assert!(content.contains("Hello"));
        assert!(content.contains("world"));
        assert!(content.contains("test"));
        assert!(usage.is_some());
        assert_eq!(finish_reason, Some("stop".to_string()));
    }

    #[tokio::test]
    async fn test_mock_provider_metadata() {
        let provider = MockProvider::default();
        let metadata = provider.metadata();

        assert_eq!(metadata.name, "Mock Provider");
        assert!(metadata.supports_streaming);
        assert!(!metadata.supported_models.is_empty());
    }

    #[tokio::test]
    async fn test_mock_provider_custom_model_name() {
        let provider = MockProvider::new("test").with_model_name("custom-model");
        let config = GenerationConfig::default();

        let response = provider.generate("prompt", &config).await.unwrap();
        assert_eq!(response.model, "custom-model");
    }
}