Skip to main content

raps_kernel/auth/
three_leg.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2025 Dmytro Yemelianov
3
4//! 3-legged OAuth (authorization code) flow
5
6use anyhow::{Context, Result};
7use std::time::Duration;
8use tiny_http::{Response, Server};
9
10use super::AuthClient;
11use super::types::TokenResponse;
12use crate::config::DEFAULT_CALLBACK_PORT;
13use crate::types::StoredToken;
14
15impl AuthClient {
16    /// Get a valid 3-legged access token (requires prior login)
17    ///
18    /// Uses Mutex-based coordination to ensure only one refresh occurs at a time.
19    /// Concurrent callers wait and receive the newly refreshed token.
20    pub async fn get_3leg_token(&self) -> Result<String> {
21        loop {
22            let refresh_token_to_use: Option<String>;
23            {
24                let cache = self.cached_3leg_token.lock().await;
25                if let Some(ref token) = cache.token {
26                    if token.is_valid() {
27                        return Ok(token.access_token.clone());
28                    }
29                    if cache.refreshing {
30                        // Another task is already refreshing; drop lock and wait
31                        drop(cache);
32                        tokio::time::sleep(Duration::from_millis(100)).await;
33                        continue;
34                    }
35                    refresh_token_to_use = token.refresh_token.clone();
36                } else {
37                    refresh_token_to_use = None;
38                }
39            }
40
41            // Try to refresh if we have a refresh token
42            if let Some(refresh) = refresh_token_to_use {
43                // Mark as refreshing
44                {
45                    let mut cache = self.cached_3leg_token.lock().await;
46                    cache.refreshing = true;
47                }
48                let result = self.refresh_token(refresh).await;
49                // Always reset the refreshing flag, even on error
50                if result.is_err() {
51                    let mut cache = self.cached_3leg_token.lock().await;
52                    cache.refreshing = false;
53                }
54                return result;
55            }
56
57            anyhow::bail!("Not logged in. Please run 'raps auth login' first.")
58        }
59    }
60
61    /// Check if user is logged in with 3-legged OAuth
62    pub async fn is_logged_in(&self) -> bool {
63        let cache = self.cached_3leg_token.lock().await;
64        if let Some(ref token) = cache.token {
65            if token.is_valid() {
66                return true;
67            }
68            // Check if we can refresh
69            if token.refresh_token.is_some() {
70                return true;
71            }
72        }
73        false
74    }
75
76    /// Start 3-legged OAuth login flow
77    pub async fn login(&self, scopes: &[&str]) -> Result<StoredToken> {
78        self.config.require_credentials()?;
79
80        let state = uuid::Uuid::new_v4().to_string();
81        let scope = scopes.join(" ");
82
83        // Parse port from callback URL or default to DEFAULT_CALLBACK_PORT
84        let preferred_port = match url::Url::parse(&self.config.callback_url) {
85            Ok(u) => u.port().unwrap_or(DEFAULT_CALLBACK_PORT),
86            Err(_) => DEFAULT_CALLBACK_PORT,
87        };
88
89        // Fallback ports (RAPS in leet speak + common alternatives)
90        let fallback_ports: Vec<u16> = vec![preferred_port, 12495, 7495, 9247, 3000, 5000];
91
92        // Try to bind to a port
93        let mut server = None;
94        let mut actual_port = preferred_port;
95
96        for &port in &fallback_ports {
97            match Server::http(format!("127.0.0.1:{}", port)) {
98                Ok(s) => {
99                    server = Some(s);
100                    actual_port = port;
101                    break;
102                }
103                Err(e) => {
104                    if crate::logging::debug() {
105                        println!("Port {} unavailable: {}", port, e);
106                    }
107                    continue;
108                }
109            }
110        }
111
112        let server = server.ok_or_else(|| {
113            anyhow::anyhow!(
114                "Failed to start callback server. Tried ports: {:?}.",
115                fallback_ports
116            )
117        })?;
118
119        tracing::info!(port = actual_port, "Callback server started");
120        if actual_port != preferred_port {
121            tracing::info!(
122                fallback_port = actual_port,
123                preferred_port,
124                "Using fallback port"
125            );
126        }
127
128        // Build callback URL with the actual port we bound to
129        let actual_callback_url = format!("http://localhost:{}/callback", actual_port);
130
131        // Build authorization URL
132        let auth_url = format!(
133            "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&state={}",
134            self.config.authorize_url(),
135            urlencoding::encode(&self.config.client_id),
136            urlencoding::encode(&actual_callback_url),
137            urlencoding::encode(&scope),
138            urlencoding::encode(&state)
139        );
140
141        eprintln!("Opening browser for authentication...");
142        eprintln!("If the browser doesn't open, visit this URL:");
143        eprintln!("{}", auth_url);
144
145        // Open browser
146        if webbrowser::open(&auth_url).is_err() {
147            eprintln!("Failed to open browser automatically.");
148        }
149
150        eprintln!("\nWaiting for authentication callback...");
151
152        // Wait for callback
153        #[allow(unused_assignments)]
154        let mut auth_code: Option<String> = None;
155
156        let server = std::sync::Arc::new(server);
157        loop {
158            let server_clone = server.clone();
159            let request = tokio::task::spawn_blocking(move || server_clone.recv())
160                .await
161                .context("Callback server thread panicked")?
162                .map_err(|e| anyhow::anyhow!("Failed to receive callback: {}", e))?;
163
164            let url = request.url().to_string();
165            tracing::debug!("Received callback request");
166
167            // Skip non-callback requests (like favicon)
168            if !url.starts_with("/callback") && !url.contains("code=") {
169                let response = Response::from_string("Not found").with_status_code(404);
170                request.respond(response).ok();
171                continue;
172            }
173
174            // Parse the callback URL for code and state
175            let parsed = url::Url::parse(&format!("http://localhost{}", url))?;
176            let params: std::collections::HashMap<_, _> = parsed.query_pairs().collect();
177
178            // Check for error
179            if let Some(error) = params.get("error") {
180                let desc = params
181                    .get("error_description")
182                    .map(|s| s.to_string())
183                    .unwrap_or_default();
184                let response = Response::from_string(format!(
185                    "<html><body><h1>Login Failed</h1><p>{}: {}</p></body></html>",
186                    error, desc
187                ))
188                .with_header(
189                    tiny_http::Header::from_bytes(&b"Content-Type"[..], &b"text/html"[..])
190                        .expect("Content-Type: text/html is a valid header"),
191                );
192                request.respond(response).ok();
193                anyhow::bail!("Authorization error: {error} - {desc}");
194            }
195
196            // Check state
197            let returned_state = params
198                .get("state")
199                .ok_or_else(|| anyhow::anyhow!("Missing state parameter"))?;
200            if returned_state != &state {
201                let response = Response::from_string("State mismatch").with_status_code(400);
202                request.respond(response).ok();
203                anyhow::bail!("State mismatch - possible CSRF attack");
204            }
205
206            // Get authorization code
207            if let Some(code) = params.get("code") {
208                auth_code = Some(code.to_string());
209
210                // Send success response to browser
211                let response = Response::from_string(
212                    "<html><body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p></body></html>"
213                ).with_header(
214                    tiny_http::Header::from_bytes(&b"Content-Type"[..], &b"text/html"[..]).expect("Content-Type: text/html is a valid header")
215                );
216                request.respond(response).ok();
217                break;
218            }
219        }
220
221        let code = auth_code.ok_or_else(|| anyhow::anyhow!("No authorization code received"))?;
222
223        println!("Authorization code received, exchanging for token...");
224
225        // Exchange code for tokens (must use the actual callback URL that was sent in the authorize request)
226        let token = self.exchange_code(&code, &actual_callback_url).await?;
227
228        // Store the token
229        let stored = StoredToken {
230            access_token: token.access_token.clone(),
231            refresh_token: token.refresh_token.clone(),
232            expires_at: chrono::Utc::now().timestamp() + token.expires_in as i64,
233            scopes: scopes.iter().map(|s| s.to_string()).collect(),
234        };
235
236        self.save_token(&stored)?;
237
238        // Update cache
239        {
240            let mut cache = self.cached_3leg_token.lock().await;
241            cache.token = Some(stored.clone());
242        }
243
244        Ok(stored)
245    }
246
247    /// Exchange authorization code for tokens
248    async fn exchange_code(&self, code: &str, redirect_uri: &str) -> Result<TokenResponse> {
249        let url = self.config.auth_url();
250
251        let params = [
252            ("grant_type", "authorization_code"),
253            ("code", code),
254            ("redirect_uri", redirect_uri),
255        ];
256
257        let _auth_start = std::time::Instant::now();
258        let response = self
259            .http_client
260            .post(&url)
261            .basic_auth(&self.config.client_id, Some(&self.config.client_secret))
262            .form(&params)
263            .send()
264            .await
265            .context("Failed to exchange authorization code")?;
266        crate::profiler::record_http_request(_auth_start.elapsed());
267
268        if !response.status().is_success() {
269            let status = response.status();
270            let error_text = response.text().await.unwrap_or_default();
271            let redacted = crate::logging::redact_secrets(&error_text);
272            anyhow::bail!("Token exchange failed ({status}): {redacted}");
273        }
274
275        let token: TokenResponse = response
276            .json()
277            .await
278            .context("Failed to parse token response")?;
279
280        Ok(token)
281    }
282
283    /// Refresh an expired access token
284    ///
285    /// On failure: preserves cached token (does not clear it), resets refreshing flag.
286    /// On success: updates cached token, resets refreshing flag.
287    async fn refresh_token(&self, refresh_token: String) -> Result<String> {
288        self.config.require_credentials()?;
289
290        let url = self.config.auth_url();
291
292        let params = [
293            ("grant_type", "refresh_token"),
294            ("refresh_token", &refresh_token),
295        ];
296
297        let _auth_start = std::time::Instant::now();
298        let response = self
299            .http_client
300            .post(&url)
301            .basic_auth(&self.config.client_id, Some(&self.config.client_secret))
302            .form(&params)
303            .send()
304            .await
305            .context("Failed to refresh token")?;
306        crate::profiler::record_http_request(_auth_start.elapsed());
307
308        if !response.status().is_success() {
309            // Refresh failed -- preserve cached token, just reset refreshing flag
310            {
311                let mut cache = self.cached_3leg_token.lock().await;
312                cache.refreshing = false;
313            }
314            anyhow::bail!("Token refresh failed. Please login again with 'raps auth login'");
315        }
316
317        let token: TokenResponse = response
318            .json()
319            .await
320            .context("Failed to parse refresh response")?;
321
322        // Update stored token, preserving scopes from the original
323        let original_scopes = {
324            let cache = self.cached_3leg_token.lock().await;
325            cache
326                .token
327                .as_ref()
328                .map(|t| t.scopes.clone())
329                .unwrap_or_default()
330        };
331        let stored = StoredToken {
332            access_token: token.access_token.clone(),
333            refresh_token: token.refresh_token.or(Some(refresh_token)),
334            expires_at: chrono::Utc::now().timestamp() + token.expires_in as i64,
335            scopes: original_scopes,
336        };
337
338        self.save_token(&stored)?;
339
340        {
341            let mut cache = self.cached_3leg_token.lock().await;
342            cache.token = Some(stored);
343            cache.refreshing = false;
344        }
345
346        Ok(token.access_token)
347    }
348
349    /// Logout - clear stored tokens
350    pub async fn logout(&self) -> Result<()> {
351        self.delete_stored_token()?;
352        let mut cache = self.cached_3leg_token.lock().await;
353        cache.token = None;
354        cache.refreshing = false;
355        Ok(())
356    }
357
358    /// Get user profile information (requires 3-legged auth with user:read or user-profile:read scope)
359    pub async fn get_user_info(&self) -> Result<super::types::UserInfo> {
360        let token = self.get_3leg_token().await?;
361        self.get_user_info_with_token(&token).await
362    }
363
364    /// Get token expiry timestamp
365    pub async fn get_token_expiry(&self) -> Option<i64> {
366        let cache = self.cached_3leg_token.lock().await;
367        cache.token.as_ref().map(|t| t.expires_at)
368    }
369}