openauth-oauth 0.0.2

OAuth support for OpenAuth.
Documentation
use std::collections::BTreeMap;

use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use url::form_urlencoded::Serializer;

use super::error::OAuthError;
use super::tokens::{get_primary_client_id, ProviderOptions};

#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum ClientAuthentication {
    #[default]
    Post,
    Basic,
}

#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct OAuthFormRequest {
    pub body: Vec<(String, String)>,
    pub headers: BTreeMap<String, String>,
}

impl OAuthFormRequest {
    pub fn new() -> Self {
        Self {
            body: Vec::new(),
            headers: BTreeMap::from([
                (
                    "content-type".to_owned(),
                    "application/x-www-form-urlencoded".to_owned(),
                ),
                ("accept".to_owned(), "application/json".to_owned()),
            ]),
        }
    }

    pub fn push_body(&mut self, key: impl Into<String>, value: impl Into<String>) {
        self.body.push((key.into(), value.into()));
    }

    pub fn set_body(&mut self, key: impl Into<String>, value: impl Into<String>) {
        let key = key.into();
        self.body.retain(|(existing, _)| existing != &key);
        self.body.push((key, value.into()));
    }

    pub fn has_body(&self, key: &str) -> bool {
        self.body.iter().any(|(existing, _)| existing == key)
    }

    pub fn form_value(&self, key: &str) -> Option<&str> {
        self.body
            .iter()
            .find(|(existing, _)| existing == key)
            .map(|(_, value)| value.as_str())
    }

    pub fn form_values(&self, key: &str) -> Vec<&str> {
        self.body
            .iter()
            .filter(|(existing, _)| existing == key)
            .map(|(_, value)| value.as_str())
            .collect()
    }

    pub fn header(&self, key: &str) -> Option<&str> {
        self.headers
            .get(&key.to_ascii_lowercase())
            .map(String::as_str)
    }

    pub fn set_header(&mut self, key: impl Into<String>, value: impl Into<String>) {
        self.headers
            .insert(key.into().to_ascii_lowercase(), value.into());
    }

    pub fn to_form_urlencoded(&self) -> String {
        let mut serializer = Serializer::new(String::new());
        for (key, value) in &self.body {
            serializer.append_pair(key, value);
        }
        serializer.finish()
    }
}

pub fn apply_client_authentication(
    request: &mut OAuthFormRequest,
    options: &ProviderOptions,
    authentication: ClientAuthentication,
    require_secret: bool,
) -> Result<(), OAuthError> {
    let primary_client_id = get_primary_client_id(&options.client_id);
    let client_secret = options.client_secret.as_deref();

    match authentication {
        ClientAuthentication::Basic => {
            let client_id = primary_client_id.unwrap_or("");
            let client_secret = if require_secret {
                client_secret.ok_or(OAuthError::MissingOption("client_secret"))?
            } else {
                client_secret.unwrap_or("")
            };
            let credentials = STANDARD.encode(format!("{client_id}:{client_secret}"));
            request.set_header("authorization", format!("Basic {credentials}"));
        }
        ClientAuthentication::Post => {
            if let Some(client_id) = primary_client_id {
                request.set_body("client_id", client_id);
            }
            if let Some(client_secret) = client_secret {
                request.set_body("client_secret", client_secret);
            } else if require_secret {
                return Err(OAuthError::MissingOption("client_secret"));
            }
        }
    }

    Ok(())
}

pub async fn post_form(
    token_endpoint: &str,
    request: OAuthFormRequest,
) -> Result<serde_json::Value, OAuthError> {
    let client = reqwest::Client::new();
    let mut builder = client.post(token_endpoint);
    for (key, value) in &request.headers {
        builder = builder.header(key, value);
    }
    let response = builder
        .body(request.to_form_urlencoded())
        .send()
        .await?
        .error_for_status()?;
    response
        .json::<serde_json::Value>()
        .await
        .map_err(Into::into)
}