pe-core 0.1.0

Core types for Potential Expectations — messages, channels, state, traits
Documentation
//! MockProvider — deterministic LLM provider for testing.
//!
//! Responses are consumed from a queue in FIFO order.
//! Returns an error if the queue is exhausted.

use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::{Mutex, MutexGuard};

use futures::Stream;

use crate::error::PeError;
use crate::llm::{LlmProvider, LlmResponse, StreamChunk, ToolSchema};
use crate::message::{AiMessage, Message, MessageContent, ToolCall};

/// What the mock should return for a given call.
#[derive(Debug, Clone)]
enum MockResponse {
    /// Return an AiMessage with text content.
    Text(String),
    /// Return an AiMessage with a tool call.
    ToolCall {
        tool_name: String,
        args: serde_json::Value,
    },
    /// Return an error.
    Error(PeError),
}

/// Deterministic LLM provider for tests.
///
/// Queue responses with the builder API, then call `complete()` to
/// consume them in FIFO order.
///
/// # Example
///
/// ```
/// use pe_core::mock_provider::MockProvider;
/// use pe_core::llm::LlmProvider;
///
/// # tokio_test::block_on(async {
/// let provider = MockProvider::new()
///     .respond_with("Hello!")
///     .respond_with("Goodbye!");
///
/// let r1 = provider.complete(&[], &[]).await.unwrap();
/// assert_eq!(r1.message.content.as_text(), Some("Hello!"));
///
/// let r2 = provider.complete(&[], &[]).await.unwrap();
/// assert_eq!(r2.message.content.as_text(), Some("Goodbye!"));
/// # });
/// ```
pub struct MockProvider {
    responses: Mutex<VecDeque<MockResponse>>,
    embed_response: Vec<f32>,
}

impl MockProvider {
    fn responses_guard(&self) -> MutexGuard<'_, VecDeque<MockResponse>> {
        match self.responses.lock() {
            Ok(guard) => guard,
            Err(poisoned) => poisoned.into_inner(),
        }
    }

    /// Create a new MockProvider with an empty response queue.
    pub fn new() -> Self {
        Self {
            responses: Mutex::new(VecDeque::new()),
            embed_response: vec![0.0; 128], // default 128-dim zero vector
        }
    }

    /// Queue a plain text response.
    #[must_use = "builder methods return the modified builder"]
    pub fn respond_with(self, text: impl Into<String>) -> Self {
        self.responses_guard()
            .push_back(MockResponse::Text(text.into()));
        self
    }

    /// Queue a tool call response.
    #[must_use = "builder methods return the modified builder"]
    pub fn respond_with_tool_call(
        self,
        tool_name: impl Into<String>,
        args: serde_json::Value,
    ) -> Self {
        self.responses_guard().push_back(MockResponse::ToolCall {
            tool_name: tool_name.into(),
            args,
        });
        self
    }

    /// Queue an error response.
    #[must_use = "builder methods return the modified builder"]
    pub fn respond_with_error(self, err: PeError) -> Self {
        self.responses_guard().push_back(MockResponse::Error(err));
        self
    }

    /// Set the embedding to return for all `embed()` calls.
    #[must_use = "builder methods return the modified builder"]
    pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
        self.embed_response = embedding;
        self
    }

    /// Number of responses remaining in the queue.
    pub fn remaining(&self) -> usize {
        self.responses_guard().len()
    }

    fn next_response(&self) -> Result<MockResponse, PeError> {
        self.responses_guard()
            .pop_front()
            .ok_or(PeError::MockProviderExhausted)
    }

    fn mock_response_to_llm(resp: MockResponse) -> Result<LlmResponse, PeError> {
        match resp {
            MockResponse::Text(text) => Ok(LlmResponse {
                message: AiMessage {
                    content: MessageContent::Text(text),
                    tool_calls: vec![],
                    invalid_tool_calls: vec![],
                    usage_metadata: None,
                    response_metadata: Default::default(),
                    id: None,
                },
                provider_metadata: Default::default(),
            }),
            MockResponse::ToolCall { tool_name, args } => Ok(LlmResponse {
                message: AiMessage {
                    content: MessageContent::Text(String::new()),
                    tool_calls: vec![ToolCall {
                        id: format!("call_{}", tool_name),
                        name: tool_name,
                        args,
                    }],
                    invalid_tool_calls: vec![],
                    usage_metadata: None,
                    response_metadata: Default::default(),
                    id: None,
                },
                provider_metadata: Default::default(),
            }),
            MockResponse::Error(e) => Err(e),
        }
    }
}

impl Default for MockProvider {
    fn default() -> Self {
        Self::new()
    }
}

impl LlmProvider for MockProvider {
    fn complete(
        &self,
        _messages: &[Message],
        _tools: &[ToolSchema],
    ) -> Pin<Box<dyn std::future::Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
        Box::pin(async move {
            let resp = self.next_response()?;
            Self::mock_response_to_llm(resp)
        })
    }

    fn stream(&self, _messages: &[Message], _tools: &[ToolSchema]) -> crate::llm::StreamFuture<'_> {
        Box::pin(async move {
            let resp = self.next_response()?;
            let llm_resp = Self::mock_response_to_llm(resp)?;

            // For mock streaming, emit the full text as one token then Done
            let text = llm_resp.message.content.as_text().unwrap_or("").to_string();
            let chunks = vec![StreamChunk::Token(text), StreamChunk::Done(llm_resp)];

            Ok(Box::pin(futures::stream::iter(chunks))
                as Pin<Box<dyn Stream<Item = StreamChunk> + Send>>)
        })
    }

    fn embed(
        &self,
        _text: &str,
    ) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
        let embedding = self.embed_response.clone();
        Box::pin(async move { Ok(embedding) })
    }

    fn provider_name(&self) -> &'static str {
        "mock"
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_text_response() {
        let provider = MockProvider::new().respond_with("Hello, world!");

        let resp = provider.complete(&[], &[]).await.unwrap();
        assert_eq!(resp.message.content.as_text(), Some("Hello, world!"));
    }

    #[tokio::test]
    async fn test_tool_call_response() {
        let provider = MockProvider::new()
            .respond_with_tool_call("web_search", serde_json::json!({ "query": "rust async" }));

        let resp = provider.complete(&[], &[]).await.unwrap();
        assert_eq!(resp.message.tool_calls.len(), 1);
        assert_eq!(resp.message.tool_calls[0].name, "web_search");
    }

    #[tokio::test]
    async fn test_multiple_responses_fifo() {
        let provider = MockProvider::new()
            .respond_with("first")
            .respond_with("second")
            .respond_with("third");

        let r1 = provider.complete(&[], &[]).await.unwrap();
        let r2 = provider.complete(&[], &[]).await.unwrap();
        let r3 = provider.complete(&[], &[]).await.unwrap();

        assert_eq!(r1.message.content.as_text(), Some("first"));
        assert_eq!(r2.message.content.as_text(), Some("second"));
        assert_eq!(r3.message.content.as_text(), Some("third"));
    }

    #[tokio::test]
    async fn test_exhausted_queue_returns_error() {
        let provider = MockProvider::new().respond_with("only one");

        let _ = provider.complete(&[], &[]).await.unwrap();
        let err = provider.complete(&[], &[]).await.unwrap_err();

        assert!(matches!(err, PeError::MockProviderExhausted));
    }

    #[tokio::test]
    async fn test_error_response() {
        let provider = MockProvider::new().respond_with_error(PeError::LlmProvider {
            details: "rate limited".into(),
        });

        let err = provider.complete(&[], &[]).await.unwrap_err();
        assert!(matches!(err, PeError::LlmProvider { .. }));
    }

    #[tokio::test]
    async fn test_embed_returns_configured_vector() {
        let provider = MockProvider::new().with_embedding(vec![1.0, 2.0, 3.0]);

        let embedding = provider.embed("test text").await.unwrap();
        assert_eq!(embedding, vec![1.0, 2.0, 3.0]);
    }

    #[tokio::test]
    async fn test_remaining_count() {
        let provider = MockProvider::new().respond_with("a").respond_with("b");

        assert_eq!(provider.remaining(), 2);
        let _ = provider.complete(&[], &[]).await;
        assert_eq!(provider.remaining(), 1);
    }

    #[test]
    fn poisoned_queue_lock_is_recovered() {
        let provider = MockProvider::new().respond_with("hello");

        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
            let _guard = provider.responses.lock().unwrap();
            panic!("poison mock provider");
        }));
        assert!(result.is_err());

        assert_eq!(provider.remaining(), 1);
    }

    #[tokio::test]
    async fn test_object_safety() {
        // Verify LlmProvider is object-safe
        let provider: Box<dyn LlmProvider> = Box::new(MockProvider::new().respond_with("boxed"));
        let resp = provider.complete(&[], &[]).await.unwrap();
        assert_eq!(resp.message.content.as_text(), Some("boxed"));
    }
}