use std::sync::Arc;
use std::sync::Mutex;
use async_trait::async_trait;
use crate::core::sm::providers::{LlmProvider, LlmRequest, LlmResponse, SmLlmError};
#[derive(Debug, Clone)]
pub enum MockReply {
Fixed(String),
Echo { prefix: String },
}
#[derive(Clone)]
pub struct MockProvider {
reply: MockReply,
requests: Arc<Mutex<Vec<LlmRequest>>>,
}
impl MockProvider {
pub fn fixed(text: impl Into<String>) -> Self {
Self {
reply: MockReply::Fixed(text.into()),
requests: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn echo(prefix: impl Into<String>) -> Self {
Self {
reply: MockReply::Echo {
prefix: prefix.into(),
},
requests: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn requests(&self) -> Vec<LlmRequest> {
self.requests
.lock()
.expect("mock lock not poisoned")
.clone()
}
pub fn last_model(&self) -> Option<String> {
self.requests
.lock()
.expect("mock lock not poisoned")
.last()
.map(|r| r.model.clone())
}
}
#[async_trait]
impl LlmProvider for MockProvider {
fn name(&self) -> &str {
"mock"
}
async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, SmLlmError> {
let text = match &self.reply {
MockReply::Fixed(t) => t.clone(),
MockReply::Echo { prefix } => {
let body: String = req
.messages
.iter()
.filter(|m| m.role == "user")
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n");
format!("{prefix}{body}")
}
};
let model = req.model.clone();
self.requests
.lock()
.expect("mock lock not poisoned")
.push(req);
Ok(LlmResponse {
text,
model,
input_tokens: 0,
output_tokens: 0,
latency_ms: 0,
cost_usd: 0.0,
})
}
}