Skip to main content

colab_cli/client/
api.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub enum Variant {
5    Cpu,
6    Gpu,
7    Tpu,
8}
9
10impl<'de> serde::Deserialize<'de> for Variant {
11    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> std::result::Result<Self, D::Error> {
12        use serde::de::{self, Visitor};
13        struct V;
14        impl Visitor<'_> for V {
15            type Value = Variant;
16            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
17                write!(
18                    f,
19                    "variant as string (\"DEFAULT\"/\"GPU\"/\"TPU\") or integer (0/1/2)"
20                )
21            }
22            fn visit_str<E: de::Error>(self, v: &str) -> std::result::Result<Variant, E> {
23                match v {
24                    "DEFAULT" | "default" | "cpu" => Ok(Variant::Cpu),
25                    "GPU" | "gpu" => Ok(Variant::Gpu),
26                    "TPU" | "tpu" => Ok(Variant::Tpu),
27                    other => Err(E::unknown_variant(other, &["DEFAULT", "GPU", "TPU"])),
28                }
29            }
30            fn visit_u64<E: de::Error>(self, v: u64) -> std::result::Result<Variant, E> {
31                match v {
32                    0 => Ok(Variant::Cpu),
33                    1 => Ok(Variant::Gpu),
34                    2 => Ok(Variant::Tpu),
35                    other => Err(E::custom(format!("unknown variant integer: {other}"))),
36                }
37            }
38            fn visit_i64<E: de::Error>(self, v: i64) -> std::result::Result<Variant, E> {
39                self.visit_u64(v as u64)
40            }
41        }
42        d.deserialize_any(V)
43    }
44}
45
46impl serde::Serialize for Variant {
47    fn serialize<S: serde::Serializer>(&self, s: S) -> std::result::Result<S::Ok, S::Error> {
48        s.serialize_str(match self {
49            Variant::Cpu => "DEFAULT",
50            Variant::Gpu => "GPU",
51            Variant::Tpu => "TPU",
52        })
53    }
54}
55
56impl Variant {
57    #[inline]
58    pub fn display_name(&self) -> &'static str {
59        match self {
60            Variant::Cpu => "CPU",
61            Variant::Gpu => "GPU",
62            Variant::Tpu => "TPU",
63        }
64    }
65}
66
67impl std::fmt::Display for Variant {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        f.write_str(self.display_name())
70    }
71}
72
73#[derive(Debug, Clone, Deserialize, Serialize)]
74#[serde(rename_all = "camelCase")]
75pub struct RuntimeProxyInfo {
76    pub token: String,
77    pub token_expires_in_seconds: i64,
78    pub url: String,
79}
80
81#[derive(Debug, Deserialize)]
82pub struct GetAssignmentResponse {
83    #[serde(rename = "token")]
84    pub xsrf_token: String,
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum Outcome {
89    Undefined,
90    QuotaDeniedVariants,
91    QuotaExceededUsageTime,
92    Success,
93    Denylisted,
94}
95
96impl<'de> serde::Deserialize<'de> for Outcome {
97    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> std::result::Result<Self, D::Error> {
98        let v = u8::deserialize(d)?;
99        match v {
100            0 => Ok(Outcome::Undefined),
101            1 => Ok(Outcome::QuotaDeniedVariants),
102            2 => Ok(Outcome::QuotaExceededUsageTime),
103            4 => Ok(Outcome::Success),
104            5 => Ok(Outcome::Denylisted),
105            other => Err(serde::de::Error::custom(format!(
106                "unknown outcome: {other}"
107            ))),
108        }
109    }
110}
111
112impl serde::Serialize for Outcome {
113    fn serialize<S: serde::Serializer>(&self, s: S) -> std::result::Result<S::Ok, S::Error> {
114        let v: u8 = match self {
115            Outcome::Undefined => 0,
116            Outcome::QuotaDeniedVariants => 1,
117            Outcome::QuotaExceededUsageTime => 2,
118            Outcome::Success => 4,
119            Outcome::Denylisted => 5,
120        };
121        s.serialize_u8(v)
122    }
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, Default)]
126#[serde(from = "u8", into = "u8")]
127pub enum Shape {
128    #[default]
129    Standard,
130    HighMem,
131    Unknown(u8),
132}
133
134impl From<u8> for Shape {
135    #[inline]
136    fn from(v: u8) -> Self {
137        match v {
138            0 => Shape::Standard,
139            1 => Shape::HighMem,
140            other => Shape::Unknown(other),
141        }
142    }
143}
144
145impl From<Shape> for u8 {
146    #[inline]
147    fn from(s: Shape) -> u8 {
148        match s {
149            Shape::Standard => 0,
150            Shape::HighMem => 1,
151            Shape::Unknown(v) => v,
152        }
153    }
154}
155
156impl Shape {
157    #[inline]
158    pub fn display_name(&self) -> &'static str {
159        match self {
160            Shape::Standard => "standard",
161            Shape::HighMem => "high-ram",
162            Shape::Unknown(_) => "unknown",
163        }
164    }
165}
166
167impl std::fmt::Display for Shape {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        f.write_str(self.display_name())
170    }
171}
172
173#[derive(Debug, Clone, Deserialize, Serialize)]
174#[serde(rename_all = "camelCase")]
175pub struct Assignment {
176    pub endpoint: String,
177    pub variant: Variant,
178    pub accelerator: Option<String>,
179    pub machine_shape: Option<Shape>,
180    pub runtime_proxy_info: RuntimeProxyInfo,
181    #[serde(rename = "fit")]
182    pub idle_timeout_sec: Option<u64>,
183    pub outcome: Option<Outcome>,
184}
185
186#[derive(Debug, Clone, Deserialize, Serialize)]
187#[serde(rename_all = "camelCase")]
188pub struct ListedAssignment {
189    pub endpoint: String,
190    pub variant: Variant,
191    pub accelerator: Option<String>,
192    pub machine_shape: Option<Shape>,
193    pub runtime_proxy_info: Option<RuntimeProxyInfo>,
194}
195
196#[derive(Debug, Deserialize)]
197pub struct ListAssignmentsResponse {
198    pub assignments: Vec<ListedAssignment>,
199}
200
201#[derive(Debug, Deserialize)]
202pub struct Session {
203    pub id: String,
204}
205
206#[derive(Debug, Deserialize)]
207pub struct JupyterTerminal {
208    pub name: String,
209}
210
211#[derive(Debug, Deserialize)]
212#[serde(rename_all = "camelCase")]
213pub struct CcuInfo {
214    pub current_balance: f64,
215    pub consumption_rate_hourly: f64,
216    pub assignments_count: u32,
217    pub eligible_gpus: Vec<String>,
218    pub eligible_tpus: Vec<String>,
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn variant_deserialises_strings() {
227        assert_eq!(
228            serde_json::from_str::<Variant>("\"DEFAULT\"").unwrap(),
229            Variant::Cpu
230        );
231        assert_eq!(
232            serde_json::from_str::<Variant>("\"GPU\"").unwrap(),
233            Variant::Gpu
234        );
235        assert_eq!(
236            serde_json::from_str::<Variant>("\"TPU\"").unwrap(),
237            Variant::Tpu
238        );
239    }
240
241    #[test]
242    fn variant_deserialises_integers() {
243        assert_eq!(serde_json::from_str::<Variant>("0").unwrap(), Variant::Cpu);
244        assert_eq!(serde_json::from_str::<Variant>("1").unwrap(), Variant::Gpu);
245        assert_eq!(serde_json::from_str::<Variant>("2").unwrap(), Variant::Tpu);
246    }
247
248    #[test]
249    fn variant_rejects_unknown() {
250        assert!(serde_json::from_str::<Variant>("\"QUANTUM\"").is_err());
251        assert!(serde_json::from_str::<Variant>("99").is_err());
252    }
253
254    #[test]
255    fn variant_serialises_to_canonical_string() {
256        assert_eq!(serde_json::to_string(&Variant::Cpu).unwrap(), "\"DEFAULT\"");
257        assert_eq!(serde_json::to_string(&Variant::Gpu).unwrap(), "\"GPU\"");
258        assert_eq!(serde_json::to_string(&Variant::Tpu).unwrap(), "\"TPU\"");
259    }
260
261    #[test]
262    fn outcome_deserialises_known_codes() {
263        assert_eq!(
264            serde_json::from_str::<Outcome>("4").unwrap(),
265            Outcome::Success
266        );
267        assert_eq!(
268            serde_json::from_str::<Outcome>("1").unwrap(),
269            Outcome::QuotaDeniedVariants
270        );
271        assert_eq!(
272            serde_json::from_str::<Outcome>("5").unwrap(),
273            Outcome::Denylisted
274        );
275    }
276
277    #[test]
278    fn outcome_rejects_unknown_code() {
279        assert!(serde_json::from_str::<Outcome>("42").is_err());
280    }
281
282    #[test]
283    fn assignment_parses_real_payload() {
284        let json = r#"{
285            "endpoint": "abc-123",
286            "variant": "GPU",
287            "accelerator": "T4",
288            "machineShape": 1,
289            "runtimeProxyInfo": {
290                "token": "tok",
291                "tokenExpiresInSeconds": 3600,
292                "url": "https://example.com"
293            },
294            "fit": 5400,
295            "outcome": 4
296        }"#;
297        let a: Assignment = serde_json::from_str(json).unwrap();
298        assert_eq!(a.endpoint, "abc-123");
299        assert_eq!(a.variant, Variant::Gpu);
300        assert_eq!(a.accelerator.as_deref(), Some("T4"));
301        assert_eq!(a.machine_shape, Some(Shape::HighMem));
302        assert_eq!(a.idle_timeout_sec, Some(5400));
303        assert_eq!(a.outcome, Some(Outcome::Success));
304    }
305
306    #[test]
307    fn listed_assignment_allows_missing_proxy_info() {
308        let json = r#"{"endpoint":"e","variant":0}"#;
309        let la: ListedAssignment = serde_json::from_str(json).unwrap();
310        assert_eq!(la.variant, Variant::Cpu);
311        assert!(la.runtime_proxy_info.is_none());
312    }
313
314    #[test]
315    fn shape_round_trip_known_and_unknown() {
316        assert_eq!(Shape::from(0u8), Shape::Standard);
317        assert_eq!(Shape::from(1u8), Shape::HighMem);
318        assert_eq!(Shape::from(7u8), Shape::Unknown(7));
319        assert_eq!(u8::from(Shape::Standard), 0);
320        assert_eq!(u8::from(Shape::HighMem), 1);
321        assert_eq!(u8::from(Shape::Unknown(7)), 7);
322    }
323
324    #[test]
325    fn shape_json_round_trip() {
326        let s: Shape = serde_json::from_str("0").unwrap();
327        assert_eq!(s, Shape::Standard);
328        let s: Shape = serde_json::from_str("1").unwrap();
329        assert_eq!(s, Shape::HighMem);
330        assert_eq!(serde_json::to_string(&Shape::HighMem).unwrap(), "1");
331    }
332
333    #[test]
334    fn ccu_info_parses() {
335        let json = r#"{
336            "currentBalance": 42.25,
337            "consumptionRateHourly": 1.76,
338            "assignmentsCount": 1,
339            "eligibleGpus": ["T4", "A100"],
340            "eligibleTpus": ["v2-8"]
341        }"#;
342        let c: CcuInfo = serde_json::from_str(json).unwrap();
343        assert_eq!(c.current_balance, 42.25);
344        assert_eq!(c.consumption_rate_hourly, 1.76);
345        assert_eq!(c.assignments_count, 1);
346        assert_eq!(c.eligible_gpus, vec!["T4".to_string(), "A100".to_string()]);
347        assert_eq!(c.eligible_tpus, vec!["v2-8".to_string()]);
348    }
349}