Skip to main content

harn_vm/llm/
healthcheck.rs

1use std::collections::BTreeMap;
2use std::time::Duration;
3
4use reqwest::Method;
5use serde::{Deserialize, Serialize};
6use serde_json::{json, Value as JsonValue};
7
8use crate::llm_config::{self, AuthEnv, HealthcheckDef, ProviderDef};
9
10use super::api::apply_auth_headers;
11
12const DEFAULT_HEALTHCHECK_TIMEOUT_SECS: u64 = 5;
13const BODY_SNIPPET_LIMIT: usize = 1000;
14
15#[derive(Debug, Clone, Default)]
16pub struct ProviderHealthcheckOptions {
17    /// Candidate API key to validate. When unset, Harn resolves credentials
18    /// from the provider's configured environment variables.
19    pub api_key: Option<String>,
20    /// Optional client override for hosts that need custom transport policy.
21    pub client: Option<reqwest::Client>,
22}
23
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25pub struct ProviderHealthcheckResult {
26    pub provider: String,
27    pub valid: bool,
28    pub message: String,
29    pub metadata: BTreeMap<String, JsonValue>,
30}
31
32impl ProviderHealthcheckResult {
33    fn new(
34        provider: impl Into<String>,
35        valid: bool,
36        message: impl Into<String>,
37        metadata: BTreeMap<String, JsonValue>,
38    ) -> Self {
39        Self {
40            provider: provider.into(),
41            valid,
42            message: message.into(),
43            metadata,
44        }
45    }
46}
47
48pub async fn run_provider_healthcheck(provider: &str) -> ProviderHealthcheckResult {
49    run_provider_healthcheck_with_options(provider, ProviderHealthcheckOptions::default()).await
50}
51
52pub async fn run_provider_healthcheck_with_options(
53    provider: &str,
54    options: ProviderHealthcheckOptions,
55) -> ProviderHealthcheckResult {
56    let provider = if provider.trim().is_empty() {
57        "anthropic"
58    } else {
59        provider.trim()
60    };
61
62    let Some(def) = llm_config::provider_config(provider) else {
63        let mut metadata = base_metadata("unknown_provider");
64        metadata.insert("provider".to_string(), json!(provider));
65        return ProviderHealthcheckResult::new(
66            provider,
67            false,
68            format!("Unknown provider: {provider}"),
69            metadata,
70        );
71    };
72
73    let Some(healthcheck) = def.healthcheck.as_ref() else {
74        let mut metadata = base_metadata("no_healthcheck");
75        metadata.insert("provider".to_string(), json!(provider));
76        return ProviderHealthcheckResult::new(
77            provider,
78            false,
79            format!("No healthcheck configured for {provider}"),
80            metadata,
81        );
82    };
83
84    let auth = resolve_healthcheck_auth(&def, options.api_key);
85    if auth.requires_auth && auth.api_key.is_none() {
86        let mut metadata = base_metadata("missing_credentials");
87        metadata.insert("provider".to_string(), json!(provider));
88        metadata.insert("auth_env".to_string(), json!(auth.candidates));
89        return ProviderHealthcheckResult::new(
90            provider,
91            false,
92            format!(
93                "Missing credentials for {provider}: set {} or pass an api_key",
94                auth.candidates.join(", ")
95            ),
96            metadata,
97        );
98    }
99
100    let url = build_healthcheck_url(&def, healthcheck);
101    let method = Method::from_bytes(healthcheck.method.as_bytes()).unwrap_or(Method::GET);
102    let client = match options.client {
103        Some(client) => client,
104        None => match reqwest::Client::builder()
105            .timeout(Duration::from_secs(DEFAULT_HEALTHCHECK_TIMEOUT_SECS))
106            .build()
107        {
108            Ok(client) => client,
109            Err(error) => {
110                let mut metadata = base_metadata("client_build_failed");
111                metadata.insert("provider".to_string(), json!(provider));
112                return ProviderHealthcheckResult::new(
113                    provider,
114                    false,
115                    format!("{provider} healthcheck failed: {error}"),
116                    metadata,
117                );
118            }
119        },
120    };
121
122    let mut request = client.request(method.clone(), &url);
123    if let Some(api_key) = auth.api_key.as_deref() {
124        request = apply_auth_headers(request, api_key, Some(&def));
125    }
126    for (name, value) in &def.extra_headers {
127        request = request.header(name, value);
128    }
129    if let Some(body) = &healthcheck.body {
130        request = request
131            .header(reqwest::header::CONTENT_TYPE, "application/json")
132            .body(body.clone());
133    }
134
135    match request.send().await {
136        Ok(response) => {
137            let status = response.status();
138            let status_code = status.as_u16();
139            let valid = status.is_success();
140            let body_text = response.text().await.unwrap_or_default();
141            let mut metadata = base_metadata(if valid { "ok" } else { "http_status" });
142            metadata.insert("provider".to_string(), json!(provider));
143            metadata.insert("status".to_string(), json!(status_code));
144            metadata.insert("url".to_string(), json!(url));
145            metadata.insert("method".to_string(), json!(method.as_str()));
146            if !valid && !body_text.is_empty() {
147                metadata.insert("body".to_string(), json!(body_snippet(&body_text)));
148            }
149
150            let message = if valid {
151                format!("{provider} is reachable (HTTP {status_code})")
152            } else {
153                let suffix = body_snippet(&body_text);
154                if suffix.is_empty() {
155                    format!("{provider} returned HTTP {status_code}")
156                } else {
157                    format!("{provider} returned HTTP {status_code}: {suffix}")
158                }
159            };
160
161            ProviderHealthcheckResult::new(provider, valid, message, metadata)
162        }
163        Err(error) => {
164            let mut metadata = base_metadata("request_failed");
165            metadata.insert("provider".to_string(), json!(provider));
166            metadata.insert("url".to_string(), json!(url));
167            metadata.insert("method".to_string(), json!(method.as_str()));
168            ProviderHealthcheckResult::new(
169                provider,
170                false,
171                format!("{provider} healthcheck failed: {error}"),
172                metadata,
173            )
174        }
175    }
176}
177
178pub fn build_healthcheck_url(def: &ProviderDef, healthcheck: &HealthcheckDef) -> String {
179    if let Some(url) = &healthcheck.url {
180        return url.clone();
181    }
182
183    let base = llm_config::resolve_base_url(def);
184    let path = healthcheck.path.as_deref().unwrap_or("");
185    if path.starts_with('/') {
186        format!("{}{}", base.trim_end_matches('/'), path)
187    } else if path.is_empty() {
188        base
189    } else {
190        format!("{}/{}", base.trim_end_matches('/'), path)
191    }
192}
193
194#[derive(Debug, Clone)]
195struct ResolvedHealthcheckAuth {
196    requires_auth: bool,
197    api_key: Option<String>,
198    candidates: Vec<String>,
199}
200
201fn resolve_healthcheck_auth(
202    def: &ProviderDef,
203    api_key_override: Option<String>,
204) -> ResolvedHealthcheckAuth {
205    let candidates = auth_env_candidates(&def.auth_env);
206    if def.auth_style == "none" || matches!(def.auth_env, AuthEnv::None) {
207        let api_key = api_key_override.and_then(non_empty);
208        return ResolvedHealthcheckAuth {
209            requires_auth: api_key.is_some(),
210            api_key,
211            candidates,
212        };
213    }
214
215    let api_key = api_key_override
216        .and_then(non_empty)
217        .or_else(|| resolve_api_key_from_env(&def.auth_env));
218    ResolvedHealthcheckAuth {
219        requires_auth: true,
220        api_key,
221        candidates,
222    }
223}
224
225fn auth_env_candidates(auth_env: &AuthEnv) -> Vec<String> {
226    match auth_env {
227        AuthEnv::None => Vec::new(),
228        AuthEnv::Single(env) => vec![env.clone()],
229        AuthEnv::Multiple(envs) => envs.clone(),
230    }
231}
232
233fn resolve_api_key_from_env(auth_env: &AuthEnv) -> Option<String> {
234    match auth_env {
235        AuthEnv::None => None,
236        AuthEnv::Single(env) => std::env::var(env).ok().and_then(non_empty),
237        AuthEnv::Multiple(envs) => envs
238            .iter()
239            .find_map(|env| std::env::var(env).ok().and_then(non_empty)),
240    }
241}
242
243fn non_empty(value: String) -> Option<String> {
244    let trimmed = value.trim();
245    if trimmed.is_empty() {
246        None
247    } else {
248        Some(trimmed.to_string())
249    }
250}
251
252fn base_metadata(reason: &str) -> BTreeMap<String, JsonValue> {
253    BTreeMap::from([("reason".to_string(), json!(reason))])
254}
255
256fn body_snippet(body: &str) -> String {
257    let mut snippet = String::new();
258    for ch in body.chars().take(BODY_SNIPPET_LIMIT) {
259        snippet.push(ch);
260    }
261    snippet
262}
263
264#[cfg(test)]
265mod tests {
266    use std::io::{Read, Write};
267    use std::net::TcpListener;
268    use std::sync::{Arc, Mutex};
269
270    use super::*;
271
272    fn provider_with_healthcheck(base_url: String, healthcheck: HealthcheckDef) -> ProviderDef {
273        ProviderDef {
274            base_url,
275            auth_style: "bearer".to_string(),
276            auth_env: AuthEnv::Single("HARN_TEST_PROVIDER_KEY".to_string()),
277            extra_headers: BTreeMap::from([("x-extra".to_string(), "extra-value".to_string())]),
278            chat_endpoint: "/chat/completions".to_string(),
279            healthcheck: Some(healthcheck),
280            ..Default::default()
281        }
282    }
283
284    fn install_provider(name: &str, provider: ProviderDef) {
285        let mut config = llm_config::ProvidersConfig::default();
286        config.providers.insert(name.to_string(), provider);
287        llm_config::set_user_overrides(Some(config));
288    }
289
290    fn spawn_healthcheck_stub(
291        status: u16,
292        body: &'static str,
293        captured: Arc<Mutex<Option<String>>>,
294    ) -> (String, std::thread::JoinHandle<()>) {
295        let listener = TcpListener::bind("127.0.0.1:0").expect("bind healthcheck stub");
296        let addr = listener.local_addr().expect("stub addr");
297
298        // Block on `accept()` directly. The earlier nonblocking poll +
299        // 30s wall-clock deadline introduced two failure modes under
300        // CI fan-out: (1) a 10ms tick stretched under load, delaying
301        // accept past the client's connect timeout; (2) the deadline
302        // wall-clock could elapse before the polling thread woke,
303        // panicking even though the client was already connected.
304        // Blocking accept is deterministic and cannot starve. The
305        // test always sends a request, so the accept always returns;
306        // if a test panics before sending, the leaked thread is
307        // captured by nextest's `leak-timeout = "fail"` rather than
308        // hidden behind a synthetic panic.
309        let handle = std::thread::spawn(move || {
310            let (mut stream, _) = listener
311                .accept()
312                .unwrap_or_else(|e| panic!("healthcheck stub accept failed: {e}"));
313            stream
314                .set_read_timeout(Some(std::time::Duration::from_secs(30)))
315                .ok();
316            stream
317                .set_write_timeout(Some(std::time::Duration::from_secs(30)))
318                .ok();
319
320            let mut bytes = Vec::new();
321            let mut buf = [0u8; 4096];
322            loop {
323                let n = stream.read(&mut buf).expect("read request");
324                if n == 0 {
325                    break;
326                }
327                bytes.extend_from_slice(&buf[..n]);
328                let request = String::from_utf8_lossy(&bytes);
329                if request_complete(&request) {
330                    break;
331                }
332            }
333            *captured.lock().expect("capture request") =
334                Some(String::from_utf8_lossy(&bytes).to_string());
335
336            let response = format!(
337                "HTTP/1.1 {status} OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
338                body.len()
339            );
340            stream
341                .write_all(response.as_bytes())
342                .expect("write response");
343        });
344
345        (format!("http://{addr}"), handle)
346    }
347
348    fn request_complete(request: &str) -> bool {
349        let Some((headers, body)) = request.split_once("\r\n\r\n") else {
350            return false;
351        };
352        let content_length = headers
353            .lines()
354            .find_map(|line| line.strip_prefix("content-length: "))
355            .or_else(|| {
356                headers
357                    .lines()
358                    .find_map(|line| line.strip_prefix("Content-Length: "))
359            })
360            .and_then(|value| value.trim().parse::<usize>().ok())
361            .unwrap_or(0);
362        body.len() >= content_length
363    }
364
365    #[tokio::test(flavor = "current_thread")]
366    #[allow(clippy::await_holding_lock)]
367    async fn sends_configured_probe_request_with_candidate_key() {
368        let _guard = crate::llm::env_lock().lock().expect("env lock");
369        let captured = Arc::new(Mutex::new(None));
370        let (base_url, server) = spawn_healthcheck_stub(200, r#"{"ok":true}"#, captured.clone());
371        install_provider(
372            "acme",
373            provider_with_healthcheck(
374                base_url.clone(),
375                HealthcheckDef {
376                    method: "POST".to_string(),
377                    path: Some("probe".to_string()),
378                    url: None,
379                    body: Some(r#"{"ping":true}"#.to_string()),
380                },
381            ),
382        );
383
384        let result = run_provider_healthcheck_with_options(
385            "acme",
386            ProviderHealthcheckOptions {
387                api_key: Some("candidate-key".to_string()),
388                client: None,
389            },
390        )
391        .await;
392        server.join().expect("stub server");
393        llm_config::clear_user_overrides();
394
395        assert!(result.valid);
396        assert_eq!(result.provider, "acme");
397        assert_eq!(result.metadata["status"], json!(200));
398        assert_eq!(result.metadata["method"], json!("POST"));
399        assert_eq!(result.metadata["url"], json!(format!("{base_url}/probe")));
400
401        let request = captured
402            .lock()
403            .expect("captured request")
404            .clone()
405            .expect("request");
406        assert!(request.starts_with("POST /probe HTTP/1.1\r\n"));
407        assert!(request.contains("authorization: Bearer candidate-key\r\n"));
408        assert!(request.contains("x-extra: extra-value\r\n"));
409        assert!(request.ends_with(r#"{"ping":true}"#));
410    }
411
412    #[tokio::test(flavor = "current_thread")]
413    #[allow(clippy::await_holding_lock)]
414    async fn reports_missing_credentials_without_network() {
415        let _guard = crate::llm::env_lock().lock().expect("env lock");
416        unsafe {
417            std::env::remove_var("HARN_TEST_PROVIDER_KEY");
418        }
419        install_provider(
420            "acme-missing-key",
421            provider_with_healthcheck(
422                "http://127.0.0.1:9".to_string(),
423                HealthcheckDef {
424                    method: "GET".to_string(),
425                    path: Some("/models".to_string()),
426                    url: None,
427                    body: None,
428                },
429            ),
430        );
431
432        let result = run_provider_healthcheck("acme-missing-key").await;
433        llm_config::clear_user_overrides();
434
435        assert!(!result.valid);
436        assert_eq!(result.metadata["reason"], json!("missing_credentials"));
437        assert_eq!(
438            result.metadata["auth_env"],
439            json!(["HARN_TEST_PROVIDER_KEY"])
440        );
441        assert!(result.message.contains("Missing credentials"));
442    }
443
444    #[tokio::test(flavor = "current_thread")]
445    #[allow(clippy::await_holding_lock)]
446    async fn returns_stable_failure_shape_for_http_errors() {
447        let _guard = crate::llm::env_lock().lock().expect("env lock");
448        let captured = Arc::new(Mutex::new(None));
449        let (base_url, server) = spawn_healthcheck_stub(401, r#"{"error":"bad key"}"#, captured);
450        install_provider(
451            "acme-auth",
452            provider_with_healthcheck(
453                base_url,
454                HealthcheckDef {
455                    method: "GET".to_string(),
456                    path: Some("/models".to_string()),
457                    url: None,
458                    body: None,
459                },
460            ),
461        );
462
463        let result = run_provider_healthcheck_with_options(
464            "acme-auth",
465            ProviderHealthcheckOptions {
466                api_key: Some("bad-key".to_string()),
467                client: None,
468            },
469        )
470        .await;
471        server.join().expect("stub server");
472        llm_config::clear_user_overrides();
473
474        assert!(!result.valid);
475        assert_eq!(result.provider, "acme-auth");
476        assert_eq!(result.metadata["reason"], json!("http_status"));
477        assert_eq!(result.metadata["status"], json!(401));
478        assert_eq!(result.metadata["body"], json!(r#"{"error":"bad key"}"#));
479    }
480
481    #[test]
482    fn default_external_provider_catalog_has_healthchecks() {
483        for provider in [
484            "openrouter",
485            "anthropic",
486            "openai",
487            "huggingface",
488            "together",
489        ] {
490            let config = llm_config::provider_config(provider)
491                .unwrap_or_else(|| panic!("missing provider {provider}"));
492            let healthcheck = config
493                .healthcheck
494                .as_ref()
495                .unwrap_or_else(|| panic!("missing healthcheck for {provider}"));
496            assert!(!healthcheck.method.is_empty());
497            assert!(healthcheck.path.is_some() || healthcheck.url.is_some());
498        }
499    }
500}