Skip to main content

stakpak_shared/oauth/
flow.rs

1//! OAuth 2.0 authorization code flow implementation
2
3use super::config::{AuthorizationRequestMode, OAuthConfig, TokenRequestMode};
4use super::error::{OAuthError, OAuthResult};
5use super::pkce::PkceChallenge;
6use serde::{Deserialize, Serialize};
7
8/// OAuth token response from the token endpoint
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TokenResponse {
11    /// Access token for API requests
12    pub access_token: String,
13    /// Refresh token for obtaining new access tokens
14    pub refresh_token: String,
15    /// Token lifetime in seconds
16    pub expires_in: i64,
17    /// Token type (usually "Bearer")
18    pub token_type: String,
19}
20
21enum TokenRequest {
22    Json(serde_json::Value),
23    Form(Vec<(String, String)>),
24}
25
26/// OAuth 2.0 authorization code flow handler
27pub struct OAuthFlow {
28    config: OAuthConfig,
29    pkce: Option<PkceChallenge>,
30    state: Option<String>,
31}
32
33impl OAuthFlow {
34    /// Create a new OAuth flow with the given configuration
35    pub fn new(config: OAuthConfig) -> Self {
36        Self {
37            config,
38            pkce: None,
39            state: None,
40        }
41    }
42
43    /// Generate the authorization URL for the user to visit
44    ///
45    /// This generates a new PKCE challenge and returns the full authorization URL
46    /// that should be opened in the user's browser.
47    pub fn generate_auth_url(&mut self) -> String {
48        let pkce = PkceChallenge::generate();
49        let state = uuid::Uuid::new_v4().simple().to_string();
50
51        let mut query = vec![
52            format!("client_id={}", urlencoding::encode(&self.config.client_id)),
53            "response_type=code".to_string(),
54            format!(
55                "redirect_uri={}",
56                urlencoding::encode(&self.config.redirect_url)
57            ),
58            format!(
59                "scope={}",
60                urlencoding::encode(&self.config.scopes_string())
61            ),
62            format!("code_challenge={}", urlencoding::encode(&pkce.challenge)),
63            format!(
64                "code_challenge_method={}",
65                PkceChallenge::challenge_method()
66            ),
67            format!("state={}", urlencoding::encode(&state)),
68        ];
69
70        if self.config.authorization_request_mode == AuthorizationRequestMode::LegacyCode {
71            query.insert(0, "code=true".to_string());
72        }
73
74        query.extend(self.config.authorization_params.iter().map(|(key, value)| {
75            format!(
76                "{}={}",
77                urlencoding::encode(key),
78                urlencoding::encode(value)
79            )
80        }));
81
82        let url = format!("{}?{}", self.config.auth_url, query.join("&"));
83
84        self.pkce = Some(pkce);
85        self.state = Some(state);
86        url
87    }
88
89    fn build_token_exchange_request(
90        &self,
91        auth_code: String,
92        state: String,
93    ) -> OAuthResult<TokenRequest> {
94        let pkce = self.pkce.as_ref().ok_or(OAuthError::PkceNotInitialized)?;
95
96        Ok(match self.config.token_request_mode {
97            TokenRequestMode::Json => TokenRequest::Json(serde_json::json!({
98                "grant_type": "authorization_code",
99                "code": auth_code,
100                "state": state,
101                "client_id": self.config.client_id,
102                "redirect_uri": self.config.redirect_url,
103                "code_verifier": pkce.verifier,
104            })),
105            TokenRequestMode::FormUrlEncoded => TokenRequest::Form(vec![
106                // OpenAI's token endpoint rejects `state` in the form-encoded
107                // exchange request (`Unknown parameter: 'state'.`). State is
108                // still validated locally before building this request.
109                ("grant_type".to_string(), "authorization_code".to_string()),
110                ("code".to_string(), auth_code),
111                ("client_id".to_string(), self.config.client_id.clone()),
112                ("redirect_uri".to_string(), self.config.redirect_url.clone()),
113                ("code_verifier".to_string(), pkce.verifier.clone()),
114            ]),
115        })
116    }
117
118    fn build_token_refresh_request(&self, refresh_token: String) -> TokenRequest {
119        match self.config.token_request_mode {
120            TokenRequestMode::Json => TokenRequest::Json(serde_json::json!({
121                "grant_type": "refresh_token",
122                "refresh_token": refresh_token,
123                "client_id": self.config.client_id,
124            })),
125            TokenRequestMode::FormUrlEncoded => TokenRequest::Form(vec![
126                ("grant_type".to_string(), "refresh_token".to_string()),
127                ("refresh_token".to_string(), refresh_token),
128                ("client_id".to_string(), self.config.client_id.clone()),
129            ]),
130        }
131    }
132
133    /// Exchange authorization code for tokens.
134    ///
135    /// The string form is kept for manual copy/paste flows that return
136    /// `authorization_code#state`. Programmatic callers should prefer
137    /// `exchange_code_with_state` when they already have separate values.
138    pub async fn exchange_code(&self, code: &str) -> OAuthResult<TokenResponse> {
139        let (auth_code, state) = parse_auth_code(code)?;
140        self.exchange_code_with_state(&auth_code, &state).await
141    }
142
143    /// Exchange authorization code for tokens using separately supplied code and state.
144    pub async fn exchange_code_with_state(
145        &self,
146        auth_code: &str,
147        state: &str,
148    ) -> OAuthResult<TokenResponse> {
149        let _pkce = self.pkce.as_ref().ok_or(OAuthError::PkceNotInitialized)?;
150
151        let expected_state = self
152            .state
153            .as_deref()
154            .ok_or(OAuthError::PkceNotInitialized)?;
155
156        // Validate state matches the authorization request state before the
157        // token exchange request is built.
158        if state != expected_state {
159            return Err(OAuthError::invalid_code_format(
160                "State mismatch - possible CSRF attack",
161            ));
162        }
163
164        let token_request =
165            self.build_token_exchange_request(auth_code.to_string(), state.to_string())?;
166
167        let client =
168            crate::tls_client::create_tls_client(crate::tls_client::TlsClientConfig::default())
169                .expect("Failed to create TLS client for OAuth token exchange");
170        let response = match token_request {
171            TokenRequest::Json(body) => client.post(&self.config.token_url).json(&body),
172            TokenRequest::Form(body) => client.post(&self.config.token_url).form(&body),
173        }
174        .send()
175        .await?;
176
177        if !response.status().is_success() {
178            let status = response.status();
179            let error_text = response.text().await.unwrap_or_default();
180            return Err(OAuthError::token_exchange_failed(format!(
181                "HTTP {}: {}",
182                status, error_text
183            )));
184        }
185
186        response.json::<TokenResponse>().await.map_err(|e| {
187            OAuthError::token_exchange_failed(format!("Failed to parse token response: {}", e))
188        })
189    }
190
191    /// Refresh an expired access token
192    pub async fn refresh_token(&self, refresh_token: &str) -> OAuthResult<TokenResponse> {
193        let token_request = self.build_token_refresh_request(refresh_token.to_string());
194        let client =
195            crate::tls_client::create_tls_client(crate::tls_client::TlsClientConfig::default())
196                .expect("Failed to create TLS client for OAuth token refresh");
197        let response = match token_request {
198            TokenRequest::Json(body) => client.post(&self.config.token_url).json(&body),
199            TokenRequest::Form(body) => client.post(&self.config.token_url).form(&body),
200        }
201        .send()
202        .await?;
203
204        if !response.status().is_success() {
205            let status = response.status();
206            let error_text = response.text().await.unwrap_or_default();
207            return Err(OAuthError::token_refresh_failed(format!(
208                "HTTP {}: {}",
209                status, error_text
210            )));
211        }
212
213        response.json::<TokenResponse>().await.map_err(|e| {
214            OAuthError::token_refresh_failed(format!("Failed to parse token response: {}", e))
215        })
216    }
217
218    /// Get the PKCE verifier (for validation purposes)
219    pub fn pkce_verifier(&self) -> Option<&str> {
220        self.pkce.as_ref().map(|p| p.verifier.as_str())
221    }
222}
223
224/// Parse the authorization code from a provider callback format that embeds state.
225///
226/// Some providers return codes in the format: "authorization_code#state".
227#[allow(clippy::string_slice)] // pos from find('#') on same string, '#' is ASCII
228fn parse_auth_code(code: &str) -> OAuthResult<(String, String)> {
229    // Handle both "#" and "%23" (URL-encoded #)
230    let code = code.replace("%23", "#");
231
232    if let Some(pos) = code.find('#') {
233        let auth_code = code[..pos].to_string();
234        let state = code[pos + 1..].to_string();
235
236        if auth_code.is_empty() || state.is_empty() {
237            return Err(OAuthError::invalid_code_format(
238                "Authorization code or state is empty",
239            ));
240        }
241
242        Ok((auth_code, state))
243    } else {
244        Err(OAuthError::invalid_code_format(
245            "Expected format: authorization_code#state",
246        ))
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use crate::oauth::config::AuthorizationRequestMode;
254
255    fn test_config() -> OAuthConfig {
256        OAuthConfig::new(
257            "test-client-id",
258            "https://example.com/auth",
259            "https://example.com/token",
260            "https://example.com/callback",
261            vec!["scope1".to_string(), "scope2".to_string()],
262        )
263    }
264
265    #[test]
266    fn test_generate_auth_url_standard_pkce() {
267        let mut flow = OAuthFlow::new(test_config());
268        let url = flow.generate_auth_url();
269
270        assert!(url.starts_with("https://example.com/auth?"));
271        assert!(url.contains("client_id=test-client-id"));
272        assert!(url.contains("response_type=code"));
273        assert!(url.contains("redirect_uri="));
274        assert!(url.contains("scope=scope1%20scope2"));
275        assert!(url.contains("code_challenge="));
276        assert!(url.contains("code_challenge_method=S256"));
277        assert!(url.contains("state="));
278        assert!(!url.contains("code=true"));
279
280        // PKCE should be initialized
281        assert!(flow.pkce.is_some());
282    }
283
284    #[test]
285    fn test_generate_auth_url_legacy_mode_includes_code_param() {
286        let mut flow = OAuthFlow::new(
287            test_config().with_authorization_request_mode(AuthorizationRequestMode::LegacyCode),
288        );
289        let url = flow.generate_auth_url();
290
291        assert!(url.contains("code=true"));
292        assert!(url.contains("response_type=code"));
293    }
294
295    #[test]
296    fn test_generate_auth_url_includes_provider_specific_params() {
297        let mut flow = OAuthFlow::new(test_config().with_authorization_params(vec![
298            ("id_token_add_organizations", "true"),
299            ("codex_cli_simplified_flow", "true"),
300            ("originator", "stakpak"),
301        ]));
302        let url = flow.generate_auth_url();
303
304        assert!(url.contains("id_token_add_organizations=true"));
305        assert!(url.contains("codex_cli_simplified_flow=true"));
306        assert!(url.contains("originator=stakpak"));
307    }
308
309    #[test]
310    fn test_generate_auth_url_uses_separate_state_from_pkce_verifier() {
311        let mut flow = OAuthFlow::new(test_config());
312        let url = flow.generate_auth_url();
313        let parsed = reqwest::Url::parse(&url).expect("parse auth url");
314        let state = parsed
315            .query_pairs()
316            .find(|(key, _)| key == "state")
317            .map(|(_, value)| value.to_string())
318            .expect("state param");
319
320        assert_ne!(Some(state.as_str()), flow.pkce_verifier());
321    }
322
323    #[test]
324    fn test_openai_token_exchange_request_uses_form_encoding_without_state() {
325        let mut flow = OAuthFlow::new(
326            test_config()
327                .with_token_request_mode(crate::oauth::config::TokenRequestMode::FormUrlEncoded),
328        );
329        let _ = flow.generate_auth_url();
330        let request = flow
331            .build_token_exchange_request("auth-code".to_string(), "callback-state".to_string())
332            .expect("token exchange request");
333
334        match request {
335            TokenRequest::Form(params) => {
336                assert!(
337                    params.contains(&("grant_type".to_string(), "authorization_code".to_string()))
338                );
339                assert!(params.contains(&("code".to_string(), "auth-code".to_string())));
340                assert!(params.contains(&("client_id".to_string(), "test-client-id".to_string())));
341                assert!(params.iter().all(|(key, _)| key != "state"));
342            }
343            TokenRequest::Json(_) => panic!("expected form request"),
344        }
345    }
346
347    #[test]
348    fn test_openai_token_refresh_request_uses_form_encoding() {
349        let flow = OAuthFlow::new(
350            test_config()
351                .with_token_request_mode(crate::oauth::config::TokenRequestMode::FormUrlEncoded),
352        );
353        let request = flow.build_token_refresh_request("refresh-token".to_string());
354
355        match request {
356            TokenRequest::Form(params) => {
357                assert!(params.contains(&("grant_type".to_string(), "refresh_token".to_string())));
358                assert!(
359                    params.contains(&("refresh_token".to_string(), "refresh-token".to_string()))
360                );
361                assert!(params.contains(&("client_id".to_string(), "test-client-id".to_string())));
362            }
363            TokenRequest::Json(_) => panic!("expected form request"),
364        }
365    }
366
367    #[test]
368    fn test_parse_auth_code_valid() {
369        let result = parse_auth_code("abc123#verifier456");
370        assert!(result.is_ok());
371        let (code, state) = result.unwrap();
372        assert_eq!(code, "abc123");
373        assert_eq!(state, "verifier456");
374    }
375
376    #[test]
377    fn test_parse_auth_code_url_encoded() {
378        let result = parse_auth_code("abc123%23verifier456");
379        assert!(result.is_ok());
380        let (code, state) = result.unwrap();
381        assert_eq!(code, "abc123");
382        assert_eq!(state, "verifier456");
383    }
384
385    #[test]
386    fn test_parse_auth_code_missing_separator() {
387        let result = parse_auth_code("abc123verifier456");
388        assert!(result.is_err());
389    }
390
391    #[test]
392    fn test_parse_auth_code_empty_parts() {
393        assert!(parse_auth_code("#state").is_err());
394        assert!(parse_auth_code("code#").is_err());
395        assert!(parse_auth_code("#").is_err());
396    }
397
398    #[test]
399    fn test_exchange_code_without_pkce() {
400        let flow = OAuthFlow::new(test_config());
401        let result = tokio_test::block_on(flow.exchange_code("code#state"));
402        assert!(matches!(result, Err(OAuthError::PkceNotInitialized)));
403    }
404
405    #[test]
406    fn test_exchange_code_with_state_without_pkce() {
407        let flow = OAuthFlow::new(test_config());
408        let result = tokio_test::block_on(flow.exchange_code_with_state("code", "state"));
409        assert!(matches!(result, Err(OAuthError::PkceNotInitialized)));
410    }
411
412    #[test]
413    fn test_token_response_serde() {
414        let json = r#"{
415            "access_token": "access123",
416            "refresh_token": "refresh456",
417            "expires_in": 3600,
418            "token_type": "Bearer"
419        }"#;
420
421        let response: TokenResponse = serde_json::from_str(json).unwrap();
422        assert_eq!(response.access_token, "access123");
423        assert_eq!(response.refresh_token, "refresh456");
424        assert_eq!(response.expires_in, 3600);
425        assert_eq!(response.token_type, "Bearer");
426    }
427}