1use std::collections::BTreeMap;
2use std::time::Duration;
3
4use reqwest::Method;
5use serde::{Deserialize, Serialize};
6use serde_json::{json, Value as JsonValue};
7
8use crate::llm_config::{self, AuthEnv, HealthcheckDef, ProviderDef};
9
10use super::api::apply_auth_headers;
11
12const DEFAULT_HEALTHCHECK_TIMEOUT_SECS: u64 = 5;
13const BODY_SNIPPET_LIMIT: usize = 1000;
14
15#[derive(Debug, Clone, Default)]
16pub struct ProviderHealthcheckOptions {
17 pub api_key: Option<String>,
20 pub client: Option<reqwest::Client>,
22}
23
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25pub struct ProviderHealthcheckResult {
26 pub provider: String,
27 pub valid: bool,
28 pub message: String,
29 pub metadata: BTreeMap<String, JsonValue>,
30}
31
32impl ProviderHealthcheckResult {
33 fn new(
34 provider: impl Into<String>,
35 valid: bool,
36 message: impl Into<String>,
37 metadata: BTreeMap<String, JsonValue>,
38 ) -> Self {
39 Self {
40 provider: provider.into(),
41 valid,
42 message: message.into(),
43 metadata,
44 }
45 }
46}
47
48pub async fn run_provider_healthcheck(provider: &str) -> ProviderHealthcheckResult {
49 run_provider_healthcheck_with_options(provider, ProviderHealthcheckOptions::default()).await
50}
51
52pub async fn run_provider_healthcheck_with_options(
53 provider: &str,
54 options: ProviderHealthcheckOptions,
55) -> ProviderHealthcheckResult {
56 let provider = if provider.trim().is_empty() {
57 "anthropic"
58 } else {
59 provider.trim()
60 };
61
62 let Some(def) = llm_config::provider_config(provider) else {
63 let mut metadata = base_metadata("unknown_provider");
64 metadata.insert("provider".to_string(), json!(provider));
65 return ProviderHealthcheckResult::new(
66 provider,
67 false,
68 format!("Unknown provider: {provider}"),
69 metadata,
70 );
71 };
72
73 let Some(healthcheck) = def.healthcheck.as_ref() else {
74 let mut metadata = base_metadata("no_healthcheck");
75 metadata.insert("provider".to_string(), json!(provider));
76 return ProviderHealthcheckResult::new(
77 provider,
78 false,
79 format!("No healthcheck configured for {provider}"),
80 metadata,
81 );
82 };
83
84 let auth = resolve_healthcheck_auth(&def, options.api_key);
85 if auth.requires_auth && auth.api_key.is_none() {
86 let mut metadata = base_metadata("missing_credentials");
87 metadata.insert("provider".to_string(), json!(provider));
88 metadata.insert("auth_env".to_string(), json!(auth.candidates));
89 return ProviderHealthcheckResult::new(
90 provider,
91 false,
92 format!(
93 "Missing credentials for {provider}: set {} or pass an api_key",
94 auth.candidates.join(", ")
95 ),
96 metadata,
97 );
98 }
99
100 let url = build_healthcheck_url(&def, healthcheck);
101 let method = Method::from_bytes(healthcheck.method.as_bytes()).unwrap_or(Method::GET);
102 let client = match options.client {
103 Some(client) => client,
104 None => match reqwest::Client::builder()
105 .timeout(Duration::from_secs(DEFAULT_HEALTHCHECK_TIMEOUT_SECS))
106 .build()
107 {
108 Ok(client) => client,
109 Err(error) => {
110 let mut metadata = base_metadata("client_build_failed");
111 metadata.insert("provider".to_string(), json!(provider));
112 return ProviderHealthcheckResult::new(
113 provider,
114 false,
115 format!("{provider} healthcheck failed: {error}"),
116 metadata,
117 );
118 }
119 },
120 };
121
122 let mut request = client.request(method.clone(), &url);
123 if let Some(api_key) = auth.api_key.as_deref() {
124 request = apply_auth_headers(request, api_key, Some(&def));
125 }
126 for (name, value) in &def.extra_headers {
127 request = request.header(name, value);
128 }
129 if let Some(body) = &healthcheck.body {
130 request = request
131 .header(reqwest::header::CONTENT_TYPE, "application/json")
132 .body(body.clone());
133 }
134
135 match request.send().await {
136 Ok(response) => {
137 let status = response.status();
138 let status_code = status.as_u16();
139 let valid = status.is_success();
140 let body_text = response.text().await.unwrap_or_default();
141 let mut metadata = base_metadata(if valid { "ok" } else { "http_status" });
142 metadata.insert("provider".to_string(), json!(provider));
143 metadata.insert("status".to_string(), json!(status_code));
144 metadata.insert("url".to_string(), json!(url));
145 metadata.insert("method".to_string(), json!(method.as_str()));
146 if !valid && !body_text.is_empty() {
147 metadata.insert("body".to_string(), json!(body_snippet(&body_text)));
148 }
149
150 let message = if valid {
151 format!("{provider} is reachable (HTTP {status_code})")
152 } else {
153 let suffix = body_snippet(&body_text);
154 if suffix.is_empty() {
155 format!("{provider} returned HTTP {status_code}")
156 } else {
157 format!("{provider} returned HTTP {status_code}: {suffix}")
158 }
159 };
160
161 ProviderHealthcheckResult::new(provider, valid, message, metadata)
162 }
163 Err(error) => {
164 let mut metadata = base_metadata("request_failed");
165 metadata.insert("provider".to_string(), json!(provider));
166 metadata.insert("url".to_string(), json!(url));
167 metadata.insert("method".to_string(), json!(method.as_str()));
168 ProviderHealthcheckResult::new(
169 provider,
170 false,
171 format!("{provider} healthcheck failed: {error}"),
172 metadata,
173 )
174 }
175 }
176}
177
178pub fn build_healthcheck_url(def: &ProviderDef, healthcheck: &HealthcheckDef) -> String {
179 if let Some(url) = &healthcheck.url {
180 return url.clone();
181 }
182
183 let base = llm_config::resolve_base_url(def);
184 let path = healthcheck.path.as_deref().unwrap_or("");
185 if path.starts_with('/') {
186 format!("{}{}", base.trim_end_matches('/'), path)
187 } else if path.is_empty() {
188 base
189 } else {
190 format!("{}/{}", base.trim_end_matches('/'), path)
191 }
192}
193
194#[derive(Debug, Clone)]
195struct ResolvedHealthcheckAuth {
196 requires_auth: bool,
197 api_key: Option<String>,
198 candidates: Vec<String>,
199}
200
201fn resolve_healthcheck_auth(
202 def: &ProviderDef,
203 api_key_override: Option<String>,
204) -> ResolvedHealthcheckAuth {
205 let candidates = auth_env_candidates(&def.auth_env);
206 if def.auth_style == "none" || matches!(def.auth_env, AuthEnv::None) {
207 let api_key = api_key_override.and_then(non_empty);
208 return ResolvedHealthcheckAuth {
209 requires_auth: api_key.is_some(),
210 api_key,
211 candidates,
212 };
213 }
214
215 let api_key = api_key_override
216 .and_then(non_empty)
217 .or_else(|| resolve_api_key_from_env(&def.auth_env));
218 ResolvedHealthcheckAuth {
219 requires_auth: true,
220 api_key,
221 candidates,
222 }
223}
224
225fn auth_env_candidates(auth_env: &AuthEnv) -> Vec<String> {
226 match auth_env {
227 AuthEnv::None => Vec::new(),
228 AuthEnv::Single(env) => vec![env.clone()],
229 AuthEnv::Multiple(envs) => envs.clone(),
230 }
231}
232
233fn resolve_api_key_from_env(auth_env: &AuthEnv) -> Option<String> {
234 match auth_env {
235 AuthEnv::None => None,
236 AuthEnv::Single(env) => std::env::var(env).ok().and_then(non_empty),
237 AuthEnv::Multiple(envs) => envs
238 .iter()
239 .find_map(|env| std::env::var(env).ok().and_then(non_empty)),
240 }
241}
242
243fn non_empty(value: String) -> Option<String> {
244 let trimmed = value.trim();
245 if trimmed.is_empty() {
246 None
247 } else {
248 Some(trimmed.to_string())
249 }
250}
251
252fn base_metadata(reason: &str) -> BTreeMap<String, JsonValue> {
253 BTreeMap::from([("reason".to_string(), json!(reason))])
254}
255
256fn body_snippet(body: &str) -> String {
257 let mut snippet = String::new();
258 for ch in body.chars().take(BODY_SNIPPET_LIMIT) {
259 snippet.push(ch);
260 }
261 snippet
262}
263
264#[cfg(test)]
265mod tests {
266 use std::io::{Read, Write};
267 use std::net::TcpListener;
268 use std::sync::{Arc, Mutex};
269
270 use super::*;
271
272 fn provider_with_healthcheck(base_url: String, healthcheck: HealthcheckDef) -> ProviderDef {
273 ProviderDef {
274 base_url,
275 auth_style: "bearer".to_string(),
276 auth_env: AuthEnv::Single("HARN_TEST_PROVIDER_KEY".to_string()),
277 extra_headers: BTreeMap::from([("x-extra".to_string(), "extra-value".to_string())]),
278 chat_endpoint: "/chat/completions".to_string(),
279 healthcheck: Some(healthcheck),
280 ..Default::default()
281 }
282 }
283
284 fn install_provider(name: &str, provider: ProviderDef) {
285 let mut config = llm_config::ProvidersConfig::default();
286 config.providers.insert(name.to_string(), provider);
287 llm_config::set_user_overrides(Some(config));
288 }
289
290 fn spawn_healthcheck_stub(
291 status: u16,
292 body: &'static str,
293 captured: Arc<Mutex<Option<String>>>,
294 ) -> (String, std::thread::JoinHandle<()>) {
295 let listener = TcpListener::bind("127.0.0.1:0").expect("bind healthcheck stub");
296 let addr = listener.local_addr().expect("stub addr");
297 listener
298 .set_nonblocking(true)
299 .expect("set listener nonblocking");
300
301 let handle = std::thread::spawn(move || {
302 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3);
303 let (mut stream, _) = loop {
304 match listener.accept() {
305 Ok(pair) => break pair,
306 Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {
307 if std::time::Instant::now() >= deadline {
308 panic!("healthcheck stub: no client within 3s");
309 }
310 std::thread::sleep(std::time::Duration::from_millis(10));
311 }
312 Err(error) => panic!("healthcheck stub accept failed: {error}"),
313 }
314 };
315 stream
316 .set_read_timeout(Some(std::time::Duration::from_secs(3)))
317 .ok();
318 stream
319 .set_write_timeout(Some(std::time::Duration::from_secs(3)))
320 .ok();
321
322 let mut bytes = Vec::new();
323 let mut buf = [0u8; 4096];
324 loop {
325 let n = stream.read(&mut buf).expect("read request");
326 if n == 0 {
327 break;
328 }
329 bytes.extend_from_slice(&buf[..n]);
330 let request = String::from_utf8_lossy(&bytes);
331 if request_complete(&request) {
332 break;
333 }
334 }
335 *captured.lock().expect("capture request") =
336 Some(String::from_utf8_lossy(&bytes).to_string());
337
338 let response = format!(
339 "HTTP/1.1 {status} OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
340 body.len()
341 );
342 stream
343 .write_all(response.as_bytes())
344 .expect("write response");
345 });
346
347 (format!("http://{addr}"), handle)
348 }
349
350 fn request_complete(request: &str) -> bool {
351 let Some((headers, body)) = request.split_once("\r\n\r\n") else {
352 return false;
353 };
354 let content_length = headers
355 .lines()
356 .find_map(|line| line.strip_prefix("content-length: "))
357 .or_else(|| {
358 headers
359 .lines()
360 .find_map(|line| line.strip_prefix("Content-Length: "))
361 })
362 .and_then(|value| value.trim().parse::<usize>().ok())
363 .unwrap_or(0);
364 body.len() >= content_length
365 }
366
367 #[tokio::test(flavor = "current_thread")]
368 #[allow(clippy::await_holding_lock)]
369 async fn sends_configured_probe_request_with_candidate_key() {
370 let _guard = crate::llm::env_lock().lock().expect("env lock");
371 let captured = Arc::new(Mutex::new(None));
372 let (base_url, server) = spawn_healthcheck_stub(200, r#"{"ok":true}"#, captured.clone());
373 install_provider(
374 "acme",
375 provider_with_healthcheck(
376 base_url.clone(),
377 HealthcheckDef {
378 method: "POST".to_string(),
379 path: Some("probe".to_string()),
380 url: None,
381 body: Some(r#"{"ping":true}"#.to_string()),
382 },
383 ),
384 );
385
386 let result = run_provider_healthcheck_with_options(
387 "acme",
388 ProviderHealthcheckOptions {
389 api_key: Some("candidate-key".to_string()),
390 client: None,
391 },
392 )
393 .await;
394 server.join().expect("stub server");
395 llm_config::clear_user_overrides();
396
397 assert!(result.valid);
398 assert_eq!(result.provider, "acme");
399 assert_eq!(result.metadata["status"], json!(200));
400 assert_eq!(result.metadata["method"], json!("POST"));
401 assert_eq!(result.metadata["url"], json!(format!("{base_url}/probe")));
402
403 let request = captured
404 .lock()
405 .expect("captured request")
406 .clone()
407 .expect("request");
408 assert!(request.starts_with("POST /probe HTTP/1.1\r\n"));
409 assert!(request.contains("authorization: Bearer candidate-key\r\n"));
410 assert!(request.contains("x-extra: extra-value\r\n"));
411 assert!(request.ends_with(r#"{"ping":true}"#));
412 }
413
414 #[tokio::test(flavor = "current_thread")]
415 #[allow(clippy::await_holding_lock)]
416 async fn reports_missing_credentials_without_network() {
417 let _guard = crate::llm::env_lock().lock().expect("env lock");
418 unsafe {
419 std::env::remove_var("HARN_TEST_PROVIDER_KEY");
420 }
421 install_provider(
422 "acme-missing-key",
423 provider_with_healthcheck(
424 "http://127.0.0.1:9".to_string(),
425 HealthcheckDef {
426 method: "GET".to_string(),
427 path: Some("/models".to_string()),
428 url: None,
429 body: None,
430 },
431 ),
432 );
433
434 let result = run_provider_healthcheck("acme-missing-key").await;
435 llm_config::clear_user_overrides();
436
437 assert!(!result.valid);
438 assert_eq!(result.metadata["reason"], json!("missing_credentials"));
439 assert_eq!(
440 result.metadata["auth_env"],
441 json!(["HARN_TEST_PROVIDER_KEY"])
442 );
443 assert!(result.message.contains("Missing credentials"));
444 }
445
446 #[tokio::test(flavor = "current_thread")]
447 #[allow(clippy::await_holding_lock)]
448 async fn returns_stable_failure_shape_for_http_errors() {
449 let _guard = crate::llm::env_lock().lock().expect("env lock");
450 let captured = Arc::new(Mutex::new(None));
451 let (base_url, server) = spawn_healthcheck_stub(401, r#"{"error":"bad key"}"#, captured);
452 install_provider(
453 "acme-auth",
454 provider_with_healthcheck(
455 base_url,
456 HealthcheckDef {
457 method: "GET".to_string(),
458 path: Some("/models".to_string()),
459 url: None,
460 body: None,
461 },
462 ),
463 );
464
465 let result = run_provider_healthcheck_with_options(
466 "acme-auth",
467 ProviderHealthcheckOptions {
468 api_key: Some("bad-key".to_string()),
469 client: None,
470 },
471 )
472 .await;
473 server.join().expect("stub server");
474 llm_config::clear_user_overrides();
475
476 assert!(!result.valid);
477 assert_eq!(result.provider, "acme-auth");
478 assert_eq!(result.metadata["reason"], json!("http_status"));
479 assert_eq!(result.metadata["status"], json!(401));
480 assert_eq!(result.metadata["body"], json!(r#"{"error":"bad key"}"#));
481 }
482
483 #[test]
484 fn default_external_provider_catalog_has_healthchecks() {
485 for provider in [
486 "openrouter",
487 "anthropic",
488 "openai",
489 "huggingface",
490 "together",
491 ] {
492 let config = llm_config::provider_config(provider)
493 .unwrap_or_else(|| panic!("missing provider {provider}"));
494 let healthcheck = config
495 .healthcheck
496 .as_ref()
497 .unwrap_or_else(|| panic!("missing healthcheck for {provider}"));
498 assert!(!healthcheck.method.is_empty());
499 assert!(healthcheck.path.is_some() || healthcheck.url.is_some());
500 }
501 }
502}