use std::cmp::Ordering;
use std::sync::Arc;
use std::sync::atomic::{self, AtomicU64};
use std::time::{Duration, SystemTime};
use tor_cell::relaycell::{AnyRelayMsgOuter, RelayCmd, StreamId, UnparsedRelayMsg};
use tor_error::{Bug, internal};
use crate::Error;
use crate::crypto::cell::HopNum;
pub(crate) struct ConfluxMsgHandler {
handler: Box<dyn AbstractConfluxMsgHandler + Send + Sync>,
last_seq_delivered: Arc<AtomicU64>,
}
impl ConfluxMsgHandler {
pub(crate) fn new(
handler: Box<dyn AbstractConfluxMsgHandler + Send + Sync>,
last_seq_delivered: Arc<AtomicU64>,
) -> Self {
Self {
handler,
last_seq_delivered,
}
}
fn validate_source_hop(&self, msg: &UnparsedRelayMsg, hop: HopNum) -> crate::Result<()> {
self.handler.validate_source_hop(msg, hop)
}
pub(crate) fn handle_conflux_msg(
&mut self,
msg: UnparsedRelayMsg,
hop: HopNum,
) -> Option<ConfluxCmd> {
let res = (|| {
let () = self.validate_source_hop(&msg, hop)?;
self.handler.handle_msg(msg, hop)
})();
match res {
Ok(cmd) => cmd,
Err(e) => {
Some(ConfluxCmd::RemoveLeg(RemoveLegReason::ConfluxHandshakeErr(
e,
)))
}
}
}
pub(crate) fn note_link_sent(&mut self, ts: SystemTime) -> Result<(), Bug> {
self.handler.note_link_sent(ts)
}
pub(crate) fn handshake_timeout(&self) -> Option<SystemTime> {
self.handler.handshake_timeout()
}
pub(crate) fn status(&self) -> ConfluxStatus {
self.handler.status()
}
fn is_msg_in_order(&self) -> Result<bool, Bug> {
let last_seq_delivered = self.last_seq_delivered.load(atomic::Ordering::Acquire);
match self.handler.last_seq_recv().cmp(&(last_seq_delivered + 1)) {
Ordering::Less => {
Err(internal!(
"Got a conflux cell with a sequence number less than the last delivered"
))
}
Ordering::Equal => Ok(true),
Ordering::Greater => Ok(false),
}
}
fn prepare_ooo_entry(
&self,
hopnum: HopNum,
cell_counts_towards_windows: bool,
streamid: StreamId,
msg: UnparsedRelayMsg,
) -> OooRelayMsg {
OooRelayMsg {
seqno: self.handler.last_seq_recv(),
hopnum,
cell_counts_towards_windows,
streamid,
msg,
}
}
#[cfg(feature = "conflux")]
pub(crate) fn action_for_msg(
&mut self,
hopnum: HopNum,
cell_counts_towards_windows: bool,
streamid: StreamId,
msg: UnparsedRelayMsg,
) -> Result<ConfluxAction, Bug> {
if !super::cmd_counts_towards_seqno(msg.cmd()) {
return Ok(ConfluxAction::Deliver(msg));
}
self.handler.inc_last_seq_recv();
let action = if self.is_msg_in_order()? {
ConfluxAction::Deliver(msg)
} else {
ConfluxAction::Enqueue(self.prepare_ooo_entry(
hopnum,
cell_counts_towards_windows,
streamid,
msg,
))
};
Ok(action)
}
pub(crate) fn inc_last_seq_delivered(&self, msg: &UnparsedRelayMsg) {
if super::cmd_counts_towards_seqno(msg.cmd()) {
self.last_seq_delivered
.fetch_add(1, atomic::Ordering::AcqRel);
}
}
pub(crate) fn init_rtt(&self) -> Option<Duration> {
self.handler.init_rtt()
}
pub(crate) fn last_seq_sent(&self) -> u64 {
self.handler.last_seq_sent()
}
pub(crate) fn set_last_seq_sent(&mut self, n: u64) {
self.handler.set_last_seq_sent(n);
}
pub(crate) fn last_seq_recv(&self) -> u64 {
self.handler.last_seq_recv()
}
pub(crate) fn note_cell_sent(&mut self, cmd: RelayCmd) {
if super::cmd_counts_towards_seqno(cmd) {
self.handler.inc_last_seq_sent();
}
}
}
#[derive(Debug)]
#[cfg(feature = "conflux")]
pub(crate) enum ConfluxAction {
Deliver(UnparsedRelayMsg),
Enqueue(OooRelayMsg),
}
pub(crate) trait AbstractConfluxMsgHandler {
fn validate_source_hop(&self, msg: &UnparsedRelayMsg, hop: HopNum) -> crate::Result<()>;
fn handle_msg(
&mut self,
msg: UnparsedRelayMsg,
hop: HopNum,
) -> crate::Result<Option<ConfluxCmd>>;
fn status(&self) -> ConfluxStatus;
fn note_link_sent(&mut self, ts: SystemTime) -> Result<(), Bug>;
fn handshake_timeout(&self) -> Option<SystemTime>;
fn init_rtt(&self) -> Option<Duration>;
fn last_seq_recv(&self) -> u64;
fn last_seq_sent(&self) -> u64;
fn set_last_seq_sent(&mut self, n: u64);
fn inc_last_seq_recv(&mut self);
fn inc_last_seq_sent(&mut self);
}
#[derive(Debug)]
pub(crate) struct OooRelayMsg {
pub(crate) seqno: u64,
pub(crate) hopnum: HopNum,
pub(crate) cell_counts_towards_windows: bool,
pub(crate) streamid: StreamId,
pub(crate) msg: UnparsedRelayMsg,
}
impl Ord for OooRelayMsg {
fn cmp(&self, other: &Self) -> Ordering {
self.seqno.cmp(&other.seqno).reverse()
}
}
impl PartialOrd for OooRelayMsg {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for OooRelayMsg {
fn eq(&self, other: &Self) -> bool {
self.seqno == other.seqno
}
}
impl Eq for OooRelayMsg {}
#[derive(Debug)]
pub(crate) enum ConfluxCmd {
RemoveLeg(RemoveLegReason),
HandshakeComplete {
hop: HopNum,
early: bool,
cell: AnyRelayMsgOuter,
},
}
#[derive(Debug, derive_more::Display)]
pub(crate) enum RemoveLegReason {
#[display("conflux handshake timed out")]
ConfluxHandshakeTimeout,
#[display("{}", _0)]
ConfluxHandshakeErr(Error),
#[display("channel closed")]
ChannelClosed,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum ConfluxStatus {
Unlinked,
Pending,
Linked,
}