Skip to main content

harn_vm/llm/api/
readiness.rs

1//! OpenAI-compatible model readiness probes.
2//!
3//! Local llama.cpp/vLLM-style servers expose their loaded model aliases
4//! through `/v1/models`. Unlike context-window discovery, this module keeps
5//! distinct user-facing failure categories so hosts can surface actionable
6//! startup diagnostics before the first chat request.
7
8use crate::llm_config::{self, ProviderDef};
9
10use super::auth::apply_auth_headers;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct ModelReadiness {
14    pub valid: bool,
15    pub category: String,
16    pub message: String,
17    pub provider: String,
18    pub model: String,
19    pub url: Option<String>,
20    pub status: Option<u16>,
21    pub available_models: Vec<String>,
22}
23
24impl ModelReadiness {
25    fn ok(
26        provider: &str,
27        model: &str,
28        url: &str,
29        status: u16,
30        available_models: Vec<String>,
31    ) -> Self {
32        Self {
33            valid: true,
34            category: "ok".to_string(),
35            message: format!("{provider} is reachable and serves model '{model}' at {url}"),
36            provider: provider.to_string(),
37            model: model.to_string(),
38            url: Some(url.to_string()),
39            status: Some(status),
40            available_models,
41        }
42    }
43
44    fn error(
45        provider: &str,
46        model: &str,
47        category: &str,
48        message: String,
49        url: Option<String>,
50        status: Option<u16>,
51        available_models: Vec<String>,
52    ) -> Self {
53        Self {
54            valid: false,
55            category: category.to_string(),
56            message,
57            provider: provider.to_string(),
58            model: model.to_string(),
59            url,
60            status,
61            available_models,
62        }
63    }
64}
65
66pub fn supports_model_readiness_probe(def: &ProviderDef) -> bool {
67    let healthcheck_uses_models = def.healthcheck.as_ref().is_some_and(|hc| {
68        hc.method.eq_ignore_ascii_case("GET") && {
69            hc.path
70                .as_deref()
71                .is_some_and(|path| path.contains("models"))
72                || hc.url.as_deref().is_some_and(|url| url.contains("models"))
73        }
74    });
75    healthcheck_uses_models || def.chat_endpoint.ends_with("/chat/completions")
76}
77
78pub fn selected_model_for_provider(provider: &str) -> Option<String> {
79    if provider == "local" {
80        if let Ok(model) = std::env::var("LOCAL_LLM_MODEL") {
81            if !model.trim().is_empty() {
82                let (resolved, _) = llm_config::resolve_model(model.trim());
83                return Some(resolved);
84            }
85        }
86    }
87
88    let selected_provider = std::env::var("HARN_LLM_PROVIDER")
89        .ok()
90        .filter(|value| !value.trim().is_empty());
91    if selected_provider.as_deref() == Some(provider) {
92        if let Ok(model) = std::env::var("HARN_LLM_MODEL") {
93            if !model.trim().is_empty() {
94                let (resolved, _) = llm_config::resolve_model(model.trim());
95                return Some(resolved);
96            }
97        }
98    }
99
100    None
101}
102
103pub fn build_models_url(def: &ProviderDef) -> Result<String, String> {
104    let raw = models_healthcheck_url(def).unwrap_or_else(|| {
105        join_base_and_path(
106            &llm_config::resolve_base_url(def),
107            &model_path_from_chat_endpoint(&def.chat_endpoint),
108        )
109    });
110    validate_url(&normalize_loopback(&raw))
111}
112
113fn models_healthcheck_url(def: &ProviderDef) -> Option<String> {
114    let healthcheck = def.healthcheck.as_ref()?;
115    if !healthcheck.method.eq_ignore_ascii_case("GET") {
116        return None;
117    }
118    if let Some(url) = healthcheck.url.as_ref() {
119        return url.contains("models").then(|| url.clone());
120    }
121    let path = healthcheck.path.as_deref()?;
122    path.contains("models")
123        .then(|| join_base_and_path(&llm_config::resolve_base_url(def), path))
124}
125
126pub fn parse_model_ids(json: &serde_json::Value) -> Vec<String> {
127    if let Some(entries) = json.as_array() {
128        return collect_model_ids(entries);
129    }
130
131    if let Some(data) = json.get("data").and_then(|value| value.as_array()) {
132        return collect_model_ids(data);
133    }
134
135    if let Some(models) = json.get("models").and_then(|value| value.as_array()) {
136        return collect_model_ids(models);
137    }
138
139    Vec::new()
140}
141
142fn collect_model_ids(entries: &[serde_json::Value]) -> Vec<String> {
143    entries
144        .iter()
145        .filter_map(|entry| {
146            entry.as_str().or_else(|| {
147                entry
148                    .get("id")
149                    .or_else(|| entry.get("name"))
150                    .and_then(|value| value.as_str())
151            })
152        })
153        .map(str::to_string)
154        .collect()
155}
156
157pub fn model_is_served(available: &[String], model: &str) -> bool {
158    available
159        .iter()
160        .any(|id| id == model || id.starts_with(model))
161}
162
163pub async fn probe_openai_compatible_model(
164    provider: &str,
165    model: &str,
166    api_key: &str,
167) -> ModelReadiness {
168    let Some(def) = llm_config::provider_config(provider) else {
169        return ModelReadiness::error(
170            provider,
171            model,
172            "unknown_provider",
173            format!("Unknown provider: {provider}"),
174            None,
175            None,
176            Vec::new(),
177        );
178    };
179
180    probe_openai_compatible_model_with_def(provider, model, api_key, &def).await
181}
182
183pub(crate) async fn probe_openai_compatible_model_with_def(
184    provider: &str,
185    model: &str,
186    api_key: &str,
187    def: &ProviderDef,
188) -> ModelReadiness {
189    let url = match build_models_url(def) {
190        Ok(url) => url,
191        Err(error) => {
192            return ModelReadiness::error(
193                provider,
194                model,
195                "invalid_url",
196                format!("Invalid OpenAI-compatible models URL for {provider}: {error}"),
197                None,
198                None,
199                Vec::new(),
200            );
201        }
202    };
203
204    let client = crate::llm::shared_utility_client();
205    let req = client
206        .get(&url)
207        .header("Content-Type", "application/json")
208        .timeout(std::time::Duration::from_secs(10));
209    let req = apply_auth_headers(req, api_key, Some(def));
210    let req = def
211        .extra_headers
212        .iter()
213        .fold(req, |req, (name, value)| req.header(name, value));
214
215    let response = match req.send().await {
216        Ok(response) => response,
217        Err(error) => {
218            return ModelReadiness::error(
219                provider,
220                model,
221                "unreachable",
222                format!("{provider} OpenAI-compatible server not reachable at {url}: {error}"),
223                Some(url),
224                None,
225                Vec::new(),
226            );
227        }
228    };
229
230    let status = response.status();
231    if !status.is_success() {
232        let body = response.text().await.unwrap_or_default();
233        return ModelReadiness::error(
234            provider,
235            model,
236            "bad_status",
237            format!(
238                "{provider} returned HTTP {} at {url}: {body}",
239                status.as_u16()
240            ),
241            Some(url),
242            Some(status.as_u16()),
243            Vec::new(),
244        );
245    }
246
247    let status_code = status.as_u16();
248    let json: serde_json::Value = match response.json().await {
249        Ok(json) => json,
250        Err(error) => {
251            return ModelReadiness::error(
252                provider,
253                model,
254                "invalid_response",
255                format!("Could not parse {provider} /models response at {url}: {error}"),
256                Some(url),
257                Some(status_code),
258                Vec::new(),
259            );
260        }
261    };
262    let available_models = parse_model_ids(&json);
263    if available_models.is_empty() {
264        return ModelReadiness::error(
265            provider,
266            model,
267            "invalid_response",
268            format!("Could not find model ids in {provider} /models response at {url}"),
269            Some(url),
270            Some(status_code),
271            available_models,
272        );
273    }
274
275    if !model_is_served(&available_models, model) {
276        let available = available_models.join(", ");
277        return ModelReadiness::error(
278            provider,
279            model,
280            "model_missing",
281            format!(
282                "Model '{model}' is not served by {provider} at {url}. Currently served: {available}"
283            ),
284            Some(url),
285            Some(status_code),
286            available_models,
287        );
288    }
289
290    ModelReadiness::ok(provider, model, &url, status_code, available_models)
291}
292
293fn model_path_from_chat_endpoint(chat_endpoint: &str) -> String {
294    if let Some(prefix) = chat_endpoint.strip_suffix("/chat/completions") {
295        if prefix.is_empty() {
296            "/models".to_string()
297        } else {
298            format!("{prefix}/models")
299        }
300    } else {
301        "/models".to_string()
302    }
303}
304
305fn join_base_and_path(base: &str, path: &str) -> String {
306    let base = base.trim_end_matches('/');
307    if path.is_empty() {
308        base.to_string()
309    } else if path.starts_with('/') {
310        format!("{base}{path}")
311    } else {
312        format!("{base}/{path}")
313    }
314}
315
316fn normalize_loopback(url: &str) -> String {
317    url.replace("://localhost:", "://127.0.0.1:")
318}
319
320fn validate_url(url: &str) -> Result<String, String> {
321    reqwest::Url::parse(url)
322        .map(|_| url.to_string())
323        .map_err(|error| format!("{url} ({error})"))
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use crate::llm_config::{HealthcheckDef, ProviderDef};
330
331    #[test]
332    fn parses_openai_and_ollama_style_model_ids() {
333        let openai = serde_json::json!({
334            "data": [{"id": "qwen-alias"}, {"id": "other"}]
335        });
336        assert_eq!(
337            parse_model_ids(&openai),
338            vec!["qwen-alias".to_string(), "other".to_string()]
339        );
340
341        let models = serde_json::json!({
342            "models": [{"name": "llama"}, {"id": "qwen"}]
343        });
344        assert_eq!(
345            parse_model_ids(&models),
346            vec!["llama".to_string(), "qwen".to_string()]
347        );
348
349        let top_level = serde_json::json!([
350            {"id": "deepseek-ai/DeepSeek-V4-Pro"},
351            {"name": "qwen"},
352            "string-model"
353        ]);
354        assert_eq!(
355            parse_model_ids(&top_level),
356            vec![
357                "deepseek-ai/DeepSeek-V4-Pro".to_string(),
358                "qwen".to_string(),
359                "string-model".to_string()
360            ]
361        );
362    }
363
364    #[test]
365    fn model_matching_accepts_exact_or_prefix() {
366        let ids = vec![
367            "qwen36".to_string(),
368            "gpt-oss:20b".to_string(),
369            "llama-local-long-id".to_string(),
370        ];
371        assert!(model_is_served(&ids, "qwen36"));
372        assert!(model_is_served(&ids, "llama-local"));
373        assert!(!model_is_served(&ids, "missing"));
374    }
375
376    #[test]
377    fn models_url_uses_healthcheck_path_and_loopback_normalization() {
378        let def = ProviderDef {
379            base_url: "http://localhost:8001".to_string(),
380            chat_endpoint: "/v1/chat/completions".to_string(),
381            healthcheck: Some(HealthcheckDef {
382                method: "GET".to_string(),
383                path: Some("/v1/models".to_string()),
384                url: None,
385                body: None,
386            }),
387            ..Default::default()
388        };
389
390        assert_eq!(
391            build_models_url(&def).unwrap(),
392            "http://127.0.0.1:8001/v1/models"
393        );
394    }
395
396    #[test]
397    fn models_url_derives_path_from_chat_endpoint() {
398        let def = ProviderDef {
399            base_url: "http://127.0.0.1:8000".to_string(),
400            chat_endpoint: "/v1/chat/completions".to_string(),
401            healthcheck: None,
402            ..Default::default()
403        };
404
405        assert_eq!(
406            build_models_url(&def).unwrap(),
407            "http://127.0.0.1:8000/v1/models"
408        );
409    }
410
411    #[test]
412    fn models_url_ignores_non_model_healthcheck_path() {
413        let def = ProviderDef {
414            base_url: "http://127.0.0.1:8080".to_string(),
415            chat_endpoint: "/v1/chat/completions".to_string(),
416            healthcheck: Some(HealthcheckDef {
417                method: "GET".to_string(),
418                path: Some("/health".to_string()),
419                url: None,
420                body: None,
421            }),
422            ..Default::default()
423        };
424
425        assert_eq!(
426            build_models_url(&def).unwrap(),
427            "http://127.0.0.1:8080/v1/models"
428        );
429    }
430
431    #[tokio::test]
432    async fn probe_reports_ready_when_model_is_served() {
433        let def = test_def_with_response(200, r#"{"data":[{"id":"served-model-long"}]}"#).await;
434
435        let result =
436            probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
437
438        assert!(result.valid);
439        assert_eq!(result.category, "ok");
440        assert_eq!(
441            result.available_models,
442            vec!["served-model-long".to_string()]
443        );
444    }
445
446    #[tokio::test]
447    async fn probe_distinguishes_model_missing() {
448        let def = test_def_with_response(200, r#"{"data":[{"id":"served-model"}]}"#).await;
449
450        let result = probe_openai_compatible_model_with_def("local", "missing", "", &def).await;
451
452        assert!(!result.valid);
453        assert_eq!(result.category, "model_missing");
454        assert_eq!(result.available_models, vec!["served-model".to_string()]);
455    }
456
457    #[tokio::test]
458    async fn probe_distinguishes_bad_status() {
459        let def = test_def_with_response(503, "loading").await;
460
461        let result =
462            probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
463
464        assert!(!result.valid);
465        assert_eq!(result.category, "bad_status");
466        assert_eq!(result.status, Some(503));
467    }
468
469    #[tokio::test]
470    async fn probe_distinguishes_invalid_url() {
471        let def = ProviderDef {
472            base_url: "not a url".to_string(),
473            chat_endpoint: "/v1/chat/completions".to_string(),
474            healthcheck: Some(HealthcheckDef {
475                method: "GET".to_string(),
476                path: Some("/v1/models".to_string()),
477                url: None,
478                body: None,
479            }),
480            auth_style: "none".to_string(),
481            ..Default::default()
482        };
483
484        let result =
485            probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
486
487        assert!(!result.valid);
488        assert_eq!(result.category, "invalid_url");
489    }
490
491    #[tokio::test]
492    async fn probe_distinguishes_unreachable() {
493        // The earlier "bind ephemeral, drop, probe the same port"
494        // pattern was racy under nextest fan-out: another concurrent
495        // test using `bind 127.0.0.1:0` could be assigned the freed
496        // port between the drop and the probe, making the connection
497        // succeed and the test report a different category. Port 1
498        // (IANA-reserved tcpmux) is always unprivileged-bind-blocked,
499        // so connecting to it yields a deterministic ECONNREFUSED →
500        // "unreachable", with no listener-handoff race.
501        let def = ProviderDef {
502            base_url: "http://127.0.0.1:1".to_string(),
503            chat_endpoint: "/v1/chat/completions".to_string(),
504            healthcheck: Some(HealthcheckDef {
505                method: "GET".to_string(),
506                path: Some("/v1/models".to_string()),
507                url: None,
508                body: None,
509            }),
510            auth_style: "none".to_string(),
511            ..Default::default()
512        };
513
514        let result =
515            probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
516
517        assert!(!result.valid);
518        assert_eq!(result.category, "unreachable");
519    }
520
521    async fn test_def_with_response(status: u16, body: &'static str) -> ProviderDef {
522        let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
523            .await
524            .expect("bind");
525        let addr = listener.local_addr().expect("addr");
526        tokio::spawn(async move {
527            let (mut socket, _) = listener.accept().await.expect("accept");
528            let mut buf = [0_u8; 1024];
529            let _ = tokio::io::AsyncReadExt::read(&mut socket, &mut buf).await;
530            let response = format!(
531                "HTTP/1.1 {status} OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}",
532                body.len(),
533                body
534            );
535            tokio::io::AsyncWriteExt::write_all(&mut socket, response.as_bytes())
536                .await
537                .expect("write");
538        });
539
540        ProviderDef {
541            base_url: format!("http://{addr}"),
542            chat_endpoint: "/v1/chat/completions".to_string(),
543            healthcheck: Some(HealthcheckDef {
544                method: "GET".to_string(),
545                path: Some("/v1/models".to_string()),
546                url: None,
547                body: None,
548            }),
549            auth_style: "none".to_string(),
550            ..Default::default()
551        }
552    }
553}