openauth-oauth 0.0.2

OAuth support for OpenAuth.
Documentation
use std::collections::BTreeMap;

use serde_json::json;
use url::Url;

use super::error::OAuthError;
use super::tokens::{get_primary_client_id, ProviderOptions};
use super::utils::generate_code_challenge;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AuthorizationUrlRequest {
    pub id: String,
    pub options: ProviderOptions,
    pub authorization_endpoint: String,
    pub redirect_uri: String,
    pub state: String,
    pub code_verifier: Option<String>,
    pub scopes: Vec<String>,
    pub claims: Vec<String>,
    pub duration: Option<String>,
    pub prompt: Option<String>,
    pub access_type: Option<String>,
    pub response_type: Option<String>,
    pub display: Option<String>,
    pub login_hint: Option<String>,
    pub hd: Option<String>,
    pub response_mode: Option<String>,
    pub additional_params: BTreeMap<String, String>,
    pub scope_joiner: String,
}

impl Default for AuthorizationUrlRequest {
    fn default() -> Self {
        Self {
            id: String::new(),
            options: ProviderOptions::default(),
            authorization_endpoint: String::new(),
            redirect_uri: String::new(),
            state: String::new(),
            code_verifier: None,
            scopes: Vec::new(),
            claims: Vec::new(),
            duration: None,
            prompt: None,
            access_type: None,
            response_type: None,
            display: None,
            login_hint: None,
            hd: None,
            response_mode: None,
            additional_params: BTreeMap::new(),
            scope_joiner: " ".to_owned(),
        }
    }
}

pub fn create_authorization_url(input: AuthorizationUrlRequest) -> Result<Url, OAuthError> {
    let endpoint = input
        .options
        .authorization_endpoint
        .as_deref()
        .unwrap_or(&input.authorization_endpoint);
    let mut url = Url::parse(endpoint)?;
    let client_id = get_primary_client_id(&input.options.client_id)
        .ok_or(OAuthError::MissingOption("client_id"))?;
    {
        let mut query = url.query_pairs_mut();
        query.append_pair(
            "response_type",
            input.response_type.as_deref().unwrap_or("code"),
        );
        query.append_pair("client_id", client_id);
        query.append_pair("state", &input.state);
        if !input.scopes.is_empty() {
            query.append_pair("scope", &input.scopes.join(&input.scope_joiner));
        }
        query.append_pair(
            "redirect_uri",
            input
                .options
                .redirect_uri
                .as_deref()
                .unwrap_or(&input.redirect_uri),
        );
        append_optional(&mut query, "duration", input.duration.as_deref());
        append_optional(&mut query, "display", input.display.as_deref());
        append_optional(&mut query, "login_hint", input.login_hint.as_deref());
        append_optional(&mut query, "prompt", input.prompt.as_deref());
        append_optional(&mut query, "hd", input.hd.as_deref());
        append_optional(&mut query, "access_type", input.access_type.as_deref());
        append_optional(&mut query, "response_mode", input.response_mode.as_deref());
        if let Some(code_verifier) = input.code_verifier {
            query.append_pair("code_challenge_method", "S256");
            query.append_pair("code_challenge", &generate_code_challenge(&code_verifier)?);
        }
        if !input.claims.is_empty() {
            let mut id_token = serde_json::Map::from_iter([
                ("email".to_owned(), serde_json::Value::Null),
                ("email_verified".to_owned(), serde_json::Value::Null),
            ]);
            for claim in input.claims {
                id_token.insert(claim, serde_json::Value::Null);
            }
            query.append_pair("claims", &json!({ "id_token": id_token }).to_string());
        }
    }
    if !input.additional_params.is_empty() {
        let mut pairs = url.query_pairs().into_owned().collect::<Vec<_>>();
        for (key, value) in input.additional_params {
            pairs.retain(|(existing, _)| existing != &key);
            pairs.push((key, value));
        }
        url.set_query(None);
        for (key, value) in pairs {
            url.query_pairs_mut().append_pair(&key, &value);
        }
    }
    Ok(url)
}

fn append_optional(
    query: &mut url::form_urlencoded::Serializer<'_, url::UrlQuery<'_>>,
    key: &str,
    value: Option<&str>,
) {
    if let Some(value) = value {
        query.append_pair(key, value);
    }
}