Skip to main content

colab_cli/auth/
oauth.rs

1use base64::Engine;
2use base64::engine::general_purpose::URL_SAFE_NO_PAD;
3use chrono::{Duration, Utc};
4use reqwest::Client;
5use serde::Deserialize;
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::net::TcpListener;
8
9use crate::auth::storage::{AccountInfo, TokenStorage};
10use crate::config::ColabConfig;
11use crate::error::{ColabError, Result};
12
13const FLOW_TIMEOUT_SECS: u64 = 120;
14const REDIRECT_SUCCESS_HTML: &str = r#"<!DOCTYPE html>
15<html><head><title>Signed in</title></head>
16<body style="font-family:sans-serif;text-align:center;padding:4em">
17<h1>Signed in to colab-cli</h1>
18<p>You can close this tab.</p>
19</body></html>"#;
20
21pub const REQUIRED_SCOPES: &[&str] = &[
22    "profile",
23    "email",
24    "https://www.googleapis.com/auth/colaboratory",
25];
26
27#[derive(Debug, Deserialize)]
28struct TokenResponse {
29    access_token: String,
30    refresh_token: Option<String>,
31    expires_in: Option<i64>,
32    #[allow(dead_code)]
33    token_type: String,
34}
35
36#[derive(Debug, Deserialize)]
37struct UserInfoResponse {
38    name: String,
39    email: String,
40}
41
42pub async fn run_login_flow(config: &ColabConfig) -> Result<AccountInfo> {
43    let (code, redirect_uri, code_verifier) = tokio::time::timeout(
44        std::time::Duration::from_secs(FLOW_TIMEOUT_SECS),
45        wait_for_auth_code(config),
46    )
47    .await
48    .map_err(|_| ColabError::oauth("authentication timed out (2 min)"))?
49    .map_err(|e| ColabError::AuthFailed(e.to_string()))?;
50
51    let http = Client::builder()
52        .use_rustls_tls()
53        .build()
54        .map_err(ColabError::Network)?;
55    let tokens = exchange_code(&http, config, &code, &redirect_uri, &code_verifier).await?;
56
57    let expires_at = Utc::now() + Duration::seconds(tokens.expires_in.unwrap_or(3600));
58
59    TokenStorage::store_access_token(&tokens.access_token, expires_at)?;
60
61    let refresh_token = tokens
62        .refresh_token
63        .ok_or_else(|| ColabError::oauth("no refresh token in response"))?;
64    TokenStorage::store_refresh_token(&refresh_token)?;
65
66    let account = fetch_user_info(&http, &tokens.access_token).await?;
67    TokenStorage::store_account(&account)?;
68
69    Ok(account)
70}
71
72pub async fn refresh_access_token(config: &ColabConfig) -> Result<String> {
73    let refresh_token = TokenStorage::get_refresh_token()?.ok_or(ColabError::NotAuthenticated)?;
74
75    let http = Client::builder()
76        .use_rustls_tls()
77        .build()
78        .map_err(ColabError::Network)?;
79
80    let params = [
81        ("client_id", config.client_id.as_str()),
82        ("client_secret", config.client_secret.as_str()),
83        ("refresh_token", refresh_token.as_str()),
84        ("grant_type", "refresh_token"),
85    ];
86
87    let resp = http
88        .post("https://oauth2.googleapis.com/token")
89        .form(&params)
90        .send()
91        .await?;
92
93    if !resp.status().is_success() {
94        let status = resp.status().as_u16();
95        let body = resp.text().await.ok();
96        return Err(ColabError::TokenRefreshFailed {
97            reason: body.unwrap_or_else(|| format!("HTTP {status}")),
98        });
99    }
100
101    let tokens: TokenResponse = resp.json().await?;
102    let expires_at = Utc::now() + Duration::seconds(tokens.expires_in.unwrap_or(3600));
103    TokenStorage::store_access_token(&tokens.access_token, expires_at)?;
104
105    Ok(tokens.access_token)
106}
107
108async fn wait_for_auth_code(config: &ColabConfig) -> Result<(String, String, String)> {
109    let listener = TcpListener::bind("127.0.0.1:0").await?;
110    let port = listener.local_addr()?.port();
111    let redirect_uri = format!("http://127.0.0.1:{port}");
112
113    let nonce: String = {
114        use std::fmt::Write;
115        let bytes: [u8; 16] = rand_bytes();
116        let mut s = String::with_capacity(32);
117        for b in &bytes {
118            let _ = write!(s, "{b:02x}");
119        }
120        s
121    };
122
123    let (code_verifier, code_challenge) = pkce_pair();
124
125    let auth_url = build_auth_url(config, &redirect_uri, &nonce, &code_challenge);
126
127    if let Err(e) = open_browser(&auth_url) {
128        eprintln!("Could not open browser automatically: {e}");
129        eprintln!("Open this URL manually:\n  {auth_url}");
130    }
131
132    let code = accept_one_redirect(&listener, &nonce).await?;
133
134    Ok((code, redirect_uri, code_verifier))
135}
136
137async fn accept_one_redirect(listener: &TcpListener, expected_nonce: &str) -> Result<String> {
138    let (stream, _) = listener.accept().await?;
139    let (reader, mut writer) = stream.into_split();
140    let mut reader = BufReader::new(reader);
141
142    let mut request_line = String::new();
143    reader.read_line(&mut request_line).await?;
144
145    let path = request_line
146        .split_whitespace()
147        .nth(1)
148        .ok_or_else(|| ColabError::oauth("malformed HTTP request from browser"))?;
149
150    let url = url::Url::parse(&format!("http://localhost{path}"))
151        .map_err(|e| ColabError::oauth(format!("invalid redirect URL: {e}")))?;
152
153    let state = url
154        .query_pairs()
155        .find(|(k, _)| k == "state")
156        .map(|(_, v)| v.into_owned())
157        .ok_or_else(|| ColabError::oauth("missing state in redirect"))?;
158
159    let received_nonce = state
160        .strip_prefix("nonce=")
161        .ok_or_else(|| ColabError::oauth("invalid state format in redirect"))?;
162
163    if received_nonce != expected_nonce {
164        return Err(ColabError::oauth("nonce mismatch — possible CSRF"));
165    }
166
167    let code = url
168        .query_pairs()
169        .find(|(k, _)| k == "code")
170        .map(|(_, v)| v.into_owned())
171        .ok_or_else(|| ColabError::oauth("missing authorization code in redirect"))?;
172
173    let body = REDIRECT_SUCCESS_HTML.as_bytes();
174    let response = format!(
175        "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
176        body.len()
177    );
178    writer.write_all(response.as_bytes()).await?;
179    writer.write_all(body).await?;
180    writer.flush().await?;
181
182    Ok(code)
183}
184
185async fn exchange_code(
186    http: &Client,
187    config: &ColabConfig,
188    code: &str,
189    redirect_uri: &str,
190    code_verifier: &str,
191) -> Result<TokenResponse> {
192    let params = [
193        ("client_id", config.client_id.as_str()),
194        ("client_secret", config.client_secret.as_str()),
195        ("code", code),
196        ("code_verifier", code_verifier),
197        ("redirect_uri", redirect_uri),
198        ("grant_type", "authorization_code"),
199    ];
200
201    let resp = http
202        .post("https://oauth2.googleapis.com/token")
203        .form(&params)
204        .send()
205        .await?;
206
207    if !resp.status().is_success() {
208        let status = resp.status().as_u16();
209        let body = resp.text().await.ok();
210        return Err(ColabError::oauth(format!(
211            "token exchange failed (HTTP {status}): {}",
212            body.as_deref().unwrap_or("no body")
213        )));
214    }
215
216    Ok(resp.json().await?)
217}
218
219async fn fetch_user_info(http: &Client, access_token: &str) -> Result<AccountInfo> {
220    let resp = http
221        .get("https://www.googleapis.com/oauth2/v2/userinfo")
222        .bearer_auth(access_token)
223        .send()
224        .await?;
225
226    if !resp.status().is_success() {
227        return Err(ColabError::oauth("failed to fetch user info"));
228    }
229
230    let info: UserInfoResponse = resp.json().await?;
231    Ok(AccountInfo {
232        email: info.email,
233        name: info.name,
234    })
235}
236
237fn pkce_pair() -> (String, String) {
238    let verifier_bytes = rand_bytes_n(64);
239    let verifier = URL_SAFE_NO_PAD.encode(&verifier_bytes);
240
241    use sha2::{Digest, Sha256};
242    let mut hasher = Sha256::new();
243    hasher.update(verifier.as_bytes());
244    let challenge = URL_SAFE_NO_PAD.encode(hasher.finalize());
245
246    (verifier, challenge)
247}
248
249fn rand_bytes() -> [u8; 16] {
250    let mut b = [0u8; 16];
251    getrandom::getrandom(&mut b).expect("getrandom failed");
252    b
253}
254
255fn rand_bytes_n(n: usize) -> Vec<u8> {
256    let mut b = vec![0u8; n];
257    getrandom::getrandom(&mut b).expect("getrandom failed");
258    b
259}
260
261fn build_auth_url(
262    config: &ColabConfig,
263    redirect_uri: &str,
264    nonce: &str,
265    code_challenge: &str,
266) -> String {
267    let scopes = REQUIRED_SCOPES.join(" ");
268    let encoded_redirect = urlencoding::encode(redirect_uri);
269    let encoded_scopes = urlencoding::encode(&scopes);
270    let encoded_challenge = urlencoding::encode(code_challenge);
271    let state = format!("nonce={nonce}");
272    let encoded_state = urlencoding::encode(&state);
273
274    format!(
275        "https://accounts.google.com/o/oauth2/v2/auth\
276?client_id={}\
277&redirect_uri={encoded_redirect}\
278&response_type=code\
279&scope={encoded_scopes}\
280&state={encoded_state}\
281&code_challenge={encoded_challenge}\
282&code_challenge_method=S256\
283&access_type=offline\
284&prompt=consent",
285        config.client_id
286    )
287}
288
289fn open_browser(url: &str) -> std::result::Result<(), String> {
290    #[cfg(target_os = "macos")]
291    std::process::Command::new("open")
292        .arg(url)
293        .spawn()
294        .map_err(|e| e.to_string())?;
295
296    #[cfg(target_os = "linux")]
297    std::process::Command::new("xdg-open")
298        .arg(url)
299        .spawn()
300        .map_err(|e| e.to_string())?;
301
302    #[cfg(target_os = "windows")]
303    std::process::Command::new("cmd")
304        .args(["/C", "start", url])
305        .spawn()
306        .map_err(|e| e.to_string())?;
307
308    Ok(())
309}