Skip to main content

modelexpress_common/
lib.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use serde::{Deserialize, Serialize};
5use std::env;
6use std::error::Error as StdError;
7
8pub mod cache;
9pub mod client_config;
10pub mod config;
11pub mod download;
12pub mod models;
13pub mod providers;
14#[cfg(any(test, feature = "test-support"))]
15#[doc(hidden)]
16pub mod test_support;
17
18// Generated gRPC code
19#[allow(clippy::similar_names)]
20#[allow(clippy::default_trait_access)]
21#[allow(clippy::doc_markdown)]
22#[allow(clippy::must_use_candidate)]
23pub mod grpc {
24    pub mod health {
25        tonic::include_proto!("model_express.health");
26    }
27    pub mod api {
28        tonic::include_proto!("model_express.api");
29    }
30    pub mod model {
31        tonic::include_proto!("model_express.model");
32    }
33    pub mod p2p {
34        tonic::include_proto!("model_express.p2p");
35    }
36}
37
38/// Defines the shared response format between server and client (legacy HTTP)
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct Response<T> {
41    pub success: bool,
42    pub data: Option<T>,
43    pub error: Option<String>,
44}
45
46/// Common error types that both client and server can use
47#[derive(Debug, thiserror::Error)]
48pub enum Error {
49    #[error("Network error: {0}")]
50    Network(String),
51
52    #[error("Server returned error: {0}")]
53    Server(String),
54
55    #[error("I/O error: {0}")]
56    Io(String),
57
58    #[error("Validation error: {0}")]
59    Validation(String),
60
61    #[error("Serialization error: {0}")]
62    Serialization(String),
63
64    #[error("gRPC error: {0}")]
65    Grpc(#[from] tonic::Status),
66
67    #[error("Transport error: {0}")]
68    Transport(String),
69
70    #[error("Generic error: {0}")]
71    Generic(String),
72}
73
74fn format_error_chain(err: &(dyn StdError + 'static)) -> String {
75    let mut parts = Vec::new();
76    let mut current = Some(err);
77
78    while let Some(error) = current {
79        let part = error.to_string();
80        if !part.is_empty() && parts.last() != Some(&part) {
81            parts.push(part);
82        }
83        current = error.source();
84    }
85
86    if parts.len() > 1 && parts.first().is_some_and(|part| part == "transport error") {
87        parts.remove(0);
88    }
89
90    if parts.is_empty() {
91        "transport error".to_string()
92    } else {
93        parts.join(": ")
94    }
95}
96
97// Implement From traits for Box<Error> to work with the Result<T> type
98impl From<tonic::Status> for Box<Error> {
99    fn from(err: tonic::Status) -> Self {
100        Box::new(Error::Grpc(err))
101    }
102}
103
104impl From<tonic::transport::Error> for Error {
105    fn from(err: tonic::transport::Error) -> Self {
106        Error::Transport(format_error_chain(&err))
107    }
108}
109
110impl From<tonic::transport::Error> for Box<Error> {
111    fn from(err: tonic::transport::Error) -> Self {
112        Box::new(Error::from(err))
113    }
114}
115
116/// Common result type for the project
117pub type Result<T> = std::result::Result<T, Box<Error>>;
118
119/// Marker struct to use Utils methods
120pub struct Utils;
121
122impl Utils {
123    /// Get home directory from environment variables
124    pub fn get_home_dir() -> std::result::Result<String, Box<Error>> {
125        env::var("HOME")
126            .or_else(|_| env::var("USERPROFILE"))
127            .map_err(|e| Error::Generic(format!("Failed to get home directory: {e}")).into())
128    }
129}
130
131/// Constants shared between client and server
132pub mod constants {
133    use std::num::NonZeroU16;
134
135    pub const DEFAULT_CACHE_PATH: &str = ".model-express/cache";
136    pub const DEFAULT_HF_CACHE_PATH: &str = ".cache/huggingface/hub";
137    pub const DEFAULT_CONFIG_PATH: &str = ".model-express/config.yaml";
138
139    pub const DEFAULT_GRPC_PORT: NonZeroU16 = NonZeroU16::new(8001).expect("8001 is non-zero");
140    pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
141
142    /// Default setting for shared storage mode (true = client and server share a network drive)
143    pub const DEFAULT_SHARED_STORAGE: bool = true;
144
145    /// Default chunk size for file transfer streaming in bytes (32 KB)
146    pub const DEFAULT_TRANSFER_CHUNK_SIZE: usize = 32 * 1024;
147}
148
149// Conversion utilities between gRPC and legacy models
150impl From<&models::Status> for grpc::health::HealthResponse {
151    fn from(status: &models::Status) -> Self {
152        Self {
153            version: status.version.clone(),
154            status: status.status.clone(),
155            uptime: status.uptime,
156        }
157    }
158}
159
160impl From<grpc::health::HealthResponse> for models::Status {
161    fn from(response: grpc::health::HealthResponse) -> Self {
162        Self {
163            version: response.version,
164            status: response.status,
165            uptime: response.uptime,
166        }
167    }
168}
169
170impl From<models::ModelProvider> for grpc::model::ModelProvider {
171    fn from(provider: models::ModelProvider) -> Self {
172        match provider {
173            models::ModelProvider::HuggingFace => grpc::model::ModelProvider::HuggingFace,
174            models::ModelProvider::Ngc => grpc::model::ModelProvider::Ngc,
175            models::ModelProvider::Gcs => grpc::model::ModelProvider::Gcs,
176        }
177    }
178}
179
180impl From<grpc::model::ModelProvider> for models::ModelProvider {
181    fn from(provider: grpc::model::ModelProvider) -> Self {
182        match provider {
183            grpc::model::ModelProvider::HuggingFace => models::ModelProvider::HuggingFace,
184            grpc::model::ModelProvider::Ngc => models::ModelProvider::Ngc,
185            grpc::model::ModelProvider::Gcs => models::ModelProvider::Gcs,
186        }
187    }
188}
189
190impl From<models::ModelStatus> for grpc::model::ModelStatus {
191    fn from(status: models::ModelStatus) -> Self {
192        match status {
193            models::ModelStatus::DOWNLOADING => grpc::model::ModelStatus::Downloading,
194            models::ModelStatus::DOWNLOADED => grpc::model::ModelStatus::Downloaded,
195            models::ModelStatus::ERROR => grpc::model::ModelStatus::Error,
196        }
197    }
198}
199
200impl From<grpc::model::ModelStatus> for models::ModelStatus {
201    fn from(status: grpc::model::ModelStatus) -> Self {
202        match status {
203            grpc::model::ModelStatus::Downloading => models::ModelStatus::DOWNLOADING,
204            grpc::model::ModelStatus::Downloaded => models::ModelStatus::DOWNLOADED,
205            grpc::model::ModelStatus::Error => models::ModelStatus::ERROR,
206        }
207    }
208}
209
210impl From<&models::ModelStatusResponse> for grpc::model::ModelStatusUpdate {
211    fn from(response: &models::ModelStatusResponse) -> Self {
212        Self {
213            model_name: response.model_name.clone(),
214            status: grpc::model::ModelStatus::from(response.status) as i32,
215            message: None,
216            provider: grpc::model::ModelProvider::from(response.provider) as i32,
217        }
218    }
219}
220
221impl From<grpc::model::ModelStatusUpdate> for models::ModelStatusResponse {
222    fn from(update: grpc::model::ModelStatusUpdate) -> Self {
223        Self {
224            model_name: update.model_name,
225            status: grpc::model::ModelStatus::try_from(update.status)
226                .unwrap_or(grpc::model::ModelStatus::Error)
227                .into(),
228            provider: grpc::model::ModelProvider::try_from(update.provider)
229                .unwrap_or(grpc::model::ModelProvider::HuggingFace)
230                .into(),
231        }
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use std::env;
239    use std::io;
240
241    #[test]
242    fn test_status_conversion_from_models_to_grpc() {
243        let status = models::Status {
244            version: "1.0.0".to_string(),
245            status: "ok".to_string(),
246            uptime: 3600,
247        };
248
249        let grpc_response: grpc::health::HealthResponse = (&status).into();
250
251        assert_eq!(grpc_response.version, status.version);
252        assert_eq!(grpc_response.status, status.status);
253        assert_eq!(grpc_response.uptime, status.uptime);
254    }
255
256    #[derive(Debug, thiserror::Error)]
257    #[error("outer error")]
258    struct OuterError(#[source] io::Error);
259
260    #[derive(Debug, thiserror::Error)]
261    #[error("transport error")]
262    struct TransportWrapper(#[source] io::Error);
263
264    #[test]
265    fn test_format_error_chain_includes_nested_causes() {
266        let err = OuterError(io::Error::other("connection reset by peer"));
267        assert_eq!(
268            format_error_chain(&err),
269            "outer error: connection reset by peer"
270        );
271    }
272
273    #[test]
274    fn test_format_error_chain_skips_repeated_transport_prefix() {
275        let err = TransportWrapper(io::Error::other("underlying cause"));
276        assert_eq!(format_error_chain(&err), "underlying cause");
277    }
278
279    #[test]
280    fn test_status_conversion_from_grpc_to_models() {
281        let grpc_response = grpc::health::HealthResponse {
282            version: "1.0.0".to_string(),
283            status: "ok".to_string(),
284            uptime: 3600,
285        };
286
287        let status: models::Status = grpc_response.into();
288
289        assert_eq!(status.version, "1.0.0");
290        assert_eq!(status.status, "ok");
291        assert_eq!(status.uptime, 3600);
292    }
293
294    #[test]
295    fn test_model_provider_conversion_both_ways() {
296        for model_provider in [
297            models::ModelProvider::HuggingFace,
298            models::ModelProvider::Ngc,
299            models::ModelProvider::Gcs,
300        ] {
301            let grpc_provider: grpc::model::ModelProvider = model_provider.into();
302            let back_to_model: models::ModelProvider = grpc_provider.into();
303            assert_eq!(model_provider, back_to_model);
304        }
305    }
306
307    #[test]
308    fn test_model_status_conversion_both_ways() {
309        let statuses = vec![
310            models::ModelStatus::DOWNLOADING,
311            models::ModelStatus::DOWNLOADED,
312            models::ModelStatus::ERROR,
313        ];
314
315        for status in statuses {
316            let grpc_status: grpc::model::ModelStatus = status.into();
317            let back_to_model: models::ModelStatus = grpc_status.into();
318            assert_eq!(status, back_to_model);
319        }
320    }
321
322    #[test]
323    fn test_model_status_response_conversion_from_models_to_grpc() {
324        let response = models::ModelStatusResponse {
325            model_name: "test-model".to_string(),
326            status: models::ModelStatus::DOWNLOADED,
327            provider: models::ModelProvider::HuggingFace,
328        };
329
330        let grpc_update: grpc::model::ModelStatusUpdate = (&response).into();
331
332        assert_eq!(grpc_update.model_name, response.model_name);
333        assert_eq!(
334            grpc_update.status,
335            grpc::model::ModelStatus::Downloaded as i32
336        );
337        assert_eq!(
338            grpc_update.provider,
339            grpc::model::ModelProvider::HuggingFace as i32
340        );
341        assert!(grpc_update.message.is_none());
342    }
343
344    #[test]
345    fn test_model_status_response_conversion_from_grpc_to_models() {
346        let grpc_update = grpc::model::ModelStatusUpdate {
347            model_name: "test-model".to_string(),
348            status: grpc::model::ModelStatus::Downloaded as i32,
349            message: Some("Test message".to_string()),
350            provider: grpc::model::ModelProvider::HuggingFace as i32,
351        };
352
353        let response: models::ModelStatusResponse = grpc_update.into();
354
355        assert_eq!(response.model_name, "test-model");
356        assert_eq!(response.status, models::ModelStatus::DOWNLOADED);
357        assert_eq!(response.provider, models::ModelProvider::HuggingFace);
358    }
359
360    #[test]
361    fn test_error_types() {
362        let network_error = Error::Network("Connection failed".to_string());
363        assert!(network_error.to_string().contains("Network error"));
364
365        let server_error = Error::Server("Internal error".to_string());
366        assert!(server_error.to_string().contains("Server returned error"));
367
368        let io_error = Error::Io("Permission denied".to_string());
369        assert!(io_error.to_string().contains("I/O error"));
370
371        let validation_error = Error::Validation("Unsafe path".to_string());
372        assert!(validation_error.to_string().contains("Validation error"));
373
374        let serialization_error = Error::Serialization("JSON parse error".to_string());
375        assert!(
376            serialization_error
377                .to_string()
378                .contains("Serialization error")
379        );
380    }
381
382    #[test]
383    fn test_constants() {
384        assert_eq!(constants::DEFAULT_GRPC_PORT.get(), 8001);
385        assert_eq!(constants::DEFAULT_TIMEOUT_SECS, 30);
386        assert_eq!(constants::DEFAULT_TRANSFER_CHUNK_SIZE, 32 * 1024);
387    }
388
389    #[test]
390    fn test_response_creation() {
391        let success_response = Response {
392            success: true,
393            data: Some("test data".to_string()),
394            error: None,
395        };
396
397        assert!(success_response.success);
398        assert!(success_response.data.is_some());
399        assert!(success_response.error.is_none());
400
401        let error_response: Response<String> = Response {
402            success: false,
403            data: None,
404            error: Some("test error".to_string()),
405        };
406
407        assert!(!error_response.success);
408        assert!(error_response.data.is_none());
409        assert!(error_response.error.is_some());
410    }
411
412    #[test]
413    fn test_utils_get_home_dir() {
414        let home_dir = Utils::get_home_dir();
415
416        if let Ok(home_dir) = home_dir {
417            assert!(!home_dir.is_empty());
418            // Check against HOME or USERPROFILE
419            if let Ok(expected_home) = env::var("HOME") {
420                assert_eq!(home_dir, expected_home);
421            } else if let Ok(expected_home) = env::var("USERPROFILE") {
422                assert_eq!(home_dir, expected_home);
423            }
424        }
425    }
426}