use std::{fmt, str::FromStr};
use bon::Builder;
use const_macros::const_early;
use miette::Diagnostic;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::{
auth::{
part::{self, Part, SEPARATOR},
query::Query,
url::{self, Url},
utf8,
},
macros::errors,
};
#[derive(Debug, Error, Diagnostic)]
#[error("empty label encountered")]
#[diagnostic(
code(otp_std::auth::label::empty),
help("make sure the label is non-empty")
)]
pub struct EmptyError;
#[derive(Debug, Error, Diagnostic)]
#[error(transparent)]
#[diagnostic(transparent)]
pub enum ParseErrorSource {
Empty(#[from] EmptyError),
Part(#[from] part::Error),
}
#[derive(Debug, Error, Diagnostic)]
#[error("failed to parse label")]
#[diagnostic(
code(otp_std::auth::label),
help("make sure the label is formatted correctly")
)]
pub struct ParseError {
#[source]
#[diagnostic_source]
pub source: ParseErrorSource,
}
impl ParseError {
pub const fn new(source: ParseErrorSource) -> Self {
Self { source }
}
pub fn empty(error: EmptyError) -> Self {
Self::new(error.into())
}
pub fn part(error: part::Error) -> Self {
Self::new(error.into())
}
pub fn new_empty() -> Self {
Self::empty(EmptyError)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Builder)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Label<'l> {
pub issuer: Option<Part<'l>>,
pub user: Part<'l>,
}
pub type Parts<'p> = (Option<Part<'p>>, Part<'p>);
pub type OwnedParts = Parts<'static>;
impl<'l> Label<'l> {
pub fn from_parts(parts: Parts<'l>) -> Self {
let (issuer, user) = parts;
Self::builder().maybe_issuer(issuer).user(user).build()
}
pub fn into_parts(self) -> Parts<'l> {
(self.issuer, self.user)
}
}
impl<'p> From<Parts<'p>> for Label<'p> {
fn from(parts: Parts<'p>) -> Self {
Self::from_parts(parts)
}
}
impl<'l> From<Label<'l>> for Parts<'l> {
fn from(label: Label<'l>) -> Self {
label.into_parts()
}
}
impl fmt::Display for Label<'_> {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(issuer) = self.issuer.as_ref() {
issuer.fmt(formatter)?;
formatter.write_str(SEPARATOR)?;
};
self.user.fmt(formatter)
}
}
errors! {
Type = ParseError,
Hack = $,
empty_error => new_empty(),
}
impl FromStr for Label<'_> {
type Err = ParseError;
fn from_str(string: &str) -> Result<Self, Self::Err> {
const_early!(string.is_empty() => empty_error!());
if let Some((issuer_string, user_string)) = string.split_once(SEPARATOR) {
let issuer = issuer_string.parse().map_err(Self::Err::part)?;
let user = user_string.parse().map_err(Self::Err::part)?;
Ok(Self::builder().issuer(issuer).user(user).build())
} else {
let user = string.parse().map_err(Self::Err::part)?;
Ok(Self::builder().user(user).build())
}
}
}
#[derive(Debug, Error, Diagnostic)]
#[error(transparent)]
#[diagnostic(transparent)]
pub enum DecodeErrorSource {
Utf8(#[from] utf8::Error),
Parse(#[from] ParseError),
}
#[derive(Debug, Error, Diagnostic)]
#[error("failed to decode label")]
#[diagnostic(
code(otp_std::auth::label::decode),
help("make sure the label is correctly formatted")
)]
pub struct DecodeError {
#[source]
#[diagnostic_source]
pub source: DecodeErrorSource,
}
impl DecodeError {
pub const fn new(source: DecodeErrorSource) -> Self {
Self { source }
}
pub fn utf8(error: utf8::Error) -> Self {
Self::new(error.into())
}
pub fn label(error: ParseError) -> Self {
Self::new(error.into())
}
}
impl Label<'_> {
pub fn decode<S: AsRef<str>>(string: S) -> Result<Self, DecodeError> {
let string = string.as_ref();
let decoded = url::decode(string)
.map_err(utf8::wrap)
.map_err(DecodeError::utf8)?;
decoded.parse().map_err(DecodeError::label)
}
}
impl Label<'_> {
pub fn encode(&self) -> String {
self.to_string()
}
}
#[derive(Debug, Error, Diagnostic)]
#[error("issuer mismatch: `{label}` in label, `{query}` in query")]
#[diagnostic(
code(otp_std::auth::label::mismatch),
help("if the issuer is present both in the label and the query, they must match")
)]
pub struct MismatchError {
pub label: String,
pub query: String,
}
impl MismatchError {
pub const fn new(label: String, query: String) -> Self {
Self { label, query }
}
}
errors! {
Type = MismatchError,
Hack = $,
mismatch_error => new(label => into_owned, query => into_owned),
}
pub fn try_match<'p>(
label_issuer: Option<Part<'p>>,
query_issuer: Option<Part<'p>>,
) -> Result<Option<Part<'p>>, MismatchError> {
match (label_issuer, query_issuer) {
(Some(label), Some(query)) if label != query => {
Err(mismatch_error!(label.get(), query.get()))
}
(label_option, query_option) => Ok(label_option.or(query_option)),
}
}
#[derive(Debug, Error, Diagnostic)]
#[error(transparent)]
#[diagnostic(transparent)]
pub enum ErrorSource {
Decode(#[from] DecodeError),
Issuer(#[from] part::DecodeError),
Mismatch(#[from] MismatchError),
}
#[derive(Debug, Error, Diagnostic)]
#[error("failed to extract label from OTP URL")]
#[diagnostic(
code(otp_std::auth::label),
help("see the report for more information")
)]
pub struct Error {
#[source]
#[diagnostic_source]
pub source: ErrorSource,
}
impl Error {
pub const fn new(source: ErrorSource) -> Self {
Self { source }
}
pub fn decode(error: DecodeError) -> Self {
Self::new(error.into())
}
pub fn mismatch(error: MismatchError) -> Self {
Self::new(error.into())
}
pub fn issuer(error: part::DecodeError) -> Self {
Self::new(error.into())
}
}
pub const ISSUER: &str = "issuer";
pub const SLASH: &str = "/";
impl Label<'_> {
pub fn query_for(&self, url: &mut Url) {
if let Some(issuer) = self.issuer.as_ref() {
url.query_pairs_mut()
.append_pair(ISSUER, issuer.encode().as_ref());
};
}
pub fn extract_from(query: &mut Query<'_>, url: &Url) -> Result<Self, Error> {
let path = url.path().trim_start_matches(SLASH);
let label = Self::decode(path).map_err(Error::decode)?;
let (label_issuer, user) = label.into_parts();
let query_issuer = query
.remove(ISSUER)
.map(Part::decode)
.transpose()
.map_err(Error::issuer)?;
let issuer = try_match(label_issuer, query_issuer).map_err(Error::mismatch)?;
Ok(Self::from_parts((issuer, user)))
}
}
pub type Owned = Label<'static>;
impl Label<'_> {
pub fn into_owned(self) -> Owned {
Owned::builder()
.maybe_issuer(self.issuer.map(Part::into_owned))
.user(self.user.into_owned())
.build()
}
}