use std::time::Instant;
use crate::broker::server::{
HandoffToken, HandoffTokenError, HandoffTokenStore, HANDOFF_TOKEN_BYTES,
};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct HandedOffPayload<T> {
pub expected_token: HandoffToken,
pub presented_token: Vec<u8>,
pub connection: T,
}
impl<T> HandedOffPayload<T> {
pub fn new(
expected_token: HandoffToken,
presented_token: impl Into<Vec<u8>>,
connection: T,
) -> Self {
Self {
expected_token,
presented_token: presented_token.into(),
connection,
}
}
pub fn presented_token(&self) -> &[u8] {
&self.presented_token
}
pub fn into_connection(self) -> T {
self.connection
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct AcceptedHandoff<T> {
pub token: HandoffToken,
pub connection: T,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RejectedHandoff<T> {
pub payload: HandedOffPayload<T>,
pub reason: HandoffRejectionReason,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum HandoffAcceptance<T> {
Accepted(AcceptedHandoff<T>),
Rejected(RejectedHandoff<T>),
}
impl<T> HandoffAcceptance<T> {
pub fn is_accepted(&self) -> bool {
matches!(self, Self::Accepted(_))
}
pub fn is_rejected(&self) -> bool {
matches!(self, Self::Rejected(_))
}
pub fn into_result(self) -> Result<AcceptedHandoff<T>, RejectedHandoff<T>> {
match self {
Self::Accepted(accepted) => Ok(accepted),
Self::Rejected(rejected) => Err(rejected),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
pub enum HandoffRejectionReason {
#[error("handoff token is missing")]
MissingToken,
#[error("handoff token length was {actual_len} bytes; expected {expected_len}")]
InvalidTokenLength {
actual_len: usize,
expected_len: usize,
},
#[error("handoff token mismatch")]
TokenMismatch,
#[error("handoff token expired")]
TokenExpired,
#[error("handoff token is not pending")]
TokenNotPending,
#[error("handoff token store error: {error}")]
TokenStore {
error: HandoffTokenError,
},
}
impl From<HandoffTokenError> for HandoffRejectionReason {
fn from(value: HandoffTokenError) -> Self {
match value {
HandoffTokenError::TokenMismatch => Self::TokenMismatch,
HandoffTokenError::TokenExpired => Self::TokenExpired,
HandoffTokenError::TokenNotPending => Self::TokenNotPending,
error => Self::TokenStore { error },
}
}
}
pub fn parse_handoff_token(token: &[u8]) -> Result<HandoffToken, HandoffRejectionReason> {
if token.is_empty() {
return Err(HandoffRejectionReason::MissingToken);
}
if token.len() != HANDOFF_TOKEN_BYTES {
return Err(HandoffRejectionReason::InvalidTokenLength {
actual_len: token.len(),
expected_len: HANDOFF_TOKEN_BYTES,
});
}
let mut bytes = [0_u8; HANDOFF_TOKEN_BYTES];
bytes.copy_from_slice(token);
Ok(HandoffToken::from_bytes(bytes))
}
pub fn accept_handed_off<T>(
pending_tokens: &mut HandoffTokenStore,
payload: HandedOffPayload<T>,
now: Instant,
) -> HandoffAcceptance<T> {
let presented = match parse_handoff_token(payload.presented_token()) {
Ok(token) => token,
Err(reason) => return reject(payload, reason),
};
match pending_tokens.consume_matching(&payload.expected_token, &presented, now) {
Ok(()) => HandoffAcceptance::Accepted(AcceptedHandoff {
token: payload.expected_token,
connection: payload.connection,
}),
Err(error) => reject(payload, error.into()),
}
}
fn reject<T>(payload: HandedOffPayload<T>, reason: HandoffRejectionReason) -> HandoffAcceptance<T> {
HandoffAcceptance::Rejected(RejectedHandoff { payload, reason })
}