use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use std::time::Instant;
use librice_proto::agent::{AgentError as ProtoAgentError, AgentPoll, AgentTransmit};
use librice_proto::component::ComponentConnectionState;
use crate::component::{Component, SelectedPair};
use crate::stream::Stream;
use librice_proto::candidate::TransportType;
#[derive(Debug)]
pub enum AgentError {
Proto(ProtoAgentError),
IoError(std::io::Error),
}
impl From<ProtoAgentError> for AgentError {
fn from(value: ProtoAgentError) -> Self {
Self::Proto(value)
}
}
impl From<std::io::Error> for AgentError {
fn from(value: std::io::Error) -> Self {
Self::IoError(value)
}
}
impl std::fmt::Display for AgentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
#[derive(Debug)]
pub struct Agent {
agent: Arc<Mutex<librice_proto::agent::Agent>>,
inner: Arc<Mutex<AgentInner>>,
}
#[derive(Debug, Default)]
pub struct AgentBuilder {
trickle_ice: bool,
controlling: bool,
}
impl AgentBuilder {
pub fn trickle_ice(mut self, trickle_ice: bool) -> Self {
self.trickle_ice = trickle_ice;
self
}
pub fn controlling(mut self, controlling: bool) -> Self {
self.controlling = controlling;
self
}
pub fn build(self) -> Agent {
let agent = Arc::new(Mutex::new(
librice_proto::agent::Agent::builder()
.trickle_ice(self.trickle_ice)
.controlling(self.controlling)
.build(),
));
Agent {
agent,
inner: Arc::new(Mutex::new(AgentInner::default())),
}
}
}
impl Default for Agent {
fn default() -> Self {
Agent::builder().build()
}
}
impl Agent {
pub fn builder() -> AgentBuilder {
AgentBuilder::default()
}
fn id(&self) -> usize {
self.agent.lock().unwrap().id()
}
pub fn messages(&self) -> impl futures::Stream<Item = AgentMessage> {
AgentStream {
agent: self.agent.clone(),
timer: None,
pending_transmit: None,
inner: self.inner.clone(),
}
}
#[tracing::instrument(
name = "ice_add_stream",
skip(self),
fields(
ice.id = self.id()
)
)]
pub fn add_stream(&self) -> Stream {
let stream_id = self.agent.lock().unwrap().add_stream();
let weak_proto_agent = Arc::downgrade(&self.agent);
let weak_inner = Arc::downgrade(&self.inner);
let ret = crate::stream::Stream::new(weak_proto_agent, weak_inner, stream_id);
let mut inner = self.inner.lock().unwrap();
inner.streams.push(ret.clone());
ret
}
pub fn stream(&self, id: usize) -> Option<Stream> {
let inner = self.inner.lock().unwrap();
inner.streams.get(id).cloned()
}
pub fn close(&self) -> Result<(), AgentError> {
let ret = self.agent.lock().unwrap().close();
let mut inner = self.inner.lock().unwrap();
if let Some(waker) = inner.waker.take() {
waker.wake();
}
ret.map_err(AgentError::Proto)
}
pub fn controlling(&self) -> bool {
self.agent.lock().unwrap().controlling()
}
pub fn add_stun_server(&self, transport: TransportType, addr: SocketAddr) {
self.agent.lock().unwrap().add_stun_server(transport, addr)
}
}
#[derive(Debug, Default)]
pub(crate) struct AgentInner {
pub(crate) waker: Option<Waker>,
streams: Vec<Stream>,
}
pub enum AgentMessage {
ComponentStateChange(Component, ComponentConnectionState),
}
#[derive(Debug)]
struct AgentStream {
agent: Arc<Mutex<librice_proto::agent::Agent>>,
inner: Arc<Mutex<AgentInner>>,
timer: Option<Pin<Box<async_io::Timer>>>,
pending_transmit: Option<AgentTransmit<'static>>,
}
impl futures::stream::Stream for AgentStream {
type Item = AgentMessage;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let Some(mut transmit) = self.pending_transmit.take() {
let mut inner = self.inner.lock().unwrap();
if let Some(stream) = inner.streams.get(transmit.stream_id) {
if let Some(retry) = stream.handle_transmit(transmit.transmit, cx.waker().clone()) {
transmit.transmit = retry;
inner.waker = Some(cx.waker().clone());
drop(inner);
self.as_mut().pending_transmit = Some(transmit.into_owned());
return Poll::Pending;
}
}
}
let weak_proto_agent = Arc::downgrade(&self.agent);
let weak_agent_inner = Arc::downgrade(&self.inner);
let mut agent = self.agent.lock().unwrap();
let now = Instant::now();
let wait = match agent.poll(now) {
AgentPoll::Closed => return Poll::Ready(None),
AgentPoll::Transmit(transmit) => {
let mut transmit = transmit.into_owned();
drop(agent);
let mut inner = self.inner.lock().unwrap();
if let Some(stream) = inner.streams.get(transmit.stream_id) {
if let Some(retry) =
stream.handle_transmit(transmit.transmit, cx.waker().clone())
{
transmit.transmit = retry;
inner.waker = Some(cx.waker().clone());
drop(inner);
self.as_mut().pending_transmit = Some(transmit);
return Poll::Pending;
}
}
cx.waker().wake_by_ref();
return Poll::Pending;
}
AgentPoll::TcpConnect(tcp_connect) => {
drop(agent);
let inner = self.inner.lock().unwrap();
if let Some(stream) = inner.streams.get(tcp_connect.stream_id) {
let weak_stream = Arc::downgrade(&stream.inner);
drop(inner);
Stream::handle_tcp_connect(
weak_stream,
weak_proto_agent,
weak_agent_inner,
tcp_connect.stream_id,
tcp_connect.component_id,
tcp_connect.from,
tcp_connect.to,
cx.waker().clone(),
);
}
cx.waker().wake_by_ref();
return Poll::Pending;
}
AgentPoll::WaitUntil(time) => Some(time),
AgentPoll::SelectedPair(pair) => {
drop(agent);
let inner = self.inner.lock().unwrap();
if let Some(stream) = inner.streams.get(pair.stream_id) {
if let Some(component) = stream.component(pair.component_id) {
if let Some(socket) = stream.socket_for_pair(pair.selected.candidate_pair())
{
component.set_selected_pair(SelectedPair::new(*pair.selected, socket));
}
}
}
cx.waker().wake_by_ref();
return Poll::Pending;
}
AgentPoll::ComponentStateChange(state) => {
drop(agent);
let inner = self.inner.lock().unwrap();
if let Some(stream) = inner.streams.get(state.stream_id) {
if let Some(component) = stream.component(state.component_id) {
return Poll::Ready(Some(AgentMessage::ComponentStateChange(
component.clone(),
state.state,
)));
}
}
cx.waker().wake_by_ref();
return Poll::Pending;
}
};
drop(agent);
if let Some(wait) = wait {
match self.as_mut().timer.as_mut() {
Some(timer) => timer.set_at(wait),
None => self.as_mut().timer = Some(Box::pin(async_io::Timer::at(wait))),
}
if core::future::Future::poll(self.as_mut().timer.as_mut().unwrap().as_mut(), cx)
.is_pending()
{
let mut agent_inner = self.inner.lock().unwrap();
agent_inner.waker = Some(cx.waker().clone());
return Poll::Pending;
}
cx.waker().wake_by_ref();
return Poll::Pending;
}
let mut agent_inner = self.inner.lock().unwrap();
agent_inner.waker = Some(cx.waker().clone());
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::*;
fn init() {
crate::tests::test_init_log();
}
#[test]
fn controlling() {
init();
let agent = Agent::builder().controlling(true).build();
assert!(agent.controlling());
let agent = Agent::builder().controlling(false).build();
assert!(!agent.controlling());
}
}