use crate::packet::ControlTypes::*;
use crate::protocol::connection::{Connection, ConnectionAction};
use crate::protocol::handshake::Handshake;
use crate::protocol::receiver::{Receiver, ReceiverAlgorithmAction};
use crate::protocol::sender::{Sender, SenderAlgorithmAction};
use crate::protocol::TimeBase;
use crate::Packet::*;
use crate::{ConnectionSettings, ControlPacket, Packet, SrtCongestCtrl};
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use std::{
io, mem,
sync::{Arc, Mutex},
time::Instant,
};
use bytes::Bytes;
use futures::channel::{mpsc, oneshot};
use futures::prelude::*;
use futures::{future, ready, select};
use log::{debug, error, info, trace};
use tokio::time::delay_until;
pub struct SrtSocket {
recvr: mpsc::Receiver<(Instant, Bytes)>,
sender: mpsc::Sender<(Instant, Bytes)>,
close: oneshot::Receiver<()>,
settings: ConnectionSettings,
flush_wakeup: Arc<Mutex<(Option<Waker>, bool)>>,
_drop_oneshot: oneshot::Sender<()>,
}
#[allow(clippy::large_enum_variant)]
enum Action {
Nothing,
CloseSender,
Send(Option<(Instant, Bytes)>),
DelegatePacket(Option<(Packet, SocketAddr)>),
}
pub fn create_bidrectional_srt<T>(sock: T, conn: crate::Connection) -> SrtSocket
where
T: Stream<Item = (Packet, SocketAddr)>
+ Sink<(Packet, SocketAddr), Error = io::Error>
+ Send
+ Unpin
+ 'static,
{
let (mut release, recvr) = mpsc::channel(128);
let (sender, new_data) = mpsc::channel(128);
let (_drop_oneshot, close_oneshot) = oneshot::channel();
let (close_send, close_recv) = oneshot::channel();
let conn_copy = conn.clone();
let fw = Arc::new(Mutex::new((None as Option<Waker>, true)));
let flush_wakeup = fw.clone();
tokio::spawn(async move {
let mut close_receiver = close_oneshot.fuse();
let _close_sender = close_send;
let mut new_data = new_data.fuse();
let mut sock = sock.fuse();
let time_base = TimeBase::new(conn_copy.settings.socket_start_time);
let mut connection = Connection::new(conn_copy.settings);
let mut sender = Sender::new(conn_copy.settings, conn_copy.handshake, SrtCongestCtrl);
let mut receiver = Receiver::new(conn_copy.settings, Handshake::Connector);
let mut flushed = true;
loop {
let (sender_timeout, close) = match sender.next_action(Instant::now()) {
SenderAlgorithmAction::WaitUntilAck | SenderAlgorithmAction::WaitForData => {
(None, false)
}
SenderAlgorithmAction::WaitUntil(t) => (Some(t), false),
SenderAlgorithmAction::Close => {
trace!("{:?} Send returned close", sender.settings().local_sockid);
(None, true)
}
};
while let Some(out) = sender.pop_output() {
if let Err(e) = sock.send(out).await {
error!("Error while seding packet: {:?}", e);
}
}
if close && receiver.is_flushed() {
trace!(
"{:?} Send returned close and receiver flushed",
sender.settings().local_sockid
);
return;
}
let recvr_timeout = loop {
match receiver.next_algorithm_action(Instant::now()) {
ReceiverAlgorithmAction::TimeBoundedReceive(t2) => {
break Some(t2);
}
ReceiverAlgorithmAction::SendControl(cp, addr) => {
if let Err(e) = sock.send((Packet::Control(cp), addr)).await {
error!("Error while sending packet {:?}", e);
}
}
ReceiverAlgorithmAction::OutputData(ib) => {
if let Err(e) = release.send(ib).await {
error!("Error while releasing packet {:?}", e);
}
}
ReceiverAlgorithmAction::Close => {
if sender.is_flushed() {
trace!("Recv returned close and sender flushed");
return;
}
}
};
};
let connection_timeout = loop {
match connection.next_action(Instant::now()) {
ConnectionAction::ContinueUntil(timeout) => break Some(timeout),
ConnectionAction::Close => {
if receiver.is_flushed() {
info!(
"{:?} Receiver flush and connectiont timeout",
sender.settings().local_sockid
);
return;
}
break None;
}
ConnectionAction::SendKeepAlive => sock
.send((
Control(ControlPacket {
timestamp: time_base.timestamp_from(Instant::now()),
dest_sockid: sender.settings().remote_sockid,
control_type: KeepAlive,
}),
sender.settings().remote,
))
.await
.unwrap(),
}
};
if sender.is_flushed() != flushed {
let mut l = fw.lock().unwrap();
flushed = sender.is_flushed();
l.1 = sender.is_flushed();
if sender.is_flushed() {
if let Some(waker) = mem::replace(&mut l.0, None) {
waker.wake();
}
}
}
let timeout = [sender_timeout, recvr_timeout, connection_timeout]
.iter()
.filter_map(|&x| x)
.min();
let timeout_fut = async {
if let Some(to) = timeout {
let now = Instant::now();
trace!(
"{:?} scheduling wakeup at {}{:?} from {}{}",
sender.settings().local_sockid,
if to > now { "+" } else { "-" },
if to > now { to - now } else { now - to },
if sender_timeout.is_some() {
"sender "
} else {
""
},
if recvr_timeout.is_some() {
"receiver"
} else {
""
}
);
delay_until(to.into()).await
} else {
trace!(
"{:?} not scheduling wakeup!!!",
sender.settings().local_sockid
);
future::pending().await
}
};
let action = select! {
_ = timeout_fut.fuse() => Action::Nothing,
res = sock.next() =>
Action::DelegatePacket(res),
res = new_data.next() => {
Action::Send(res)
}
_ = close_receiver => {
Action::CloseSender
}
};
match action {
Action::Nothing => {}
Action::DelegatePacket(res) => {
match res {
Some((pack, from)) => {
connection.on_packet(Instant::now());
match &pack {
Data(_) => receiver.handle_packet(Instant::now(), (pack, from)),
Control(cp) => match &cp.control_type {
Handshake(_) | Ack { .. } | Nak(_) | DropRequest { .. } => {
sender.handle_packet((pack, from), Instant::now()).unwrap();
}
Ack2(_) => receiver.handle_packet(Instant::now(), (pack, from)),
Shutdown => {
sender
.handle_packet((pack.clone(), from), Instant::now())
.unwrap();
receiver.handle_packet(Instant::now(), (pack, from));
}
KeepAlive => {}
Srt(_) => unimplemented!(),
},
}
}
None => {
info!(
"{:?} Exiting because underlying stream ended",
sender.settings().local_sockid
);
break;
}
}
}
Action::Send(res) => match res {
Some(item) => {
trace!("{:?} queued packet to send", sender.settings().local_sockid);
sender.handle_data(item);
}
None => {
debug!("Incoming data stream closed");
sender.handle_close();
}
},
Action::CloseSender => sender.handle_close(),
}
}
});
SrtSocket {
recvr,
sender,
close: close_recv,
settings: conn.settings,
flush_wakeup,
_drop_oneshot,
}
}
impl SrtSocket {
pub fn settings(&self) -> &ConnectionSettings {
&self.settings
}
}
impl Stream for SrtSocket {
type Item = Result<(Instant, Bytes), io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
Poll::Ready(ready!(Pin::new(&mut self.recvr).poll_next(cx)).map(Ok))
}
}
impl Sink<(Instant, Bytes)> for SrtSocket {
type Error = io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(ready!(Pin::new(&mut self.sender).poll_ready(cx))
.map_err(|e| io::Error::new(io::ErrorKind::NotConnected, e))?))
}
fn start_send(mut self: Pin<&mut Self>, item: (Instant, Bytes)) -> Result<(), Self::Error> {
Ok(self
.sender
.start_send(item)
.map_err(|e| io::Error::new(io::ErrorKind::NotConnected, e))?)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
ready!(Pin::new(&mut self.sender).poll_flush(cx))
.map_err(|e| io::Error::new(io::ErrorKind::NotConnected, e))?;
let mut l = self.flush_wakeup.lock().unwrap();
if l.1 {
Poll::Ready(Ok(()))
} else {
l.0 = Some(cx.waker().clone());
Poll::Pending
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
ready!(Pin::new(&mut self.sender).poll_close(cx))
.map_err(|e| io::Error::new(io::ErrorKind::NotConnected, e))?;
match Pin::new(&mut self.close).poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(_)) => Poll::Ready(Ok(())),
Poll::Ready(Ok(_)) => unreachable!(),
}
}
}