1use serde::{Deserialize, Serialize};
5use std::env;
6use std::error::Error as StdError;
7
8pub mod cache;
9pub mod client_config;
10pub mod config;
11pub mod download;
12pub mod models;
13pub mod providers;
14#[cfg(any(test, feature = "test-support"))]
15#[doc(hidden)]
16pub mod test_support;
17
18#[allow(clippy::similar_names)]
20#[allow(clippy::default_trait_access)]
21#[allow(clippy::doc_markdown)]
22#[allow(clippy::must_use_candidate)]
23pub mod grpc {
24 pub mod health {
25 tonic::include_proto!("model_express.health");
26 }
27 pub mod api {
28 tonic::include_proto!("model_express.api");
29 }
30 pub mod model {
31 tonic::include_proto!("model_express.model");
32 }
33 pub mod p2p {
34 tonic::include_proto!("model_express.p2p");
35 }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct Response<T> {
41 pub success: bool,
42 pub data: Option<T>,
43 pub error: Option<String>,
44}
45
46#[derive(Debug, thiserror::Error)]
48pub enum Error {
49 #[error("Network error: {0}")]
50 Network(String),
51
52 #[error("Server returned error: {0}")]
53 Server(String),
54
55 #[error("I/O error: {0}")]
56 Io(String),
57
58 #[error("Validation error: {0}")]
59 Validation(String),
60
61 #[error("Serialization error: {0}")]
62 Serialization(String),
63
64 #[error("gRPC error: {0}")]
65 Grpc(#[from] tonic::Status),
66
67 #[error("Transport error: {0}")]
68 Transport(String),
69
70 #[error("Generic error: {0}")]
71 Generic(String),
72}
73
74fn format_error_chain(err: &(dyn StdError + 'static)) -> String {
75 let mut parts = Vec::new();
76 let mut current = Some(err);
77
78 while let Some(error) = current {
79 let part = error.to_string();
80 if !part.is_empty() && parts.last() != Some(&part) {
81 parts.push(part);
82 }
83 current = error.source();
84 }
85
86 if parts.len() > 1 && parts.first().is_some_and(|part| part == "transport error") {
87 parts.remove(0);
88 }
89
90 if parts.is_empty() {
91 "transport error".to_string()
92 } else {
93 parts.join(": ")
94 }
95}
96
97impl From<tonic::Status> for Box<Error> {
99 fn from(err: tonic::Status) -> Self {
100 Box::new(Error::Grpc(err))
101 }
102}
103
104impl From<tonic::transport::Error> for Error {
105 fn from(err: tonic::transport::Error) -> Self {
106 Error::Transport(format_error_chain(&err))
107 }
108}
109
110impl From<tonic::transport::Error> for Box<Error> {
111 fn from(err: tonic::transport::Error) -> Self {
112 Box::new(Error::from(err))
113 }
114}
115
116pub type Result<T> = std::result::Result<T, Box<Error>>;
118
119pub struct Utils;
121
122impl Utils {
123 pub fn get_home_dir() -> std::result::Result<String, Box<Error>> {
125 env::var("HOME")
126 .or_else(|_| env::var("USERPROFILE"))
127 .map_err(|e| Error::Generic(format!("Failed to get home directory: {e}")).into())
128 }
129}
130
131pub mod constants {
133 use std::num::NonZeroU16;
134
135 pub const DEFAULT_CACHE_PATH: &str = ".model-express/cache";
136 pub const DEFAULT_HF_CACHE_PATH: &str = ".cache/huggingface/hub";
137 pub const DEFAULT_CONFIG_PATH: &str = ".model-express/config.yaml";
138
139 pub const DEFAULT_GRPC_PORT: NonZeroU16 = NonZeroU16::new(8001).expect("8001 is non-zero");
140 pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
141
142 pub const DEFAULT_SHARED_STORAGE: bool = true;
144
145 pub const DEFAULT_TRANSFER_CHUNK_SIZE: usize = 32 * 1024;
147}
148
149impl From<&models::Status> for grpc::health::HealthResponse {
151 fn from(status: &models::Status) -> Self {
152 Self {
153 version: status.version.clone(),
154 status: status.status.clone(),
155 uptime: status.uptime,
156 }
157 }
158}
159
160impl From<grpc::health::HealthResponse> for models::Status {
161 fn from(response: grpc::health::HealthResponse) -> Self {
162 Self {
163 version: response.version,
164 status: response.status,
165 uptime: response.uptime,
166 }
167 }
168}
169
170impl From<models::ModelProvider> for grpc::model::ModelProvider {
171 fn from(provider: models::ModelProvider) -> Self {
172 match provider {
173 models::ModelProvider::HuggingFace => grpc::model::ModelProvider::HuggingFace,
174 models::ModelProvider::Ngc => grpc::model::ModelProvider::Ngc,
175 models::ModelProvider::Gcs => grpc::model::ModelProvider::Gcs,
176 }
177 }
178}
179
180impl From<grpc::model::ModelProvider> for models::ModelProvider {
181 fn from(provider: grpc::model::ModelProvider) -> Self {
182 match provider {
183 grpc::model::ModelProvider::HuggingFace => models::ModelProvider::HuggingFace,
184 grpc::model::ModelProvider::Ngc => models::ModelProvider::Ngc,
185 grpc::model::ModelProvider::Gcs => models::ModelProvider::Gcs,
186 }
187 }
188}
189
190impl From<models::ModelStatus> for grpc::model::ModelStatus {
191 fn from(status: models::ModelStatus) -> Self {
192 match status {
193 models::ModelStatus::DOWNLOADING => grpc::model::ModelStatus::Downloading,
194 models::ModelStatus::DOWNLOADED => grpc::model::ModelStatus::Downloaded,
195 models::ModelStatus::ERROR => grpc::model::ModelStatus::Error,
196 }
197 }
198}
199
200impl From<grpc::model::ModelStatus> for models::ModelStatus {
201 fn from(status: grpc::model::ModelStatus) -> Self {
202 match status {
203 grpc::model::ModelStatus::Downloading => models::ModelStatus::DOWNLOADING,
204 grpc::model::ModelStatus::Downloaded => models::ModelStatus::DOWNLOADED,
205 grpc::model::ModelStatus::Error => models::ModelStatus::ERROR,
206 }
207 }
208}
209
210impl From<&models::ModelStatusResponse> for grpc::model::ModelStatusUpdate {
211 fn from(response: &models::ModelStatusResponse) -> Self {
212 Self {
213 model_name: response.model_name.clone(),
214 status: grpc::model::ModelStatus::from(response.status) as i32,
215 message: None,
216 provider: grpc::model::ModelProvider::from(response.provider) as i32,
217 }
218 }
219}
220
221impl From<grpc::model::ModelStatusUpdate> for models::ModelStatusResponse {
222 fn from(update: grpc::model::ModelStatusUpdate) -> Self {
223 Self {
224 model_name: update.model_name,
225 status: grpc::model::ModelStatus::try_from(update.status)
226 .unwrap_or(grpc::model::ModelStatus::Error)
227 .into(),
228 provider: grpc::model::ModelProvider::try_from(update.provider)
229 .unwrap_or(grpc::model::ModelProvider::HuggingFace)
230 .into(),
231 }
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use std::env;
239 use std::io;
240
241 #[test]
242 fn test_status_conversion_from_models_to_grpc() {
243 let status = models::Status {
244 version: "1.0.0".to_string(),
245 status: "ok".to_string(),
246 uptime: 3600,
247 };
248
249 let grpc_response: grpc::health::HealthResponse = (&status).into();
250
251 assert_eq!(grpc_response.version, status.version);
252 assert_eq!(grpc_response.status, status.status);
253 assert_eq!(grpc_response.uptime, status.uptime);
254 }
255
256 #[derive(Debug, thiserror::Error)]
257 #[error("outer error")]
258 struct OuterError(#[source] io::Error);
259
260 #[derive(Debug, thiserror::Error)]
261 #[error("transport error")]
262 struct TransportWrapper(#[source] io::Error);
263
264 #[test]
265 fn test_format_error_chain_includes_nested_causes() {
266 let err = OuterError(io::Error::other("connection reset by peer"));
267 assert_eq!(
268 format_error_chain(&err),
269 "outer error: connection reset by peer"
270 );
271 }
272
273 #[test]
274 fn test_format_error_chain_skips_repeated_transport_prefix() {
275 let err = TransportWrapper(io::Error::other("underlying cause"));
276 assert_eq!(format_error_chain(&err), "underlying cause");
277 }
278
279 #[test]
280 fn test_status_conversion_from_grpc_to_models() {
281 let grpc_response = grpc::health::HealthResponse {
282 version: "1.0.0".to_string(),
283 status: "ok".to_string(),
284 uptime: 3600,
285 };
286
287 let status: models::Status = grpc_response.into();
288
289 assert_eq!(status.version, "1.0.0");
290 assert_eq!(status.status, "ok");
291 assert_eq!(status.uptime, 3600);
292 }
293
294 #[test]
295 fn test_model_provider_conversion_both_ways() {
296 for model_provider in [
297 models::ModelProvider::HuggingFace,
298 models::ModelProvider::Ngc,
299 models::ModelProvider::Gcs,
300 ] {
301 let grpc_provider: grpc::model::ModelProvider = model_provider.into();
302 let back_to_model: models::ModelProvider = grpc_provider.into();
303 assert_eq!(model_provider, back_to_model);
304 }
305 }
306
307 #[test]
308 fn test_model_status_conversion_both_ways() {
309 let statuses = vec![
310 models::ModelStatus::DOWNLOADING,
311 models::ModelStatus::DOWNLOADED,
312 models::ModelStatus::ERROR,
313 ];
314
315 for status in statuses {
316 let grpc_status: grpc::model::ModelStatus = status.into();
317 let back_to_model: models::ModelStatus = grpc_status.into();
318 assert_eq!(status, back_to_model);
319 }
320 }
321
322 #[test]
323 fn test_model_status_response_conversion_from_models_to_grpc() {
324 let response = models::ModelStatusResponse {
325 model_name: "test-model".to_string(),
326 status: models::ModelStatus::DOWNLOADED,
327 provider: models::ModelProvider::HuggingFace,
328 };
329
330 let grpc_update: grpc::model::ModelStatusUpdate = (&response).into();
331
332 assert_eq!(grpc_update.model_name, response.model_name);
333 assert_eq!(
334 grpc_update.status,
335 grpc::model::ModelStatus::Downloaded as i32
336 );
337 assert_eq!(
338 grpc_update.provider,
339 grpc::model::ModelProvider::HuggingFace as i32
340 );
341 assert!(grpc_update.message.is_none());
342 }
343
344 #[test]
345 fn test_model_status_response_conversion_from_grpc_to_models() {
346 let grpc_update = grpc::model::ModelStatusUpdate {
347 model_name: "test-model".to_string(),
348 status: grpc::model::ModelStatus::Downloaded as i32,
349 message: Some("Test message".to_string()),
350 provider: grpc::model::ModelProvider::HuggingFace as i32,
351 };
352
353 let response: models::ModelStatusResponse = grpc_update.into();
354
355 assert_eq!(response.model_name, "test-model");
356 assert_eq!(response.status, models::ModelStatus::DOWNLOADED);
357 assert_eq!(response.provider, models::ModelProvider::HuggingFace);
358 }
359
360 #[test]
361 fn test_error_types() {
362 let network_error = Error::Network("Connection failed".to_string());
363 assert!(network_error.to_string().contains("Network error"));
364
365 let server_error = Error::Server("Internal error".to_string());
366 assert!(server_error.to_string().contains("Server returned error"));
367
368 let io_error = Error::Io("Permission denied".to_string());
369 assert!(io_error.to_string().contains("I/O error"));
370
371 let validation_error = Error::Validation("Unsafe path".to_string());
372 assert!(validation_error.to_string().contains("Validation error"));
373
374 let serialization_error = Error::Serialization("JSON parse error".to_string());
375 assert!(
376 serialization_error
377 .to_string()
378 .contains("Serialization error")
379 );
380 }
381
382 #[test]
383 fn test_constants() {
384 assert_eq!(constants::DEFAULT_GRPC_PORT.get(), 8001);
385 assert_eq!(constants::DEFAULT_TIMEOUT_SECS, 30);
386 assert_eq!(constants::DEFAULT_TRANSFER_CHUNK_SIZE, 32 * 1024);
387 }
388
389 #[test]
390 fn test_response_creation() {
391 let success_response = Response {
392 success: true,
393 data: Some("test data".to_string()),
394 error: None,
395 };
396
397 assert!(success_response.success);
398 assert!(success_response.data.is_some());
399 assert!(success_response.error.is_none());
400
401 let error_response: Response<String> = Response {
402 success: false,
403 data: None,
404 error: Some("test error".to_string()),
405 };
406
407 assert!(!error_response.success);
408 assert!(error_response.data.is_none());
409 assert!(error_response.error.is_some());
410 }
411
412 #[test]
413 fn test_utils_get_home_dir() {
414 let home_dir = Utils::get_home_dir();
415
416 if let Ok(home_dir) = home_dir {
417 assert!(!home_dir.is_empty());
418 if let Ok(expected_home) = env::var("HOME") {
420 assert_eq!(home_dir, expected_home);
421 } else if let Ok(expected_home) = env::var("USERPROFILE") {
422 assert_eq!(home_dir, expected_home);
423 }
424 }
425 }
426}