mailledger_oauth/flow/
code.rs1use super::{OAuthClient, PkceChallenge};
4use crate::error::Result;
5use crate::token::Token;
6use url::Url;
7
8#[derive(Debug)]
13pub struct AuthorizationCodeFlow {
14 client: OAuthClient,
15 pkce: Option<PkceChallenge>,
16}
17
18impl AuthorizationCodeFlow {
19 #[must_use]
21 pub const fn new(client: OAuthClient) -> Self {
22 Self { client, pkce: None }
23 }
24
25 #[must_use]
27 pub fn with_pkce(mut self) -> Self {
28 self.pkce = Some(PkceChallenge::generate());
29 self
30 }
31
32 pub fn authorization_url(&self, scopes: Option<&[String]>, state: Option<&str>) -> Result<Url> {
45 let mut url = self.client.provider.auth_url.clone();
46
47 {
48 let mut pairs = url.query_pairs_mut();
49 pairs
50 .append_pair("client_id", &self.client.client_id)
51 .append_pair("response_type", "code");
52
53 if let Some(redirect_uri) = &self.client.redirect_uri {
54 pairs.append_pair("redirect_uri", redirect_uri);
55 }
56
57 let scope_str = scopes.map_or_else(
58 || self.client.provider.default_scopes.join(" "),
59 |s| s.join(" "),
60 );
61
62 if !scope_str.is_empty() {
63 pairs.append_pair("scope", &scope_str);
64 }
65
66 if let Some(state_val) = state {
67 pairs.append_pair("state", state_val);
68 }
69
70 if let Some(pkce) = &self.pkce {
71 pairs
72 .append_pair("code_challenge", pkce.challenge())
73 .append_pair("code_challenge_method", pkce.method());
74 }
75
76 match self.client.provider.name.as_str() {
78 "Google" => {
79 pairs
80 .append_pair("access_type", "offline")
81 .append_pair("prompt", "consent");
82 }
83 "Microsoft" => {
84 pairs.append_pair("prompt", "consent");
85 }
86 _ => {}
87 }
88 }
89
90 Ok(url)
91 }
92
93 pub async fn exchange_code(&self, code: &str, redirect_uri: Option<&str>) -> Result<Token> {
104 let code_verifier = self.pkce.as_ref().map(PkceChallenge::verifier);
105 self.client
106 .exchange_code(code, redirect_uri, code_verifier)
107 .await
108 }
109
110 #[must_use]
112 pub fn pkce_verifier(&self) -> Option<&str> {
113 self.pkce.as_ref().map(PkceChallenge::verifier)
114 }
115}
116
117#[cfg(test)]
118#[allow(clippy::unwrap_used, clippy::redundant_clone, clippy::manual_string_new, clippy::needless_collect, clippy::unreadable_literal, clippy::used_underscore_items, clippy::similar_names)]
119mod tests {
120 use super::*;
121 use crate::provider::Provider;
122
123 #[test]
124 fn test_authorization_url() {
125 let provider = Provider::google().unwrap();
126 let client =
127 OAuthClient::new("test_client", provider).with_redirect_uri("http://localhost:8080");
128
129 let flow = AuthorizationCodeFlow::new(client);
130 let url = flow.authorization_url(None, Some("random_state")).unwrap();
131
132 assert!(url.as_str().contains("client_id=test_client"));
133 assert!(url.as_str().contains("response_type=code"));
134 assert!(url.as_str().contains("state=random_state"));
135 assert!(
137 url.as_str()
138 .contains("redirect_uri=http%3A%2F%2Flocalhost%3A8080")
139 );
140 }
141
142 #[test]
143 fn test_authorization_url_with_pkce() {
144 let provider = Provider::google().unwrap();
145 let client = OAuthClient::new("test_client", provider);
146
147 let flow = AuthorizationCodeFlow::new(client).with_pkce();
148 let url = flow.authorization_url(None, None).unwrap();
149
150 assert!(url.as_str().contains("code_challenge="));
151 assert!(url.as_str().contains("code_challenge_method=S256"));
152 assert!(flow.pkce_verifier().is_some());
153 }
154
155 #[test]
156 fn test_authorization_url_custom_scopes() {
157 let provider = Provider::google().unwrap();
158 let client = OAuthClient::new("test_client", provider);
159
160 let flow = AuthorizationCodeFlow::new(client);
161 let scopes = vec!["email".to_string(), "profile".to_string()];
162 let url = flow.authorization_url(Some(&scopes), None).unwrap();
163
164 assert!(url.as_str().contains("scope=email+profile"));
166 }
167
168 #[test]
169 fn test_google_specific_params() {
170 let provider = Provider::google().unwrap();
171 let client = OAuthClient::new("test_client", provider);
172
173 let flow = AuthorizationCodeFlow::new(client);
174 let url = flow.authorization_url(None, None).unwrap();
175
176 assert!(url.as_str().contains("access_type=offline"));
177 assert!(url.as_str().contains("prompt=consent"));
178 }
179}