Skip to main content

harn_vm/llm/api/
readiness.rs

1//! OpenAI-compatible model readiness probes.
2//!
3//! Local llama.cpp/vLLM-style servers expose their loaded model aliases
4//! through `/v1/models`. Unlike context-window discovery, this module keeps
5//! distinct user-facing failure categories so hosts can surface actionable
6//! startup diagnostics before the first chat request.
7
8use crate::llm_config::{self, ProviderDef};
9
10use super::auth::apply_auth_headers;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct ModelReadiness {
14    pub valid: bool,
15    pub category: String,
16    pub message: String,
17    pub provider: String,
18    pub model: String,
19    pub url: Option<String>,
20    pub status: Option<u16>,
21    pub available_models: Vec<String>,
22}
23
24impl ModelReadiness {
25    fn ok(
26        provider: &str,
27        model: &str,
28        url: &str,
29        status: u16,
30        available_models: Vec<String>,
31    ) -> Self {
32        Self {
33            valid: true,
34            category: "ok".to_string(),
35            message: format!("{provider} is reachable and serves model '{model}' at {url}"),
36            provider: provider.to_string(),
37            model: model.to_string(),
38            url: Some(url.to_string()),
39            status: Some(status),
40            available_models,
41        }
42    }
43
44    fn error(
45        provider: &str,
46        model: &str,
47        category: &str,
48        message: String,
49        url: Option<String>,
50        status: Option<u16>,
51        available_models: Vec<String>,
52    ) -> Self {
53        Self {
54            valid: false,
55            category: category.to_string(),
56            message,
57            provider: provider.to_string(),
58            model: model.to_string(),
59            url,
60            status,
61            available_models,
62        }
63    }
64}
65
66pub fn supports_model_readiness_probe(def: &ProviderDef) -> bool {
67    let healthcheck_uses_models = def.healthcheck.as_ref().is_some_and(|hc| {
68        hc.method.eq_ignore_ascii_case("GET") && {
69            hc.path
70                .as_deref()
71                .is_some_and(|path| path.contains("models"))
72                || hc.url.as_deref().is_some_and(|url| url.contains("models"))
73        }
74    });
75    healthcheck_uses_models || def.chat_endpoint.ends_with("/chat/completions")
76}
77
78pub fn selected_model_for_provider(provider: &str) -> Option<String> {
79    if provider == "local" {
80        if let Ok(model) = std::env::var("LOCAL_LLM_MODEL") {
81            if !model.trim().is_empty() {
82                let (resolved, _) = llm_config::resolve_model(model.trim());
83                return Some(resolved);
84            }
85        }
86    }
87
88    let selected_provider = std::env::var("HARN_LLM_PROVIDER")
89        .ok()
90        .filter(|value| !value.trim().is_empty());
91    if selected_provider.as_deref() == Some(provider) {
92        if let Ok(model) = std::env::var("HARN_LLM_MODEL") {
93            if !model.trim().is_empty() {
94                let (resolved, _) = llm_config::resolve_model(model.trim());
95                return Some(resolved);
96            }
97        }
98    }
99
100    None
101}
102
103pub fn build_models_url(def: &ProviderDef) -> Result<String, String> {
104    let raw = models_healthcheck_url(def).unwrap_or_else(|| {
105        join_base_and_path(
106            &llm_config::resolve_base_url(def),
107            &model_path_from_chat_endpoint(&def.chat_endpoint),
108        )
109    });
110    validate_url(&normalize_loopback(&raw))
111}
112
113fn models_healthcheck_url(def: &ProviderDef) -> Option<String> {
114    let healthcheck = def.healthcheck.as_ref()?;
115    if !healthcheck.method.eq_ignore_ascii_case("GET") {
116        return None;
117    }
118    if let Some(url) = healthcheck.url.as_ref() {
119        return url.contains("models").then(|| url.clone());
120    }
121    let path = healthcheck.path.as_deref()?;
122    path.contains("models")
123        .then(|| join_base_and_path(&llm_config::resolve_base_url(def), path))
124}
125
126pub fn parse_model_ids(json: &serde_json::Value) -> Vec<String> {
127    if let Some(data) = json.get("data").and_then(|value| value.as_array()) {
128        return data
129            .iter()
130            .filter_map(|entry| entry.get("id").and_then(|value| value.as_str()))
131            .map(str::to_string)
132            .collect();
133    }
134
135    if let Some(models) = json.get("models").and_then(|value| value.as_array()) {
136        return models
137            .iter()
138            .filter_map(|entry| {
139                entry
140                    .get("id")
141                    .or_else(|| entry.get("name"))
142                    .and_then(|value| value.as_str())
143            })
144            .map(str::to_string)
145            .collect();
146    }
147
148    Vec::new()
149}
150
151pub fn model_is_served(available: &[String], model: &str) -> bool {
152    available
153        .iter()
154        .any(|id| id == model || id.starts_with(model))
155}
156
157pub async fn probe_openai_compatible_model(
158    provider: &str,
159    model: &str,
160    api_key: &str,
161) -> ModelReadiness {
162    let Some(def) = llm_config::provider_config(provider) else {
163        return ModelReadiness::error(
164            provider,
165            model,
166            "unknown_provider",
167            format!("Unknown provider: {provider}"),
168            None,
169            None,
170            Vec::new(),
171        );
172    };
173
174    probe_openai_compatible_model_with_def(provider, model, api_key, &def).await
175}
176
177pub(crate) async fn probe_openai_compatible_model_with_def(
178    provider: &str,
179    model: &str,
180    api_key: &str,
181    def: &ProviderDef,
182) -> ModelReadiness {
183    let url = match build_models_url(def) {
184        Ok(url) => url,
185        Err(error) => {
186            return ModelReadiness::error(
187                provider,
188                model,
189                "invalid_url",
190                format!("Invalid OpenAI-compatible models URL for {provider}: {error}"),
191                None,
192                None,
193                Vec::new(),
194            );
195        }
196    };
197
198    let client = crate::llm::shared_utility_client();
199    let req = client
200        .get(&url)
201        .header("Content-Type", "application/json")
202        .timeout(std::time::Duration::from_secs(10));
203    let req = apply_auth_headers(req, api_key, Some(def));
204    let req = def
205        .extra_headers
206        .iter()
207        .fold(req, |req, (name, value)| req.header(name, value));
208
209    let response = match req.send().await {
210        Ok(response) => response,
211        Err(error) => {
212            return ModelReadiness::error(
213                provider,
214                model,
215                "unreachable",
216                format!("{provider} OpenAI-compatible server not reachable at {url}: {error}"),
217                Some(url),
218                None,
219                Vec::new(),
220            );
221        }
222    };
223
224    let status = response.status();
225    if !status.is_success() {
226        let body = response.text().await.unwrap_or_default();
227        return ModelReadiness::error(
228            provider,
229            model,
230            "bad_status",
231            format!(
232                "{provider} returned HTTP {} at {url}: {body}",
233                status.as_u16()
234            ),
235            Some(url),
236            Some(status.as_u16()),
237            Vec::new(),
238        );
239    }
240
241    let status_code = status.as_u16();
242    let json: serde_json::Value = match response.json().await {
243        Ok(json) => json,
244        Err(error) => {
245            return ModelReadiness::error(
246                provider,
247                model,
248                "invalid_response",
249                format!("Could not parse {provider} /models response at {url}: {error}"),
250                Some(url),
251                Some(status_code),
252                Vec::new(),
253            );
254        }
255    };
256    let available_models = parse_model_ids(&json);
257    if available_models.is_empty() {
258        return ModelReadiness::error(
259            provider,
260            model,
261            "invalid_response",
262            format!("Could not find model ids in {provider} /models response at {url}"),
263            Some(url),
264            Some(status_code),
265            available_models,
266        );
267    }
268
269    if !model_is_served(&available_models, model) {
270        let available = available_models.join(", ");
271        return ModelReadiness::error(
272            provider,
273            model,
274            "model_missing",
275            format!(
276                "Model '{model}' is not served by {provider} at {url}. Currently served: {available}"
277            ),
278            Some(url),
279            Some(status_code),
280            available_models,
281        );
282    }
283
284    ModelReadiness::ok(provider, model, &url, status_code, available_models)
285}
286
287fn model_path_from_chat_endpoint(chat_endpoint: &str) -> String {
288    if let Some(prefix) = chat_endpoint.strip_suffix("/chat/completions") {
289        if prefix.is_empty() {
290            "/models".to_string()
291        } else {
292            format!("{prefix}/models")
293        }
294    } else {
295        "/models".to_string()
296    }
297}
298
299fn join_base_and_path(base: &str, path: &str) -> String {
300    let base = base.trim_end_matches('/');
301    if path.is_empty() {
302        base.to_string()
303    } else if path.starts_with('/') {
304        format!("{base}{path}")
305    } else {
306        format!("{base}/{path}")
307    }
308}
309
310fn normalize_loopback(url: &str) -> String {
311    url.replace("://localhost:", "://127.0.0.1:")
312}
313
314fn validate_url(url: &str) -> Result<String, String> {
315    reqwest::Url::parse(url)
316        .map(|_| url.to_string())
317        .map_err(|error| format!("{url} ({error})"))
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::llm_config::{HealthcheckDef, ProviderDef};
324
325    #[test]
326    fn parses_openai_and_ollama_style_model_ids() {
327        let openai = serde_json::json!({
328            "data": [{"id": "qwen-alias"}, {"id": "other"}]
329        });
330        assert_eq!(
331            parse_model_ids(&openai),
332            vec!["qwen-alias".to_string(), "other".to_string()]
333        );
334
335        let models = serde_json::json!({
336            "models": [{"name": "llama"}, {"id": "qwen"}]
337        });
338        assert_eq!(
339            parse_model_ids(&models),
340            vec!["llama".to_string(), "qwen".to_string()]
341        );
342    }
343
344    #[test]
345    fn model_matching_accepts_exact_or_prefix() {
346        let ids = vec![
347            "qwen36".to_string(),
348            "gpt-oss:20b".to_string(),
349            "llama-local-long-id".to_string(),
350        ];
351        assert!(model_is_served(&ids, "qwen36"));
352        assert!(model_is_served(&ids, "llama-local"));
353        assert!(!model_is_served(&ids, "missing"));
354    }
355
356    #[test]
357    fn models_url_uses_healthcheck_path_and_loopback_normalization() {
358        let def = ProviderDef {
359            base_url: "http://localhost:8001".to_string(),
360            chat_endpoint: "/v1/chat/completions".to_string(),
361            healthcheck: Some(HealthcheckDef {
362                method: "GET".to_string(),
363                path: Some("/v1/models".to_string()),
364                url: None,
365                body: None,
366            }),
367            ..Default::default()
368        };
369
370        assert_eq!(
371            build_models_url(&def).unwrap(),
372            "http://127.0.0.1:8001/v1/models"
373        );
374    }
375
376    #[test]
377    fn models_url_derives_path_from_chat_endpoint() {
378        let def = ProviderDef {
379            base_url: "http://127.0.0.1:8000".to_string(),
380            chat_endpoint: "/v1/chat/completions".to_string(),
381            healthcheck: None,
382            ..Default::default()
383        };
384
385        assert_eq!(
386            build_models_url(&def).unwrap(),
387            "http://127.0.0.1:8000/v1/models"
388        );
389    }
390
391    #[test]
392    fn models_url_ignores_non_model_healthcheck_path() {
393        let def = ProviderDef {
394            base_url: "http://127.0.0.1:8080".to_string(),
395            chat_endpoint: "/v1/chat/completions".to_string(),
396            healthcheck: Some(HealthcheckDef {
397                method: "GET".to_string(),
398                path: Some("/health".to_string()),
399                url: None,
400                body: None,
401            }),
402            ..Default::default()
403        };
404
405        assert_eq!(
406            build_models_url(&def).unwrap(),
407            "http://127.0.0.1:8080/v1/models"
408        );
409    }
410
411    #[tokio::test]
412    async fn probe_reports_ready_when_model_is_served() {
413        let def = test_def_with_response(200, r#"{"data":[{"id":"served-model-long"}]}"#).await;
414
415        let result =
416            probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
417
418        assert!(result.valid);
419        assert_eq!(result.category, "ok");
420        assert_eq!(
421            result.available_models,
422            vec!["served-model-long".to_string()]
423        );
424    }
425
426    #[tokio::test]
427    async fn probe_distinguishes_model_missing() {
428        let def = test_def_with_response(200, r#"{"data":[{"id":"served-model"}]}"#).await;
429
430        let result = probe_openai_compatible_model_with_def("local", "missing", "", &def).await;
431
432        assert!(!result.valid);
433        assert_eq!(result.category, "model_missing");
434        assert_eq!(result.available_models, vec!["served-model".to_string()]);
435    }
436
437    #[tokio::test]
438    async fn probe_distinguishes_bad_status() {
439        let def = test_def_with_response(503, "loading").await;
440
441        let result =
442            probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
443
444        assert!(!result.valid);
445        assert_eq!(result.category, "bad_status");
446        assert_eq!(result.status, Some(503));
447    }
448
449    #[tokio::test]
450    async fn probe_distinguishes_invalid_url() {
451        let def = ProviderDef {
452            base_url: "not a url".to_string(),
453            chat_endpoint: "/v1/chat/completions".to_string(),
454            healthcheck: Some(HealthcheckDef {
455                method: "GET".to_string(),
456                path: Some("/v1/models".to_string()),
457                url: None,
458                body: None,
459            }),
460            auth_style: "none".to_string(),
461            ..Default::default()
462        };
463
464        let result =
465            probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
466
467        assert!(!result.valid);
468        assert_eq!(result.category, "invalid_url");
469    }
470
471    #[tokio::test]
472    async fn probe_distinguishes_unreachable() {
473        let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
474            .await
475            .expect("bind");
476        let addr = listener.local_addr().expect("addr");
477        drop(listener);
478        let def = ProviderDef {
479            base_url: format!("http://{addr}"),
480            chat_endpoint: "/v1/chat/completions".to_string(),
481            healthcheck: Some(HealthcheckDef {
482                method: "GET".to_string(),
483                path: Some("/v1/models".to_string()),
484                url: None,
485                body: None,
486            }),
487            auth_style: "none".to_string(),
488            ..Default::default()
489        };
490
491        let result =
492            probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
493
494        assert!(!result.valid);
495        assert_eq!(result.category, "unreachable");
496    }
497
498    async fn test_def_with_response(status: u16, body: &'static str) -> ProviderDef {
499        let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
500            .await
501            .expect("bind");
502        let addr = listener.local_addr().expect("addr");
503        tokio::spawn(async move {
504            let (mut socket, _) = listener.accept().await.expect("accept");
505            let mut buf = [0_u8; 1024];
506            let _ = tokio::io::AsyncReadExt::read(&mut socket, &mut buf).await;
507            let response = format!(
508                "HTTP/1.1 {status} OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}",
509                body.len(),
510                body
511            );
512            tokio::io::AsyncWriteExt::write_all(&mut socket, response.as_bytes())
513                .await
514                .expect("write");
515        });
516
517        ProviderDef {
518            base_url: format!("http://{addr}"),
519            chat_endpoint: "/v1/chat/completions".to_string(),
520            healthcheck: Some(HealthcheckDef {
521                method: "GET".to_string(),
522                path: Some("/v1/models".to_string()),
523                url: None,
524                body: None,
525            }),
526            auth_style: "none".to_string(),
527            ..Default::default()
528        }
529    }
530}