use std::{
net::SocketAddr,
pin::Pin,
task::{Context, Poll, Waker},
time::{Duration, Instant},
};
use bytes::Bytes;
use futures::{ready, FutureExt};
use msf_stun as stun;
use tokio::time::Sleep;
use crate::{
candidate::{CandidateKind, CandidatePair, LocalCandidate, RemoteCandidate},
AgentRole,
};
const RTO: u64 = 500;
const RM: u64 = 16;
const RC: u32 = 7;
pub struct Check {
pair: CandidatePair,
state: CheckState,
task: Option<Waker>,
nominated: bool,
}
impl Check {
pub fn new(pair: CandidatePair, nominated: bool) -> Self {
Self {
pair,
state: CheckState::Frozen,
task: None,
nominated,
}
}
fn transaction(&self) -> Option<&CheckTransaction> {
if let CheckState::InProgress(t) = &self.state {
Some(t)
} else {
None
}
}
pub fn transaction_id(&self) -> Option<TransactionId> {
self.transaction().map(|t| t.id())
}
pub fn is_nominated(&self) -> bool {
self.nominated
}
pub fn local_candidate(&self) -> &LocalCandidate {
self.pair.local()
}
pub fn remote_candidate(&self) -> &RemoteCandidate {
self.pair.remote()
}
pub fn candidate_pair(&self) -> &CandidatePair {
&self.pair
}
pub fn component(&self) -> u8 {
self.pair.component()
}
pub fn priority(&self, local_role: AgentRole) -> u64 {
self.pair.priority(local_role)
}
pub fn foundation(&self) -> &str {
self.pair.foundation()
}
pub fn is_frozen(&self) -> bool {
matches!(self.state, CheckState::Frozen)
}
pub fn is_waiting(&self) -> bool {
matches!(self.state, CheckState::Waiting)
}
pub fn is_success(&self) -> bool {
matches!(self.state, CheckState::Succeeded)
}
pub fn is_done(&self) -> bool {
matches!(
self.state,
CheckState::Succeeded | CheckState::Cancelled | CheckState::Failed
)
}
pub fn unfreeze(&mut self) {
debug_assert!(self.is_frozen());
self.state = CheckState::Waiting;
}
pub fn trigger(&mut self) {
self.state = CheckState::Waiting;
}
pub fn finish(&mut self) {
match &mut self.state {
CheckState::Frozen | CheckState::Waiting => self.state = CheckState::Cancelled,
CheckState::InProgress(t) => t.finish(),
_ => (),
}
if let Some(task) = self.task.take() {
task.wake();
}
}
pub fn cancel(&mut self) {
if self.is_done() {
return;
}
self.state = CheckState::Cancelled;
if let Some(task) = self.task.take() {
task.wake();
}
}
pub fn schedule(
&mut self,
username: &str,
password: &str,
agent_role: AgentRole,
tie_breaker: u64,
) {
debug_assert!(matches!(
self.state,
CheckState::Frozen | CheckState::Waiting
));
let transaction_id = rand::random();
let local = self.pair.local();
let remote = self.pair.remote();
let addr = local.addr();
let priority = LocalCandidate::calculate_priority(
remote.component(),
CandidateKind::PeerReflexive,
addr,
);
let mut builder = stun::MessageBuilder::binding_request(transaction_id);
builder
.username(username)
.priority(priority)
.message_integrity(password.as_bytes())
.fingerprint(true);
if agent_role == AgentRole::Controlling {
if self.nominated {
builder.use_candidate(true);
}
builder.ice_controlling(tie_breaker);
} else {
builder.ice_controlled(tie_breaker);
}
let msg = CheckMessage {
local_addr: local.base(),
remote_addr: remote.addr(),
component: remote.component(),
data: builder.build(),
};
let transaction = CheckTransaction::new(agent_role, transaction_id, msg);
self.state = CheckState::InProgress(Box::new(transaction));
if let Some(task) = self.task.take() {
task.wake();
}
}
pub fn process_stun_response(
&mut self,
local_addr: SocketAddr,
remote_addr: SocketAddr,
response: &stun::Message,
) -> Result<(), CheckError> {
let transaction = self.transaction().unwrap();
debug_assert_eq!(transaction.id(), response.transaction_id());
let res = if let Some(err) = response.attributes().get_error_code() {
if err.code() == 487 {
Err(CheckError::RoleConflict(transaction.agent_role()))
} else {
Err(CheckError::Failed)
}
} else {
let local_candidate = self.pair.local();
let remote_candidate = self.pair.remote();
if local_addr == local_candidate.base() && remote_addr == remote_candidate.addr() {
Ok(())
} else {
Err(CheckError::Failed)
}
};
self.state = match &res {
Ok(_) => CheckState::Succeeded,
Err(CheckError::RoleConflict(_)) => CheckState::Waiting,
Err(_) => CheckState::Failed,
};
res
}
pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<CheckMessage>> {
match &mut self.state {
CheckState::Frozen | CheckState::Waiting => {
let task = cx.waker();
self.task = Some(task.clone());
Poll::Pending
}
CheckState::InProgress(t) => {
if let Some(msg) = ready!(t.poll(cx)) {
Poll::Ready(Some(msg))
} else {
self.state = CheckState::Failed;
Poll::Ready(None)
}
}
_ => Poll::Ready(None),
}
}
pub fn update(&mut self, other: Check) {
if !self.nominated && other.nominated {
*self = other;
} else if self.nominated == other.nominated {
match (&self.state, &other.state) {
(CheckState::Frozen, _) => *self = other,
(CheckState::Waiting, CheckState::InProgress(_)) => *self = other,
(_, CheckState::Succeeded) => *self = other,
(_, CheckState::Failed) => *self = other,
(_, CheckState::Cancelled) => *self = other,
_ => (),
}
}
}
}
enum CheckState {
Frozen,
Waiting,
InProgress(Box<CheckTransaction>),
Succeeded,
Failed,
Cancelled,
}
pub type TransactionId = [u8; 12];
struct CheckTransaction {
id: TransactionId,
timeout: Pin<Box<Sleep>>,
next_timeout: Duration,
last_timeout: Duration,
remaining_attempts: u32,
message: CheckMessage,
agent_role: AgentRole,
last_attempt: Option<Instant>,
task: Option<Waker>,
}
impl CheckTransaction {
fn new(agent_role: AgentRole, id: TransactionId, message: CheckMessage) -> Self {
Self {
id,
timeout: Box::pin(tokio::time::sleep(Duration::from_millis(0))),
next_timeout: Duration::from_millis(RTO),
last_timeout: Duration::from_millis(RTO * RM),
remaining_attempts: RC,
message,
agent_role,
last_attempt: None,
task: None,
}
}
fn id(&self) -> TransactionId {
self.id
}
fn agent_role(&self) -> AgentRole {
self.agent_role
}
fn finish(&mut self) {
if let Some(last_attempt) = self.last_attempt {
let deadline = std::cmp::min(
last_attempt + self.last_timeout,
Instant::now() + Duration::from_millis(1_000),
);
self.timeout.as_mut().reset(deadline.into());
self.remaining_attempts = 0;
} else {
self.remaining_attempts = 1;
}
if let Some(task) = self.task.take() {
task.wake();
}
}
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<CheckMessage>> {
let mut res = Poll::Pending;
loop {
let poll = self.timeout.poll_unpin(cx);
if poll.is_pending() {
let task = cx.waker();
if res.is_pending() {
self.task = Some(task.clone());
}
return res;
}
let timeout = if self.remaining_attempts == 0 {
return Poll::Ready(None);
} else if self.remaining_attempts == 1 {
self.last_timeout
} else {
self.next_timeout
};
let now = Instant::now();
self.last_attempt = Some(now);
let deadline = now + timeout;
self.timeout.as_mut().reset(deadline.into());
res = Poll::Ready(Some(self.message.clone()));
self.remaining_attempts -= 1;
self.next_timeout *= 2;
}
}
}
#[derive(Clone)]
pub struct CheckMessage {
local_addr: SocketAddr,
remote_addr: SocketAddr,
component: u8,
data: Bytes,
}
impl CheckMessage {
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
pub fn component(&self) -> u8 {
self.component
}
pub fn take_data(self) -> Bytes {
self.data
}
}
#[derive(Copy, Clone)]
pub enum CheckError {
Failed,
RoleConflict(AgentRole),
UnknownTransaction,
}