Skip to main content

adk_model/
mock.rs

1use adk_core::{Llm, LlmRequest, LlmResponse, LlmResponseStream, Result};
2use async_trait::async_trait;
3
4/// A mock LLM implementation for testing without real API calls.
5///
6/// Returns pre-configured responses in order when `generate_content` is called.
7pub struct MockLlm {
8    name: String,
9    responses: Vec<LlmResponse>,
10}
11
12impl MockLlm {
13    /// Create a new mock LLM with the given name.
14    pub fn new(name: impl Into<String>) -> Self {
15        Self { name: name.into(), responses: vec![] }
16    }
17
18    /// Add a response to the queue of responses returned by this mock.
19    pub fn with_response(mut self, response: LlmResponse) -> Self {
20        self.responses.push(response);
21        self
22    }
23}
24
25#[async_trait]
26impl Llm for MockLlm {
27    fn name(&self) -> &str {
28        &self.name
29    }
30
31    async fn generate_content(&self, _req: LlmRequest, _stream: bool) -> Result<LlmResponseStream> {
32        let responses = self.responses.clone();
33        let stream = async_stream::stream! {
34            for response in responses {
35                yield Ok(response);
36            }
37        };
38        Ok(Box::pin(stream))
39    }
40}
41
42#[cfg(test)]
43mod tests {
44    use super::*;
45    use adk_core::Content;
46
47    #[test]
48    fn test_mock_llm() {
49        let mock =
50            MockLlm::new("test-llm").with_response(LlmResponse::new(Content::new("assistant")));
51        assert_eq!(mock.name(), "test-llm");
52        assert_eq!(mock.responses.len(), 1);
53    }
54
55    #[tokio::test]
56    async fn test_mock_llm_generate() {
57        use futures::StreamExt;
58
59        let mock = MockLlm::new("test")
60            .with_response(LlmResponse::new(Content::new("assistant").with_text("Hello")));
61
62        let req = LlmRequest::new("test", vec![]);
63        let mut stream = mock.generate_content(req, false).await.unwrap();
64
65        let response = stream.next().await.unwrap().unwrap();
66        assert!(response.content.is_some());
67    }
68}