use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use crate::core::sm::config::SmInferenceConfig;
use crate::core::sm::providers::{
LlmProvider, LlmRequest, LlmResponse, ProviderKind, ResolvedCall, SmLlmError, SmModelTier,
TierResolver,
};
#[derive(Clone)]
pub struct MockChatProvider {
reply: String,
cost_usd: f64,
requests: Arc<Mutex<Vec<LlmRequest>>>,
}
impl MockChatProvider {
pub fn new(reply: impl Into<String>, cost_usd: f64) -> Self {
Self {
reply: reply.into(),
cost_usd,
requests: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn last_request(&self) -> Option<LlmRequest> {
self.requests.lock().expect("mock lock").last().cloned()
}
pub fn request_count(&self) -> usize {
self.requests.lock().expect("mock lock").len()
}
}
#[async_trait]
impl LlmProvider for MockChatProvider {
fn name(&self) -> &str {
"mock"
}
async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, SmLlmError> {
let model = req.model.clone();
self.requests.lock().expect("mock lock").push(req);
Ok(LlmResponse {
text: self.reply.clone(),
model,
input_tokens: 10,
output_tokens: 5,
latency_ms: 1,
cost_usd: self.cost_usd,
})
}
}
#[derive(Clone)]
pub enum MockResolution {
Provider(MockChatProvider),
Degraded,
Validation,
ProviderThenDegraded {
provider: MockChatProvider,
n: usize,
calls: Arc<AtomicUsize>,
},
}
#[derive(Clone)]
pub struct MockResolver {
resolution: MockResolution,
}
impl MockResolver {
pub fn with_provider(provider: MockChatProvider) -> Self {
Self {
resolution: MockResolution::Provider(provider),
}
}
pub fn degraded() -> Self {
Self {
resolution: MockResolution::Degraded,
}
}
pub fn validation() -> Self {
Self {
resolution: MockResolution::Validation,
}
}
pub fn provider_then_degraded(provider: MockChatProvider, n: usize) -> Self {
Self {
resolution: MockResolution::ProviderThenDegraded {
provider,
n,
calls: Arc::new(AtomicUsize::new(0)),
},
}
}
}
#[async_trait]
impl TierResolver for MockResolver {
async fn resolve(
&self,
cfg: &SmInferenceConfig,
tier: SmModelTier,
) -> Result<ResolvedCall, SmLlmError> {
match &self.resolution {
MockResolution::Provider(provider) => resolved_call(cfg, tier, provider),
MockResolution::Degraded => Err(SmLlmError::Degraded(
"mock: no provider configured".to_string(),
)),
MockResolution::Validation => {
Err(SmLlmError::Validation("mock: unknown provider".to_string()))
}
MockResolution::ProviderThenDegraded { provider, n, calls } => {
let idx = calls.fetch_add(1, Ordering::SeqCst);
if idx < *n {
resolved_call(cfg, tier, provider)
} else {
Err(SmLlmError::Degraded(
"mock: provider exhausted after first call".to_string(),
))
}
}
}
}
}
fn resolved_call(
cfg: &SmInferenceConfig,
tier: SmModelTier,
provider: &MockChatProvider,
) -> Result<ResolvedCall, SmLlmError> {
let tier_model = crate::core::sm::providers::resolve_tier_model(cfg, tier)?;
let (_kind, bare) =
crate::core::sm::providers::resolve_provider_and_model(&tier_model, ProviderKind::Auto);
Ok(ResolvedCall {
provider: Arc::new(provider.clone()),
model: bare,
kind: ProviderKind::Anthropic,
})
}