use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
use serde_json::Value;
pub const DEFAULT_PROBE_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Deserialize)]
pub struct ProbeRequest {
pub provider_id: String,
pub base_url: String,
pub api_key: String,
#[serde(default)]
pub model_hint: Option<String>,
}
impl std::fmt::Debug for ProbeRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProbeRequest")
.field("provider_id", &self.provider_id)
.field("base_url", &self.base_url)
.field("api_key", &"<redacted>")
.field("model_hint", &self.model_hint)
.finish()
}
}
#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
pub struct ProbeResult {
pub ok: bool,
pub status: u16,
pub latency_ms: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_count: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
pub async fn probe(
req: &ProbeRequest,
http: &reqwest::Client,
timeout: Option<Duration>,
) -> ProbeResult {
let url = build_models_url(&req.base_url);
let start = Instant::now();
let response = http
.get(&url)
.bearer_auth(&req.api_key)
.timeout(timeout.unwrap_or(DEFAULT_PROBE_TIMEOUT))
.send()
.await;
match response {
Ok(r) => {
let status = r.status().as_u16();
let body = r.bytes().await.unwrap_or_default();
let latency_ms = start.elapsed().as_millis() as u64;
let ok = (200..300).contains(&status);
if ok {
let model_count = parse_model_count(&body);
ProbeResult {
ok: true,
status,
latency_ms,
model_count,
error: None,
}
} else {
let raw_text = String::from_utf8_lossy(&body).into_owned();
let trimmed = raw_text.chars().take(400).collect::<String>();
let safe = redact_key(&trimmed, &req.api_key);
ProbeResult {
ok: false,
status,
latency_ms,
model_count: None,
error: Some(format!("HTTP {status}: {safe}")),
}
}
}
Err(e) => {
let latency_ms = start.elapsed().as_millis() as u64;
let raw = e.to_string();
let safe = redact_key(&raw, &req.api_key);
ProbeResult {
ok: false,
status: 0,
latency_ms,
model_count: None,
error: Some(safe),
}
}
}
}
fn build_models_url(base_url: &str) -> String {
let trimmed = base_url.trim_end_matches('/');
format!("{trimmed}/models")
}
fn parse_model_count(body: &[u8]) -> Option<usize> {
let v: Value = serde_json::from_slice(body).ok()?;
let arr = v.get("data")?.as_array()?;
Some(arr.len())
}
fn redact_key(haystack: &str, key: &str) -> String {
if key.is_empty() {
return haystack.to_string();
}
let mut out = haystack.replace(key, "<redacted>");
if key.len() > 8 {
out = out.replace(&key[..8], "<redacted>");
}
out
}
#[cfg(test)]
mod tests {
use super::*;
fn req_with(api_key: &str, base_url: &str) -> ProbeRequest {
ProbeRequest {
provider_id: "minimax".into(),
base_url: base_url.into(),
api_key: api_key.into(),
model_hint: None,
}
}
#[test]
fn build_models_url_handles_trailing_slash() {
assert_eq!(
build_models_url("https://api.minimax.chat/v1"),
"https://api.minimax.chat/v1/models"
);
assert_eq!(
build_models_url("https://api.minimax.chat/v1/"),
"https://api.minimax.chat/v1/models"
);
}
#[test]
fn parse_model_count_reads_data_array_length() {
let body = br#"{"data":[{"id":"a"},{"id":"b"},{"id":"c"}]}"#;
assert_eq!(parse_model_count(body), Some(3));
}
#[test]
fn parse_model_count_returns_none_on_unexpected_shape() {
assert_eq!(parse_model_count(b"not json"), None);
assert_eq!(parse_model_count(br#"{"models":[]}"#), None);
assert_eq!(parse_model_count(br#"{"data":"oops"}"#), None);
}
#[test]
fn redact_key_replaces_full_and_prefix() {
let key = "sk-supersecretkey-1234567890abcdef";
let body = format!("error: invalid token {key} (origin: foo)");
let redacted = redact_key(&body, key);
assert!(!redacted.contains(key));
assert!(redacted.contains("<redacted>"));
let prefix_only = format!("token starts with {} which is wrong", &key[..8]);
let redacted2 = redact_key(&prefix_only, key);
assert!(!redacted2.contains(&key[..8]));
}
#[test]
fn redact_key_noop_on_empty_key() {
assert_eq!(redact_key("hello", ""), "hello");
}
#[test]
fn debug_redacts_api_key() {
let r = req_with("sk-leak-this-not", "https://x/v1");
let s = format!("{r:?}");
assert!(s.contains("<redacted>"));
assert!(!s.contains("sk-leak-this-not"));
}
#[tokio::test]
async fn probe_timeout_returns_error_under_seven_seconds() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let _accept_task = tokio::spawn(async move {
let _ = listener.accept().await;
tokio::time::sleep(Duration::from_secs(15)).await;
});
let http = reqwest::Client::builder()
.timeout(DEFAULT_PROBE_TIMEOUT)
.build()
.unwrap();
let req = req_with("sk-test", &format!("http://{addr}/v1"));
let started = Instant::now();
let result = probe(&req, &http, None).await;
let elapsed = started.elapsed();
assert!(!result.ok, "timeout should not be reported as ok");
assert!(
elapsed < Duration::from_secs(7),
"probe waited too long: {elapsed:?}"
);
assert!(result.error.is_some());
}
}