use url::Url;
use crate::{
config::{AuthEndPoint, ClientID, Config, RedirectURI},
csrf_token::{CSRFToken, RawCSRFToken},
error::Error,
nonce::Nonce,
};
use std::{collections::HashMap, iter::Iterator};
#[derive(Debug, Clone, PartialEq)]
pub struct Code(pub(crate) String);
impl Code {
pub fn new_with_verify_csrf(res: RawCodeResponse, csrf_token_val: &str) -> Result<Self, Error> {
if res.state.0 == csrf_token_val {
Ok(res.code)
} else {
Err(Error::CSRFNotMatch)
}
}
}
#[derive(Debug, Clone)]
pub struct CodeRequest<'a> {
auth_endpoint: &'a AuthEndPoint,
client_id: &'a ClientID,
response_type: &'a str,
scope: AdditionalScope,
redirect_uri: &'a RedirectURI,
access_type: AccessType,
state: &'a CSRFToken,
nonce: &'a Nonce,
}
impl<'a> CodeRequest<'a> {
pub fn new(
access_type: AccessType,
config: &'a Config,
scope: AdditionalScope,
state: &'a CSRFToken,
nonce: &'a Nonce,
) -> Self {
Self {
auth_endpoint: &config.auth_endpoint,
client_id: &config.client_id,
response_type: "code",
scope,
redirect_uri: &config.redirect_uri,
access_type,
state,
nonce,
}
}
pub fn try_into_url(self) -> Result<Url, Error> {
let access_type = match self.access_type {
AccessType::Online => "online",
AccessType::Offline => "offline",
};
let scope = match self.scope {
AdditionalScope::Email => "openid email",
AdditionalScope::Profile => "openid profile",
AdditionalScope::Both => "openid email profile",
AdditionalScope::None => "openid",
};
let url = format!(
"{}?response_type={}&client_id={}&scope={}&access_type={}&redirect_uri={}&state={}&nonce={}",
self.auth_endpoint.0,
self.response_type,
self.client_id.0,
scope,
access_type,
self.redirect_uri.0,
self.state.0,
self.nonce.0,
);
let url = Url::parse(&url).map_err(|_| Error::ParseURL)?;
Ok(url)
}
}
#[derive(Debug, Clone)]
pub struct RawCodeResponse {
state: RawCSRFToken,
code: Code,
}
impl RawCodeResponse {
pub fn new<Q>(query_src: Q) -> Result<Self, Error>
where
Q: QueryExtractor,
{
let query_str = query_src.extract_query().ok_or(Error::ParamsNotFound)?;
let params: HashMap<_, _> = url::form_urlencoded::parse(query_str.as_bytes()).collect();
Ok(Self {
state: RawCSRFToken(
params
.get("state")
.ok_or(Error::ParamsNotFound)?
.to_string(),
),
code: Code(params.get("code").ok_or(Error::ParamsNotFound)?.to_string()),
})
}
pub fn exchange_with_code(self, csrf_token_val: &str) -> Result<Code, Error> {
if self.state.0 == csrf_token_val {
Ok(self.code)
} else {
Err(Error::CSRFNotMatch)
}
}
}
pub trait QueryExtractor {
fn extract_query(&self) -> Option<&str>;
}
impl<T> QueryExtractor for http::Request<T> {
fn extract_query(&self) -> Option<&str> {
self.uri().query()
}
}
impl<T: QueryExtractor + ?Sized> QueryExtractor for &T {
fn extract_query(&self) -> Option<&str> {
(*self).extract_query()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AccessType {
Online,
Offline,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AdditionalScope {
Email,
Profile,
Both,
None,
}
#[cfg(test)]
mod tests {
use url::Url;
use crate::{code::AccessType, config::ConfigBuilder, csrf_token::CSRFToken, nonce::Nonce};
use super::{AdditionalScope, CodeRequest, RawCodeResponse};
#[test]
fn test_code_req_offline() {
let access_type = AccessType::Offline;
let auth_endpoint = "https://auth.example.com/auth";
let client_id = "my_client_id";
let client_secret = "my_secret";
let token_endpoint = "https://token.example.com";
let redirect_uri = "https://redirect.example.com";
let config = ConfigBuilder::new()
.auth_endpoint(auth_endpoint)
.client_id(client_id)
.client_secret(client_secret)
.token_endpoint(token_endpoint)
.redirect_uri(redirect_uri)
.build();
let scope = AdditionalScope::Both;
let state = CSRFToken::new().unwrap();
let nonce = Nonce::new();
let code_req =
CodeRequest::new(access_type.clone(), &config, scope.clone(), &state, &nonce);
assert_eq!(code_req.access_type, access_type);
assert_eq!(code_req.auth_endpoint.0, auth_endpoint);
assert_eq!(code_req.client_id.0, client_id);
assert_eq!(code_req.redirect_uri.0, redirect_uri);
assert_eq!(*code_req.state, state);
assert_eq!(*code_req.nonce, nonce);
assert_eq!(code_req.scope, scope);
}
#[test]
fn test_code_req_new_some_scope() {
let access_type = AccessType::Online;
let auth_endpoint = "https://auth.example.com/auth";
let client_id = "my_client_id";
let client_secret = "my_secret";
let token_endpoint = "https://token.example.com";
let redirect_uri = "https://redirect.example.com";
let config = ConfigBuilder::new()
.auth_endpoint(auth_endpoint)
.client_id(client_id)
.client_secret(client_secret)
.token_endpoint(token_endpoint)
.redirect_uri(redirect_uri)
.build();
let scope = AdditionalScope::Both;
let state = CSRFToken::new().unwrap();
let nonce = Nonce::new();
let code_req =
CodeRequest::new(access_type.clone(), &config, scope.clone(), &state, &nonce);
assert_eq!(code_req.access_type, access_type);
assert_eq!(code_req.auth_endpoint.0, auth_endpoint);
assert_eq!(code_req.client_id.0, client_id);
assert_eq!(code_req.redirect_uri.0, redirect_uri);
assert_eq!(*code_req.state, state);
assert_eq!(*code_req.nonce, nonce);
assert_eq!(code_req.scope, scope);
}
#[test]
fn test_code_req_new_none_scope() {
let access_type = AccessType::Online;
let auth_endpoint = "https://auth.example.com/auth";
let client_id = "my_client_id";
let client_secret = "my_secret";
let token_endpoint = "https://token.example.com";
let redirect_uri = "https://redirect.example.com";
let config = ConfigBuilder::new()
.auth_endpoint(auth_endpoint)
.client_id(client_id)
.client_secret(client_secret)
.token_endpoint(token_endpoint)
.redirect_uri(redirect_uri)
.build();
let scope = AdditionalScope::None;
let state = CSRFToken::new().unwrap();
let nonce = Nonce::new();
let code_req =
CodeRequest::new(access_type.clone(), &config, scope.clone(), &state, &nonce);
assert_eq!(code_req.access_type, access_type);
assert_eq!(code_req.auth_endpoint.0, auth_endpoint);
assert_eq!(code_req.client_id.0, client_id);
assert_eq!(code_req.redirect_uri.0, redirect_uri);
assert_eq!(*code_req.state, state);
assert_eq!(*code_req.nonce, nonce);
assert_eq!(code_req.scope, scope);
}
#[test]
fn test_code_req_into_url() {
let access_type = AccessType::Online;
let auth_endpoint = "https://auth.example.com/auth";
let client_id = "my_client_id";
let client_secret = "my_secret";
let token_endpoint = "https://token.example.com";
let redirect_url = "https://redirect.example.com";
let config = ConfigBuilder::new()
.auth_endpoint(auth_endpoint)
.client_id(client_id)
.client_secret(client_secret)
.token_endpoint(token_endpoint)
.redirect_uri(redirect_url)
.build();
let scope = AdditionalScope::Both;
let state = CSRFToken::new().unwrap();
let nonce = Nonce::new();
let code_req = CodeRequest::new(access_type, &config, scope, &state, &nonce);
let url = code_req.try_into_url().unwrap();
let expected_url = format!(
"{}?response_type={}&client_id={}&scope={}&access_type={}&redirect_uri={}&state={}&nonce={}",
auth_endpoint,
"code",
client_id,
"openid email profile",
"online",
redirect_url,
state.0,
nonce.0,
);
assert_eq!(url, Url::parse(&expected_url).unwrap());
}
#[test]
fn test_code_req_into_url_scope_one() {
let access_type = AccessType::Online;
let auth_endpoint = "https://auth.example.com/auth";
let client_id = "my_client_id";
let client_secret = "my_secret";
let token_endpoint = "https://token.example.com";
let redirect_url = "https://redirect.example.com";
let config = ConfigBuilder::new()
.auth_endpoint(auth_endpoint)
.client_id(client_id)
.client_secret(client_secret)
.token_endpoint(token_endpoint)
.redirect_uri(redirect_url)
.build();
let scope = AdditionalScope::Email;
let state = CSRFToken::new().unwrap();
let nonce = Nonce::new();
let code_req = CodeRequest::new(access_type, &config, scope, &state, &nonce);
let url = code_req.try_into_url().unwrap();
let expected_url = format!(
"{}?response_type={}&client_id={}&scope={}&access_type={}&redirect_uri={}&state={}&nonce={}",
auth_endpoint,
"code",
client_id,
"openid email",
"online",
redirect_url,
state.0,
nonce.0,
);
assert_eq!(url, Url::parse(&expected_url).unwrap());
}
#[test]
fn test_code_req_into_url_scope_none() {
let access_type = AccessType::Online;
let auth_endpoint = "https://auth.example.com/auth";
let client_id = "my_client_id";
let client_secret = "my_secret";
let token_endpoint = "https://token.example.com";
let redirect_url = "https://redirect.example.com";
let config = ConfigBuilder::new()
.auth_endpoint(auth_endpoint)
.client_id(client_id)
.client_secret(client_secret)
.token_endpoint(token_endpoint)
.redirect_uri(redirect_url)
.build();
let scope = AdditionalScope::None;
let state = CSRFToken::new().unwrap();
let nonce = Nonce::new();
let code_req = CodeRequest::new(access_type, &config, scope, &state, &nonce);
let url = code_req.try_into_url().unwrap();
let expected_url = format!(
"{}?response_type={}&client_id={}&scope={}&access_type={}&redirect_uri={}&state={}&nonce={}",
auth_endpoint, "code", client_id, "openid", "online", redirect_url, state.0, nonce.0,
);
assert_eq!(url, Url::parse(&expected_url).unwrap());
}
#[test]
fn test_construct_uncheck_code_res() {
let code = "mycode";
let state = "mystate";
let uri = format!("https://www.example.com/autu?code={}&state={}", code, state);
let http_req = http::Request::builder().uri(uri).body(()).unwrap();
let raw_code_res = RawCodeResponse::new(http_req);
assert!(raw_code_res.is_ok());
assert_eq!(raw_code_res.clone().unwrap().state.0, "mystate");
assert_eq!(raw_code_res.unwrap().code.0, "mycode");
}
#[test]
fn test_construct_uncheck_code_res_none_params() {
let uri = format!("https://www.example.com/");
let http_req = http::Request::builder().uri(uri).body(()).unwrap();
let raw_code_res = RawCodeResponse::new(http_req);
assert!(raw_code_res.is_err());
}
}