Skip to main content

modelexpress_common/
lib.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use serde::{Deserialize, Serialize};
5use std::env;
6
7pub mod cache;
8pub mod client_config;
9pub mod config;
10pub mod download;
11pub mod models;
12pub mod providers;
13
14// Generated gRPC code
15#[allow(clippy::similar_names)]
16#[allow(clippy::default_trait_access)]
17#[allow(clippy::doc_markdown)]
18#[allow(clippy::must_use_candidate)]
19pub mod grpc {
20    pub mod health {
21        tonic::include_proto!("model_express.health");
22    }
23    pub mod api {
24        tonic::include_proto!("model_express.api");
25    }
26    pub mod model {
27        tonic::include_proto!("model_express.model");
28    }
29}
30
31/// Defines the shared response format between server and client (legacy HTTP)
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Response<T> {
34    pub success: bool,
35    pub data: Option<T>,
36    pub error: Option<String>,
37}
38
39/// Common error types that both client and server can use
40#[derive(Debug, thiserror::Error)]
41pub enum Error {
42    #[error("Network error: {0}")]
43    Network(String),
44
45    #[error("Server returned error: {0}")]
46    Server(String),
47
48    #[error("Serialization error: {0}")]
49    Serialization(String),
50
51    #[error("gRPC error: {0}")]
52    Grpc(#[from] tonic::Status),
53
54    #[error("Transport error: {0}")]
55    Transport(#[from] tonic::transport::Error),
56
57    #[error("Generic error: {0}")]
58    Generic(String),
59}
60
61// Implement From traits for Box<Error> to work with the Result<T> type
62impl From<tonic::Status> for Box<Error> {
63    fn from(err: tonic::Status) -> Self {
64        Box::new(Error::Grpc(err))
65    }
66}
67
68impl From<tonic::transport::Error> for Box<Error> {
69    fn from(err: tonic::transport::Error) -> Self {
70        Box::new(Error::Transport(err))
71    }
72}
73
74/// Common result type for the project
75pub type Result<T> = std::result::Result<T, Box<Error>>;
76
77/// Marker struct to use Utils methods
78pub struct Utils;
79
80impl Utils {
81    /// Get home directory from environment variables
82    pub fn get_home_dir() -> std::result::Result<String, Box<Error>> {
83        env::var("HOME")
84            .or_else(|_| env::var("USERPROFILE"))
85            .map_err(|e| Error::Generic(format!("Failed to get home directory: {e}")).into())
86    }
87}
88
89/// Constants shared between client and server
90pub mod constants {
91    use std::num::NonZeroU16;
92
93    pub const DEFAULT_CACHE_PATH: &str = ".model-express/cache";
94    pub const DEFAULT_HF_CACHE_PATH: &str = ".cache/huggingface/hub";
95    pub const DEFAULT_CONFIG_PATH: &str = ".model-express/config.yaml";
96
97    pub const DEFAULT_GRPC_PORT: NonZeroU16 = NonZeroU16::new(8001).expect("8001 is non-zero");
98    pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
99
100    /// Default setting for shared storage mode (true = client and server share a network drive)
101    pub const DEFAULT_SHARED_STORAGE: bool = true;
102
103    /// Default chunk size for file transfer streaming in bytes (32 KB)
104    pub const DEFAULT_TRANSFER_CHUNK_SIZE: usize = 32 * 1024;
105}
106
107// Conversion utilities between gRPC and legacy models
108impl From<&models::Status> for grpc::health::HealthResponse {
109    fn from(status: &models::Status) -> Self {
110        Self {
111            version: status.version.clone(),
112            status: status.status.clone(),
113            uptime: status.uptime,
114        }
115    }
116}
117
118impl From<grpc::health::HealthResponse> for models::Status {
119    fn from(response: grpc::health::HealthResponse) -> Self {
120        Self {
121            version: response.version,
122            status: response.status,
123            uptime: response.uptime,
124        }
125    }
126}
127
128impl From<models::ModelProvider> for grpc::model::ModelProvider {
129    fn from(provider: models::ModelProvider) -> Self {
130        match provider {
131            models::ModelProvider::HuggingFace => grpc::model::ModelProvider::HuggingFace,
132        }
133    }
134}
135
136impl From<grpc::model::ModelProvider> for models::ModelProvider {
137    fn from(provider: grpc::model::ModelProvider) -> Self {
138        match provider {
139            grpc::model::ModelProvider::HuggingFace => models::ModelProvider::HuggingFace,
140        }
141    }
142}
143
144impl From<models::ModelStatus> for grpc::model::ModelStatus {
145    fn from(status: models::ModelStatus) -> Self {
146        match status {
147            models::ModelStatus::DOWNLOADING => grpc::model::ModelStatus::Downloading,
148            models::ModelStatus::DOWNLOADED => grpc::model::ModelStatus::Downloaded,
149            models::ModelStatus::ERROR => grpc::model::ModelStatus::Error,
150        }
151    }
152}
153
154impl From<grpc::model::ModelStatus> for models::ModelStatus {
155    fn from(status: grpc::model::ModelStatus) -> Self {
156        match status {
157            grpc::model::ModelStatus::Downloading => models::ModelStatus::DOWNLOADING,
158            grpc::model::ModelStatus::Downloaded => models::ModelStatus::DOWNLOADED,
159            grpc::model::ModelStatus::Error => models::ModelStatus::ERROR,
160        }
161    }
162}
163
164impl From<&models::ModelStatusResponse> for grpc::model::ModelStatusUpdate {
165    fn from(response: &models::ModelStatusResponse) -> Self {
166        Self {
167            model_name: response.model_name.clone(),
168            status: grpc::model::ModelStatus::from(response.status) as i32,
169            message: None,
170            provider: grpc::model::ModelProvider::from(response.provider) as i32,
171        }
172    }
173}
174
175impl From<grpc::model::ModelStatusUpdate> for models::ModelStatusResponse {
176    fn from(update: grpc::model::ModelStatusUpdate) -> Self {
177        Self {
178            model_name: update.model_name,
179            status: grpc::model::ModelStatus::try_from(update.status)
180                .unwrap_or(grpc::model::ModelStatus::Error)
181                .into(),
182            provider: grpc::model::ModelProvider::try_from(update.provider)
183                .unwrap_or(grpc::model::ModelProvider::HuggingFace)
184                .into(),
185        }
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use std::env;
193
194    #[test]
195    fn test_status_conversion_from_models_to_grpc() {
196        let status = models::Status {
197            version: "1.0.0".to_string(),
198            status: "ok".to_string(),
199            uptime: 3600,
200        };
201
202        let grpc_response: grpc::health::HealthResponse = (&status).into();
203
204        assert_eq!(grpc_response.version, status.version);
205        assert_eq!(grpc_response.status, status.status);
206        assert_eq!(grpc_response.uptime, status.uptime);
207    }
208
209    #[test]
210    fn test_status_conversion_from_grpc_to_models() {
211        let grpc_response = grpc::health::HealthResponse {
212            version: "1.0.0".to_string(),
213            status: "ok".to_string(),
214            uptime: 3600,
215        };
216
217        let status: models::Status = grpc_response.into();
218
219        assert_eq!(status.version, "1.0.0");
220        assert_eq!(status.status, "ok");
221        assert_eq!(status.uptime, 3600);
222    }
223
224    #[test]
225    fn test_model_provider_conversion_both_ways() {
226        let model_provider = models::ModelProvider::HuggingFace;
227        let grpc_provider: grpc::model::ModelProvider = model_provider.into();
228        let back_to_model: models::ModelProvider = grpc_provider.into();
229
230        assert_eq!(model_provider, back_to_model);
231    }
232
233    #[test]
234    fn test_model_status_conversion_both_ways() {
235        let statuses = vec![
236            models::ModelStatus::DOWNLOADING,
237            models::ModelStatus::DOWNLOADED,
238            models::ModelStatus::ERROR,
239        ];
240
241        for status in statuses {
242            let grpc_status: grpc::model::ModelStatus = status.into();
243            let back_to_model: models::ModelStatus = grpc_status.into();
244            assert_eq!(status, back_to_model);
245        }
246    }
247
248    #[test]
249    fn test_model_status_response_conversion_from_models_to_grpc() {
250        let response = models::ModelStatusResponse {
251            model_name: "test-model".to_string(),
252            status: models::ModelStatus::DOWNLOADED,
253            provider: models::ModelProvider::HuggingFace,
254        };
255
256        let grpc_update: grpc::model::ModelStatusUpdate = (&response).into();
257
258        assert_eq!(grpc_update.model_name, response.model_name);
259        assert_eq!(
260            grpc_update.status,
261            grpc::model::ModelStatus::Downloaded as i32
262        );
263        assert_eq!(
264            grpc_update.provider,
265            grpc::model::ModelProvider::HuggingFace as i32
266        );
267        assert!(grpc_update.message.is_none());
268    }
269
270    #[test]
271    fn test_model_status_response_conversion_from_grpc_to_models() {
272        let grpc_update = grpc::model::ModelStatusUpdate {
273            model_name: "test-model".to_string(),
274            status: grpc::model::ModelStatus::Downloaded as i32,
275            message: Some("Test message".to_string()),
276            provider: grpc::model::ModelProvider::HuggingFace as i32,
277        };
278
279        let response: models::ModelStatusResponse = grpc_update.into();
280
281        assert_eq!(response.model_name, "test-model");
282        assert_eq!(response.status, models::ModelStatus::DOWNLOADED);
283        assert_eq!(response.provider, models::ModelProvider::HuggingFace);
284    }
285
286    #[test]
287    fn test_error_types() {
288        let network_error = Error::Network("Connection failed".to_string());
289        assert!(network_error.to_string().contains("Network error"));
290
291        let server_error = Error::Server("Internal error".to_string());
292        assert!(server_error.to_string().contains("Server returned error"));
293
294        let serialization_error = Error::Serialization("JSON parse error".to_string());
295        assert!(
296            serialization_error
297                .to_string()
298                .contains("Serialization error")
299        );
300    }
301
302    #[test]
303    fn test_constants() {
304        assert_eq!(constants::DEFAULT_GRPC_PORT.get(), 8001);
305        assert_eq!(constants::DEFAULT_TIMEOUT_SECS, 30);
306        assert_eq!(constants::DEFAULT_TRANSFER_CHUNK_SIZE, 32 * 1024);
307    }
308
309    #[test]
310    fn test_response_creation() {
311        let success_response = Response {
312            success: true,
313            data: Some("test data".to_string()),
314            error: None,
315        };
316
317        assert!(success_response.success);
318        assert!(success_response.data.is_some());
319        assert!(success_response.error.is_none());
320
321        let error_response: Response<String> = Response {
322            success: false,
323            data: None,
324            error: Some("test error".to_string()),
325        };
326
327        assert!(!error_response.success);
328        assert!(error_response.data.is_none());
329        assert!(error_response.error.is_some());
330    }
331
332    #[test]
333    fn test_utils_get_home_dir() {
334        let home_dir = Utils::get_home_dir();
335
336        if let Ok(home_dir) = home_dir {
337            assert!(!home_dir.is_empty());
338            // Check against HOME or USERPROFILE
339            if let Ok(expected_home) = env::var("HOME") {
340                assert_eq!(home_dir, expected_home);
341            } else if let Ok(expected_home) = env::var("USERPROFILE") {
342                assert_eq!(home_dir, expected_home);
343            }
344        }
345    }
346}