use {
crate::{
UniqueId,
discovery::PeerEntry,
primitives::{Abbreviated, Short},
},
bytes::Bytes,
chrono::{DateTime, Utc},
core::cmp::Ordering,
derive_more::Display,
humansize::{DECIMAL, format_size},
serde::{Deserialize, Serialize},
std::sync::Arc,
};
pub type TicketClass = UniqueId;
pub use chrono::Duration;
#[derive(
Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash,
)]
pub struct Ticket {
pub class: TicketClass,
pub data: Bytes,
}
impl Ticket {
pub const fn new(class: TicketClass, data: Bytes) -> Self {
Self { class, data }
}
pub fn id(&self) -> UniqueId {
self.class.derive(&self.data)
}
pub fn data(&self) -> &[u8] {
&self.data
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Expiration {
Never,
At(DateTime<Utc>),
}
impl PartialOrd for Expiration {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Expiration {
fn cmp(&self, other: &Self) -> Ordering {
match (self, other) {
(Self::Never, Self::Never) => Ordering::Equal,
(Self::Never, _) => Ordering::Greater,
(_, Self::Never) => Ordering::Less,
(Self::At(a), Self::At(b)) => a.cmp(b),
}
}
}
impl Expiration {
pub fn is_expired(&self) -> bool {
match self {
Self::Never => false,
Self::At(t) => *t < Utc::now(),
}
}
pub fn remaining(&self) -> Option<core::time::Duration> {
match self {
Self::Never => None,
Self::At(t) => (*t - Utc::now()).to_std().ok(),
}
}
pub const fn never() -> Self {
Self::Never
}
pub fn already_expired() -> Self {
Self::At(Utc::now() - chrono::Duration::hours(1))
}
pub fn after(duration: impl Into<chrono::Duration>) -> Self {
Self::At(Utc::now() + duration.into())
}
pub fn in_n_hours(hours: i64) -> Self {
Self::after(chrono::Duration::hours(hours))
}
pub fn in_1h() -> Self {
Self::in_n_hours(1)
}
pub fn in_3h() -> Self {
Self::in_n_hours(3)
}
pub fn in_6h() -> Self {
Self::in_n_hours(6)
}
pub fn in_24h() -> Self {
Self::in_n_hours(24)
}
}
impl core::fmt::Debug for Ticket {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Ticket")
.field("class", &self.class)
.field("data", &hex::encode(&self.data))
.field("len", &self.data.len())
.finish()
}
}
impl core::fmt::Debug for Short<&Ticket> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Ticket")
.field("class", &Short(&self.0.class))
.field("data", &Abbreviated::<8, _>(&self.0.data))
.field(
"size",
&format_args!(
"{}",
format_size(self.0.data.len() + self.0.class.len(), DECIMAL)
),
)
.finish()
}
}
impl core::fmt::Display for Ticket {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Ticket({},{})",
self.class,
format_size(self.data.len() + self.class.len(), DECIMAL)
)
}
}
#[derive(Debug, Clone, Copy, Display, thiserror::Error)]
pub struct InvalidTicket;
pub trait TicketValidator: Send + Sync + 'static {
fn class(&self) -> TicketClass;
fn signature(&self) -> UniqueId;
fn validate(
&self,
ticket: &[u8],
peer: &PeerEntry,
) -> Result<Expiration, InvalidTicket>;
}
impl<T: TicketValidator + ?Sized> TicketValidator for Box<T> {
fn class(&self) -> TicketClass {
(**self).class()
}
fn signature(&self) -> UniqueId {
(**self).signature()
}
fn validate(
&self,
ticket: &[u8],
peer: &PeerEntry,
) -> Result<Expiration, InvalidTicket> {
(**self).validate(ticket, peer)
}
}
impl<T: TicketValidator + ?Sized> TicketValidator for Arc<T> {
fn class(&self) -> TicketClass {
(**self).class()
}
fn signature(&self) -> UniqueId {
(**self).signature()
}
fn validate(
&self,
ticket: &[u8],
peer: &PeerEntry,
) -> Result<Expiration, InvalidTicket> {
(**self).validate(ticket, peer)
}
}