use super::*;
use async_trait::async_trait;
use std::sync::atomic::{AtomicU32, Ordering};
struct OkLlm;
#[async_trait]
impl LlmProvider for OkLlm {
fn name(&self) -> &str {
"ok-stub"
}
async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
Ok(LlmResponse {
text: "hi".to_string(),
model: req.model.clone(),
input_tokens: 1,
output_tokens: 1,
latency_ms: 0,
cost_usd: 0.0,
finish_reason: None,
})
}
}
struct AuthErrorLlm;
#[async_trait]
impl LlmProvider for AuthErrorLlm {
fn name(&self) -> &str {
"auth-error-stub"
}
async fn complete(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
Err(LlmError::AccessDenied("invalid api key".into()))
}
}
struct TransportErrorLlm;
#[async_trait]
impl LlmProvider for TransportErrorLlm {
fn name(&self) -> &str {
"transport-stub"
}
async fn complete(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
Err(LlmError::Transport("connection refused".into()))
}
}
struct HungLlm;
#[async_trait]
impl LlmProvider for HungLlm {
fn name(&self) -> &str {
"hung-stub"
}
async fn complete(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
tokio::time::sleep(Duration::from_secs(60)).await;
Err(LlmError::Transport("hung".into()))
}
}
struct CountingLlm {
calls: Arc<AtomicU32>,
}
#[async_trait]
impl LlmProvider for CountingLlm {
fn name(&self) -> &str {
"counting-stub"
}
async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
self.calls.fetch_add(1, Ordering::Relaxed);
Ok(LlmResponse {
text: "x".into(),
model: req.model.clone(),
input_tokens: 1,
output_tokens: 1,
latency_ms: 0,
cost_usd: 0.0,
finish_reason: None,
})
}
}
#[test]
fn probe_status_serialises_lowercase() {
assert_eq!(
serde_json::to_string(&InferenceStatus::Ok).unwrap(),
"\"ok\""
);
assert_eq!(
serde_json::to_string(&InferenceStatus::Unreachable).unwrap(),
"\"unreachable\""
);
assert_eq!(
serde_json::to_string(&InferenceStatus::AuthError).unwrap(),
"\"auth_error\""
);
assert_eq!(
serde_json::to_string(&InferenceStatus::Unknown).unwrap(),
"\"unknown\""
);
}
#[test]
fn probe_status_is_ok() {
assert!(InferenceStatus::Ok.is_ok());
assert!(!InferenceStatus::Unreachable.is_ok());
assert!(!InferenceStatus::AuthError.is_ok());
assert!(!InferenceStatus::Unknown.is_ok());
}
#[test]
fn error_mapping_access_denied_is_auth_error() {
let status = map_llm_error(&LlmError::AccessDenied("denied".into()));
assert_eq!(status, InferenceStatus::AuthError);
}
#[test]
fn error_mapping_model_not_found_is_auth_error() {
let status = map_llm_error(&LlmError::ModelNotFound("no-model".into()));
assert_eq!(status, InferenceStatus::AuthError);
}
#[test]
fn error_mapping_model_not_ready_is_auth_error() {
let status = map_llm_error(&LlmError::ModelNotReady("creating".into()));
assert_eq!(status, InferenceStatus::AuthError);
}
#[test]
fn error_mapping_validation_is_auth_error() {
let status = map_llm_error(&LlmError::Validation("bad prefix".into()));
assert_eq!(status, InferenceStatus::AuthError);
}
#[test]
fn error_mapping_transport_is_unreachable() {
let status = map_llm_error(&LlmError::Transport("connection refused".into()));
assert_eq!(status, InferenceStatus::Unreachable);
}
#[test]
fn error_mapping_rate_limited_is_unreachable() {
let status = map_llm_error(&LlmError::RateLimited);
assert_eq!(status, InferenceStatus::Unreachable);
}
#[test]
fn error_mapping_upstream_5xx_is_unreachable() {
let status = map_llm_error(&LlmError::Upstream {
status: 503,
body: "overloaded".into(),
});
assert_eq!(status, InferenceStatus::Unreachable);
}
#[tokio::test]
async fn probe_returns_ok_on_success() {
let llm: Arc<dyn LlmProvider> = Arc::new(OkLlm);
let status = run_probe(&llm, "test-model", Duration::from_secs(5)).await;
assert_eq!(status, InferenceStatus::Ok);
}
#[tokio::test]
async fn probe_returns_auth_error_on_access_denied() {
let llm: Arc<dyn LlmProvider> = Arc::new(AuthErrorLlm);
let status = run_probe(&llm, "test-model", Duration::from_secs(5)).await;
assert_eq!(status, InferenceStatus::AuthError);
}
#[tokio::test]
async fn probe_returns_unreachable_on_transport() {
let llm: Arc<dyn LlmProvider> = Arc::new(TransportErrorLlm);
let status = run_probe(&llm, "test-model", Duration::from_secs(5)).await;
assert_eq!(status, InferenceStatus::Unreachable);
}
#[tokio::test(start_paused = true)]
async fn probe_timeout_returns_unknown() {
let llm: Arc<dyn LlmProvider> = Arc::new(HungLlm);
let status = run_probe(&llm, "test-model", Duration::from_millis(10)).await;
assert_eq!(
status,
InferenceStatus::Unknown,
"probe timeout must return Unknown, not Unreachable (#739)"
);
}
#[tokio::test]
async fn probe_cache_prevents_redundant_calls() {
let calls = Arc::new(AtomicU32::new(0));
let llm: Arc<dyn LlmProvider> = Arc::new(CountingLlm {
calls: Arc::clone(&calls),
});
let probe = InferenceProbe::new(Duration::from_secs(60), Duration::from_secs(5));
let s1 = probe.probe(&llm, "m").await;
let s2 = probe.probe(&llm, "m").await;
assert_eq!(s1, InferenceStatus::Ok);
assert_eq!(s2, InferenceStatus::Ok);
assert_eq!(
calls.load(Ordering::Relaxed),
1,
"provider must be called exactly once when cache is warm"
);
}
#[tokio::test(start_paused = true)]
async fn probe_cache_ttl_zero_always_reprobes() {
let calls = Arc::new(AtomicU32::new(0));
let llm: Arc<dyn LlmProvider> = Arc::new(CountingLlm {
calls: Arc::clone(&calls),
});
let probe = InferenceProbe::new(Duration::ZERO, Duration::from_secs(5));
probe.probe(&llm, "m").await;
probe.probe(&llm, "m").await;
assert_eq!(
calls.load(Ordering::Relaxed),
2,
"zero TTL must reprobe on every call"
);
}
#[tokio::test(start_paused = true)]
async fn consecutive_unknown_degrades_after_threshold() {
let llm: Arc<dyn LlmProvider> = Arc::new(HungLlm);
let probe = InferenceProbe::new(Duration::ZERO, Duration::from_millis(1));
let threshold = CONSECUTIVE_UNKNOWN_DEGRADATION_THRESHOLD;
for i in 0..(threshold - 1) {
let status = probe.probe(&llm, "m").await;
assert_eq!(
status,
InferenceStatus::Unknown,
"probe {i}: should be Unknown before threshold"
);
}
let status = probe.probe(&llm, "m").await;
assert_eq!(
status,
InferenceStatus::Unreachable,
"probe at threshold must escalate Unknown → Unreachable (#820)"
);
}
#[tokio::test(start_paused = true)]
async fn consecutive_unknown_resets_on_ok() {
let hung_llm: Arc<dyn LlmProvider> = Arc::new(HungLlm);
let ok_llm: Arc<dyn LlmProvider> = Arc::new(OkLlm);
let probe = InferenceProbe::new(Duration::ZERO, Duration::from_millis(1));
let threshold = CONSECUTIVE_UNKNOWN_DEGRADATION_THRESHOLD;
for _ in 0..threshold {
probe.probe(&hung_llm, "m").await;
}
let escalated = probe.probe(&hung_llm, "m").await;
assert_eq!(
escalated,
InferenceStatus::Unreachable,
"should be escalated to Unreachable before reset"
);
let probe_ok = InferenceProbe::new(Duration::ZERO, Duration::from_secs(5));
let status = probe_ok.probe(&ok_llm, "m").await;
assert_eq!(
status,
InferenceStatus::Ok,
"successful probe must return Ok (counter at 0 → no escalation)"
);
let probe_reset = InferenceProbe::new(Duration::ZERO, Duration::from_millis(1));
for _ in 0..(threshold - 1) {
probe_reset.probe(&hung_llm, "m").await;
}
let reset_probe = InferenceProbe::new(Duration::ZERO, Duration::from_secs(5));
let after_reset = reset_probe.probe(&ok_llm, "m").await;
assert_eq!(after_reset, InferenceStatus::Ok, "Ok probe resets streak");
}
#[test]
#[serial_test::serial]
fn health_probe_timeout_default() {
unsafe { std::env::remove_var("TRUSTY_REVIEW_HEALTH_TIMEOUT_SECS") };
let t = health_probe_timeout();
assert_eq!(
t,
Duration::from_secs(10),
"default probe timeout must be 10 s (#739)"
);
}
#[test]
#[serial_test::serial]
fn health_probe_timeout_env_override() {
unsafe { std::env::set_var("TRUSTY_REVIEW_HEALTH_TIMEOUT_SECS", "15") };
let t = health_probe_timeout();
unsafe { std::env::remove_var("TRUSTY_REVIEW_HEALTH_TIMEOUT_SECS") };
assert_eq!(
t,
Duration::from_secs(15),
"env-var override must be honoured"
);
}
#[test]
#[serial_test::serial]
fn health_probe_timeout_env_invalid_falls_back() {
unsafe { std::env::set_var("TRUSTY_REVIEW_HEALTH_TIMEOUT_SECS", "not-a-number") };
let t = health_probe_timeout();
unsafe { std::env::remove_var("TRUSTY_REVIEW_HEALTH_TIMEOUT_SECS") };
assert_eq!(
t,
Duration::from_secs(10),
"invalid env var must fall back to 10 s default"
);
}
#[test]
#[serial_test::serial]
fn health_probe_timeout_env_zero_falls_back() {
unsafe { std::env::set_var("TRUSTY_REVIEW_HEALTH_TIMEOUT_SECS", "0") };
let t = health_probe_timeout();
unsafe { std::env::remove_var("TRUSTY_REVIEW_HEALTH_TIMEOUT_SECS") };
assert_eq!(
t,
Duration::from_secs(10),
"zero env var must fall back to 10 s default"
);
}