use std::net::SocketAddr;
use std::time::{Duration, Instant};
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use tokio::time::timeout;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info};
use super::handshake::{caller_conclusion, caller_induction, SrtHandshake};
use super::packet::build_data_packet;
use crate::{Result, StreamError};
const TS_BYTES_PER_DATAGRAM: usize = 7 * 188;
const SETUP_TIMEOUT: Duration = Duration::from_secs(5);
pub struct SrtCaller {
addr: SocketAddr,
socket_id: u32,
}
impl SrtCaller {
pub fn new(addr: SocketAddr) -> Self {
let socket_id = (0x4152_434C ^ (addr.port() as u32).rotate_left(16)) | 1;
Self { addr, socket_id }
}
pub async fn run(
self,
mut ts: mpsc::Receiver<bytes::Bytes>,
shutdown: CancellationToken,
) -> Result<()> {
let sock = UdpSocket::bind(("0.0.0.0", 0)).await?;
sock.connect(self.addr).await?;
let mut buf = [0u8; 1500];
sock.send(&caller_induction(self.socket_id, 0)).await?;
let n = timeout(SETUP_TIMEOUT, sock.recv(&mut buf))
.await
.map_err(|_| StreamError::protocol("srt induction timed out"))??;
let induction = SrtHandshake::parse(&buf[..n])
.ok_or_else(|| StreamError::protocol("malformed srt induction response"))?;
sock.send(&caller_conclusion(self.socket_id, 0, induction.cookie))
.await?;
let n = timeout(SETUP_TIMEOUT, sock.recv(&mut buf))
.await
.map_err(|_| StreamError::protocol("srt conclusion timed out"))??;
let agreement = SrtHandshake::parse(&buf[..n])
.ok_or_else(|| StreamError::protocol("malformed srt conclusion response"))?;
let dest = agreement.socket_id;
info!(addr = %self.addr, dest, "srt caller connected");
let start = Instant::now();
let mut seq = 0u32;
let mut msg = 0u32;
loop {
tokio::select! {
_ = shutdown.cancelled() => break,
chunk = ts.recv() => match chunk {
Some(bytes) => {
for piece in bytes.chunks(TS_BYTES_PER_DATAGRAM) {
let ts_us = start.elapsed().as_micros() as u32;
let pkt = build_data_packet(seq, msg, ts_us, dest, piece);
if sock.send(&pkt).await.is_err() {
return Ok(()); }
seq = seq.wrapping_add(1) & 0x7FFF_FFFF;
msg = msg.wrapping_add(1) & 0x03FF_FFFF;
}
}
None => break, }
}
}
debug!(addr = %self.addr, "srt caller finished");
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::srt::{handshake::respond, SrtPacket};
#[tokio::test]
async fn caller_handshakes_and_sends_data_over_loopback() {
let listener = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel(4);
let shutdown = CancellationToken::new();
let sh = shutdown.clone();
let caller = SrtCaller::new(addr);
let caller_task = tokio::spawn(async move { caller.run(rx, sh).await });
let mut buf = [0u8; 1500];
let (n, peer) = listener.recv_from(&mut buf).await.unwrap();
let reply = respond(&buf[..n]).unwrap();
listener.send_to(&reply, peer).await.unwrap();
let (n, peer) = listener.recv_from(&mut buf).await.unwrap();
let reply = respond(&buf[..n]).unwrap();
listener.send_to(&reply, peer).await.unwrap();
tx.send(bytes::Bytes::from(vec![0x47u8; TS_BYTES_PER_DATAGRAM]))
.await
.unwrap();
let (n, _) = timeout(Duration::from_secs(5), listener.recv_from(&mut buf))
.await
.expect("data packet arrived")
.unwrap();
assert!(
matches!(SrtPacket::parse(&buf[..n]), Some(SrtPacket::Data { .. })),
"caller sent an SRT data packet after the handshake"
);
shutdown.cancel();
let _ = caller_task.await;
}
}