1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub enum Variant {
5 Cpu,
6 Gpu,
7 Tpu,
8}
9
10impl<'de> serde::Deserialize<'de> for Variant {
11 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> std::result::Result<Self, D::Error> {
12 use serde::de::{self, Visitor};
13 struct V;
14 impl Visitor<'_> for V {
15 type Value = Variant;
16 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
17 write!(
18 f,
19 "variant as string (\"DEFAULT\"/\"GPU\"/\"TPU\") or integer (0/1/2)"
20 )
21 }
22 fn visit_str<E: de::Error>(self, v: &str) -> std::result::Result<Variant, E> {
23 match v {
24 "DEFAULT" | "default" | "cpu" => Ok(Variant::Cpu),
25 "GPU" | "gpu" => Ok(Variant::Gpu),
26 "TPU" | "tpu" => Ok(Variant::Tpu),
27 other => Err(E::unknown_variant(other, &["DEFAULT", "GPU", "TPU"])),
28 }
29 }
30 fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Variant, E> {
31 match v {
32 0 => Ok(Variant::Cpu),
33 1 => Ok(Variant::Gpu),
34 2 => Ok(Variant::Tpu),
35 other => Err(E::custom(format!("unknown variant integer: {other}"))),
36 }
37 }
38 fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Variant, E> {
39 self.visit_u64(v as u64)
40 }
41 }
42 d.deserialize_any(V)
43 }
44}
45
46impl serde::Serialize for Variant {
47 fn serialize<S: serde::Serializer>(&self, s: S) -> std::result::Result<S::Ok, S::Error> {
48 s.serialize_str(match self {
49 Variant::Cpu => "DEFAULT",
50 Variant::Gpu => "GPU",
51 Variant::Tpu => "TPU",
52 })
53 }
54}
55
56impl Variant {
57 #[inline]
58 pub fn display_name(&self) -> &'static str {
59 match self {
60 Variant::Cpu => "CPU",
61 Variant::Gpu => "GPU",
62 Variant::Tpu => "TPU",
63 }
64 }
65}
66
67impl std::fmt::Display for Variant {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.write_str(self.display_name())
70 }
71}
72
73#[derive(Debug, Clone, Deserialize, Serialize)]
74#[serde(rename_all = "camelCase")]
75pub struct RuntimeProxyInfo {
76 pub token: String,
77 pub token_expires_in_seconds: i64,
78 pub url: String,
79}
80
81#[derive(Debug, Deserialize)]
82pub struct GetAssignmentResponse {
83 #[serde(rename = "token")]
84 pub xsrf_token: String,
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum Outcome {
89 Undefined,
90 QuotaDeniedVariants,
91 QuotaExceededUsageTime,
92 Success,
93 Denylisted,
94}
95
96impl<'de> serde::Deserialize<'de> for Outcome {
97 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> std::result::Result<Self, D::Error> {
98 let v = u8::deserialize(d)?;
99 match v {
100 0 => Ok(Outcome::Undefined),
101 1 => Ok(Outcome::QuotaDeniedVariants),
102 2 => Ok(Outcome::QuotaExceededUsageTime),
103 4 => Ok(Outcome::Success),
104 5 => Ok(Outcome::Denylisted),
105 other => Err(serde::de::Error::custom(format!(
106 "unknown outcome: {other}"
107 ))),
108 }
109 }
110}
111
112impl serde::Serialize for Outcome {
113 fn serialize<S: serde::Serializer>(&self, s: S) -> std::result::Result<S::Ok, S::Error> {
114 let v: u8 = match self {
115 Outcome::Undefined => 0,
116 Outcome::QuotaDeniedVariants => 1,
117 Outcome::QuotaExceededUsageTime => 2,
118 Outcome::Success => 4,
119 Outcome::Denylisted => 5,
120 };
121 s.serialize_u8(v)
122 }
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, Default)]
126#[serde(from = "u8", into = "u8")]
127pub enum Shape {
128 #[default]
129 Standard,
130 HighMem,
131 Unknown(u8),
132}
133
134impl From<u8> for Shape {
135 #[inline]
136 fn from(v: u8) -> Self {
137 match v {
138 0 => Shape::Standard,
139 1 => Shape::HighMem,
140 other => Shape::Unknown(other),
141 }
142 }
143}
144
145impl From<Shape> for u8 {
146 #[inline]
147 fn from(s: Shape) -> u8 {
148 match s {
149 Shape::Standard => 0,
150 Shape::HighMem => 1,
151 Shape::Unknown(v) => v,
152 }
153 }
154}
155
156impl Shape {
157 #[inline]
158 pub fn display_name(&self) -> &'static str {
159 match self {
160 Shape::Standard => "standard",
161 Shape::HighMem => "high-ram",
162 Shape::Unknown(_) => "unknown",
163 }
164 }
165}
166
167impl std::fmt::Display for Shape {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 f.write_str(self.display_name())
170 }
171}
172
173#[derive(Debug, Clone, Deserialize, Serialize)]
174#[serde(rename_all = "camelCase")]
175pub struct Assignment {
176 pub endpoint: String,
177 pub variant: Variant,
178 pub accelerator: Option<String>,
179 pub machine_shape: Option<Shape>,
180 pub runtime_proxy_info: RuntimeProxyInfo,
181 #[serde(rename = "fit")]
182 pub idle_timeout_sec: Option<u64>,
183 pub outcome: Option<Outcome>,
184}
185
186#[derive(Debug, Clone, Deserialize, Serialize)]
187#[serde(rename_all = "camelCase")]
188pub struct ListedAssignment {
189 pub endpoint: String,
190 pub variant: Variant,
191 pub accelerator: Option<String>,
192 pub machine_shape: Option<Shape>,
193 pub runtime_proxy_info: Option<RuntimeProxyInfo>,
194}
195
196#[derive(Debug, Deserialize)]
197pub struct ListAssignmentsResponse {
198 pub assignments: Vec<ListedAssignment>,
199}
200
201#[derive(Debug, Deserialize)]
202pub struct Session {
203 pub id: String,
204}
205
206#[derive(Debug, Deserialize)]
207pub struct JupyterTerminal {
208 pub name: String,
209}
210
211#[derive(Debug, Deserialize)]
212#[serde(rename_all = "camelCase")]
213pub struct CcuInfo {
214 pub current_balance: f64,
215 pub consumption_rate_hourly: f64,
216 pub assignments_count: u32,
217 pub eligible_gpus: Vec<String>,
218 pub eligible_tpus: Vec<String>,
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn variant_deserialises_strings() {
227 assert_eq!(
228 serde_json::from_str::<Variant>("\"DEFAULT\"").unwrap(),
229 Variant::Cpu
230 );
231 assert_eq!(
232 serde_json::from_str::<Variant>("\"GPU\"").unwrap(),
233 Variant::Gpu
234 );
235 assert_eq!(
236 serde_json::from_str::<Variant>("\"TPU\"").unwrap(),
237 Variant::Tpu
238 );
239 }
240
241 #[test]
242 fn variant_deserialises_integers() {
243 assert_eq!(serde_json::from_str::<Variant>("0").unwrap(), Variant::Cpu);
244 assert_eq!(serde_json::from_str::<Variant>("1").unwrap(), Variant::Gpu);
245 assert_eq!(serde_json::from_str::<Variant>("2").unwrap(), Variant::Tpu);
246 }
247
248 #[test]
249 fn variant_rejects_unknown() {
250 assert!(serde_json::from_str::<Variant>("\"QUANTUM\"").is_err());
251 assert!(serde_json::from_str::<Variant>("99").is_err());
252 }
253
254 #[test]
255 fn variant_serialises_to_canonical_string() {
256 assert_eq!(serde_json::to_string(&Variant::Cpu).unwrap(), "\"DEFAULT\"");
257 assert_eq!(serde_json::to_string(&Variant::Gpu).unwrap(), "\"GPU\"");
258 assert_eq!(serde_json::to_string(&Variant::Tpu).unwrap(), "\"TPU\"");
259 }
260
261 #[test]
262 fn outcome_deserialises_known_codes() {
263 assert_eq!(
264 serde_json::from_str::<Outcome>("4").unwrap(),
265 Outcome::Success
266 );
267 assert_eq!(
268 serde_json::from_str::<Outcome>("1").unwrap(),
269 Outcome::QuotaDeniedVariants
270 );
271 assert_eq!(
272 serde_json::from_str::<Outcome>("5").unwrap(),
273 Outcome::Denylisted
274 );
275 }
276
277 #[test]
278 fn outcome_rejects_unknown_code() {
279 assert!(serde_json::from_str::<Outcome>("42").is_err());
280 }
281
282 #[test]
283 fn assignment_parses_real_payload() {
284 let json = r#"{
285 "endpoint": "abc-123",
286 "variant": "GPU",
287 "accelerator": "T4",
288 "machineShape": 1,
289 "runtimeProxyInfo": {
290 "token": "tok",
291 "tokenExpiresInSeconds": 3600,
292 "url": "https://example.com"
293 },
294 "fit": 5400,
295 "outcome": 4
296 }"#;
297 let a: Assignment = serde_json::from_str(json).unwrap();
298 assert_eq!(a.endpoint, "abc-123");
299 assert_eq!(a.variant, Variant::Gpu);
300 assert_eq!(a.accelerator.as_deref(), Some("T4"));
301 assert_eq!(a.machine_shape, Some(Shape::HighMem));
302 assert_eq!(a.idle_timeout_sec, Some(5400));
303 assert_eq!(a.outcome, Some(Outcome::Success));
304 }
305
306 #[test]
307 fn listed_assignment_allows_missing_proxy_info() {
308 let json = r#"{"endpoint":"e","variant":0}"#;
309 let la: ListedAssignment = serde_json::from_str(json).unwrap();
310 assert_eq!(la.variant, Variant::Cpu);
311 assert!(la.runtime_proxy_info.is_none());
312 }
313
314 #[test]
315 fn shape_round_trip_known_and_unknown() {
316 assert_eq!(Shape::from(0u8), Shape::Standard);
317 assert_eq!(Shape::from(1u8), Shape::HighMem);
318 assert_eq!(Shape::from(7u8), Shape::Unknown(7));
319 assert_eq!(u8::from(Shape::Standard), 0);
320 assert_eq!(u8::from(Shape::HighMem), 1);
321 assert_eq!(u8::from(Shape::Unknown(7)), 7);
322 }
323
324 #[test]
325 fn shape_json_round_trip() {
326 let s: Shape = serde_json::from_str("0").unwrap();
327 assert_eq!(s, Shape::Standard);
328 let s: Shape = serde_json::from_str("1").unwrap();
329 assert_eq!(s, Shape::HighMem);
330 assert_eq!(serde_json::to_string(&Shape::HighMem).unwrap(), "1");
331 }
332
333 #[test]
334 fn ccu_info_parses() {
335 let json = r#"{
336 "currentBalance": 42.25,
337 "consumptionRateHourly": 1.76,
338 "assignmentsCount": 1,
339 "eligibleGpus": ["T4", "A100"],
340 "eligibleTpus": ["v2-8"]
341 }"#;
342 let c: CcuInfo = serde_json::from_str(json).unwrap();
343 assert_eq!(c.current_balance, 42.25);
344 assert_eq!(c.consumption_rate_hourly, 1.76);
345 assert_eq!(c.assignments_count, 1);
346 assert_eq!(c.eligible_gpus, vec!["T4".to_string(), "A100".to_string()]);
347 assert_eq!(c.eligible_tpus, vec!["v2-8".to_string()]);
348 }
349}