Skip to main content

colab_cli/client/
mod.rs

1pub mod api;
2
3use std::future::Future;
4use std::path::Path;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::time::Duration;
8
9use base64::Engine;
10use reqwest::{Client, RequestBuilder, Response, header};
11use serde::de::DeserializeOwned;
12use uuid::Uuid;
13
14use crate::client::api::{
15    Assignment, CcuInfo, GetAssignmentResponse, JupyterTerminal, ListAssignmentsResponse,
16    ListedAssignment, Outcome, RuntimeProxyInfo, Session, Shape, Variant,
17};
18use crate::config::ColabConfig;
19use crate::error::{ColabError, Result};
20
21const ACCEPT_JSON: &str = "application/json";
22const CLIENT_AGENT: &str = "vscode";
23const TUNNEL_HEADER: &str = "X-Colab-Tunnel";
24const TUNNEL_VALUE: &str = "Google";
25const PROXY_TOKEN_HEADER: &str = "X-Colab-Runtime-Proxy-Token";
26const XSRF_TOKEN_HEADER: &str = "X-Goog-Colab-Token";
27const CLIENT_AGENT_HEADER: &str = "X-Colab-Client-Agent";
28const TUN_PREFIX: &str = "/tun/m";
29const XSSI_PREFIX: &[u8] = b")]}'\n";
30
31#[doc(hidden)]
32#[inline]
33pub fn strip_xssi(s: &str) -> &str {
34    let b = s.as_bytes();
35    if b.len() >= XSSI_PREFIX.len() && &b[..XSSI_PREFIX.len()] == XSSI_PREFIX {
36        unsafe { std::str::from_utf8_unchecked(&b[XSSI_PREFIX.len()..]) }
37    } else {
38        s
39    }
40}
41
42type TokenFn = Arc<dyn Fn() -> Pin<Box<dyn Future<Output = Result<String>> + Send>> + Send + Sync>;
43
44#[derive(Clone)]
45pub struct ColabClient {
46    http: Client,
47    colab_domain: String,
48    get_access_token: TokenFn,
49}
50
51impl ColabClient {
52    pub fn new<F, Fut>(config: &ColabConfig, get_access_token: F) -> Result<Self>
53    where
54        F: Fn() -> Fut + Send + Sync + 'static,
55        Fut: Future<Output = Result<String>> + Send + 'static,
56    {
57        let http = {
58            let mut b = Client::builder()
59                .use_rustls_tls()
60                .tcp_nodelay(true)
61                .http2_adaptive_window(true)
62                .http2_keep_alive_interval(Duration::from_secs(30))
63                .http2_keep_alive_while_idle(true)
64                .pool_idle_timeout(Duration::from_secs(90))
65                .pool_max_idle_per_host(8)
66                .timeout(Duration::from_secs(60))
67                .connect_timeout(Duration::from_secs(10));
68            if config.is_local() {
69                b = b.danger_accept_invalid_certs(true);
70            }
71            b.build().map_err(ColabError::Network)?
72        };
73
74        Ok(Self {
75            http,
76            colab_domain: config.colab_domain.trim_end_matches('/').to_string(),
77            get_access_token: Arc::new(move || Box::pin(get_access_token())),
78        })
79    }
80
81    pub async fn list_assignments(&self) -> Result<Vec<ListedAssignment>> {
82        let url = self.colab_url(format!("{TUN_PREFIX}/assignments"));
83        let resp = self.colab_request(self.http.get(&url)).await?;
84        let parsed: ListAssignmentsResponse = self.parse_json(resp).await?;
85        Ok(parsed.assignments)
86    }
87
88    pub async fn assign(
89        &self,
90        notebook_hash: Uuid,
91        variant: Variant,
92        accelerator: Option<&str>,
93        shape: Shape,
94    ) -> Result<(Assignment, bool)> {
95        let url = self.build_assign_url(notebook_hash, variant, accelerator, shape);
96
97        let get_resp = self.colab_request(self.http.get(&url)).await?;
98        let body = get_resp.text().await?;
99        let json: serde_json::Value = serde_json::from_str(strip_xssi(&body))?;
100
101        if json.get("endpoint").is_some() {
102            let assignment: Assignment = serde_json::from_value(json)?;
103            return Ok((assignment, false));
104        }
105
106        let get_response: GetAssignmentResponse = serde_json::from_value(json)?;
107        let xsrf_token = get_response.xsrf_token;
108
109        let post_resp = self
110            .colab_request(
111                self.http
112                    .post(&url)
113                    .header(XSRF_TOKEN_HEADER, &xsrf_token)
114                    .header(header::CONTENT_LENGTH, "0"),
115            )
116            .await?;
117        let assignment: Assignment = self.parse_json(post_resp).await?;
118
119        match assignment.outcome {
120            Some(Outcome::QuotaDeniedVariants) | Some(Outcome::QuotaExceededUsageTime) => {
121                Err(ColabError::InsufficientQuota)
122            }
123            Some(Outcome::Denylisted) => Err(ColabError::AccountDenylisted),
124            _ => Ok((assignment, true)),
125        }
126    }
127
128    pub async fn unassign(&self, endpoint: &str) -> Result<()> {
129        let url = self.colab_url(format!("{TUN_PREFIX}/unassign/{endpoint}"));
130
131        let token_resp = self.colab_request(self.http.get(&url)).await?;
132        let token_body: serde_json::Value = self.parse_json(token_resp).await?;
133        let token = token_body["token"]
134            .as_str()
135            .ok_or_else(|| ColabError::parse("missing token in unassign response"))?
136            .to_string();
137
138        self.colab_request(
139            self.http
140                .post(&url)
141                .header(XSRF_TOKEN_HEADER, &token)
142                .header(header::CONTENT_LENGTH, "0"),
143        )
144        .await?;
145        Ok(())
146    }
147
148    pub async fn refresh_connection(&self, endpoint: &str) -> Result<RuntimeProxyInfo> {
149        let url = self.colab_url(format!("{TUN_PREFIX}/runtime-proxy-token"));
150        let url = format!("{url}&endpoint={endpoint}&port=8080");
151        let resp = self
152            .colab_request(self.http.get(&url).header(TUNNEL_HEADER, TUNNEL_VALUE))
153            .await?;
154        self.parse_json(resp).await
155    }
156
157    pub async fn list_sessions_via_tunnel(&self, endpoint: &str) -> Result<Vec<Session>> {
158        let url = self.colab_url(format!("{TUN_PREFIX}/{endpoint}/api/sessions"));
159        let resp = self
160            .colab_request(self.http.get(&url).header(TUNNEL_HEADER, TUNNEL_VALUE))
161            .await?;
162        self.parse_json(resp).await
163    }
164
165    pub async fn delete_session(
166        &self,
167        proxy_url: &str,
168        proxy_token: &str,
169        session_id: &str,
170    ) -> Result<()> {
171        let url = format!(
172            "{}/api/sessions/{session_id}",
173            proxy_url.trim_end_matches('/')
174        );
175        let resp = self
176            .http
177            .delete(&url)
178            .header(PROXY_TOKEN_HEADER, proxy_token)
179            .header(CLIENT_AGENT_HEADER, CLIENT_AGENT)
180            .send()
181            .await?;
182        self.check_status_raw(resp, &url).await?;
183        Ok(())
184    }
185
186    pub async fn create_terminal(
187        &self,
188        proxy_url: &str,
189        proxy_token: &str,
190    ) -> Result<JupyterTerminal> {
191        let url = format!("{}/api/terminals", proxy_url.trim_end_matches('/'));
192        let resp = self
193            .http
194            .post(&url)
195            .header(PROXY_TOKEN_HEADER, proxy_token)
196            .header(CLIENT_AGENT_HEADER, CLIENT_AGENT)
197            .header(header::ACCEPT, ACCEPT_JSON)
198            .header(header::CONTENT_LENGTH, "0")
199            .send()
200            .await?;
201        let resp = self.check_status_raw(resp, &url).await?;
202        Ok(resp.json().await?)
203    }
204
205    /// Delete a Jupyter terminal that was previously created with
206    /// `create_terminal`. Used to cleanly reap the remote process tree
207    /// belonging to a specific short-lived view (e.g. `server ps`) without
208    /// touching unrelated sessions or the assigned server itself.
209    pub async fn delete_terminal(
210        &self,
211        proxy_url: &str,
212        proxy_token: &str,
213        terminal_name: &str,
214    ) -> Result<()> {
215        let url = format!(
216            "{}/api/terminals/{}",
217            proxy_url.trim_end_matches('/'),
218            terminal_name
219        );
220        let resp = self
221            .http
222            .delete(&url)
223            .header(PROXY_TOKEN_HEADER, proxy_token)
224            .header(CLIENT_AGENT_HEADER, CLIENT_AGENT)
225            .send()
226            .await?;
227        // A 404 here means the terminal was already reaped by the remote
228        // (e.g. because the user's bpytop exited and the shell walked out
229        // of its parent). That's not an error from our perspective.
230        if resp.status().as_u16() == 404 {
231            return Ok(());
232        }
233        self.check_status_raw(resp, &url).await?;
234        Ok(())
235    }
236
237    pub fn terminal_ws_url(&self, proxy_url: &str, terminal_name: &str) -> String {
238        let base = proxy_url
239            .trim_end_matches('/')
240            .replace("https://", "wss://")
241            .replace("http://", "ws://");
242        format!("{base}/terminals/websocket/{terminal_name}")
243    }
244
245    pub async fn upload_file_streaming(
246        &self,
247        proxy_url: &str,
248        proxy_token: &str,
249        remote_path: &str,
250        file_path: &Path,
251        progress: impl Fn(u64) + Send + 'static,
252    ) -> Result<()> {
253        // url-encode each segment so paths with spaces work; keep `/` as-is
254        let encoded_path = remote_path
255            .trim_start_matches('/')
256            .split('/')
257            .map(|seg| urlencoding::encode(seg).into_owned())
258            .collect::<Vec<_>>()
259            .join("/");
260        let url = format!(
261            "{}/api/contents/{encoded_path}",
262            proxy_url.trim_end_matches('/')
263        );
264
265        let meta = std::fs::metadata(file_path)?;
266        if !meta.is_file() {
267            return Err(ColabError::config(format!(
268                "not a regular file: {}",
269                file_path.display()
270            )));
271        }
272        let file_size = meta.len();
273
274        const CHUNK_RAW: usize = 3 * 1024 * 1024;
275
276        let prefix = br#"{"type":"file","format":"base64","content":""#;
277        let suffix = br#""}"#;
278        let base64_len = (file_size.div_ceil(3) * 4) as usize;
279        let content_length = prefix.len() + base64_len + suffix.len();
280
281        let file_path = file_path.to_owned();
282        let (tx, rx) =
283            tokio::sync::mpsc::channel::<std::result::Result<Vec<u8>, std::io::Error>>(4);
284
285        tokio::task::spawn_blocking(move || {
286            use std::io::Read;
287
288            if tx.blocking_send(Ok(prefix.to_vec())).is_err() {
289                return;
290            }
291
292            let mut file = match std::fs::File::open(&file_path) {
293                Ok(f) => f,
294                Err(e) => {
295                    let _ = tx.blocking_send(Err(e));
296                    return;
297                }
298            };
299
300            let mut buf = vec![0u8; CHUNK_RAW];
301            let mut bytes_so_far = 0u64;
302
303            loop {
304                let mut filled = 0;
305                while filled < CHUNK_RAW {
306                    match file.read(&mut buf[filled..]) {
307                        Ok(0) => break,
308                        Ok(n) => filled += n,
309                        Err(e) => {
310                            let _ = tx.blocking_send(Err(e));
311                            return;
312                        }
313                    }
314                }
315                if filled == 0 {
316                    break;
317                }
318                bytes_so_far += filled as u64;
319                progress(bytes_so_far);
320                let encoded = base64::engine::general_purpose::STANDARD
321                    .encode(&buf[..filled])
322                    .into_bytes();
323                if tx.blocking_send(Ok(encoded)).is_err() {
324                    return;
325                }
326            }
327
328            let _ = tx.blocking_send(Ok(suffix.to_vec()));
329        });
330
331        let stream = futures_util::stream::unfold(rx, |mut rx| async {
332            rx.recv().await.map(|item| (item, rx))
333        });
334
335        let body = reqwest::Body::wrap_stream(stream);
336
337        let resp = self
338            .http
339            .put(&url)
340            .header(PROXY_TOKEN_HEADER, proxy_token)
341            .header(CLIENT_AGENT_HEADER, CLIENT_AGENT)
342            .header(header::CONTENT_TYPE, "application/json")
343            .header(header::CONTENT_LENGTH, content_length.to_string())
344            .body(body)
345            .send()
346            .await?;
347
348        self.check_status_raw(resp, &url).await?;
349        Ok(())
350    }
351
352    pub async fn send_keep_alive(&self, endpoint: &str) -> Result<()> {
353        let url = self.colab_url(format!("{TUN_PREFIX}/{endpoint}/keep-alive/"));
354        self.colab_request(
355            self.http
356                .post(&url)
357                .header(TUNNEL_HEADER, TUNNEL_VALUE)
358                .header(header::CONTENT_LENGTH, "0"),
359        )
360        .await?;
361        Ok(())
362    }
363
364    pub async fn get_ccu_info(&self) -> Result<CcuInfo> {
365        let url = self.colab_url(format!("{TUN_PREFIX}/ccu-info"));
366        let resp = self.colab_request(self.http.get(&url)).await?;
367        self.parse_json(resp).await
368    }
369
370    #[inline]
371    fn colab_url(&self, path: impl AsRef<str>) -> String {
372        let mut out = String::with_capacity(self.colab_domain.len() + path.as_ref().len() + 10);
373        out.push_str(&self.colab_domain);
374        out.push_str(path.as_ref());
375        out.push_str("?authuser=0");
376        out
377    }
378
379    fn build_assign_url(
380        &self,
381        notebook_hash: Uuid,
382        variant: Variant,
383        accelerator: Option<&str>,
384        shape: Shape,
385    ) -> String {
386        build_assign_url(
387            &self.colab_domain,
388            notebook_hash,
389            variant,
390            accelerator,
391            shape,
392        )
393    }
394
395    async fn colab_request(&self, builder: RequestBuilder) -> Result<Response> {
396        let token = (self.get_access_token)().await?;
397        let resp = builder
398            .header(header::AUTHORIZATION, format!("Bearer {token}"))
399            .header(header::ACCEPT, ACCEPT_JSON)
400            .header(CLIENT_AGENT_HEADER, CLIENT_AGENT)
401            .send()
402            .await?;
403        let url = resp.url().to_string();
404        self.check_status_raw(resp, &url).await
405    }
406
407    async fn check_status_raw(&self, resp: Response, url: &str) -> Result<Response> {
408        if resp.status().is_success() {
409            return Ok(resp);
410        }
411        let status = resp.status().as_u16();
412        let body = resp.text().await.ok();
413        match status {
414            412 => Err(ColabError::TooManyAssignments),
415            404 => Err(ColabError::ServerNotFound {
416                endpoint: url.to_string(),
417            }),
418            _ => Err(ColabError::api(status, url, body)),
419        }
420    }
421
422    async fn parse_json<T: DeserializeOwned>(&self, resp: Response) -> Result<T> {
423        let body = resp.text().await?;
424        serde_json::from_str(strip_xssi(&body)).map_err(|e| {
425            ColabError::parse(format!("failed to parse API response: {e}\nbody: {body}"))
426        })
427    }
428}
429
430#[doc(hidden)]
431pub fn build_assign_url(
432    colab_domain: &str,
433    notebook_hash: Uuid,
434    variant: Variant,
435    accelerator: Option<&str>,
436    shape: Shape,
437) -> String {
438    let nbh = uuid_to_websafe_base64(notebook_hash);
439    let mut url = String::with_capacity(colab_domain.len() + 96);
440    url.push_str(colab_domain);
441    url.push_str(TUN_PREFIX);
442    url.push_str("/assign?authuser=0&nbh=");
443    url.push_str(&nbh);
444    if !matches!(variant, Variant::Cpu) {
445        url.push_str("&variant=");
446        url.push_str(variant_param(variant));
447    }
448    if let Some(acc) = accelerator {
449        url.push_str("&accelerator=");
450        url.push_str(acc);
451    }
452    // High-RAM is requested via `&shape=hm` — this matches exactly what the
453    // Colab web UI sends (see network capture). There is no `machineShape=N`
454    // parameter; Standard omits the param entirely.
455    if matches!(shape, Shape::HighMem) {
456        url.push_str("&shape=hm");
457    }
458    url
459}
460
461#[doc(hidden)]
462#[inline]
463pub fn uuid_to_websafe_base64(id: Uuid) -> String {
464    let s = id.to_string().replace('-', "_");
465    let mut out = String::with_capacity(44);
466    out.push_str(&s);
467    for _ in s.len()..44 {
468        out.push('.');
469    }
470    out
471}
472
473#[inline]
474fn variant_param(v: Variant) -> &'static str {
475    match v {
476        Variant::Cpu => "DEFAULT",
477        Variant::Gpu => "GPU",
478        Variant::Tpu => "TPU",
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn strip_xssi_removes_prefix_when_present() {
488        assert_eq!(strip_xssi(")]}'\n{\"a\":1}"), "{\"a\":1}");
489    }
490
491    #[test]
492    fn strip_xssi_is_identity_without_prefix() {
493        assert_eq!(strip_xssi("{\"a\":1}"), "{\"a\":1}");
494    }
495
496    #[test]
497    fn strip_xssi_handles_empty() {
498        assert_eq!(strip_xssi(""), "");
499    }
500
501    #[test]
502    fn uuid_encodes_to_44_char_websafe() {
503        let id = Uuid::nil();
504        let nbh = uuid_to_websafe_base64(id);
505        assert_eq!(nbh.len(), 44);
506        assert!(nbh.starts_with("00000000_0000_0000_0000_000000000000"));
507        assert!(nbh.ends_with('.'));
508        assert!(!nbh.contains('-'));
509    }
510
511    #[test]
512    fn uuid_round_trips_a_real_uuid() {
513        let id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
514        let nbh = uuid_to_websafe_base64(id);
515        assert_eq!(nbh.len(), 44);
516        assert_eq!(&nbh[..36], "550e8400_e29b_41d4_a716_446655440000");
517        assert_eq!(&nbh[36..], "........");
518    }
519
520    #[test]
521    fn variant_param_mapping() {
522        assert_eq!(variant_param(Variant::Cpu), "DEFAULT");
523        assert_eq!(variant_param(Variant::Gpu), "GPU");
524        assert_eq!(variant_param(Variant::Tpu), "TPU");
525    }
526
527    #[test]
528    fn assign_url_cpu_standard_is_minimal() {
529        let id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
530        let u = build_assign_url(
531            "https://colab.research.google.com",
532            id,
533            Variant::Cpu,
534            None,
535            Shape::Standard,
536        );
537        assert!(u.contains("/tun/m/assign?authuser=0"));
538        assert!(u.contains("&nbh=550e8400_e29b_41d4_a716_446655440000"));
539        assert!(!u.contains("variant="));
540        assert!(!u.contains("accelerator="));
541        assert!(!u.contains("shape="));
542        assert!(!u.contains("machineShape="));
543    }
544
545    #[test]
546    fn assign_url_gpu_with_accelerator_and_highmem() {
547        let id = Uuid::nil();
548        let u = build_assign_url(
549            "https://colab.research.google.com",
550            id,
551            Variant::Gpu,
552            Some("T4"),
553            Shape::HighMem,
554        );
555        assert!(u.contains("variant=GPU"));
556        assert!(u.contains("accelerator=T4"));
557        // High-RAM is signalled with `&shape=hm` — matches colab.research.google.com web UI.
558        assert!(u.contains("&shape=hm"));
559        assert!(!u.contains("machineShape="));
560    }
561
562    #[test]
563    fn assign_url_tpu_no_accelerator_standard() {
564        let id = Uuid::nil();
565        let u = build_assign_url("https://x.y", id, Variant::Tpu, None, Shape::Standard);
566        assert!(u.contains("variant=TPU"));
567        assert!(!u.contains("accelerator="));
568        assert!(!u.contains("shape="));
569    }
570}