use std::borrow::Cow;
use primitives::grant::{GrantExtension, Value};
use base64;
use ring::digest::{SHA256, digest};
use ring::constant_time::verify_slices_are_equal;
pub struct Pkce {
required: bool,
allow_plain: bool,
}
enum Method {
Plain(String),
Sha256(String),
}
impl Pkce {
pub fn required() -> Pkce {
Pkce {
required: true,
allow_plain: false,
}
}
pub fn optional() -> Pkce {
Pkce {
required: false,
allow_plain: false,
}
}
pub fn allow_plain(&mut self) {
self.allow_plain = true;
}
pub fn challenge(&self, method: Option<Cow<str>>, challenge: Option<Cow<str>>)
-> Result<Option<Value>, ()>
{
let method = method.unwrap_or(Cow::Borrowed("plain"));
let challenge = match challenge {
None if self.required => return Err(()),
None => return Ok(None),
Some(challenge) => challenge,
};
let method = Method::from_parameter(method, challenge)?;
let method = method.assert_supported_method(self.allow_plain)?;
Ok(Some(Value::private(Some(method.encode()))))
}
pub fn verify(&self, method: Option<Value>, verifier: Option<Cow<str>>) -> Result<(), ()>
{
let (method, verifier) = match (method, verifier) {
(None, _) if self.required => return Err(()),
(None, _) => return Ok(()),
(Some(_), None) => return Err(()),
(Some(method), Some(verifier)) => (method, verifier),
};
let method = match method.as_private() {
Ok(Some(method)) => method,
_ => return Err(()),
};
let method = Method::from_encoded(Cow::Owned(method))?;
method.verify(&verifier)
}
}
impl GrantExtension for Pkce {
fn identifier(&self) -> &'static str {
"pkce"
}
}
fn b64encode(data: &[u8]) -> String {
base64::encode_config(data, base64::URL_SAFE_NO_PAD)
}
impl Method {
fn from_parameter(method: Cow<str>, challenge: Cow<str>) -> Result<Self, ()> {
match method.as_ref() {
"plain" => Ok(Method::Plain(challenge.into_owned())),
"S256" => Ok(Method::Sha256(challenge.into_owned())),
_ => Err(()),
}
}
fn assert_supported_method(self, allow_plain: bool) -> Result<Self, ()> {
match (self, allow_plain) {
(this, true) => Ok(this),
(Method::Sha256(content), false) => Ok(Method::Sha256(content)),
(Method::Plain(_), false) => Err(()),
}
}
fn encode(self) -> String {
match self {
Method::Plain(challenge) => challenge + "p",
Method::Sha256(challenge) => challenge + "S",
}
}
fn from_encoded(encoded: Cow<str>) -> Result<Method, ()> {
let mut encoded = encoded.into_owned();
match encoded.pop() {
None => Err(()),
Some('p') => Ok(Method::Plain(encoded)),
Some('S') => Ok(Method::Sha256(encoded)),
_ => Err(())
}
}
fn verify(&self, verifier: &str) -> Result<(), ()> {
match self {
Method::Plain(encoded) =>
verify_slices_are_equal(encoded.as_bytes(), verifier.as_bytes())
.map_err(|_| ()),
Method::Sha256(encoded) => {
let digest = digest(&SHA256, verifier.as_bytes());
let b64digest = b64encode(digest.as_ref());
verify_slices_are_equal(encoded.as_bytes(), b64digest.as_bytes())
.map_err(|_| ())
}
}
}
}