rtc_interceptor/twcc/
sender.rs

1//! TWCC Sender Interceptor - adds transport-wide sequence numbers to outgoing packets.
2
3use super::stream_supports_twcc;
4use crate::stream_info::StreamInfo;
5use crate::{Interceptor, Packet, TaggedPacket};
6use shared::error::Error;
7use shared::marshal::Marshal;
8use std::collections::HashMap;
9use std::marker::PhantomData;
10use std::time::Instant;
11
12/// Builder for the TwccSenderInterceptor.
13///
14/// # Example
15///
16/// ```ignore
17/// use rtc_interceptor::{Registry, TwccSenderBuilder};
18///
19/// let chain = Registry::new()
20///     .with(TwccSenderBuilder::new().build())
21///     .build();
22/// ```
23pub struct TwccSenderBuilder<P> {
24    _phantom: PhantomData<P>,
25}
26
27impl<P> Default for TwccSenderBuilder<P> {
28    fn default() -> Self {
29        Self {
30            _phantom: PhantomData,
31        }
32    }
33}
34
35impl<P> TwccSenderBuilder<P> {
36    /// Create a new builder with default settings.
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    /// Build the interceptor factory function.
42    pub fn build(self) -> impl FnOnce(P) -> TwccSenderInterceptor<P> {
43        move |inner| TwccSenderInterceptor::new(inner)
44    }
45}
46
47/// Per-stream state for the sender.
48struct LocalStream {
49    /// Header extension ID for transport-wide CC.
50    hdr_ext_id: u8,
51}
52
53/// Interceptor that adds transport-wide sequence numbers to outgoing RTP packets.
54///
55/// This interceptor examines the stream's RTP header extensions for the transport-wide
56/// CC extension URI and adds sequence numbers to each outgoing packet.
57pub struct TwccSenderInterceptor<P> {
58    inner: P,
59    /// Transport-wide sequence number counter (shared across all streams).
60    next_sequence_number: u16,
61    /// Local stream state per SSRC.
62    streams: HashMap<u32, LocalStream>,
63}
64
65impl<P> TwccSenderInterceptor<P> {
66    fn new(inner: P) -> Self {
67        Self {
68            inner,
69            next_sequence_number: 0,
70            streams: HashMap::new(),
71        }
72    }
73}
74
75impl<P: Interceptor> sansio::Protocol<TaggedPacket, TaggedPacket, ()> for TwccSenderInterceptor<P> {
76    type Rout = TaggedPacket;
77    type Wout = TaggedPacket;
78    type Eout = ();
79    type Error = Error;
80    type Time = Instant;
81
82    fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
83        self.inner.handle_read(msg)
84    }
85
86    fn poll_read(&mut self) -> Option<Self::Rout> {
87        self.inner.poll_read()
88    }
89
90    fn handle_write(&mut self, mut msg: TaggedPacket) -> Result<(), Self::Error> {
91        // Add transport-wide CC sequence number to outgoing RTP packets
92        if let Packet::Rtp(ref mut rtp_packet) = msg.message
93            && let Some(stream) = self.streams.get(&rtp_packet.header.ssrc)
94        {
95            // Create transport CC extension
96            let seq = self.next_sequence_number;
97            self.next_sequence_number = self.next_sequence_number.wrapping_add(1);
98
99            let tcc_ext = rtp::extension::transport_cc_extension::TransportCcExtension {
100                transport_sequence: seq,
101            };
102
103            // Marshal the extension
104            if let Ok(ext_data) = tcc_ext.marshal() {
105                // Set the extension on the packet
106                let _ = rtp_packet
107                    .header
108                    .set_extension(stream.hdr_ext_id, ext_data.freeze());
109            }
110        }
111
112        self.inner.handle_write(msg)
113    }
114
115    fn poll_write(&mut self) -> Option<Self::Wout> {
116        self.inner.poll_write()
117    }
118
119    fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
120        self.inner.handle_timeout(now)
121    }
122
123    fn poll_timeout(&mut self) -> Option<Self::Time> {
124        self.inner.poll_timeout()
125    }
126}
127
128impl<P: Interceptor> Interceptor for TwccSenderInterceptor<P> {
129    fn bind_local_stream(&mut self, info: &StreamInfo) {
130        if let Some(hdr_ext_id) = stream_supports_twcc(info) {
131            // Don't add header extension if ID is 0 (invalid)
132            if hdr_ext_id != 0 {
133                self.streams.insert(info.ssrc, LocalStream { hdr_ext_id });
134            }
135        }
136        self.inner.bind_local_stream(info);
137    }
138
139    fn unbind_local_stream(&mut self, info: &StreamInfo) {
140        self.streams.remove(&info.ssrc);
141        self.inner.unbind_local_stream(info);
142    }
143
144    fn bind_remote_stream(&mut self, info: &StreamInfo) {
145        self.inner.bind_remote_stream(info);
146    }
147
148    fn unbind_remote_stream(&mut self, info: &StreamInfo) {
149        self.inner.unbind_remote_stream(info);
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use crate::Registry;
157    use crate::stream_info::RTPHeaderExtension;
158    use sansio::Protocol;
159    use shared::marshal::Unmarshal;
160
161    fn make_rtp_packet(ssrc: u32, seq: u16) -> TaggedPacket {
162        TaggedPacket {
163            now: Instant::now(),
164            transport: Default::default(),
165            message: Packet::Rtp(rtp::Packet {
166                header: rtp::header::Header {
167                    ssrc,
168                    sequence_number: seq,
169                    ..Default::default()
170                },
171                payload: vec![].into(),
172            }),
173        }
174    }
175
176    #[test]
177    fn test_twcc_sender_builder_defaults() {
178        let chain = Registry::new()
179            .with(TwccSenderBuilder::default().build())
180            .build();
181
182        assert!(chain.streams.is_empty());
183    }
184
185    #[test]
186    fn test_twcc_sender_adds_extension() {
187        let mut chain = Registry::new()
188            .with(TwccSenderBuilder::new().build())
189            .build();
190
191        // Bind stream with TWCC support
192        let info = StreamInfo {
193            ssrc: 12345,
194            rtp_header_extensions: vec![RTPHeaderExtension {
195                uri: super::super::TRANSPORT_CC_URI.to_string(),
196                id: 5,
197            }],
198            ..Default::default()
199        };
200        chain.bind_local_stream(&info);
201
202        // Send packets
203        let pkt1 = make_rtp_packet(12345, 1);
204        chain.handle_write(pkt1).unwrap();
205        let out1 = chain.poll_write().unwrap();
206
207        let pkt2 = make_rtp_packet(12345, 2);
208        chain.handle_write(pkt2).unwrap();
209        let out2 = chain.poll_write().unwrap();
210
211        // Verify extensions were added with incrementing sequence numbers
212        if let Packet::Rtp(rtp1) = out1.message {
213            let ext = rtp1.header.get_extension(5);
214            assert!(ext.is_some());
215            let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
216                &mut ext.unwrap().as_ref(),
217            )
218            .unwrap();
219            assert_eq!(tcc.transport_sequence, 0);
220        } else {
221            panic!("Expected RTP packet");
222        }
223
224        if let Packet::Rtp(rtp2) = out2.message {
225            let ext = rtp2.header.get_extension(5);
226            assert!(ext.is_some());
227            let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
228                &mut ext.unwrap().as_ref(),
229            )
230            .unwrap();
231            assert_eq!(tcc.transport_sequence, 1);
232        } else {
233            panic!("Expected RTP packet");
234        }
235    }
236
237    #[test]
238    fn test_twcc_sender_no_extension_without_binding() {
239        let mut chain = Registry::new()
240            .with(TwccSenderBuilder::new().build())
241            .build();
242
243        // Send packet without binding (no TWCC)
244        let pkt = make_rtp_packet(12345, 1);
245        chain.handle_write(pkt).unwrap();
246        let out = chain.poll_write().unwrap();
247
248        // Verify no extension was added
249        if let Packet::Rtp(rtp) = out.message {
250            assert!(rtp.header.get_extension(5).is_none());
251        } else {
252            panic!("Expected RTP packet");
253        }
254    }
255
256    #[test]
257    fn test_twcc_sender_unbind_removes_stream() {
258        let mut chain = Registry::new()
259            .with(TwccSenderBuilder::new().build())
260            .build();
261
262        let info = StreamInfo {
263            ssrc: 12345,
264            rtp_header_extensions: vec![RTPHeaderExtension {
265                uri: super::super::TRANSPORT_CC_URI.to_string(),
266                id: 5,
267            }],
268            ..Default::default()
269        };
270
271        chain.bind_local_stream(&info);
272        assert!(chain.streams.contains_key(&12345));
273
274        chain.unbind_local_stream(&info);
275        assert!(!chain.streams.contains_key(&12345));
276    }
277
278    #[test]
279    fn test_twcc_sender_sequence_wraparound() {
280        let mut chain = Registry::new()
281            .with(TwccSenderBuilder::new().build())
282            .build();
283
284        let info = StreamInfo {
285            ssrc: 12345,
286            rtp_header_extensions: vec![RTPHeaderExtension {
287                uri: super::super::TRANSPORT_CC_URI.to_string(),
288                id: 5,
289            }],
290            ..Default::default()
291        };
292        chain.bind_local_stream(&info);
293
294        // Set sequence number near wraparound
295        chain.next_sequence_number = 65534;
296
297        for expected_seq in [65534u16, 65535, 0, 1] {
298            let pkt = make_rtp_packet(12345, 1);
299            chain.handle_write(pkt).unwrap();
300            let out = chain.poll_write().unwrap();
301
302            if let Packet::Rtp(rtp) = out.message {
303                let ext = rtp.header.get_extension(5).unwrap();
304                let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
305                    &mut ext.as_ref(),
306                )
307                .unwrap();
308                assert_eq!(tcc.transport_sequence, expected_seq);
309            }
310        }
311    }
312
313    #[test]
314    fn test_twcc_sender_multiple_streams_share_counter() {
315        let mut chain = Registry::new()
316            .with(TwccSenderBuilder::new().build())
317            .build();
318
319        // Bind two streams
320        let info1 = StreamInfo {
321            ssrc: 1111,
322            rtp_header_extensions: vec![RTPHeaderExtension {
323                uri: super::super::TRANSPORT_CC_URI.to_string(),
324                id: 5,
325            }],
326            ..Default::default()
327        };
328        let info2 = StreamInfo {
329            ssrc: 2222,
330            rtp_header_extensions: vec![RTPHeaderExtension {
331                uri: super::super::TRANSPORT_CC_URI.to_string(),
332                id: 5,
333            }],
334            ..Default::default()
335        };
336        chain.bind_local_stream(&info1);
337        chain.bind_local_stream(&info2);
338
339        // Send packets alternating between streams
340        for (i, ssrc) in [1111u32, 2222, 1111, 2222].iter().enumerate() {
341            let pkt = make_rtp_packet(*ssrc, 1);
342            chain.handle_write(pkt).unwrap();
343            let out = chain.poll_write().unwrap();
344
345            if let Packet::Rtp(rtp) = out.message {
346                let ext = rtp.header.get_extension(5).unwrap();
347                let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
348                    &mut ext.as_ref(),
349                )
350                .unwrap();
351                assert_eq!(tcc.transport_sequence, i as u16);
352            }
353        }
354    }
355}