use std::{
borrow::Cow,
fmt,
hash::{Hash, Hasher},
};
use const_macros::{const_map_err, const_none, const_ok, const_try};
use constant_time_eq::constant_time_eq;
#[cfg(feature = "static")]
use into_static::IntoStatic;
#[cfg(feature = "diagnostics")]
use miette::Diagnostic;
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
use thiserror::Error;
use crate::{
challenge::Challenge,
check::string::{self, const_check_str},
count::{self, Count},
encoding, generate,
length::{self, Length},
method::Method,
};
pub const ERROR: &str = "invalid verifier; check the length and characters";
#[derive(Debug, Error)]
#[cfg_attr(feature = "diagnostics", derive(Diagnostic))]
pub enum Error {
#[error("invalid verifier length")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(pkce_std::verifier::length),
help("check the length of the verifier")
)
)]
Length(#[from] length::Error),
#[error("verifier contains invalid character(s)")]
#[cfg_attr(
feature = "diagnostics",
diagnostic(
code(pkce_std::verifier::check),
help("make sure the verifier is composed of valid characters only")
)
)]
String(#[from] string::Error),
}
#[derive(Debug, Clone)]
pub struct Verifier<'v> {
value: Cow<'v, str>,
}
#[cfg(feature = "serde")]
impl Serialize for Verifier<'_> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.get().serialize(serializer)
}
}
#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for Verifier<'_> {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let value = Cow::deserialize(deserializer)?;
Self::new(value).map_err(de::Error::custom)
}
}
impl fmt::Display for Verifier<'_> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
self.get().fmt(formatter)
}
}
impl Verifier<'_> {
pub fn get(&self) -> &str {
self.value.as_ref()
}
}
impl AsRef<str> for Verifier<'_> {
fn as_ref(&self) -> &str {
self.get()
}
}
impl PartialEq for Verifier<'_> {
fn eq(&self, other: &Self) -> bool {
constant_time_eq(self.get().as_bytes(), other.get().as_bytes())
}
}
impl Eq for Verifier<'_> {}
impl Hash for Verifier<'_> {
fn hash<H: Hasher>(&self, hasher: &mut H) {
self.get().hash(hasher);
}
}
impl Verifier<'_> {
pub fn generate(length: Length) -> Self {
unsafe { Self::owned_unchecked(generate::string(length)) }
}
pub fn generate_default() -> Self {
Self::generate(Length::default())
}
pub fn generate_encode(count: Count) -> Self {
unsafe { Self::encode_unchecked(generate::bytes(count)) }
}
pub fn generate_encode_default() -> Self {
Self::generate_encode(Count::default())
}
}
impl Verifier<'_> {
pub fn challenge_using(&self, method: Method) -> Challenge {
Challenge::create_using(method, self)
}
pub fn challenge(&self) -> Challenge {
self.challenge_using(Method::default())
}
pub fn verify(&self, challenge: &Challenge) -> bool {
let expected = self.challenge_using(challenge.method());
challenge == &expected
}
}
impl<'v> Verifier<'v> {
pub fn new(value: Cow<'v, str>) -> Result<Self, Error> {
Self::check(value.as_ref())?;
Ok(unsafe { Self::new_unchecked(value) })
}
pub const unsafe fn new_unchecked(value: Cow<'v, str>) -> Self {
Self { value }
}
pub fn borrowed(value: &'v str) -> Result<Self, Error> {
Self::new(Cow::Borrowed(value))
}
pub const unsafe fn borrowed_unchecked(value: &'v str) -> Self {
unsafe { Self::new_unchecked(Cow::Borrowed(value)) }
}
pub fn owned(value: String) -> Result<Self, Error> {
Self::new(Cow::Owned(value))
}
pub const unsafe fn owned_unchecked(value: String) -> Self {
unsafe { Self::new_unchecked(Cow::Owned(value)) }
}
pub const fn const_borrowed(value: &'v str) -> Result<Self, Error> {
const_try!(Self::const_check_str(value));
Ok(unsafe { Self::borrowed_unchecked(value) })
}
pub const fn const_borrowed_ok(value: &'v str) -> Option<Self> {
const_none!(const_ok!(Self::const_check_str(value)));
Some(unsafe { Self::borrowed_unchecked(value) })
}
pub const fn const_check_str(string: &str) -> Result<(), Error> {
const_try!(const_map_err!(Length::check(string.len()) => Error::Length));
const_try!(const_map_err!(const_check_str(string) => Error::String));
Ok(())
}
pub fn check_str(string: &str) -> Result<(), Error> {
Length::check(string.len())?;
string::check_str(string)?;
Ok(())
}
pub fn check<S: AsRef<str>>(value: S) -> Result<(), Error> {
Self::check_str(value.as_ref())
}
pub fn take(self) -> Cow<'v, str> {
self.value
}
}
impl Verifier<'_> {
pub fn encode<B: AsRef<[u8]>>(bytes: B) -> Result<Self, count::Error> {
Count::check(bytes.as_ref().len())?;
Ok(unsafe { Self::encode_unchecked(bytes) })
}
pub unsafe fn encode_unchecked<B: AsRef<[u8]>>(bytes: B) -> Self {
let string = encoding::encode(bytes);
unsafe { Self::owned_unchecked(string) }
}
}
#[macro_export]
macro_rules! const_borrowed_verifier {
($value: expr) => {
$crate::verifier::Verifier::const_borrowed_ok($value).expect($crate::verifier::ERROR)
};
}
#[cfg(feature = "static")]
pub type StaticVerifier = Verifier<'static>;
#[cfg(feature = "static")]
impl IntoStatic for Verifier<'_> {
type Static = StaticVerifier;
fn into_static(self) -> Self::Static {
unsafe { Self::Static::new_unchecked(self.value.into_static()) }
}
}