Skip to main content

agentrs_core/
testing.rs

1//! Test utilities for SDK users and internal crates.
2
3use std::{collections::VecDeque, sync::Arc};
4
5use async_trait::async_trait;
6use futures::{stream, StreamExt};
7use tokio::sync::Mutex;
8
9use crate::{
10    CompletionRequest, CompletionResponse, LlmProvider, ProviderError, Result, StreamChunk,
11};
12
13/// Mock provider for deterministic tests.
14#[derive(Clone, Default)]
15pub struct MockLlmProvider {
16    name: String,
17    responses: Arc<Mutex<VecDeque<CompletionResponse>>>,
18}
19
20impl MockLlmProvider {
21    /// Creates an empty mock provider.
22    pub fn new() -> Self {
23        Self {
24            name: "mock".to_string(),
25            responses: Arc::new(Mutex::new(VecDeque::new())),
26        }
27    }
28
29    /// Creates a mock provider seeded with responses.
30    pub fn with_responses(responses: Vec<CompletionResponse>) -> Self {
31        Self {
32            name: "mock".to_string(),
33            responses: Arc::new(Mutex::new(VecDeque::from(responses))),
34        }
35    }
36
37    /// Creates a mock provider from plain text responses.
38    pub fn with_text_responses(texts: Vec<impl Into<String>>) -> Self {
39        Self::with_responses(texts.into_iter().map(CompletionResponse::text).collect())
40    }
41
42    /// Creates a mock provider that first asks for a tool, then returns a final answer.
43    pub fn with_tool_call_sequence(
44        tool_name: impl Into<String>,
45        arguments: serde_json::Value,
46        final_text: impl Into<String>,
47    ) -> Self {
48        Self::with_responses(vec![
49            CompletionResponse::tool_call(tool_name, arguments),
50            CompletionResponse::text(final_text),
51        ])
52    }
53
54    /// Queues another response.
55    pub async fn push_response(&self, response: CompletionResponse) {
56        self.responses.lock().await.push_back(response);
57    }
58
59    async fn pop_response(&self) -> Result<CompletionResponse> {
60        self.responses.lock().await.pop_front().ok_or_else(|| {
61            ProviderError::InvalidResponse("mock provider exhausted".to_string()).into()
62        })
63    }
64}
65
66#[async_trait]
67impl LlmProvider for MockLlmProvider {
68    async fn complete(&self, _req: CompletionRequest) -> Result<CompletionResponse> {
69        self.pop_response().await
70    }
71
72    async fn stream(
73        &self,
74        _req: CompletionRequest,
75    ) -> Result<futures::stream::BoxStream<'_, Result<StreamChunk>>> {
76        let response = self.pop_response().await?;
77        let text = response.message.text_content();
78        let chunks = text
79            .split_whitespace()
80            .map(|token| {
81                Ok(StreamChunk {
82                    delta: format!("{token} "),
83                    tool_call_delta: None,
84                    finish_reason: None,
85                })
86            })
87            .chain(std::iter::once(Ok(StreamChunk {
88                delta: String::new(),
89                tool_call_delta: None,
90                finish_reason: Some(response.stop_reason),
91            })))
92            .collect::<Vec<_>>();
93
94        Ok(stream::iter(chunks).boxed())
95    }
96
97    fn name(&self) -> &str {
98        &self.name
99    }
100}