1use super::recorder::Recorder;
4use super::stream_supports_twcc;
5use crate::stream_info::StreamInfo;
6use crate::{Interceptor, Packet, TaggedPacket};
7use shared::TransportContext;
8use shared::error::Error;
9use shared::marshal::Unmarshal;
10use std::collections::{HashMap, VecDeque};
11use std::marker::PhantomData;
12use std::time::{Duration, Instant};
13
14const DEFAULT_INTERVAL: Duration = Duration::from_millis(100);
16
17pub struct TwccReceiverBuilder<P> {
32 interval: Duration,
34 _phantom: PhantomData<P>,
35}
36
37impl<P> Default for TwccReceiverBuilder<P> {
38 fn default() -> Self {
39 Self {
40 interval: DEFAULT_INTERVAL,
41 _phantom: PhantomData,
42 }
43 }
44}
45
46impl<P> TwccReceiverBuilder<P> {
47 pub fn new() -> Self {
49 Self::default()
50 }
51
52 pub fn with_interval(mut self, interval: Duration) -> Self {
54 self.interval = interval;
55 self
56 }
57
58 pub fn build(self) -> impl FnOnce(P) -> TwccReceiverInterceptor<P> {
60 move |inner| TwccReceiverInterceptor::new(inner, self.interval)
61 }
62}
63
64struct RemoteStream {
66 hdr_ext_id: u8,
68}
69
70pub struct TwccReceiverInterceptor<P> {
75 inner: P,
76
77 interval: Duration,
79
80 start_time: Option<Instant>,
82
83 recorder: Option<Recorder>,
85
86 streams: HashMap<u32, RemoteStream>,
88
89 write_queue: VecDeque<TaggedPacket>,
91
92 next_timeout: Option<Instant>,
94}
95
96impl<P> TwccReceiverInterceptor<P> {
97 fn new(inner: P, interval: Duration) -> Self {
98 Self {
99 inner,
100 interval,
101 start_time: None,
102 recorder: None,
103 streams: HashMap::new(),
104 write_queue: VecDeque::new(),
105 next_timeout: None,
106 }
107 }
108
109 fn generate_feedback(&mut self, now: Instant) {
110 let Some(recorder) = self.recorder.as_mut() else {
111 return;
112 };
113
114 let packets = recorder.build_feedback_packet();
115 for pkt in packets {
116 self.write_queue.push_back(TaggedPacket {
117 now,
118 transport: TransportContext::default(),
119 message: Packet::Rtcp(vec![pkt]),
120 });
121 }
122 }
123}
124
125impl<P: Interceptor> sansio::Protocol<TaggedPacket, TaggedPacket, ()>
126 for TwccReceiverInterceptor<P>
127{
128 type Rout = TaggedPacket;
129 type Wout = TaggedPacket;
130 type Eout = ();
131 type Error = Error;
132 type Time = Instant;
133
134 fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
135 if let Packet::Rtp(ref rtp_packet) = msg.message
137 && let Some(stream) = self.streams.get(&rtp_packet.header.ssrc)
138 {
139 if self.recorder.is_none() {
141 self.recorder = Some(Recorder::new(rand::random()));
143 self.start_time = Some(msg.now);
144 self.next_timeout = Some(msg.now + self.interval);
145 }
146
147 if let Some(ext_data) = rtp_packet.header.get_extension(stream.hdr_ext_id)
149 && let Ok(tcc) =
150 rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
151 &mut ext_data.as_ref(),
152 )
153 {
154 let arrival_time = self
156 .start_time
157 .map(|start| msg.now.duration_since(start).as_micros() as i64)
158 .unwrap_or(0);
159
160 if let Some(recorder) = self.recorder.as_mut() {
161 recorder.record(rtp_packet.header.ssrc, tcc.transport_sequence, arrival_time);
162 }
163 }
164 }
165
166 self.inner.handle_read(msg)
167 }
168
169 fn poll_read(&mut self) -> Option<Self::Rout> {
170 self.inner.poll_read()
171 }
172
173 fn handle_write(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
174 self.inner.handle_write(msg)
175 }
176
177 fn poll_write(&mut self) -> Option<Self::Wout> {
178 if let Some(pkt) = self.write_queue.pop_front() {
180 return Some(pkt);
181 }
182 self.inner.poll_write()
183 }
184
185 fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
186 if let Some(timeout) = self.next_timeout
188 && now >= timeout
189 {
190 self.generate_feedback(now);
191 self.next_timeout = Some(now + self.interval);
192 }
193 self.inner.handle_timeout(now)
194 }
195
196 fn poll_timeout(&mut self) -> Option<Self::Time> {
197 let inner_timeout = self.inner.poll_timeout();
198
199 match (self.next_timeout, inner_timeout) {
200 (Some(a), Some(b)) => Some(a.min(b)),
201 (Some(a), None) => Some(a),
202 (None, Some(b)) => Some(b),
203 (None, None) => None,
204 }
205 }
206}
207
208impl<P: Interceptor> Interceptor for TwccReceiverInterceptor<P> {
209 fn bind_local_stream(&mut self, info: &StreamInfo) {
210 self.inner.bind_local_stream(info);
211 }
212
213 fn unbind_local_stream(&mut self, info: &StreamInfo) {
214 self.inner.unbind_local_stream(info);
215 }
216
217 fn bind_remote_stream(&mut self, info: &StreamInfo) {
218 if let Some(hdr_ext_id) = stream_supports_twcc(info) {
219 if hdr_ext_id != 0 {
221 self.streams.insert(info.ssrc, RemoteStream { hdr_ext_id });
222 }
223 }
224 self.inner.bind_remote_stream(info);
225 }
226
227 fn unbind_remote_stream(&mut self, info: &StreamInfo) {
228 self.streams.remove(&info.ssrc);
229 self.inner.unbind_remote_stream(info);
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use crate::Registry;
237 use crate::stream_info::RTPHeaderExtension;
238 use sansio::Protocol;
239 use shared::marshal::Marshal;
240
241 fn make_rtp_packet_with_twcc(
242 ssrc: u32,
243 seq: u16,
244 twcc_seq: u16,
245 hdr_ext_id: u8,
246 ) -> rtp::Packet {
247 let mut pkt = rtp::Packet {
248 header: rtp::header::Header {
249 ssrc,
250 sequence_number: seq,
251 ..Default::default()
252 },
253 payload: vec![].into(),
254 };
255
256 let tcc_ext = rtp::extension::transport_cc_extension::TransportCcExtension {
257 transport_sequence: twcc_seq,
258 };
259 if let Ok(ext_data) = tcc_ext.marshal() {
260 let _ = pkt.header.set_extension(hdr_ext_id, ext_data.freeze());
261 }
262
263 pkt
264 }
265
266 #[test]
267 fn test_twcc_receiver_builder_defaults() {
268 let chain = Registry::new()
269 .with(TwccReceiverBuilder::default().build())
270 .build();
271
272 assert_eq!(chain.interval, DEFAULT_INTERVAL);
273 assert!(chain.recorder.is_none());
274 }
275
276 #[test]
277 fn test_twcc_receiver_builder_custom_interval() {
278 let chain = Registry::new()
279 .with(
280 TwccReceiverBuilder::new()
281 .with_interval(Duration::from_millis(50))
282 .build(),
283 )
284 .build();
285
286 assert_eq!(chain.interval, Duration::from_millis(50));
287 }
288
289 #[test]
290 fn test_twcc_receiver_records_packets() {
291 let mut chain = Registry::new()
292 .with(TwccReceiverBuilder::new().build())
293 .build();
294
295 let info = StreamInfo {
297 ssrc: 12345,
298 rtp_header_extensions: vec![RTPHeaderExtension {
299 uri: super::super::TRANSPORT_CC_URI.to_string(),
300 id: 5,
301 }],
302 ..Default::default()
303 };
304 chain.bind_remote_stream(&info);
305
306 let now = Instant::now();
307
308 let rtp = make_rtp_packet_with_twcc(12345, 1, 0, 5);
310 let pkt = TaggedPacket {
311 now,
312 transport: Default::default(),
313 message: Packet::Rtp(rtp),
314 };
315 chain.handle_read(pkt).unwrap();
316
317 assert!(chain.recorder.is_some());
319 assert!(chain.next_timeout.is_some());
320 }
321
322 #[test]
323 fn test_twcc_receiver_generates_feedback_on_timeout() {
324 let mut chain = Registry::new()
325 .with(
326 TwccReceiverBuilder::new()
327 .with_interval(Duration::from_millis(100))
328 .build(),
329 )
330 .build();
331
332 let info = StreamInfo {
333 ssrc: 12345,
334 rtp_header_extensions: vec![RTPHeaderExtension {
335 uri: super::super::TRANSPORT_CC_URI.to_string(),
336 id: 5,
337 }],
338 ..Default::default()
339 };
340 chain.bind_remote_stream(&info);
341
342 let start = Instant::now();
343
344 for i in 0..5u16 {
346 let rtp = make_rtp_packet_with_twcc(12345, i, i, 5);
347 let pkt = TaggedPacket {
348 now: start + Duration::from_millis(i as u64 * 10),
349 transport: Default::default(),
350 message: Packet::Rtp(rtp),
351 };
352 chain.handle_read(pkt).unwrap();
353 }
354
355 let timeout_time = start + Duration::from_millis(150);
357 chain.handle_timeout(timeout_time).unwrap();
358
359 let feedback = chain.poll_write();
361 assert!(feedback.is_some());
362
363 if let Some(tagged) = feedback {
364 if let Packet::Rtcp(rtcp_packets) = tagged.message {
365 assert!(!rtcp_packets.is_empty());
366 } else {
367 panic!("Expected RTCP packet");
368 }
369 }
370 }
371
372 #[test]
373 fn test_twcc_receiver_no_feedback_without_binding() {
374 let mut chain = Registry::new()
375 .with(TwccReceiverBuilder::new().build())
376 .build();
377
378 let now = Instant::now();
379
380 let rtp = make_rtp_packet_with_twcc(12345, 1, 0, 5);
382 let pkt = TaggedPacket {
383 now,
384 transport: Default::default(),
385 message: Packet::Rtp(rtp),
386 };
387 chain.handle_read(pkt).unwrap();
388
389 assert!(chain.recorder.is_none());
391 }
392
393 #[test]
394 fn test_twcc_receiver_unbind_removes_stream() {
395 let mut chain = Registry::new()
396 .with(TwccReceiverBuilder::new().build())
397 .build();
398
399 let info = StreamInfo {
400 ssrc: 12345,
401 rtp_header_extensions: vec![RTPHeaderExtension {
402 uri: super::super::TRANSPORT_CC_URI.to_string(),
403 id: 5,
404 }],
405 ..Default::default()
406 };
407
408 chain.bind_remote_stream(&info);
409 assert!(chain.streams.contains_key(&12345));
410
411 chain.unbind_remote_stream(&info);
412 assert!(!chain.streams.contains_key(&12345));
413 }
414
415 #[test]
416 fn test_twcc_receiver_poll_timeout() {
417 let mut chain = Registry::new()
418 .with(TwccReceiverBuilder::new().build())
419 .build();
420
421 assert!(chain.poll_timeout().is_none());
423
424 let info = StreamInfo {
425 ssrc: 12345,
426 rtp_header_extensions: vec![RTPHeaderExtension {
427 uri: super::super::TRANSPORT_CC_URI.to_string(),
428 id: 5,
429 }],
430 ..Default::default()
431 };
432 chain.bind_remote_stream(&info);
433
434 let now = Instant::now();
435
436 let rtp = make_rtp_packet_with_twcc(12345, 1, 0, 5);
438 let pkt = TaggedPacket {
439 now,
440 transport: Default::default(),
441 message: Packet::Rtp(rtp),
442 };
443 chain.handle_read(pkt).unwrap();
444
445 let timeout = chain.poll_timeout();
447 assert!(timeout.is_some());
448 }
449}