rtc_interceptor/twcc/
sender.rs

1//! TWCC Sender Interceptor - adds transport-wide sequence numbers to outgoing packets.
2
3use super::stream_supports_twcc;
4use crate::stream_info::StreamInfo;
5use crate::{Interceptor, Packet, TaggedPacket, interceptor};
6use shared::error::Error;
7use shared::marshal::Marshal;
8use std::collections::HashMap;
9use std::marker::PhantomData;
10
11/// Builder for the TwccSenderInterceptor.
12///
13/// # Example
14///
15/// ```ignore
16/// use rtc_interceptor::{Registry, TwccSenderBuilder};
17///
18/// let chain = Registry::new()
19///     .with(TwccSenderBuilder::new().build())
20///     .build();
21/// ```
22pub struct TwccSenderBuilder<P> {
23    _phantom: PhantomData<P>,
24}
25
26impl<P> Default for TwccSenderBuilder<P> {
27    fn default() -> Self {
28        Self {
29            _phantom: PhantomData,
30        }
31    }
32}
33
34impl<P> TwccSenderBuilder<P> {
35    /// Create a new builder with default settings.
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    /// Build the interceptor factory function.
41    pub fn build(self) -> impl FnOnce(P) -> TwccSenderInterceptor<P> {
42        move |inner| TwccSenderInterceptor::new(inner)
43    }
44}
45
46/// Per-stream state for the sender.
47struct LocalStream {
48    /// Header extension ID for transport-wide CC.
49    hdr_ext_id: u8,
50}
51
52/// Interceptor that adds transport-wide sequence numbers to outgoing RTP packets.
53///
54/// This interceptor examines the stream's RTP header extensions for the transport-wide
55/// CC extension URI and adds sequence numbers to each outgoing packet.
56#[derive(Interceptor)]
57pub struct TwccSenderInterceptor<P> {
58    #[next]
59    inner: P,
60    /// Transport-wide sequence number counter (shared across all streams).
61    next_sequence_number: u16,
62    /// Local stream state per SSRC.
63    streams: HashMap<u32, LocalStream>,
64}
65
66impl<P> TwccSenderInterceptor<P> {
67    fn new(inner: P) -> Self {
68        Self {
69            inner,
70            next_sequence_number: 0,
71            streams: HashMap::new(),
72        }
73    }
74}
75
76#[interceptor]
77impl<P: Interceptor> TwccSenderInterceptor<P> {
78    #[overrides]
79    fn handle_write(&mut self, mut msg: TaggedPacket) -> Result<(), Self::Error> {
80        // Add transport-wide CC sequence number to outgoing RTP packets
81        if let Packet::Rtp(ref mut rtp_packet) = msg.message
82            && let Some(stream) = self.streams.get(&rtp_packet.header.ssrc)
83        {
84            // Create transport CC extension
85            let seq = self.next_sequence_number;
86            self.next_sequence_number = self.next_sequence_number.wrapping_add(1);
87
88            let tcc_ext = rtp::extension::transport_cc_extension::TransportCcExtension {
89                transport_sequence: seq,
90            };
91
92            // Marshal the extension
93            if let Ok(ext_data) = tcc_ext.marshal() {
94                // Set the extension on the packet
95                let _ = rtp_packet
96                    .header
97                    .set_extension(stream.hdr_ext_id, ext_data.freeze());
98            }
99        }
100
101        self.inner.handle_write(msg)
102    }
103
104    #[overrides]
105    fn bind_local_stream(&mut self, info: &StreamInfo) {
106        if let Some(hdr_ext_id) = stream_supports_twcc(info) {
107            // Don't add header extension if ID is 0 (invalid)
108            if hdr_ext_id != 0 {
109                self.streams.insert(info.ssrc, LocalStream { hdr_ext_id });
110            }
111        }
112        self.inner.bind_local_stream(info);
113    }
114
115    #[overrides]
116    fn unbind_local_stream(&mut self, info: &StreamInfo) {
117        self.streams.remove(&info.ssrc);
118        self.inner.unbind_local_stream(info);
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::Registry;
126    use crate::stream_info::RTPHeaderExtension;
127    use sansio::Protocol;
128    use shared::marshal::Unmarshal;
129    use std::time::Instant;
130
131    fn make_rtp_packet(ssrc: u32, seq: u16) -> TaggedPacket {
132        TaggedPacket {
133            now: Instant::now(),
134            transport: Default::default(),
135            message: Packet::Rtp(rtp::Packet {
136                header: rtp::header::Header {
137                    ssrc,
138                    sequence_number: seq,
139                    ..Default::default()
140                },
141                payload: vec![].into(),
142            }),
143        }
144    }
145
146    #[test]
147    fn test_twcc_sender_builder_defaults() {
148        let chain = Registry::new()
149            .with(TwccSenderBuilder::default().build())
150            .build();
151
152        assert!(chain.streams.is_empty());
153    }
154
155    #[test]
156    fn test_twcc_sender_adds_extension() {
157        let mut chain = Registry::new()
158            .with(TwccSenderBuilder::new().build())
159            .build();
160
161        // Bind stream with TWCC support
162        let info = StreamInfo {
163            ssrc: 12345,
164            rtp_header_extensions: vec![RTPHeaderExtension {
165                uri: super::super::TRANSPORT_CC_URI.to_string(),
166                id: 5,
167            }],
168            ..Default::default()
169        };
170        chain.bind_local_stream(&info);
171
172        // Send packets
173        let pkt1 = make_rtp_packet(12345, 1);
174        chain.handle_write(pkt1).unwrap();
175        let out1 = chain.poll_write().unwrap();
176
177        let pkt2 = make_rtp_packet(12345, 2);
178        chain.handle_write(pkt2).unwrap();
179        let out2 = chain.poll_write().unwrap();
180
181        // Verify extensions were added with incrementing sequence numbers
182        if let Packet::Rtp(rtp1) = out1.message {
183            let ext = rtp1.header.get_extension(5);
184            assert!(ext.is_some());
185            let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
186                &mut ext.unwrap().as_ref(),
187            )
188            .unwrap();
189            assert_eq!(tcc.transport_sequence, 0);
190        } else {
191            panic!("Expected RTP packet");
192        }
193
194        if let Packet::Rtp(rtp2) = out2.message {
195            let ext = rtp2.header.get_extension(5);
196            assert!(ext.is_some());
197            let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
198                &mut ext.unwrap().as_ref(),
199            )
200            .unwrap();
201            assert_eq!(tcc.transport_sequence, 1);
202        } else {
203            panic!("Expected RTP packet");
204        }
205    }
206
207    #[test]
208    fn test_twcc_sender_no_extension_without_binding() {
209        let mut chain = Registry::new()
210            .with(TwccSenderBuilder::new().build())
211            .build();
212
213        // Send packet without binding (no TWCC)
214        let pkt = make_rtp_packet(12345, 1);
215        chain.handle_write(pkt).unwrap();
216        let out = chain.poll_write().unwrap();
217
218        // Verify no extension was added
219        if let Packet::Rtp(rtp) = out.message {
220            assert!(rtp.header.get_extension(5).is_none());
221        } else {
222            panic!("Expected RTP packet");
223        }
224    }
225
226    #[test]
227    fn test_twcc_sender_unbind_removes_stream() {
228        let mut chain = Registry::new()
229            .with(TwccSenderBuilder::new().build())
230            .build();
231
232        let info = StreamInfo {
233            ssrc: 12345,
234            rtp_header_extensions: vec![RTPHeaderExtension {
235                uri: super::super::TRANSPORT_CC_URI.to_string(),
236                id: 5,
237            }],
238            ..Default::default()
239        };
240
241        chain.bind_local_stream(&info);
242        assert!(chain.streams.contains_key(&12345));
243
244        chain.unbind_local_stream(&info);
245        assert!(!chain.streams.contains_key(&12345));
246    }
247
248    #[test]
249    fn test_twcc_sender_sequence_wraparound() {
250        let mut chain = Registry::new()
251            .with(TwccSenderBuilder::new().build())
252            .build();
253
254        let info = StreamInfo {
255            ssrc: 12345,
256            rtp_header_extensions: vec![RTPHeaderExtension {
257                uri: super::super::TRANSPORT_CC_URI.to_string(),
258                id: 5,
259            }],
260            ..Default::default()
261        };
262        chain.bind_local_stream(&info);
263
264        // Set sequence number near wraparound
265        chain.next_sequence_number = 65534;
266
267        for expected_seq in [65534u16, 65535, 0, 1] {
268            let pkt = make_rtp_packet(12345, 1);
269            chain.handle_write(pkt).unwrap();
270            let out = chain.poll_write().unwrap();
271
272            if let Packet::Rtp(rtp) = out.message {
273                let ext = rtp.header.get_extension(5).unwrap();
274                let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
275                    &mut ext.as_ref(),
276                )
277                .unwrap();
278                assert_eq!(tcc.transport_sequence, expected_seq);
279            }
280        }
281    }
282
283    #[test]
284    fn test_twcc_sender_multiple_streams_share_counter() {
285        let mut chain = Registry::new()
286            .with(TwccSenderBuilder::new().build())
287            .build();
288
289        // Bind two streams
290        let info1 = StreamInfo {
291            ssrc: 1111,
292            rtp_header_extensions: vec![RTPHeaderExtension {
293                uri: super::super::TRANSPORT_CC_URI.to_string(),
294                id: 5,
295            }],
296            ..Default::default()
297        };
298        let info2 = StreamInfo {
299            ssrc: 2222,
300            rtp_header_extensions: vec![RTPHeaderExtension {
301                uri: super::super::TRANSPORT_CC_URI.to_string(),
302                id: 5,
303            }],
304            ..Default::default()
305        };
306        chain.bind_local_stream(&info1);
307        chain.bind_local_stream(&info2);
308
309        // Send packets alternating between streams
310        for (i, ssrc) in [1111u32, 2222, 1111, 2222].iter().enumerate() {
311            let pkt = make_rtp_packet(*ssrc, 1);
312            chain.handle_write(pkt).unwrap();
313            let out = chain.poll_write().unwrap();
314
315            if let Packet::Rtp(rtp) = out.message {
316                let ext = rtp.header.get_extension(5).unwrap();
317                let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
318                    &mut ext.as_ref(),
319                )
320                .unwrap();
321                assert_eq!(tcc.transport_sequence, i as u16);
322            }
323        }
324    }
325}