modelexpress_common/
lib.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use serde::{Deserialize, Serialize};
5use std::env;
6
7pub mod cache;
8pub mod client_config;
9pub mod config;
10pub mod download;
11pub mod models;
12pub mod providers;
13
14// Generated gRPC code
15#[allow(clippy::similar_names)]
16#[allow(clippy::default_trait_access)]
17#[allow(clippy::doc_markdown)]
18#[allow(clippy::must_use_candidate)]
19pub mod grpc {
20    pub mod health {
21        tonic::include_proto!("model_express.health");
22    }
23    pub mod api {
24        tonic::include_proto!("model_express.api");
25    }
26    pub mod model {
27        tonic::include_proto!("model_express.model");
28    }
29}
30
31/// Defines the shared response format between server and client (legacy HTTP)
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Response<T> {
34    pub success: bool,
35    pub data: Option<T>,
36    pub error: Option<String>,
37}
38
39/// Common error types that both client and server can use
40#[derive(Debug, thiserror::Error)]
41pub enum Error {
42    #[error("Network error: {0}")]
43    Network(String),
44
45    #[error("Server returned error: {0}")]
46    Server(String),
47
48    #[error("Serialization error: {0}")]
49    Serialization(String),
50
51    #[error("gRPC error: {0}")]
52    Grpc(#[from] tonic::Status),
53
54    #[error("Transport error: {0}")]
55    Transport(#[from] tonic::transport::Error),
56
57    #[error("Generic error: {0}")]
58    Generic(String),
59}
60
61// Implement From traits for Box<Error> to work with the Result<T> type
62impl From<tonic::Status> for Box<Error> {
63    fn from(err: tonic::Status) -> Self {
64        Box::new(Error::Grpc(err))
65    }
66}
67
68impl From<tonic::transport::Error> for Box<Error> {
69    fn from(err: tonic::transport::Error) -> Self {
70        Box::new(Error::Transport(err))
71    }
72}
73
74/// Common result type for the project
75pub type Result<T> = std::result::Result<T, Box<Error>>;
76
77/// Marker struct to use Utils methods
78pub struct Utils;
79
80impl Utils {
81    /// Get home directory from environment variables
82    pub fn get_home_dir() -> std::result::Result<String, Box<Error>> {
83        env::var("HOME")
84            .or_else(|_| env::var("USERPROFILE"))
85            .map_err(|e| Error::Generic(format!("Failed to get home directory: {e}")).into())
86    }
87}
88
89/// Constants shared between client and server
90pub mod constants {
91    use std::num::NonZeroU16;
92
93    pub const DEFAULT_CACHE_PATH: &str = ".model-express/cache";
94    pub const DEFAULT_HF_CACHE_PATH: &str = ".cache/huggingface/hub";
95    pub const DEFAULT_CONFIG_PATH: &str = ".model-express/config.yaml";
96
97    pub const DEFAULT_GRPC_PORT: NonZeroU16 = NonZeroU16::new(8001).expect("8001 is non-zero");
98    pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
99}
100
101// Conversion utilities between gRPC and legacy models
102impl From<&models::Status> for grpc::health::HealthResponse {
103    fn from(status: &models::Status) -> Self {
104        Self {
105            version: status.version.clone(),
106            status: status.status.clone(),
107            uptime: status.uptime,
108        }
109    }
110}
111
112impl From<grpc::health::HealthResponse> for models::Status {
113    fn from(response: grpc::health::HealthResponse) -> Self {
114        Self {
115            version: response.version,
116            status: response.status,
117            uptime: response.uptime,
118        }
119    }
120}
121
122impl From<models::ModelProvider> for grpc::model::ModelProvider {
123    fn from(provider: models::ModelProvider) -> Self {
124        match provider {
125            models::ModelProvider::HuggingFace => grpc::model::ModelProvider::HuggingFace,
126        }
127    }
128}
129
130impl From<grpc::model::ModelProvider> for models::ModelProvider {
131    fn from(provider: grpc::model::ModelProvider) -> Self {
132        match provider {
133            grpc::model::ModelProvider::HuggingFace => models::ModelProvider::HuggingFace,
134        }
135    }
136}
137
138impl From<models::ModelStatus> for grpc::model::ModelStatus {
139    fn from(status: models::ModelStatus) -> Self {
140        match status {
141            models::ModelStatus::DOWNLOADING => grpc::model::ModelStatus::Downloading,
142            models::ModelStatus::DOWNLOADED => grpc::model::ModelStatus::Downloaded,
143            models::ModelStatus::ERROR => grpc::model::ModelStatus::Error,
144        }
145    }
146}
147
148impl From<grpc::model::ModelStatus> for models::ModelStatus {
149    fn from(status: grpc::model::ModelStatus) -> Self {
150        match status {
151            grpc::model::ModelStatus::Downloading => models::ModelStatus::DOWNLOADING,
152            grpc::model::ModelStatus::Downloaded => models::ModelStatus::DOWNLOADED,
153            grpc::model::ModelStatus::Error => models::ModelStatus::ERROR,
154        }
155    }
156}
157
158impl From<&models::ModelStatusResponse> for grpc::model::ModelStatusUpdate {
159    fn from(response: &models::ModelStatusResponse) -> Self {
160        Self {
161            model_name: response.model_name.clone(),
162            status: grpc::model::ModelStatus::from(response.status) as i32,
163            message: None,
164            provider: grpc::model::ModelProvider::from(response.provider) as i32,
165        }
166    }
167}
168
169impl From<grpc::model::ModelStatusUpdate> for models::ModelStatusResponse {
170    fn from(update: grpc::model::ModelStatusUpdate) -> Self {
171        Self {
172            model_name: update.model_name,
173            status: grpc::model::ModelStatus::try_from(update.status)
174                .unwrap_or(grpc::model::ModelStatus::Error)
175                .into(),
176            provider: grpc::model::ModelProvider::try_from(update.provider)
177                .unwrap_or(grpc::model::ModelProvider::HuggingFace)
178                .into(),
179        }
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use std::env;
187
188    #[test]
189    fn test_status_conversion_from_models_to_grpc() {
190        let status = models::Status {
191            version: "1.0.0".to_string(),
192            status: "ok".to_string(),
193            uptime: 3600,
194        };
195
196        let grpc_response: grpc::health::HealthResponse = (&status).into();
197
198        assert_eq!(grpc_response.version, status.version);
199        assert_eq!(grpc_response.status, status.status);
200        assert_eq!(grpc_response.uptime, status.uptime);
201    }
202
203    #[test]
204    fn test_status_conversion_from_grpc_to_models() {
205        let grpc_response = grpc::health::HealthResponse {
206            version: "1.0.0".to_string(),
207            status: "ok".to_string(),
208            uptime: 3600,
209        };
210
211        let status: models::Status = grpc_response.into();
212
213        assert_eq!(status.version, "1.0.0");
214        assert_eq!(status.status, "ok");
215        assert_eq!(status.uptime, 3600);
216    }
217
218    #[test]
219    fn test_model_provider_conversion_both_ways() {
220        let model_provider = models::ModelProvider::HuggingFace;
221        let grpc_provider: grpc::model::ModelProvider = model_provider.into();
222        let back_to_model: models::ModelProvider = grpc_provider.into();
223
224        assert_eq!(model_provider, back_to_model);
225    }
226
227    #[test]
228    fn test_model_status_conversion_both_ways() {
229        let statuses = vec![
230            models::ModelStatus::DOWNLOADING,
231            models::ModelStatus::DOWNLOADED,
232            models::ModelStatus::ERROR,
233        ];
234
235        for status in statuses {
236            let grpc_status: grpc::model::ModelStatus = status.into();
237            let back_to_model: models::ModelStatus = grpc_status.into();
238            assert_eq!(status, back_to_model);
239        }
240    }
241
242    #[test]
243    fn test_model_status_response_conversion_from_models_to_grpc() {
244        let response = models::ModelStatusResponse {
245            model_name: "test-model".to_string(),
246            status: models::ModelStatus::DOWNLOADED,
247            provider: models::ModelProvider::HuggingFace,
248        };
249
250        let grpc_update: grpc::model::ModelStatusUpdate = (&response).into();
251
252        assert_eq!(grpc_update.model_name, response.model_name);
253        assert_eq!(
254            grpc_update.status,
255            grpc::model::ModelStatus::Downloaded as i32
256        );
257        assert_eq!(
258            grpc_update.provider,
259            grpc::model::ModelProvider::HuggingFace as i32
260        );
261        assert!(grpc_update.message.is_none());
262    }
263
264    #[test]
265    fn test_model_status_response_conversion_from_grpc_to_models() {
266        let grpc_update = grpc::model::ModelStatusUpdate {
267            model_name: "test-model".to_string(),
268            status: grpc::model::ModelStatus::Downloaded as i32,
269            message: Some("Test message".to_string()),
270            provider: grpc::model::ModelProvider::HuggingFace as i32,
271        };
272
273        let response: models::ModelStatusResponse = grpc_update.into();
274
275        assert_eq!(response.model_name, "test-model");
276        assert_eq!(response.status, models::ModelStatus::DOWNLOADED);
277        assert_eq!(response.provider, models::ModelProvider::HuggingFace);
278    }
279
280    #[test]
281    fn test_error_types() {
282        let network_error = Error::Network("Connection failed".to_string());
283        assert!(network_error.to_string().contains("Network error"));
284
285        let server_error = Error::Server("Internal error".to_string());
286        assert!(server_error.to_string().contains("Server returned error"));
287
288        let serialization_error = Error::Serialization("JSON parse error".to_string());
289        assert!(
290            serialization_error
291                .to_string()
292                .contains("Serialization error")
293        );
294    }
295
296    #[test]
297    fn test_constants() {
298        assert_eq!(constants::DEFAULT_GRPC_PORT.get(), 8001);
299        assert_eq!(constants::DEFAULT_TIMEOUT_SECS, 30);
300    }
301
302    #[test]
303    fn test_response_creation() {
304        let success_response = Response {
305            success: true,
306            data: Some("test data".to_string()),
307            error: None,
308        };
309
310        assert!(success_response.success);
311        assert!(success_response.data.is_some());
312        assert!(success_response.error.is_none());
313
314        let error_response: Response<String> = Response {
315            success: false,
316            data: None,
317            error: Some("test error".to_string()),
318        };
319
320        assert!(!error_response.success);
321        assert!(error_response.data.is_none());
322        assert!(error_response.error.is_some());
323    }
324
325    #[test]
326    fn test_utils_get_home_dir() {
327        let home_dir = Utils::get_home_dir();
328
329        if let Ok(home_dir) = home_dir {
330            assert!(!home_dir.is_empty());
331            // Check against HOME or USERPROFILE
332            if let Ok(expected_home) = env::var("HOME") {
333                assert_eq!(home_dir, expected_home);
334            } else if let Ok(expected_home) = env::var("USERPROFILE") {
335                assert_eq!(home_dir, expected_home);
336            }
337        }
338    }
339}