sctp_async/association/
mod.rs

1//TODO: remove this conditional test
2#[cfg(not(target_os = "macos"))]
3#[cfg(test)]
4mod association_test;
5
6mod association_internal;
7mod association_stats;
8
9use crate::chunk::chunk_abort::ChunkAbort;
10use crate::chunk::chunk_cookie_ack::ChunkCookieAck;
11use crate::chunk::chunk_cookie_echo::ChunkCookieEcho;
12use crate::chunk::chunk_error::ChunkError;
13use crate::chunk::chunk_forward_tsn::{ChunkForwardTsn, ChunkForwardTsnStream};
14use crate::chunk::chunk_heartbeat::ChunkHeartbeat;
15use crate::chunk::chunk_heartbeat_ack::ChunkHeartbeatAck;
16use crate::chunk::chunk_init::ChunkInit;
17use crate::chunk::chunk_payload_data::{ChunkPayloadData, PayloadProtocolIdentifier};
18use crate::chunk::chunk_reconfig::ChunkReconfig;
19use crate::chunk::chunk_selective_ack::ChunkSelectiveAck;
20use crate::chunk::chunk_shutdown::ChunkShutdown;
21use crate::chunk::chunk_shutdown_ack::ChunkShutdownAck;
22use crate::chunk::chunk_shutdown_complete::ChunkShutdownComplete;
23use crate::chunk::chunk_type::*;
24use crate::chunk::Chunk;
25use crate::error::{Error, Result};
26use crate::error_cause::*;
27use crate::packet::Packet;
28use crate::param::param_heartbeat_info::ParamHeartbeatInfo;
29use crate::param::param_outgoing_reset_request::ParamOutgoingResetRequest;
30use crate::param::param_reconfig_response::{ParamReconfigResponse, ReconfigResult};
31use crate::param::param_state_cookie::ParamStateCookie;
32use crate::param::param_supported_extensions::ParamSupportedExtensions;
33use crate::param::Param;
34use crate::queue::control_queue::ControlQueue;
35use crate::queue::payload_queue::PayloadQueue;
36use crate::queue::pending_queue::PendingQueue;
37use crate::stream::*;
38use crate::timer::ack_timer::*;
39use crate::timer::rtx_timer::*;
40use crate::util::*;
41
42use association_internal::*;
43use association_stats::*;
44
45use bytes::Bytes;
46use rand::random;
47use std::collections::{HashMap, VecDeque};
48use std::fmt;
49use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU8, AtomicUsize, Ordering};
50use std::sync::Arc;
51use std::time::SystemTime;
52use tokio::sync::{broadcast, mpsc, Mutex};
53use util::Conn;
54
55pub(crate) const RECEIVE_MTU: usize = 8192;
56/// MTU for inbound packet (from DTLS)
57pub(crate) const INITIAL_MTU: u32 = 1228;
58/// initial MTU for outgoing packets (to DTLS)
59pub(crate) const INITIAL_RECV_BUF_SIZE: u32 = 1024 * 1024;
60pub(crate) const COMMON_HEADER_SIZE: u32 = 12;
61pub(crate) const DATA_CHUNK_HEADER_SIZE: u32 = 16;
62pub(crate) const DEFAULT_MAX_MESSAGE_SIZE: u32 = 65536;
63
64/// other constants
65pub(crate) const ACCEPT_CH_SIZE: usize = 16;
66
67/// association state enums
68#[derive(Debug, Copy, Clone, PartialEq)]
69pub(crate) enum AssociationState {
70    Closed = 0,
71    CookieWait = 1,
72    CookieEchoed = 2,
73    Established = 3,
74    ShutdownAckSent = 4,
75    ShutdownPending = 5,
76    ShutdownReceived = 6,
77    ShutdownSent = 7,
78}
79
80impl From<u8> for AssociationState {
81    fn from(v: u8) -> AssociationState {
82        match v {
83            1 => AssociationState::CookieWait,
84            2 => AssociationState::CookieEchoed,
85            3 => AssociationState::Established,
86            4 => AssociationState::ShutdownAckSent,
87            5 => AssociationState::ShutdownPending,
88            6 => AssociationState::ShutdownReceived,
89            7 => AssociationState::ShutdownSent,
90            _ => AssociationState::Closed,
91        }
92    }
93}
94
95impl fmt::Display for AssociationState {
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        let s = match *self {
98            AssociationState::Closed => "Closed",
99            AssociationState::CookieWait => "CookieWait",
100            AssociationState::CookieEchoed => "CookieEchoed",
101            AssociationState::Established => "Established",
102            AssociationState::ShutdownPending => "ShutdownPending",
103            AssociationState::ShutdownSent => "ShutdownSent",
104            AssociationState::ShutdownReceived => "ShutdownReceived",
105            AssociationState::ShutdownAckSent => "ShutdownAckSent",
106        };
107        write!(f, "{}", s)
108    }
109}
110
111/// retransmission timer IDs
112#[derive(Debug, Copy, Clone, PartialEq)]
113pub(crate) enum RtxTimerId {
114    T1Init,
115    T1Cookie,
116    T2Shutdown,
117    T3RTX,
118    Reconfig,
119}
120
121impl Default for RtxTimerId {
122    fn default() -> Self {
123        RtxTimerId::T1Init
124    }
125}
126
127impl fmt::Display for RtxTimerId {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        let s = match *self {
130            RtxTimerId::T1Init => "T1Init",
131            RtxTimerId::T1Cookie => "T1Cookie",
132            RtxTimerId::T2Shutdown => "T2Shutdown",
133            RtxTimerId::T3RTX => "T3RTX",
134            RtxTimerId::Reconfig => "Reconfig",
135        };
136        write!(f, "{}", s)
137    }
138}
139
140/// ack mode (for testing)
141#[derive(Debug, Copy, Clone, PartialEq)]
142pub(crate) enum AckMode {
143    Normal,
144    NoDelay,
145    AlwaysDelay,
146}
147impl Default for AckMode {
148    fn default() -> Self {
149        AckMode::Normal
150    }
151}
152
153impl fmt::Display for AckMode {
154    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155        let s = match *self {
156            AckMode::Normal => "Normal",
157            AckMode::NoDelay => "NoDelay",
158            AckMode::AlwaysDelay => "AlwaysDelay",
159        };
160        write!(f, "{}", s)
161    }
162}
163
164/// ack transmission state
165#[derive(Debug, Copy, Clone, PartialEq)]
166pub(crate) enum AckState {
167    Idle,      // ack timer is off
168    Immediate, // will send ack immediately
169    Delay,     // ack timer is on (ack is being delayed)
170}
171
172impl Default for AckState {
173    fn default() -> Self {
174        AckState::Idle
175    }
176}
177
178impl fmt::Display for AckState {
179    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180        let s = match *self {
181            AckState::Idle => "Idle",
182            AckState::Immediate => "Immediate",
183            AckState::Delay => "Delay",
184        };
185        write!(f, "{}", s)
186    }
187}
188
189/// Config collects the arguments to create_association construction into
190/// a single structure
191pub struct Config {
192    pub net_conn: Arc<dyn Conn + Send + Sync>,
193    pub max_receive_buffer_size: u32,
194    pub max_message_size: u32,
195    pub name: String,
196}
197
198///Association represents an SCTP association
199///13.2.  Parameters Necessary per Association (i.e., the TCB)
200///Peer : Tag value to be sent in every packet and is received
201///Verification: in the INIT or INIT ACK chunk.
202///Tag :
203///
204///My : Tag expected in every inbound packet and sent in the
205///Verification: INIT or INIT ACK chunk.
206///
207///Tag :
208///State : A state variable indicating what state the association
209/// : is in, i.e., COOKIE-WAIT, COOKIE-ECHOED, ESTABLISHED,
210/// : SHUTDOWN-PENDING, SHUTDOWN-SENT, SHUTDOWN-RECEIVED,
211/// : SHUTDOWN-ACK-SENT.
212///
213/// No Closed state is illustrated since if a
214/// association is Closed its TCB SHOULD be removed.
215pub struct Association {
216    name: String,
217    state: Arc<AtomicU8>,
218    max_message_size: Arc<AtomicU32>,
219    inflight_queue_length: Arc<AtomicUsize>,
220    will_send_shutdown: Arc<AtomicBool>,
221    awake_write_loop_ch: Arc<mpsc::Sender<()>>,
222    close_loop_ch_rx: Mutex<broadcast::Receiver<()>>,
223    accept_ch_rx: Mutex<mpsc::Receiver<Arc<Stream>>>,
224    net_conn: Arc<dyn Conn + Send + Sync>,
225    bytes_received: Arc<AtomicUsize>,
226    bytes_sent: Arc<AtomicUsize>,
227
228    pub(crate) association_internal: Arc<Mutex<AssociationInternal>>,
229}
230
231impl Association {
232    /// server accepts a SCTP stream over a conn
233    pub async fn server(config: Config) -> Result<Self> {
234        let (a, mut handshake_completed_ch_rx) = Association::new(config, false).await?;
235
236        if let Some(err_opt) = handshake_completed_ch_rx.recv().await {
237            if let Some(err) = err_opt {
238                Err(err)
239            } else {
240                Ok(a)
241            }
242        } else {
243            Err(Error::ErrAssociationHandshakeClosed)
244        }
245    }
246
247    /// Client opens a SCTP stream over a conn
248    pub async fn client(config: Config) -> Result<Self> {
249        let (a, mut handshake_completed_ch_rx) = Association::new(config, true).await?;
250
251        if let Some(err_opt) = handshake_completed_ch_rx.recv().await {
252            if let Some(err) = err_opt {
253                Err(err)
254            } else {
255                Ok(a)
256            }
257        } else {
258            Err(Error::ErrAssociationHandshakeClosed)
259        }
260    }
261
262    /// Shutdown initiates the shutdown sequence. The method blocks until the
263    /// shutdown sequence is completed and the connection is closed, or until the
264    /// passed context is done, in which case the context's error is returned.
265    pub async fn shutdown(&self) -> Result<()> {
266        log::debug!("[{}] closing association..", self.name);
267
268        let state = self.get_state();
269        if state != AssociationState::Established {
270            return Err(Error::ErrShutdownNonEstablished);
271        }
272
273        // Attempt a graceful shutdown.
274        self.set_state(AssociationState::ShutdownPending);
275
276        if self.inflight_queue_length.load(Ordering::SeqCst) == 0 {
277            // No more outstanding, send shutdown.
278            self.will_send_shutdown.store(true, Ordering::SeqCst);
279            let _ = self.awake_write_loop_ch.try_send(());
280            self.set_state(AssociationState::ShutdownSent);
281        }
282
283        {
284            let mut close_loop_ch_rx = self.close_loop_ch_rx.lock().await;
285            let _ = close_loop_ch_rx.recv().await;
286        }
287
288        Ok(())
289    }
290
291    /// Close ends the SCTP Association and cleans up any state
292    pub async fn close(&self) -> Result<()> {
293        log::debug!("[{}] closing association..", self.name);
294
295        let _ = self.net_conn.close().await;
296
297        let mut ai = self.association_internal.lock().await;
298        ai.close().await
299    }
300
301    async fn new(config: Config, is_client: bool) -> Result<(Self, mpsc::Receiver<Option<Error>>)> {
302        let net_conn = Arc::clone(&config.net_conn);
303
304        let (awake_write_loop_ch_tx, awake_write_loop_ch_rx) = mpsc::channel(1);
305        let (accept_ch_tx, accept_ch_rx) = mpsc::channel(ACCEPT_CH_SIZE);
306        let (handshake_completed_ch_tx, handshake_completed_ch_rx) = mpsc::channel(1);
307        let (close_loop_ch_tx, close_loop_ch_rx) = broadcast::channel(1);
308        let (close_loop_ch_rx1, close_loop_ch_rx2) =
309            (close_loop_ch_tx.subscribe(), close_loop_ch_tx.subscribe());
310        let awake_write_loop_ch = Arc::new(awake_write_loop_ch_tx);
311
312        let ai = AssociationInternal::new(
313            config,
314            close_loop_ch_tx,
315            accept_ch_tx,
316            handshake_completed_ch_tx,
317            Arc::clone(&awake_write_loop_ch),
318        );
319
320        let bytes_received = Arc::new(AtomicUsize::new(0));
321        let bytes_sent = Arc::new(AtomicUsize::new(0));
322        let name = ai.name.clone();
323        let state = Arc::clone(&ai.state);
324        let max_message_size = Arc::clone(&ai.max_message_size);
325        let inflight_queue_length = Arc::clone(&ai.inflight_queue_length);
326        let will_send_shutdown = Arc::clone(&ai.will_send_shutdown);
327
328        let mut init = ChunkInit {
329            initial_tsn: ai.my_next_tsn,
330            num_outbound_streams: ai.my_max_num_outbound_streams,
331            num_inbound_streams: ai.my_max_num_inbound_streams,
332            initiate_tag: ai.my_verification_tag,
333            advertised_receiver_window_credit: ai.max_receive_buffer_size,
334            ..Default::default()
335        };
336        init.set_supported_extensions();
337
338        let name1 = name.clone();
339        let name2 = name.clone();
340
341        let bytes_received1 = Arc::clone(&bytes_received);
342        let bytes_sent2 = Arc::clone(&bytes_sent);
343
344        let net_conn1 = Arc::clone(&net_conn);
345        let net_conn2 = Arc::clone(&net_conn);
346
347        let association_internal = Arc::new(Mutex::new(ai));
348        let association_internal1 = Arc::clone(&association_internal);
349        let association_internal2 = Arc::clone(&association_internal);
350
351        {
352            let association_internal3 = Arc::clone(&association_internal);
353
354            let mut ai = association_internal.lock().await;
355            ai.t1init = Some(RtxTimer::new(
356                Arc::downgrade(&association_internal3),
357                RtxTimerId::T1Init,
358                MAX_INIT_RETRANS,
359            ));
360            ai.t1cookie = Some(RtxTimer::new(
361                Arc::downgrade(&association_internal3),
362                RtxTimerId::T1Cookie,
363                MAX_INIT_RETRANS,
364            ));
365            ai.t2shutdown = Some(RtxTimer::new(
366                Arc::downgrade(&association_internal3),
367                RtxTimerId::T2Shutdown,
368                NO_MAX_RETRANS,
369            )); // retransmit forever
370            ai.t3rtx = Some(RtxTimer::new(
371                Arc::downgrade(&association_internal3),
372                RtxTimerId::T3RTX,
373                NO_MAX_RETRANS,
374            )); // retransmit forever
375            ai.treconfig = Some(RtxTimer::new(
376                Arc::downgrade(&association_internal3),
377                RtxTimerId::Reconfig,
378                NO_MAX_RETRANS,
379            )); // retransmit forever
380            ai.ack_timer = Some(AckTimer::new(
381                Arc::downgrade(&association_internal3),
382                ACK_INTERVAL,
383            ));
384        }
385
386        tokio::spawn(async move {
387            Association::read_loop(
388                name1,
389                bytes_received1,
390                net_conn1,
391                close_loop_ch_rx1,
392                association_internal1,
393            )
394            .await;
395        });
396
397        tokio::spawn(async move {
398            Association::write_loop(
399                name2,
400                bytes_sent2,
401                net_conn2,
402                close_loop_ch_rx2,
403                association_internal2,
404                awake_write_loop_ch_rx,
405            )
406            .await;
407        });
408
409        if is_client {
410            let mut ai = association_internal.lock().await;
411            ai.set_state(AssociationState::CookieWait);
412            ai.stored_init = Some(init);
413            ai.send_init()?;
414            let rto = ai.rto_mgr.get_rto();
415            if let Some(t1init) = &ai.t1init {
416                t1init.start(rto).await;
417            }
418        }
419
420        Ok((
421            Association {
422                name,
423                state,
424                max_message_size,
425                inflight_queue_length,
426                will_send_shutdown,
427                awake_write_loop_ch,
428                close_loop_ch_rx: Mutex::new(close_loop_ch_rx),
429                accept_ch_rx: Mutex::new(accept_ch_rx),
430                net_conn,
431                bytes_received,
432                bytes_sent,
433                association_internal,
434            },
435            handshake_completed_ch_rx,
436        ))
437    }
438
439    async fn read_loop(
440        name: String,
441        bytes_received: Arc<AtomicUsize>,
442        net_conn: Arc<dyn Conn + Send + Sync>,
443        mut close_loop_ch: broadcast::Receiver<()>,
444        association_internal: Arc<Mutex<AssociationInternal>>,
445    ) {
446        log::debug!("[{}] read_loop entered", name);
447
448        let mut buffer = vec![0u8; RECEIVE_MTU];
449        let mut done = false;
450        let mut n;
451        while !done {
452            tokio::select! {
453                _ = close_loop_ch.recv() => break,
454                result = net_conn.recv(&mut buffer) => {
455                    match result {
456                        Ok(m) => {
457                            n=m;
458                        }
459                        Err(err) => {
460                            log::warn!("[{}] failed to read packets on net_conn: {}", name, err);
461                            break;
462                        }
463                    }
464                }
465            };
466
467            // Make a buffer sized to what we read, then copy the data we
468            // read from the underlying transport. We do this because the
469            // user data is passed to the reassembly queue without
470            // copying.
471            log::debug!("[{}] recving {} bytes", name, n);
472            let inbound = Bytes::from(buffer[..n].to_vec());
473            bytes_received.fetch_add(n, Ordering::SeqCst);
474
475            {
476                let mut ai = association_internal.lock().await;
477                if let Err(err) = ai.handle_inbound(&inbound).await {
478                    log::warn!("[{}] failed to handle_inbound: {:?}", name, err);
479                    done = true;
480                }
481            }
482        }
483
484        {
485            let mut ai = association_internal.lock().await;
486            if let Err(err) = ai.close().await {
487                log::warn!("[{}] failed to close association: {:?}", name, err);
488            }
489        }
490
491        log::debug!("[{}] read_loop exited", name);
492    }
493
494    async fn write_loop(
495        name: String,
496        bytes_sent: Arc<AtomicUsize>,
497        net_conn: Arc<dyn Conn + Send + Sync>,
498        mut close_loop_ch: broadcast::Receiver<()>,
499        association_internal: Arc<Mutex<AssociationInternal>>,
500        mut awake_write_loop_ch: mpsc::Receiver<()>,
501    ) {
502        log::debug!("[{}] write_loop entered", name);
503        let mut done = false;
504        while !done {
505            //log::debug!("[{}] gather_outbound begin", name);
506            let (raw_packets, mut ok) = {
507                let mut ai = association_internal.lock().await;
508                ai.gather_outbound().await
509            };
510            //log::debug!("[{}] gather_outbound done with {}", name, raw_packets.len());
511
512            for raw in &raw_packets {
513                log::debug!("[{}] sending {} bytes", name, raw.len());
514                if let Err(err) = net_conn.send(raw).await {
515                    log::warn!("[{}] failed to write packets on net_conn: {}", name, err);
516                    ok = false;
517                    break;
518                } else {
519                    bytes_sent.fetch_add(raw.len(), Ordering::SeqCst);
520                }
521                //log::debug!("[{}] sending {} bytes done", name, raw.len());
522            }
523
524            if !ok {
525                break;
526            }
527
528            //log::debug!("[{}] wait awake_write_loop_ch", name);
529            tokio::select! {
530                _ = awake_write_loop_ch.recv() =>{}
531                _ = close_loop_ch.recv() => {
532                    done = true;
533                }
534            };
535            //log::debug!("[{}] wait awake_write_loop_ch done", name);
536        }
537
538        {
539            let mut ai = association_internal.lock().await;
540            if let Err(err) = ai.close().await {
541                log::warn!("[{}] failed to close association: {:?}", name, err);
542            }
543        }
544
545        log::debug!("[{}] write_loop exited", name);
546    }
547
548    /// bytes_sent returns the number of bytes sent
549    pub fn bytes_sent(&self) -> usize {
550        self.bytes_sent.load(Ordering::SeqCst)
551    }
552
553    /// bytes_received returns the number of bytes received
554    pub fn bytes_received(&self) -> usize {
555        self.bytes_received.load(Ordering::SeqCst)
556    }
557
558    /// open_stream opens a stream
559    pub async fn open_stream(
560        &self,
561        stream_identifier: u16,
562        default_payload_type: PayloadProtocolIdentifier,
563    ) -> Result<Arc<Stream>> {
564        let mut ai = self.association_internal.lock().await;
565        ai.open_stream(stream_identifier, default_payload_type)
566    }
567
568    /// accept_stream accepts a stream
569    pub async fn accept_stream(&self) -> Option<Arc<Stream>> {
570        let mut accept_ch_rx = self.accept_ch_rx.lock().await;
571        accept_ch_rx.recv().await
572    }
573
574    /// max_message_size returns the maximum message size you can send.
575    pub fn max_message_size(&self) -> u32 {
576        self.max_message_size.load(Ordering::SeqCst)
577    }
578
579    /// set_max_message_size sets the maximum message size you can send.
580    pub fn set_max_message_size(&self, max_message_size: u32) {
581        self.max_message_size
582            .store(max_message_size, Ordering::SeqCst);
583    }
584
585    /// set_state atomically sets the state of the Association.
586    fn set_state(&self, new_state: AssociationState) {
587        let old_state = AssociationState::from(self.state.swap(new_state as u8, Ordering::SeqCst));
588        if new_state != old_state {
589            log::debug!(
590                "[{}] state change: '{}' => '{}'",
591                self.name,
592                old_state,
593                new_state,
594            );
595        }
596    }
597
598    /// get_state atomically returns the state of the Association.
599    fn get_state(&self) -> AssociationState {
600        self.state.load(Ordering::SeqCst).into()
601    }
602}