rtc_interceptor/twcc/
receiver.rs

1//! TWCC Receiver Interceptor - tracks incoming packets and generates feedback.
2
3use super::recorder::Recorder;
4use super::stream_supports_twcc;
5use crate::stream_info::StreamInfo;
6use crate::{Interceptor, Packet, TaggedPacket, interceptor};
7use shared::TransportContext;
8use shared::error::Error;
9use shared::marshal::Unmarshal;
10use std::collections::{HashMap, VecDeque};
11use std::marker::PhantomData;
12use std::time::{Duration, Instant};
13
14/// Default interval for sending TWCC feedback.
15const DEFAULT_INTERVAL: Duration = Duration::from_millis(100);
16
17/// Builder for the TwccReceiverInterceptor.
18///
19/// # Example
20///
21/// ```ignore
22/// use rtc_interceptor::{Registry, TwccReceiverBuilder};
23/// use std::time::Duration;
24///
25/// let chain = Registry::new()
26///     .with(TwccReceiverBuilder::new()
27///         .with_interval(Duration::from_millis(100))
28///         .build())
29///     .build();
30/// ```
31pub struct TwccReceiverBuilder<P> {
32    /// Interval between feedback reports.
33    interval: Duration,
34    _phantom: PhantomData<P>,
35}
36
37impl<P> Default for TwccReceiverBuilder<P> {
38    fn default() -> Self {
39        Self {
40            interval: DEFAULT_INTERVAL,
41            _phantom: PhantomData,
42        }
43    }
44}
45
46impl<P> TwccReceiverBuilder<P> {
47    /// Create a new builder with default settings.
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    /// Set the interval between feedback reports.
53    pub fn with_interval(mut self, interval: Duration) -> Self {
54        self.interval = interval;
55        self
56    }
57
58    /// Build the interceptor factory function.
59    pub fn build(self) -> impl FnOnce(P) -> TwccReceiverInterceptor<P> {
60        move |inner| TwccReceiverInterceptor::new(inner, self.interval)
61    }
62}
63
64/// Per-stream state for the receiver.
65struct RemoteStream {
66    /// Header extension ID for transport-wide CC.
67    hdr_ext_id: u8,
68}
69
70/// Interceptor that tracks incoming RTP packets and generates TWCC feedback.
71///
72/// This interceptor examines incoming RTP packets for transport-wide CC sequence
73/// numbers and periodically generates TransportLayerCC feedback packets.
74#[derive(Interceptor)]
75pub struct TwccReceiverInterceptor<P> {
76    #[next]
77    inner: P,
78
79    /// Configuration
80    interval: Duration,
81
82    /// Start time for calculating arrival times.
83    start_time: Option<Instant>,
84
85    /// TWCC recorder for building feedback.
86    recorder: Option<Recorder>,
87
88    /// Remote stream state per SSRC.
89    streams: HashMap<u32, RemoteStream>,
90
91    /// Queue for feedback packets.
92    write_queue: VecDeque<TaggedPacket>,
93
94    /// Next timeout for sending feedback.
95    next_timeout: Option<Instant>,
96}
97
98impl<P> TwccReceiverInterceptor<P> {
99    fn new(inner: P, interval: Duration) -> Self {
100        Self {
101            inner,
102            interval,
103            start_time: None,
104            recorder: None,
105            streams: HashMap::new(),
106            write_queue: VecDeque::new(),
107            next_timeout: None,
108        }
109    }
110
111    fn generate_feedback(&mut self, now: Instant) {
112        let Some(recorder) = self.recorder.as_mut() else {
113            return;
114        };
115
116        let packets = recorder.build_feedback_packet();
117        for pkt in packets {
118            self.write_queue.push_back(TaggedPacket {
119                now,
120                transport: TransportContext::default(),
121                message: Packet::Rtcp(vec![pkt]),
122            });
123        }
124    }
125}
126
127#[interceptor]
128impl<P: Interceptor> TwccReceiverInterceptor<P> {
129    #[overrides]
130    fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
131        // Process incoming RTP packets with TWCC extension
132        if let Packet::Rtp(ref rtp_packet) = msg.message
133            && let Some(stream) = self.streams.get(&rtp_packet.header.ssrc)
134        {
135            // Initialize recorder on first packet
136            if self.recorder.is_none() {
137                // Use a random sender SSRC for feedback
138                self.recorder = Some(Recorder::new(rand::random()));
139                self.start_time = Some(msg.now);
140                self.next_timeout = Some(msg.now + self.interval);
141            }
142
143            // Extract transport CC sequence number
144            if let Some(ext_data) = rtp_packet.header.get_extension(stream.hdr_ext_id)
145                && let Ok(tcc) =
146                    rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
147                        &mut ext_data.as_ref(),
148                    )
149            {
150                // Calculate arrival time in microseconds since start
151                let arrival_time = self
152                    .start_time
153                    .map(|start| msg.now.duration_since(start).as_micros() as i64)
154                    .unwrap_or(0);
155
156                if let Some(recorder) = self.recorder.as_mut() {
157                    recorder.record(rtp_packet.header.ssrc, tcc.transport_sequence, arrival_time);
158                }
159            }
160        }
161
162        self.inner.handle_read(msg)
163    }
164
165    #[overrides]
166    fn poll_write(&mut self) -> Option<Self::Wout> {
167        // First drain feedback packets
168        if let Some(pkt) = self.write_queue.pop_front() {
169            return Some(pkt);
170        }
171        self.inner.poll_write()
172    }
173
174    #[overrides]
175    fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
176        // Check if we need to send feedback
177        if let Some(timeout) = self.next_timeout
178            && now >= timeout
179        {
180            self.generate_feedback(now);
181            self.next_timeout = Some(now + self.interval);
182        }
183        self.inner.handle_timeout(now)
184    }
185
186    #[overrides]
187    fn poll_timeout(&mut self) -> Option<Self::Time> {
188        let inner_timeout = self.inner.poll_timeout();
189
190        match (self.next_timeout, inner_timeout) {
191            (Some(a), Some(b)) => Some(a.min(b)),
192            (Some(a), None) => Some(a),
193            (None, Some(b)) => Some(b),
194            (None, None) => None,
195        }
196    }
197
198    #[overrides]
199    fn bind_remote_stream(&mut self, info: &StreamInfo) {
200        if let Some(hdr_ext_id) = stream_supports_twcc(info) {
201            // Don't track if ID is 0 (invalid)
202            if hdr_ext_id != 0 {
203                self.streams.insert(info.ssrc, RemoteStream { hdr_ext_id });
204            }
205        }
206        self.inner.bind_remote_stream(info);
207    }
208
209    #[overrides]
210    fn unbind_remote_stream(&mut self, info: &StreamInfo) {
211        self.streams.remove(&info.ssrc);
212        self.inner.unbind_remote_stream(info);
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::Registry;
220    use crate::stream_info::RTPHeaderExtension;
221    use sansio::Protocol;
222    use shared::marshal::Marshal;
223
224    fn make_rtp_packet_with_twcc(
225        ssrc: u32,
226        seq: u16,
227        twcc_seq: u16,
228        hdr_ext_id: u8,
229    ) -> rtp::Packet {
230        let mut pkt = rtp::Packet {
231            header: rtp::header::Header {
232                ssrc,
233                sequence_number: seq,
234                ..Default::default()
235            },
236            payload: vec![].into(),
237        };
238
239        let tcc_ext = rtp::extension::transport_cc_extension::TransportCcExtension {
240            transport_sequence: twcc_seq,
241        };
242        if let Ok(ext_data) = tcc_ext.marshal() {
243            let _ = pkt.header.set_extension(hdr_ext_id, ext_data.freeze());
244        }
245
246        pkt
247    }
248
249    #[test]
250    fn test_twcc_receiver_builder_defaults() {
251        let chain = Registry::new()
252            .with(TwccReceiverBuilder::default().build())
253            .build();
254
255        assert_eq!(chain.interval, DEFAULT_INTERVAL);
256        assert!(chain.recorder.is_none());
257    }
258
259    #[test]
260    fn test_twcc_receiver_builder_custom_interval() {
261        let chain = Registry::new()
262            .with(
263                TwccReceiverBuilder::new()
264                    .with_interval(Duration::from_millis(50))
265                    .build(),
266            )
267            .build();
268
269        assert_eq!(chain.interval, Duration::from_millis(50));
270    }
271
272    #[test]
273    fn test_twcc_receiver_records_packets() {
274        let mut chain = Registry::new()
275            .with(TwccReceiverBuilder::new().build())
276            .build();
277
278        // Bind remote stream with TWCC support
279        let info = StreamInfo {
280            ssrc: 12345,
281            rtp_header_extensions: vec![RTPHeaderExtension {
282                uri: super::super::TRANSPORT_CC_URI.to_string(),
283                id: 5,
284            }],
285            ..Default::default()
286        };
287        chain.bind_remote_stream(&info);
288
289        let now = Instant::now();
290
291        // Receive RTP packet with TWCC extension
292        let rtp = make_rtp_packet_with_twcc(12345, 1, 0, 5);
293        let pkt = TaggedPacket {
294            now,
295            transport: Default::default(),
296            message: Packet::Rtp(rtp),
297        };
298        chain.handle_read(pkt).unwrap();
299
300        // Recorder should be initialized
301        assert!(chain.recorder.is_some());
302        assert!(chain.next_timeout.is_some());
303    }
304
305    #[test]
306    fn test_twcc_receiver_generates_feedback_on_timeout() {
307        let mut chain = Registry::new()
308            .with(
309                TwccReceiverBuilder::new()
310                    .with_interval(Duration::from_millis(100))
311                    .build(),
312            )
313            .build();
314
315        let info = StreamInfo {
316            ssrc: 12345,
317            rtp_header_extensions: vec![RTPHeaderExtension {
318                uri: super::super::TRANSPORT_CC_URI.to_string(),
319                id: 5,
320            }],
321            ..Default::default()
322        };
323        chain.bind_remote_stream(&info);
324
325        let start = Instant::now();
326
327        // Receive some packets
328        for i in 0..5u16 {
329            let rtp = make_rtp_packet_with_twcc(12345, i, i, 5);
330            let pkt = TaggedPacket {
331                now: start + Duration::from_millis(i as u64 * 10),
332                transport: Default::default(),
333                message: Packet::Rtp(rtp),
334            };
335            chain.handle_read(pkt).unwrap();
336        }
337
338        // Trigger timeout
339        let timeout_time = start + Duration::from_millis(150);
340        chain.handle_timeout(timeout_time).unwrap();
341
342        // Should have feedback packet
343        let feedback = chain.poll_write();
344        assert!(feedback.is_some());
345
346        if let Some(tagged) = feedback {
347            if let Packet::Rtcp(rtcp_packets) = tagged.message {
348                assert!(!rtcp_packets.is_empty());
349            } else {
350                panic!("Expected RTCP packet");
351            }
352        }
353    }
354
355    #[test]
356    fn test_twcc_receiver_no_feedback_without_binding() {
357        let mut chain = Registry::new()
358            .with(TwccReceiverBuilder::new().build())
359            .build();
360
361        let now = Instant::now();
362
363        // Receive packet without binding (no TWCC tracking)
364        let rtp = make_rtp_packet_with_twcc(12345, 1, 0, 5);
365        let pkt = TaggedPacket {
366            now,
367            transport: Default::default(),
368            message: Packet::Rtp(rtp),
369        };
370        chain.handle_read(pkt).unwrap();
371
372        // Recorder should not be initialized
373        assert!(chain.recorder.is_none());
374    }
375
376    #[test]
377    fn test_twcc_receiver_unbind_removes_stream() {
378        let mut chain = Registry::new()
379            .with(TwccReceiverBuilder::new().build())
380            .build();
381
382        let info = StreamInfo {
383            ssrc: 12345,
384            rtp_header_extensions: vec![RTPHeaderExtension {
385                uri: super::super::TRANSPORT_CC_URI.to_string(),
386                id: 5,
387            }],
388            ..Default::default()
389        };
390
391        chain.bind_remote_stream(&info);
392        assert!(chain.streams.contains_key(&12345));
393
394        chain.unbind_remote_stream(&info);
395        assert!(!chain.streams.contains_key(&12345));
396    }
397
398    #[test]
399    fn test_twcc_receiver_poll_timeout() {
400        let mut chain = Registry::new()
401            .with(TwccReceiverBuilder::new().build())
402            .build();
403
404        // No timeout initially
405        assert!(chain.poll_timeout().is_none());
406
407        let info = StreamInfo {
408            ssrc: 12345,
409            rtp_header_extensions: vec![RTPHeaderExtension {
410                uri: super::super::TRANSPORT_CC_URI.to_string(),
411                id: 5,
412            }],
413            ..Default::default()
414        };
415        chain.bind_remote_stream(&info);
416
417        let now = Instant::now();
418
419        // Receive a packet to initialize recorder
420        let rtp = make_rtp_packet_with_twcc(12345, 1, 0, 5);
421        let pkt = TaggedPacket {
422            now,
423            transport: Default::default(),
424            message: Packet::Rtp(rtp),
425        };
426        chain.handle_read(pkt).unwrap();
427
428        // Should have timeout now
429        let timeout = chain.poll_timeout();
430        assert!(timeout.is_some());
431    }
432}