Skip to main content

cargo_faasta/
github_oauth.rs

1use anyhow::{Result, anyhow};
2use cyper::Client as HttpClient;
3use oauth2::http as oauth_http;
4use oauth2::{
5    AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, HttpRequest, HttpResponse,
6    RedirectUrl, Scope, TokenResponse, TokenUrl, basic::BasicClient,
7};
8use serde::Deserialize;
9use std::{net::SocketAddr, str::FromStr};
10use tiny_http::{Response, Server};
11use url::Url;
12
13// GitHub OAuth app details
14const DEFAULT_CLIENT_ID: &str = "Iv23lik79igmHPi63dO1";
15const DEFAULT_CLIENT_SECRET: &str = "2a10cd3c2465622a1649b766e574f15eb9211eb7";
16const REDIRECT_PORT: u16 = 9876;
17
18type GithubOAuthClient = BasicClient<
19    oauth2::EndpointSet,
20    oauth2::EndpointNotSet,
21    oauth2::EndpointNotSet,
22    oauth2::EndpointNotSet,
23    oauth2::EndpointSet,
24>;
25
26use std::sync::Mutex;
27use std::sync::atomic::{AtomicBool, Ordering};
28
29// Test mode flag
30static TEST_MODE: AtomicBool = AtomicBool::new(false);
31static TEST_USERNAME: Mutex<Option<String>> = Mutex::new(None);
32static TEST_TOKEN: Mutex<Option<String>> = Mutex::new(None);
33
34#[cfg(test)]
35/// Enable test mode with given username and token.
36pub fn enable_test_mode(username: String, token: String) {
37    TEST_MODE.store(true, Ordering::Relaxed);
38    *TEST_USERNAME.lock().unwrap() = Some(username);
39    *TEST_TOKEN.lock().unwrap() = Some(token);
40}
41
42/// Get the test mode status and credentials
43fn get_test_data() -> (bool, Option<String>, Option<String>) {
44    (
45        TEST_MODE.load(Ordering::Relaxed),
46        TEST_USERNAME.lock().unwrap().clone(),
47        TEST_TOKEN.lock().unwrap().clone(),
48    )
49}
50
51/// Get client ID from environment or use default
52fn get_client_id() -> String {
53    std::env::var("FAASTA_GITHUB_CLIENT_ID").unwrap_or_else(|_| DEFAULT_CLIENT_ID.to_string())
54}
55
56/// Get client secret from environment or use default
57fn get_client_secret() -> String {
58    std::env::var("FAASTA_GITHUB_CLIENT_SECRET")
59        .unwrap_or_else(|_| DEFAULT_CLIENT_SECRET.to_string())
60}
61
62// Structure to hold user info from GitHub API
63#[derive(Debug, Deserialize)]
64struct GitHubUser {
65    login: String,
66}
67
68/// Performs the GitHub OAuth flow and returns the username and token
69pub async fn github_oauth_flow() -> Result<(String, String)> {
70    // Check if we're in test mode
71    let (is_test_mode, test_username, test_token) = get_test_data();
72    if is_test_mode {
73        if let (Some(username), Some(token)) = (test_username, test_token) {
74            println!("Using test credentials");
75            return Ok((username, format!("Bearer {token}")));
76        }
77    }
78
79    // Set up the OAuth2 client
80    let github_client = get_oauth_client()?;
81
82    // Generate the authorization URL
83    let (authorize_url, csrf_state) = github_client
84        .authorize_url(CsrfToken::new_random)
85        .add_scope(Scope::new("user:email".to_string()))
86        .url();
87
88    // Start the redirect server
89    let server = start_redirect_server()?;
90
91    // Open the browser to authenticate the user
92    println!("Opening browser for GitHub authentication...");
93    println!("Authorization URL: {authorize_url}");
94    if let Err(e) = open::that(authorize_url.to_string()) {
95        return Err(anyhow!("Failed to open browser: {}", e));
96    }
97
98    // Wait for the callback from GitHub
99    println!("Waiting for GitHub authentication...");
100    let auth_code = wait_for_callback(server, &csrf_state)?;
101
102    // Exchange the authorization code for a token
103    println!("Exchanging authorization code for token...");
104    let token = match github_client
105        .exchange_code(AuthorizationCode::new(auth_code))
106        .request_async(&cyper_async_http_client)
107        .await
108    {
109        Ok(token) => token,
110        Err(e) => {
111            println!("Error exchanging code for token: {e:?}");
112            return Err(anyhow!(
113                "Failed to exchange authorization code for token: {}",
114                e
115            ));
116        }
117    };
118
119    // Get the access token as a string
120    let access_token = token.access_token().secret();
121
122    // Get the user's GitHub info using the token
123    println!("Getting GitHub user information...");
124    let username = get_github_username(access_token).await?;
125
126    Ok((username, format!("Bearer {access_token}")))
127}
128
129/// Create an OAuth client for GitHub
130fn get_oauth_client() -> Result<GithubOAuthClient> {
131    let redirect_url = format!("http://localhost:{REDIRECT_PORT}/oauth/callback");
132    println!("Redirect URL: {redirect_url}");
133
134    Ok(BasicClient::new(ClientId::new(get_client_id()))
135        .set_client_secret(ClientSecret::new(get_client_secret()))
136        .set_auth_uri(AuthUrl::new(
137            "https://github.com/login/oauth/authorize".to_string(),
138        )?)
139        .set_token_uri(TokenUrl::new(
140            "https://github.com/login/oauth/access_token".to_string(),
141        )?)
142        .set_redirect_uri(RedirectUrl::new(redirect_url)?))
143}
144
145/// Starts a local HTTP server to receive the OAuth redirect
146fn start_redirect_server() -> Result<Server> {
147    let addr = SocketAddr::from_str(&format!("127.0.0.1:{REDIRECT_PORT}"))?;
148    let server = Server::http(addr).map_err(|e| anyhow!("Failed to start server: {}", e))?;
149    Ok(server)
150}
151
152/// Waits for and processes the OAuth callback
153fn wait_for_callback(server: Server, csrf_state: &CsrfToken) -> Result<String> {
154    // Wait for the callback from GitHub
155    let req = server.recv()?;
156
157    // Parse the request URL to extract the code and state
158    let url_str = format!("http://localhost{}", req.url());
159    let url = Url::parse(&url_str)?;
160
161    // Extract query parameters
162    let mut code = None;
163    let mut state = None;
164
165    for (key, value) in url.query_pairs() {
166        if key == "code" {
167            code = Some(value.to_string());
168        } else if key == "state" {
169            state = Some(value.to_string());
170        }
171    }
172
173    // Verify the state to prevent CSRF attacks
174    if state.as_deref() != Some(csrf_state.secret()) {
175        // Send an error response to the browser
176        let error_html = "<html><body><h1>Authentication Error</h1><p>Invalid state parameter. This could be a CSRF attack.</p></body></html>";
177        req.respond(Response::from_string(error_html))?;
178
179        return Err(anyhow!("Invalid OAuth state"));
180    }
181
182    // Check for the code and respond appropriately
183    match code {
184        Some(code_value) => {
185            // Send a success response to the browser
186            let success_html = "<h1>Authentication Successful!</h1><p>You can now close this window and return to the terminal.</p>";
187            req.respond(Response::from_string(success_html))?;
188
189            Ok(code_value)
190        }
191        None => {
192            // Send an error response for missing code
193            let error_html =
194                "<h1>Authentication Error</h1><p>No authorization code received from GitHub.</p>";
195            req.respond(Response::from_string(error_html))?;
196
197            Err(anyhow!("No authorization code received"))
198        }
199    }
200}
201
202/// Gets the GitHub username from the user's profile
203async fn get_github_username(token: &str) -> Result<String> {
204    let user: GitHubUser = HttpClient::new()
205        .get("https://api.github.com/user")?
206        .header("User-Agent", "faasta-cli")?
207        .header("Authorization", format!("Bearer {token}"))?
208        .send()
209        .await?
210        .json()
211        .await?;
212
213    Ok(user.login)
214}
215
216async fn cyper_async_http_client(
217    request: HttpRequest,
218) -> std::result::Result<HttpResponse, cyper::Error> {
219    let method = request.method().clone();
220
221    let mut outbound_headers = http::HeaderMap::new();
222    for (name, value) in request.headers().iter() {
223        outbound_headers.append(name.clone(), value.clone());
224    }
225
226    let response = HttpClient::new()
227        .request(method, request.uri().to_string())?
228        .headers(outbound_headers)
229        .body(request.body().clone())
230        .send()
231        .await?;
232
233    let mut inbound_headers = oauth_http::HeaderMap::new();
234    for (name, value) in response.headers().iter() {
235        inbound_headers.append(name.clone(), value.clone());
236    }
237
238    let status_code = oauth_http::StatusCode::from_u16(response.status().as_u16())
239        .expect("response status code should be valid");
240    let body = response.bytes().await?.to_vec();
241
242    let mut response_builder = oauth_http::Response::builder().status(status_code);
243    {
244        let headers = response_builder
245            .headers_mut()
246            .expect("builder should be valid");
247        *headers = inbound_headers;
248    }
249
250    Ok(response_builder.body(body)?)
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[compio::test]
258    async fn test_oauth_flow_with_test_mode() {
259        // Set up test mode
260        enable_test_mode("test_user".to_string(), "test_token".to_string());
261
262        // Run the OAuth flow
263        let result = github_oauth_flow().await;
264
265        // Check the result
266        assert!(result.is_ok());
267        let (username, token) = result.unwrap();
268        assert_eq!(username, "test_user");
269        assert_eq!(token, "Bearer test_token");
270    }
271}