use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use anyhow::Result;
use crate::claude::ai::{AiClient, AiClientMetadata};
pub(crate) struct ConfigurableMockAiClient {
responses: Arc<Mutex<VecDeque<Result<String>>>>,
metadata: AiClientMetadata,
recorded_prompts: Arc<Mutex<Vec<(String, String)>>>,
}
impl ConfigurableMockAiClient {
pub(crate) fn new(responses: Vec<Result<String>>) -> Self {
Self {
responses: Arc::new(Mutex::new(VecDeque::from(responses))),
metadata: AiClientMetadata {
provider: "Mock".to_string(),
model: "mock-model".to_string(),
max_context_length: 200_000,
max_response_length: 8_192,
active_beta: None,
},
recorded_prompts: Arc::new(Mutex::new(Vec::new())),
}
}
pub(crate) fn with_context_length(mut self, max_context_length: usize) -> Self {
self.metadata.max_context_length = max_context_length;
self
}
pub(crate) fn response_handle(&self) -> ResponseQueueHandle {
ResponseQueueHandle {
responses: self.responses.clone(),
}
}
pub(crate) fn prompt_handle(&self) -> PromptRecordHandle {
PromptRecordHandle {
recorded_prompts: self.recorded_prompts.clone(),
}
}
}
pub(crate) struct ResponseQueueHandle {
responses: Arc<Mutex<VecDeque<Result<String>>>>,
}
impl ResponseQueueHandle {
pub(crate) fn remaining(&self) -> usize {
self.responses.lock().unwrap().len()
}
}
pub(crate) struct PromptRecordHandle {
recorded_prompts: Arc<Mutex<Vec<(String, String)>>>,
}
impl PromptRecordHandle {
pub(crate) fn prompts(&self) -> Vec<(String, String)> {
self.recorded_prompts.lock().unwrap().clone()
}
pub(crate) fn request_count(&self) -> usize {
self.recorded_prompts.lock().unwrap().len()
}
}
impl AiClient for ConfigurableMockAiClient {
fn send_request<'a>(
&'a self,
system_prompt: &'a str,
user_prompt: &'a str,
) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>> {
let responses = self.responses.clone();
let recorded = self.recorded_prompts.clone();
let sys = system_prompt.to_string();
let usr = user_prompt.to_string();
Box::pin(async move {
recorded.lock().unwrap().push((sys, usr));
responses
.lock()
.unwrap()
.pop_front()
.unwrap_or_else(|| Err(anyhow::anyhow!("no more mock responses")))
})
}
fn get_metadata(&self) -> AiClientMetadata {
self.metadata.clone()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use crate::claude::ai::RequestOptions;
#[test]
fn mock_client_defaults_to_no_schema_support() {
let client = ConfigurableMockAiClient::new(vec![]);
let caps = client.capabilities();
assert!(
!caps.supports_response_schema,
"mock client should default to no schema support so tests don't have to care"
);
}
#[tokio::test]
async fn mock_client_send_with_options_falls_through_to_send_request() {
let client = ConfigurableMockAiClient::new(vec![Ok("hello".to_string())]);
let prompt_handle = client.prompt_handle();
let result = client
.send_request_with_options("sys", "user", RequestOptions::default())
.await
.expect("default send_request_with_options should succeed");
assert_eq!(result, "hello");
let prompts = prompt_handle.prompts();
assert_eq!(prompts.len(), 1);
assert_eq!(prompts[0], ("sys".to_string(), "user".to_string()));
}
}