modelexpress_common/
models.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7/// Request model for client -> server communication
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct Request {
10    pub id: String,
11    pub action: String,
12    pub payload: Option<HashMap<String, serde_json::Value>>,
13}
14
15/// Status model for server health checks
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Status {
18    pub version: String,
19    pub status: String,
20    pub uptime: u64,
21}
22
23/// Status of a model download
24#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
25pub enum ModelStatus {
26    /// Model is currently being downloaded
27    DOWNLOADING,
28    /// Model has been successfully downloaded
29    DOWNLOADED,
30    /// Model download failed with an error
31    ERROR,
32}
33
34/// Supported model providers
35#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Default)]
36pub enum ModelProvider {
37    /// Hugging Face model hub
38    #[default]
39    HuggingFace,
40}
41
42/// Response for model status request
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelStatusResponse {
45    pub model_name: String,
46    pub status: ModelStatus,
47    pub provider: ModelProvider,
48}
49
50#[cfg(test)]
51#[allow(clippy::expect_used)]
52mod tests {
53    use super::*;
54    use serde_json::json;
55
56    #[test]
57    fn test_model_status_serialization() {
58        let status = ModelStatus::DOWNLOADING;
59        let serialized = serde_json::to_string(&status).expect("Failed to serialize ModelStatus");
60        let deserialized: ModelStatus =
61            serde_json::from_str(&serialized).expect("Failed to deserialize ModelStatus");
62        assert_eq!(status, deserialized);
63    }
64
65    #[test]
66    fn test_model_provider_serialization() {
67        let provider = ModelProvider::HuggingFace;
68        let serialized =
69            serde_json::to_string(&provider).expect("Failed to serialize ModelProvider");
70        let deserialized: ModelProvider =
71            serde_json::from_str(&serialized).expect("Failed to deserialize ModelProvider");
72        assert_eq!(provider, deserialized);
73    }
74
75    #[test]
76    fn test_model_provider_default() {
77        let provider = ModelProvider::default();
78        assert_eq!(provider, ModelProvider::HuggingFace);
79    }
80
81    #[test]
82    fn test_request_serialization() {
83        let mut payload = HashMap::new();
84        payload.insert("key".to_string(), json!("value"));
85
86        let request = Request {
87            id: "test-id".to_string(),
88            action: "test-action".to_string(),
89            payload: Some(payload),
90        };
91
92        let serialized = serde_json::to_string(&request).expect("Failed to serialize Request");
93        let deserialized: Request =
94            serde_json::from_str(&serialized).expect("Failed to deserialize Request");
95
96        assert_eq!(request.id, deserialized.id);
97        assert_eq!(request.action, deserialized.action);
98        assert!(request.payload.is_some());
99        assert!(deserialized.payload.is_some());
100    }
101
102    #[test]
103    fn test_status_serialization() {
104        let status = Status {
105            version: "1.0.0".to_string(),
106            status: "ok".to_string(),
107            uptime: 3600,
108        };
109
110        let serialized = serde_json::to_string(&status).expect("Failed to serialize Status");
111        let deserialized: Status =
112            serde_json::from_str(&serialized).expect("Failed to deserialize Status");
113
114        assert_eq!(status.version, deserialized.version);
115        assert_eq!(status.status, deserialized.status);
116        assert_eq!(status.uptime, deserialized.uptime);
117    }
118
119    #[test]
120    fn test_model_status_response_serialization() {
121        let response = ModelStatusResponse {
122            model_name: "test-model".to_string(),
123            status: ModelStatus::DOWNLOADED,
124            provider: ModelProvider::HuggingFace,
125        };
126
127        let serialized =
128            serde_json::to_string(&response).expect("Failed to serialize ModelStatusResponse");
129        let deserialized: ModelStatusResponse =
130            serde_json::from_str(&serialized).expect("Failed to deserialize ModelStatusResponse");
131
132        assert_eq!(response.model_name, deserialized.model_name);
133        assert_eq!(response.status, deserialized.status);
134        assert_eq!(response.provider, deserialized.provider);
135    }
136
137    #[test]
138    fn test_model_status_all_variants() {
139        assert_eq!(ModelStatus::DOWNLOADING, ModelStatus::DOWNLOADING);
140        assert_eq!(ModelStatus::DOWNLOADED, ModelStatus::DOWNLOADED);
141        assert_eq!(ModelStatus::ERROR, ModelStatus::ERROR);
142
143        assert_ne!(ModelStatus::DOWNLOADING, ModelStatus::DOWNLOADED);
144        assert_ne!(ModelStatus::DOWNLOADED, ModelStatus::ERROR);
145        assert_ne!(ModelStatus::ERROR, ModelStatus::DOWNLOADING);
146    }
147}