Skip to main content

modelexpress_common/
lib.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use serde::{Deserialize, Serialize};
5use std::env;
6
7pub mod cache;
8pub mod client_config;
9pub mod config;
10pub mod download;
11pub mod models;
12pub mod providers;
13#[cfg(any(test, feature = "test-support"))]
14#[doc(hidden)]
15pub mod test_support;
16
17// Generated gRPC code
18#[allow(clippy::similar_names)]
19#[allow(clippy::default_trait_access)]
20#[allow(clippy::doc_markdown)]
21#[allow(clippy::must_use_candidate)]
22pub mod grpc {
23    pub mod health {
24        tonic::include_proto!("model_express.health");
25    }
26    pub mod api {
27        tonic::include_proto!("model_express.api");
28    }
29    pub mod model {
30        tonic::include_proto!("model_express.model");
31    }
32    pub mod p2p {
33        tonic::include_proto!("model_express.p2p");
34    }
35}
36
37/// Defines the shared response format between server and client (legacy HTTP)
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct Response<T> {
40    pub success: bool,
41    pub data: Option<T>,
42    pub error: Option<String>,
43}
44
45/// Common error types that both client and server can use
46#[derive(Debug, thiserror::Error)]
47pub enum Error {
48    #[error("Network error: {0}")]
49    Network(String),
50
51    #[error("Server returned error: {0}")]
52    Server(String),
53
54    #[error("I/O error: {0}")]
55    Io(String),
56
57    #[error("Validation error: {0}")]
58    Validation(String),
59
60    #[error("Serialization error: {0}")]
61    Serialization(String),
62
63    #[error("gRPC error: {0}")]
64    Grpc(#[from] tonic::Status),
65
66    #[error("Transport error: {0}")]
67    Transport(#[from] tonic::transport::Error),
68
69    #[error("Generic error: {0}")]
70    Generic(String),
71}
72
73// Implement From traits for Box<Error> to work with the Result<T> type
74impl From<tonic::Status> for Box<Error> {
75    fn from(err: tonic::Status) -> Self {
76        Box::new(Error::Grpc(err))
77    }
78}
79
80impl From<tonic::transport::Error> for Box<Error> {
81    fn from(err: tonic::transport::Error) -> Self {
82        Box::new(Error::Transport(err))
83    }
84}
85
86/// Common result type for the project
87pub type Result<T> = std::result::Result<T, Box<Error>>;
88
89/// Marker struct to use Utils methods
90pub struct Utils;
91
92impl Utils {
93    /// Get home directory from environment variables
94    pub fn get_home_dir() -> std::result::Result<String, Box<Error>> {
95        env::var("HOME")
96            .or_else(|_| env::var("USERPROFILE"))
97            .map_err(|e| Error::Generic(format!("Failed to get home directory: {e}")).into())
98    }
99}
100
101/// Constants shared between client and server
102pub mod constants {
103    use std::num::NonZeroU16;
104
105    pub const DEFAULT_CACHE_PATH: &str = ".model-express/cache";
106    pub const DEFAULT_HF_CACHE_PATH: &str = ".cache/huggingface/hub";
107    pub const DEFAULT_CONFIG_PATH: &str = ".model-express/config.yaml";
108
109    pub const DEFAULT_GRPC_PORT: NonZeroU16 = NonZeroU16::new(8001).expect("8001 is non-zero");
110    pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
111
112    /// Default setting for shared storage mode (true = client and server share a network drive)
113    pub const DEFAULT_SHARED_STORAGE: bool = true;
114
115    /// Default chunk size for file transfer streaming in bytes (32 KB)
116    pub const DEFAULT_TRANSFER_CHUNK_SIZE: usize = 32 * 1024;
117}
118
119// Conversion utilities between gRPC and legacy models
120impl From<&models::Status> for grpc::health::HealthResponse {
121    fn from(status: &models::Status) -> Self {
122        Self {
123            version: status.version.clone(),
124            status: status.status.clone(),
125            uptime: status.uptime,
126        }
127    }
128}
129
130impl From<grpc::health::HealthResponse> for models::Status {
131    fn from(response: grpc::health::HealthResponse) -> Self {
132        Self {
133            version: response.version,
134            status: response.status,
135            uptime: response.uptime,
136        }
137    }
138}
139
140impl From<models::ModelProvider> for grpc::model::ModelProvider {
141    fn from(provider: models::ModelProvider) -> Self {
142        match provider {
143            models::ModelProvider::HuggingFace => grpc::model::ModelProvider::HuggingFace,
144        }
145    }
146}
147
148impl From<grpc::model::ModelProvider> for models::ModelProvider {
149    fn from(provider: grpc::model::ModelProvider) -> Self {
150        match provider {
151            grpc::model::ModelProvider::HuggingFace => models::ModelProvider::HuggingFace,
152        }
153    }
154}
155
156impl From<models::ModelStatus> for grpc::model::ModelStatus {
157    fn from(status: models::ModelStatus) -> Self {
158        match status {
159            models::ModelStatus::DOWNLOADING => grpc::model::ModelStatus::Downloading,
160            models::ModelStatus::DOWNLOADED => grpc::model::ModelStatus::Downloaded,
161            models::ModelStatus::ERROR => grpc::model::ModelStatus::Error,
162        }
163    }
164}
165
166impl From<grpc::model::ModelStatus> for models::ModelStatus {
167    fn from(status: grpc::model::ModelStatus) -> Self {
168        match status {
169            grpc::model::ModelStatus::Downloading => models::ModelStatus::DOWNLOADING,
170            grpc::model::ModelStatus::Downloaded => models::ModelStatus::DOWNLOADED,
171            grpc::model::ModelStatus::Error => models::ModelStatus::ERROR,
172        }
173    }
174}
175
176impl From<&models::ModelStatusResponse> for grpc::model::ModelStatusUpdate {
177    fn from(response: &models::ModelStatusResponse) -> Self {
178        Self {
179            model_name: response.model_name.clone(),
180            status: grpc::model::ModelStatus::from(response.status) as i32,
181            message: None,
182            provider: grpc::model::ModelProvider::from(response.provider) as i32,
183        }
184    }
185}
186
187impl From<grpc::model::ModelStatusUpdate> for models::ModelStatusResponse {
188    fn from(update: grpc::model::ModelStatusUpdate) -> Self {
189        Self {
190            model_name: update.model_name,
191            status: grpc::model::ModelStatus::try_from(update.status)
192                .unwrap_or(grpc::model::ModelStatus::Error)
193                .into(),
194            provider: grpc::model::ModelProvider::try_from(update.provider)
195                .unwrap_or(grpc::model::ModelProvider::HuggingFace)
196                .into(),
197        }
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use std::env;
205
206    #[test]
207    fn test_status_conversion_from_models_to_grpc() {
208        let status = models::Status {
209            version: "1.0.0".to_string(),
210            status: "ok".to_string(),
211            uptime: 3600,
212        };
213
214        let grpc_response: grpc::health::HealthResponse = (&status).into();
215
216        assert_eq!(grpc_response.version, status.version);
217        assert_eq!(grpc_response.status, status.status);
218        assert_eq!(grpc_response.uptime, status.uptime);
219    }
220
221    #[test]
222    fn test_status_conversion_from_grpc_to_models() {
223        let grpc_response = grpc::health::HealthResponse {
224            version: "1.0.0".to_string(),
225            status: "ok".to_string(),
226            uptime: 3600,
227        };
228
229        let status: models::Status = grpc_response.into();
230
231        assert_eq!(status.version, "1.0.0");
232        assert_eq!(status.status, "ok");
233        assert_eq!(status.uptime, 3600);
234    }
235
236    #[test]
237    fn test_model_provider_conversion_both_ways() {
238        let model_provider = models::ModelProvider::HuggingFace;
239        let grpc_provider: grpc::model::ModelProvider = model_provider.into();
240        let back_to_model: models::ModelProvider = grpc_provider.into();
241
242        assert_eq!(model_provider, back_to_model);
243    }
244
245    #[test]
246    fn test_model_status_conversion_both_ways() {
247        let statuses = vec![
248            models::ModelStatus::DOWNLOADING,
249            models::ModelStatus::DOWNLOADED,
250            models::ModelStatus::ERROR,
251        ];
252
253        for status in statuses {
254            let grpc_status: grpc::model::ModelStatus = status.into();
255            let back_to_model: models::ModelStatus = grpc_status.into();
256            assert_eq!(status, back_to_model);
257        }
258    }
259
260    #[test]
261    fn test_model_status_response_conversion_from_models_to_grpc() {
262        let response = models::ModelStatusResponse {
263            model_name: "test-model".to_string(),
264            status: models::ModelStatus::DOWNLOADED,
265            provider: models::ModelProvider::HuggingFace,
266        };
267
268        let grpc_update: grpc::model::ModelStatusUpdate = (&response).into();
269
270        assert_eq!(grpc_update.model_name, response.model_name);
271        assert_eq!(
272            grpc_update.status,
273            grpc::model::ModelStatus::Downloaded as i32
274        );
275        assert_eq!(
276            grpc_update.provider,
277            grpc::model::ModelProvider::HuggingFace as i32
278        );
279        assert!(grpc_update.message.is_none());
280    }
281
282    #[test]
283    fn test_model_status_response_conversion_from_grpc_to_models() {
284        let grpc_update = grpc::model::ModelStatusUpdate {
285            model_name: "test-model".to_string(),
286            status: grpc::model::ModelStatus::Downloaded as i32,
287            message: Some("Test message".to_string()),
288            provider: grpc::model::ModelProvider::HuggingFace as i32,
289        };
290
291        let response: models::ModelStatusResponse = grpc_update.into();
292
293        assert_eq!(response.model_name, "test-model");
294        assert_eq!(response.status, models::ModelStatus::DOWNLOADED);
295        assert_eq!(response.provider, models::ModelProvider::HuggingFace);
296    }
297
298    #[test]
299    fn test_error_types() {
300        let network_error = Error::Network("Connection failed".to_string());
301        assert!(network_error.to_string().contains("Network error"));
302
303        let server_error = Error::Server("Internal error".to_string());
304        assert!(server_error.to_string().contains("Server returned error"));
305
306        let io_error = Error::Io("Permission denied".to_string());
307        assert!(io_error.to_string().contains("I/O error"));
308
309        let validation_error = Error::Validation("Unsafe path".to_string());
310        assert!(validation_error.to_string().contains("Validation error"));
311
312        let serialization_error = Error::Serialization("JSON parse error".to_string());
313        assert!(
314            serialization_error
315                .to_string()
316                .contains("Serialization error")
317        );
318    }
319
320    #[test]
321    fn test_constants() {
322        assert_eq!(constants::DEFAULT_GRPC_PORT.get(), 8001);
323        assert_eq!(constants::DEFAULT_TIMEOUT_SECS, 30);
324        assert_eq!(constants::DEFAULT_TRANSFER_CHUNK_SIZE, 32 * 1024);
325    }
326
327    #[test]
328    fn test_response_creation() {
329        let success_response = Response {
330            success: true,
331            data: Some("test data".to_string()),
332            error: None,
333        };
334
335        assert!(success_response.success);
336        assert!(success_response.data.is_some());
337        assert!(success_response.error.is_none());
338
339        let error_response: Response<String> = Response {
340            success: false,
341            data: None,
342            error: Some("test error".to_string()),
343        };
344
345        assert!(!error_response.success);
346        assert!(error_response.data.is_none());
347        assert!(error_response.error.is_some());
348    }
349
350    #[test]
351    fn test_utils_get_home_dir() {
352        let home_dir = Utils::get_home_dir();
353
354        if let Ok(home_dir) = home_dir {
355            assert!(!home_dir.is_empty());
356            // Check against HOME or USERPROFILE
357            if let Ok(expected_home) = env::var("HOME") {
358                assert_eq!(home_dir, expected_home);
359            } else if let Ok(expected_home) = env::var("USERPROFILE") {
360                assert_eq!(home_dir, expected_home);
361            }
362        }
363    }
364}