1use std::{
2 future::Future,
3 io,
4 mem::MaybeUninit,
5 net::{IpAddr, Ipv4Addr, SocketAddr},
6 pin::Pin,
7 sync::{Arc, Mutex},
8 task::{Context, Poll},
9 time::Duration,
10};
11
12use bytes::Bytes;
13use futures::{
14 channel::{mpsc, oneshot},
15 ready, Sink, SinkExt, Stream, StreamExt,
16};
17use msf_stun as stun;
18use tokio::{io::ReadBuf, net::UdpSocket, task::JoinHandle};
19
20use crate::log::Logger;
21
22#[derive(Clone)]
24pub struct Packet {
25 local_addr: SocketAddr,
26 remote_addr: SocketAddr,
27 data: Bytes,
28}
29
30impl Packet {
31 #[inline]
33 pub fn local_addr(&self) -> SocketAddr {
34 self.local_addr
35 }
36
37 #[inline]
39 pub fn remote_addr(&self) -> SocketAddr {
40 self.remote_addr
41 }
42
43 #[inline]
45 pub fn data(&self) -> &Bytes {
46 &self.data
47 }
48
49 #[inline]
51 pub fn take_data(self) -> Bytes {
52 self.data
53 }
54}
55
56type InputPacket = Packet;
58
59type OutputPacket = (SocketAddr, Bytes);
61
62type OutputPacketTx = mpsc::UnboundedSender<OutputPacket>;
64
65pub struct ICESockets {
67 logger: Logger,
68 open_sockets: Vec<Socket>,
69 binding_rx: mpsc::Receiver<Binding>,
70 socket_rx: mpsc::Receiver<Socket>,
71 packet_rx: mpsc::Receiver<Packet>,
72}
73
74impl ICESockets {
75 pub fn new(logger: Logger, local_addresses: &[IpAddr], stun_servers: &[SocketAddr]) -> Self {
77 let (binding_tx, binding_rx) = mpsc::channel(4);
78 let (socket_tx, socket_rx) = mpsc::channel(4);
79 let (packet_tx, packet_rx) = mpsc::channel(4);
80
81 let unspecified = &[IpAddr::from(Ipv4Addr::UNSPECIFIED)][..];
82
83 let local_addresses = if local_addresses.is_empty() {
84 unspecified
85 } else {
86 local_addresses
87 };
88
89 let stun_servers = Arc::new(stun_servers.to_vec());
90
91 for addr in local_addresses {
92 let logger = logger.clone();
93 let addr = SocketAddr::from((*addr, 0));
94 let binding_tx = binding_tx.clone();
95 let packet_tx = packet_tx.clone();
96 let stun_servers = stun_servers.clone();
97
98 let mut socket_tx = socket_tx.clone();
99
100 tokio::spawn(async move {
101 let socket =
102 Socket::new(logger.clone(), addr, &stun_servers, packet_tx, binding_tx);
103
104 match socket.await {
105 Ok(socket) => {
106 let _ = socket_tx.send(socket).await;
107 }
108 Err(err) => {
109 warn!(logger, "unable to create a new UDP socket"; "cause" => %err);
110 }
111 }
112 });
113 }
114
115 Self {
116 logger,
117 open_sockets: Vec::with_capacity(local_addresses.len()),
118 binding_rx,
119 socket_rx,
120 packet_rx,
121 }
122 }
123
124 pub fn poll_next_binding(&mut self, cx: &mut Context<'_>) -> Poll<Option<Binding>> {
126 let sockets = self.poll_sockets(cx);
127
128 if let Some(binding) = ready!(self.binding_rx.poll_next_unpin(cx)) {
129 Poll::Ready(Some(binding))
130 } else if sockets.is_pending() {
131 Poll::Pending
132 } else {
133 Poll::Ready(None)
134 }
135 }
136
137 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Packet> {
149 loop {
150 match self.poll_next_binding(cx) {
151 Poll::Ready(Some(_)) => (),
152 Poll::Ready(None) => break,
153 Poll::Pending => break,
154 }
155 }
156
157 if let Poll::Ready(Some(packet)) = self.packet_rx.poll_next_unpin(cx) {
158 Poll::Ready(packet)
159 } else {
160 Poll::Pending
161 }
162 }
163
164 pub fn send(&mut self, local_addr: SocketAddr, remote_addr: SocketAddr, data: Bytes) {
166 let socket = self
167 .open_sockets
168 .iter_mut()
169 .find(|socket| socket.is_bound_to(local_addr));
170
171 if let Some(socket) = socket {
172 let _ = socket.send(remote_addr, data);
173 } else {
174 debug!(self.logger, "unknown socket for local binding"; "binding" => %local_addr);
175 }
176 }
177
178 fn poll_sockets(&mut self, cx: &mut Context<'_>) -> Poll<()> {
180 while let Poll::Ready(ready) = self.socket_rx.poll_next_unpin(cx) {
181 if let Some(socket) = ready {
182 self.open_sockets.push(socket);
183 } else {
184 return Poll::Ready(());
185 }
186 }
187
188 Poll::Pending
189 }
190}
191
192#[derive(Copy, Clone)]
194pub enum Binding {
195 Local(LocalBinding),
196 Reflexive(ReflexiveBinding),
197}
198
199impl Binding {
200 fn local(addr: SocketAddr) -> Self {
202 Self::Local(LocalBinding::new(addr))
203 }
204
205 fn reflexive(base: SocketAddr, addr: SocketAddr, source: SocketAddr) -> Self {
207 Self::Reflexive(ReflexiveBinding::new(base, addr, source))
208 }
209}
210
211#[derive(Copy, Clone)]
213pub struct LocalBinding {
214 addr: SocketAddr,
215}
216
217impl LocalBinding {
218 fn new(addr: SocketAddr) -> Self {
220 Self { addr }
221 }
222
223 pub fn addr(self) -> SocketAddr {
225 self.addr
226 }
227}
228
229#[derive(Copy, Clone)]
231pub struct ReflexiveBinding {
232 base: SocketAddr,
233 addr: SocketAddr,
234 source: SocketAddr,
235}
236
237impl ReflexiveBinding {
238 fn new(base: SocketAddr, addr: SocketAddr, source: SocketAddr) -> Self {
240 Self { base, addr, source }
241 }
242
243 pub fn base(&self) -> SocketAddr {
245 self.base
246 }
247
248 pub fn addr(&self) -> SocketAddr {
250 self.addr
251 }
252
253 pub fn source(&self) -> SocketAddr {
255 self.source
256 }
257}
258
259struct Socket {
261 local_addr: SocketAddr,
262 output_packet_tx: OutputPacketTx,
263 reader: JoinHandle<()>,
264 keep_alive: JoinHandle<()>,
265}
266
267impl Socket {
268 async fn new<S, B>(
281 logger: Logger,
282 local_addr: SocketAddr,
283 stun_servers: &[SocketAddr],
284 input_packet_tx: S,
285 mut binding_tx: B,
286 ) -> io::Result<Self>
287 where
288 S: Sink<InputPacket> + Send + Unpin + 'static,
289 B: Sink<Binding> + Send + Unpin + 'static,
290 {
291 let socket = UdpSocketWrapper::bind(local_addr).await?;
292
293 let local_addr = socket.local_addr();
294
295 let _ = binding_tx.send(Binding::local(local_addr)).await;
296
297 let (output_packet_tx, output_packet_rx) = mpsc::unbounded();
298
299 tokio::spawn(socket.write_all(logger.clone(), output_packet_rx));
300
301 let mut stun_context = StunContext::new(output_packet_tx.clone());
302
303 let ctx = stun_context.clone();
304
305 let reader = tokio::spawn(async move {
306 let _ = socket.read_all(logger, input_packet_tx, ctx).await;
307 });
308
309 let stun_servers = stun_servers
310 .iter()
311 .copied()
312 .filter(|addr| local_addr.is_ipv4() == addr.is_ipv4())
313 .collect::<Vec<_>>();
314
315 let keep_alive = tokio::spawn(async move {
316 let reflexive_addr = stun_context.get_reflexive_addr(stun_servers);
317
318 if let Some((reflexive_addr, stun_server)) = reflexive_addr.await {
319 let binding = Binding::reflexive(local_addr, reflexive_addr, stun_server);
320
321 let _ = binding_tx.send(binding).await;
322
323 std::mem::drop(binding_tx);
325
326 stun_context
328 .keep_alive(stun_server, Duration::from_secs(10))
329 .await;
330 }
331 });
332
333 let res = Self {
334 local_addr,
335 output_packet_tx,
336 reader,
337 keep_alive,
338 };
339
340 Ok(res)
341 }
342
343 fn is_bound_to(&self, local_addr: SocketAddr) -> bool {
345 self.local_addr == local_addr
346 || (local_addr.port() == 0 && self.local_addr.ip() == local_addr.ip())
347 }
348
349 fn send(&self, remote_addr: SocketAddr, data: Bytes) -> io::Result<()> {
351 self.output_packet_tx
352 .unbounded_send((remote_addr, data))
353 .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))
354 }
355}
356
357impl Drop for Socket {
358 fn drop(&mut self) {
359 self.keep_alive.abort();
360 self.reader.abort();
361 }
362}
363
364struct UdpSocketWrapper {
366 inner: Arc<UdpSocket>,
367 local_addr: SocketAddr,
368}
369
370impl UdpSocketWrapper {
371 async fn bind(local_addr: SocketAddr) -> io::Result<Self> {
373 let socket = UdpSocket::bind(local_addr).await?;
374
375 let local_addr = socket.local_addr()?;
376
377 let res = Self {
378 inner: Arc::new(socket),
379 local_addr,
380 };
381
382 Ok(res)
383 }
384
385 fn local_addr(&self) -> SocketAddr {
387 self.local_addr
388 }
389
390 fn write_all<S>(&self, logger: Logger, mut stream: S) -> impl Future<Output = ()>
392 where
393 S: Stream<Item = OutputPacket> + Unpin,
394 {
395 let socket = self.inner.clone();
396
397 async move {
398 while let Some((peer, data)) = stream.next().await {
399 if let Err(err) = socket.send_to(&data, peer).await {
400 warn!(logger, "socket write error"; "cause" => %err);
402
403 break;
405 }
406 }
407 }
408 }
409
410 async fn read_all<S>(
413 self,
414 logger: Logger,
415 mut sink: S,
416 mut stun_context: StunContext,
417 ) -> Result<(), S::Error>
418 where
419 S: Sink<Packet> + Unpin,
420 {
421 let stream = UdpSocketStream::from(self);
422
423 let mut filtered = stream.filter_map(move |item| {
424 let res = match item {
425 Ok(packet) => {
426 if let Err(packet) = stun_context.process_packet(packet) {
427 Some(Ok(packet))
428 } else {
429 None
430 }
431 }
432 Err(err) => Some(Err(err)),
433 };
434
435 futures::future::ready(res)
436 });
437
438 while let Some(item) = filtered.next().await {
439 match item {
440 Ok(packet) => sink.send(packet).await?,
441 Err(err) => {
442 warn!(logger, "socket read error"; "cause" => %err);
443 }
444 }
445 }
446
447 Ok(())
448 }
449}
450
451struct UdpSocketStream {
453 socket: Option<Arc<UdpSocket>>,
454 local_addr: SocketAddr,
455}
456
457impl Stream for UdpSocketStream {
458 type Item = io::Result<Packet>;
459
460 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
461 if let Some(socket) = self.socket.as_ref() {
462 let mut buffer: [MaybeUninit<u8>; 65_536] =
464 unsafe { MaybeUninit::uninit().assume_init() };
465
466 let mut buffer = ReadBuf::uninit(&mut buffer);
467
468 match ready!(socket.poll_recv_from(cx, &mut buffer)) {
469 Ok(peer) => {
470 let packet = Packet {
471 local_addr: self.local_addr,
472 remote_addr: peer,
473 data: Bytes::copy_from_slice(buffer.filled()),
474 };
475
476 Poll::Ready(Some(Ok(packet)))
477 }
478 Err(err) => {
479 self.socket = None;
481
482 Poll::Ready(Some(Err(err)))
483 }
484 }
485 } else {
486 Poll::Ready(None)
487 }
488 }
489}
490
491impl From<UdpSocketWrapper> for UdpSocketStream {
492 fn from(socket: UdpSocketWrapper) -> Self {
493 Self {
494 socket: Some(socket.inner),
495 local_addr: socket.local_addr,
496 }
497 }
498}
499
500const RTO: u64 = 500;
502const RM: u64 = 16;
503const RC: u32 = 7;
504
505type StunTransactionId = [u8; 12];
507
508#[derive(Clone)]
510struct StunContext {
511 inner: Arc<Mutex<InnerStunContext>>,
512 output_packet_tx: OutputPacketTx,
513}
514
515impl StunContext {
516 fn new(output_packet_tx: OutputPacketTx) -> Self {
518 Self {
519 inner: Arc::new(Mutex::new(InnerStunContext::new())),
520 output_packet_tx,
521 }
522 }
523
524 async fn get_reflexive_addr<I>(&mut self, stun_servers: I) -> Option<(SocketAddr, SocketAddr)>
526 where
527 I: IntoIterator<Item = SocketAddr>,
528 {
529 let stun_servers = stun_servers.into_iter();
530
531 let reflexive_addrs = futures::stream::iter(stun_servers.enumerate())
532 .then(|(index, addr)| async move {
533 if index > 0 {
534 tokio::time::sleep(Duration::from_millis(RTO << 1)).await;
535 }
536
537 addr
538 })
539 .map(|stun_server| {
540 let request = self.new_binding_request(stun_server, RC);
541
542 async move {
543 if let Ok(reflexive_addr) = request.await {
544 Some((reflexive_addr, stun_server))
545 } else {
546 None
547 }
548 }
549 })
550 .buffered((((1 << (RC - 1)) + RM) * RTO / 1_000) as usize)
551 .filter_map(futures::future::ready);
552
553 futures::pin_mut!(reflexive_addrs);
554
555 reflexive_addrs.next().await
556 }
557
558 async fn keep_alive(&mut self, stun_server: SocketAddr, interval: Duration) {
561 loop {
562 tokio::time::sleep(interval).await;
563
564 let _ = self.new_binding_request(stun_server, 1).await;
565 }
566 }
567
568 fn new_binding_request(
570 &mut self,
571 stun_server: SocketAddr,
572 attempts: u32,
573 ) -> impl Future<Output = io::Result<SocketAddr>> {
574 let transaction_id = rand::random();
575
576 let (reflexive_addr_tx, reflexive_addr_rx) = oneshot::channel();
577
578 let transaction = StunTransaction {
579 context: self.clone(),
580 output_packet_tx: self.output_packet_tx.clone(),
581 reflexive_addr_rx,
582 stun_server,
583 transaction_id,
584 next_timeout: Duration::from_millis(RTO),
585 last_timeout: Duration::from_millis(RTO * RM),
586 remaining_attempts: attempts,
587 };
588
589 let handle = StunTransactionHandle {
590 transaction_id,
591 reflexive_addr_tx,
592 };
593
594 self.inner.lock().unwrap().add_handle(handle);
595
596 transaction.resolve()
597 }
598
599 fn remove_handle(&mut self, id: StunTransactionId) {
601 self.inner.lock().unwrap().remove_handle(id);
602 }
603
604 fn process_packet(&mut self, packet: InputPacket) -> Result<(), InputPacket> {
607 self.inner.lock().unwrap().process_packet(packet)
608 }
609}
610
611struct InnerStunContext {
613 transactions: Vec<StunTransactionHandle>,
614}
615
616impl InnerStunContext {
617 fn new() -> Self {
619 Self {
620 transactions: Vec::new(),
621 }
622 }
623
624 fn add_handle(&mut self, handle: StunTransactionHandle) {
626 self.transactions.push(handle);
627 }
628
629 fn remove_handle(
631 &mut self,
632 transaction_id: StunTransactionId,
633 ) -> Option<StunTransactionHandle> {
634 self.transactions
635 .iter()
636 .position(|t| t.transaction_id() == transaction_id)
637 .map(|i| self.transactions.swap_remove(i))
638 }
639
640 fn process_packet(&mut self, packet: InputPacket) -> Result<(), InputPacket> {
642 let data = packet.data();
643
644 if let Ok(msg) = stun::Message::from_frame(data.clone()) {
645 if msg.is_rfc5389_message()
646 && msg.is_response()
647 && msg.method() == stun::Method::Binding
648 {
649 if let Some(handle) = self.remove_handle(msg.transaction_id()) {
650 let attrs = msg.attributes();
651
652 if let Some(addr) = attrs.get_any_mapped_address() {
653 handle.resolve(addr);
654 }
655
656 return Ok(());
657 }
658 }
659 }
660
661 Err(packet)
662 }
663}
664
665struct StunTransaction<S, F> {
667 context: StunContext,
668 output_packet_tx: S,
669 reflexive_addr_rx: F,
670 stun_server: SocketAddr,
671 transaction_id: StunTransactionId,
672 next_timeout: Duration,
673 last_timeout: Duration,
674 remaining_attempts: u32,
675}
676
677impl<S, F, E> StunTransaction<S, F>
678where
679 S: Sink<OutputPacket> + Unpin,
680 F: Future<Output = Result<SocketAddr, E>> + Unpin,
681{
682 async fn resolve(mut self) -> io::Result<SocketAddr> {
684 let builder = stun::MessageBuilder::binding_request(self.transaction_id);
685
686 let msg = builder.build();
687
688 while self.remaining_attempts > 0 {
689 self.output_packet_tx
690 .send((self.stun_server, msg.clone()))
691 .await
692 .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))?;
693
694 let timeout = if self.remaining_attempts > 1 {
695 self.next_timeout
696 } else {
697 self.last_timeout
698 };
699
700 let addr = tokio::time::timeout(timeout, &mut self.reflexive_addr_rx);
701
702 if let Ok(res) = addr.await {
703 return res.map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe));
704 }
705
706 self.remaining_attempts -= 1;
707 self.next_timeout *= 2;
708 }
709
710 Err(io::Error::from(io::ErrorKind::TimedOut))
711 }
712}
713
714impl<S, F> Drop for StunTransaction<S, F> {
715 fn drop(&mut self) {
716 self.context.remove_handle(self.transaction_id);
717 }
718}
719
720type ReflexiveAddrTx = oneshot::Sender<SocketAddr>;
722
723struct StunTransactionHandle {
725 transaction_id: StunTransactionId,
726 reflexive_addr_tx: ReflexiveAddrTx,
727}
728
729impl StunTransactionHandle {
730 fn transaction_id(&self) -> StunTransactionId {
732 self.transaction_id
733 }
734
735 fn resolve(self, reflexive_addr: SocketAddr) {
737 let _ = self.reflexive_addr_tx.send(reflexive_addr);
738 }
739}