1pub mod api;
2
3use std::future::Future;
4use std::path::Path;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::time::Duration;
8
9use base64::Engine;
10use reqwest::{Client, RequestBuilder, Response, header};
11use serde::de::DeserializeOwned;
12use uuid::Uuid;
13
14use crate::client::api::{
15 Assignment, CcuInfo, GetAssignmentResponse, JupyterTerminal, ListAssignmentsResponse,
16 ListedAssignment, Outcome, RuntimeProxyInfo, Session, Shape, Variant,
17};
18use crate::config::ColabConfig;
19use crate::error::{ColabError, Result};
20
21const ACCEPT_JSON: &str = "application/json";
22const CLIENT_AGENT: &str = "vscode";
23const TUNNEL_HEADER: &str = "X-Colab-Tunnel";
24const TUNNEL_VALUE: &str = "Google";
25const PROXY_TOKEN_HEADER: &str = "X-Colab-Runtime-Proxy-Token";
26const XSRF_TOKEN_HEADER: &str = "X-Goog-Colab-Token";
27const CLIENT_AGENT_HEADER: &str = "X-Colab-Client-Agent";
28const TUN_PREFIX: &str = "/tun/m";
29const XSSI_PREFIX: &[u8] = b")]}'\n";
30
31#[doc(hidden)]
32#[inline]
33pub fn strip_xssi(s: &str) -> &str {
34 let b = s.as_bytes();
35 if b.len() >= XSSI_PREFIX.len() && &b[..XSSI_PREFIX.len()] == XSSI_PREFIX {
36 unsafe { std::str::from_utf8_unchecked(&b[XSSI_PREFIX.len()..]) }
37 } else {
38 s
39 }
40}
41
42type TokenFn = Arc<dyn Fn() -> Pin<Box<dyn Future<Output = Result<String>> + Send>> + Send + Sync>;
43
44#[derive(Clone)]
45pub struct ColabClient {
46 http: Client,
47 colab_domain: String,
48 get_access_token: TokenFn,
49}
50
51impl ColabClient {
52 pub fn new<F, Fut>(config: &ColabConfig, get_access_token: F) -> Result<Self>
53 where
54 F: Fn() -> Fut + Send + Sync + 'static,
55 Fut: Future<Output = Result<String>> + Send + 'static,
56 {
57 let http = {
58 let mut b = Client::builder()
59 .use_rustls_tls()
60 .tcp_nodelay(true)
61 .http2_adaptive_window(true)
62 .http2_keep_alive_interval(Duration::from_secs(30))
63 .http2_keep_alive_while_idle(true)
64 .pool_idle_timeout(Duration::from_secs(90))
65 .pool_max_idle_per_host(8)
66 .timeout(Duration::from_secs(60))
67 .connect_timeout(Duration::from_secs(10));
68 if config.is_local() {
69 b = b.danger_accept_invalid_certs(true);
70 }
71 b.build().map_err(ColabError::Network)?
72 };
73
74 Ok(Self {
75 http,
76 colab_domain: config.colab_domain.trim_end_matches('/').to_string(),
77 get_access_token: Arc::new(move || Box::pin(get_access_token())),
78 })
79 }
80
81 pub async fn list_assignments(&self) -> Result<Vec<ListedAssignment>> {
82 let url = self.colab_url(format!("{TUN_PREFIX}/assignments"));
83 let resp = self.colab_request(self.http.get(&url)).await?;
84 let parsed: ListAssignmentsResponse = self.parse_json(resp).await?;
85 Ok(parsed.assignments)
86 }
87
88 pub async fn assign(
89 &self,
90 notebook_hash: Uuid,
91 variant: Variant,
92 accelerator: Option<&str>,
93 shape: Shape,
94 ) -> Result<(Assignment, bool)> {
95 let url = self.build_assign_url(notebook_hash, variant, accelerator, shape);
96
97 let get_resp = self.colab_request(self.http.get(&url)).await?;
98 let body = get_resp.text().await?;
99 let json: serde_json::Value = serde_json::from_str(strip_xssi(&body))?;
100
101 if json.get("endpoint").is_some() {
102 let assignment: Assignment = serde_json::from_value(json)?;
103 return Ok((assignment, false));
104 }
105
106 let get_response: GetAssignmentResponse = serde_json::from_value(json)?;
107 let xsrf_token = get_response.xsrf_token;
108
109 let post_resp = self
110 .colab_request(
111 self.http
112 .post(&url)
113 .header(XSRF_TOKEN_HEADER, &xsrf_token)
114 .header(header::CONTENT_LENGTH, "0"),
115 )
116 .await?;
117 let assignment: Assignment = self.parse_json(post_resp).await?;
118
119 match assignment.outcome {
120 Some(Outcome::QuotaDeniedVariants) | Some(Outcome::QuotaExceededUsageTime) => {
121 Err(ColabError::InsufficientQuota)
122 }
123 Some(Outcome::Denylisted) => Err(ColabError::AccountDenylisted),
124 _ => Ok((assignment, true)),
125 }
126 }
127
128 pub async fn unassign(&self, endpoint: &str) -> Result<()> {
129 let url = self.colab_url(format!("{TUN_PREFIX}/unassign/{endpoint}"));
130
131 let token_resp = self.colab_request(self.http.get(&url)).await?;
132 let token_body: serde_json::Value = self.parse_json(token_resp).await?;
133 let token = token_body["token"]
134 .as_str()
135 .ok_or_else(|| ColabError::parse("missing token in unassign response"))?
136 .to_string();
137
138 self.colab_request(
139 self.http
140 .post(&url)
141 .header(XSRF_TOKEN_HEADER, &token)
142 .header(header::CONTENT_LENGTH, "0"),
143 )
144 .await?;
145 Ok(())
146 }
147
148 pub async fn refresh_connection(&self, endpoint: &str) -> Result<RuntimeProxyInfo> {
149 let url = self.colab_url(format!("{TUN_PREFIX}/runtime-proxy-token"));
150 let url = format!("{url}&endpoint={endpoint}&port=8080");
151 let resp = self
152 .colab_request(self.http.get(&url).header(TUNNEL_HEADER, TUNNEL_VALUE))
153 .await?;
154 self.parse_json(resp).await
155 }
156
157 pub async fn list_sessions_via_tunnel(&self, endpoint: &str) -> Result<Vec<Session>> {
158 let url = self.colab_url(format!("{TUN_PREFIX}/{endpoint}/api/sessions"));
159 let resp = self
160 .colab_request(self.http.get(&url).header(TUNNEL_HEADER, TUNNEL_VALUE))
161 .await?;
162 self.parse_json(resp).await
163 }
164
165 pub async fn delete_session(
166 &self,
167 proxy_url: &str,
168 proxy_token: &str,
169 session_id: &str,
170 ) -> Result<()> {
171 let url = format!(
172 "{}/api/sessions/{session_id}",
173 proxy_url.trim_end_matches('/')
174 );
175 let resp = self
176 .http
177 .delete(&url)
178 .header(PROXY_TOKEN_HEADER, proxy_token)
179 .header(CLIENT_AGENT_HEADER, CLIENT_AGENT)
180 .send()
181 .await?;
182 self.check_status_raw(resp, &url).await?;
183 Ok(())
184 }
185
186 pub async fn create_terminal(
187 &self,
188 proxy_url: &str,
189 proxy_token: &str,
190 ) -> Result<JupyterTerminal> {
191 let url = format!("{}/api/terminals", proxy_url.trim_end_matches('/'));
192 let resp = self
193 .http
194 .post(&url)
195 .header(PROXY_TOKEN_HEADER, proxy_token)
196 .header(CLIENT_AGENT_HEADER, CLIENT_AGENT)
197 .header(header::ACCEPT, ACCEPT_JSON)
198 .header(header::CONTENT_LENGTH, "0")
199 .send()
200 .await?;
201 let resp = self.check_status_raw(resp, &url).await?;
202 Ok(resp.json().await?)
203 }
204
205 pub async fn delete_terminal(
210 &self,
211 proxy_url: &str,
212 proxy_token: &str,
213 terminal_name: &str,
214 ) -> Result<()> {
215 let url = format!(
216 "{}/api/terminals/{}",
217 proxy_url.trim_end_matches('/'),
218 terminal_name
219 );
220 let resp = self
221 .http
222 .delete(&url)
223 .header(PROXY_TOKEN_HEADER, proxy_token)
224 .header(CLIENT_AGENT_HEADER, CLIENT_AGENT)
225 .send()
226 .await?;
227 if resp.status().as_u16() == 404 {
231 return Ok(());
232 }
233 self.check_status_raw(resp, &url).await?;
234 Ok(())
235 }
236
237 pub fn terminal_ws_url(&self, proxy_url: &str, terminal_name: &str) -> String {
238 let base = proxy_url
239 .trim_end_matches('/')
240 .replace("https://", "wss://")
241 .replace("http://", "ws://");
242 format!("{base}/terminals/websocket/{terminal_name}")
243 }
244
245 pub async fn upload_file_streaming(
246 &self,
247 proxy_url: &str,
248 proxy_token: &str,
249 remote_path: &str,
250 file_path: &Path,
251 progress: impl Fn(u64) + Send + 'static,
252 ) -> Result<()> {
253 let encoded_path = remote_path
255 .trim_start_matches('/')
256 .split('/')
257 .map(|seg| urlencoding::encode(seg).into_owned())
258 .collect::<Vec<_>>()
259 .join("/");
260 let url = format!(
261 "{}/api/contents/{encoded_path}",
262 proxy_url.trim_end_matches('/')
263 );
264
265 let meta = std::fs::metadata(file_path)?;
266 if !meta.is_file() {
267 return Err(ColabError::config(format!(
268 "not a regular file: {}",
269 file_path.display()
270 )));
271 }
272 let file_size = meta.len();
273
274 const CHUNK_RAW: usize = 3 * 1024 * 1024;
275
276 let prefix = br#"{"type":"file","format":"base64","content":""#;
277 let suffix = br#""}"#;
278 let base64_len = (file_size.div_ceil(3) * 4) as usize;
279 let content_length = prefix.len() + base64_len + suffix.len();
280
281 let file_path = file_path.to_owned();
282 let (tx, rx) =
283 tokio::sync::mpsc::channel::<std::result::Result<Vec<u8>, std::io::Error>>(4);
284
285 tokio::task::spawn_blocking(move || {
286 use std::io::Read;
287
288 if tx.blocking_send(Ok(prefix.to_vec())).is_err() {
289 return;
290 }
291
292 let mut file = match std::fs::File::open(&file_path) {
293 Ok(f) => f,
294 Err(e) => {
295 let _ = tx.blocking_send(Err(e));
296 return;
297 }
298 };
299
300 let mut buf = vec![0u8; CHUNK_RAW];
301 let mut bytes_so_far = 0u64;
302
303 loop {
304 let mut filled = 0;
305 while filled < CHUNK_RAW {
306 match file.read(&mut buf[filled..]) {
307 Ok(0) => break,
308 Ok(n) => filled += n,
309 Err(e) => {
310 let _ = tx.blocking_send(Err(e));
311 return;
312 }
313 }
314 }
315 if filled == 0 {
316 break;
317 }
318 bytes_so_far += filled as u64;
319 progress(bytes_so_far);
320 let encoded = base64::engine::general_purpose::STANDARD
321 .encode(&buf[..filled])
322 .into_bytes();
323 if tx.blocking_send(Ok(encoded)).is_err() {
324 return;
325 }
326 }
327
328 let _ = tx.blocking_send(Ok(suffix.to_vec()));
329 });
330
331 let stream = futures_util::stream::unfold(rx, |mut rx| async {
332 rx.recv().await.map(|item| (item, rx))
333 });
334
335 let body = reqwest::Body::wrap_stream(stream);
336
337 let resp = self
338 .http
339 .put(&url)
340 .header(PROXY_TOKEN_HEADER, proxy_token)
341 .header(CLIENT_AGENT_HEADER, CLIENT_AGENT)
342 .header(header::CONTENT_TYPE, "application/json")
343 .header(header::CONTENT_LENGTH, content_length.to_string())
344 .body(body)
345 .send()
346 .await?;
347
348 self.check_status_raw(resp, &url).await?;
349 Ok(())
350 }
351
352 pub async fn send_keep_alive(&self, endpoint: &str) -> Result<()> {
353 let url = self.colab_url(format!("{TUN_PREFIX}/{endpoint}/keep-alive/"));
354 self.colab_request(
355 self.http
356 .post(&url)
357 .header(TUNNEL_HEADER, TUNNEL_VALUE)
358 .header(header::CONTENT_LENGTH, "0"),
359 )
360 .await?;
361 Ok(())
362 }
363
364 pub async fn get_ccu_info(&self) -> Result<CcuInfo> {
365 let url = self.colab_url(format!("{TUN_PREFIX}/ccu-info"));
366 let resp = self.colab_request(self.http.get(&url)).await?;
367 self.parse_json(resp).await
368 }
369
370 #[inline]
371 fn colab_url(&self, path: impl AsRef<str>) -> String {
372 let mut out = String::with_capacity(self.colab_domain.len() + path.as_ref().len() + 10);
373 out.push_str(&self.colab_domain);
374 out.push_str(path.as_ref());
375 out.push_str("?authuser=0");
376 out
377 }
378
379 fn build_assign_url(
380 &self,
381 notebook_hash: Uuid,
382 variant: Variant,
383 accelerator: Option<&str>,
384 shape: Shape,
385 ) -> String {
386 build_assign_url(
387 &self.colab_domain,
388 notebook_hash,
389 variant,
390 accelerator,
391 shape,
392 )
393 }
394
395 async fn colab_request(&self, builder: RequestBuilder) -> Result<Response> {
396 let token = (self.get_access_token)().await?;
397 let resp = builder
398 .header(header::AUTHORIZATION, format!("Bearer {token}"))
399 .header(header::ACCEPT, ACCEPT_JSON)
400 .header(CLIENT_AGENT_HEADER, CLIENT_AGENT)
401 .send()
402 .await?;
403 let url = resp.url().to_string();
404 self.check_status_raw(resp, &url).await
405 }
406
407 async fn check_status_raw(&self, resp: Response, url: &str) -> Result<Response> {
408 if resp.status().is_success() {
409 return Ok(resp);
410 }
411 let status = resp.status().as_u16();
412 let body = resp.text().await.ok();
413 match status {
414 412 => Err(ColabError::TooManyAssignments),
415 404 => Err(ColabError::ServerNotFound {
416 endpoint: url.to_string(),
417 }),
418 _ => Err(ColabError::api(status, url, body)),
419 }
420 }
421
422 async fn parse_json<T: DeserializeOwned>(&self, resp: Response) -> Result<T> {
423 let body = resp.text().await?;
424 serde_json::from_str(strip_xssi(&body)).map_err(|e| {
425 ColabError::parse(format!("failed to parse API response: {e}\nbody: {body}"))
426 })
427 }
428}
429
430#[doc(hidden)]
431pub fn build_assign_url(
432 colab_domain: &str,
433 notebook_hash: Uuid,
434 variant: Variant,
435 accelerator: Option<&str>,
436 shape: Shape,
437) -> String {
438 let nbh = uuid_to_websafe_base64(notebook_hash);
439 let mut url = String::with_capacity(colab_domain.len() + 96);
440 url.push_str(colab_domain);
441 url.push_str(TUN_PREFIX);
442 url.push_str("/assign?authuser=0&nbh=");
443 url.push_str(&nbh);
444 if !matches!(variant, Variant::Cpu) {
445 url.push_str("&variant=");
446 url.push_str(variant_param(variant));
447 }
448 if let Some(acc) = accelerator {
449 url.push_str("&accelerator=");
450 url.push_str(acc);
451 }
452 if matches!(shape, Shape::HighMem) {
456 url.push_str("&shape=hm");
457 }
458 url
459}
460
461#[doc(hidden)]
462#[inline]
463pub fn uuid_to_websafe_base64(id: Uuid) -> String {
464 let s = id.to_string().replace('-', "_");
465 let mut out = String::with_capacity(44);
466 out.push_str(&s);
467 for _ in s.len()..44 {
468 out.push('.');
469 }
470 out
471}
472
473#[inline]
474fn variant_param(v: Variant) -> &'static str {
475 match v {
476 Variant::Cpu => "DEFAULT",
477 Variant::Gpu => "GPU",
478 Variant::Tpu => "TPU",
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn strip_xssi_removes_prefix_when_present() {
488 assert_eq!(strip_xssi(")]}'\n{\"a\":1}"), "{\"a\":1}");
489 }
490
491 #[test]
492 fn strip_xssi_is_identity_without_prefix() {
493 assert_eq!(strip_xssi("{\"a\":1}"), "{\"a\":1}");
494 }
495
496 #[test]
497 fn strip_xssi_handles_empty() {
498 assert_eq!(strip_xssi(""), "");
499 }
500
501 #[test]
502 fn uuid_encodes_to_44_char_websafe() {
503 let id = Uuid::nil();
504 let nbh = uuid_to_websafe_base64(id);
505 assert_eq!(nbh.len(), 44);
506 assert!(nbh.starts_with("00000000_0000_0000_0000_000000000000"));
507 assert!(nbh.ends_with('.'));
508 assert!(!nbh.contains('-'));
509 }
510
511 #[test]
512 fn uuid_round_trips_a_real_uuid() {
513 let id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
514 let nbh = uuid_to_websafe_base64(id);
515 assert_eq!(nbh.len(), 44);
516 assert_eq!(&nbh[..36], "550e8400_e29b_41d4_a716_446655440000");
517 assert_eq!(&nbh[36..], "........");
518 }
519
520 #[test]
521 fn variant_param_mapping() {
522 assert_eq!(variant_param(Variant::Cpu), "DEFAULT");
523 assert_eq!(variant_param(Variant::Gpu), "GPU");
524 assert_eq!(variant_param(Variant::Tpu), "TPU");
525 }
526
527 #[test]
528 fn assign_url_cpu_standard_is_minimal() {
529 let id = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
530 let u = build_assign_url(
531 "https://colab.research.google.com",
532 id,
533 Variant::Cpu,
534 None,
535 Shape::Standard,
536 );
537 assert!(u.contains("/tun/m/assign?authuser=0"));
538 assert!(u.contains("&nbh=550e8400_e29b_41d4_a716_446655440000"));
539 assert!(!u.contains("variant="));
540 assert!(!u.contains("accelerator="));
541 assert!(!u.contains("shape="));
542 assert!(!u.contains("machineShape="));
543 }
544
545 #[test]
546 fn assign_url_gpu_with_accelerator_and_highmem() {
547 let id = Uuid::nil();
548 let u = build_assign_url(
549 "https://colab.research.google.com",
550 id,
551 Variant::Gpu,
552 Some("T4"),
553 Shape::HighMem,
554 );
555 assert!(u.contains("variant=GPU"));
556 assert!(u.contains("accelerator=T4"));
557 assert!(u.contains("&shape=hm"));
559 assert!(!u.contains("machineShape="));
560 }
561
562 #[test]
563 fn assign_url_tpu_no_accelerator_standard() {
564 let id = Uuid::nil();
565 let u = build_assign_url("https://x.y", id, Variant::Tpu, None, Shape::Standard);
566 assert!(u.contains("variant=TPU"));
567 assert!(!u.contains("accelerator="));
568 assert!(!u.contains("shape="));
569 }
570}