modelexpress_common/
lib.rs1use serde::{Deserialize, Serialize};
5use std::env;
6
7pub mod cache;
8pub mod client_config;
9pub mod config;
10pub mod download;
11pub mod models;
12pub mod providers;
13
14#[allow(clippy::similar_names)]
16#[allow(clippy::default_trait_access)]
17#[allow(clippy::doc_markdown)]
18#[allow(clippy::must_use_candidate)]
19pub mod grpc {
20 pub mod health {
21 tonic::include_proto!("model_express.health");
22 }
23 pub mod api {
24 tonic::include_proto!("model_express.api");
25 }
26 pub mod model {
27 tonic::include_proto!("model_express.model");
28 }
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Response<T> {
34 pub success: bool,
35 pub data: Option<T>,
36 pub error: Option<String>,
37}
38
39#[derive(Debug, thiserror::Error)]
41pub enum Error {
42 #[error("Network error: {0}")]
43 Network(String),
44
45 #[error("Server returned error: {0}")]
46 Server(String),
47
48 #[error("Serialization error: {0}")]
49 Serialization(String),
50
51 #[error("gRPC error: {0}")]
52 Grpc(#[from] tonic::Status),
53
54 #[error("Transport error: {0}")]
55 Transport(#[from] tonic::transport::Error),
56
57 #[error("Generic error: {0}")]
58 Generic(String),
59}
60
61impl From<tonic::Status> for Box<Error> {
63 fn from(err: tonic::Status) -> Self {
64 Box::new(Error::Grpc(err))
65 }
66}
67
68impl From<tonic::transport::Error> for Box<Error> {
69 fn from(err: tonic::transport::Error) -> Self {
70 Box::new(Error::Transport(err))
71 }
72}
73
74pub type Result<T> = std::result::Result<T, Box<Error>>;
76
77pub struct Utils;
79
80impl Utils {
81 pub fn get_home_dir() -> std::result::Result<String, Box<Error>> {
83 env::var("HOME")
84 .or_else(|_| env::var("USERPROFILE"))
85 .map_err(|e| Error::Generic(format!("Failed to get home directory: {e}")).into())
86 }
87}
88
89pub mod constants {
91 use std::num::NonZeroU16;
92
93 pub const DEFAULT_CACHE_PATH: &str = ".model-express/cache";
94 pub const DEFAULT_HF_CACHE_PATH: &str = ".cache/huggingface/hub";
95 pub const DEFAULT_CONFIG_PATH: &str = ".model-express/config.yaml";
96
97 pub const DEFAULT_GRPC_PORT: NonZeroU16 = NonZeroU16::new(8001).expect("8001 is non-zero");
98 pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
99
100 pub const DEFAULT_SHARED_STORAGE: bool = true;
102
103 pub const DEFAULT_TRANSFER_CHUNK_SIZE: usize = 32 * 1024;
105}
106
107impl From<&models::Status> for grpc::health::HealthResponse {
109 fn from(status: &models::Status) -> Self {
110 Self {
111 version: status.version.clone(),
112 status: status.status.clone(),
113 uptime: status.uptime,
114 }
115 }
116}
117
118impl From<grpc::health::HealthResponse> for models::Status {
119 fn from(response: grpc::health::HealthResponse) -> Self {
120 Self {
121 version: response.version,
122 status: response.status,
123 uptime: response.uptime,
124 }
125 }
126}
127
128impl From<models::ModelProvider> for grpc::model::ModelProvider {
129 fn from(provider: models::ModelProvider) -> Self {
130 match provider {
131 models::ModelProvider::HuggingFace => grpc::model::ModelProvider::HuggingFace,
132 }
133 }
134}
135
136impl From<grpc::model::ModelProvider> for models::ModelProvider {
137 fn from(provider: grpc::model::ModelProvider) -> Self {
138 match provider {
139 grpc::model::ModelProvider::HuggingFace => models::ModelProvider::HuggingFace,
140 }
141 }
142}
143
144impl From<models::ModelStatus> for grpc::model::ModelStatus {
145 fn from(status: models::ModelStatus) -> Self {
146 match status {
147 models::ModelStatus::DOWNLOADING => grpc::model::ModelStatus::Downloading,
148 models::ModelStatus::DOWNLOADED => grpc::model::ModelStatus::Downloaded,
149 models::ModelStatus::ERROR => grpc::model::ModelStatus::Error,
150 }
151 }
152}
153
154impl From<grpc::model::ModelStatus> for models::ModelStatus {
155 fn from(status: grpc::model::ModelStatus) -> Self {
156 match status {
157 grpc::model::ModelStatus::Downloading => models::ModelStatus::DOWNLOADING,
158 grpc::model::ModelStatus::Downloaded => models::ModelStatus::DOWNLOADED,
159 grpc::model::ModelStatus::Error => models::ModelStatus::ERROR,
160 }
161 }
162}
163
164impl From<&models::ModelStatusResponse> for grpc::model::ModelStatusUpdate {
165 fn from(response: &models::ModelStatusResponse) -> Self {
166 Self {
167 model_name: response.model_name.clone(),
168 status: grpc::model::ModelStatus::from(response.status) as i32,
169 message: None,
170 provider: grpc::model::ModelProvider::from(response.provider) as i32,
171 }
172 }
173}
174
175impl From<grpc::model::ModelStatusUpdate> for models::ModelStatusResponse {
176 fn from(update: grpc::model::ModelStatusUpdate) -> Self {
177 Self {
178 model_name: update.model_name,
179 status: grpc::model::ModelStatus::try_from(update.status)
180 .unwrap_or(grpc::model::ModelStatus::Error)
181 .into(),
182 provider: grpc::model::ModelProvider::try_from(update.provider)
183 .unwrap_or(grpc::model::ModelProvider::HuggingFace)
184 .into(),
185 }
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use std::env;
193
194 #[test]
195 fn test_status_conversion_from_models_to_grpc() {
196 let status = models::Status {
197 version: "1.0.0".to_string(),
198 status: "ok".to_string(),
199 uptime: 3600,
200 };
201
202 let grpc_response: grpc::health::HealthResponse = (&status).into();
203
204 assert_eq!(grpc_response.version, status.version);
205 assert_eq!(grpc_response.status, status.status);
206 assert_eq!(grpc_response.uptime, status.uptime);
207 }
208
209 #[test]
210 fn test_status_conversion_from_grpc_to_models() {
211 let grpc_response = grpc::health::HealthResponse {
212 version: "1.0.0".to_string(),
213 status: "ok".to_string(),
214 uptime: 3600,
215 };
216
217 let status: models::Status = grpc_response.into();
218
219 assert_eq!(status.version, "1.0.0");
220 assert_eq!(status.status, "ok");
221 assert_eq!(status.uptime, 3600);
222 }
223
224 #[test]
225 fn test_model_provider_conversion_both_ways() {
226 let model_provider = models::ModelProvider::HuggingFace;
227 let grpc_provider: grpc::model::ModelProvider = model_provider.into();
228 let back_to_model: models::ModelProvider = grpc_provider.into();
229
230 assert_eq!(model_provider, back_to_model);
231 }
232
233 #[test]
234 fn test_model_status_conversion_both_ways() {
235 let statuses = vec![
236 models::ModelStatus::DOWNLOADING,
237 models::ModelStatus::DOWNLOADED,
238 models::ModelStatus::ERROR,
239 ];
240
241 for status in statuses {
242 let grpc_status: grpc::model::ModelStatus = status.into();
243 let back_to_model: models::ModelStatus = grpc_status.into();
244 assert_eq!(status, back_to_model);
245 }
246 }
247
248 #[test]
249 fn test_model_status_response_conversion_from_models_to_grpc() {
250 let response = models::ModelStatusResponse {
251 model_name: "test-model".to_string(),
252 status: models::ModelStatus::DOWNLOADED,
253 provider: models::ModelProvider::HuggingFace,
254 };
255
256 let grpc_update: grpc::model::ModelStatusUpdate = (&response).into();
257
258 assert_eq!(grpc_update.model_name, response.model_name);
259 assert_eq!(
260 grpc_update.status,
261 grpc::model::ModelStatus::Downloaded as i32
262 );
263 assert_eq!(
264 grpc_update.provider,
265 grpc::model::ModelProvider::HuggingFace as i32
266 );
267 assert!(grpc_update.message.is_none());
268 }
269
270 #[test]
271 fn test_model_status_response_conversion_from_grpc_to_models() {
272 let grpc_update = grpc::model::ModelStatusUpdate {
273 model_name: "test-model".to_string(),
274 status: grpc::model::ModelStatus::Downloaded as i32,
275 message: Some("Test message".to_string()),
276 provider: grpc::model::ModelProvider::HuggingFace as i32,
277 };
278
279 let response: models::ModelStatusResponse = grpc_update.into();
280
281 assert_eq!(response.model_name, "test-model");
282 assert_eq!(response.status, models::ModelStatus::DOWNLOADED);
283 assert_eq!(response.provider, models::ModelProvider::HuggingFace);
284 }
285
286 #[test]
287 fn test_error_types() {
288 let network_error = Error::Network("Connection failed".to_string());
289 assert!(network_error.to_string().contains("Network error"));
290
291 let server_error = Error::Server("Internal error".to_string());
292 assert!(server_error.to_string().contains("Server returned error"));
293
294 let serialization_error = Error::Serialization("JSON parse error".to_string());
295 assert!(
296 serialization_error
297 .to_string()
298 .contains("Serialization error")
299 );
300 }
301
302 #[test]
303 fn test_constants() {
304 assert_eq!(constants::DEFAULT_GRPC_PORT.get(), 8001);
305 assert_eq!(constants::DEFAULT_TIMEOUT_SECS, 30);
306 assert_eq!(constants::DEFAULT_TRANSFER_CHUNK_SIZE, 32 * 1024);
307 }
308
309 #[test]
310 fn test_response_creation() {
311 let success_response = Response {
312 success: true,
313 data: Some("test data".to_string()),
314 error: None,
315 };
316
317 assert!(success_response.success);
318 assert!(success_response.data.is_some());
319 assert!(success_response.error.is_none());
320
321 let error_response: Response<String> = Response {
322 success: false,
323 data: None,
324 error: Some("test error".to_string()),
325 };
326
327 assert!(!error_response.success);
328 assert!(error_response.data.is_none());
329 assert!(error_response.error.is_some());
330 }
331
332 #[test]
333 fn test_utils_get_home_dir() {
334 let home_dir = Utils::get_home_dir();
335
336 if let Ok(home_dir) = home_dir {
337 assert!(!home_dir.is_empty());
338 if let Ok(expected_home) = env::var("HOME") {
340 assert_eq!(home_dir, expected_home);
341 } else if let Ok(expected_home) = env::var("USERPROFILE") {
342 assert_eq!(home_dir, expected_home);
343 }
344 }
345 }
346}