use std::cell::RefCell;
use super::{Provider, Request, Response, Usage};
use crate::AskError;
pub(crate) struct MockProvider {
pub last_request: RefCell<Option<CapturedRequest>>,
pub canned: String,
pub canned_usage: Usage,
}
#[derive(Clone)]
pub(crate) struct CapturedRequest {
#[allow(dead_code)]
pub model: String,
#[allow(dead_code)]
pub max_tokens: u32,
pub system_blocks: Vec<String>,
pub user_message: String,
pub schema_block_has_cache_control: bool,
}
impl MockProvider {
pub(crate) fn new(canned: impl Into<String>) -> Self {
Self {
last_request: RefCell::new(None),
canned: canned.into(),
canned_usage: Usage::default(),
}
}
}
impl Provider for MockProvider {
fn complete(&self, req: Request<'_>) -> Result<Response, AskError> {
let captured = CapturedRequest {
model: req.model.to_string(),
max_tokens: req.max_tokens,
system_blocks: req.system.iter().map(|b| b.text.clone()).collect(),
user_message: req
.messages
.first()
.map(|m| m.content.clone())
.unwrap_or_default(),
schema_block_has_cache_control: req.system.iter().any(|b| b.cache_control.is_some()),
};
*self.last_request.borrow_mut() = Some(captured);
Ok(Response {
text: self.canned.clone(),
usage: self.canned_usage.clone(),
})
}
}