Skip to main content

raps_kernel/auth/
device_code.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2025 Dmytro Yemelianov
3
4//! Manual PKCE OAuth flow for headless environments
5//!
6//! Instead of the device code grant (which APS doesn't support), this module
7//! implements a manual authorization code flow with PKCE (S256). The user is
8//! shown an authorize URL, opens it on any device, and pastes the resulting
9//! callback URL (or bare authorization code) back into the terminal.
10
11use anyhow::{Context, Result};
12use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
13use colored::Colorize;
14use sha2::{Digest, Sha256};
15
16use super::AuthClient;
17use super::types::TokenResponse;
18use crate::types::StoredToken;
19
20/// Generate a cryptographically random PKCE code verifier.
21///
22/// The verifier is 128 characters from the unreserved character set
23/// defined in RFC 7636 §4.1: [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
24fn generate_code_verifier() -> String {
25    use rand::Rng;
26    const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
27    let mut rng = rand::thread_rng();
28    (0..128)
29        .map(|_| {
30            let idx = rng.gen_range(0..CHARSET.len());
31            CHARSET[idx] as char
32        })
33        .collect()
34}
35
36/// Derive the S256 code challenge from a code verifier (RFC 7636 §4.2).
37fn derive_code_challenge(verifier: &str) -> String {
38    let hash = Sha256::digest(verifier.as_bytes());
39    URL_SAFE_NO_PAD.encode(hash)
40}
41
42impl AuthClient {
43    /// Login with 3-legged OAuth using manual PKCE flow (headless-friendly).
44    ///
45    /// 1. Generates a PKCE code verifier / challenge pair.
46    /// 2. Prints an authorization URL for the user to open in any browser.
47    /// 3. Prompts the user to paste back the callback URL (or bare code).
48    /// 4. Exchanges the authorization code + verifier for tokens.
49    pub async fn login_device(&self, scopes: &[&str]) -> Result<StoredToken> {
50        self.config.require_credentials()?;
51
52        // --- PKCE ---
53        let code_verifier = generate_code_verifier();
54        let code_challenge = derive_code_challenge(&code_verifier);
55
56        // --- CSRF state ---
57        let state = uuid::Uuid::new_v4().to_string();
58
59        // --- Build authorize URL ---
60        let scope_str = scopes.join(" ");
61        let redirect_uri = &self.config.callback_url;
62        let auth_url = format!(
63            "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&state={}&code_challenge={}&code_challenge_method=S256",
64            self.config.authorize_url(),
65            urlencoding::encode(&self.config.client_id),
66            urlencoding::encode(redirect_uri),
67            urlencoding::encode(&scope_str),
68            urlencoding::encode(&state),
69            urlencoding::encode(&code_challenge),
70        );
71
72        // --- Display instructions ---
73        println!("\n{}", "Manual PKCE Authentication".bold().cyan());
74        println!("{}", "-".repeat(50));
75        println!(
76            "{}",
77            "Open the following URL in any browser to authorize:".dimmed()
78        );
79        println!("\n  {}\n", auth_url.cyan());
80        println!(
81            "{}",
82            "After authorizing, you will be redirected to your callback URL.".dimmed()
83        );
84        println!(
85            "{}",
86            "Paste the full callback URL (or just the authorization code) below.".dimmed()
87        );
88        println!("{}", "-".repeat(50));
89
90        // --- Prompt user for the callback URL / code ---
91        let input: String = dialoguer::Input::new()
92            .with_prompt("Callback URL or authorization code")
93            .interact_text()
94            .context("Failed to read user input")?;
95
96        let input = input.trim().to_string();
97        if input.is_empty() {
98            anyhow::bail!("No authorization code provided. Please try again.");
99        }
100
101        // Parse the authorization code and validate state
102        let auth_code = if input.contains("code=") || input.starts_with("http") {
103            // User pasted a full URL — extract query parameters
104            let parsed_url = url::Url::parse(&input)
105                .context("Failed to parse the pasted URL. Please try again with a valid URL.")?;
106            let params: std::collections::HashMap<_, _> = parsed_url.query_pairs().collect();
107
108            // Check for OAuth error in the callback
109            if let Some(error) = params.get("error") {
110                let desc = params
111                    .get("error_description")
112                    .map(|s| s.to_string())
113                    .unwrap_or_default();
114                anyhow::bail!("Authorization error: {error} - {desc}");
115            }
116
117            // Validate CSRF state
118            let returned_state = params
119                .get("state")
120                .ok_or_else(|| anyhow::anyhow!("Missing state parameter in callback URL"))?;
121            if returned_state.as_ref() != state.as_str() {
122                anyhow::bail!("State mismatch — possible CSRF attack. Please try again.");
123            }
124
125            params
126                .get("code")
127                .ok_or_else(|| anyhow::anyhow!("No authorization code found in callback URL"))?
128                .to_string()
129        } else {
130            // User pasted a bare authorization code
131            input
132        };
133
134        println!("Authorization code received, exchanging for token...");
135
136        // --- Exchange code for tokens ---
137        let token = self
138            .exchange_code_pkce(&auth_code, redirect_uri, &code_verifier)
139            .await?;
140
141        println!("\n{} Authorization successful!", "OK".green().bold());
142
143        // --- Store token ---
144        let stored = StoredToken {
145            access_token: token.access_token.clone(),
146            refresh_token: token.refresh_token.clone(),
147            expires_at: chrono::Utc::now().timestamp() + token.expires_in as i64,
148            scopes: scopes.iter().map(|s| s.to_string()).collect(),
149        };
150
151        self.save_token(&stored)?;
152
153        // Update cache
154        {
155            let mut cache = self.cached_3leg_token.lock().await;
156            cache.token = Some(stored.clone());
157        }
158
159        Ok(stored)
160    }
161
162    /// Exchange an authorization code for tokens using PKCE (no client_secret required).
163    ///
164    /// Uses HTTP Basic auth with client_id/client_secret for compatibility with APS,
165    /// while also sending the PKCE code_verifier.
166    async fn exchange_code_pkce(
167        &self,
168        code: &str,
169        redirect_uri: &str,
170        code_verifier: &str,
171    ) -> Result<TokenResponse> {
172        let url = self.config.auth_url();
173
174        let params = [
175            ("grant_type", "authorization_code"),
176            ("code", code),
177            ("redirect_uri", redirect_uri),
178            ("code_verifier", code_verifier),
179        ];
180
181        let _auth_start = std::time::Instant::now();
182        let response = self
183            .http_client
184            .post(&url)
185            .basic_auth(&self.config.client_id, Some(&self.config.client_secret))
186            .form(&params)
187            .send()
188            .await
189            .context("Failed to exchange authorization code")?;
190        crate::profiler::record_http_request(_auth_start.elapsed());
191
192        if !response.status().is_success() {
193            let status = response.status();
194            let error_text = response.text().await.unwrap_or_default();
195            let redacted = crate::logging::redact_secrets(&error_text);
196            anyhow::bail!("Token exchange failed ({status}): {redacted}");
197        }
198
199        let token: TokenResponse = response
200            .json()
201            .await
202            .context("Failed to parse token response")?;
203
204        Ok(token)
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn test_code_verifier_length_and_charset() {
214        let verifier = generate_code_verifier();
215        assert_eq!(verifier.len(), 128);
216        for ch in verifier.chars() {
217            assert!(
218                ch.is_ascii_alphanumeric() || ch == '-' || ch == '.' || ch == '_' || ch == '~',
219                "Invalid character in code verifier: {ch}"
220            );
221        }
222    }
223
224    #[test]
225    fn test_code_verifier_uniqueness() {
226        let v1 = generate_code_verifier();
227        let v2 = generate_code_verifier();
228        assert_ne!(v1, v2, "Two verifiers should not be identical");
229    }
230
231    #[test]
232    fn test_code_challenge_is_valid_base64url() {
233        let verifier = generate_code_verifier();
234        let challenge = derive_code_challenge(&verifier);
235        // SHA-256 produces 32 bytes → 43 base64url chars (no padding)
236        assert_eq!(challenge.len(), 43);
237        for ch in challenge.chars() {
238            assert!(
239                ch.is_ascii_alphanumeric() || ch == '-' || ch == '_',
240                "Invalid character in code challenge: {ch}"
241            );
242        }
243    }
244
245    #[test]
246    fn test_code_challenge_deterministic() {
247        let verifier = "test-verifier-value";
248        let c1 = derive_code_challenge(verifier);
249        let c2 = derive_code_challenge(verifier);
250        assert_eq!(c1, c2);
251    }
252
253    #[test]
254    fn test_code_challenge_known_vector() {
255        // RFC 7636 Appendix B test vector
256        let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
257        let challenge = derive_code_challenge(verifier);
258        assert_eq!(challenge, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM");
259    }
260}