Skip to main content

modelexpress_common/
models.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use clap::{ValueEnum, builder::PossibleValue};
5use serde::{Deserialize, Serialize};
6use std::fmt::{Display, Formatter};
7
8/// Status model for server health checks
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Status {
11    pub version: String,
12    pub status: String,
13    pub uptime: u64,
14}
15
16/// Status of a model download
17#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
18pub enum ModelStatus {
19    /// Model is currently being downloaded
20    DOWNLOADING,
21    /// Model has been successfully downloaded
22    DOWNLOADED,
23    /// Model download failed with an error
24    ERROR,
25}
26
27/// Supported model providers
28#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
29pub enum ModelProvider {
30    /// Hugging Face model hub
31    #[default]
32    HuggingFace,
33    /// NVIDIA NGC catalog
34    Ngc,
35    /// Google Cloud Storage
36    Gcs,
37}
38
39impl ModelProvider {
40    #[must_use]
41    pub const fn as_str(self) -> &'static str {
42        match self {
43            Self::HuggingFace => "hugging-face",
44            Self::Ngc => "ngc",
45            Self::Gcs => "gcs",
46        }
47    }
48}
49
50impl Display for ModelProvider {
51    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
52        f.write_str(self.as_str())
53    }
54}
55
56impl ValueEnum for ModelProvider {
57    fn value_variants<'a>() -> &'a [Self] {
58        &[Self::HuggingFace, Self::Ngc, Self::Gcs]
59    }
60
61    fn to_possible_value(&self) -> Option<PossibleValue> {
62        Some(PossibleValue::new(self.as_str()))
63    }
64}
65
66/// Response for model status request
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ModelStatusResponse {
69    pub model_name: String,
70    pub status: ModelStatus,
71    pub provider: ModelProvider,
72}
73
74#[cfg(test)]
75#[allow(clippy::expect_used)]
76mod tests {
77    use super::*;
78
79    #[test]
80    fn test_model_status_serialization() {
81        let status = ModelStatus::DOWNLOADING;
82        let serialized = serde_json::to_string(&status).expect("Failed to serialize ModelStatus");
83        let deserialized: ModelStatus =
84            serde_json::from_str(&serialized).expect("Failed to deserialize ModelStatus");
85        assert_eq!(status, deserialized);
86    }
87
88    #[test]
89    fn test_model_provider_serialization() {
90        for provider in [
91            ModelProvider::HuggingFace,
92            ModelProvider::Ngc,
93            ModelProvider::Gcs,
94        ] {
95            let serialized =
96                serde_json::to_string(&provider).expect("Failed to serialize ModelProvider");
97            let deserialized: ModelProvider =
98                serde_json::from_str(&serialized).expect("Failed to deserialize ModelProvider");
99            assert_eq!(provider, deserialized);
100        }
101    }
102
103    #[test]
104    fn test_model_provider_default() {
105        let provider = ModelProvider::default();
106        assert_eq!(provider, ModelProvider::HuggingFace);
107    }
108
109    #[test]
110    fn test_model_provider_display() {
111        assert_eq!(ModelProvider::HuggingFace.to_string(), "hugging-face");
112        assert_eq!(ModelProvider::Ngc.to_string(), "ngc");
113        assert_eq!(ModelProvider::Gcs.to_string(), "gcs");
114    }
115
116    #[test]
117    fn test_model_provider_value_enum_matches_display() {
118        for provider in [
119            ModelProvider::HuggingFace,
120            ModelProvider::Ngc,
121            ModelProvider::Gcs,
122        ] {
123            let parsed = ModelProvider::from_str(provider.as_str(), false)
124                .expect("Failed to parse ModelProvider from clap value");
125            assert_eq!(parsed, provider);
126        }
127    }
128
129    #[test]
130    fn test_status_serialization() {
131        let status = Status {
132            version: "1.0.0".to_string(),
133            status: "ok".to_string(),
134            uptime: 3600,
135        };
136
137        let serialized = serde_json::to_string(&status).expect("Failed to serialize Status");
138        let deserialized: Status =
139            serde_json::from_str(&serialized).expect("Failed to deserialize Status");
140
141        assert_eq!(status.version, deserialized.version);
142        assert_eq!(status.status, deserialized.status);
143        assert_eq!(status.uptime, deserialized.uptime);
144    }
145
146    #[test]
147    fn test_model_status_response_serialization() {
148        let response = ModelStatusResponse {
149            model_name: "test-model".to_string(),
150            status: ModelStatus::DOWNLOADED,
151            provider: ModelProvider::HuggingFace,
152        };
153
154        let serialized =
155            serde_json::to_string(&response).expect("Failed to serialize ModelStatusResponse");
156        let deserialized: ModelStatusResponse =
157            serde_json::from_str(&serialized).expect("Failed to deserialize ModelStatusResponse");
158
159        assert_eq!(response.model_name, deserialized.model_name);
160        assert_eq!(response.status, deserialized.status);
161        assert_eq!(response.provider, deserialized.provider);
162    }
163
164    #[test]
165    fn test_model_status_all_variants() {
166        assert_eq!(ModelStatus::DOWNLOADING, ModelStatus::DOWNLOADING);
167        assert_eq!(ModelStatus::DOWNLOADED, ModelStatus::DOWNLOADED);
168        assert_eq!(ModelStatus::ERROR, ModelStatus::ERROR);
169
170        assert_ne!(ModelStatus::DOWNLOADING, ModelStatus::DOWNLOADED);
171        assert_ne!(ModelStatus::DOWNLOADED, ModelStatus::ERROR);
172        assert_ne!(ModelStatus::ERROR, ModelStatus::DOWNLOADING);
173    }
174}