1use std::collections::BTreeMap;
2use std::time::Duration;
3
4use reqwest::Method;
5use serde::{Deserialize, Serialize};
6use serde_json::{json, Value as JsonValue};
7
8use crate::llm_config::{self, AuthEnv, HealthcheckDef, ProviderDef};
9
10use super::api::apply_auth_headers;
11
12const DEFAULT_HEALTHCHECK_TIMEOUT_SECS: u64 = 5;
13const BODY_SNIPPET_LIMIT: usize = 1000;
14
15#[derive(Debug, Clone, Default)]
16pub struct ProviderHealthcheckOptions {
17 pub api_key: Option<String>,
20 pub client: Option<reqwest::Client>,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
25pub struct ProviderHealthcheckResult {
26 pub provider: String,
27 pub valid: bool,
28 pub message: String,
29 pub metadata: BTreeMap<String, JsonValue>,
30}
31
32impl ProviderHealthcheckResult {
33 fn new(
34 provider: impl Into<String>,
35 valid: bool,
36 message: impl Into<String>,
37 metadata: BTreeMap<String, JsonValue>,
38 ) -> Self {
39 Self {
40 provider: provider.into(),
41 valid,
42 message: message.into(),
43 metadata,
44 }
45 }
46}
47
48pub async fn run_provider_healthcheck(provider: &str) -> ProviderHealthcheckResult {
49 run_provider_healthcheck_with_options(provider, ProviderHealthcheckOptions::default()).await
50}
51
52pub async fn run_provider_healthcheck_with_options(
53 provider: &str,
54 options: ProviderHealthcheckOptions,
55) -> ProviderHealthcheckResult {
56 let provider = if provider.trim().is_empty() {
57 "anthropic"
58 } else {
59 provider.trim()
60 };
61
62 let Some(def) = llm_config::provider_config(provider) else {
63 let mut metadata = base_metadata("unknown_provider");
64 metadata.insert("provider".to_string(), json!(provider));
65 return ProviderHealthcheckResult::new(
66 provider,
67 false,
68 format!("Unknown provider: {provider}"),
69 metadata,
70 );
71 };
72
73 let Some(healthcheck) = def.healthcheck.as_ref() else {
74 let mut metadata = base_metadata("no_healthcheck");
75 metadata.insert("provider".to_string(), json!(provider));
76 return ProviderHealthcheckResult::new(
77 provider,
78 false,
79 format!("No healthcheck configured for {provider}"),
80 metadata,
81 );
82 };
83
84 let auth = resolve_healthcheck_auth(&def, options.api_key);
85 if auth.requires_auth && auth.api_key.is_none() {
86 let mut metadata = base_metadata("missing_credentials");
87 metadata.insert("provider".to_string(), json!(provider));
88 metadata.insert("auth_env".to_string(), json!(auth.candidates));
89 return ProviderHealthcheckResult::new(
90 provider,
91 false,
92 format!(
93 "Missing credentials for {provider}: set {} or pass an api_key",
94 auth.candidates.join(", ")
95 ),
96 metadata,
97 );
98 }
99
100 let url = build_healthcheck_url(&def, healthcheck);
101 let method = Method::from_bytes(healthcheck.method.as_bytes()).unwrap_or(Method::GET);
102 let client = match options.client {
103 Some(client) => client,
104 None => match reqwest::Client::builder()
105 .timeout(Duration::from_secs(DEFAULT_HEALTHCHECK_TIMEOUT_SECS))
106 .build()
107 {
108 Ok(client) => client,
109 Err(error) => {
110 let mut metadata = base_metadata("client_build_failed");
111 metadata.insert("provider".to_string(), json!(provider));
112 return ProviderHealthcheckResult::new(
113 provider,
114 false,
115 format!("{provider} healthcheck failed: {error}"),
116 metadata,
117 );
118 }
119 },
120 };
121
122 let mut request = client.request(method.clone(), &url);
123 if let Some(api_key) = auth.api_key.as_deref() {
124 request = apply_auth_headers(request, api_key, Some(&def));
125 }
126 for (name, value) in &def.extra_headers {
127 request = request.header(name, value);
128 }
129 if let Some(body) = &healthcheck.body {
130 request = request
131 .header(reqwest::header::CONTENT_TYPE, "application/json")
132 .body(body.clone());
133 }
134
135 match request.send().await {
136 Ok(response) => {
137 let status = response.status();
138 let status_code = status.as_u16();
139 let valid = status.is_success();
140 let body_text = response.text().await.unwrap_or_default();
141 let mut metadata = base_metadata(if valid { "ok" } else { "http_status" });
142 metadata.insert("provider".to_string(), json!(provider));
143 metadata.insert("status".to_string(), json!(status_code));
144 metadata.insert("url".to_string(), json!(url));
145 metadata.insert("method".to_string(), json!(method.as_str()));
146 if !valid && !body_text.is_empty() {
147 metadata.insert("body".to_string(), json!(body_snippet(&body_text)));
148 }
149
150 let message = if valid {
151 format!("{provider} is reachable (HTTP {status_code})")
152 } else {
153 let suffix = body_snippet(&body_text);
154 if suffix.is_empty() {
155 format!("{provider} returned HTTP {status_code}")
156 } else {
157 format!("{provider} returned HTTP {status_code}: {suffix}")
158 }
159 };
160
161 ProviderHealthcheckResult::new(provider, valid, message, metadata)
162 }
163 Err(error) => {
164 let mut metadata = base_metadata("request_failed");
165 metadata.insert("provider".to_string(), json!(provider));
166 metadata.insert("url".to_string(), json!(url));
167 metadata.insert("method".to_string(), json!(method.as_str()));
168 ProviderHealthcheckResult::new(
169 provider,
170 false,
171 format!("{provider} healthcheck failed: {error}"),
172 metadata,
173 )
174 }
175 }
176}
177
178pub fn build_healthcheck_url(def: &ProviderDef, healthcheck: &HealthcheckDef) -> String {
179 if let Some(url) = &healthcheck.url {
180 return url.clone();
181 }
182
183 let base = llm_config::resolve_base_url(def);
184 let path = healthcheck.path.as_deref().unwrap_or("");
185 if path.starts_with('/') {
186 format!("{}{}", base.trim_end_matches('/'), path)
187 } else if path.is_empty() {
188 base
189 } else {
190 format!("{}/{}", base.trim_end_matches('/'), path)
191 }
192}
193
194#[derive(Debug, Clone)]
195struct ResolvedHealthcheckAuth {
196 requires_auth: bool,
197 api_key: Option<String>,
198 candidates: Vec<String>,
199}
200
201fn resolve_healthcheck_auth(
202 def: &ProviderDef,
203 api_key_override: Option<String>,
204) -> ResolvedHealthcheckAuth {
205 let candidates = auth_env_candidates(&def.auth_env);
206 if def.auth_style == "none" || matches!(def.auth_env, AuthEnv::None) {
207 let api_key = api_key_override.and_then(non_empty);
208 return ResolvedHealthcheckAuth {
209 requires_auth: api_key.is_some(),
210 api_key,
211 candidates,
212 };
213 }
214
215 let api_key = api_key_override
216 .and_then(non_empty)
217 .or_else(|| resolve_api_key_from_env(&def.auth_env));
218 ResolvedHealthcheckAuth {
219 requires_auth: true,
220 api_key,
221 candidates,
222 }
223}
224
225fn auth_env_candidates(auth_env: &AuthEnv) -> Vec<String> {
226 match auth_env {
227 AuthEnv::None => Vec::new(),
228 AuthEnv::Single(env) => vec![env.clone()],
229 AuthEnv::Multiple(envs) => envs.clone(),
230 }
231}
232
233fn resolve_api_key_from_env(auth_env: &AuthEnv) -> Option<String> {
234 match auth_env {
235 AuthEnv::None => None,
236 AuthEnv::Single(env) => std::env::var(env).ok().and_then(non_empty),
237 AuthEnv::Multiple(envs) => envs
238 .iter()
239 .find_map(|env| std::env::var(env).ok().and_then(non_empty)),
240 }
241}
242
243fn non_empty(value: String) -> Option<String> {
244 let trimmed = value.trim();
245 if trimmed.is_empty() {
246 None
247 } else {
248 Some(trimmed.to_string())
249 }
250}
251
252fn base_metadata(reason: &str) -> BTreeMap<String, JsonValue> {
253 BTreeMap::from([("reason".to_string(), json!(reason))])
254}
255
256fn body_snippet(body: &str) -> String {
257 let mut snippet = String::new();
258 for ch in body.chars().take(BODY_SNIPPET_LIMIT) {
259 snippet.push(ch);
260 }
261 snippet
262}
263
264#[cfg(test)]
265mod tests {
266 use crate::http::framing::{http_content_length_from_header_lines, TEST_HTTP_MAX_BODY_BYTES};
267 use std::io::{Read, Write};
268 use std::net::TcpListener;
269 use std::sync::{Arc, Mutex};
270
271 use super::*;
272
273 fn provider_with_healthcheck(base_url: String, healthcheck: HealthcheckDef) -> ProviderDef {
274 ProviderDef {
275 base_url,
276 auth_style: "bearer".to_string(),
277 auth_env: AuthEnv::Single("HARN_TEST_PROVIDER_KEY".to_string()),
278 extra_headers: BTreeMap::from([("x-extra".to_string(), "extra-value".to_string())]),
279 chat_endpoint: "/chat/completions".to_string(),
280 healthcheck: Some(healthcheck),
281 ..Default::default()
282 }
283 }
284
285 fn install_provider(name: &str, provider: ProviderDef) {
286 let mut config = llm_config::ProvidersConfig::default();
287 config.providers.insert(name.to_string(), provider);
288 llm_config::set_user_overrides(Some(config));
289 }
290
291 fn spawn_healthcheck_stub(
292 status: u16,
293 body: &'static str,
294 captured: Arc<Mutex<Option<String>>>,
295 ) -> (String, std::thread::JoinHandle<()>) {
296 let listener = TcpListener::bind("127.0.0.1:0").expect("bind healthcheck stub");
297 let addr = listener.local_addr().expect("stub addr");
298
299 let handle = std::thread::spawn(move || {
311 let (mut stream, _) = listener
312 .accept()
313 .unwrap_or_else(|e| panic!("healthcheck stub accept failed: {e}"));
314 stream
315 .set_read_timeout(Some(std::time::Duration::from_secs(30)))
316 .ok();
317 stream
318 .set_write_timeout(Some(std::time::Duration::from_secs(30)))
319 .ok();
320
321 let mut bytes = Vec::new();
322 let mut buf = [0u8; 4096];
323 loop {
324 let n = stream.read(&mut buf).expect("read request");
325 if n == 0 {
326 break;
327 }
328 bytes.extend_from_slice(&buf[..n]);
329 let request = String::from_utf8_lossy(&bytes);
330 if request_complete(&request) {
331 break;
332 }
333 }
334 *captured.lock().expect("capture request") =
335 Some(String::from_utf8_lossy(&bytes).to_string());
336
337 let response = format!(
338 "HTTP/1.1 {status} OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
339 body.len()
340 );
341 stream
342 .write_all(response.as_bytes())
343 .expect("write response");
344 });
345
346 (format!("http://{addr}"), handle)
347 }
348
349 fn request_complete(request: &str) -> bool {
350 let Some((headers, body)) = request.split_once("\r\n\r\n") else {
351 return false;
352 };
353 let content_length = match http_content_length_from_header_lines(
354 headers.lines(),
355 TEST_HTTP_MAX_BODY_BYTES,
356 ) {
357 Ok(content_length) => content_length,
358 Err(_) => return true,
359 };
360 body.len() >= content_length
361 }
362
363 #[test]
364 fn request_complete_stops_on_oversized_content_length() {
365 let request = format!(
366 "POST /probe HTTP/1.1\r\nContent-Length: {}\r\n\r\n",
367 TEST_HTTP_MAX_BODY_BYTES + 1
368 );
369 assert!(request_complete(&request));
370 }
371
372 #[tokio::test(flavor = "current_thread")]
373 #[allow(clippy::await_holding_lock)]
374 async fn sends_configured_probe_request_with_candidate_key() {
375 let _guard = crate::llm::env_lock().lock().expect("env lock");
376 let captured = Arc::new(Mutex::new(None));
377 let (base_url, server) = spawn_healthcheck_stub(200, r#"{"ok":true}"#, captured.clone());
378 install_provider(
379 "acme",
380 provider_with_healthcheck(
381 base_url.clone(),
382 HealthcheckDef {
383 method: "POST".to_string(),
384 path: Some("probe".to_string()),
385 url: None,
386 body: Some(r#"{"ping":true}"#.to_string()),
387 },
388 ),
389 );
390
391 let result = run_provider_healthcheck_with_options(
392 "acme",
393 ProviderHealthcheckOptions {
394 api_key: Some("candidate-key".to_string()),
395 client: None,
396 },
397 )
398 .await;
399 server.join().expect("stub server");
400 llm_config::clear_user_overrides();
401
402 assert!(result.valid);
403 assert_eq!(result.provider, "acme");
404 assert_eq!(result.metadata["status"], json!(200));
405 assert_eq!(result.metadata["method"], json!("POST"));
406 assert_eq!(result.metadata["url"], json!(format!("{base_url}/probe")));
407
408 let request = captured
409 .lock()
410 .expect("captured request")
411 .clone()
412 .expect("request");
413 assert!(request.starts_with("POST /probe HTTP/1.1\r\n"));
414 assert!(request.contains("authorization: Bearer candidate-key\r\n"));
415 assert!(request.contains("x-extra: extra-value\r\n"));
416 assert!(request.ends_with(r#"{"ping":true}"#));
417 }
418
419 #[tokio::test(flavor = "current_thread")]
420 #[allow(clippy::await_holding_lock)]
421 async fn reports_missing_credentials_without_network() {
422 let _guard = crate::llm::env_lock().lock().expect("env lock");
423 unsafe {
424 std::env::remove_var("HARN_TEST_PROVIDER_KEY");
425 }
426 install_provider(
427 "acme-missing-key",
428 provider_with_healthcheck(
429 "http://127.0.0.1:9".to_string(),
430 HealthcheckDef {
431 method: "GET".to_string(),
432 path: Some("/models".to_string()),
433 url: None,
434 body: None,
435 },
436 ),
437 );
438
439 let result = run_provider_healthcheck("acme-missing-key").await;
440 llm_config::clear_user_overrides();
441
442 assert!(!result.valid);
443 assert_eq!(result.metadata["reason"], json!("missing_credentials"));
444 assert_eq!(
445 result.metadata["auth_env"],
446 json!(["HARN_TEST_PROVIDER_KEY"])
447 );
448 assert!(result.message.contains("Missing credentials"));
449 }
450
451 #[tokio::test(flavor = "current_thread")]
452 #[allow(clippy::await_holding_lock)]
453 async fn returns_stable_failure_shape_for_http_errors() {
454 let _guard = crate::llm::env_lock().lock().expect("env lock");
455 let captured = Arc::new(Mutex::new(None));
456 let (base_url, server) = spawn_healthcheck_stub(401, r#"{"error":"bad key"}"#, captured);
457 install_provider(
458 "acme-auth",
459 provider_with_healthcheck(
460 base_url,
461 HealthcheckDef {
462 method: "GET".to_string(),
463 path: Some("/models".to_string()),
464 url: None,
465 body: None,
466 },
467 ),
468 );
469
470 let result = run_provider_healthcheck_with_options(
471 "acme-auth",
472 ProviderHealthcheckOptions {
473 api_key: Some("bad-key".to_string()),
474 client: None,
475 },
476 )
477 .await;
478 server.join().expect("stub server");
479 llm_config::clear_user_overrides();
480
481 assert!(!result.valid);
482 assert_eq!(result.provider, "acme-auth");
483 assert_eq!(result.metadata["reason"], json!("http_status"));
484 assert_eq!(result.metadata["status"], json!(401));
485 assert_eq!(result.metadata["body"], json!(r#"{"error":"bad key"}"#));
486 }
487
488 #[test]
489 fn default_external_provider_catalog_has_healthchecks() {
490 for provider in [
491 "openrouter",
492 "anthropic",
493 "openai",
494 "huggingface",
495 "together",
496 ] {
497 let config = llm_config::provider_config(provider)
498 .unwrap_or_else(|| panic!("missing provider {provider}"));
499 let healthcheck = config
500 .healthcheck
501 .as_ref()
502 .unwrap_or_else(|| panic!("missing healthcheck for {provider}"));
503 assert!(!healthcheck.method.is_empty());
504 assert!(healthcheck.path.is_some() || healthcheck.url.is_some());
505 }
506 }
507}