modelexpress_common/
lib.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use serde::{Deserialize, Serialize};
5
6pub mod cache;
7pub mod client_config;
8pub mod config;
9pub mod download;
10pub mod models;
11pub mod providers;
12
13// Generated gRPC code
14#[allow(clippy::similar_names)]
15#[allow(clippy::default_trait_access)]
16#[allow(clippy::doc_markdown)]
17#[allow(clippy::must_use_candidate)]
18pub mod grpc {
19    pub mod health {
20        tonic::include_proto!("model_express.health");
21    }
22    pub mod api {
23        tonic::include_proto!("model_express.api");
24    }
25    pub mod model {
26        tonic::include_proto!("model_express.model");
27    }
28}
29
30/// Defines the shared response format between server and client (legacy HTTP)
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct Response<T> {
33    pub success: bool,
34    pub data: Option<T>,
35    pub error: Option<String>,
36}
37
38/// Common error types that both client and server can use
39#[derive(Debug, thiserror::Error)]
40pub enum Error {
41    #[error("Network error: {0}")]
42    Network(String),
43
44    #[error("Server returned error: {0}")]
45    Server(String),
46
47    #[error("Serialization error: {0}")]
48    Serialization(String),
49
50    #[error("gRPC error: {0}")]
51    Grpc(#[from] tonic::Status),
52
53    #[error("Transport error: {0}")]
54    Transport(#[from] tonic::transport::Error),
55}
56
57// Implement From traits for Box<Error> to work with the Result<T> type
58impl From<tonic::Status> for Box<Error> {
59    fn from(err: tonic::Status) -> Self {
60        Box::new(Error::Grpc(err))
61    }
62}
63
64impl From<tonic::transport::Error> for Box<Error> {
65    fn from(err: tonic::transport::Error) -> Self {
66        Box::new(Error::Transport(err))
67    }
68}
69
70/// Common result type for the project
71pub type Result<T> = std::result::Result<T, Box<Error>>;
72
73/// Constants shared between client and server
74pub mod constants {
75    use std::num::NonZeroU16;
76
77    pub const DEFAULT_GRPC_PORT: NonZeroU16 = NonZeroU16::new(8001).expect("8001 is non-zero");
78    pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
79}
80
81// Conversion utilities between gRPC and legacy models
82impl From<&models::Status> for grpc::health::HealthResponse {
83    fn from(status: &models::Status) -> Self {
84        Self {
85            version: status.version.clone(),
86            status: status.status.clone(),
87            uptime: status.uptime,
88        }
89    }
90}
91
92impl From<grpc::health::HealthResponse> for models::Status {
93    fn from(response: grpc::health::HealthResponse) -> Self {
94        Self {
95            version: response.version,
96            status: response.status,
97            uptime: response.uptime,
98        }
99    }
100}
101
102impl From<models::ModelProvider> for grpc::model::ModelProvider {
103    fn from(provider: models::ModelProvider) -> Self {
104        match provider {
105            models::ModelProvider::HuggingFace => grpc::model::ModelProvider::HuggingFace,
106        }
107    }
108}
109
110impl From<grpc::model::ModelProvider> for models::ModelProvider {
111    fn from(provider: grpc::model::ModelProvider) -> Self {
112        match provider {
113            grpc::model::ModelProvider::HuggingFace => models::ModelProvider::HuggingFace,
114        }
115    }
116}
117
118impl From<models::ModelStatus> for grpc::model::ModelStatus {
119    fn from(status: models::ModelStatus) -> Self {
120        match status {
121            models::ModelStatus::DOWNLOADING => grpc::model::ModelStatus::Downloading,
122            models::ModelStatus::DOWNLOADED => grpc::model::ModelStatus::Downloaded,
123            models::ModelStatus::ERROR => grpc::model::ModelStatus::Error,
124        }
125    }
126}
127
128impl From<grpc::model::ModelStatus> for models::ModelStatus {
129    fn from(status: grpc::model::ModelStatus) -> Self {
130        match status {
131            grpc::model::ModelStatus::Downloading => models::ModelStatus::DOWNLOADING,
132            grpc::model::ModelStatus::Downloaded => models::ModelStatus::DOWNLOADED,
133            grpc::model::ModelStatus::Error => models::ModelStatus::ERROR,
134        }
135    }
136}
137
138impl From<&models::ModelStatusResponse> for grpc::model::ModelStatusUpdate {
139    fn from(response: &models::ModelStatusResponse) -> Self {
140        Self {
141            model_name: response.model_name.clone(),
142            status: grpc::model::ModelStatus::from(response.status) as i32,
143            message: None,
144            provider: grpc::model::ModelProvider::from(response.provider) as i32,
145        }
146    }
147}
148
149impl From<grpc::model::ModelStatusUpdate> for models::ModelStatusResponse {
150    fn from(update: grpc::model::ModelStatusUpdate) -> Self {
151        Self {
152            model_name: update.model_name,
153            status: grpc::model::ModelStatus::try_from(update.status)
154                .unwrap_or(grpc::model::ModelStatus::Error)
155                .into(),
156            provider: grpc::model::ModelProvider::try_from(update.provider)
157                .unwrap_or(grpc::model::ModelProvider::HuggingFace)
158                .into(),
159        }
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn test_status_conversion_from_models_to_grpc() {
169        let status = models::Status {
170            version: "1.0.0".to_string(),
171            status: "ok".to_string(),
172            uptime: 3600,
173        };
174
175        let grpc_response: grpc::health::HealthResponse = (&status).into();
176
177        assert_eq!(grpc_response.version, status.version);
178        assert_eq!(grpc_response.status, status.status);
179        assert_eq!(grpc_response.uptime, status.uptime);
180    }
181
182    #[test]
183    fn test_status_conversion_from_grpc_to_models() {
184        let grpc_response = grpc::health::HealthResponse {
185            version: "1.0.0".to_string(),
186            status: "ok".to_string(),
187            uptime: 3600,
188        };
189
190        let status: models::Status = grpc_response.into();
191
192        assert_eq!(status.version, "1.0.0");
193        assert_eq!(status.status, "ok");
194        assert_eq!(status.uptime, 3600);
195    }
196
197    #[test]
198    fn test_model_provider_conversion_both_ways() {
199        let model_provider = models::ModelProvider::HuggingFace;
200        let grpc_provider: grpc::model::ModelProvider = model_provider.into();
201        let back_to_model: models::ModelProvider = grpc_provider.into();
202
203        assert_eq!(model_provider, back_to_model);
204    }
205
206    #[test]
207    fn test_model_status_conversion_both_ways() {
208        let statuses = vec![
209            models::ModelStatus::DOWNLOADING,
210            models::ModelStatus::DOWNLOADED,
211            models::ModelStatus::ERROR,
212        ];
213
214        for status in statuses {
215            let grpc_status: grpc::model::ModelStatus = status.into();
216            let back_to_model: models::ModelStatus = grpc_status.into();
217            assert_eq!(status, back_to_model);
218        }
219    }
220
221    #[test]
222    fn test_model_status_response_conversion_from_models_to_grpc() {
223        let response = models::ModelStatusResponse {
224            model_name: "test-model".to_string(),
225            status: models::ModelStatus::DOWNLOADED,
226            provider: models::ModelProvider::HuggingFace,
227        };
228
229        let grpc_update: grpc::model::ModelStatusUpdate = (&response).into();
230
231        assert_eq!(grpc_update.model_name, response.model_name);
232        assert_eq!(
233            grpc_update.status,
234            grpc::model::ModelStatus::Downloaded as i32
235        );
236        assert_eq!(
237            grpc_update.provider,
238            grpc::model::ModelProvider::HuggingFace as i32
239        );
240        assert!(grpc_update.message.is_none());
241    }
242
243    #[test]
244    fn test_model_status_response_conversion_from_grpc_to_models() {
245        let grpc_update = grpc::model::ModelStatusUpdate {
246            model_name: "test-model".to_string(),
247            status: grpc::model::ModelStatus::Downloaded as i32,
248            message: Some("Test message".to_string()),
249            provider: grpc::model::ModelProvider::HuggingFace as i32,
250        };
251
252        let response: models::ModelStatusResponse = grpc_update.into();
253
254        assert_eq!(response.model_name, "test-model");
255        assert_eq!(response.status, models::ModelStatus::DOWNLOADED);
256        assert_eq!(response.provider, models::ModelProvider::HuggingFace);
257    }
258
259    #[test]
260    fn test_error_types() {
261        let network_error = Error::Network("Connection failed".to_string());
262        assert!(network_error.to_string().contains("Network error"));
263
264        let server_error = Error::Server("Internal error".to_string());
265        assert!(server_error.to_string().contains("Server returned error"));
266
267        let serialization_error = Error::Serialization("JSON parse error".to_string());
268        assert!(
269            serialization_error
270                .to_string()
271                .contains("Serialization error")
272        );
273    }
274
275    #[test]
276    fn test_constants() {
277        assert_eq!(constants::DEFAULT_GRPC_PORT.get(), 8001);
278        assert_eq!(constants::DEFAULT_TIMEOUT_SECS, 30);
279    }
280
281    #[test]
282    fn test_response_creation() {
283        let success_response = Response {
284            success: true,
285            data: Some("test data".to_string()),
286            error: None,
287        };
288
289        assert!(success_response.success);
290        assert!(success_response.data.is_some());
291        assert!(success_response.error.is_none());
292
293        let error_response: Response<String> = Response {
294            success: false,
295            data: None,
296            error: Some("test error".to_string()),
297        };
298
299        assert!(!error_response.success);
300        assert!(error_response.data.is_none());
301        assert!(error_response.error.is_some());
302    }
303}