use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use anyhow::Result;
use crate::claude::ai::{AiClient, AiClientMetadata};
pub(crate) struct ConfigurableMockAiClient {
responses: Arc<Mutex<VecDeque<Result<String>>>>,
metadata: AiClientMetadata,
recorded_prompts: Arc<Mutex<Vec<(String, String)>>>,
}
impl ConfigurableMockAiClient {
pub(crate) fn new(responses: Vec<Result<String>>) -> Self {
Self {
responses: Arc::new(Mutex::new(VecDeque::from(responses))),
metadata: AiClientMetadata {
provider: "Mock".to_string(),
model: "mock-model".to_string(),
max_context_length: 200_000,
max_response_length: 8_192,
active_beta: None,
},
recorded_prompts: Arc::new(Mutex::new(Vec::new())),
}
}
pub(crate) fn with_context_length(mut self, max_context_length: usize) -> Self {
self.metadata.max_context_length = max_context_length;
self
}
pub(crate) fn response_handle(&self) -> ResponseQueueHandle {
ResponseQueueHandle {
responses: self.responses.clone(),
}
}
pub(crate) fn prompt_handle(&self) -> PromptRecordHandle {
PromptRecordHandle {
recorded_prompts: self.recorded_prompts.clone(),
}
}
}
pub(crate) struct ResponseQueueHandle {
responses: Arc<Mutex<VecDeque<Result<String>>>>,
}
impl ResponseQueueHandle {
pub(crate) fn remaining(&self) -> usize {
self.responses.lock().unwrap().len()
}
}
pub(crate) struct PromptRecordHandle {
recorded_prompts: Arc<Mutex<Vec<(String, String)>>>,
}
impl PromptRecordHandle {
pub(crate) fn prompts(&self) -> Vec<(String, String)> {
self.recorded_prompts.lock().unwrap().clone()
}
pub(crate) fn request_count(&self) -> usize {
self.recorded_prompts.lock().unwrap().len()
}
}
impl AiClient for ConfigurableMockAiClient {
fn send_request<'a>(
&'a self,
system_prompt: &'a str,
user_prompt: &'a str,
) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>> {
let responses = self.responses.clone();
let recorded = self.recorded_prompts.clone();
let sys = system_prompt.to_string();
let usr = user_prompt.to_string();
Box::pin(async move {
recorded.lock().unwrap().push((sys, usr));
responses
.lock()
.unwrap()
.pop_front()
.unwrap_or_else(|| Err(anyhow::anyhow!("no more mock responses")))
})
}
fn get_metadata(&self) -> AiClientMetadata {
self.metadata.clone()
}
}