1use std::{collections::VecDeque, sync::Arc};
4
5use async_trait::async_trait;
6use futures::{stream, StreamExt};
7use tokio::sync::Mutex;
8
9use crate::{
10 CompletionRequest, CompletionResponse, LlmProvider, ProviderError, Result, StreamChunk,
11};
12
13#[derive(Clone, Default)]
15pub struct MockLlmProvider {
16 name: String,
17 responses: Arc<Mutex<VecDeque<CompletionResponse>>>,
18}
19
20impl MockLlmProvider {
21 pub fn new() -> Self {
23 Self {
24 name: "mock".to_string(),
25 responses: Arc::new(Mutex::new(VecDeque::new())),
26 }
27 }
28
29 pub fn with_responses(responses: Vec<CompletionResponse>) -> Self {
31 Self {
32 name: "mock".to_string(),
33 responses: Arc::new(Mutex::new(VecDeque::from(responses))),
34 }
35 }
36
37 pub fn with_text_responses(texts: Vec<impl Into<String>>) -> Self {
39 Self::with_responses(texts.into_iter().map(CompletionResponse::text).collect())
40 }
41
42 pub fn with_tool_call_sequence(
44 tool_name: impl Into<String>,
45 arguments: serde_json::Value,
46 final_text: impl Into<String>,
47 ) -> Self {
48 Self::with_responses(vec![
49 CompletionResponse::tool_call(tool_name, arguments),
50 CompletionResponse::text(final_text),
51 ])
52 }
53
54 pub async fn push_response(&self, response: CompletionResponse) {
56 self.responses.lock().await.push_back(response);
57 }
58
59 async fn pop_response(&self) -> Result<CompletionResponse> {
60 self.responses.lock().await.pop_front().ok_or_else(|| {
61 ProviderError::InvalidResponse("mock provider exhausted".to_string()).into()
62 })
63 }
64}
65
66#[async_trait]
67impl LlmProvider for MockLlmProvider {
68 async fn complete(&self, _req: CompletionRequest) -> Result<CompletionResponse> {
69 self.pop_response().await
70 }
71
72 async fn stream(
73 &self,
74 _req: CompletionRequest,
75 ) -> Result<futures::stream::BoxStream<'_, Result<StreamChunk>>> {
76 let response = self.pop_response().await?;
77 let text = response.message.text_content();
78 let chunks = text
79 .split_whitespace()
80 .map(|token| {
81 Ok(StreamChunk {
82 delta: format!("{token} "),
83 tool_call_delta: None,
84 finish_reason: None,
85 })
86 })
87 .chain(std::iter::once(Ok(StreamChunk {
88 delta: String::new(),
89 tool_call_delta: None,
90 finish_reason: Some(response.stop_reason),
91 })))
92 .collect::<Vec<_>>();
93
94 Ok(stream::iter(chunks).boxed())
95 }
96
97 fn name(&self) -> &str {
98 &self.name
99 }
100}