modelexpress_common/
models.rs1use clap::{ValueEnum, builder::PossibleValue};
5use serde::{Deserialize, Serialize};
6use std::fmt::{Display, Formatter};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Status {
11 pub version: String,
12 pub status: String,
13 pub uptime: u64,
14}
15
16#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
18pub enum ModelStatus {
19 DOWNLOADING,
21 DOWNLOADED,
23 ERROR,
25}
26
27#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
29pub enum ModelProvider {
30 #[default]
32 HuggingFace,
33 Ngc,
35 Gcs,
37}
38
39impl ModelProvider {
40 #[must_use]
41 pub const fn as_str(self) -> &'static str {
42 match self {
43 Self::HuggingFace => "hugging-face",
44 Self::Ngc => "ngc",
45 Self::Gcs => "gcs",
46 }
47 }
48}
49
50impl Display for ModelProvider {
51 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
52 f.write_str(self.as_str())
53 }
54}
55
56impl ValueEnum for ModelProvider {
57 fn value_variants<'a>() -> &'a [Self] {
58 &[Self::HuggingFace, Self::Ngc, Self::Gcs]
59 }
60
61 fn to_possible_value(&self) -> Option<PossibleValue> {
62 Some(PossibleValue::new(self.as_str()))
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ModelStatusResponse {
69 pub model_name: String,
70 pub status: ModelStatus,
71 pub provider: ModelProvider,
72}
73
74#[cfg(test)]
75#[allow(clippy::expect_used)]
76mod tests {
77 use super::*;
78
79 #[test]
80 fn test_model_status_serialization() {
81 let status = ModelStatus::DOWNLOADING;
82 let serialized = serde_json::to_string(&status).expect("Failed to serialize ModelStatus");
83 let deserialized: ModelStatus =
84 serde_json::from_str(&serialized).expect("Failed to deserialize ModelStatus");
85 assert_eq!(status, deserialized);
86 }
87
88 #[test]
89 fn test_model_provider_serialization() {
90 for provider in [
91 ModelProvider::HuggingFace,
92 ModelProvider::Ngc,
93 ModelProvider::Gcs,
94 ] {
95 let serialized =
96 serde_json::to_string(&provider).expect("Failed to serialize ModelProvider");
97 let deserialized: ModelProvider =
98 serde_json::from_str(&serialized).expect("Failed to deserialize ModelProvider");
99 assert_eq!(provider, deserialized);
100 }
101 }
102
103 #[test]
104 fn test_model_provider_default() {
105 let provider = ModelProvider::default();
106 assert_eq!(provider, ModelProvider::HuggingFace);
107 }
108
109 #[test]
110 fn test_model_provider_display() {
111 assert_eq!(ModelProvider::HuggingFace.to_string(), "hugging-face");
112 assert_eq!(ModelProvider::Ngc.to_string(), "ngc");
113 assert_eq!(ModelProvider::Gcs.to_string(), "gcs");
114 }
115
116 #[test]
117 fn test_model_provider_value_enum_matches_display() {
118 for provider in [
119 ModelProvider::HuggingFace,
120 ModelProvider::Ngc,
121 ModelProvider::Gcs,
122 ] {
123 let parsed = ModelProvider::from_str(provider.as_str(), false)
124 .expect("Failed to parse ModelProvider from clap value");
125 assert_eq!(parsed, provider);
126 }
127 }
128
129 #[test]
130 fn test_status_serialization() {
131 let status = Status {
132 version: "1.0.0".to_string(),
133 status: "ok".to_string(),
134 uptime: 3600,
135 };
136
137 let serialized = serde_json::to_string(&status).expect("Failed to serialize Status");
138 let deserialized: Status =
139 serde_json::from_str(&serialized).expect("Failed to deserialize Status");
140
141 assert_eq!(status.version, deserialized.version);
142 assert_eq!(status.status, deserialized.status);
143 assert_eq!(status.uptime, deserialized.uptime);
144 }
145
146 #[test]
147 fn test_model_status_response_serialization() {
148 let response = ModelStatusResponse {
149 model_name: "test-model".to_string(),
150 status: ModelStatus::DOWNLOADED,
151 provider: ModelProvider::HuggingFace,
152 };
153
154 let serialized =
155 serde_json::to_string(&response).expect("Failed to serialize ModelStatusResponse");
156 let deserialized: ModelStatusResponse =
157 serde_json::from_str(&serialized).expect("Failed to deserialize ModelStatusResponse");
158
159 assert_eq!(response.model_name, deserialized.model_name);
160 assert_eq!(response.status, deserialized.status);
161 assert_eq!(response.provider, deserialized.provider);
162 }
163
164 #[test]
165 fn test_model_status_all_variants() {
166 assert_eq!(ModelStatus::DOWNLOADING, ModelStatus::DOWNLOADING);
167 assert_eq!(ModelStatus::DOWNLOADED, ModelStatus::DOWNLOADED);
168 assert_eq!(ModelStatus::ERROR, ModelStatus::ERROR);
169
170 assert_ne!(ModelStatus::DOWNLOADING, ModelStatus::DOWNLOADED);
171 assert_ne!(ModelStatus::DOWNLOADED, ModelStatus::ERROR);
172 assert_ne!(ModelStatus::ERROR, ModelStatus::DOWNLOADING);
173 }
174}