use super::{OAuthClient, PkceChallenge};
use crate::error::Result;
use crate::token::Token;
use url::Url;
#[derive(Debug)]
pub struct AuthorizationCodeFlow {
client: OAuthClient,
pkce: Option<PkceChallenge>,
}
impl AuthorizationCodeFlow {
#[must_use]
pub const fn new(client: OAuthClient) -> Self {
Self { client, pkce: None }
}
#[must_use]
pub fn with_pkce(mut self) -> Self {
self.pkce = Some(PkceChallenge::generate());
self
}
pub fn authorization_url(&self, scopes: Option<&[String]>, state: Option<&str>) -> Result<Url> {
let mut url = self.client.provider.auth_url.clone();
{
let mut pairs = url.query_pairs_mut();
pairs
.append_pair("client_id", &self.client.client_id)
.append_pair("response_type", "code");
if let Some(redirect_uri) = &self.client.redirect_uri {
pairs.append_pair("redirect_uri", redirect_uri);
}
let scope_str = scopes.map_or_else(
|| self.client.provider.default_scopes.join(" "),
|s| s.join(" "),
);
if !scope_str.is_empty() {
pairs.append_pair("scope", &scope_str);
}
if let Some(state_val) = state {
pairs.append_pair("state", state_val);
}
if let Some(pkce) = &self.pkce {
pairs
.append_pair("code_challenge", pkce.challenge())
.append_pair("code_challenge_method", pkce.method());
}
match self.client.provider.name.as_str() {
"Google" => {
pairs
.append_pair("access_type", "offline")
.append_pair("prompt", "consent");
}
"Microsoft" => {
pairs.append_pair("prompt", "consent");
}
_ => {}
}
}
Ok(url)
}
pub async fn exchange_code(&self, code: &str, redirect_uri: Option<&str>) -> Result<Token> {
let code_verifier = self.pkce.as_ref().map(PkceChallenge::verifier);
self.client
.exchange_code(code, redirect_uri, code_verifier)
.await
}
#[must_use]
pub fn pkce_verifier(&self) -> Option<&str> {
self.pkce.as_ref().map(PkceChallenge::verifier)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::redundant_clone, clippy::manual_string_new, clippy::needless_collect, clippy::unreadable_literal, clippy::used_underscore_items, clippy::similar_names)]
mod tests {
use super::*;
use crate::provider::Provider;
#[test]
fn test_authorization_url() {
let provider = Provider::google().unwrap();
let client =
OAuthClient::new("test_client", provider).with_redirect_uri("http://localhost:8080");
let flow = AuthorizationCodeFlow::new(client);
let url = flow.authorization_url(None, Some("random_state")).unwrap();
assert!(url.as_str().contains("client_id=test_client"));
assert!(url.as_str().contains("response_type=code"));
assert!(url.as_str().contains("state=random_state"));
assert!(
url.as_str()
.contains("redirect_uri=http%3A%2F%2Flocalhost%3A8080")
);
}
#[test]
fn test_authorization_url_with_pkce() {
let provider = Provider::google().unwrap();
let client = OAuthClient::new("test_client", provider);
let flow = AuthorizationCodeFlow::new(client).with_pkce();
let url = flow.authorization_url(None, None).unwrap();
assert!(url.as_str().contains("code_challenge="));
assert!(url.as_str().contains("code_challenge_method=S256"));
assert!(flow.pkce_verifier().is_some());
}
#[test]
fn test_authorization_url_custom_scopes() {
let provider = Provider::google().unwrap();
let client = OAuthClient::new("test_client", provider);
let flow = AuthorizationCodeFlow::new(client);
let scopes = vec!["email".to_string(), "profile".to_string()];
let url = flow.authorization_url(Some(&scopes), None).unwrap();
assert!(url.as_str().contains("scope=email+profile"));
}
#[test]
fn test_google_specific_params() {
let provider = Provider::google().unwrap();
let client = OAuthClient::new("test_client", provider);
let flow = AuthorizationCodeFlow::new(client);
let url = flow.authorization_url(None, None).unwrap();
assert!(url.as_str().contains("access_type=offline"));
assert!(url.as_str().contains("prompt=consent"));
}
}