webrtc_sctp/association/
mod.rs

1#[cfg(test)]
2mod association_test;
3
4mod association_internal;
5mod association_stats;
6
7use std::collections::{HashMap, VecDeque};
8use std::fmt;
9use std::sync::atomic::Ordering;
10use std::sync::Arc;
11use std::time::SystemTime;
12
13use association_internal::*;
14use association_stats::*;
15use bytes::{Bytes, BytesMut};
16use portable_atomic::{AtomicBool, AtomicU32, AtomicU8, AtomicUsize};
17use rand::random;
18use tokio::sync::{broadcast, mpsc, Mutex};
19use util::Conn;
20
21use crate::chunk::chunk_abort::ChunkAbort;
22use crate::chunk::chunk_cookie_ack::ChunkCookieAck;
23use crate::chunk::chunk_cookie_echo::ChunkCookieEcho;
24use crate::chunk::chunk_error::ChunkError;
25use crate::chunk::chunk_forward_tsn::{ChunkForwardTsn, ChunkForwardTsnStream};
26use crate::chunk::chunk_heartbeat::ChunkHeartbeat;
27use crate::chunk::chunk_heartbeat_ack::ChunkHeartbeatAck;
28use crate::chunk::chunk_init::ChunkInit;
29use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier};
30use crate::chunk::chunk_reconfig::ChunkReconfig;
31use crate::chunk::chunk_selective_ack::ChunkSelectiveAck;
32use crate::chunk::chunk_shutdown::ChunkShutdown;
33use crate::chunk::chunk_shutdown_ack::ChunkShutdownAck;
34use crate::chunk::chunk_shutdown_complete::ChunkShutdownComplete;
35use crate::chunk::chunk_type::*;
36use crate::chunk::Chunk;
37use crate::error::{Error, Result};
38use crate::error_cause::*;
39use crate::packet::Packet;
40use crate::param::param_heartbeat_info::ParamHeartbeatInfo;
41use crate::param::param_outgoing_reset_request::ParamOutgoingResetRequest;
42use crate::param::param_reconfig_response::{ParamReconfigResponse, ReconfigResult};
43use crate::param::param_state_cookie::ParamStateCookie;
44use crate::param::param_supported_extensions::ParamSupportedExtensions;
45use crate::param::Param;
46use crate::queue::control_queue::ControlQueue;
47use crate::queue::payload_queue::PayloadQueue;
48use crate::queue::pending_queue::PendingQueue;
49use crate::stream::*;
50use crate::timer::ack_timer::*;
51use crate::timer::rtx_timer::*;
52use crate::util::*;
53
54pub(crate) const RECEIVE_MTU: usize = 8192;
55/// MTU for inbound packet (from DTLS)
56pub(crate) const INITIAL_MTU: u32 = 1228;
57/// initial MTU for outgoing packets (to DTLS)
58pub(crate) const INITIAL_RECV_BUF_SIZE: u32 = 1024 * 1024;
59pub(crate) const COMMON_HEADER_SIZE: u32 = 12;
60pub(crate) const DATA_CHUNK_HEADER_SIZE: u32 = 16;
61pub(crate) const DEFAULT_MAX_MESSAGE_SIZE: u32 = 65536;
62
63/// other constants
64pub(crate) const ACCEPT_CH_SIZE: usize = 16;
65
66/// association state enums
67#[derive(Debug, Copy, Clone, PartialEq)]
68pub(crate) enum AssociationState {
69    Closed = 0,
70    CookieWait = 1,
71    CookieEchoed = 2,
72    Established = 3,
73    ShutdownAckSent = 4,
74    ShutdownPending = 5,
75    ShutdownReceived = 6,
76    ShutdownSent = 7,
77}
78
79impl From<u8> for AssociationState {
80    fn from(v: u8) -> AssociationState {
81        match v {
82            1 => AssociationState::CookieWait,
83            2 => AssociationState::CookieEchoed,
84            3 => AssociationState::Established,
85            4 => AssociationState::ShutdownAckSent,
86            5 => AssociationState::ShutdownPending,
87            6 => AssociationState::ShutdownReceived,
88            7 => AssociationState::ShutdownSent,
89            _ => AssociationState::Closed,
90        }
91    }
92}
93
94impl fmt::Display for AssociationState {
95    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96        let s = match *self {
97            AssociationState::Closed => "Closed",
98            AssociationState::CookieWait => "CookieWait",
99            AssociationState::CookieEchoed => "CookieEchoed",
100            AssociationState::Established => "Established",
101            AssociationState::ShutdownPending => "ShutdownPending",
102            AssociationState::ShutdownSent => "ShutdownSent",
103            AssociationState::ShutdownReceived => "ShutdownReceived",
104            AssociationState::ShutdownAckSent => "ShutdownAckSent",
105        };
106        write!(f, "{s}")
107    }
108}
109
110/// retransmission timer IDs
111#[derive(Default, Debug, Copy, Clone, PartialEq)]
112pub(crate) enum RtxTimerId {
113    #[default]
114    T1Init,
115    T1Cookie,
116    T2Shutdown,
117    T3RTX,
118    Reconfig,
119}
120
121impl fmt::Display for RtxTimerId {
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        let s = match *self {
124            RtxTimerId::T1Init => "T1Init",
125            RtxTimerId::T1Cookie => "T1Cookie",
126            RtxTimerId::T2Shutdown => "T2Shutdown",
127            RtxTimerId::T3RTX => "T3RTX",
128            RtxTimerId::Reconfig => "Reconfig",
129        };
130        write!(f, "{s}")
131    }
132}
133
134/// ack mode (for testing)
135#[derive(Default, Debug, Copy, Clone, PartialEq)]
136pub(crate) enum AckMode {
137    #[default]
138    Normal,
139    NoDelay,
140    AlwaysDelay,
141}
142
143impl fmt::Display for AckMode {
144    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145        let s = match *self {
146            AckMode::Normal => "Normal",
147            AckMode::NoDelay => "NoDelay",
148            AckMode::AlwaysDelay => "AlwaysDelay",
149        };
150        write!(f, "{s}")
151    }
152}
153
154/// ack transmission state
155#[derive(Default, Debug, Copy, Clone, PartialEq)]
156pub(crate) enum AckState {
157    #[default]
158    Idle, // ack timer is off
159    Immediate, // will send ack immediately
160    Delay,     // ack timer is on (ack is being delayed)
161}
162
163impl fmt::Display for AckState {
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        let s = match *self {
166            AckState::Idle => "Idle",
167            AckState::Immediate => "Immediate",
168            AckState::Delay => "Delay",
169        };
170        write!(f, "{s}")
171    }
172}
173
174/// Config collects the arguments to create_association construction into
175/// a single structure
176pub struct Config {
177    pub net_conn: Arc<dyn Conn + Send + Sync>,
178    pub max_receive_buffer_size: u32,
179    pub max_message_size: u32,
180    pub name: String,
181    pub remote_port: u16,
182    pub local_port: u16,
183}
184
185///Association represents an SCTP association
186///13.2.  Parameters Necessary per Association (i.e., the TCB)
187///Peer : Tag value to be sent in every packet and is received
188///Verification: in the INIT or INIT ACK chunk.
189///Tag :
190///
191///My : Tag expected in every inbound packet and sent in the
192///Verification: INIT or INIT ACK chunk.
193///
194///Tag :
195///State : A state variable indicating what state the association
196/// : is in, i.e., COOKIE-WAIT, COOKIE-ECHOED, ESTABLISHED,
197/// : SHUTDOWN-PENDING, SHUTDOWN-SENT, SHUTDOWN-RECEIVED,
198/// : SHUTDOWN-ACK-SENT.
199///
200/// No Closed state is illustrated since if a
201/// association is Closed its TCB SHOULD be removed.
202pub struct Association {
203    name: String,
204    state: Arc<AtomicU8>,
205    // TODO: Convert into `u32`, as there is no reason why the `max_message_size` should need to be
206    // changed after the Assocaition has been created. Note that even if there was a use case for
207    // this, it is not used anywhere in the code base.
208    //
209    // Using atomics where not necessary -- especially in a hot path such as `prepare_write` -- may
210    // negatively impact performance, and adds unneeded complexity to the code.
211    max_message_size: Arc<AtomicU32>,
212    inflight_queue_length: Arc<AtomicUsize>,
213    will_send_shutdown: Arc<AtomicBool>,
214    awake_write_loop_ch: Arc<mpsc::Sender<()>>,
215    close_loop_ch_rx: Mutex<broadcast::Receiver<()>>,
216    accept_ch_rx: Mutex<mpsc::Receiver<Arc<Stream>>>,
217    net_conn: Arc<dyn Conn + Send + Sync>,
218    bytes_received: Arc<AtomicUsize>,
219    bytes_sent: Arc<AtomicUsize>,
220
221    pub(crate) association_internal: Arc<Mutex<AssociationInternal>>,
222}
223
224impl Association {
225    /// server accepts a SCTP stream over a conn
226    pub async fn server(config: Config) -> Result<Self> {
227        let (a, mut handshake_completed_ch_rx) = Association::new(config, false).await?;
228
229        if let Some(err_opt) = handshake_completed_ch_rx.recv().await {
230            if let Some(err) = err_opt {
231                Err(err)
232            } else {
233                Ok(a)
234            }
235        } else {
236            Err(Error::ErrAssociationHandshakeClosed)
237        }
238    }
239
240    /// Client opens a SCTP stream over a conn
241    pub async fn client(config: Config) -> Result<Self> {
242        let (a, mut handshake_completed_ch_rx) = Association::new(config, true).await?;
243
244        if let Some(err_opt) = handshake_completed_ch_rx.recv().await {
245            if let Some(err) = err_opt {
246                Err(err)
247            } else {
248                Ok(a)
249            }
250        } else {
251            Err(Error::ErrAssociationHandshakeClosed)
252        }
253    }
254
255    /// Shutdown initiates the shutdown sequence. The method blocks until the
256    /// shutdown sequence is completed and the connection is closed, or until the
257    /// passed context is done, in which case the context's error is returned.
258    pub async fn shutdown(&self) -> Result<()> {
259        log::debug!("[{}] closing association..", self.name);
260
261        let state = self.get_state();
262        if state != AssociationState::Established {
263            return Err(Error::ErrShutdownNonEstablished);
264        }
265
266        // Attempt a graceful shutdown.
267        self.set_state(AssociationState::ShutdownPending);
268
269        if self.inflight_queue_length.load(Ordering::SeqCst) == 0 {
270            // No more outstanding, send shutdown.
271            self.will_send_shutdown.store(true, Ordering::SeqCst);
272            let _ = self.awake_write_loop_ch.try_send(());
273            self.set_state(AssociationState::ShutdownSent);
274        }
275
276        {
277            let mut close_loop_ch_rx = self.close_loop_ch_rx.lock().await;
278            let _ = close_loop_ch_rx.recv().await;
279        }
280
281        Ok(())
282    }
283
284    /// Close ends the SCTP Association and cleans up any state
285    pub async fn close(&self) -> Result<()> {
286        log::debug!("[{}] closing association..", self.name);
287
288        let _ = self.net_conn.close().await;
289
290        let mut ai = self.association_internal.lock().await;
291        ai.close().await
292    }
293
294    async fn new(config: Config, is_client: bool) -> Result<(Self, mpsc::Receiver<Option<Error>>)> {
295        let net_conn = Arc::clone(&config.net_conn);
296
297        let (awake_write_loop_ch_tx, awake_write_loop_ch_rx) = mpsc::channel(1);
298        let (accept_ch_tx, accept_ch_rx) = mpsc::channel(ACCEPT_CH_SIZE);
299        let (handshake_completed_ch_tx, handshake_completed_ch_rx) = mpsc::channel(1);
300        let (close_loop_ch_tx, close_loop_ch_rx) = broadcast::channel(1);
301        let (close_loop_ch_rx1, close_loop_ch_rx2) =
302            (close_loop_ch_tx.subscribe(), close_loop_ch_tx.subscribe());
303        let awake_write_loop_ch = Arc::new(awake_write_loop_ch_tx);
304
305        let ai = AssociationInternal::new(
306            config,
307            close_loop_ch_tx,
308            accept_ch_tx,
309            handshake_completed_ch_tx,
310            Arc::clone(&awake_write_loop_ch),
311        );
312
313        let bytes_received = Arc::new(AtomicUsize::new(0));
314        let bytes_sent = Arc::new(AtomicUsize::new(0));
315        let name = ai.name.clone();
316        let state = Arc::clone(&ai.state);
317        let max_message_size = Arc::clone(&ai.max_message_size);
318        let inflight_queue_length = Arc::clone(&ai.inflight_queue_length);
319        let will_send_shutdown = Arc::clone(&ai.will_send_shutdown);
320
321        let mut init = ChunkInit {
322            initial_tsn: ai.my_next_tsn,
323            num_outbound_streams: ai.my_max_num_outbound_streams,
324            num_inbound_streams: ai.my_max_num_inbound_streams,
325            initiate_tag: ai.my_verification_tag,
326            advertised_receiver_window_credit: ai.max_receive_buffer_size,
327            ..Default::default()
328        };
329        init.set_supported_extensions();
330
331        let association_internal = Arc::new(Mutex::new(ai));
332        {
333            let weak = Arc::downgrade(&association_internal);
334            let mut ai = association_internal.lock().await;
335            ai.t1init = Some(RtxTimer::new(
336                weak.clone(),
337                RtxTimerId::T1Init,
338                MAX_INIT_RETRANS,
339            ));
340            ai.t1cookie = Some(RtxTimer::new(
341                weak.clone(),
342                RtxTimerId::T1Cookie,
343                MAX_INIT_RETRANS,
344            ));
345            ai.t2shutdown = Some(RtxTimer::new(
346                weak.clone(),
347                RtxTimerId::T2Shutdown,
348                NO_MAX_RETRANS,
349            )); // retransmit forever
350            ai.t3rtx = Some(RtxTimer::new(
351                weak.clone(),
352                RtxTimerId::T3RTX,
353                NO_MAX_RETRANS,
354            )); // retransmit forever
355            ai.treconfig = Some(RtxTimer::new(
356                weak.clone(),
357                RtxTimerId::Reconfig,
358                NO_MAX_RETRANS,
359            )); // retransmit forever
360            ai.ack_timer = Some(AckTimer::new(weak, ACK_INTERVAL));
361
362            tokio::spawn(Association::read_loop(
363                name.clone(),
364                Arc::clone(&bytes_received),
365                Arc::clone(&net_conn),
366                close_loop_ch_rx1,
367                Arc::clone(&association_internal),
368            ));
369
370            tokio::spawn(Association::write_loop(
371                name.clone(),
372                Arc::clone(&bytes_sent),
373                Arc::clone(&net_conn),
374                close_loop_ch_rx2,
375                Arc::clone(&association_internal),
376                awake_write_loop_ch_rx,
377            ));
378
379            if is_client {
380                ai.set_state(AssociationState::CookieWait);
381                ai.stored_init = Some(init);
382                ai.send_init()?;
383                let rto = ai.rto_mgr.get_rto();
384                if let Some(t1init) = &ai.t1init {
385                    t1init.start(rto).await;
386                }
387            }
388        }
389
390        Ok((
391            Association {
392                name,
393                state,
394                max_message_size,
395                inflight_queue_length,
396                will_send_shutdown,
397                awake_write_loop_ch,
398                close_loop_ch_rx: Mutex::new(close_loop_ch_rx),
399                accept_ch_rx: Mutex::new(accept_ch_rx),
400                net_conn,
401                bytes_received,
402                bytes_sent,
403                association_internal,
404            },
405            handshake_completed_ch_rx,
406        ))
407    }
408
409    async fn read_loop(
410        name: String,
411        bytes_received: Arc<AtomicUsize>,
412        net_conn: Arc<dyn Conn + Send + Sync>,
413        mut close_loop_ch: broadcast::Receiver<()>,
414        association_internal: Arc<Mutex<AssociationInternal>>,
415    ) {
416        log::debug!("[{name}] read_loop entered");
417
418        let mut buffer = vec![0u8; RECEIVE_MTU];
419        let mut done = false;
420        let mut n;
421        while !done {
422            tokio::select! {
423                _ = close_loop_ch.recv() => break,
424                result = net_conn.recv(&mut buffer) => {
425                    match result {
426                        Ok(m) => {
427                            n=m;
428                        }
429                        Err(err) => {
430                            log::warn!("[{name}] failed to read packets on net_conn: {err}");
431                            break;
432                        }
433                    }
434                }
435            };
436
437            // Make a buffer sized to what we read, then copy the data we
438            // read from the underlying transport. We do this because the
439            // user data is passed to the reassembly queue without
440            // copying.
441            log::debug!("[{name}] recving {n} bytes");
442            let inbound = Bytes::from(buffer[..n].to_vec());
443            bytes_received.fetch_add(n, Ordering::SeqCst);
444
445            {
446                let mut ai = association_internal.lock().await;
447                if let Err(err) = ai.handle_inbound(&inbound).await {
448                    log::warn!("[{name}] failed to handle_inbound: {err:?}");
449                    done = true;
450                }
451            }
452        }
453
454        {
455            let mut ai = association_internal.lock().await;
456            if let Err(err) = ai.close().await {
457                log::warn!("[{name}] failed to close association: {err:?}");
458            }
459        }
460
461        log::debug!("[{name}] read_loop exited");
462    }
463
464    async fn write_loop(
465        name: String,
466        bytes_sent: Arc<AtomicUsize>,
467        net_conn: Arc<dyn Conn + Send + Sync>,
468        mut close_loop_ch: broadcast::Receiver<()>,
469        association_internal: Arc<Mutex<AssociationInternal>>,
470        mut awake_write_loop_ch: mpsc::Receiver<()>,
471    ) {
472        log::debug!("[{name}] write_loop entered");
473        let done = Arc::new(AtomicBool::new(false));
474        let name = Arc::new(name);
475
476        'outer: while !done.load(Ordering::Relaxed) {
477            //log::debug!("[{}] gather_outbound begin", name);
478            let (packets, continue_loop) = {
479                let mut ai = association_internal.lock().await;
480                ai.gather_outbound().await
481            };
482            //log::debug!("[{}] gather_outbound done with {}", name, packets.len());
483
484            let net_conn = Arc::clone(&net_conn);
485            let bytes_sent = Arc::clone(&bytes_sent);
486            let name2 = Arc::clone(&name);
487            let done2 = Arc::clone(&done);
488            let mut buffer = None;
489            for raw in packets {
490                let mut buf = buffer
491                    .take()
492                    .unwrap_or_else(|| BytesMut::with_capacity(16 * 1024));
493
494                // We do the marshalling work in a blocking task here for a reason:
495                // If we don't tokio tends to run the write_loop and read_loop of one connection on the same OS thread
496                // This means that even though we release the lock above, the read_loop isn't able to take it, simply because it is not being scheduled by tokio
497                // Doing it this way, tokio schedules this work on a dedicated blocking thread, this future is suspended, and the read_loop can make progress
498                match tokio::task::spawn_blocking(move || raw.marshal_to(&mut buf).map(|_| buf))
499                    .await
500                {
501                    Ok(Ok(mut buf)) => {
502                        let raw = buf.as_ref();
503                        if let Err(err) = net_conn.send(raw.as_ref()).await {
504                            log::warn!("[{name2}] failed to write packets on net_conn: {err}");
505                            done2.store(true, Ordering::Relaxed)
506                        } else {
507                            bytes_sent.fetch_add(raw.len(), Ordering::SeqCst);
508                        }
509
510                        // Reuse allocation. Have to use options, since spawn blocking can't borrow, has to take ownership.
511                        buf.clear();
512                        buffer = Some(buf);
513                    }
514                    Ok(Err(err)) => {
515                        log::warn!("[{name2}] failed to serialize a packet: {err:?}");
516                    }
517                    Err(err) => {
518                        if err.is_cancelled() {
519                            log::debug!(
520                                "[{name}] task cancelled while serializing a packet: {err:?}"
521                            );
522                            break 'outer;
523                        } else {
524                            log::error!("[{name}] panic while serializing a packet: {err:?}");
525                        }
526                    }
527                }
528                //log::debug!("[{}] sending {} bytes done", name, raw.len());
529            }
530
531            if !continue_loop {
532                break;
533            }
534
535            //log::debug!("[{}] wait awake_write_loop_ch", name);
536            tokio::select! {
537                _ = awake_write_loop_ch.recv() =>{}
538                _ = close_loop_ch.recv() => {
539                    done.store(true, Ordering::Relaxed);
540                }
541            };
542            //log::debug!("[{}] wait awake_write_loop_ch done", name);
543        }
544
545        {
546            let mut ai = association_internal.lock().await;
547            if let Err(err) = ai.close().await {
548                log::warn!("[{name}] failed to close association: {err:?}");
549            }
550        }
551
552        log::debug!("[{name}] write_loop exited");
553    }
554
555    /// bytes_sent returns the number of bytes sent
556    pub fn bytes_sent(&self) -> usize {
557        self.bytes_sent.load(Ordering::SeqCst)
558    }
559
560    /// bytes_received returns the number of bytes received
561    pub fn bytes_received(&self) -> usize {
562        self.bytes_received.load(Ordering::SeqCst)
563    }
564
565    /// open_stream opens a stream
566    pub async fn open_stream(
567        &self,
568        stream_identifier: u16,
569        default_payload_type: PayloadProtocolIdentifier,
570    ) -> Result<Arc<Stream>> {
571        let mut ai = self.association_internal.lock().await;
572        ai.open_stream(stream_identifier, default_payload_type)
573    }
574
575    /// accept_stream accepts a stream
576    pub async fn accept_stream(&self) -> Option<Arc<Stream>> {
577        let mut accept_ch_rx = self.accept_ch_rx.lock().await;
578        accept_ch_rx.recv().await
579    }
580
581    /// max_message_size returns the maximum message size you can send.
582    pub fn max_message_size(&self) -> u32 {
583        self.max_message_size.load(Ordering::SeqCst)
584    }
585
586    /// set_max_message_size sets the maximum message size you can send.
587    pub fn set_max_message_size(&self, max_message_size: u32) {
588        self.max_message_size
589            .store(max_message_size, Ordering::SeqCst);
590    }
591
592    /// set_state atomically sets the state of the Association.
593    fn set_state(&self, new_state: AssociationState) {
594        let old_state = AssociationState::from(self.state.swap(new_state as u8, Ordering::SeqCst));
595        if new_state != old_state {
596            log::debug!(
597                "[{}] state change: '{}' => '{}'",
598                self.name,
599                old_state,
600                new_state,
601            );
602        }
603    }
604
605    /// get_state atomically returns the state of the Association.
606    fn get_state(&self) -> AssociationState {
607        self.state.load(Ordering::SeqCst).into()
608    }
609}