Skip to main content

modelexpress_common/
models.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use clap::{ValueEnum, builder::PossibleValue};
5use serde::{Deserialize, Serialize};
6use std::fmt::{Display, Formatter};
7
8/// Status model for server health checks
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Status {
11    pub version: String,
12    pub status: String,
13    pub uptime: u64,
14}
15
16/// Status of a model download
17#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
18pub enum ModelStatus {
19    /// Model is currently being downloaded
20    DOWNLOADING,
21    /// Model has been successfully downloaded
22    DOWNLOADED,
23    /// Model download failed with an error
24    ERROR,
25}
26
27/// Supported model providers
28#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
29pub enum ModelProvider {
30    /// Hugging Face model hub
31    #[default]
32    HuggingFace,
33}
34
35impl ModelProvider {
36    #[must_use]
37    pub const fn as_str(self) -> &'static str {
38        match self {
39            Self::HuggingFace => "hugging-face",
40        }
41    }
42}
43
44impl Display for ModelProvider {
45    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
46        f.write_str(self.as_str())
47    }
48}
49
50impl ValueEnum for ModelProvider {
51    fn value_variants<'a>() -> &'a [Self] {
52        &[Self::HuggingFace]
53    }
54
55    fn to_possible_value(&self) -> Option<PossibleValue> {
56        Some(PossibleValue::new(self.as_str()))
57    }
58}
59
60/// Response for model status request
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct ModelStatusResponse {
63    pub model_name: String,
64    pub status: ModelStatus,
65    pub provider: ModelProvider,
66}
67
68#[cfg(test)]
69#[allow(clippy::expect_used)]
70mod tests {
71    use super::*;
72
73    #[test]
74    fn test_model_status_serialization() {
75        let status = ModelStatus::DOWNLOADING;
76        let serialized = serde_json::to_string(&status).expect("Failed to serialize ModelStatus");
77        let deserialized: ModelStatus =
78            serde_json::from_str(&serialized).expect("Failed to deserialize ModelStatus");
79        assert_eq!(status, deserialized);
80    }
81
82    #[test]
83    fn test_model_provider_serialization() {
84        let provider = ModelProvider::HuggingFace;
85        let serialized =
86            serde_json::to_string(&provider).expect("Failed to serialize ModelProvider");
87        let deserialized: ModelProvider =
88            serde_json::from_str(&serialized).expect("Failed to deserialize ModelProvider");
89        assert_eq!(provider, deserialized);
90    }
91
92    #[test]
93    fn test_model_provider_default() {
94        let provider = ModelProvider::default();
95        assert_eq!(provider, ModelProvider::HuggingFace);
96    }
97
98    #[test]
99    fn test_model_provider_display() {
100        assert_eq!(ModelProvider::HuggingFace.to_string(), "hugging-face");
101    }
102
103    #[test]
104    fn test_model_provider_value_enum_matches_display() {
105        let provider = ModelProvider::HuggingFace;
106        let parsed = ModelProvider::from_str(provider.as_str(), false)
107            .expect("Failed to parse ModelProvider from clap value");
108
109        assert_eq!(parsed, provider);
110    }
111
112    #[test]
113    fn test_status_serialization() {
114        let status = Status {
115            version: "1.0.0".to_string(),
116            status: "ok".to_string(),
117            uptime: 3600,
118        };
119
120        let serialized = serde_json::to_string(&status).expect("Failed to serialize Status");
121        let deserialized: Status =
122            serde_json::from_str(&serialized).expect("Failed to deserialize Status");
123
124        assert_eq!(status.version, deserialized.version);
125        assert_eq!(status.status, deserialized.status);
126        assert_eq!(status.uptime, deserialized.uptime);
127    }
128
129    #[test]
130    fn test_model_status_response_serialization() {
131        let response = ModelStatusResponse {
132            model_name: "test-model".to_string(),
133            status: ModelStatus::DOWNLOADED,
134            provider: ModelProvider::HuggingFace,
135        };
136
137        let serialized =
138            serde_json::to_string(&response).expect("Failed to serialize ModelStatusResponse");
139        let deserialized: ModelStatusResponse =
140            serde_json::from_str(&serialized).expect("Failed to deserialize ModelStatusResponse");
141
142        assert_eq!(response.model_name, deserialized.model_name);
143        assert_eq!(response.status, deserialized.status);
144        assert_eq!(response.provider, deserialized.provider);
145    }
146
147    #[test]
148    fn test_model_status_all_variants() {
149        assert_eq!(ModelStatus::DOWNLOADING, ModelStatus::DOWNLOADING);
150        assert_eq!(ModelStatus::DOWNLOADED, ModelStatus::DOWNLOADED);
151        assert_eq!(ModelStatus::ERROR, ModelStatus::ERROR);
152
153        assert_ne!(ModelStatus::DOWNLOADING, ModelStatus::DOWNLOADED);
154        assert_ne!(ModelStatus::DOWNLOADED, ModelStatus::ERROR);
155        assert_ne!(ModelStatus::ERROR, ModelStatus::DOWNLOADING);
156    }
157}