msf_ice/
socket.rs

1use std::{
2    future::Future,
3    io,
4    mem::MaybeUninit,
5    net::{IpAddr, Ipv4Addr, SocketAddr},
6    pin::Pin,
7    sync::{Arc, Mutex},
8    task::{Context, Poll},
9    time::Duration,
10};
11
12use bytes::Bytes;
13use futures::{
14    channel::{mpsc, oneshot},
15    ready, Sink, SinkExt, Stream, StreamExt,
16};
17use msf_stun as stun;
18use tokio::{io::ReadBuf, net::UdpSocket, task::JoinHandle};
19
20use crate::log::Logger;
21
22/// Data packet.
23#[derive(Clone)]
24pub struct Packet {
25    local_addr: SocketAddr,
26    remote_addr: SocketAddr,
27    data: Bytes,
28}
29
30impl Packet {
31    /// Get the local address where the packet was received.
32    #[inline]
33    pub fn local_addr(&self) -> SocketAddr {
34        self.local_addr
35    }
36
37    /// Get the remote address where the packet was sent from.
38    #[inline]
39    pub fn remote_addr(&self) -> SocketAddr {
40        self.remote_addr
41    }
42
43    /// Get packet data.
44    #[inline]
45    pub fn data(&self) -> &Bytes {
46        &self.data
47    }
48
49    /// Take the packet data.
50    #[inline]
51    pub fn take_data(self) -> Bytes {
52        self.data
53    }
54}
55
56/// Type alias.
57type InputPacket = Packet;
58
59/// Type alias.
60type OutputPacket = (SocketAddr, Bytes);
61
62/// Type alias.
63type OutputPacketTx = mpsc::UnboundedSender<OutputPacket>;
64
65/// ICE socket manager.
66pub struct ICESockets {
67    logger: Logger,
68    open_sockets: Vec<Socket>,
69    binding_rx: mpsc::Receiver<Binding>,
70    socket_rx: mpsc::Receiver<Socket>,
71    packet_rx: mpsc::Receiver<Packet>,
72}
73
74impl ICESockets {
75    /// Create a new socket manager.
76    pub fn new(logger: Logger, local_addresses: &[IpAddr], stun_servers: &[SocketAddr]) -> Self {
77        let (binding_tx, binding_rx) = mpsc::channel(4);
78        let (socket_tx, socket_rx) = mpsc::channel(4);
79        let (packet_tx, packet_rx) = mpsc::channel(4);
80
81        let unspecified = &[IpAddr::from(Ipv4Addr::UNSPECIFIED)][..];
82
83        let local_addresses = if local_addresses.is_empty() {
84            unspecified
85        } else {
86            local_addresses
87        };
88
89        let stun_servers = Arc::new(stun_servers.to_vec());
90
91        for addr in local_addresses {
92            let logger = logger.clone();
93            let addr = SocketAddr::from((*addr, 0));
94            let binding_tx = binding_tx.clone();
95            let packet_tx = packet_tx.clone();
96            let stun_servers = stun_servers.clone();
97
98            let mut socket_tx = socket_tx.clone();
99
100            tokio::spawn(async move {
101                let socket =
102                    Socket::new(logger.clone(), addr, &stun_servers, packet_tx, binding_tx);
103
104                match socket.await {
105                    Ok(socket) => {
106                        let _ = socket_tx.send(socket).await;
107                    }
108                    Err(err) => {
109                        warn!(logger, "unable to create a new UDP socket"; "cause" => %err);
110                    }
111                }
112            });
113        }
114
115        Self {
116            logger,
117            open_sockets: Vec::with_capacity(local_addresses.len()),
118            binding_rx,
119            socket_rx,
120            packet_rx,
121        }
122    }
123
124    /// Get the next local binding.
125    pub fn poll_next_binding(&mut self, cx: &mut Context<'_>) -> Poll<Option<Binding>> {
126        let sockets = self.poll_sockets(cx);
127
128        if let Some(binding) = ready!(self.binding_rx.poll_next_unpin(cx)) {
129            Poll::Ready(Some(binding))
130        } else if sockets.is_pending() {
131            Poll::Pending
132        } else {
133            Poll::Ready(None)
134        }
135    }
136
137    /*/// Close all sockets matching a given filter function.
138    ///
139    /// TODO: use this
140    pub fn close_sockets<F>(&mut self, mut filter: F)
141    where
142        F: FnMut(SocketAddr) -> bool,
143    {
144        self.open_sockets.retain(|socket| !filter(socket.local_addr()));
145    }*/
146
147    /// Receive the next packet.
148    pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Packet> {
149        loop {
150            match self.poll_next_binding(cx) {
151                Poll::Ready(Some(_)) => (),
152                Poll::Ready(None) => break,
153                Poll::Pending => break,
154            }
155        }
156
157        if let Poll::Ready(Some(packet)) = self.packet_rx.poll_next_unpin(cx) {
158            Poll::Ready(packet)
159        } else {
160            Poll::Pending
161        }
162    }
163
164    /// Send given data from a given local binding to a given destination.
165    pub fn send(&mut self, local_addr: SocketAddr, remote_addr: SocketAddr, data: Bytes) {
166        let socket = self
167            .open_sockets
168            .iter_mut()
169            .find(|socket| socket.is_bound_to(local_addr));
170
171        if let Some(socket) = socket {
172            let _ = socket.send(remote_addr, data);
173        } else {
174            debug!(self.logger, "unknown socket for local binding"; "binding" => %local_addr);
175        }
176    }
177
178    /// Poll pending sockets.
179    fn poll_sockets(&mut self, cx: &mut Context<'_>) -> Poll<()> {
180        while let Poll::Ready(ready) = self.socket_rx.poll_next_unpin(cx) {
181            if let Some(socket) = ready {
182                self.open_sockets.push(socket);
183            } else {
184                return Poll::Ready(());
185            }
186        }
187
188        Poll::Pending
189    }
190}
191
192/// Socket binding.
193#[derive(Copy, Clone)]
194pub enum Binding {
195    Local(LocalBinding),
196    Reflexive(ReflexiveBinding),
197}
198
199impl Binding {
200    /// Create a new local binding.
201    fn local(addr: SocketAddr) -> Self {
202        Self::Local(LocalBinding::new(addr))
203    }
204
205    /// Create a new reflexive binding.
206    fn reflexive(base: SocketAddr, addr: SocketAddr, source: SocketAddr) -> Self {
207        Self::Reflexive(ReflexiveBinding::new(base, addr, source))
208    }
209}
210
211/// Local socket binding.
212#[derive(Copy, Clone)]
213pub struct LocalBinding {
214    addr: SocketAddr,
215}
216
217impl LocalBinding {
218    /// Create a new binding.
219    fn new(addr: SocketAddr) -> Self {
220        Self { addr }
221    }
222
223    /// Socket address where the socket is bound to.
224    pub fn addr(self) -> SocketAddr {
225        self.addr
226    }
227}
228
229/// Reflexive socket binding.
230#[derive(Copy, Clone)]
231pub struct ReflexiveBinding {
232    base: SocketAddr,
233    addr: SocketAddr,
234    source: SocketAddr,
235}
236
237impl ReflexiveBinding {
238    /// Create a new binding.
239    fn new(base: SocketAddr, addr: SocketAddr, source: SocketAddr) -> Self {
240        Self { base, addr, source }
241    }
242
243    /// Local socket address where the socket is bound to.
244    pub fn base(&self) -> SocketAddr {
245        self.base
246    }
247
248    /// Public socket address where the socket is bound to.
249    pub fn addr(&self) -> SocketAddr {
250        self.addr
251    }
252
253    /// Source of the binding information (e.g. a STUN server).
254    pub fn source(&self) -> SocketAddr {
255        self.source
256    }
257}
258
259/// ICE socket.
260struct Socket {
261    local_addr: SocketAddr,
262    output_packet_tx: OutputPacketTx,
263    reader: JoinHandle<()>,
264    keep_alive: JoinHandle<()>,
265}
266
267impl Socket {
268    /// Create a new ICE socket.
269    ///
270    /// A new UDP socket will be created and it will be bound to a given local
271    /// address (the port will be assigned automatically if the given port is
272    /// 0).
273    ///
274    /// Once the socket is created a server-reflexive address will be
275    /// automatically obtained from one of the given STUN servers.
276    ///
277    /// All incoming packets will be passed to a given packet sink and all
278    /// bindings (the local binding and optionally the server-reflexive
279    /// binding) will be passed to a given binding sink.
280    async fn new<S, B>(
281        logger: Logger,
282        local_addr: SocketAddr,
283        stun_servers: &[SocketAddr],
284        input_packet_tx: S,
285        mut binding_tx: B,
286    ) -> io::Result<Self>
287    where
288        S: Sink<InputPacket> + Send + Unpin + 'static,
289        B: Sink<Binding> + Send + Unpin + 'static,
290    {
291        let socket = UdpSocketWrapper::bind(local_addr).await?;
292
293        let local_addr = socket.local_addr();
294
295        let _ = binding_tx.send(Binding::local(local_addr)).await;
296
297        let (output_packet_tx, output_packet_rx) = mpsc::unbounded();
298
299        tokio::spawn(socket.write_all(logger.clone(), output_packet_rx));
300
301        let mut stun_context = StunContext::new(output_packet_tx.clone());
302
303        let ctx = stun_context.clone();
304
305        let reader = tokio::spawn(async move {
306            let _ = socket.read_all(logger, input_packet_tx, ctx).await;
307        });
308
309        let stun_servers = stun_servers
310            .iter()
311            .copied()
312            .filter(|addr| local_addr.is_ipv4() == addr.is_ipv4())
313            .collect::<Vec<_>>();
314
315        let keep_alive = tokio::spawn(async move {
316            let reflexive_addr = stun_context.get_reflexive_addr(stun_servers);
317
318            if let Some((reflexive_addr, stun_server)) = reflexive_addr.await {
319                let binding = Binding::reflexive(local_addr, reflexive_addr, stun_server);
320
321                let _ = binding_tx.send(binding).await;
322
323                // there will be no more bindings from us
324                std::mem::drop(binding_tx);
325
326                // TODO: check the timing
327                stun_context
328                    .keep_alive(stun_server, Duration::from_secs(10))
329                    .await;
330            }
331        });
332
333        let res = Self {
334            local_addr,
335            output_packet_tx,
336            reader,
337            keep_alive,
338        };
339
340        Ok(res)
341    }
342
343    /// Check if the socket is bound to a given address.
344    fn is_bound_to(&self, local_addr: SocketAddr) -> bool {
345        self.local_addr == local_addr
346            || (local_addr.port() == 0 && self.local_addr.ip() == local_addr.ip())
347    }
348
349    /// Send given data to a given remote destination.
350    fn send(&self, remote_addr: SocketAddr, data: Bytes) -> io::Result<()> {
351        self.output_packet_tx
352            .unbounded_send((remote_addr, data))
353            .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))
354    }
355}
356
357impl Drop for Socket {
358    fn drop(&mut self) {
359        self.keep_alive.abort();
360        self.reader.abort();
361    }
362}
363
364/// Helper struct.
365struct UdpSocketWrapper {
366    inner: Arc<UdpSocket>,
367    local_addr: SocketAddr,
368}
369
370impl UdpSocketWrapper {
371    /// Create a new UDP socket bound to a given local address.
372    async fn bind(local_addr: SocketAddr) -> io::Result<Self> {
373        let socket = UdpSocket::bind(local_addr).await?;
374
375        let local_addr = socket.local_addr()?;
376
377        let res = Self {
378            inner: Arc::new(socket),
379            local_addr,
380        };
381
382        Ok(res)
383    }
384
385    /// Get the socket binding.
386    fn local_addr(&self) -> SocketAddr {
387        self.local_addr
388    }
389
390    /// Send all packets from a given stream using the underlying UDP socket.
391    fn write_all<S>(&self, logger: Logger, mut stream: S) -> impl Future<Output = ()>
392    where
393        S: Stream<Item = OutputPacket> + Unpin,
394    {
395        let socket = self.inner.clone();
396
397        async move {
398            while let Some((peer, data)) = stream.next().await {
399                if let Err(err) = socket.send_to(&data, peer).await {
400                    // log the error
401                    warn!(logger, "socket write error"; "cause" => %err);
402
403                    // ... and terminate the loop
404                    break;
405                }
406            }
407        }
408    }
409
410    /// Read all packets from the underlying UDP socket and feed them to a
411    /// given sink.
412    async fn read_all<S>(
413        self,
414        logger: Logger,
415        mut sink: S,
416        mut stun_context: StunContext,
417    ) -> Result<(), S::Error>
418    where
419        S: Sink<Packet> + Unpin,
420    {
421        let stream = UdpSocketStream::from(self);
422
423        let mut filtered = stream.filter_map(move |item| {
424            let res = match item {
425                Ok(packet) => {
426                    if let Err(packet) = stun_context.process_packet(packet) {
427                        Some(Ok(packet))
428                    } else {
429                        None
430                    }
431                }
432                Err(err) => Some(Err(err)),
433            };
434
435            futures::future::ready(res)
436        });
437
438        while let Some(item) = filtered.next().await {
439            match item {
440                Ok(packet) => sink.send(packet).await?,
441                Err(err) => {
442                    warn!(logger, "socket read error"; "cause" => %err);
443                }
444            }
445        }
446
447        Ok(())
448    }
449}
450
451/// Helper struct.
452struct UdpSocketStream {
453    socket: Option<Arc<UdpSocket>>,
454    local_addr: SocketAddr,
455}
456
457impl Stream for UdpSocketStream {
458    type Item = io::Result<Packet>;
459
460    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
461        if let Some(socket) = self.socket.as_ref() {
462            // XXX: use MaybeUninit::uninit_array() once stabilized
463            let mut buffer: [MaybeUninit<u8>; 65_536] =
464                unsafe { MaybeUninit::uninit().assume_init() };
465
466            let mut buffer = ReadBuf::uninit(&mut buffer);
467
468            match ready!(socket.poll_recv_from(cx, &mut buffer)) {
469                Ok(peer) => {
470                    let packet = Packet {
471                        local_addr: self.local_addr,
472                        remote_addr: peer,
473                        data: Bytes::copy_from_slice(buffer.filled()),
474                    };
475
476                    Poll::Ready(Some(Ok(packet)))
477                }
478                Err(err) => {
479                    // drop the socket, we don't want to poll it again
480                    self.socket = None;
481
482                    Poll::Ready(Some(Err(err)))
483                }
484            }
485        } else {
486            Poll::Ready(None)
487        }
488    }
489}
490
491impl From<UdpSocketWrapper> for UdpSocketStream {
492    fn from(socket: UdpSocketWrapper) -> Self {
493        Self {
494            socket: Some(socket.inner),
495            local_addr: socket.local_addr,
496        }
497    }
498}
499
500// TODO: make these configurable
501const RTO: u64 = 500;
502const RM: u64 = 16;
503const RC: u32 = 7;
504
505/// Type alias.
506type StunTransactionId = [u8; 12];
507
508/// Socket STUN context.
509#[derive(Clone)]
510struct StunContext {
511    inner: Arc<Mutex<InnerStunContext>>,
512    output_packet_tx: OutputPacketTx,
513}
514
515impl StunContext {
516    /// Create a new STUN context.
517    fn new(output_packet_tx: OutputPacketTx) -> Self {
518        Self {
519            inner: Arc::new(Mutex::new(InnerStunContext::new())),
520            output_packet_tx,
521        }
522    }
523
524    /// Get server-reflexive address using one of the given STUN servers.
525    async fn get_reflexive_addr<I>(&mut self, stun_servers: I) -> Option<(SocketAddr, SocketAddr)>
526    where
527        I: IntoIterator<Item = SocketAddr>,
528    {
529        let stun_servers = stun_servers.into_iter();
530
531        let reflexive_addrs = futures::stream::iter(stun_servers.enumerate())
532            .then(|(index, addr)| async move {
533                if index > 0 {
534                    tokio::time::sleep(Duration::from_millis(RTO << 1)).await;
535                }
536
537                addr
538            })
539            .map(|stun_server| {
540                let request = self.new_binding_request(stun_server, RC);
541
542                async move {
543                    if let Ok(reflexive_addr) = request.await {
544                        Some((reflexive_addr, stun_server))
545                    } else {
546                        None
547                    }
548                }
549            })
550            .buffered((((1 << (RC - 1)) + RM) * RTO / 1_000) as usize)
551            .filter_map(futures::future::ready);
552
553        futures::pin_mut!(reflexive_addrs);
554
555        reflexive_addrs.next().await
556    }
557
558    /// Keep alive the server-reflexive binding by sending STUN requests to a
559    /// given STUN server in a given interval.
560    async fn keep_alive(&mut self, stun_server: SocketAddr, interval: Duration) {
561        loop {
562            tokio::time::sleep(interval).await;
563
564            let _ = self.new_binding_request(stun_server, 1).await;
565        }
566    }
567
568    /// Create a new binding request.
569    fn new_binding_request(
570        &mut self,
571        stun_server: SocketAddr,
572        attempts: u32,
573    ) -> impl Future<Output = io::Result<SocketAddr>> {
574        let transaction_id = rand::random();
575
576        let (reflexive_addr_tx, reflexive_addr_rx) = oneshot::channel();
577
578        let transaction = StunTransaction {
579            context: self.clone(),
580            output_packet_tx: self.output_packet_tx.clone(),
581            reflexive_addr_rx,
582            stun_server,
583            transaction_id,
584            next_timeout: Duration::from_millis(RTO),
585            last_timeout: Duration::from_millis(RTO * RM),
586            remaining_attempts: attempts,
587        };
588
589        let handle = StunTransactionHandle {
590            transaction_id,
591            reflexive_addr_tx,
592        };
593
594        self.inner.lock().unwrap().add_handle(handle);
595
596        transaction.resolve()
597    }
598
599    /// Remove a given STUN transaction handle.
600    fn remove_handle(&mut self, id: StunTransactionId) {
601        self.inner.lock().unwrap().remove_handle(id);
602    }
603
604    /// Process and consume a given input packet or return it back for further
605    /// processing by the ICE channel.
606    fn process_packet(&mut self, packet: InputPacket) -> Result<(), InputPacket> {
607        self.inner.lock().unwrap().process_packet(packet)
608    }
609}
610
611/// Inner STUN context.
612struct InnerStunContext {
613    transactions: Vec<StunTransactionHandle>,
614}
615
616impl InnerStunContext {
617    /// Create a new context.
618    fn new() -> Self {
619        Self {
620            transactions: Vec::new(),
621        }
622    }
623
624    /// Add a given transaction handle.
625    fn add_handle(&mut self, handle: StunTransactionHandle) {
626        self.transactions.push(handle);
627    }
628
629    /// Remove a given transaction handle and return it.
630    fn remove_handle(
631        &mut self,
632        transaction_id: StunTransactionId,
633    ) -> Option<StunTransactionHandle> {
634        self.transactions
635            .iter()
636            .position(|t| t.transaction_id() == transaction_id)
637            .map(|i| self.transactions.swap_remove(i))
638    }
639
640    /// Process a given input packet.
641    fn process_packet(&mut self, packet: InputPacket) -> Result<(), InputPacket> {
642        let data = packet.data();
643
644        if let Ok(msg) = stun::Message::from_frame(data.clone()) {
645            if msg.is_rfc5389_message()
646                && msg.is_response()
647                && msg.method() == stun::Method::Binding
648            {
649                if let Some(handle) = self.remove_handle(msg.transaction_id()) {
650                    let attrs = msg.attributes();
651
652                    if let Some(addr) = attrs.get_any_mapped_address() {
653                        handle.resolve(addr);
654                    }
655
656                    return Ok(());
657                }
658            }
659        }
660
661        Err(packet)
662    }
663}
664
665/// STUN transaction.
666struct StunTransaction<S, F> {
667    context: StunContext,
668    output_packet_tx: S,
669    reflexive_addr_rx: F,
670    stun_server: SocketAddr,
671    transaction_id: StunTransactionId,
672    next_timeout: Duration,
673    last_timeout: Duration,
674    remaining_attempts: u32,
675}
676
677impl<S, F, E> StunTransaction<S, F>
678where
679    S: Sink<OutputPacket> + Unpin,
680    F: Future<Output = Result<SocketAddr, E>> + Unpin,
681{
682    /// Resolve the transaction.
683    async fn resolve(mut self) -> io::Result<SocketAddr> {
684        let builder = stun::MessageBuilder::binding_request(self.transaction_id);
685
686        let msg = builder.build();
687
688        while self.remaining_attempts > 0 {
689            self.output_packet_tx
690                .send((self.stun_server, msg.clone()))
691                .await
692                .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))?;
693
694            let timeout = if self.remaining_attempts > 1 {
695                self.next_timeout
696            } else {
697                self.last_timeout
698            };
699
700            let addr = tokio::time::timeout(timeout, &mut self.reflexive_addr_rx);
701
702            if let Ok(res) = addr.await {
703                return res.map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe));
704            }
705
706            self.remaining_attempts -= 1;
707            self.next_timeout *= 2;
708        }
709
710        Err(io::Error::from(io::ErrorKind::TimedOut))
711    }
712}
713
714impl<S, F> Drop for StunTransaction<S, F> {
715    fn drop(&mut self) {
716        self.context.remove_handle(self.transaction_id);
717    }
718}
719
720/// Type alias.
721type ReflexiveAddrTx = oneshot::Sender<SocketAddr>;
722
723/// STUN transaction handle.
724struct StunTransactionHandle {
725    transaction_id: StunTransactionId,
726    reflexive_addr_tx: ReflexiveAddrTx,
727}
728
729impl StunTransactionHandle {
730    /// Get the transaction ID.
731    fn transaction_id(&self) -> StunTransactionId {
732        self.transaction_id
733    }
734
735    /// Resolve the STUN transaction.
736    fn resolve(self, reflexive_addr: SocketAddr) {
737        let _ = self.reflexive_addr_tx.send(reflexive_addr);
738    }
739}