cargo_faasta/
github_oauth.rs1use anyhow::{Result, anyhow};
2use cyper::Client as HttpClient;
3use oauth2::http as oauth_http;
4use oauth2::{
5 AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, HttpRequest, HttpResponse,
6 RedirectUrl, Scope, TokenResponse, TokenUrl, basic::BasicClient,
7};
8use serde::Deserialize;
9use std::{net::SocketAddr, str::FromStr};
10use tiny_http::{Response, Server};
11use url::Url;
12
13const DEFAULT_CLIENT_ID: &str = "Iv23lik79igmHPi63dO1";
15const DEFAULT_CLIENT_SECRET: &str = "2a10cd3c2465622a1649b766e574f15eb9211eb7";
16const REDIRECT_PORT: u16 = 9876;
17
18type GithubOAuthClient = BasicClient<
19 oauth2::EndpointSet,
20 oauth2::EndpointNotSet,
21 oauth2::EndpointNotSet,
22 oauth2::EndpointNotSet,
23 oauth2::EndpointSet,
24>;
25
26use std::sync::Mutex;
27use std::sync::atomic::{AtomicBool, Ordering};
28
29static TEST_MODE: AtomicBool = AtomicBool::new(false);
31static TEST_USERNAME: Mutex<Option<String>> = Mutex::new(None);
32static TEST_TOKEN: Mutex<Option<String>> = Mutex::new(None);
33
34#[cfg(test)]
35pub fn enable_test_mode(username: String, token: String) {
37 TEST_MODE.store(true, Ordering::Relaxed);
38 *TEST_USERNAME.lock().unwrap() = Some(username);
39 *TEST_TOKEN.lock().unwrap() = Some(token);
40}
41
42fn get_test_data() -> (bool, Option<String>, Option<String>) {
44 (
45 TEST_MODE.load(Ordering::Relaxed),
46 TEST_USERNAME.lock().unwrap().clone(),
47 TEST_TOKEN.lock().unwrap().clone(),
48 )
49}
50
51fn get_client_id() -> String {
53 std::env::var("FAASTA_GITHUB_CLIENT_ID").unwrap_or_else(|_| DEFAULT_CLIENT_ID.to_string())
54}
55
56fn get_client_secret() -> String {
58 std::env::var("FAASTA_GITHUB_CLIENT_SECRET")
59 .unwrap_or_else(|_| DEFAULT_CLIENT_SECRET.to_string())
60}
61
62#[derive(Debug, Deserialize)]
64struct GitHubUser {
65 login: String,
66}
67
68pub async fn github_oauth_flow() -> Result<(String, String)> {
70 let (is_test_mode, test_username, test_token) = get_test_data();
72 if is_test_mode {
73 if let (Some(username), Some(token)) = (test_username, test_token) {
74 println!("Using test credentials");
75 return Ok((username, format!("Bearer {token}")));
76 }
77 }
78
79 let github_client = get_oauth_client()?;
81
82 let (authorize_url, csrf_state) = github_client
84 .authorize_url(CsrfToken::new_random)
85 .add_scope(Scope::new("user:email".to_string()))
86 .url();
87
88 let server = start_redirect_server()?;
90
91 println!("Opening browser for GitHub authentication...");
93 println!("Authorization URL: {authorize_url}");
94 if let Err(e) = open::that(authorize_url.to_string()) {
95 return Err(anyhow!("Failed to open browser: {}", e));
96 }
97
98 println!("Waiting for GitHub authentication...");
100 let auth_code = wait_for_callback(server, &csrf_state)?;
101
102 println!("Exchanging authorization code for token...");
104 let token = match github_client
105 .exchange_code(AuthorizationCode::new(auth_code))
106 .request_async(&cyper_async_http_client)
107 .await
108 {
109 Ok(token) => token,
110 Err(e) => {
111 println!("Error exchanging code for token: {e:?}");
112 return Err(anyhow!(
113 "Failed to exchange authorization code for token: {}",
114 e
115 ));
116 }
117 };
118
119 let access_token = token.access_token().secret();
121
122 println!("Getting GitHub user information...");
124 let username = get_github_username(access_token).await?;
125
126 Ok((username, format!("Bearer {access_token}")))
127}
128
129fn get_oauth_client() -> Result<GithubOAuthClient> {
131 let redirect_url = format!("http://localhost:{REDIRECT_PORT}/oauth/callback");
132 println!("Redirect URL: {redirect_url}");
133
134 Ok(BasicClient::new(ClientId::new(get_client_id()))
135 .set_client_secret(ClientSecret::new(get_client_secret()))
136 .set_auth_uri(AuthUrl::new(
137 "https://github.com/login/oauth/authorize".to_string(),
138 )?)
139 .set_token_uri(TokenUrl::new(
140 "https://github.com/login/oauth/access_token".to_string(),
141 )?)
142 .set_redirect_uri(RedirectUrl::new(redirect_url)?))
143}
144
145fn start_redirect_server() -> Result<Server> {
147 let addr = SocketAddr::from_str(&format!("127.0.0.1:{REDIRECT_PORT}"))?;
148 let server = Server::http(addr).map_err(|e| anyhow!("Failed to start server: {}", e))?;
149 Ok(server)
150}
151
152fn wait_for_callback(server: Server, csrf_state: &CsrfToken) -> Result<String> {
154 let req = server.recv()?;
156
157 let url_str = format!("http://localhost{}", req.url());
159 let url = Url::parse(&url_str)?;
160
161 let mut code = None;
163 let mut state = None;
164
165 for (key, value) in url.query_pairs() {
166 if key == "code" {
167 code = Some(value.to_string());
168 } else if key == "state" {
169 state = Some(value.to_string());
170 }
171 }
172
173 if state.as_deref() != Some(csrf_state.secret()) {
175 let error_html = "<html><body><h1>Authentication Error</h1><p>Invalid state parameter. This could be a CSRF attack.</p></body></html>";
177 req.respond(Response::from_string(error_html))?;
178
179 return Err(anyhow!("Invalid OAuth state"));
180 }
181
182 match code {
184 Some(code_value) => {
185 let success_html = "<h1>Authentication Successful!</h1><p>You can now close this window and return to the terminal.</p>";
187 req.respond(Response::from_string(success_html))?;
188
189 Ok(code_value)
190 }
191 None => {
192 let error_html =
194 "<h1>Authentication Error</h1><p>No authorization code received from GitHub.</p>";
195 req.respond(Response::from_string(error_html))?;
196
197 Err(anyhow!("No authorization code received"))
198 }
199 }
200}
201
202async fn get_github_username(token: &str) -> Result<String> {
204 let user: GitHubUser = HttpClient::new()
205 .get("https://api.github.com/user")?
206 .header("User-Agent", "faasta-cli")?
207 .header("Authorization", format!("Bearer {token}"))?
208 .send()
209 .await?
210 .json()
211 .await?;
212
213 Ok(user.login)
214}
215
216async fn cyper_async_http_client(
217 request: HttpRequest,
218) -> std::result::Result<HttpResponse, cyper::Error> {
219 let method = request.method().clone();
220
221 let mut outbound_headers = http::HeaderMap::new();
222 for (name, value) in request.headers().iter() {
223 outbound_headers.append(name.clone(), value.clone());
224 }
225
226 let response = HttpClient::new()
227 .request(method, request.uri().to_string())?
228 .headers(outbound_headers)
229 .body(request.body().clone())
230 .send()
231 .await?;
232
233 let mut inbound_headers = oauth_http::HeaderMap::new();
234 for (name, value) in response.headers().iter() {
235 inbound_headers.append(name.clone(), value.clone());
236 }
237
238 let status_code = oauth_http::StatusCode::from_u16(response.status().as_u16())
239 .expect("response status code should be valid");
240 let body = response.bytes().await?.to_vec();
241
242 let mut response_builder = oauth_http::Response::builder().status(status_code);
243 {
244 let headers = response_builder
245 .headers_mut()
246 .expect("builder should be valid");
247 *headers = inbound_headers;
248 }
249
250 Ok(response_builder.body(body)?)
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[compio::test]
258 async fn test_oauth_flow_with_test_mode() {
259 enable_test_mode("test_user".to_string(), "test_token".to_string());
261
262 let result = github_oauth_flow().await;
264
265 assert!(result.is_ok());
267 let (username, token) = result.unwrap();
268 assert_eq!(username, "test_user");
269 assert_eq!(token, "Bearer test_token");
270 }
271}