Skip to main content

harn_vm/llm/
healthcheck.rs

1use std::collections::BTreeMap;
2use std::time::Duration;
3
4use reqwest::Method;
5use serde::{Deserialize, Serialize};
6use serde_json::{json, Value as JsonValue};
7
8use crate::llm_config::{self, AuthEnv, HealthcheckDef, ProviderDef};
9
10use super::api::apply_auth_headers;
11
12const DEFAULT_HEALTHCHECK_TIMEOUT_SECS: u64 = 5;
13const BODY_SNIPPET_LIMIT: usize = 1000;
14
15#[derive(Debug, Clone, Default)]
16pub struct ProviderHealthcheckOptions {
17    /// Candidate API key to validate. When unset, Harn resolves credentials
18    /// from the provider's configured environment variables.
19    pub api_key: Option<String>,
20    /// Optional client override for hosts that need custom transport policy.
21    pub client: Option<reqwest::Client>,
22}
23
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25pub struct ProviderHealthcheckResult {
26    pub provider: String,
27    pub valid: bool,
28    pub message: String,
29    pub metadata: BTreeMap<String, JsonValue>,
30}
31
32impl ProviderHealthcheckResult {
33    fn new(
34        provider: impl Into<String>,
35        valid: bool,
36        message: impl Into<String>,
37        metadata: BTreeMap<String, JsonValue>,
38    ) -> Self {
39        Self {
40            provider: provider.into(),
41            valid,
42            message: message.into(),
43            metadata,
44        }
45    }
46}
47
48pub async fn run_provider_healthcheck(provider: &str) -> ProviderHealthcheckResult {
49    run_provider_healthcheck_with_options(provider, ProviderHealthcheckOptions::default()).await
50}
51
52pub async fn run_provider_healthcheck_with_options(
53    provider: &str,
54    options: ProviderHealthcheckOptions,
55) -> ProviderHealthcheckResult {
56    let provider = if provider.trim().is_empty() {
57        "anthropic"
58    } else {
59        provider.trim()
60    };
61
62    let Some(def) = llm_config::provider_config(provider) else {
63        let mut metadata = base_metadata("unknown_provider");
64        metadata.insert("provider".to_string(), json!(provider));
65        return ProviderHealthcheckResult::new(
66            provider,
67            false,
68            format!("Unknown provider: {provider}"),
69            metadata,
70        );
71    };
72
73    let Some(healthcheck) = def.healthcheck.as_ref() else {
74        let mut metadata = base_metadata("no_healthcheck");
75        metadata.insert("provider".to_string(), json!(provider));
76        return ProviderHealthcheckResult::new(
77            provider,
78            false,
79            format!("No healthcheck configured for {provider}"),
80            metadata,
81        );
82    };
83
84    let auth = resolve_healthcheck_auth(&def, options.api_key);
85    if auth.requires_auth && auth.api_key.is_none() {
86        let mut metadata = base_metadata("missing_credentials");
87        metadata.insert("provider".to_string(), json!(provider));
88        metadata.insert("auth_env".to_string(), json!(auth.candidates));
89        return ProviderHealthcheckResult::new(
90            provider,
91            false,
92            format!(
93                "Missing credentials for {provider}: set {} or pass an api_key",
94                auth.candidates.join(", ")
95            ),
96            metadata,
97        );
98    }
99
100    let url = build_healthcheck_url(&def, healthcheck);
101    let method = Method::from_bytes(healthcheck.method.as_bytes()).unwrap_or(Method::GET);
102    let client = match options.client {
103        Some(client) => client,
104        None => match reqwest::Client::builder()
105            .timeout(Duration::from_secs(DEFAULT_HEALTHCHECK_TIMEOUT_SECS))
106            .build()
107        {
108            Ok(client) => client,
109            Err(error) => {
110                let mut metadata = base_metadata("client_build_failed");
111                metadata.insert("provider".to_string(), json!(provider));
112                return ProviderHealthcheckResult::new(
113                    provider,
114                    false,
115                    format!("{provider} healthcheck failed: {error}"),
116                    metadata,
117                );
118            }
119        },
120    };
121
122    let mut request = client.request(method.clone(), &url);
123    if let Some(api_key) = auth.api_key.as_deref() {
124        request = apply_auth_headers(request, api_key, Some(&def));
125    }
126    for (name, value) in &def.extra_headers {
127        request = request.header(name, value);
128    }
129    if let Some(body) = &healthcheck.body {
130        request = request
131            .header(reqwest::header::CONTENT_TYPE, "application/json")
132            .body(body.clone());
133    }
134
135    match request.send().await {
136        Ok(response) => {
137            let status = response.status();
138            let status_code = status.as_u16();
139            let valid = status.is_success();
140            let body_text = response.text().await.unwrap_or_default();
141            let mut metadata = base_metadata(if valid { "ok" } else { "http_status" });
142            metadata.insert("provider".to_string(), json!(provider));
143            metadata.insert("status".to_string(), json!(status_code));
144            metadata.insert("url".to_string(), json!(url));
145            metadata.insert("method".to_string(), json!(method.as_str()));
146            if !valid && !body_text.is_empty() {
147                metadata.insert("body".to_string(), json!(body_snippet(&body_text)));
148            }
149
150            let message = if valid {
151                format!("{provider} is reachable (HTTP {status_code})")
152            } else {
153                let suffix = body_snippet(&body_text);
154                if suffix.is_empty() {
155                    format!("{provider} returned HTTP {status_code}")
156                } else {
157                    format!("{provider} returned HTTP {status_code}: {suffix}")
158                }
159            };
160
161            ProviderHealthcheckResult::new(provider, valid, message, metadata)
162        }
163        Err(error) => {
164            let mut metadata = base_metadata("request_failed");
165            metadata.insert("provider".to_string(), json!(provider));
166            metadata.insert("url".to_string(), json!(url));
167            metadata.insert("method".to_string(), json!(method.as_str()));
168            ProviderHealthcheckResult::new(
169                provider,
170                false,
171                format!("{provider} healthcheck failed: {error}"),
172                metadata,
173            )
174        }
175    }
176}
177
178pub fn build_healthcheck_url(def: &ProviderDef, healthcheck: &HealthcheckDef) -> String {
179    if let Some(url) = &healthcheck.url {
180        return url.clone();
181    }
182
183    let base = llm_config::resolve_base_url(def);
184    let path = healthcheck.path.as_deref().unwrap_or("");
185    if path.starts_with('/') {
186        format!("{}{}", base.trim_end_matches('/'), path)
187    } else if path.is_empty() {
188        base
189    } else {
190        format!("{}/{}", base.trim_end_matches('/'), path)
191    }
192}
193
194#[derive(Debug, Clone)]
195struct ResolvedHealthcheckAuth {
196    requires_auth: bool,
197    api_key: Option<String>,
198    candidates: Vec<String>,
199}
200
201fn resolve_healthcheck_auth(
202    def: &ProviderDef,
203    api_key_override: Option<String>,
204) -> ResolvedHealthcheckAuth {
205    let candidates = auth_env_candidates(&def.auth_env);
206    if def.auth_style == "none" || matches!(def.auth_env, AuthEnv::None) {
207        let api_key = api_key_override.and_then(non_empty);
208        return ResolvedHealthcheckAuth {
209            requires_auth: api_key.is_some(),
210            api_key,
211            candidates,
212        };
213    }
214
215    let api_key = api_key_override
216        .and_then(non_empty)
217        .or_else(|| resolve_api_key_from_env(&def.auth_env));
218    ResolvedHealthcheckAuth {
219        requires_auth: true,
220        api_key,
221        candidates,
222    }
223}
224
225fn auth_env_candidates(auth_env: &AuthEnv) -> Vec<String> {
226    match auth_env {
227        AuthEnv::None => Vec::new(),
228        AuthEnv::Single(env) => vec![env.clone()],
229        AuthEnv::Multiple(envs) => envs.clone(),
230    }
231}
232
233fn resolve_api_key_from_env(auth_env: &AuthEnv) -> Option<String> {
234    match auth_env {
235        AuthEnv::None => None,
236        AuthEnv::Single(env) => std::env::var(env).ok().and_then(non_empty),
237        AuthEnv::Multiple(envs) => envs
238            .iter()
239            .find_map(|env| std::env::var(env).ok().and_then(non_empty)),
240    }
241}
242
243fn non_empty(value: String) -> Option<String> {
244    let trimmed = value.trim();
245    if trimmed.is_empty() {
246        None
247    } else {
248        Some(trimmed.to_string())
249    }
250}
251
252fn base_metadata(reason: &str) -> BTreeMap<String, JsonValue> {
253    BTreeMap::from([("reason".to_string(), json!(reason))])
254}
255
256fn body_snippet(body: &str) -> String {
257    let mut snippet = String::new();
258    for ch in body.chars().take(BODY_SNIPPET_LIMIT) {
259        snippet.push(ch);
260    }
261    snippet
262}
263
264#[cfg(test)]
265mod tests {
266    use std::io::{Read, Write};
267    use std::net::TcpListener;
268    use std::sync::{Arc, Mutex};
269
270    use super::*;
271
272    fn provider_with_healthcheck(base_url: String, healthcheck: HealthcheckDef) -> ProviderDef {
273        ProviderDef {
274            base_url,
275            auth_style: "bearer".to_string(),
276            auth_env: AuthEnv::Single("HARN_TEST_PROVIDER_KEY".to_string()),
277            extra_headers: BTreeMap::from([("x-extra".to_string(), "extra-value".to_string())]),
278            chat_endpoint: "/chat/completions".to_string(),
279            healthcheck: Some(healthcheck),
280            ..Default::default()
281        }
282    }
283
284    fn install_provider(name: &str, provider: ProviderDef) {
285        let mut config = llm_config::ProvidersConfig::default();
286        config.providers.insert(name.to_string(), provider);
287        llm_config::set_user_overrides(Some(config));
288    }
289
290    fn spawn_healthcheck_stub(
291        status: u16,
292        body: &'static str,
293        captured: Arc<Mutex<Option<String>>>,
294    ) -> (String, std::thread::JoinHandle<()>) {
295        let listener = TcpListener::bind("127.0.0.1:0").expect("bind healthcheck stub");
296        let addr = listener.local_addr().expect("stub addr");
297        listener
298            .set_nonblocking(true)
299            .expect("set listener nonblocking");
300
301        // Use a generous deadline so the stub doesn't trip when nextest fans
302        // out across the workspace and starves this thread of CPU. The
303        // healthcheck client itself completes in milliseconds against the
304        // loopback stub once it gets scheduled — the deadline is just an
305        // upper bound to keep a stuck test from hanging forever.
306        let handle = std::thread::spawn(move || {
307            let deadline = std::time::Instant::now() + std::time::Duration::from_secs(30);
308            let (mut stream, _) = loop {
309                match listener.accept() {
310                    Ok(pair) => break pair,
311                    Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {
312                        if std::time::Instant::now() >= deadline {
313                            panic!("healthcheck stub: no client within 30s");
314                        }
315                        std::thread::sleep(std::time::Duration::from_millis(10));
316                    }
317                    Err(error) => panic!("healthcheck stub accept failed: {error}"),
318                }
319            };
320            stream
321                .set_nonblocking(false)
322                .expect("set accepted stream blocking");
323            stream
324                .set_read_timeout(Some(std::time::Duration::from_secs(30)))
325                .ok();
326            stream
327                .set_write_timeout(Some(std::time::Duration::from_secs(30)))
328                .ok();
329
330            let mut bytes = Vec::new();
331            let mut buf = [0u8; 4096];
332            loop {
333                let n = stream.read(&mut buf).expect("read request");
334                if n == 0 {
335                    break;
336                }
337                bytes.extend_from_slice(&buf[..n]);
338                let request = String::from_utf8_lossy(&bytes);
339                if request_complete(&request) {
340                    break;
341                }
342            }
343            *captured.lock().expect("capture request") =
344                Some(String::from_utf8_lossy(&bytes).to_string());
345
346            let response = format!(
347                "HTTP/1.1 {status} OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
348                body.len()
349            );
350            stream
351                .write_all(response.as_bytes())
352                .expect("write response");
353        });
354
355        (format!("http://{addr}"), handle)
356    }
357
358    fn request_complete(request: &str) -> bool {
359        let Some((headers, body)) = request.split_once("\r\n\r\n") else {
360            return false;
361        };
362        let content_length = headers
363            .lines()
364            .find_map(|line| line.strip_prefix("content-length: "))
365            .or_else(|| {
366                headers
367                    .lines()
368                    .find_map(|line| line.strip_prefix("Content-Length: "))
369            })
370            .and_then(|value| value.trim().parse::<usize>().ok())
371            .unwrap_or(0);
372        body.len() >= content_length
373    }
374
375    #[tokio::test(flavor = "current_thread")]
376    #[allow(clippy::await_holding_lock)]
377    async fn sends_configured_probe_request_with_candidate_key() {
378        let _guard = crate::llm::env_lock().lock().expect("env lock");
379        let captured = Arc::new(Mutex::new(None));
380        let (base_url, server) = spawn_healthcheck_stub(200, r#"{"ok":true}"#, captured.clone());
381        install_provider(
382            "acme",
383            provider_with_healthcheck(
384                base_url.clone(),
385                HealthcheckDef {
386                    method: "POST".to_string(),
387                    path: Some("probe".to_string()),
388                    url: None,
389                    body: Some(r#"{"ping":true}"#.to_string()),
390                },
391            ),
392        );
393
394        let result = run_provider_healthcheck_with_options(
395            "acme",
396            ProviderHealthcheckOptions {
397                api_key: Some("candidate-key".to_string()),
398                client: None,
399            },
400        )
401        .await;
402        server.join().expect("stub server");
403        llm_config::clear_user_overrides();
404
405        assert!(result.valid);
406        assert_eq!(result.provider, "acme");
407        assert_eq!(result.metadata["status"], json!(200));
408        assert_eq!(result.metadata["method"], json!("POST"));
409        assert_eq!(result.metadata["url"], json!(format!("{base_url}/probe")));
410
411        let request = captured
412            .lock()
413            .expect("captured request")
414            .clone()
415            .expect("request");
416        assert!(request.starts_with("POST /probe HTTP/1.1\r\n"));
417        assert!(request.contains("authorization: Bearer candidate-key\r\n"));
418        assert!(request.contains("x-extra: extra-value\r\n"));
419        assert!(request.ends_with(r#"{"ping":true}"#));
420    }
421
422    #[tokio::test(flavor = "current_thread")]
423    #[allow(clippy::await_holding_lock)]
424    async fn reports_missing_credentials_without_network() {
425        let _guard = crate::llm::env_lock().lock().expect("env lock");
426        unsafe {
427            std::env::remove_var("HARN_TEST_PROVIDER_KEY");
428        }
429        install_provider(
430            "acme-missing-key",
431            provider_with_healthcheck(
432                "http://127.0.0.1:9".to_string(),
433                HealthcheckDef {
434                    method: "GET".to_string(),
435                    path: Some("/models".to_string()),
436                    url: None,
437                    body: None,
438                },
439            ),
440        );
441
442        let result = run_provider_healthcheck("acme-missing-key").await;
443        llm_config::clear_user_overrides();
444
445        assert!(!result.valid);
446        assert_eq!(result.metadata["reason"], json!("missing_credentials"));
447        assert_eq!(
448            result.metadata["auth_env"],
449            json!(["HARN_TEST_PROVIDER_KEY"])
450        );
451        assert!(result.message.contains("Missing credentials"));
452    }
453
454    #[tokio::test(flavor = "current_thread")]
455    #[allow(clippy::await_holding_lock)]
456    async fn returns_stable_failure_shape_for_http_errors() {
457        let _guard = crate::llm::env_lock().lock().expect("env lock");
458        let captured = Arc::new(Mutex::new(None));
459        let (base_url, server) = spawn_healthcheck_stub(401, r#"{"error":"bad key"}"#, captured);
460        install_provider(
461            "acme-auth",
462            provider_with_healthcheck(
463                base_url,
464                HealthcheckDef {
465                    method: "GET".to_string(),
466                    path: Some("/models".to_string()),
467                    url: None,
468                    body: None,
469                },
470            ),
471        );
472
473        let result = run_provider_healthcheck_with_options(
474            "acme-auth",
475            ProviderHealthcheckOptions {
476                api_key: Some("bad-key".to_string()),
477                client: None,
478            },
479        )
480        .await;
481        server.join().expect("stub server");
482        llm_config::clear_user_overrides();
483
484        assert!(!result.valid);
485        assert_eq!(result.provider, "acme-auth");
486        assert_eq!(result.metadata["reason"], json!("http_status"));
487        assert_eq!(result.metadata["status"], json!(401));
488        assert_eq!(result.metadata["body"], json!(r#"{"error":"bad key"}"#));
489    }
490
491    #[test]
492    fn default_external_provider_catalog_has_healthchecks() {
493        for provider in [
494            "openrouter",
495            "anthropic",
496            "openai",
497            "huggingface",
498            "together",
499        ] {
500            let config = llm_config::provider_config(provider)
501                .unwrap_or_else(|| panic!("missing provider {provider}"));
502            let healthcheck = config
503                .healthcheck
504                .as_ref()
505                .unwrap_or_else(|| panic!("missing healthcheck for {provider}"));
506            assert!(!healthcheck.method.is_empty());
507            assert!(healthcheck.path.is_some() || healthcheck.url.is_some());
508        }
509    }
510}