Skip to main content

modelexpress_common/
models.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use serde::{Deserialize, Serialize};
5
6/// Status model for server health checks
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Status {
9    pub version: String,
10    pub status: String,
11    pub uptime: u64,
12}
13
14/// Status of a model download
15#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
16pub enum ModelStatus {
17    /// Model is currently being downloaded
18    DOWNLOADING,
19    /// Model has been successfully downloaded
20    DOWNLOADED,
21    /// Model download failed with an error
22    ERROR,
23}
24
25/// Supported model providers
26#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Default)]
27pub enum ModelProvider {
28    /// Hugging Face model hub
29    #[default]
30    HuggingFace,
31}
32
33/// Response for model status request
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ModelStatusResponse {
36    pub model_name: String,
37    pub status: ModelStatus,
38    pub provider: ModelProvider,
39}
40
41#[cfg(test)]
42#[allow(clippy::expect_used)]
43mod tests {
44    use super::*;
45
46    #[test]
47    fn test_model_status_serialization() {
48        let status = ModelStatus::DOWNLOADING;
49        let serialized = serde_json::to_string(&status).expect("Failed to serialize ModelStatus");
50        let deserialized: ModelStatus =
51            serde_json::from_str(&serialized).expect("Failed to deserialize ModelStatus");
52        assert_eq!(status, deserialized);
53    }
54
55    #[test]
56    fn test_model_provider_serialization() {
57        let provider = ModelProvider::HuggingFace;
58        let serialized =
59            serde_json::to_string(&provider).expect("Failed to serialize ModelProvider");
60        let deserialized: ModelProvider =
61            serde_json::from_str(&serialized).expect("Failed to deserialize ModelProvider");
62        assert_eq!(provider, deserialized);
63    }
64
65    #[test]
66    fn test_model_provider_default() {
67        let provider = ModelProvider::default();
68        assert_eq!(provider, ModelProvider::HuggingFace);
69    }
70
71    #[test]
72    fn test_status_serialization() {
73        let status = Status {
74            version: "1.0.0".to_string(),
75            status: "ok".to_string(),
76            uptime: 3600,
77        };
78
79        let serialized = serde_json::to_string(&status).expect("Failed to serialize Status");
80        let deserialized: Status =
81            serde_json::from_str(&serialized).expect("Failed to deserialize Status");
82
83        assert_eq!(status.version, deserialized.version);
84        assert_eq!(status.status, deserialized.status);
85        assert_eq!(status.uptime, deserialized.uptime);
86    }
87
88    #[test]
89    fn test_model_status_response_serialization() {
90        let response = ModelStatusResponse {
91            model_name: "test-model".to_string(),
92            status: ModelStatus::DOWNLOADED,
93            provider: ModelProvider::HuggingFace,
94        };
95
96        let serialized =
97            serde_json::to_string(&response).expect("Failed to serialize ModelStatusResponse");
98        let deserialized: ModelStatusResponse =
99            serde_json::from_str(&serialized).expect("Failed to deserialize ModelStatusResponse");
100
101        assert_eq!(response.model_name, deserialized.model_name);
102        assert_eq!(response.status, deserialized.status);
103        assert_eq!(response.provider, deserialized.provider);
104    }
105
106    #[test]
107    fn test_model_status_all_variants() {
108        assert_eq!(ModelStatus::DOWNLOADING, ModelStatus::DOWNLOADING);
109        assert_eq!(ModelStatus::DOWNLOADED, ModelStatus::DOWNLOADED);
110        assert_eq!(ModelStatus::ERROR, ModelStatus::ERROR);
111
112        assert_ne!(ModelStatus::DOWNLOADING, ModelStatus::DOWNLOADED);
113        assert_ne!(ModelStatus::DOWNLOADED, ModelStatus::ERROR);
114        assert_ne!(ModelStatus::ERROR, ModelStatus::DOWNLOADING);
115    }
116}