use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
use tracing::debug;
#[cfg(test)]
use crate::llm::LlmResponse;
use crate::llm::{ChatMessage, LlmError, LlmProvider, LlmRequest};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum InferenceStatus {
Ok,
Unreachable,
AuthError,
Unknown,
}
impl InferenceStatus {
pub fn is_ok(self) -> bool {
self == InferenceStatus::Ok
}
}
impl std::fmt::Display for InferenceStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
InferenceStatus::Ok => "ok",
InferenceStatus::Unreachable => "unreachable",
InferenceStatus::AuthError => "auth_error",
InferenceStatus::Unknown => "unknown",
};
write!(f, "{s}")
}
}
pub fn map_llm_error(err: &LlmError) -> InferenceStatus {
match err {
LlmError::AccessDenied(_) | LlmError::ModelNotFound(_) | LlmError::ModelNotReady(_) => {
InferenceStatus::AuthError
}
LlmError::Validation(_) => InferenceStatus::AuthError,
LlmError::Transport(_) | LlmError::RateLimited | LlmError::Upstream { .. } => {
InferenceStatus::Unreachable
}
}
}
#[derive(Clone)]
pub struct InferenceProbe {
cached: Arc<Mutex<Option<(InferenceStatus, Instant)>>>,
ttl: Duration,
probe_timeout: Duration,
}
impl Default for InferenceProbe {
fn default() -> Self {
Self::new(Duration::from_secs(10), Duration::from_secs(3))
}
}
impl InferenceProbe {
pub fn new(ttl: Duration, probe_timeout: Duration) -> Self {
Self {
cached: Arc::new(Mutex::new(None)),
ttl,
probe_timeout,
}
}
pub async fn probe(&self, llm: &Arc<dyn LlmProvider>, model: &str) -> InferenceStatus {
{
let guard = self.cached.lock().unwrap_or_else(|p| p.into_inner());
if let Some((status, ts)) = *guard
&& ts.elapsed() < self.ttl
{
debug!(status = %status, "inference probe: cache hit");
return status;
}
}
let status = run_probe(llm, model, self.probe_timeout).await;
debug!(status = %status, "inference probe: fresh result");
{
let mut guard = self.cached.lock().unwrap_or_else(|p| p.into_inner());
*guard = Some((status, Instant::now()));
}
status
}
}
async fn run_probe(llm: &Arc<dyn LlmProvider>, model: &str, timeout: Duration) -> InferenceStatus {
let req = LlmRequest {
model: model.to_string(),
system: String::new(),
messages: vec![ChatMessage {
role: "user".to_string(),
content: "hi".to_string(),
}],
temperature: 0.0,
max_tokens: 1,
response_schema: None,
};
let result = tokio::time::timeout(timeout, llm.complete(req)).await;
match result {
Err(_elapsed) => {
debug!("inference probe: timed out");
InferenceStatus::Unreachable
}
Ok(Ok(_)) => InferenceStatus::Ok,
Ok(Err(e)) => {
debug!(error = %e, "inference probe: provider error");
map_llm_error(&e)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use std::sync::atomic::{AtomicU32, Ordering};
struct OkLlm;
#[async_trait]
impl LlmProvider for OkLlm {
fn name(&self) -> &str {
"ok-stub"
}
async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
Ok(LlmResponse {
text: "hi".to_string(),
model: req.model.clone(),
input_tokens: 1,
output_tokens: 1,
latency_ms: 0,
cost_usd: 0.0,
})
}
}
struct AuthErrorLlm;
#[async_trait]
impl LlmProvider for AuthErrorLlm {
fn name(&self) -> &str {
"auth-error-stub"
}
async fn complete(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
Err(LlmError::AccessDenied("invalid api key".into()))
}
}
struct TransportErrorLlm;
#[async_trait]
impl LlmProvider for TransportErrorLlm {
fn name(&self) -> &str {
"transport-stub"
}
async fn complete(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
Err(LlmError::Transport("connection refused".into()))
}
}
struct HungLlm;
#[async_trait]
impl LlmProvider for HungLlm {
fn name(&self) -> &str {
"hung-stub"
}
async fn complete(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
tokio::time::sleep(Duration::from_secs(60)).await;
Err(LlmError::Transport("hung".into()))
}
}
struct CountingLlm {
calls: Arc<AtomicU32>,
}
#[async_trait]
impl LlmProvider for CountingLlm {
fn name(&self) -> &str {
"counting-stub"
}
async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
self.calls.fetch_add(1, Ordering::Relaxed);
Ok(LlmResponse {
text: "x".into(),
model: req.model.clone(),
input_tokens: 1,
output_tokens: 1,
latency_ms: 0,
cost_usd: 0.0,
})
}
}
#[test]
fn probe_status_serialises_lowercase() {
assert_eq!(
serde_json::to_string(&InferenceStatus::Ok).unwrap(),
"\"ok\""
);
assert_eq!(
serde_json::to_string(&InferenceStatus::Unreachable).unwrap(),
"\"unreachable\""
);
assert_eq!(
serde_json::to_string(&InferenceStatus::AuthError).unwrap(),
"\"auth_error\""
);
assert_eq!(
serde_json::to_string(&InferenceStatus::Unknown).unwrap(),
"\"unknown\""
);
}
#[test]
fn probe_status_is_ok() {
assert!(InferenceStatus::Ok.is_ok());
assert!(!InferenceStatus::Unreachable.is_ok());
assert!(!InferenceStatus::AuthError.is_ok());
assert!(!InferenceStatus::Unknown.is_ok());
}
#[test]
fn error_mapping_access_denied_is_auth_error() {
let status = map_llm_error(&LlmError::AccessDenied("denied".into()));
assert_eq!(status, InferenceStatus::AuthError);
}
#[test]
fn error_mapping_model_not_found_is_auth_error() {
let status = map_llm_error(&LlmError::ModelNotFound("no-model".into()));
assert_eq!(status, InferenceStatus::AuthError);
}
#[test]
fn error_mapping_model_not_ready_is_auth_error() {
let status = map_llm_error(&LlmError::ModelNotReady("creating".into()));
assert_eq!(status, InferenceStatus::AuthError);
}
#[test]
fn error_mapping_validation_is_auth_error() {
let status = map_llm_error(&LlmError::Validation("bad prefix".into()));
assert_eq!(status, InferenceStatus::AuthError);
}
#[test]
fn error_mapping_transport_is_unreachable() {
let status = map_llm_error(&LlmError::Transport("connection refused".into()));
assert_eq!(status, InferenceStatus::Unreachable);
}
#[test]
fn error_mapping_rate_limited_is_unreachable() {
let status = map_llm_error(&LlmError::RateLimited);
assert_eq!(status, InferenceStatus::Unreachable);
}
#[test]
fn error_mapping_upstream_5xx_is_unreachable() {
let status = map_llm_error(&LlmError::Upstream {
status: 503,
body: "overloaded".into(),
});
assert_eq!(status, InferenceStatus::Unreachable);
}
#[tokio::test]
async fn probe_returns_ok_on_success() {
let llm: Arc<dyn LlmProvider> = Arc::new(OkLlm);
let status = run_probe(&llm, "test-model", Duration::from_secs(5)).await;
assert_eq!(status, InferenceStatus::Ok);
}
#[tokio::test]
async fn probe_returns_auth_error_on_access_denied() {
let llm: Arc<dyn LlmProvider> = Arc::new(AuthErrorLlm);
let status = run_probe(&llm, "test-model", Duration::from_secs(5)).await;
assert_eq!(status, InferenceStatus::AuthError);
}
#[tokio::test]
async fn probe_returns_unreachable_on_transport() {
let llm: Arc<dyn LlmProvider> = Arc::new(TransportErrorLlm);
let status = run_probe(&llm, "test-model", Duration::from_secs(5)).await;
assert_eq!(status, InferenceStatus::Unreachable);
}
#[tokio::test(start_paused = true)]
async fn probe_respects_timeout() {
let llm: Arc<dyn LlmProvider> = Arc::new(HungLlm);
let status = run_probe(&llm, "test-model", Duration::from_millis(10)).await;
assert_eq!(
status,
InferenceStatus::Unreachable,
"hung endpoint must produce Unreachable"
);
}
#[tokio::test]
async fn probe_cache_prevents_redundant_calls() {
let calls = Arc::new(AtomicU32::new(0));
let llm: Arc<dyn LlmProvider> = Arc::new(CountingLlm {
calls: Arc::clone(&calls),
});
let probe = InferenceProbe::new(Duration::from_secs(60), Duration::from_secs(5));
let s1 = probe.probe(&llm, "m").await;
let s2 = probe.probe(&llm, "m").await;
assert_eq!(s1, InferenceStatus::Ok);
assert_eq!(s2, InferenceStatus::Ok);
assert_eq!(
calls.load(Ordering::Relaxed),
1,
"provider must be called exactly once when cache is warm"
);
}
#[tokio::test(start_paused = true)]
async fn probe_cache_ttl_zero_always_reprobes() {
let calls = Arc::new(AtomicU32::new(0));
let llm: Arc<dyn LlmProvider> = Arc::new(CountingLlm {
calls: Arc::clone(&calls),
});
let probe = InferenceProbe::new(Duration::ZERO, Duration::from_secs(5));
probe.probe(&llm, "m").await;
probe.probe(&llm, "m").await;
assert_eq!(
calls.load(Ordering::Relaxed),
2,
"zero TTL must reprobe on every call"
);
}
}