Skip to main content

stakpak_shared/oauth/
flow.rs

1//! OAuth 2.0 authorization code flow implementation
2
3use super::config::OAuthConfig;
4use super::error::{OAuthError, OAuthResult};
5use super::pkce::PkceChallenge;
6use serde::{Deserialize, Serialize};
7
8/// OAuth token response from the token endpoint
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TokenResponse {
11    /// Access token for API requests
12    pub access_token: String,
13    /// Refresh token for obtaining new access tokens
14    pub refresh_token: String,
15    /// Token lifetime in seconds
16    pub expires_in: i64,
17    /// Token type (usually "Bearer")
18    pub token_type: String,
19}
20
21/// OAuth 2.0 authorization code flow handler
22pub struct OAuthFlow {
23    config: OAuthConfig,
24    pkce: Option<PkceChallenge>,
25}
26
27impl OAuthFlow {
28    /// Create a new OAuth flow with the given configuration
29    pub fn new(config: OAuthConfig) -> Self {
30        Self { config, pkce: None }
31    }
32
33    /// Generate the authorization URL for the user to visit
34    ///
35    /// This generates a new PKCE challenge and returns the full authorization URL
36    /// that should be opened in the user's browser.
37    pub fn generate_auth_url(&mut self) -> String {
38        let pkce = PkceChallenge::generate();
39
40        let url = format!(
41            "{}?code=true&client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method={}&state={}",
42            self.config.auth_url,
43            urlencoding::encode(&self.config.client_id),
44            urlencoding::encode(&self.config.redirect_url),
45            urlencoding::encode(&self.config.scopes_string()),
46            urlencoding::encode(&pkce.challenge),
47            PkceChallenge::challenge_method(),
48            urlencoding::encode(&pkce.verifier), // State contains verifier for validation
49        );
50
51        self.pkce = Some(pkce);
52        url
53    }
54
55    /// Exchange authorization code for tokens
56    ///
57    /// The code should be in the format "authorization_code#state" as returned by Anthropic.
58    pub async fn exchange_code(&self, code: &str) -> OAuthResult<TokenResponse> {
59        let pkce = self.pkce.as_ref().ok_or(OAuthError::PkceNotInitialized)?;
60
61        // Parse the authorization code - format: "authorization_code#state"
62        let (auth_code, state) = parse_auth_code(code)?;
63
64        // Validate state matches our verifier
65        if state != pkce.verifier {
66            return Err(OAuthError::invalid_code_format(
67                "State mismatch - possible CSRF attack",
68            ));
69        }
70
71        let client =
72            crate::tls_client::create_tls_client(crate::tls_client::TlsClientConfig::default())
73                .expect("Failed to create TLS client for OAuth token exchange");
74        let response = client
75            .post(&self.config.token_url)
76            .json(&serde_json::json!({
77                "grant_type": "authorization_code",
78                "code": auth_code,
79                "state": state,
80                "client_id": self.config.client_id,
81                "redirect_uri": self.config.redirect_url,
82                "code_verifier": pkce.verifier,
83            }))
84            .send()
85            .await?;
86
87        if !response.status().is_success() {
88            let status = response.status();
89            let error_text = response.text().await.unwrap_or_default();
90            return Err(OAuthError::token_exchange_failed(format!(
91                "HTTP {}: {}",
92                status, error_text
93            )));
94        }
95
96        response.json::<TokenResponse>().await.map_err(|e| {
97            OAuthError::token_exchange_failed(format!("Failed to parse token response: {}", e))
98        })
99    }
100
101    /// Refresh an expired access token
102    pub async fn refresh_token(&self, refresh_token: &str) -> OAuthResult<TokenResponse> {
103        let client =
104            crate::tls_client::create_tls_client(crate::tls_client::TlsClientConfig::default())
105                .expect("Failed to create TLS client for OAuth token refresh");
106        let response = client
107            .post(&self.config.token_url)
108            .json(&serde_json::json!({
109                "grant_type": "refresh_token",
110                "refresh_token": refresh_token,
111                "client_id": self.config.client_id,
112            }))
113            .send()
114            .await?;
115
116        if !response.status().is_success() {
117            let status = response.status();
118            let error_text = response.text().await.unwrap_or_default();
119            return Err(OAuthError::token_refresh_failed(format!(
120                "HTTP {}: {}",
121                status, error_text
122            )));
123        }
124
125        response.json::<TokenResponse>().await.map_err(|e| {
126            OAuthError::token_refresh_failed(format!("Failed to parse token response: {}", e))
127        })
128    }
129
130    /// Get the PKCE verifier (for validation purposes)
131    pub fn pkce_verifier(&self) -> Option<&str> {
132        self.pkce.as_ref().map(|p| p.verifier.as_str())
133    }
134}
135
136/// Parse the authorization code from Anthropic's callback format
137///
138/// Anthropic returns codes in the format: "authorization_code#state"
139#[allow(clippy::string_slice)] // pos from find('#') on same string, '#' is ASCII
140fn parse_auth_code(code: &str) -> OAuthResult<(String, String)> {
141    // Handle both "#" and "%23" (URL-encoded #)
142    let code = code.replace("%23", "#");
143
144    if let Some(pos) = code.find('#') {
145        let auth_code = code[..pos].to_string();
146        let state = code[pos + 1..].to_string();
147
148        if auth_code.is_empty() || state.is_empty() {
149            return Err(OAuthError::invalid_code_format(
150                "Authorization code or state is empty",
151            ));
152        }
153
154        Ok((auth_code, state))
155    } else {
156        Err(OAuthError::invalid_code_format(
157            "Expected format: authorization_code#state",
158        ))
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    fn test_config() -> OAuthConfig {
167        OAuthConfig::new(
168            "test-client-id",
169            "https://example.com/auth",
170            "https://example.com/token",
171            "https://example.com/callback",
172            vec!["scope1".to_string(), "scope2".to_string()],
173        )
174    }
175
176    #[test]
177    fn test_generate_auth_url() {
178        let mut flow = OAuthFlow::new(test_config());
179        let url = flow.generate_auth_url();
180
181        assert!(url.starts_with("https://example.com/auth?"));
182        assert!(url.contains("client_id=test-client-id"));
183        assert!(url.contains("response_type=code"));
184        assert!(url.contains("redirect_uri="));
185        assert!(url.contains("scope=scope1%20scope2"));
186        assert!(url.contains("code_challenge="));
187        assert!(url.contains("code_challenge_method=S256"));
188        assert!(url.contains("state="));
189
190        // PKCE should be initialized
191        assert!(flow.pkce.is_some());
192    }
193
194    #[test]
195    fn test_parse_auth_code_valid() {
196        let result = parse_auth_code("abc123#verifier456");
197        assert!(result.is_ok());
198        let (code, state) = result.unwrap();
199        assert_eq!(code, "abc123");
200        assert_eq!(state, "verifier456");
201    }
202
203    #[test]
204    fn test_parse_auth_code_url_encoded() {
205        let result = parse_auth_code("abc123%23verifier456");
206        assert!(result.is_ok());
207        let (code, state) = result.unwrap();
208        assert_eq!(code, "abc123");
209        assert_eq!(state, "verifier456");
210    }
211
212    #[test]
213    fn test_parse_auth_code_missing_separator() {
214        let result = parse_auth_code("abc123verifier456");
215        assert!(result.is_err());
216    }
217
218    #[test]
219    fn test_parse_auth_code_empty_parts() {
220        assert!(parse_auth_code("#state").is_err());
221        assert!(parse_auth_code("code#").is_err());
222        assert!(parse_auth_code("#").is_err());
223    }
224
225    #[test]
226    fn test_exchange_code_without_pkce() {
227        let flow = OAuthFlow::new(test_config());
228        let result = tokio_test::block_on(flow.exchange_code("code#state"));
229        assert!(matches!(result, Err(OAuthError::PkceNotInitialized)));
230    }
231
232    #[test]
233    fn test_token_response_serde() {
234        let json = r#"{
235            "access_token": "access123",
236            "refresh_token": "refresh456",
237            "expires_in": 3600,
238            "token_type": "Bearer"
239        }"#;
240
241        let response: TokenResponse = serde_json::from_str(json).unwrap();
242        assert_eq!(response.access_token, "access123");
243        assert_eq!(response.refresh_token, "refresh456");
244        assert_eq!(response.expires_in, 3600);
245        assert_eq!(response.token_type, "Bearer");
246    }
247}