1use super::recorder::Recorder;
4use super::stream_supports_twcc;
5use crate::stream_info::StreamInfo;
6use crate::{Interceptor, Packet, TaggedPacket, interceptor};
7use shared::TransportContext;
8use shared::error::Error;
9use shared::marshal::Unmarshal;
10use std::collections::{HashMap, VecDeque};
11use std::marker::PhantomData;
12use std::time::{Duration, Instant};
13
14const DEFAULT_INTERVAL: Duration = Duration::from_millis(100);
16
17pub struct TwccReceiverBuilder<P> {
32 interval: Duration,
34 _phantom: PhantomData<P>,
35}
36
37impl<P> Default for TwccReceiverBuilder<P> {
38 fn default() -> Self {
39 Self {
40 interval: DEFAULT_INTERVAL,
41 _phantom: PhantomData,
42 }
43 }
44}
45
46impl<P> TwccReceiverBuilder<P> {
47 pub fn new() -> Self {
49 Self::default()
50 }
51
52 pub fn with_interval(mut self, interval: Duration) -> Self {
54 self.interval = interval;
55 self
56 }
57
58 pub fn build(self) -> impl FnOnce(P) -> TwccReceiverInterceptor<P> {
60 move |inner| TwccReceiverInterceptor::new(inner, self.interval)
61 }
62}
63
64struct RemoteStream {
66 hdr_ext_id: u8,
68}
69
70#[derive(Interceptor)]
75pub struct TwccReceiverInterceptor<P> {
76 #[next]
77 inner: P,
78
79 interval: Duration,
81
82 start_time: Option<Instant>,
84
85 recorder: Option<Recorder>,
87
88 streams: HashMap<u32, RemoteStream>,
90
91 write_queue: VecDeque<TaggedPacket>,
93
94 next_timeout: Option<Instant>,
96}
97
98impl<P> TwccReceiverInterceptor<P> {
99 fn new(inner: P, interval: Duration) -> Self {
100 Self {
101 inner,
102 interval,
103 start_time: None,
104 recorder: None,
105 streams: HashMap::new(),
106 write_queue: VecDeque::new(),
107 next_timeout: None,
108 }
109 }
110
111 fn generate_feedback(&mut self, now: Instant) {
112 let Some(recorder) = self.recorder.as_mut() else {
113 return;
114 };
115
116 let packets = recorder.build_feedback_packet();
117 for pkt in packets {
118 self.write_queue.push_back(TaggedPacket {
119 now,
120 transport: TransportContext::default(),
121 message: Packet::Rtcp(vec![pkt]),
122 });
123 }
124 }
125}
126
127#[interceptor]
128impl<P: Interceptor> TwccReceiverInterceptor<P> {
129 #[overrides]
130 fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
131 if let Packet::Rtp(ref rtp_packet) = msg.message
133 && let Some(stream) = self.streams.get(&rtp_packet.header.ssrc)
134 {
135 if self.recorder.is_none() {
137 self.recorder = Some(Recorder::new(rand::random()));
139 self.start_time = Some(msg.now);
140 self.next_timeout = Some(msg.now + self.interval);
141 }
142
143 if let Some(ext_data) = rtp_packet.header.get_extension(stream.hdr_ext_id)
145 && let Ok(tcc) =
146 rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
147 &mut ext_data.as_ref(),
148 )
149 {
150 let arrival_time = self
152 .start_time
153 .map(|start| msg.now.duration_since(start).as_micros() as i64)
154 .unwrap_or(0);
155
156 if let Some(recorder) = self.recorder.as_mut() {
157 recorder.record(rtp_packet.header.ssrc, tcc.transport_sequence, arrival_time);
158 }
159 }
160 }
161
162 self.inner.handle_read(msg)
163 }
164
165 #[overrides]
166 fn poll_write(&mut self) -> Option<Self::Wout> {
167 if let Some(pkt) = self.write_queue.pop_front() {
169 return Some(pkt);
170 }
171 self.inner.poll_write()
172 }
173
174 #[overrides]
175 fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
176 if let Some(timeout) = self.next_timeout
178 && now >= timeout
179 {
180 self.generate_feedback(now);
181 self.next_timeout = Some(now + self.interval);
182 }
183 self.inner.handle_timeout(now)
184 }
185
186 #[overrides]
187 fn poll_timeout(&mut self) -> Option<Self::Time> {
188 let inner_timeout = self.inner.poll_timeout();
189
190 match (self.next_timeout, inner_timeout) {
191 (Some(a), Some(b)) => Some(a.min(b)),
192 (Some(a), None) => Some(a),
193 (None, Some(b)) => Some(b),
194 (None, None) => None,
195 }
196 }
197
198 #[overrides]
199 fn bind_remote_stream(&mut self, info: &StreamInfo) {
200 if let Some(hdr_ext_id) = stream_supports_twcc(info) {
201 if hdr_ext_id != 0 {
203 self.streams.insert(info.ssrc, RemoteStream { hdr_ext_id });
204 }
205 }
206 self.inner.bind_remote_stream(info);
207 }
208
209 #[overrides]
210 fn unbind_remote_stream(&mut self, info: &StreamInfo) {
211 self.streams.remove(&info.ssrc);
212 self.inner.unbind_remote_stream(info);
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::Registry;
220 use crate::stream_info::RTPHeaderExtension;
221 use sansio::Protocol;
222 use shared::marshal::Marshal;
223
224 fn make_rtp_packet_with_twcc(
225 ssrc: u32,
226 seq: u16,
227 twcc_seq: u16,
228 hdr_ext_id: u8,
229 ) -> rtp::Packet {
230 let mut pkt = rtp::Packet {
231 header: rtp::header::Header {
232 ssrc,
233 sequence_number: seq,
234 ..Default::default()
235 },
236 payload: vec![].into(),
237 };
238
239 let tcc_ext = rtp::extension::transport_cc_extension::TransportCcExtension {
240 transport_sequence: twcc_seq,
241 };
242 if let Ok(ext_data) = tcc_ext.marshal() {
243 let _ = pkt.header.set_extension(hdr_ext_id, ext_data.freeze());
244 }
245
246 pkt
247 }
248
249 #[test]
250 fn test_twcc_receiver_builder_defaults() {
251 let chain = Registry::new()
252 .with(TwccReceiverBuilder::default().build())
253 .build();
254
255 assert_eq!(chain.interval, DEFAULT_INTERVAL);
256 assert!(chain.recorder.is_none());
257 }
258
259 #[test]
260 fn test_twcc_receiver_builder_custom_interval() {
261 let chain = Registry::new()
262 .with(
263 TwccReceiverBuilder::new()
264 .with_interval(Duration::from_millis(50))
265 .build(),
266 )
267 .build();
268
269 assert_eq!(chain.interval, Duration::from_millis(50));
270 }
271
272 #[test]
273 fn test_twcc_receiver_records_packets() {
274 let mut chain = Registry::new()
275 .with(TwccReceiverBuilder::new().build())
276 .build();
277
278 let info = StreamInfo {
280 ssrc: 12345,
281 rtp_header_extensions: vec![RTPHeaderExtension {
282 uri: super::super::TRANSPORT_CC_URI.to_string(),
283 id: 5,
284 }],
285 ..Default::default()
286 };
287 chain.bind_remote_stream(&info);
288
289 let now = Instant::now();
290
291 let rtp = make_rtp_packet_with_twcc(12345, 1, 0, 5);
293 let pkt = TaggedPacket {
294 now,
295 transport: Default::default(),
296 message: Packet::Rtp(rtp),
297 };
298 chain.handle_read(pkt).unwrap();
299
300 assert!(chain.recorder.is_some());
302 assert!(chain.next_timeout.is_some());
303 }
304
305 #[test]
306 fn test_twcc_receiver_generates_feedback_on_timeout() {
307 let mut chain = Registry::new()
308 .with(
309 TwccReceiverBuilder::new()
310 .with_interval(Duration::from_millis(100))
311 .build(),
312 )
313 .build();
314
315 let info = StreamInfo {
316 ssrc: 12345,
317 rtp_header_extensions: vec![RTPHeaderExtension {
318 uri: super::super::TRANSPORT_CC_URI.to_string(),
319 id: 5,
320 }],
321 ..Default::default()
322 };
323 chain.bind_remote_stream(&info);
324
325 let start = Instant::now();
326
327 for i in 0..5u16 {
329 let rtp = make_rtp_packet_with_twcc(12345, i, i, 5);
330 let pkt = TaggedPacket {
331 now: start + Duration::from_millis(i as u64 * 10),
332 transport: Default::default(),
333 message: Packet::Rtp(rtp),
334 };
335 chain.handle_read(pkt).unwrap();
336 }
337
338 let timeout_time = start + Duration::from_millis(150);
340 chain.handle_timeout(timeout_time).unwrap();
341
342 let feedback = chain.poll_write();
344 assert!(feedback.is_some());
345
346 if let Some(tagged) = feedback {
347 if let Packet::Rtcp(rtcp_packets) = tagged.message {
348 assert!(!rtcp_packets.is_empty());
349 } else {
350 panic!("Expected RTCP packet");
351 }
352 }
353 }
354
355 #[test]
356 fn test_twcc_receiver_no_feedback_without_binding() {
357 let mut chain = Registry::new()
358 .with(TwccReceiverBuilder::new().build())
359 .build();
360
361 let now = Instant::now();
362
363 let rtp = make_rtp_packet_with_twcc(12345, 1, 0, 5);
365 let pkt = TaggedPacket {
366 now,
367 transport: Default::default(),
368 message: Packet::Rtp(rtp),
369 };
370 chain.handle_read(pkt).unwrap();
371
372 assert!(chain.recorder.is_none());
374 }
375
376 #[test]
377 fn test_twcc_receiver_unbind_removes_stream() {
378 let mut chain = Registry::new()
379 .with(TwccReceiverBuilder::new().build())
380 .build();
381
382 let info = StreamInfo {
383 ssrc: 12345,
384 rtp_header_extensions: vec![RTPHeaderExtension {
385 uri: super::super::TRANSPORT_CC_URI.to_string(),
386 id: 5,
387 }],
388 ..Default::default()
389 };
390
391 chain.bind_remote_stream(&info);
392 assert!(chain.streams.contains_key(&12345));
393
394 chain.unbind_remote_stream(&info);
395 assert!(!chain.streams.contains_key(&12345));
396 }
397
398 #[test]
399 fn test_twcc_receiver_poll_timeout() {
400 let mut chain = Registry::new()
401 .with(TwccReceiverBuilder::new().build())
402 .build();
403
404 assert!(chain.poll_timeout().is_none());
406
407 let info = StreamInfo {
408 ssrc: 12345,
409 rtp_header_extensions: vec![RTPHeaderExtension {
410 uri: super::super::TRANSPORT_CC_URI.to_string(),
411 id: 5,
412 }],
413 ..Default::default()
414 };
415 chain.bind_remote_stream(&info);
416
417 let now = Instant::now();
418
419 let rtp = make_rtp_packet_with_twcc(12345, 1, 0, 5);
421 let pkt = TaggedPacket {
422 now,
423 transport: Default::default(),
424 message: Packet::Rtp(rtp),
425 };
426 chain.handle_read(pkt).unwrap();
427
428 let timeout = chain.poll_timeout();
430 assert!(timeout.is_some());
431 }
432}