1use std::collections::BTreeMap;
2use std::time::Duration;
3
4use reqwest::Method;
5use serde::{Deserialize, Serialize};
6use serde_json::{json, Value as JsonValue};
7
8use crate::llm_config::{self, AuthEnv, HealthcheckDef, ProviderDef};
9
10use super::api::apply_auth_headers;
11
12const DEFAULT_HEALTHCHECK_TIMEOUT_SECS: u64 = 5;
13const BODY_SNIPPET_LIMIT: usize = 1000;
14
15#[derive(Debug, Clone, Default)]
16pub struct ProviderHealthcheckOptions {
17 pub api_key: Option<String>,
20 pub client: Option<reqwest::Client>,
22}
23
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25pub struct ProviderHealthcheckResult {
26 pub provider: String,
27 pub valid: bool,
28 pub message: String,
29 pub metadata: BTreeMap<String, JsonValue>,
30}
31
32impl ProviderHealthcheckResult {
33 fn new(
34 provider: impl Into<String>,
35 valid: bool,
36 message: impl Into<String>,
37 metadata: BTreeMap<String, JsonValue>,
38 ) -> Self {
39 Self {
40 provider: provider.into(),
41 valid,
42 message: message.into(),
43 metadata,
44 }
45 }
46}
47
48pub async fn run_provider_healthcheck(provider: &str) -> ProviderHealthcheckResult {
49 run_provider_healthcheck_with_options(provider, ProviderHealthcheckOptions::default()).await
50}
51
52pub async fn run_provider_healthcheck_with_options(
53 provider: &str,
54 options: ProviderHealthcheckOptions,
55) -> ProviderHealthcheckResult {
56 let provider = if provider.trim().is_empty() {
57 "anthropic"
58 } else {
59 provider.trim()
60 };
61
62 let Some(def) = llm_config::provider_config(provider) else {
63 let mut metadata = base_metadata("unknown_provider");
64 metadata.insert("provider".to_string(), json!(provider));
65 return ProviderHealthcheckResult::new(
66 provider,
67 false,
68 format!("Unknown provider: {provider}"),
69 metadata,
70 );
71 };
72
73 let Some(healthcheck) = def.healthcheck.as_ref() else {
74 let mut metadata = base_metadata("no_healthcheck");
75 metadata.insert("provider".to_string(), json!(provider));
76 return ProviderHealthcheckResult::new(
77 provider,
78 false,
79 format!("No healthcheck configured for {provider}"),
80 metadata,
81 );
82 };
83
84 let auth = resolve_healthcheck_auth(&def, options.api_key);
85 if auth.requires_auth && auth.api_key.is_none() {
86 let mut metadata = base_metadata("missing_credentials");
87 metadata.insert("provider".to_string(), json!(provider));
88 metadata.insert("auth_env".to_string(), json!(auth.candidates));
89 return ProviderHealthcheckResult::new(
90 provider,
91 false,
92 format!(
93 "Missing credentials for {provider}: set {} or pass an api_key",
94 auth.candidates.join(", ")
95 ),
96 metadata,
97 );
98 }
99
100 let url = build_healthcheck_url(&def, healthcheck);
101 let method = Method::from_bytes(healthcheck.method.as_bytes()).unwrap_or(Method::GET);
102 let client = match options.client {
103 Some(client) => client,
104 None => match reqwest::Client::builder()
105 .timeout(Duration::from_secs(DEFAULT_HEALTHCHECK_TIMEOUT_SECS))
106 .build()
107 {
108 Ok(client) => client,
109 Err(error) => {
110 let mut metadata = base_metadata("client_build_failed");
111 metadata.insert("provider".to_string(), json!(provider));
112 return ProviderHealthcheckResult::new(
113 provider,
114 false,
115 format!("{provider} healthcheck failed: {error}"),
116 metadata,
117 );
118 }
119 },
120 };
121
122 let mut request = client.request(method.clone(), &url);
123 if let Some(api_key) = auth.api_key.as_deref() {
124 request = apply_auth_headers(request, api_key, Some(&def));
125 }
126 for (name, value) in &def.extra_headers {
127 request = request.header(name, value);
128 }
129 if let Some(body) = &healthcheck.body {
130 request = request
131 .header(reqwest::header::CONTENT_TYPE, "application/json")
132 .body(body.clone());
133 }
134
135 match request.send().await {
136 Ok(response) => {
137 let status = response.status();
138 let status_code = status.as_u16();
139 let valid = status.is_success();
140 let body_text = response.text().await.unwrap_or_default();
141 let mut metadata = base_metadata(if valid { "ok" } else { "http_status" });
142 metadata.insert("provider".to_string(), json!(provider));
143 metadata.insert("status".to_string(), json!(status_code));
144 metadata.insert("url".to_string(), json!(url));
145 metadata.insert("method".to_string(), json!(method.as_str()));
146 if !valid && !body_text.is_empty() {
147 metadata.insert("body".to_string(), json!(body_snippet(&body_text)));
148 }
149
150 let message = if valid {
151 format!("{provider} is reachable (HTTP {status_code})")
152 } else {
153 let suffix = body_snippet(&body_text);
154 if suffix.is_empty() {
155 format!("{provider} returned HTTP {status_code}")
156 } else {
157 format!("{provider} returned HTTP {status_code}: {suffix}")
158 }
159 };
160
161 ProviderHealthcheckResult::new(provider, valid, message, metadata)
162 }
163 Err(error) => {
164 let mut metadata = base_metadata("request_failed");
165 metadata.insert("provider".to_string(), json!(provider));
166 metadata.insert("url".to_string(), json!(url));
167 metadata.insert("method".to_string(), json!(method.as_str()));
168 ProviderHealthcheckResult::new(
169 provider,
170 false,
171 format!("{provider} healthcheck failed: {error}"),
172 metadata,
173 )
174 }
175 }
176}
177
178pub fn build_healthcheck_url(def: &ProviderDef, healthcheck: &HealthcheckDef) -> String {
179 if let Some(url) = &healthcheck.url {
180 return url.clone();
181 }
182
183 let base = llm_config::resolve_base_url(def);
184 let path = healthcheck.path.as_deref().unwrap_or("");
185 if path.starts_with('/') {
186 format!("{}{}", base.trim_end_matches('/'), path)
187 } else if path.is_empty() {
188 base
189 } else {
190 format!("{}/{}", base.trim_end_matches('/'), path)
191 }
192}
193
194#[derive(Debug, Clone)]
195struct ResolvedHealthcheckAuth {
196 requires_auth: bool,
197 api_key: Option<String>,
198 candidates: Vec<String>,
199}
200
201fn resolve_healthcheck_auth(
202 def: &ProviderDef,
203 api_key_override: Option<String>,
204) -> ResolvedHealthcheckAuth {
205 let candidates = auth_env_candidates(&def.auth_env);
206 if def.auth_style == "none" || matches!(def.auth_env, AuthEnv::None) {
207 let api_key = api_key_override.and_then(non_empty);
208 return ResolvedHealthcheckAuth {
209 requires_auth: api_key.is_some(),
210 api_key,
211 candidates,
212 };
213 }
214
215 let api_key = api_key_override
216 .and_then(non_empty)
217 .or_else(|| resolve_api_key_from_env(&def.auth_env));
218 ResolvedHealthcheckAuth {
219 requires_auth: true,
220 api_key,
221 candidates,
222 }
223}
224
225fn auth_env_candidates(auth_env: &AuthEnv) -> Vec<String> {
226 match auth_env {
227 AuthEnv::None => Vec::new(),
228 AuthEnv::Single(env) => vec![env.clone()],
229 AuthEnv::Multiple(envs) => envs.clone(),
230 }
231}
232
233fn resolve_api_key_from_env(auth_env: &AuthEnv) -> Option<String> {
234 match auth_env {
235 AuthEnv::None => None,
236 AuthEnv::Single(env) => std::env::var(env).ok().and_then(non_empty),
237 AuthEnv::Multiple(envs) => envs
238 .iter()
239 .find_map(|env| std::env::var(env).ok().and_then(non_empty)),
240 }
241}
242
243fn non_empty(value: String) -> Option<String> {
244 let trimmed = value.trim();
245 if trimmed.is_empty() {
246 None
247 } else {
248 Some(trimmed.to_string())
249 }
250}
251
252fn base_metadata(reason: &str) -> BTreeMap<String, JsonValue> {
253 BTreeMap::from([("reason".to_string(), json!(reason))])
254}
255
256fn body_snippet(body: &str) -> String {
257 let mut snippet = String::new();
258 for ch in body.chars().take(BODY_SNIPPET_LIMIT) {
259 snippet.push(ch);
260 }
261 snippet
262}
263
264#[cfg(test)]
265mod tests {
266 use std::io::{Read, Write};
267 use std::net::TcpListener;
268 use std::sync::{Arc, Mutex};
269
270 use super::*;
271
272 fn provider_with_healthcheck(base_url: String, healthcheck: HealthcheckDef) -> ProviderDef {
273 ProviderDef {
274 base_url,
275 auth_style: "bearer".to_string(),
276 auth_env: AuthEnv::Single("HARN_TEST_PROVIDER_KEY".to_string()),
277 extra_headers: BTreeMap::from([("x-extra".to_string(), "extra-value".to_string())]),
278 chat_endpoint: "/chat/completions".to_string(),
279 healthcheck: Some(healthcheck),
280 ..Default::default()
281 }
282 }
283
284 fn install_provider(name: &str, provider: ProviderDef) {
285 let mut config = llm_config::ProvidersConfig::default();
286 config.providers.insert(name.to_string(), provider);
287 llm_config::set_user_overrides(Some(config));
288 }
289
290 fn spawn_healthcheck_stub(
291 status: u16,
292 body: &'static str,
293 captured: Arc<Mutex<Option<String>>>,
294 ) -> (String, std::thread::JoinHandle<()>) {
295 let listener = TcpListener::bind("127.0.0.1:0").expect("bind healthcheck stub");
296 let addr = listener.local_addr().expect("stub addr");
297 listener
298 .set_nonblocking(true)
299 .expect("set listener nonblocking");
300
301 let handle = std::thread::spawn(move || {
307 let deadline = std::time::Instant::now() + std::time::Duration::from_secs(30);
308 let (mut stream, _) = loop {
309 match listener.accept() {
310 Ok(pair) => break pair,
311 Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {
312 if std::time::Instant::now() >= deadline {
313 panic!("healthcheck stub: no client within 30s");
314 }
315 std::thread::sleep(std::time::Duration::from_millis(10));
316 }
317 Err(error) => panic!("healthcheck stub accept failed: {error}"),
318 }
319 };
320 stream
321 .set_nonblocking(false)
322 .expect("set accepted stream blocking");
323 stream
324 .set_read_timeout(Some(std::time::Duration::from_secs(30)))
325 .ok();
326 stream
327 .set_write_timeout(Some(std::time::Duration::from_secs(30)))
328 .ok();
329
330 let mut bytes = Vec::new();
331 let mut buf = [0u8; 4096];
332 loop {
333 let n = stream.read(&mut buf).expect("read request");
334 if n == 0 {
335 break;
336 }
337 bytes.extend_from_slice(&buf[..n]);
338 let request = String::from_utf8_lossy(&bytes);
339 if request_complete(&request) {
340 break;
341 }
342 }
343 *captured.lock().expect("capture request") =
344 Some(String::from_utf8_lossy(&bytes).to_string());
345
346 let response = format!(
347 "HTTP/1.1 {status} OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
348 body.len()
349 );
350 stream
351 .write_all(response.as_bytes())
352 .expect("write response");
353 });
354
355 (format!("http://{addr}"), handle)
356 }
357
358 fn request_complete(request: &str) -> bool {
359 let Some((headers, body)) = request.split_once("\r\n\r\n") else {
360 return false;
361 };
362 let content_length = headers
363 .lines()
364 .find_map(|line| line.strip_prefix("content-length: "))
365 .or_else(|| {
366 headers
367 .lines()
368 .find_map(|line| line.strip_prefix("Content-Length: "))
369 })
370 .and_then(|value| value.trim().parse::<usize>().ok())
371 .unwrap_or(0);
372 body.len() >= content_length
373 }
374
375 #[tokio::test(flavor = "current_thread")]
376 #[allow(clippy::await_holding_lock)]
377 async fn sends_configured_probe_request_with_candidate_key() {
378 let _guard = crate::llm::env_lock().lock().expect("env lock");
379 let captured = Arc::new(Mutex::new(None));
380 let (base_url, server) = spawn_healthcheck_stub(200, r#"{"ok":true}"#, captured.clone());
381 install_provider(
382 "acme",
383 provider_with_healthcheck(
384 base_url.clone(),
385 HealthcheckDef {
386 method: "POST".to_string(),
387 path: Some("probe".to_string()),
388 url: None,
389 body: Some(r#"{"ping":true}"#.to_string()),
390 },
391 ),
392 );
393
394 let result = run_provider_healthcheck_with_options(
395 "acme",
396 ProviderHealthcheckOptions {
397 api_key: Some("candidate-key".to_string()),
398 client: None,
399 },
400 )
401 .await;
402 server.join().expect("stub server");
403 llm_config::clear_user_overrides();
404
405 assert!(result.valid);
406 assert_eq!(result.provider, "acme");
407 assert_eq!(result.metadata["status"], json!(200));
408 assert_eq!(result.metadata["method"], json!("POST"));
409 assert_eq!(result.metadata["url"], json!(format!("{base_url}/probe")));
410
411 let request = captured
412 .lock()
413 .expect("captured request")
414 .clone()
415 .expect("request");
416 assert!(request.starts_with("POST /probe HTTP/1.1\r\n"));
417 assert!(request.contains("authorization: Bearer candidate-key\r\n"));
418 assert!(request.contains("x-extra: extra-value\r\n"));
419 assert!(request.ends_with(r#"{"ping":true}"#));
420 }
421
422 #[tokio::test(flavor = "current_thread")]
423 #[allow(clippy::await_holding_lock)]
424 async fn reports_missing_credentials_without_network() {
425 let _guard = crate::llm::env_lock().lock().expect("env lock");
426 unsafe {
427 std::env::remove_var("HARN_TEST_PROVIDER_KEY");
428 }
429 install_provider(
430 "acme-missing-key",
431 provider_with_healthcheck(
432 "http://127.0.0.1:9".to_string(),
433 HealthcheckDef {
434 method: "GET".to_string(),
435 path: Some("/models".to_string()),
436 url: None,
437 body: None,
438 },
439 ),
440 );
441
442 let result = run_provider_healthcheck("acme-missing-key").await;
443 llm_config::clear_user_overrides();
444
445 assert!(!result.valid);
446 assert_eq!(result.metadata["reason"], json!("missing_credentials"));
447 assert_eq!(
448 result.metadata["auth_env"],
449 json!(["HARN_TEST_PROVIDER_KEY"])
450 );
451 assert!(result.message.contains("Missing credentials"));
452 }
453
454 #[tokio::test(flavor = "current_thread")]
455 #[allow(clippy::await_holding_lock)]
456 async fn returns_stable_failure_shape_for_http_errors() {
457 let _guard = crate::llm::env_lock().lock().expect("env lock");
458 let captured = Arc::new(Mutex::new(None));
459 let (base_url, server) = spawn_healthcheck_stub(401, r#"{"error":"bad key"}"#, captured);
460 install_provider(
461 "acme-auth",
462 provider_with_healthcheck(
463 base_url,
464 HealthcheckDef {
465 method: "GET".to_string(),
466 path: Some("/models".to_string()),
467 url: None,
468 body: None,
469 },
470 ),
471 );
472
473 let result = run_provider_healthcheck_with_options(
474 "acme-auth",
475 ProviderHealthcheckOptions {
476 api_key: Some("bad-key".to_string()),
477 client: None,
478 },
479 )
480 .await;
481 server.join().expect("stub server");
482 llm_config::clear_user_overrides();
483
484 assert!(!result.valid);
485 assert_eq!(result.provider, "acme-auth");
486 assert_eq!(result.metadata["reason"], json!("http_status"));
487 assert_eq!(result.metadata["status"], json!(401));
488 assert_eq!(result.metadata["body"], json!(r#"{"error":"bad key"}"#));
489 }
490
491 #[test]
492 fn default_external_provider_catalog_has_healthchecks() {
493 for provider in [
494 "openrouter",
495 "anthropic",
496 "openai",
497 "huggingface",
498 "together",
499 ] {
500 let config = llm_config::provider_config(provider)
501 .unwrap_or_else(|| panic!("missing provider {provider}"));
502 let healthcheck = config
503 .healthcheck
504 .as_ref()
505 .unwrap_or_else(|| panic!("missing healthcheck for {provider}"));
506 assert!(!healthcheck.method.is_empty());
507 assert!(healthcheck.path.is_some() || healthcheck.url.is_some());
508 }
509 }
510}