1use serde::{Deserialize, Serialize};
2use std::fmt;
3
4#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
6#[serde(rename_all = "UPPERCASE")]
7pub enum DeviceType {
8 CPU,
9 GPU,
10 NPU,
11}
12
13impl fmt::Display for DeviceType {
14 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15 match self {
16 DeviceType::CPU => write!(f, "CPU"),
17 DeviceType::GPU => write!(f, "GPU"),
18 DeviceType::NPU => write!(f, "NPU"),
19 }
20 }
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
25pub enum ExecutionProvider {
26 #[serde(rename = "CPUExecutionProvider")]
27 CPU,
28 #[serde(rename = "WebGpuExecutionProvider")]
29 WebGPU,
30 #[serde(rename = "CUDAExecutionProvider")]
31 CUDA,
32 #[serde(rename = "QNNExecutionProvider")]
33 QNN,
34}
35
36impl ExecutionProvider {
37 pub fn get_alias(&self) -> String {
39 match self {
40 ExecutionProvider::CPU => "cpu".to_string(),
41 ExecutionProvider::WebGPU => "webgpu".to_string(),
42 ExecutionProvider::CUDA => "cuda".to_string(),
43 ExecutionProvider::QNN => "qnn".to_string(),
44 }
45 }
46}
47
48impl fmt::Display for ExecutionProvider {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 write!(
51 f,
52 "{}",
53 match self {
54 ExecutionProvider::CPU => "CPUExecutionProvider",
55 ExecutionProvider::WebGPU => "WebGpuExecutionProvider",
56 ExecutionProvider::CUDA => "CUDAExecutionProvider",
57 ExecutionProvider::QNN => "QNNExecutionProvider",
58 }
59 )
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ModelRuntime {
66 #[serde(rename = "deviceType")]
67 pub device_type: DeviceType,
68 #[serde(rename = "executionProvider")]
69 pub execution_provider: ExecutionProvider,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct FoundryListResponseModel {
75 pub name: String,
76 #[serde(rename = "displayName")]
77 pub display_name: String,
78 #[serde(rename = "modelType")]
79 pub model_type: String,
80 #[serde(rename = "providerType")]
81 pub provider_type: String,
82 pub uri: String,
83 pub version: String,
84 #[serde(rename = "promptTemplate")]
85 pub prompt_template: serde_json::Value,
86 pub publisher: String,
87 pub task: String,
88 pub runtime: ModelRuntime,
89 #[serde(rename = "fileSizeMb")]
90 pub file_size_mb: i32,
91 #[serde(rename = "modelSettings")]
92 pub model_settings: serde_json::Value,
93 pub alias: String,
94 #[serde(rename = "supportsToolCalling")]
95 pub supports_tool_calling: bool,
96 pub license: String,
97 #[serde(rename = "licenseDescription")]
98 pub license_description: String,
99 #[serde(rename = "parentModelUri")]
100 pub parent_model_uri: String,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct FoundryModelInfo {
106 pub alias: String,
107 pub id: String,
108 pub version: String,
109 pub runtime: ExecutionProvider,
110 pub uri: String,
111 pub file_size_mb: i32,
112 pub prompt_template: serde_json::Value,
113 pub provider: String,
114 pub publisher: String,
115 pub license: String,
116 pub task: String,
117}
118
119impl FoundryModelInfo {
120 pub fn from_list_response(response: &FoundryListResponseModel) -> Self {
122 Self {
123 alias: response.alias.clone(),
124 id: response.name.clone(),
125 version: response.version.clone(),
126 runtime: response.runtime.execution_provider.clone(),
127 uri: response.uri.clone(),
128 file_size_mb: response.file_size_mb,
129 prompt_template: response.prompt_template.clone(),
130 provider: response.provider_type.clone(),
131 publisher: response.publisher.clone(),
132 license: response.license.clone(),
133 task: response.task.clone(),
134 }
135 }
136
137 pub fn to_download_body(&self) -> serde_json::Value {
139 let provider_type = if self.provider == "AzureFoundry" {
140 format!("{}Local", self.provider)
141 } else {
142 self.provider.clone()
143 };
144
145 serde_json::json!({
146 "model": {
147 "Name": self.id,
148 "Uri": self.uri,
149 "Publisher": self.publisher,
150 "ProviderType": provider_type,
151 "PromptTemplate": self.prompt_template,
152 },
153 "IgnorePipeReport": true
154 })
155 }
156}
157
158impl fmt::Display for FoundryModelInfo {
159 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
160 write!(
161 f,
162 "FoundryModelInfo(alias={}, id={}, runtime={}, file_size={} MB, license={})",
163 self.alias,
164 self.id,
165 self.runtime.get_alias(),
166 self.file_size_mb,
167 self.license
168 )
169 }
170}