rtc_interceptor/twcc/
receiver.rs

1//! TWCC Receiver Interceptor - tracks incoming packets and generates feedback.
2
3use super::recorder::Recorder;
4use super::stream_supports_twcc;
5use crate::stream_info::StreamInfo;
6use crate::{Interceptor, Packet, TaggedPacket};
7use shared::TransportContext;
8use shared::error::Error;
9use shared::marshal::Unmarshal;
10use std::collections::{HashMap, VecDeque};
11use std::marker::PhantomData;
12use std::time::{Duration, Instant};
13
14/// Default interval for sending TWCC feedback.
15const DEFAULT_INTERVAL: Duration = Duration::from_millis(100);
16
17/// Builder for the TwccReceiverInterceptor.
18///
19/// # Example
20///
21/// ```ignore
22/// use rtc_interceptor::{Registry, TwccReceiverBuilder};
23/// use std::time::Duration;
24///
25/// let chain = Registry::new()
26///     .with(TwccReceiverBuilder::new()
27///         .with_interval(Duration::from_millis(100))
28///         .build())
29///     .build();
30/// ```
31pub struct TwccReceiverBuilder<P> {
32    /// Interval between feedback reports.
33    interval: Duration,
34    _phantom: PhantomData<P>,
35}
36
37impl<P> Default for TwccReceiverBuilder<P> {
38    fn default() -> Self {
39        Self {
40            interval: DEFAULT_INTERVAL,
41            _phantom: PhantomData,
42        }
43    }
44}
45
46impl<P> TwccReceiverBuilder<P> {
47    /// Create a new builder with default settings.
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    /// Set the interval between feedback reports.
53    pub fn with_interval(mut self, interval: Duration) -> Self {
54        self.interval = interval;
55        self
56    }
57
58    /// Build the interceptor factory function.
59    pub fn build(self) -> impl FnOnce(P) -> TwccReceiverInterceptor<P> {
60        move |inner| TwccReceiverInterceptor::new(inner, self.interval)
61    }
62}
63
64/// Per-stream state for the receiver.
65struct RemoteStream {
66    /// Header extension ID for transport-wide CC.
67    hdr_ext_id: u8,
68}
69
70/// Interceptor that tracks incoming RTP packets and generates TWCC feedback.
71///
72/// This interceptor examines incoming RTP packets for transport-wide CC sequence
73/// numbers and periodically generates TransportLayerCC feedback packets.
74pub struct TwccReceiverInterceptor<P> {
75    inner: P,
76
77    /// Configuration
78    interval: Duration,
79
80    /// Start time for calculating arrival times.
81    start_time: Option<Instant>,
82
83    /// TWCC recorder for building feedback.
84    recorder: Option<Recorder>,
85
86    /// Remote stream state per SSRC.
87    streams: HashMap<u32, RemoteStream>,
88
89    /// Queue for feedback packets.
90    write_queue: VecDeque<TaggedPacket>,
91
92    /// Next timeout for sending feedback.
93    next_timeout: Option<Instant>,
94}
95
96impl<P> TwccReceiverInterceptor<P> {
97    fn new(inner: P, interval: Duration) -> Self {
98        Self {
99            inner,
100            interval,
101            start_time: None,
102            recorder: None,
103            streams: HashMap::new(),
104            write_queue: VecDeque::new(),
105            next_timeout: None,
106        }
107    }
108
109    fn generate_feedback(&mut self, now: Instant) {
110        let Some(recorder) = self.recorder.as_mut() else {
111            return;
112        };
113
114        let packets = recorder.build_feedback_packet();
115        for pkt in packets {
116            self.write_queue.push_back(TaggedPacket {
117                now,
118                transport: TransportContext::default(),
119                message: Packet::Rtcp(vec![pkt]),
120            });
121        }
122    }
123}
124
125impl<P: Interceptor> sansio::Protocol<TaggedPacket, TaggedPacket, ()>
126    for TwccReceiverInterceptor<P>
127{
128    type Rout = TaggedPacket;
129    type Wout = TaggedPacket;
130    type Eout = ();
131    type Error = Error;
132    type Time = Instant;
133
134    fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
135        // Process incoming RTP packets with TWCC extension
136        if let Packet::Rtp(ref rtp_packet) = msg.message
137            && let Some(stream) = self.streams.get(&rtp_packet.header.ssrc)
138        {
139            // Initialize recorder on first packet
140            if self.recorder.is_none() {
141                // Use a random sender SSRC for feedback
142                self.recorder = Some(Recorder::new(rand::random()));
143                self.start_time = Some(msg.now);
144                self.next_timeout = Some(msg.now + self.interval);
145            }
146
147            // Extract transport CC sequence number
148            if let Some(ext_data) = rtp_packet.header.get_extension(stream.hdr_ext_id)
149                && let Ok(tcc) =
150                    rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
151                        &mut ext_data.as_ref(),
152                    )
153            {
154                // Calculate arrival time in microseconds since start
155                let arrival_time = self
156                    .start_time
157                    .map(|start| msg.now.duration_since(start).as_micros() as i64)
158                    .unwrap_or(0);
159
160                if let Some(recorder) = self.recorder.as_mut() {
161                    recorder.record(rtp_packet.header.ssrc, tcc.transport_sequence, arrival_time);
162                }
163            }
164        }
165
166        self.inner.handle_read(msg)
167    }
168
169    fn poll_read(&mut self) -> Option<Self::Rout> {
170        self.inner.poll_read()
171    }
172
173    fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
174        self.inner.handle_write(msg)
175    }
176
177    fn poll_write(&mut self) -> Option<Self::Wout> {
178        // First drain feedback packets
179        if let Some(pkt) = self.write_queue.pop_front() {
180            return Some(pkt);
181        }
182        self.inner.poll_write()
183    }
184
185    fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
186        // Check if we need to send feedback
187        if let Some(timeout) = self.next_timeout
188            && now >= timeout
189        {
190            self.generate_feedback(now);
191            self.next_timeout = Some(now + self.interval);
192        }
193        self.inner.handle_timeout(now)
194    }
195
196    fn poll_timeout(&mut self) -> Option<Self::Time> {
197        let inner_timeout = self.inner.poll_timeout();
198
199        match (self.next_timeout, inner_timeout) {
200            (Some(a), Some(b)) => Some(a.min(b)),
201            (Some(a), None) => Some(a),
202            (None, Some(b)) => Some(b),
203            (None, None) => None,
204        }
205    }
206}
207
208impl<P: Interceptor> Interceptor for TwccReceiverInterceptor<P> {
209    fn bind_local_stream(&mut self, info: &StreamInfo) {
210        self.inner.bind_local_stream(info);
211    }
212
213    fn unbind_local_stream(&mut self, info: &StreamInfo) {
214        self.inner.unbind_local_stream(info);
215    }
216
217    fn bind_remote_stream(&mut self, info: &StreamInfo) {
218        if let Some(hdr_ext_id) = stream_supports_twcc(info) {
219            // Don't track if ID is 0 (invalid)
220            if hdr_ext_id != 0 {
221                self.streams.insert(info.ssrc, RemoteStream { hdr_ext_id });
222            }
223        }
224        self.inner.bind_remote_stream(info);
225    }
226
227    fn unbind_remote_stream(&mut self, info: &StreamInfo) {
228        self.streams.remove(&info.ssrc);
229        self.inner.unbind_remote_stream(info);
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use crate::Registry;
237    use crate::stream_info::RTPHeaderExtension;
238    use sansio::Protocol;
239    use shared::marshal::Marshal;
240
241    fn make_rtp_packet_with_twcc(
242        ssrc: u32,
243        seq: u16,
244        twcc_seq: u16,
245        hdr_ext_id: u8,
246    ) -> rtp::Packet {
247        let mut pkt = rtp::Packet {
248            header: rtp::header::Header {
249                ssrc,
250                sequence_number: seq,
251                ..Default::default()
252            },
253            payload: vec![].into(),
254        };
255
256        let tcc_ext = rtp::extension::transport_cc_extension::TransportCcExtension {
257            transport_sequence: twcc_seq,
258        };
259        if let Ok(ext_data) = tcc_ext.marshal() {
260            let _ = pkt.header.set_extension(hdr_ext_id, ext_data.freeze());
261        }
262
263        pkt
264    }
265
266    #[test]
267    fn test_twcc_receiver_builder_defaults() {
268        let chain = Registry::new()
269            .with(TwccReceiverBuilder::default().build())
270            .build();
271
272        assert_eq!(chain.interval, DEFAULT_INTERVAL);
273        assert!(chain.recorder.is_none());
274    }
275
276    #[test]
277    fn test_twcc_receiver_builder_custom_interval() {
278        let chain = Registry::new()
279            .with(
280                TwccReceiverBuilder::new()
281                    .with_interval(Duration::from_millis(50))
282                    .build(),
283            )
284            .build();
285
286        assert_eq!(chain.interval, Duration::from_millis(50));
287    }
288
289    #[test]
290    fn test_twcc_receiver_records_packets() {
291        let mut chain = Registry::new()
292            .with(TwccReceiverBuilder::new().build())
293            .build();
294
295        // Bind remote stream with TWCC support
296        let info = StreamInfo {
297            ssrc: 12345,
298            rtp_header_extensions: vec![RTPHeaderExtension {
299                uri: super::super::TRANSPORT_CC_URI.to_string(),
300                id: 5,
301            }],
302            ..Default::default()
303        };
304        chain.bind_remote_stream(&info);
305
306        let now = Instant::now();
307
308        // Receive RTP packet with TWCC extension
309        let rtp = make_rtp_packet_with_twcc(12345, 1, 0, 5);
310        let pkt = TaggedPacket {
311            now,
312            transport: Default::default(),
313            message: Packet::Rtp(rtp),
314        };
315        chain.handle_read(pkt).unwrap();
316
317        // Recorder should be initialized
318        assert!(chain.recorder.is_some());
319        assert!(chain.next_timeout.is_some());
320    }
321
322    #[test]
323    fn test_twcc_receiver_generates_feedback_on_timeout() {
324        let mut chain = Registry::new()
325            .with(
326                TwccReceiverBuilder::new()
327                    .with_interval(Duration::from_millis(100))
328                    .build(),
329            )
330            .build();
331
332        let info = StreamInfo {
333            ssrc: 12345,
334            rtp_header_extensions: vec![RTPHeaderExtension {
335                uri: super::super::TRANSPORT_CC_URI.to_string(),
336                id: 5,
337            }],
338            ..Default::default()
339        };
340        chain.bind_remote_stream(&info);
341
342        let start = Instant::now();
343
344        // Receive some packets
345        for i in 0..5u16 {
346            let rtp = make_rtp_packet_with_twcc(12345, i, i, 5);
347            let pkt = TaggedPacket {
348                now: start + Duration::from_millis(i as u64 * 10),
349                transport: Default::default(),
350                message: Packet::Rtp(rtp),
351            };
352            chain.handle_read(pkt).unwrap();
353        }
354
355        // Trigger timeout
356        let timeout_time = start + Duration::from_millis(150);
357        chain.handle_timeout(timeout_time).unwrap();
358
359        // Should have feedback packet
360        let feedback = chain.poll_write();
361        assert!(feedback.is_some());
362
363        if let Some(tagged) = feedback {
364            if let Packet::Rtcp(rtcp_packets) = tagged.message {
365                assert!(!rtcp_packets.is_empty());
366            } else {
367                panic!("Expected RTCP packet");
368            }
369        }
370    }
371
372    #[test]
373    fn test_twcc_receiver_no_feedback_without_binding() {
374        let mut chain = Registry::new()
375            .with(TwccReceiverBuilder::new().build())
376            .build();
377
378        let now = Instant::now();
379
380        // Receive packet without binding (no TWCC tracking)
381        let rtp = make_rtp_packet_with_twcc(12345, 1, 0, 5);
382        let pkt = TaggedPacket {
383            now,
384            transport: Default::default(),
385            message: Packet::Rtp(rtp),
386        };
387        chain.handle_read(pkt).unwrap();
388
389        // Recorder should not be initialized
390        assert!(chain.recorder.is_none());
391    }
392
393    #[test]
394    fn test_twcc_receiver_unbind_removes_stream() {
395        let mut chain = Registry::new()
396            .with(TwccReceiverBuilder::new().build())
397            .build();
398
399        let info = StreamInfo {
400            ssrc: 12345,
401            rtp_header_extensions: vec![RTPHeaderExtension {
402                uri: super::super::TRANSPORT_CC_URI.to_string(),
403                id: 5,
404            }],
405            ..Default::default()
406        };
407
408        chain.bind_remote_stream(&info);
409        assert!(chain.streams.contains_key(&12345));
410
411        chain.unbind_remote_stream(&info);
412        assert!(!chain.streams.contains_key(&12345));
413    }
414
415    #[test]
416    fn test_twcc_receiver_poll_timeout() {
417        let mut chain = Registry::new()
418            .with(TwccReceiverBuilder::new().build())
419            .build();
420
421        // No timeout initially
422        assert!(chain.poll_timeout().is_none());
423
424        let info = StreamInfo {
425            ssrc: 12345,
426            rtp_header_extensions: vec![RTPHeaderExtension {
427                uri: super::super::TRANSPORT_CC_URI.to_string(),
428                id: 5,
429            }],
430            ..Default::default()
431        };
432        chain.bind_remote_stream(&info);
433
434        let now = Instant::now();
435
436        // Receive a packet to initialize recorder
437        let rtp = make_rtp_packet_with_twcc(12345, 1, 0, 5);
438        let pkt = TaggedPacket {
439            now,
440            transport: Default::default(),
441            message: Packet::Rtp(rtp),
442        };
443        chain.handle_read(pkt).unwrap();
444
445        // Should have timeout now
446        let timeout = chain.poll_timeout();
447        assert!(timeout.is_some());
448    }
449}