1use base64::Engine;
2use base64::engine::general_purpose::URL_SAFE_NO_PAD;
3use chrono::{Duration, Utc};
4use reqwest::Client;
5use serde::Deserialize;
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::net::TcpListener;
8
9use crate::auth::storage::{AccountInfo, TokenStorage};
10use crate::config::ColabConfig;
11use crate::error::{ColabError, Result};
12
13const FLOW_TIMEOUT_SECS: u64 = 120;
14const REDIRECT_SUCCESS_HTML: &str = r#"<!DOCTYPE html>
15<html><head><title>Signed in</title></head>
16<body style="font-family:sans-serif;text-align:center;padding:4em">
17<h1>Signed in to colab-cli</h1>
18<p>You can close this tab.</p>
19</body></html>"#;
20
21pub const REQUIRED_SCOPES: &[&str] = &[
22 "profile",
23 "email",
24 "https://www.googleapis.com/auth/colaboratory",
25];
26
27#[derive(Debug, Deserialize)]
28struct TokenResponse {
29 access_token: String,
30 refresh_token: Option<String>,
31 expires_in: Option<i64>,
32 #[allow(dead_code)]
33 token_type: String,
34}
35
36#[derive(Debug, Deserialize)]
37struct UserInfoResponse {
38 name: String,
39 email: String,
40}
41
42pub async fn run_login_flow(config: &ColabConfig) -> Result<AccountInfo> {
43 let (code, redirect_uri, code_verifier) = tokio::time::timeout(
44 std::time::Duration::from_secs(FLOW_TIMEOUT_SECS),
45 wait_for_auth_code(config),
46 )
47 .await
48 .map_err(|_| ColabError::oauth("authentication timed out (2 min)"))?
49 .map_err(|e| ColabError::AuthFailed(e.to_string()))?;
50
51 let http = Client::builder()
52 .use_rustls_tls()
53 .build()
54 .map_err(ColabError::Network)?;
55 let tokens = exchange_code(&http, config, &code, &redirect_uri, &code_verifier).await?;
56
57 let expires_at = Utc::now() + Duration::seconds(tokens.expires_in.unwrap_or(3600));
58
59 TokenStorage::store_access_token(&tokens.access_token, expires_at)?;
60
61 let refresh_token = tokens
62 .refresh_token
63 .ok_or_else(|| ColabError::oauth("no refresh token in response"))?;
64 TokenStorage::store_refresh_token(&refresh_token)?;
65
66 let account = fetch_user_info(&http, &tokens.access_token).await?;
67 TokenStorage::store_account(&account)?;
68
69 Ok(account)
70}
71
72pub async fn refresh_access_token(config: &ColabConfig) -> Result<String> {
73 let refresh_token = TokenStorage::get_refresh_token()?.ok_or(ColabError::NotAuthenticated)?;
74
75 let http = Client::builder()
76 .use_rustls_tls()
77 .build()
78 .map_err(ColabError::Network)?;
79
80 let params = [
81 ("client_id", config.client_id.as_str()),
82 ("client_secret", config.client_secret.as_str()),
83 ("refresh_token", refresh_token.as_str()),
84 ("grant_type", "refresh_token"),
85 ];
86
87 let resp = http
88 .post("https://oauth2.googleapis.com/token")
89 .form(¶ms)
90 .send()
91 .await?;
92
93 if !resp.status().is_success() {
94 let status = resp.status().as_u16();
95 let body = resp.text().await.ok();
96 return Err(ColabError::TokenRefreshFailed {
97 reason: body.unwrap_or_else(|| format!("HTTP {status}")),
98 });
99 }
100
101 let tokens: TokenResponse = resp.json().await?;
102 let expires_at = Utc::now() + Duration::seconds(tokens.expires_in.unwrap_or(3600));
103 TokenStorage::store_access_token(&tokens.access_token, expires_at)?;
104
105 Ok(tokens.access_token)
106}
107
108async fn wait_for_auth_code(config: &ColabConfig) -> Result<(String, String, String)> {
109 let listener = TcpListener::bind("127.0.0.1:0").await?;
110 let port = listener.local_addr()?.port();
111 let redirect_uri = format!("http://127.0.0.1:{port}");
112
113 let nonce: String = {
114 use std::fmt::Write;
115 let bytes: [u8; 16] = rand_bytes();
116 let mut s = String::with_capacity(32);
117 for b in &bytes {
118 let _ = write!(s, "{b:02x}");
119 }
120 s
121 };
122
123 let (code_verifier, code_challenge) = pkce_pair();
124
125 let auth_url = build_auth_url(config, &redirect_uri, &nonce, &code_challenge);
126
127 if let Err(e) = open_browser(&auth_url) {
128 eprintln!("Could not open browser automatically: {e}");
129 eprintln!("Open this URL manually:\n {auth_url}");
130 }
131
132 let code = accept_one_redirect(&listener, &nonce).await?;
133
134 Ok((code, redirect_uri, code_verifier))
135}
136
137async fn accept_one_redirect(listener: &TcpListener, expected_nonce: &str) -> Result<String> {
138 let (stream, _) = listener.accept().await?;
139 let (reader, mut writer) = stream.into_split();
140 let mut reader = BufReader::new(reader);
141
142 let mut request_line = String::new();
143 reader.read_line(&mut request_line).await?;
144
145 let path = request_line
146 .split_whitespace()
147 .nth(1)
148 .ok_or_else(|| ColabError::oauth("malformed HTTP request from browser"))?;
149
150 let url = url::Url::parse(&format!("http://localhost{path}"))
151 .map_err(|e| ColabError::oauth(format!("invalid redirect URL: {e}")))?;
152
153 let state = url
154 .query_pairs()
155 .find(|(k, _)| k == "state")
156 .map(|(_, v)| v.into_owned())
157 .ok_or_else(|| ColabError::oauth("missing state in redirect"))?;
158
159 let received_nonce = state
160 .strip_prefix("nonce=")
161 .ok_or_else(|| ColabError::oauth("invalid state format in redirect"))?;
162
163 if received_nonce != expected_nonce {
164 return Err(ColabError::oauth("nonce mismatch — possible CSRF"));
165 }
166
167 let code = url
168 .query_pairs()
169 .find(|(k, _)| k == "code")
170 .map(|(_, v)| v.into_owned())
171 .ok_or_else(|| ColabError::oauth("missing authorization code in redirect"))?;
172
173 let body = REDIRECT_SUCCESS_HTML.as_bytes();
174 let response = format!(
175 "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
176 body.len()
177 );
178 writer.write_all(response.as_bytes()).await?;
179 writer.write_all(body).await?;
180 writer.flush().await?;
181
182 Ok(code)
183}
184
185async fn exchange_code(
186 http: &Client,
187 config: &ColabConfig,
188 code: &str,
189 redirect_uri: &str,
190 code_verifier: &str,
191) -> Result<TokenResponse> {
192 let params = [
193 ("client_id", config.client_id.as_str()),
194 ("client_secret", config.client_secret.as_str()),
195 ("code", code),
196 ("code_verifier", code_verifier),
197 ("redirect_uri", redirect_uri),
198 ("grant_type", "authorization_code"),
199 ];
200
201 let resp = http
202 .post("https://oauth2.googleapis.com/token")
203 .form(¶ms)
204 .send()
205 .await?;
206
207 if !resp.status().is_success() {
208 let status = resp.status().as_u16();
209 let body = resp.text().await.ok();
210 return Err(ColabError::oauth(format!(
211 "token exchange failed (HTTP {status}): {}",
212 body.as_deref().unwrap_or("no body")
213 )));
214 }
215
216 Ok(resp.json().await?)
217}
218
219async fn fetch_user_info(http: &Client, access_token: &str) -> Result<AccountInfo> {
220 let resp = http
221 .get("https://www.googleapis.com/oauth2/v2/userinfo")
222 .bearer_auth(access_token)
223 .send()
224 .await?;
225
226 if !resp.status().is_success() {
227 return Err(ColabError::oauth("failed to fetch user info"));
228 }
229
230 let info: UserInfoResponse = resp.json().await?;
231 Ok(AccountInfo {
232 email: info.email,
233 name: info.name,
234 })
235}
236
237fn pkce_pair() -> (String, String) {
238 let verifier_bytes = rand_bytes_n(64);
239 let verifier = URL_SAFE_NO_PAD.encode(&verifier_bytes);
240
241 use sha2::{Digest, Sha256};
242 let mut hasher = Sha256::new();
243 hasher.update(verifier.as_bytes());
244 let challenge = URL_SAFE_NO_PAD.encode(hasher.finalize());
245
246 (verifier, challenge)
247}
248
249fn rand_bytes() -> [u8; 16] {
250 let mut b = [0u8; 16];
251 getrandom::getrandom(&mut b).expect("getrandom failed");
252 b
253}
254
255fn rand_bytes_n(n: usize) -> Vec<u8> {
256 let mut b = vec![0u8; n];
257 getrandom::getrandom(&mut b).expect("getrandom failed");
258 b
259}
260
261fn build_auth_url(
262 config: &ColabConfig,
263 redirect_uri: &str,
264 nonce: &str,
265 code_challenge: &str,
266) -> String {
267 let scopes = REQUIRED_SCOPES.join(" ");
268 let encoded_redirect = urlencoding::encode(redirect_uri);
269 let encoded_scopes = urlencoding::encode(&scopes);
270 let encoded_challenge = urlencoding::encode(code_challenge);
271 let state = format!("nonce={nonce}");
272 let encoded_state = urlencoding::encode(&state);
273
274 format!(
275 "https://accounts.google.com/o/oauth2/v2/auth\
276?client_id={}\
277&redirect_uri={encoded_redirect}\
278&response_type=code\
279&scope={encoded_scopes}\
280&state={encoded_state}\
281&code_challenge={encoded_challenge}\
282&code_challenge_method=S256\
283&access_type=offline\
284&prompt=consent",
285 config.client_id
286 )
287}
288
289fn open_browser(url: &str) -> std::result::Result<(), String> {
290 #[cfg(target_os = "macos")]
291 std::process::Command::new("open")
292 .arg(url)
293 .spawn()
294 .map_err(|e| e.to_string())?;
295
296 #[cfg(target_os = "linux")]
297 std::process::Command::new("xdg-open")
298 .arg(url)
299 .spawn()
300 .map_err(|e| e.to_string())?;
301
302 #[cfg(target_os = "windows")]
303 std::process::Command::new("cmd")
304 .args(["/C", "start", url])
305 .spawn()
306 .map_err(|e| e.to_string())?;
307
308 Ok(())
309}