use std::borrow::Cow;
use crate::primitives::grant::{GrantExtension, Value};
use base64::{self, engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
pub struct Pkce {
required: bool,
allow_plain: bool,
}
enum Method {
Plain(String),
Sha256(String),
}
impl Pkce {
pub fn required() -> Pkce {
Pkce {
required: true,
allow_plain: false,
}
}
pub fn optional() -> Pkce {
Pkce {
required: false,
allow_plain: false,
}
}
pub fn allow_plain(&mut self) {
self.allow_plain = true;
}
pub fn challenge(
&self, method: Option<Cow<str>>, challenge: Option<Cow<str>>,
) -> Result<Option<Value>, ()> {
let method = method.unwrap_or(Cow::Borrowed("plain"));
let challenge = match challenge {
None if self.required => return Err(()),
None => return Ok(None),
Some(challenge) => challenge,
};
let method = Method::from_parameter(method, challenge)?;
let method = method.assert_supported_method(self.allow_plain)?;
Ok(Some(Value::private(Some(method.encode()))))
}
pub fn verify(&self, method: Option<Value>, verifier: Option<Cow<str>>) -> Result<(), ()> {
let (method, verifier) = match (method, verifier) {
(None, _) if self.required => return Err(()),
(None, _) => return Ok(()),
(Some(_), None) => return Err(()),
(Some(method), Some(verifier)) => (method, verifier),
};
let method = match method.into_private_value() {
Ok(Some(method)) => method,
_ => return Err(()),
};
let method = Method::from_encoded(Cow::Owned(method))?;
method.verify(&verifier)
}
}
impl GrantExtension for Pkce {
fn identifier(&self) -> &'static str {
"pkce"
}
}
fn b64encode(data: &[u8]) -> String {
URL_SAFE_NO_PAD.encode(data)
}
impl Method {
fn from_parameter(method: Cow<str>, challenge: Cow<str>) -> Result<Self, ()> {
match method.as_ref() {
"plain" => Ok(Method::Plain(challenge.into_owned())),
"S256" => Ok(Method::Sha256(challenge.into_owned())),
_ => Err(()),
}
}
fn assert_supported_method(self, allow_plain: bool) -> Result<Self, ()> {
match (self, allow_plain) {
(this, true) => Ok(this),
(Method::Sha256(content), false) => Ok(Method::Sha256(content)),
(Method::Plain(_), false) => Err(()),
}
}
fn encode(self) -> String {
match self {
Method::Plain(challenge) => challenge + "p",
Method::Sha256(challenge) => challenge + "S",
}
}
fn from_encoded(encoded: Cow<str>) -> Result<Method, ()> {
let mut encoded = encoded.into_owned();
match encoded.pop() {
None => Err(()),
Some('p') => Ok(Method::Plain(encoded)),
Some('S') => Ok(Method::Sha256(encoded)),
_ => Err(()),
}
}
fn verify(&self, verifier: &str) -> Result<(), ()> {
match self {
Method::Plain(encoded) => {
if encoded.as_bytes().ct_eq(verifier.as_bytes()).into() {
Ok(())
} else {
Err(())
}
}
Method::Sha256(encoded) => {
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let b64digest = b64encode(&hasher.finalize());
if encoded.as_bytes().ct_eq(b64digest.as_bytes()).into() {
Ok(())
} else {
Err(())
}
}
}
}
}