Skip to main content

mailledger_oauth/flow/
code.rs

1//! Authorization Code Flow implementation.
2
3use super::{OAuthClient, PkceChallenge};
4use crate::error::Result;
5use crate::token::Token;
6use url::Url;
7
8/// Authorization Code Flow for `OAuth2`.
9///
10/// This flow is suitable for applications that can open a browser
11/// and receive the authorization code via redirect.
12#[derive(Debug)]
13pub struct AuthorizationCodeFlow {
14    client: OAuthClient,
15    pkce: Option<PkceChallenge>,
16}
17
18impl AuthorizationCodeFlow {
19    /// Creates a new authorization code flow.
20    #[must_use]
21    pub const fn new(client: OAuthClient) -> Self {
22        Self { client, pkce: None }
23    }
24
25    /// Enables PKCE for enhanced security (recommended for public clients).
26    #[must_use]
27    pub fn with_pkce(mut self) -> Self {
28        self.pkce = Some(PkceChallenge::generate());
29        self
30    }
31
32    /// Builds the authorization URL for user consent.
33    ///
34    /// The user should be redirected to this URL to authorize the application.
35    ///
36    /// # Arguments
37    ///
38    /// * `scopes` - Optional scopes to request (uses provider defaults if None)
39    /// * `state` - Optional state parameter for CSRF protection
40    ///
41    /// # Errors
42    ///
43    /// Returns an error if the URL cannot be constructed.
44    pub fn authorization_url(&self, scopes: Option<&[String]>, state: Option<&str>) -> Result<Url> {
45        let mut url = self.client.provider.auth_url.clone();
46
47        {
48            let mut pairs = url.query_pairs_mut();
49            pairs
50                .append_pair("client_id", &self.client.client_id)
51                .append_pair("response_type", "code");
52
53            if let Some(redirect_uri) = &self.client.redirect_uri {
54                pairs.append_pair("redirect_uri", redirect_uri);
55            }
56
57            let scope_str = scopes.map_or_else(
58                || self.client.provider.default_scopes.join(" "),
59                |s| s.join(" "),
60            );
61
62            if !scope_str.is_empty() {
63                pairs.append_pair("scope", &scope_str);
64            }
65
66            if let Some(state_val) = state {
67                pairs.append_pair("state", state_val);
68            }
69
70            if let Some(pkce) = &self.pkce {
71                pairs
72                    .append_pair("code_challenge", pkce.challenge())
73                    .append_pair("code_challenge_method", pkce.method());
74            }
75
76            // Provider-specific parameters
77            match self.client.provider.name.as_str() {
78                "Google" => {
79                    pairs
80                        .append_pair("access_type", "offline")
81                        .append_pair("prompt", "consent");
82                }
83                "Microsoft" => {
84                    pairs.append_pair("prompt", "consent");
85                }
86                _ => {}
87            }
88        }
89
90        Ok(url)
91    }
92
93    /// Exchanges the authorization code for an access token.
94    ///
95    /// # Arguments
96    ///
97    /// * `code` - Authorization code from the redirect
98    /// * `redirect_uri` - Optional redirect URI (uses client config if None)
99    ///
100    /// # Errors
101    ///
102    /// Returns an error if the token exchange fails.
103    pub async fn exchange_code(&self, code: &str, redirect_uri: Option<&str>) -> Result<Token> {
104        let code_verifier = self.pkce.as_ref().map(PkceChallenge::verifier);
105        self.client
106            .exchange_code(code, redirect_uri, code_verifier)
107            .await
108    }
109
110    /// Returns the PKCE verifier if PKCE is enabled.
111    #[must_use]
112    pub fn pkce_verifier(&self) -> Option<&str> {
113        self.pkce.as_ref().map(PkceChallenge::verifier)
114    }
115}
116
117#[cfg(test)]
118#[allow(clippy::unwrap_used, clippy::redundant_clone, clippy::manual_string_new, clippy::needless_collect, clippy::unreadable_literal, clippy::used_underscore_items, clippy::similar_names)]
119mod tests {
120    use super::*;
121    use crate::provider::Provider;
122
123    #[test]
124    fn test_authorization_url() {
125        let provider = Provider::google().unwrap();
126        let client =
127            OAuthClient::new("test_client", provider).with_redirect_uri("http://localhost:8080");
128
129        let flow = AuthorizationCodeFlow::new(client);
130        let url = flow.authorization_url(None, Some("random_state")).unwrap();
131
132        assert!(url.as_str().contains("client_id=test_client"));
133        assert!(url.as_str().contains("response_type=code"));
134        assert!(url.as_str().contains("state=random_state"));
135        // Check URL-encoded redirect_uri
136        assert!(
137            url.as_str()
138                .contains("redirect_uri=http%3A%2F%2Flocalhost%3A8080")
139        );
140    }
141
142    #[test]
143    fn test_authorization_url_with_pkce() {
144        let provider = Provider::google().unwrap();
145        let client = OAuthClient::new("test_client", provider);
146
147        let flow = AuthorizationCodeFlow::new(client).with_pkce();
148        let url = flow.authorization_url(None, None).unwrap();
149
150        assert!(url.as_str().contains("code_challenge="));
151        assert!(url.as_str().contains("code_challenge_method=S256"));
152        assert!(flow.pkce_verifier().is_some());
153    }
154
155    #[test]
156    fn test_authorization_url_custom_scopes() {
157        let provider = Provider::google().unwrap();
158        let client = OAuthClient::new("test_client", provider);
159
160        let flow = AuthorizationCodeFlow::new(client);
161        let scopes = vec!["email".to_string(), "profile".to_string()];
162        let url = flow.authorization_url(Some(&scopes), None).unwrap();
163
164        // Check URL-encoded scope (space becomes + in query parameters)
165        assert!(url.as_str().contains("scope=email+profile"));
166    }
167
168    #[test]
169    fn test_google_specific_params() {
170        let provider = Provider::google().unwrap();
171        let client = OAuthClient::new("test_client", provider);
172
173        let flow = AuthorizationCodeFlow::new(client);
174        let url = flow.authorization_url(None, None).unwrap();
175
176        assert!(url.as_str().contains("access_type=offline"));
177        assert!(url.as_str().contains("prompt=consent"));
178    }
179}