use alloc::boxed::Box;
use alloc::vec::Vec;
use core::error::Error;
use core::fmt::Display;
use core::net::SocketAddr;
use core::sync::atomic::AtomicU64;
use core::time::Duration;
use stun_proto::Instant;
use stun_proto::types::data::Data;
use crate::candidate::{ParseCandidateError, TransportType};
use crate::component::ComponentConnectionState;
use crate::conncheck::{CheckListSetPollRet, ConnCheckEvent, ConnCheckListSet, SelectedPair};
use crate::gathering::{GatherPoll, GatheredCandidate};
use crate::rand::rand_u64;
use crate::stream::{Stream, StreamMut, StreamState};
use crate::turn::TurnConfig;
use stun_proto::agent::{StunError, Transmit};
use stun_proto::types::message::StunParseError;
use tracing::{info, warn};
#[derive(Debug)]
pub enum AgentError {
Failed,
AlreadyExists,
AlreadyInProgress,
NotInProgress,
ResourceNotFound,
Malformed,
WrongImplementation,
ConnectionClosed,
StunParse,
StunWrite,
CandidateParse(ParseCandidateError),
ProtocolViolation,
}
impl Error for AgentError {}
impl Display for AgentError {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "{self:?}")
}
}
impl From<ParseCandidateError> for AgentError {
fn from(e: ParseCandidateError) -> Self {
Self::CandidateParse(e)
}
}
impl From<StunError> for AgentError {
fn from(e: StunError) -> Self {
match e {
StunError::ResourceNotFound => AgentError::ResourceNotFound,
StunError::ProtocolViolation => AgentError::ProtocolViolation,
StunError::ParseError(_) => AgentError::StunParse,
StunError::WriteError(_) => AgentError::StunWrite,
StunError::AlreadyInProgress => AgentError::AlreadyInProgress,
_ => AgentError::Failed,
}
}
}
impl From<StunParseError> for AgentError {
fn from(_e: StunParseError) -> Self {
Self::StunParse
}
}
#[derive(Debug)]
pub struct Agent {
id: u64,
pub(crate) checklistset: ConnCheckListSet,
pub(crate) stun_servers: Vec<(TransportType, SocketAddr)>,
pub(crate) turn_servers: Vec<TurnConfig>,
streams: Vec<StreamState>,
}
#[derive(Debug, Default)]
pub struct AgentBuilder {
trickle_ice: bool,
controlling: bool,
}
impl AgentBuilder {
pub fn trickle_ice(mut self, trickle_ice: bool) -> Self {
self.trickle_ice = trickle_ice;
self
}
pub fn controlling(mut self, controlling: bool) -> Self {
self.controlling = controlling;
self
}
pub fn build(self) -> Agent {
turn_client_proto::types::debug_init();
rice_stun_types::debug_init();
let id = AGENT_COUNT.fetch_add(1, core::sync::atomic::Ordering::SeqCst);
let tie_breaker = rand_u64();
let controlling = self.controlling;
Agent {
id,
checklistset: ConnCheckListSet::builder(tie_breaker, controlling)
.trickle_ice(self.trickle_ice)
.build(),
stun_servers: Vec::new(),
turn_servers: Vec::new(),
streams: Vec::new(),
}
}
}
static AGENT_COUNT: AtomicU64 = AtomicU64::new(0);
impl Default for Agent {
fn default() -> Self {
Agent::builder().build()
}
}
impl Agent {
pub fn builder() -> AgentBuilder {
AgentBuilder::default()
}
pub fn id(&self) -> u64 {
self.id
}
#[tracing::instrument(
name = "ice_add_stream",
skip(self),
fields(
ice.id = self.id
)
)]
pub fn add_stream(&mut self) -> usize {
let checklist_id = self.checklistset.new_list();
let id = self.streams.len();
let stream = crate::stream::StreamState::new(id, checklist_id);
self.streams.push(stream);
id
}
#[tracing::instrument(
name = "ice_close",
skip(self),
fields(
ice.id = self.id
)
)]
pub fn close(&mut self, now: Instant) {
info!("closing agent");
self.checklistset.close(now);
}
pub fn controlling(&self) -> bool {
self.checklistset.controlling()
}
#[tracing::instrument(
name = "ice_add_stun_server",
skip(self)
fields(ice.id = self.id)
)]
pub fn add_stun_server(&mut self, transport: TransportType, addr: SocketAddr) {
info!("Adding stun server");
self.stun_servers.push((transport, addr));
}
pub fn stun_servers(&self) -> &[(TransportType, SocketAddr)] {
&self.stun_servers
}
#[tracing::instrument(
name = "ice_add_turn_server",
skip(self)
fields(ice.id = self.id)
)]
pub fn add_turn_server(&mut self, config: TurnConfig) {
info!("Adding turn server");
self.turn_servers.push(config);
}
pub fn turn_servers(&self) -> &[TurnConfig] {
&self.turn_servers
}
pub fn stream(&self, id: usize) -> Option<crate::stream::Stream<'_>> {
if self.streams.get(id).is_some() {
Some(Stream::from_agent(self, id))
} else {
None
}
}
pub(crate) fn stream_state(&self, id: usize) -> Option<&crate::stream::StreamState> {
self.streams.get(id)
}
pub fn mut_stream(&mut self, id: usize) -> Option<StreamMut<'_>> {
if self.streams.get_mut(id).is_some() {
Some(StreamMut::from_agent(self, id))
} else {
None
}
}
pub(crate) fn mut_stream_state(
&mut self,
id: usize,
) -> Option<&mut crate::stream::StreamState> {
self.streams.get_mut(id)
}
#[tracing::instrument(
name = "agent_poll",
ret
skip(self)
fields(
id = self.id,
)
)]
pub fn poll(&mut self, now: Instant) -> AgentPoll {
let mut lowest_wait = None;
for stream in self.streams.iter_mut() {
let stream_id = stream.id();
match stream.poll_gather(now) {
GatherPoll::AllocateSocket {
component_id,
transport,
local_addr,
remote_addr,
} => {
return AgentPoll::AllocateSocket(AgentSocket {
stream_id,
component_id,
transport,
from: local_addr,
to: remote_addr,
});
}
GatherPoll::WaitUntil(earliest_wait) => {
if let Some(check_wait) = lowest_wait {
if earliest_wait < check_wait {
lowest_wait = Some(earliest_wait);
}
} else {
lowest_wait = Some(earliest_wait);
}
}
GatherPoll::NewCandidate(candidate) => {
return AgentPoll::GatheredCandidate(AgentGatheredCandidate {
stream_id,
gathered: candidate,
});
}
GatherPoll::Complete(component_id) => {
return AgentPoll::GatheringComplete(AgentGatheringComplete {
stream_id,
component_id,
});
}
GatherPoll::Finished => (),
}
}
loop {
match self.checklistset.poll(now) {
CheckListSetPollRet::Closed => return AgentPoll::Closed,
CheckListSetPollRet::Completed => continue,
CheckListSetPollRet::WaitUntil(earliest_wait) => {
if let Some(check_wait) = lowest_wait {
if earliest_wait < check_wait {
lowest_wait = Some(earliest_wait);
}
} else {
lowest_wait = Some(earliest_wait);
}
break;
}
CheckListSetPollRet::AllocateSocket {
checklist_id,
component_id: cid,
transport,
local_addr: from,
remote_addr: to,
} => {
if let Some(stream) =
self.streams.iter().find(|s| s.checklist_id == checklist_id)
{
return AgentPoll::AllocateSocket(AgentSocket {
stream_id: stream.id(),
component_id: cid,
transport,
from,
to,
});
} else {
warn!("did not find stream for allocate socket {from:?} -> {to:?}");
}
}
CheckListSetPollRet::RemoveSocket {
checklist_id,
component_id: cid,
transport,
local_addr: from,
remote_addr: to,
} => {
if let Some(stream) =
self.streams.iter().find(|s| s.checklist_id == checklist_id)
{
return AgentPoll::RemoveSocket(AgentSocket {
stream_id: stream.id(),
component_id: cid,
transport,
from,
to,
});
} else {
warn!("did not find stream for remove socket {from:?} -> {to:?}");
}
}
CheckListSetPollRet::Event {
checklist_id,
event: ConnCheckEvent::ComponentState(cid, state),
} => {
if let Some(stream) = self
.streams
.iter_mut()
.find(|s| s.checklist_id == checklist_id)
{
if let Some(component) = stream.mut_component_state(cid) {
if component.set_state(state) {
return AgentPoll::ComponentStateChange(
AgentComponentStateChange {
stream_id: stream.id(),
component_id: cid,
state,
},
);
}
}
}
}
CheckListSetPollRet::Event {
checklist_id,
event: ConnCheckEvent::SelectedPair(cid, selected),
} => {
if let Some(stream) =
self.streams.iter().find(|s| s.checklist_id == checklist_id)
{
if stream.component_state(cid).is_some() {
return AgentPoll::SelectedPair(AgentSelectedPair {
stream_id: stream.id(),
component_id: cid,
selected,
});
}
}
}
}
}
AgentPoll::WaitUntil(lowest_wait.unwrap_or_else(|| now + Duration::from_secs(600)))
}
pub fn poll_transmit(&mut self, now: Instant) -> Option<AgentTransmit> {
for stream in self.streams.iter_mut() {
let stream_id = stream.id();
if let Some((_component_id, transmit)) = stream.poll_gather_transmit(now) {
return Some(AgentTransmit::from_data(stream_id, transmit));
}
}
let transmit = self.checklistset.poll_transmit(now)?;
if let Some(stream) = self
.streams
.iter()
.find(|s| s.checklist_id == transmit.checklist_id)
{
Some(AgentTransmit {
stream_id: stream.id(),
transmit: transmit.transmit,
})
} else {
warn!(
"did not find stream for transmit {:?} -> {:?}",
transmit.transmit.from, transmit.transmit.to
);
None
}
}
}
#[derive(Debug)]
pub enum AgentPoll {
Closed,
WaitUntil(Instant),
AllocateSocket(AgentSocket),
RemoveSocket(AgentSocket),
SelectedPair(AgentSelectedPair),
ComponentStateChange(AgentComponentStateChange),
GatheredCandidate(AgentGatheredCandidate),
GatheringComplete(AgentGatheringComplete),
}
#[derive(Debug)]
pub struct AgentTransmit {
pub stream_id: usize,
pub transmit: Transmit<Box<[u8]>>,
}
impl AgentTransmit {
fn from_data(stream_id: usize, transmit: Transmit<Data<'_>>) -> Self {
Self {
stream_id,
transmit: transmit.reinterpret_data(|data| {
let Data::Owned(owned) = data.into_owned() else {
unreachable!();
};
owned.take()
}),
}
}
}
#[derive(Debug)]
pub struct AgentSocket {
pub stream_id: usize,
pub component_id: usize,
pub transport: TransportType,
pub from: SocketAddr,
pub to: SocketAddr,
}
#[derive(Debug)]
pub struct AgentSelectedPair {
pub stream_id: usize,
pub component_id: usize,
pub selected: Box<SelectedPair>,
}
#[derive(Debug)]
#[repr(C)]
pub struct AgentComponentStateChange {
pub stream_id: usize,
pub component_id: usize,
pub state: ComponentConnectionState,
}
#[derive(Debug)]
#[repr(C)]
pub struct AgentGatheredCandidate {
pub stream_id: usize,
pub gathered: GatheredCandidate,
}
#[derive(Debug)]
#[repr(C)]
pub struct AgentGatheringComplete {
pub stream_id: usize,
pub component_id: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn controlling() {
let _log = crate::tests::test_init_log();
let agent = Agent::builder().controlling(true).build();
assert!(agent.controlling());
let agent = Agent::builder().controlling(false).build();
assert!(!agent.controlling());
}
}