1use crate::llm_config::{self, ProviderDef};
9
10use super::auth::apply_auth_headers;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct ModelReadiness {
14 pub valid: bool,
15 pub category: String,
16 pub message: String,
17 pub provider: String,
18 pub model: String,
19 pub url: Option<String>,
20 pub status: Option<u16>,
21 pub available_models: Vec<String>,
22}
23
24impl ModelReadiness {
25 fn ok(
26 provider: &str,
27 model: &str,
28 url: &str,
29 status: u16,
30 available_models: Vec<String>,
31 ) -> Self {
32 Self {
33 valid: true,
34 category: "ok".to_string(),
35 message: format!("{provider} is reachable and serves model '{model}' at {url}"),
36 provider: provider.to_string(),
37 model: model.to_string(),
38 url: Some(url.to_string()),
39 status: Some(status),
40 available_models,
41 }
42 }
43
44 fn error(
45 provider: &str,
46 model: &str,
47 category: &str,
48 message: String,
49 url: Option<String>,
50 status: Option<u16>,
51 available_models: Vec<String>,
52 ) -> Self {
53 Self {
54 valid: false,
55 category: category.to_string(),
56 message,
57 provider: provider.to_string(),
58 model: model.to_string(),
59 url,
60 status,
61 available_models,
62 }
63 }
64}
65
66pub fn supports_model_readiness_probe(def: &ProviderDef) -> bool {
67 let healthcheck_uses_models = def.healthcheck.as_ref().is_some_and(|hc| {
68 hc.method.eq_ignore_ascii_case("GET") && {
69 hc.path
70 .as_deref()
71 .is_some_and(|path| path.contains("models"))
72 || hc.url.as_deref().is_some_and(|url| url.contains("models"))
73 }
74 });
75 healthcheck_uses_models || def.chat_endpoint.ends_with("/chat/completions")
76}
77
78pub fn selected_model_for_provider(provider: &str) -> Option<String> {
79 if provider == "local" {
80 if let Ok(model) = std::env::var("LOCAL_LLM_MODEL") {
81 if !model.trim().is_empty() {
82 let (resolved, _) = llm_config::resolve_model(model.trim());
83 return Some(resolved);
84 }
85 }
86 }
87
88 let selected_provider = std::env::var("HARN_LLM_PROVIDER")
89 .ok()
90 .filter(|value| !value.trim().is_empty());
91 if selected_provider.as_deref() == Some(provider) {
92 if let Ok(model) = std::env::var("HARN_LLM_MODEL") {
93 if !model.trim().is_empty() {
94 let (resolved, _) = llm_config::resolve_model(model.trim());
95 return Some(resolved);
96 }
97 }
98 }
99
100 None
101}
102
103pub fn build_models_url(def: &ProviderDef) -> Result<String, String> {
104 let raw = models_healthcheck_url(def).unwrap_or_else(|| {
105 join_base_and_path(
106 &llm_config::resolve_base_url(def),
107 &model_path_from_chat_endpoint(&def.chat_endpoint),
108 )
109 });
110 validate_url(&normalize_loopback(&raw))
111}
112
113fn models_healthcheck_url(def: &ProviderDef) -> Option<String> {
114 let healthcheck = def.healthcheck.as_ref()?;
115 if !healthcheck.method.eq_ignore_ascii_case("GET") {
116 return None;
117 }
118 if let Some(url) = healthcheck.url.as_ref() {
119 return url.contains("models").then(|| url.clone());
120 }
121 let path = healthcheck.path.as_deref()?;
122 path.contains("models")
123 .then(|| join_base_and_path(&llm_config::resolve_base_url(def), path))
124}
125
126pub fn parse_model_ids(json: &serde_json::Value) -> Vec<String> {
127 if let Some(data) = json.get("data").and_then(|value| value.as_array()) {
128 return data
129 .iter()
130 .filter_map(|entry| entry.get("id").and_then(|value| value.as_str()))
131 .map(str::to_string)
132 .collect();
133 }
134
135 if let Some(models) = json.get("models").and_then(|value| value.as_array()) {
136 return models
137 .iter()
138 .filter_map(|entry| {
139 entry
140 .get("id")
141 .or_else(|| entry.get("name"))
142 .and_then(|value| value.as_str())
143 })
144 .map(str::to_string)
145 .collect();
146 }
147
148 Vec::new()
149}
150
151pub fn model_is_served(available: &[String], model: &str) -> bool {
152 available
153 .iter()
154 .any(|id| id == model || id.starts_with(model))
155}
156
157pub async fn probe_openai_compatible_model(
158 provider: &str,
159 model: &str,
160 api_key: &str,
161) -> ModelReadiness {
162 let Some(def) = llm_config::provider_config(provider) else {
163 return ModelReadiness::error(
164 provider,
165 model,
166 "unknown_provider",
167 format!("Unknown provider: {provider}"),
168 None,
169 None,
170 Vec::new(),
171 );
172 };
173
174 probe_openai_compatible_model_with_def(provider, model, api_key, &def).await
175}
176
177pub(crate) async fn probe_openai_compatible_model_with_def(
178 provider: &str,
179 model: &str,
180 api_key: &str,
181 def: &ProviderDef,
182) -> ModelReadiness {
183 let url = match build_models_url(def) {
184 Ok(url) => url,
185 Err(error) => {
186 return ModelReadiness::error(
187 provider,
188 model,
189 "invalid_url",
190 format!("Invalid OpenAI-compatible models URL for {provider}: {error}"),
191 None,
192 None,
193 Vec::new(),
194 );
195 }
196 };
197
198 let client = crate::llm::shared_utility_client();
199 let req = client
200 .get(&url)
201 .header("Content-Type", "application/json")
202 .timeout(std::time::Duration::from_secs(10));
203 let req = apply_auth_headers(req, api_key, Some(def));
204 let req = def
205 .extra_headers
206 .iter()
207 .fold(req, |req, (name, value)| req.header(name, value));
208
209 let response = match req.send().await {
210 Ok(response) => response,
211 Err(error) => {
212 return ModelReadiness::error(
213 provider,
214 model,
215 "unreachable",
216 format!("{provider} OpenAI-compatible server not reachable at {url}: {error}"),
217 Some(url),
218 None,
219 Vec::new(),
220 );
221 }
222 };
223
224 let status = response.status();
225 if !status.is_success() {
226 let body = response.text().await.unwrap_or_default();
227 return ModelReadiness::error(
228 provider,
229 model,
230 "bad_status",
231 format!(
232 "{provider} returned HTTP {} at {url}: {body}",
233 status.as_u16()
234 ),
235 Some(url),
236 Some(status.as_u16()),
237 Vec::new(),
238 );
239 }
240
241 let status_code = status.as_u16();
242 let json: serde_json::Value = match response.json().await {
243 Ok(json) => json,
244 Err(error) => {
245 return ModelReadiness::error(
246 provider,
247 model,
248 "invalid_response",
249 format!("Could not parse {provider} /models response at {url}: {error}"),
250 Some(url),
251 Some(status_code),
252 Vec::new(),
253 );
254 }
255 };
256 let available_models = parse_model_ids(&json);
257 if available_models.is_empty() {
258 return ModelReadiness::error(
259 provider,
260 model,
261 "invalid_response",
262 format!("Could not find model ids in {provider} /models response at {url}"),
263 Some(url),
264 Some(status_code),
265 available_models,
266 );
267 }
268
269 if !model_is_served(&available_models, model) {
270 let available = available_models.join(", ");
271 return ModelReadiness::error(
272 provider,
273 model,
274 "model_missing",
275 format!(
276 "Model '{model}' is not served by {provider} at {url}. Currently served: {available}"
277 ),
278 Some(url),
279 Some(status_code),
280 available_models,
281 );
282 }
283
284 ModelReadiness::ok(provider, model, &url, status_code, available_models)
285}
286
287fn model_path_from_chat_endpoint(chat_endpoint: &str) -> String {
288 if let Some(prefix) = chat_endpoint.strip_suffix("/chat/completions") {
289 if prefix.is_empty() {
290 "/models".to_string()
291 } else {
292 format!("{prefix}/models")
293 }
294 } else {
295 "/models".to_string()
296 }
297}
298
299fn join_base_and_path(base: &str, path: &str) -> String {
300 let base = base.trim_end_matches('/');
301 if path.is_empty() {
302 base.to_string()
303 } else if path.starts_with('/') {
304 format!("{base}{path}")
305 } else {
306 format!("{base}/{path}")
307 }
308}
309
310fn normalize_loopback(url: &str) -> String {
311 url.replace("://localhost:", "://127.0.0.1:")
312}
313
314fn validate_url(url: &str) -> Result<String, String> {
315 reqwest::Url::parse(url)
316 .map(|_| url.to_string())
317 .map_err(|error| format!("{url} ({error})"))
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::llm_config::{HealthcheckDef, ProviderDef};
324
325 #[test]
326 fn parses_openai_and_ollama_style_model_ids() {
327 let openai = serde_json::json!({
328 "data": [{"id": "qwen-alias"}, {"id": "other"}]
329 });
330 assert_eq!(
331 parse_model_ids(&openai),
332 vec!["qwen-alias".to_string(), "other".to_string()]
333 );
334
335 let models = serde_json::json!({
336 "models": [{"name": "llama"}, {"id": "qwen"}]
337 });
338 assert_eq!(
339 parse_model_ids(&models),
340 vec!["llama".to_string(), "qwen".to_string()]
341 );
342 }
343
344 #[test]
345 fn model_matching_accepts_exact_or_prefix() {
346 let ids = vec![
347 "qwen36".to_string(),
348 "gpt-oss:20b".to_string(),
349 "llama-local-long-id".to_string(),
350 ];
351 assert!(model_is_served(&ids, "qwen36"));
352 assert!(model_is_served(&ids, "llama-local"));
353 assert!(!model_is_served(&ids, "missing"));
354 }
355
356 #[test]
357 fn models_url_uses_healthcheck_path_and_loopback_normalization() {
358 let def = ProviderDef {
359 base_url: "http://localhost:8001".to_string(),
360 chat_endpoint: "/v1/chat/completions".to_string(),
361 healthcheck: Some(HealthcheckDef {
362 method: "GET".to_string(),
363 path: Some("/v1/models".to_string()),
364 url: None,
365 body: None,
366 }),
367 ..Default::default()
368 };
369
370 assert_eq!(
371 build_models_url(&def).unwrap(),
372 "http://127.0.0.1:8001/v1/models"
373 );
374 }
375
376 #[test]
377 fn models_url_derives_path_from_chat_endpoint() {
378 let def = ProviderDef {
379 base_url: "http://127.0.0.1:8000".to_string(),
380 chat_endpoint: "/v1/chat/completions".to_string(),
381 healthcheck: None,
382 ..Default::default()
383 };
384
385 assert_eq!(
386 build_models_url(&def).unwrap(),
387 "http://127.0.0.1:8000/v1/models"
388 );
389 }
390
391 #[test]
392 fn models_url_ignores_non_model_healthcheck_path() {
393 let def = ProviderDef {
394 base_url: "http://127.0.0.1:8080".to_string(),
395 chat_endpoint: "/v1/chat/completions".to_string(),
396 healthcheck: Some(HealthcheckDef {
397 method: "GET".to_string(),
398 path: Some("/health".to_string()),
399 url: None,
400 body: None,
401 }),
402 ..Default::default()
403 };
404
405 assert_eq!(
406 build_models_url(&def).unwrap(),
407 "http://127.0.0.1:8080/v1/models"
408 );
409 }
410
411 #[tokio::test]
412 async fn probe_reports_ready_when_model_is_served() {
413 let def = test_def_with_response(200, r#"{"data":[{"id":"served-model-long"}]}"#).await;
414
415 let result =
416 probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
417
418 assert!(result.valid);
419 assert_eq!(result.category, "ok");
420 assert_eq!(
421 result.available_models,
422 vec!["served-model-long".to_string()]
423 );
424 }
425
426 #[tokio::test]
427 async fn probe_distinguishes_model_missing() {
428 let def = test_def_with_response(200, r#"{"data":[{"id":"served-model"}]}"#).await;
429
430 let result = probe_openai_compatible_model_with_def("local", "missing", "", &def).await;
431
432 assert!(!result.valid);
433 assert_eq!(result.category, "model_missing");
434 assert_eq!(result.available_models, vec!["served-model".to_string()]);
435 }
436
437 #[tokio::test]
438 async fn probe_distinguishes_bad_status() {
439 let def = test_def_with_response(503, "loading").await;
440
441 let result =
442 probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
443
444 assert!(!result.valid);
445 assert_eq!(result.category, "bad_status");
446 assert_eq!(result.status, Some(503));
447 }
448
449 #[tokio::test]
450 async fn probe_distinguishes_invalid_url() {
451 let def = ProviderDef {
452 base_url: "not a url".to_string(),
453 chat_endpoint: "/v1/chat/completions".to_string(),
454 healthcheck: Some(HealthcheckDef {
455 method: "GET".to_string(),
456 path: Some("/v1/models".to_string()),
457 url: None,
458 body: None,
459 }),
460 auth_style: "none".to_string(),
461 ..Default::default()
462 };
463
464 let result =
465 probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
466
467 assert!(!result.valid);
468 assert_eq!(result.category, "invalid_url");
469 }
470
471 #[tokio::test]
472 async fn probe_distinguishes_unreachable() {
473 let def = ProviderDef {
482 base_url: "http://127.0.0.1:1".to_string(),
483 chat_endpoint: "/v1/chat/completions".to_string(),
484 healthcheck: Some(HealthcheckDef {
485 method: "GET".to_string(),
486 path: Some("/v1/models".to_string()),
487 url: None,
488 body: None,
489 }),
490 auth_style: "none".to_string(),
491 ..Default::default()
492 };
493
494 let result =
495 probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
496
497 assert!(!result.valid);
498 assert_eq!(result.category, "unreachable");
499 }
500
501 async fn test_def_with_response(status: u16, body: &'static str) -> ProviderDef {
502 let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
503 .await
504 .expect("bind");
505 let addr = listener.local_addr().expect("addr");
506 tokio::spawn(async move {
507 let (mut socket, _) = listener.accept().await.expect("accept");
508 let mut buf = [0_u8; 1024];
509 let _ = tokio::io::AsyncReadExt::read(&mut socket, &mut buf).await;
510 let response = format!(
511 "HTTP/1.1 {status} OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}",
512 body.len(),
513 body
514 );
515 tokio::io::AsyncWriteExt::write_all(&mut socket, response.as_bytes())
516 .await
517 .expect("write");
518 });
519
520 ProviderDef {
521 base_url: format!("http://{addr}"),
522 chat_endpoint: "/v1/chat/completions".to_string(),
523 healthcheck: Some(HealthcheckDef {
524 method: "GET".to_string(),
525 path: Some("/v1/models".to_string()),
526 url: None,
527 body: None,
528 }),
529 auth_style: "none".to_string(),
530 ..Default::default()
531 }
532 }
533}