1use crate::llm_config::{self, ProviderDef};
9
10use super::auth::apply_auth_headers;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct ModelReadiness {
14 pub valid: bool,
15 pub category: String,
16 pub message: String,
17 pub provider: String,
18 pub model: String,
19 pub url: Option<String>,
20 pub status: Option<u16>,
21 pub available_models: Vec<String>,
22}
23
24impl ModelReadiness {
25 fn ok(
26 provider: &str,
27 model: &str,
28 url: &str,
29 status: u16,
30 available_models: Vec<String>,
31 ) -> Self {
32 Self {
33 valid: true,
34 category: "ok".to_string(),
35 message: format!("{provider} is reachable and serves model '{model}' at {url}"),
36 provider: provider.to_string(),
37 model: model.to_string(),
38 url: Some(url.to_string()),
39 status: Some(status),
40 available_models,
41 }
42 }
43
44 fn error(
45 provider: &str,
46 model: &str,
47 category: &str,
48 message: String,
49 url: Option<String>,
50 status: Option<u16>,
51 available_models: Vec<String>,
52 ) -> Self {
53 Self {
54 valid: false,
55 category: category.to_string(),
56 message,
57 provider: provider.to_string(),
58 model: model.to_string(),
59 url,
60 status,
61 available_models,
62 }
63 }
64}
65
66pub fn supports_model_readiness_probe(def: &ProviderDef) -> bool {
67 let healthcheck_uses_models = def.healthcheck.as_ref().is_some_and(|hc| {
68 hc.method.eq_ignore_ascii_case("GET") && {
69 hc.path
70 .as_deref()
71 .is_some_and(|path| path.contains("models"))
72 || hc.url.as_deref().is_some_and(|url| url.contains("models"))
73 }
74 });
75 healthcheck_uses_models || def.chat_endpoint.ends_with("/chat/completions")
76}
77
78pub fn selected_model_for_provider(provider: &str) -> Option<String> {
79 if provider == "local" {
80 if let Ok(model) = std::env::var("LOCAL_LLM_MODEL") {
81 if !model.trim().is_empty() {
82 let (resolved, _) = llm_config::resolve_model(model.trim());
83 return Some(resolved);
84 }
85 }
86 }
87
88 let selected_provider = std::env::var("HARN_LLM_PROVIDER")
89 .ok()
90 .filter(|value| !value.trim().is_empty());
91 if selected_provider.as_deref() == Some(provider) {
92 if let Ok(model) = std::env::var("HARN_LLM_MODEL") {
93 if !model.trim().is_empty() {
94 let (resolved, _) = llm_config::resolve_model(model.trim());
95 return Some(resolved);
96 }
97 }
98 }
99
100 None
101}
102
103pub fn build_models_url(def: &ProviderDef) -> Result<String, String> {
104 let raw = models_healthcheck_url(def).unwrap_or_else(|| {
105 join_base_and_path(
106 &llm_config::resolve_base_url(def),
107 &model_path_from_chat_endpoint(&def.chat_endpoint),
108 )
109 });
110 validate_url(&normalize_loopback(&raw))
111}
112
113fn models_healthcheck_url(def: &ProviderDef) -> Option<String> {
114 let healthcheck = def.healthcheck.as_ref()?;
115 if !healthcheck.method.eq_ignore_ascii_case("GET") {
116 return None;
117 }
118 if let Some(url) = healthcheck.url.as_ref() {
119 return url.contains("models").then(|| url.clone());
120 }
121 let path = healthcheck.path.as_deref()?;
122 path.contains("models")
123 .then(|| join_base_and_path(&llm_config::resolve_base_url(def), path))
124}
125
126pub fn parse_model_ids(json: &serde_json::Value) -> Vec<String> {
127 if let Some(entries) = json.as_array() {
128 return collect_model_ids(entries);
129 }
130
131 if let Some(data) = json.get("data").and_then(|value| value.as_array()) {
132 return collect_model_ids(data);
133 }
134
135 if let Some(models) = json.get("models").and_then(|value| value.as_array()) {
136 return collect_model_ids(models);
137 }
138
139 Vec::new()
140}
141
142fn collect_model_ids(entries: &[serde_json::Value]) -> Vec<String> {
143 entries
144 .iter()
145 .filter_map(|entry| {
146 entry.as_str().or_else(|| {
147 entry
148 .get("id")
149 .or_else(|| entry.get("name"))
150 .and_then(|value| value.as_str())
151 })
152 })
153 .map(str::to_string)
154 .collect()
155}
156
157pub fn model_is_served(available: &[String], model: &str) -> bool {
158 available
159 .iter()
160 .any(|id| id == model || id.starts_with(model))
161}
162
163pub async fn probe_openai_compatible_model(
164 provider: &str,
165 model: &str,
166 api_key: &str,
167) -> ModelReadiness {
168 let Some(def) = llm_config::provider_config(provider) else {
169 return ModelReadiness::error(
170 provider,
171 model,
172 "unknown_provider",
173 format!("Unknown provider: {provider}"),
174 None,
175 None,
176 Vec::new(),
177 );
178 };
179
180 probe_openai_compatible_model_with_def(provider, model, api_key, &def).await
181}
182
183pub(crate) async fn probe_openai_compatible_model_with_def(
184 provider: &str,
185 model: &str,
186 api_key: &str,
187 def: &ProviderDef,
188) -> ModelReadiness {
189 let url = match build_models_url(def) {
190 Ok(url) => url,
191 Err(error) => {
192 return ModelReadiness::error(
193 provider,
194 model,
195 "invalid_url",
196 format!("Invalid OpenAI-compatible models URL for {provider}: {error}"),
197 None,
198 None,
199 Vec::new(),
200 );
201 }
202 };
203
204 let client = crate::llm::shared_utility_client();
205 let req = client
206 .get(&url)
207 .header("Content-Type", "application/json")
208 .timeout(std::time::Duration::from_secs(10));
209 let req = apply_auth_headers(req, api_key, Some(def));
210 let req = def
211 .extra_headers
212 .iter()
213 .fold(req, |req, (name, value)| req.header(name, value));
214
215 let response = match req.send().await {
216 Ok(response) => response,
217 Err(error) => {
218 return ModelReadiness::error(
219 provider,
220 model,
221 "unreachable",
222 format!("{provider} OpenAI-compatible server not reachable at {url}: {error}"),
223 Some(url),
224 None,
225 Vec::new(),
226 );
227 }
228 };
229
230 let status = response.status();
231 if !status.is_success() {
232 let body = response.text().await.unwrap_or_default();
233 return ModelReadiness::error(
234 provider,
235 model,
236 "bad_status",
237 format!(
238 "{provider} returned HTTP {} at {url}: {body}",
239 status.as_u16()
240 ),
241 Some(url),
242 Some(status.as_u16()),
243 Vec::new(),
244 );
245 }
246
247 let status_code = status.as_u16();
248 let json: serde_json::Value = match response.json().await {
249 Ok(json) => json,
250 Err(error) => {
251 return ModelReadiness::error(
252 provider,
253 model,
254 "invalid_response",
255 format!("Could not parse {provider} /models response at {url}: {error}"),
256 Some(url),
257 Some(status_code),
258 Vec::new(),
259 );
260 }
261 };
262 let available_models = parse_model_ids(&json);
263 if available_models.is_empty() {
264 return ModelReadiness::error(
265 provider,
266 model,
267 "invalid_response",
268 format!("Could not find model ids in {provider} /models response at {url}"),
269 Some(url),
270 Some(status_code),
271 available_models,
272 );
273 }
274
275 if !model_is_served(&available_models, model) {
276 let available = available_models.join(", ");
277 return ModelReadiness::error(
278 provider,
279 model,
280 "model_missing",
281 format!(
282 "Model '{model}' is not served by {provider} at {url}. Currently served: {available}"
283 ),
284 Some(url),
285 Some(status_code),
286 available_models,
287 );
288 }
289
290 ModelReadiness::ok(provider, model, &url, status_code, available_models)
291}
292
293fn model_path_from_chat_endpoint(chat_endpoint: &str) -> String {
294 if let Some(prefix) = chat_endpoint.strip_suffix("/chat/completions") {
295 if prefix.is_empty() {
296 "/models".to_string()
297 } else {
298 format!("{prefix}/models")
299 }
300 } else {
301 "/models".to_string()
302 }
303}
304
305fn join_base_and_path(base: &str, path: &str) -> String {
306 let base = base.trim_end_matches('/');
307 if path.is_empty() {
308 base.to_string()
309 } else if path.starts_with('/') {
310 format!("{base}{path}")
311 } else {
312 format!("{base}/{path}")
313 }
314}
315
316fn normalize_loopback(url: &str) -> String {
317 url.replace("://localhost:", "://127.0.0.1:")
318}
319
320fn validate_url(url: &str) -> Result<String, String> {
321 reqwest::Url::parse(url)
322 .map(|_| url.to_string())
323 .map_err(|error| format!("{url} ({error})"))
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use crate::llm_config::{HealthcheckDef, ProviderDef};
330
331 #[test]
332 fn parses_openai_and_ollama_style_model_ids() {
333 let openai = serde_json::json!({
334 "data": [{"id": "qwen-alias"}, {"id": "other"}]
335 });
336 assert_eq!(
337 parse_model_ids(&openai),
338 vec!["qwen-alias".to_string(), "other".to_string()]
339 );
340
341 let models = serde_json::json!({
342 "models": [{"name": "llama"}, {"id": "qwen"}]
343 });
344 assert_eq!(
345 parse_model_ids(&models),
346 vec!["llama".to_string(), "qwen".to_string()]
347 );
348
349 let top_level = serde_json::json!([
350 {"id": "deepseek-ai/DeepSeek-V4-Pro"},
351 {"name": "qwen"},
352 "string-model"
353 ]);
354 assert_eq!(
355 parse_model_ids(&top_level),
356 vec![
357 "deepseek-ai/DeepSeek-V4-Pro".to_string(),
358 "qwen".to_string(),
359 "string-model".to_string()
360 ]
361 );
362 }
363
364 #[test]
365 fn model_matching_accepts_exact_or_prefix() {
366 let ids = vec![
367 "qwen36".to_string(),
368 "gpt-oss:20b".to_string(),
369 "llama-local-long-id".to_string(),
370 ];
371 assert!(model_is_served(&ids, "qwen36"));
372 assert!(model_is_served(&ids, "llama-local"));
373 assert!(!model_is_served(&ids, "missing"));
374 }
375
376 #[test]
377 fn models_url_uses_healthcheck_path_and_loopback_normalization() {
378 let def = ProviderDef {
379 base_url: "http://localhost:8001".to_string(),
380 chat_endpoint: "/v1/chat/completions".to_string(),
381 healthcheck: Some(HealthcheckDef {
382 method: "GET".to_string(),
383 path: Some("/v1/models".to_string()),
384 url: None,
385 body: None,
386 }),
387 ..Default::default()
388 };
389
390 assert_eq!(
391 build_models_url(&def).unwrap(),
392 "http://127.0.0.1:8001/v1/models"
393 );
394 }
395
396 #[test]
397 fn models_url_derives_path_from_chat_endpoint() {
398 let def = ProviderDef {
399 base_url: "http://127.0.0.1:8000".to_string(),
400 chat_endpoint: "/v1/chat/completions".to_string(),
401 healthcheck: None,
402 ..Default::default()
403 };
404
405 assert_eq!(
406 build_models_url(&def).unwrap(),
407 "http://127.0.0.1:8000/v1/models"
408 );
409 }
410
411 #[test]
412 fn models_url_ignores_non_model_healthcheck_path() {
413 let def = ProviderDef {
414 base_url: "http://127.0.0.1:8080".to_string(),
415 chat_endpoint: "/v1/chat/completions".to_string(),
416 healthcheck: Some(HealthcheckDef {
417 method: "GET".to_string(),
418 path: Some("/health".to_string()),
419 url: None,
420 body: None,
421 }),
422 ..Default::default()
423 };
424
425 assert_eq!(
426 build_models_url(&def).unwrap(),
427 "http://127.0.0.1:8080/v1/models"
428 );
429 }
430
431 #[tokio::test]
432 async fn probe_reports_ready_when_model_is_served() {
433 let def = test_def_with_response(200, r#"{"data":[{"id":"served-model-long"}]}"#).await;
434
435 let result =
436 probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
437
438 assert!(result.valid);
439 assert_eq!(result.category, "ok");
440 assert_eq!(
441 result.available_models,
442 vec!["served-model-long".to_string()]
443 );
444 }
445
446 #[tokio::test]
447 async fn probe_distinguishes_model_missing() {
448 let def = test_def_with_response(200, r#"{"data":[{"id":"served-model"}]}"#).await;
449
450 let result = probe_openai_compatible_model_with_def("local", "missing", "", &def).await;
451
452 assert!(!result.valid);
453 assert_eq!(result.category, "model_missing");
454 assert_eq!(result.available_models, vec!["served-model".to_string()]);
455 }
456
457 #[tokio::test]
458 async fn probe_distinguishes_bad_status() {
459 let def = test_def_with_response(503, "loading").await;
460
461 let result =
462 probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
463
464 assert!(!result.valid);
465 assert_eq!(result.category, "bad_status");
466 assert_eq!(result.status, Some(503));
467 }
468
469 #[tokio::test]
470 async fn probe_distinguishes_invalid_url() {
471 let def = ProviderDef {
472 base_url: "not a url".to_string(),
473 chat_endpoint: "/v1/chat/completions".to_string(),
474 healthcheck: Some(HealthcheckDef {
475 method: "GET".to_string(),
476 path: Some("/v1/models".to_string()),
477 url: None,
478 body: None,
479 }),
480 auth_style: "none".to_string(),
481 ..Default::default()
482 };
483
484 let result =
485 probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
486
487 assert!(!result.valid);
488 assert_eq!(result.category, "invalid_url");
489 }
490
491 #[tokio::test]
492 async fn probe_distinguishes_unreachable() {
493 let def = ProviderDef {
502 base_url: "http://127.0.0.1:1".to_string(),
503 chat_endpoint: "/v1/chat/completions".to_string(),
504 healthcheck: Some(HealthcheckDef {
505 method: "GET".to_string(),
506 path: Some("/v1/models".to_string()),
507 url: None,
508 body: None,
509 }),
510 auth_style: "none".to_string(),
511 ..Default::default()
512 };
513
514 let result =
515 probe_openai_compatible_model_with_def("local", "served-model", "", &def).await;
516
517 assert!(!result.valid);
518 assert_eq!(result.category, "unreachable");
519 }
520
521 async fn test_def_with_response(status: u16, body: &'static str) -> ProviderDef {
522 let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
523 .await
524 .expect("bind");
525 let addr = listener.local_addr().expect("addr");
526 tokio::spawn(async move {
527 let (mut socket, _) = listener.accept().await.expect("accept");
528 let mut buf = [0_u8; 1024];
529 let _ = tokio::io::AsyncReadExt::read(&mut socket, &mut buf).await;
530 let response = format!(
531 "HTTP/1.1 {status} OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}",
532 body.len(),
533 body
534 );
535 tokio::io::AsyncWriteExt::write_all(&mut socket, response.as_bytes())
536 .await
537 .expect("write");
538 });
539
540 ProviderDef {
541 base_url: format!("http://{addr}"),
542 chat_endpoint: "/v1/chat/completions".to_string(),
543 healthcheck: Some(HealthcheckDef {
544 method: "GET".to_string(),
545 path: Some("/v1/models".to_string()),
546 url: None,
547 body: None,
548 }),
549 auth_style: "none".to_string(),
550 ..Default::default()
551 }
552 }
553}