foundry_local/
api.rs

1use std::{collections::HashMap, env};
2
3use anyhow::{anyhow, Result};
4use log::{debug, info};
5use serde_json::Value;
6
7use crate::{
8    client::HttpClient,
9    models::{ExecutionProvider, FoundryListResponseModel, FoundryModelInfo},
10    service::{check_foundry_installed, get_service_uri, start_service},
11};
12
13/// Manager for Foundry Local SDK operations.
14pub struct FoundryLocalManager {
15    service_uri: Option<String>,
16    client: Option<HttpClient>,
17    catalog_list: Option<Vec<FoundryModelInfo>>,
18    catalog_dict: Option<HashMap<String, FoundryModelInfo>>,
19    timeout: Option<u64>,
20}
21
22/// Builder for creating a FoundryLocalManager instance.
23pub struct FoundryLocalManagerBuilder {
24    alias_or_model_id: Option<String>,
25    bootstrap: bool,
26    timeout_secs: Option<u64>,
27}
28
29impl FoundryLocalManagerBuilder {
30    /// Create a new builder instance.
31    pub fn new() -> Self {
32        Self {
33            alias_or_model_id: None,
34            bootstrap: false,
35            timeout_secs: None,
36        }
37    }
38
39    /// Set the alias or model ID to download and load.
40    pub fn alias_or_model_id(mut self, alias_or_model_id: impl Into<String>) -> Self {
41        self.alias_or_model_id = Some(alias_or_model_id.into());
42        self
43    }
44
45    /// Set whether to start the service if it is not running.
46    pub fn bootstrap(mut self, bootstrap: bool) -> Self {
47        self.bootstrap = bootstrap;
48        self
49    }
50
51    /// Set the timeout for the HTTP client in seconds.
52    pub fn timeout_secs(mut self, timeout_secs: u64) -> Self {
53        self.timeout_secs = Some(timeout_secs);
54        self
55    }
56
57    /// Build the FoundryLocalManager instance.
58    pub async fn build(self) -> Result<FoundryLocalManager> {
59        check_foundry_installed()?;
60
61        let mut manager = FoundryLocalManager {
62            service_uri: None,
63            client: None,
64            catalog_list: None,
65            catalog_dict: None,
66            timeout: self.timeout_secs,
67        };
68
69        if let Some(uri) = get_service_uri() {
70            manager.set_service_uri_and_client(Some(uri));
71        }
72
73        if self.bootstrap {
74            manager.start_service()?;
75
76            if let Some(model) = self.alias_or_model_id {
77                manager.download_model(&model, None, false).await?;
78                manager.load_model(&model, Some(600)).await?;
79            }
80        }
81
82        Ok(manager)
83    }
84}
85
86impl FoundryLocalManager {
87    /// Create a new builder for FoundryLocalManager.
88    pub fn builder() -> FoundryLocalManagerBuilder {
89        FoundryLocalManagerBuilder::new()
90    }
91
92    fn set_service_uri_and_client(&mut self, service_uri: Option<String>) {
93        self.service_uri = service_uri.clone();
94        self.client = service_uri.map(|uri| HttpClient::new(&uri, self.timeout));
95    }
96
97    /// Get the service URI.
98    ///
99    /// # Returns
100    ///
101    /// URI of the Foundry service.
102    pub fn service_uri(&self) -> Result<&str> {
103        self.service_uri
104            .as_deref()
105            .ok_or_else(|| anyhow!("Service URI is not set. Please start the service first."))
106    }
107
108    /// Get the HTTP client.
109    ///
110    /// # Returns
111    ///
112    /// HTTP client instance.
113    fn client(&self) -> Result<&HttpClient> {
114        self.client
115            .as_ref()
116            .ok_or_else(|| anyhow!("HTTP client is not set. Please start the service first."))
117    }
118
119    /// Get the endpoint for the service.
120    ///
121    /// # Returns
122    ///
123    /// Endpoint URL.
124    pub fn endpoint(&self) -> Result<String> {
125        Ok(format!("{}/v1", self.service_uri()?))
126    }
127
128    /// Get the API key for authentication.
129    ///
130    /// # Returns
131    ///
132    /// API key.
133    pub fn api_key(&self) -> String {
134        env::var("OPENAI_API_KEY").unwrap_or_else(|_| "OPENAI_API_KEY".to_string())
135    }
136
137    // Service management API
138
139    /// Check if the service is running. Will also set the service URI if it is not set.
140    ///
141    /// # Returns
142    ///
143    /// True if the service is running, False otherwise.
144    pub fn is_service_running(&mut self) -> bool {
145        if let Some(uri) = get_service_uri() {
146            self.set_service_uri_and_client(Some(uri));
147            true
148        } else {
149            false
150        }
151    }
152
153    /// Start the service.
154    ///
155    /// # Returns
156    ///
157    /// Result indicating success or failure.
158    pub fn start_service(&mut self) -> Result<()> {
159        let uri = start_service()?;
160        self.set_service_uri_and_client(Some(uri));
161        Ok(())
162    }
163
164    // Catalog API
165
166    /// Get a list of available models in the catalog.
167    ///
168    /// # Returns
169    ///
170    /// List of catalog models.
171    pub async fn list_catalog_models(&mut self) -> Result<&Vec<FoundryModelInfo>> {
172        if self.catalog_list.is_none() {
173            let models: Vec<FoundryListResponseModel> = self
174                .client()?
175                .get("/foundry/list", None)
176                .await?
177                .ok_or_else(|| anyhow!("Failed to get catalog models"))?;
178
179            self.catalog_list = Some(
180                models
181                    .iter()
182                    .map(FoundryModelInfo::from_list_response)
183                    .collect(),
184            );
185        }
186
187        Ok(self.catalog_list.as_ref().unwrap())
188    }
189
190    /// Get a dictionary of available models. Keyed by model ID and alias.
191    /// Alias points to the most preferred model.
192    ///
193    /// # Returns
194    ///
195    /// Dictionary of catalog models.
196    async fn get_catalog_dict(&mut self) -> Result<&HashMap<String, FoundryModelInfo>> {
197        if self.catalog_dict.is_some() {
198            return Ok(self.catalog_dict.as_ref().unwrap());
199        }
200
201        let catalog_models = self.list_catalog_models().await?;
202        let mut catalog_dict = HashMap::new();
203        let mut alias_candidates: HashMap<String, Vec<&FoundryModelInfo>> = HashMap::new();
204
205        // Create dictionary of models by ID
206        for model in catalog_models {
207            catalog_dict.insert(model.id.clone(), model.clone());
208        }
209
210        // Group models by alias
211        for model in catalog_models {
212            alias_candidates
213                .entry(model.alias.clone())
214                .or_default()
215                .push(model);
216        }
217
218        // Define the preferred order of execution providers
219        let mut preferred_order = vec![
220            ExecutionProvider::QNN,
221            ExecutionProvider::CUDA,
222            ExecutionProvider::CPU,
223            ExecutionProvider::WebGPU,
224        ];
225
226        if cfg!(not(target_os = "windows")) {
227            // Adjust order for non-Windows platforms
228            preferred_order.retain(|p| !matches!(p, ExecutionProvider::CPU));
229            preferred_order.push(ExecutionProvider::CPU);
230        }
231
232        let priority_map: HashMap<_, _> = preferred_order
233            .into_iter()
234            .enumerate()
235            .map(|(i, provider)| (provider, i))
236            .collect();
237
238        // Choose the preferred model for each alias
239        for (alias, candidates) in alias_candidates {
240            if let Some(preferred) = candidates.into_iter().min_by_key(|model| {
241                priority_map
242                    .get(&model.runtime)
243                    .copied()
244                    .unwrap_or(usize::MAX)
245            }) {
246                catalog_dict.insert(alias, preferred.clone());
247            }
248        }
249
250        self.catalog_dict = Some(catalog_dict);
251        Ok(self.catalog_dict.as_ref().unwrap())
252    }
253
254    /// Refresh the catalog.
255    pub fn refresh_catalog(&mut self) {
256        self.catalog_list = None;
257        self.catalog_dict = None;
258    }
259
260    /// Get the model information by alias or ID.
261    ///
262    /// # Arguments
263    ///
264    /// * `alias_or_model_id` - Alias or Model ID. If it is an alias, the most preferred model will be returned.
265    /// * `raise_on_not_found` - If true, raise an error if the model is not found. Default is false.
266    ///
267    /// # Returns
268    ///
269    /// Model information, or None if not found and raise_on_not_found is false.
270    pub async fn get_model_info(
271        &mut self,
272        alias_or_model_id: &str,
273        raise_on_not_found: bool,
274    ) -> Result<FoundryModelInfo> {
275        let catalog_dict = self.get_catalog_dict().await?;
276
277        match catalog_dict.get(alias_or_model_id) {
278            Some(model) => Ok(model.clone()),
279            None if raise_on_not_found => Err(anyhow!(
280                "Model {} not found in the catalog",
281                alias_or_model_id
282            )),
283            None => Err(anyhow!(
284                "Model {} not found in the catalog",
285                alias_or_model_id
286            )),
287        }
288    }
289
290    // Cache management API
291
292    /// Get the cache location.
293    ///
294    /// # Returns
295    ///
296    /// Cache location as a string.
297    pub async fn get_cache_location(&self) -> Result<String> {
298        let response: Value = self
299            .client()?
300            .get("/foundry/cache", None)
301            .await?
302            .ok_or_else(|| anyhow!("Failed to get cache location"))?;
303
304        response["location"]
305            .as_str()
306            .map(|s| s.to_string())
307            .ok_or_else(|| anyhow!("Invalid cache location response"))
308    }
309
310    /// List cached models.
311    ///
312    /// # Returns
313    ///
314    /// List of cached models.
315    pub async fn list_cached_models(&mut self) -> Result<Vec<FoundryModelInfo>> {
316        let response: Value = self
317            .client()?
318            .get("/openai/models", None)
319            .await?
320            .ok_or_else(|| anyhow!("Failed to list cached models - no response"))?;
321
322        // Handle both direct array response and object with models field
323        let model_ids = if response.is_array() {
324            response
325                .as_array()
326                .ok_or_else(|| anyhow!("Invalid models response - expected array"))?
327                .iter()
328                .filter_map(|v| v.as_str())
329                .map(|s| s.to_string())
330                .collect::<Vec<_>>()
331        } else {
332            response["models"]
333                .as_array()
334                .ok_or_else(|| anyhow!("Invalid models response - expected models field"))?
335                .iter()
336                .filter_map(|v| v.as_str())
337                .map(|s| s.to_string())
338                .collect::<Vec<_>>()
339        };
340
341        self.fetch_model_infos(&model_ids).await
342    }
343
344    async fn fetch_model_infos(&mut self, model_ids: &[String]) -> Result<Vec<FoundryModelInfo>> {
345        let mut results = Vec::new();
346        let catalog_dict = self.get_catalog_dict().await?;
347
348        for id in model_ids {
349            if let Some(model) = catalog_dict.get(id) {
350                results.push(model.clone());
351            } else {
352                debug!("Model {id} not found in the catalog");
353            }
354        }
355
356        Ok(results)
357    }
358
359    /// Download a model.
360    ///
361    /// # Arguments
362    ///
363    /// * `alias_or_model_id` - Alias or Model ID.
364    /// * `token` - Optional token for authentication.
365    /// * `force` - If true, force re-download even if the model is already cached.
366    ///
367    /// # Returns
368    ///
369    /// Downloaded model information.
370    pub async fn download_model(
371        &mut self,
372        alias_or_model_id: &str,
373        token: Option<&str>,
374        force: bool,
375    ) -> Result<FoundryModelInfo> {
376        let model_info = self.get_model_info(alias_or_model_id, true).await?;
377        info!(
378            "Downloading model: {} ({}) - {} MB",
379            model_info.alias, model_info.id, model_info.file_size_mb
380        );
381
382        let mut body = model_info.to_download_body();
383
384        if let Some(t) = token {
385            body["token"] = Value::String(t.to_string());
386        }
387
388        if force {
389            body["Force"] = Value::Bool(true);
390        }
391
392        let client = self.client()?;
393        let _response: Value = client
394            .post_with_progress("/openai/download", Some(body))
395            .await?;
396
397        Ok(model_info)
398    }
399
400    /// Load a model.
401    ///
402    /// # Arguments
403    ///
404    /// * `alias_or_model_id` - Alias or Model ID.
405    /// * `ttl` - Optional time-to-live in seconds. Default is 10 minutes (600 seconds).
406    ///
407    /// # Returns
408    ///
409    /// Loaded model information.
410    pub async fn load_model(
411        &mut self,
412        alias_or_model_id: &str,
413        ttl: Option<i32>,
414    ) -> Result<FoundryModelInfo> {
415        let model_info = self.get_model_info(alias_or_model_id, true).await?;
416        info!("Loading model: {} ({})", model_info.alias, model_info.id);
417
418        let url = format!("/openai/load/{}", model_info.id);
419        let ttl_str = ttl.unwrap_or(600).to_string();
420        let mut query_params = vec![("ttl", ttl_str.as_str())];
421
422        // Handle execution provider selection for WEBGPU and CUDA models
423        let ep_str = if matches!(
424            model_info.runtime,
425            ExecutionProvider::WebGPU | ExecutionProvider::CUDA
426        ) {
427            let has_cuda_support = self
428                .list_catalog_models()
429                .await?
430                .iter()
431                .any(|mi| mi.runtime == ExecutionProvider::CUDA);
432
433            if has_cuda_support {
434                ExecutionProvider::CUDA.get_alias().to_string()
435            } else {
436                model_info.runtime.get_alias().to_string()
437            }
438        } else {
439            String::new()
440        };
441
442        if !ep_str.is_empty() {
443            query_params.push(("ep", ep_str.as_str()));
444        }
445
446        let client = self.client()?;
447        let _response: Option<Value> = client.get(&url, Some(&query_params)).await?;
448
449        Ok(model_info)
450    }
451
452    /// Unload a model.
453    ///
454    /// # Arguments
455    ///
456    /// * `alias_or_model_id` - Alias or Model ID.
457    /// * `force` - If true, force unload even if the model is in use.
458    ///
459    /// # Returns
460    ///
461    /// Result indicating success or failure.
462    pub async fn unload_model(&mut self, alias_or_model_id: &str, force: bool) -> Result<()> {
463        let model_info = self.get_model_info(alias_or_model_id, true).await?;
464        info!("Unloading model: {} ({})", model_info.alias, model_info.id);
465
466        let url = format!("/openai/unload/{}", model_info.id);
467        let force_str = force.to_string();
468        let query_params = vec![("force", force_str.as_str())];
469
470        let client = self.client()?;
471        let _response: Option<Value> = client.get(&url, Some(&query_params)).await?;
472
473        Ok(())
474    }
475
476    /// List loaded models.
477    ///
478    /// # Returns
479    ///
480    /// List of loaded models.
481    pub async fn list_loaded_models(&mut self) -> Result<Vec<FoundryModelInfo>> {
482        let response: Value = self
483            .client()?
484            .get("/openai/loadedmodels", None)
485            .await?
486            .ok_or_else(|| anyhow!("Failed to list loaded models - no response"))?;
487
488        // Handle both direct array response and object with models field
489        let model_ids = if response.is_array() {
490            response
491                .as_array()
492                .ok_or_else(|| anyhow!("Invalid models response - expected array"))?
493                .iter()
494                .filter_map(|v| v.as_str())
495                .map(|s| s.to_string())
496                .collect::<Vec<_>>()
497        } else {
498            response["models"]
499                .as_array()
500                .ok_or_else(|| anyhow!("Invalid models response - expected models field"))?
501                .iter()
502                .filter_map(|v| v.as_str())
503                .map(|s| s.to_string())
504                .collect::<Vec<_>>()
505        };
506
507        self.fetch_model_infos(&model_ids).await
508    }
509
510    /// Set a custom service URI and client for testing purposes.
511    #[doc(hidden)]
512    pub fn set_test_service_uri(&mut self, uri: &str) {
513        self.service_uri = Some(uri.to_string());
514        self.client = Some(HttpClient::new(uri, self.timeout));
515        self.catalog_list = None;
516        self.catalog_dict = None;
517    }
518
519    /// Create a new FoundryLocalManager instance for testing without checking if Foundry is installed.
520    #[doc(hidden)]
521    pub async fn new_for_testing() -> Result<Self> {
522        Ok(Self {
523            service_uri: None,
524            client: None,
525            catalog_list: None,
526            catalog_dict: None,
527            timeout: None,
528        })
529    }
530}