1use std::{collections::HashMap, env};
2
3use anyhow::{anyhow, Result};
4use log::{debug, info};
5use serde_json::Value;
6
7use crate::{
8 client::HttpClient,
9 models::{ExecutionProvider, FoundryListResponseModel, FoundryModelInfo},
10 service::{check_foundry_installed, get_service_uri, start_service},
11};
12
13pub struct FoundryLocalManager {
15 service_uri: Option<String>,
16 client: Option<HttpClient>,
17 catalog_list: Option<Vec<FoundryModelInfo>>,
18 catalog_dict: Option<HashMap<String, FoundryModelInfo>>,
19 timeout: Option<u64>,
20}
21
22pub struct FoundryLocalManagerBuilder {
24 alias_or_model_id: Option<String>,
25 bootstrap: bool,
26 timeout_secs: Option<u64>,
27}
28
29impl FoundryLocalManagerBuilder {
30 pub fn new() -> Self {
32 Self {
33 alias_or_model_id: None,
34 bootstrap: false,
35 timeout_secs: None,
36 }
37 }
38
39 pub fn alias_or_model_id(mut self, alias_or_model_id: impl Into<String>) -> Self {
41 self.alias_or_model_id = Some(alias_or_model_id.into());
42 self
43 }
44
45 pub fn bootstrap(mut self, bootstrap: bool) -> Self {
47 self.bootstrap = bootstrap;
48 self
49 }
50
51 pub fn timeout_secs(mut self, timeout_secs: u64) -> Self {
53 self.timeout_secs = Some(timeout_secs);
54 self
55 }
56
57 pub async fn build(self) -> Result<FoundryLocalManager> {
59 check_foundry_installed()?;
60
61 let mut manager = FoundryLocalManager {
62 service_uri: None,
63 client: None,
64 catalog_list: None,
65 catalog_dict: None,
66 timeout: self.timeout_secs,
67 };
68
69 if let Some(uri) = get_service_uri() {
70 manager.set_service_uri_and_client(Some(uri));
71 }
72
73 if self.bootstrap {
74 manager.start_service()?;
75
76 if let Some(model) = self.alias_or_model_id {
77 manager.download_model(&model, None, false).await?;
78 manager.load_model(&model, Some(600)).await?;
79 }
80 }
81
82 Ok(manager)
83 }
84}
85
86impl FoundryLocalManager {
87 pub fn builder() -> FoundryLocalManagerBuilder {
89 FoundryLocalManagerBuilder::new()
90 }
91
92 fn set_service_uri_and_client(&mut self, service_uri: Option<String>) {
93 self.service_uri = service_uri.clone();
94 self.client = service_uri.map(|uri| HttpClient::new(&uri, self.timeout));
95 }
96
97 pub fn service_uri(&self) -> Result<&str> {
103 self.service_uri
104 .as_deref()
105 .ok_or_else(|| anyhow!("Service URI is not set. Please start the service first."))
106 }
107
108 fn client(&self) -> Result<&HttpClient> {
114 self.client
115 .as_ref()
116 .ok_or_else(|| anyhow!("HTTP client is not set. Please start the service first."))
117 }
118
119 pub fn endpoint(&self) -> Result<String> {
125 Ok(format!("{}/v1", self.service_uri()?))
126 }
127
128 pub fn api_key(&self) -> String {
134 env::var("OPENAI_API_KEY").unwrap_or_else(|_| "OPENAI_API_KEY".to_string())
135 }
136
137 pub fn is_service_running(&mut self) -> bool {
145 if let Some(uri) = get_service_uri() {
146 self.set_service_uri_and_client(Some(uri));
147 true
148 } else {
149 false
150 }
151 }
152
153 pub fn start_service(&mut self) -> Result<()> {
159 let uri = start_service()?;
160 self.set_service_uri_and_client(Some(uri));
161 Ok(())
162 }
163
164 pub async fn list_catalog_models(&mut self) -> Result<&Vec<FoundryModelInfo>> {
172 if self.catalog_list.is_none() {
173 let models: Vec<FoundryListResponseModel> = self
174 .client()?
175 .get("/foundry/list", None)
176 .await?
177 .ok_or_else(|| anyhow!("Failed to get catalog models"))?;
178
179 self.catalog_list = Some(
180 models
181 .iter()
182 .map(FoundryModelInfo::from_list_response)
183 .collect(),
184 );
185 }
186
187 Ok(self.catalog_list.as_ref().unwrap())
188 }
189
190 async fn get_catalog_dict(&mut self) -> Result<&HashMap<String, FoundryModelInfo>> {
197 if self.catalog_dict.is_some() {
198 return Ok(self.catalog_dict.as_ref().unwrap());
199 }
200
201 let catalog_models = self.list_catalog_models().await?;
202 let mut catalog_dict = HashMap::new();
203 let mut alias_candidates: HashMap<String, Vec<&FoundryModelInfo>> = HashMap::new();
204
205 for model in catalog_models {
207 catalog_dict.insert(model.id.clone(), model.clone());
208 }
209
210 for model in catalog_models {
212 alias_candidates
213 .entry(model.alias.clone())
214 .or_default()
215 .push(model);
216 }
217
218 let mut preferred_order = vec![
220 ExecutionProvider::QNN,
221 ExecutionProvider::CUDA,
222 ExecutionProvider::CPU,
223 ExecutionProvider::WebGPU,
224 ];
225
226 if cfg!(not(target_os = "windows")) {
227 preferred_order.retain(|p| !matches!(p, ExecutionProvider::CPU));
229 preferred_order.push(ExecutionProvider::CPU);
230 }
231
232 let priority_map: HashMap<_, _> = preferred_order
233 .into_iter()
234 .enumerate()
235 .map(|(i, provider)| (provider, i))
236 .collect();
237
238 for (alias, candidates) in alias_candidates {
240 if let Some(preferred) = candidates.into_iter().min_by_key(|model| {
241 priority_map
242 .get(&model.runtime)
243 .copied()
244 .unwrap_or(usize::MAX)
245 }) {
246 catalog_dict.insert(alias, preferred.clone());
247 }
248 }
249
250 self.catalog_dict = Some(catalog_dict);
251 Ok(self.catalog_dict.as_ref().unwrap())
252 }
253
254 pub fn refresh_catalog(&mut self) {
256 self.catalog_list = None;
257 self.catalog_dict = None;
258 }
259
260 pub async fn get_model_info(
271 &mut self,
272 alias_or_model_id: &str,
273 raise_on_not_found: bool,
274 ) -> Result<FoundryModelInfo> {
275 let catalog_dict = self.get_catalog_dict().await?;
276
277 match catalog_dict.get(alias_or_model_id) {
278 Some(model) => Ok(model.clone()),
279 None if raise_on_not_found => Err(anyhow!(
280 "Model {} not found in the catalog",
281 alias_or_model_id
282 )),
283 None => Err(anyhow!(
284 "Model {} not found in the catalog",
285 alias_or_model_id
286 )),
287 }
288 }
289
290 pub async fn get_cache_location(&self) -> Result<String> {
298 let response: Value = self
299 .client()?
300 .get("/foundry/cache", None)
301 .await?
302 .ok_or_else(|| anyhow!("Failed to get cache location"))?;
303
304 response["location"]
305 .as_str()
306 .map(|s| s.to_string())
307 .ok_or_else(|| anyhow!("Invalid cache location response"))
308 }
309
310 pub async fn list_cached_models(&mut self) -> Result<Vec<FoundryModelInfo>> {
316 let response: Value = self
317 .client()?
318 .get("/openai/models", None)
319 .await?
320 .ok_or_else(|| anyhow!("Failed to list cached models - no response"))?;
321
322 let model_ids = if response.is_array() {
324 response
325 .as_array()
326 .ok_or_else(|| anyhow!("Invalid models response - expected array"))?
327 .iter()
328 .filter_map(|v| v.as_str())
329 .map(|s| s.to_string())
330 .collect::<Vec<_>>()
331 } else {
332 response["models"]
333 .as_array()
334 .ok_or_else(|| anyhow!("Invalid models response - expected models field"))?
335 .iter()
336 .filter_map(|v| v.as_str())
337 .map(|s| s.to_string())
338 .collect::<Vec<_>>()
339 };
340
341 self.fetch_model_infos(&model_ids).await
342 }
343
344 async fn fetch_model_infos(&mut self, model_ids: &[String]) -> Result<Vec<FoundryModelInfo>> {
345 let mut results = Vec::new();
346 let catalog_dict = self.get_catalog_dict().await?;
347
348 for id in model_ids {
349 if let Some(model) = catalog_dict.get(id) {
350 results.push(model.clone());
351 } else {
352 debug!("Model {id} not found in the catalog");
353 }
354 }
355
356 Ok(results)
357 }
358
359 pub async fn download_model(
371 &mut self,
372 alias_or_model_id: &str,
373 token: Option<&str>,
374 force: bool,
375 ) -> Result<FoundryModelInfo> {
376 let model_info = self.get_model_info(alias_or_model_id, true).await?;
377 info!(
378 "Downloading model: {} ({}) - {} MB",
379 model_info.alias, model_info.id, model_info.file_size_mb
380 );
381
382 let mut body = model_info.to_download_body();
383
384 if let Some(t) = token {
385 body["token"] = Value::String(t.to_string());
386 }
387
388 if force {
389 body["Force"] = Value::Bool(true);
390 }
391
392 let client = self.client()?;
393 let _response: Value = client
394 .post_with_progress("/openai/download", Some(body))
395 .await?;
396
397 Ok(model_info)
398 }
399
400 pub async fn load_model(
411 &mut self,
412 alias_or_model_id: &str,
413 ttl: Option<i32>,
414 ) -> Result<FoundryModelInfo> {
415 let model_info = self.get_model_info(alias_or_model_id, true).await?;
416 info!("Loading model: {} ({})", model_info.alias, model_info.id);
417
418 let url = format!("/openai/load/{}", model_info.id);
419 let ttl_str = ttl.unwrap_or(600).to_string();
420 let mut query_params = vec![("ttl", ttl_str.as_str())];
421
422 let ep_str = if matches!(
424 model_info.runtime,
425 ExecutionProvider::WebGPU | ExecutionProvider::CUDA
426 ) {
427 let has_cuda_support = self
428 .list_catalog_models()
429 .await?
430 .iter()
431 .any(|mi| mi.runtime == ExecutionProvider::CUDA);
432
433 if has_cuda_support {
434 ExecutionProvider::CUDA.get_alias().to_string()
435 } else {
436 model_info.runtime.get_alias().to_string()
437 }
438 } else {
439 String::new()
440 };
441
442 if !ep_str.is_empty() {
443 query_params.push(("ep", ep_str.as_str()));
444 }
445
446 let client = self.client()?;
447 let _response: Option<Value> = client.get(&url, Some(&query_params)).await?;
448
449 Ok(model_info)
450 }
451
452 pub async fn unload_model(&mut self, alias_or_model_id: &str, force: bool) -> Result<()> {
463 let model_info = self.get_model_info(alias_or_model_id, true).await?;
464 info!("Unloading model: {} ({})", model_info.alias, model_info.id);
465
466 let url = format!("/openai/unload/{}", model_info.id);
467 let force_str = force.to_string();
468 let query_params = vec![("force", force_str.as_str())];
469
470 let client = self.client()?;
471 let _response: Option<Value> = client.get(&url, Some(&query_params)).await?;
472
473 Ok(())
474 }
475
476 pub async fn list_loaded_models(&mut self) -> Result<Vec<FoundryModelInfo>> {
482 let response: Value = self
483 .client()?
484 .get("/openai/loadedmodels", None)
485 .await?
486 .ok_or_else(|| anyhow!("Failed to list loaded models - no response"))?;
487
488 let model_ids = if response.is_array() {
490 response
491 .as_array()
492 .ok_or_else(|| anyhow!("Invalid models response - expected array"))?
493 .iter()
494 .filter_map(|v| v.as_str())
495 .map(|s| s.to_string())
496 .collect::<Vec<_>>()
497 } else {
498 response["models"]
499 .as_array()
500 .ok_or_else(|| anyhow!("Invalid models response - expected models field"))?
501 .iter()
502 .filter_map(|v| v.as_str())
503 .map(|s| s.to_string())
504 .collect::<Vec<_>>()
505 };
506
507 self.fetch_model_infos(&model_ids).await
508 }
509
510 #[doc(hidden)]
512 pub fn set_test_service_uri(&mut self, uri: &str) {
513 self.service_uri = Some(uri.to_string());
514 self.client = Some(HttpClient::new(uri, self.timeout));
515 self.catalog_list = None;
516 self.catalog_dict = None;
517 }
518
519 #[doc(hidden)]
521 pub async fn new_for_testing() -> Result<Self> {
522 Ok(Self {
523 service_uri: None,
524 client: None,
525 catalog_list: None,
526 catalog_dict: None,
527 timeout: None,
528 })
529 }
530}