foundry_local/
models.rs

1use serde::{Deserialize, Serialize};
2use std::fmt;
3
4/// Enumeration of devices supported by the model.
5#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
6#[serde(rename_all = "UPPERCASE")]
7pub enum DeviceType {
8    CPU,
9    GPU,
10    NPU,
11}
12
13impl fmt::Display for DeviceType {
14    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15        match self {
16            DeviceType::CPU => write!(f, "CPU"),
17            DeviceType::GPU => write!(f, "GPU"),
18            DeviceType::NPU => write!(f, "NPU"),
19        }
20    }
21}
22
23/// Enumeration of execution providers supported by the model.
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
25pub enum ExecutionProvider {
26    #[serde(rename = "CPUExecutionProvider")]
27    CPU,
28    #[serde(rename = "WebGpuExecutionProvider")]
29    WebGPU,
30    #[serde(rename = "CUDAExecutionProvider")]
31    CUDA,
32    #[serde(rename = "QNNExecutionProvider")]
33    QNN,
34}
35
36impl ExecutionProvider {
37    /// Get the alias for the execution provider.
38    pub fn get_alias(&self) -> String {
39        match self {
40            ExecutionProvider::CPU => "cpu".to_string(),
41            ExecutionProvider::WebGPU => "webgpu".to_string(),
42            ExecutionProvider::CUDA => "cuda".to_string(),
43            ExecutionProvider::QNN => "qnn".to_string(),
44        }
45    }
46}
47
48impl fmt::Display for ExecutionProvider {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        write!(
51            f,
52            "{}",
53            match self {
54                ExecutionProvider::CPU => "CPUExecutionProvider",
55                ExecutionProvider::WebGPU => "WebGpuExecutionProvider",
56                ExecutionProvider::CUDA => "CUDAExecutionProvider",
57                ExecutionProvider::QNN => "QNNExecutionProvider",
58            }
59        )
60    }
61}
62
63/// Model runtime information.
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ModelRuntime {
66    #[serde(rename = "deviceType")]
67    pub device_type: DeviceType,
68    #[serde(rename = "executionProvider")]
69    pub execution_provider: ExecutionProvider,
70}
71
72/// Response model for listing models.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct FoundryListResponseModel {
75    pub name: String,
76    #[serde(rename = "displayName")]
77    pub display_name: String,
78    #[serde(rename = "modelType")]
79    pub model_type: String,
80    #[serde(rename = "providerType")]
81    pub provider_type: String,
82    pub uri: String,
83    pub version: String,
84    #[serde(rename = "promptTemplate")]
85    pub prompt_template: serde_json::Value,
86    pub publisher: String,
87    pub task: String,
88    pub runtime: ModelRuntime,
89    #[serde(rename = "fileSizeMb")]
90    pub file_size_mb: i32,
91    #[serde(rename = "modelSettings")]
92    pub model_settings: serde_json::Value,
93    pub alias: String,
94    #[serde(rename = "supportsToolCalling")]
95    pub supports_tool_calling: bool,
96    pub license: String,
97    #[serde(rename = "licenseDescription")]
98    pub license_description: String,
99    #[serde(rename = "parentModelUri")]
100    pub parent_model_uri: String,
101}
102
103/// Model information.
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct FoundryModelInfo {
106    pub alias: String,
107    pub id: String,
108    pub version: String,
109    pub runtime: ExecutionProvider,
110    pub uri: String,
111    pub file_size_mb: i32,
112    pub prompt_template: serde_json::Value,
113    pub provider: String,
114    pub publisher: String,
115    pub license: String,
116    pub task: String,
117}
118
119impl FoundryModelInfo {
120    /// Create a FoundryModelInfo object from a FoundryListResponseModel object.
121    pub fn from_list_response(response: &FoundryListResponseModel) -> Self {
122        Self {
123            alias: response.alias.clone(),
124            id: response.name.clone(),
125            version: response.version.clone(),
126            runtime: response.runtime.execution_provider.clone(),
127            uri: response.uri.clone(),
128            file_size_mb: response.file_size_mb,
129            prompt_template: response.prompt_template.clone(),
130            provider: response.provider_type.clone(),
131            publisher: response.publisher.clone(),
132            license: response.license.clone(),
133            task: response.task.clone(),
134        }
135    }
136
137    /// Convert the FoundryModelInfo object to a dictionary for download.
138    pub fn to_download_body(&self) -> serde_json::Value {
139        let provider_type = if self.provider == "AzureFoundry" {
140            format!("{}Local", self.provider)
141        } else {
142            self.provider.clone()
143        };
144
145        serde_json::json!({
146            "model": {
147                "Name": self.id,
148                "Uri": self.uri,
149                "Publisher": self.publisher,
150                "ProviderType": provider_type,
151                "PromptTemplate": self.prompt_template,
152            },
153            "IgnorePipeReport": true
154        })
155    }
156}
157
158impl fmt::Display for FoundryModelInfo {
159    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
160        write!(
161            f,
162            "FoundryModelInfo(alias={}, id={}, runtime={}, file_size={} MB, license={})",
163            self.alias,
164            self.id,
165            self.runtime.get_alias(),
166            self.file_size_mb,
167            self.license
168        )
169    }
170}