rtc_interceptor/twcc/
sender.rs1use super::stream_supports_twcc;
4use crate::stream_info::StreamInfo;
5use crate::{Interceptor, Packet, TaggedPacket};
6use shared::error::Error;
7use shared::marshal::Marshal;
8use std::collections::HashMap;
9use std::marker::PhantomData;
10use std::time::Instant;
11
12pub struct TwccSenderBuilder<P> {
24 _phantom: PhantomData<P>,
25}
26
27impl<P> Default for TwccSenderBuilder<P> {
28 fn default() -> Self {
29 Self {
30 _phantom: PhantomData,
31 }
32 }
33}
34
35impl<P> TwccSenderBuilder<P> {
36 pub fn new() -> Self {
38 Self::default()
39 }
40
41 pub fn build(self) -> impl FnOnce(P) -> TwccSenderInterceptor<P> {
43 move |inner| TwccSenderInterceptor::new(inner)
44 }
45}
46
47struct LocalStream {
49 hdr_ext_id: u8,
51}
52
53pub struct TwccSenderInterceptor<P> {
58 inner: P,
59 next_sequence_number: u16,
61 streams: HashMap<u32, LocalStream>,
63}
64
65impl<P> TwccSenderInterceptor<P> {
66 fn new(inner: P) -> Self {
67 Self {
68 inner,
69 next_sequence_number: 0,
70 streams: HashMap::new(),
71 }
72 }
73}
74
75impl<P: Interceptor> sansio::Protocol<TaggedPacket, TaggedPacket, ()> for TwccSenderInterceptor<P> {
76 type Rout = TaggedPacket;
77 type Wout = TaggedPacket;
78 type Eout = ();
79 type Error = Error;
80 type Time = Instant;
81
82 fn handle_read(&mut self, msg: TaggedPacket) -> Result<(), Self::Error> {
83 self.inner.handle_read(msg)
84 }
85
86 fn poll_read(&mut self) -> Option<Self::Rout> {
87 self.inner.poll_read()
88 }
89
90 fn handle_write(&mut self, mut msg: TaggedPacket) -> Result<(), Self::Error> {
91 if let Packet::Rtp(ref mut rtp_packet) = msg.message
93 && let Some(stream) = self.streams.get(&rtp_packet.header.ssrc)
94 {
95 let seq = self.next_sequence_number;
97 self.next_sequence_number = self.next_sequence_number.wrapping_add(1);
98
99 let tcc_ext = rtp::extension::transport_cc_extension::TransportCcExtension {
100 transport_sequence: seq,
101 };
102
103 if let Ok(ext_data) = tcc_ext.marshal() {
105 let _ = rtp_packet
107 .header
108 .set_extension(stream.hdr_ext_id, ext_data.freeze());
109 }
110 }
111
112 self.inner.handle_write(msg)
113 }
114
115 fn poll_write(&mut self) -> Option<Self::Wout> {
116 self.inner.poll_write()
117 }
118
119 fn handle_timeout(&mut self, now: Self::Time) -> Result<(), Self::Error> {
120 self.inner.handle_timeout(now)
121 }
122
123 fn poll_timeout(&mut self) -> Option<Self::Time> {
124 self.inner.poll_timeout()
125 }
126}
127
128impl<P: Interceptor> Interceptor for TwccSenderInterceptor<P> {
129 fn bind_local_stream(&mut self, info: &StreamInfo) {
130 if let Some(hdr_ext_id) = stream_supports_twcc(info) {
131 if hdr_ext_id != 0 {
133 self.streams.insert(info.ssrc, LocalStream { hdr_ext_id });
134 }
135 }
136 self.inner.bind_local_stream(info);
137 }
138
139 fn unbind_local_stream(&mut self, info: &StreamInfo) {
140 self.streams.remove(&info.ssrc);
141 self.inner.unbind_local_stream(info);
142 }
143
144 fn bind_remote_stream(&mut self, info: &StreamInfo) {
145 self.inner.bind_remote_stream(info);
146 }
147
148 fn unbind_remote_stream(&mut self, info: &StreamInfo) {
149 self.inner.unbind_remote_stream(info);
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::Registry;
157 use crate::stream_info::RTPHeaderExtension;
158 use sansio::Protocol;
159 use shared::marshal::Unmarshal;
160
161 fn make_rtp_packet(ssrc: u32, seq: u16) -> TaggedPacket {
162 TaggedPacket {
163 now: Instant::now(),
164 transport: Default::default(),
165 message: Packet::Rtp(rtp::Packet {
166 header: rtp::header::Header {
167 ssrc,
168 sequence_number: seq,
169 ..Default::default()
170 },
171 payload: vec![].into(),
172 }),
173 }
174 }
175
176 #[test]
177 fn test_twcc_sender_builder_defaults() {
178 let chain = Registry::new()
179 .with(TwccSenderBuilder::default().build())
180 .build();
181
182 assert!(chain.streams.is_empty());
183 }
184
185 #[test]
186 fn test_twcc_sender_adds_extension() {
187 let mut chain = Registry::new()
188 .with(TwccSenderBuilder::new().build())
189 .build();
190
191 let info = StreamInfo {
193 ssrc: 12345,
194 rtp_header_extensions: vec![RTPHeaderExtension {
195 uri: super::super::TRANSPORT_CC_URI.to_string(),
196 id: 5,
197 }],
198 ..Default::default()
199 };
200 chain.bind_local_stream(&info);
201
202 let pkt1 = make_rtp_packet(12345, 1);
204 chain.handle_write(pkt1).unwrap();
205 let out1 = chain.poll_write().unwrap();
206
207 let pkt2 = make_rtp_packet(12345, 2);
208 chain.handle_write(pkt2).unwrap();
209 let out2 = chain.poll_write().unwrap();
210
211 if let Packet::Rtp(rtp1) = out1.message {
213 let ext = rtp1.header.get_extension(5);
214 assert!(ext.is_some());
215 let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
216 &mut ext.unwrap().as_ref(),
217 )
218 .unwrap();
219 assert_eq!(tcc.transport_sequence, 0);
220 } else {
221 panic!("Expected RTP packet");
222 }
223
224 if let Packet::Rtp(rtp2) = out2.message {
225 let ext = rtp2.header.get_extension(5);
226 assert!(ext.is_some());
227 let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
228 &mut ext.unwrap().as_ref(),
229 )
230 .unwrap();
231 assert_eq!(tcc.transport_sequence, 1);
232 } else {
233 panic!("Expected RTP packet");
234 }
235 }
236
237 #[test]
238 fn test_twcc_sender_no_extension_without_binding() {
239 let mut chain = Registry::new()
240 .with(TwccSenderBuilder::new().build())
241 .build();
242
243 let pkt = make_rtp_packet(12345, 1);
245 chain.handle_write(pkt).unwrap();
246 let out = chain.poll_write().unwrap();
247
248 if let Packet::Rtp(rtp) = out.message {
250 assert!(rtp.header.get_extension(5).is_none());
251 } else {
252 panic!("Expected RTP packet");
253 }
254 }
255
256 #[test]
257 fn test_twcc_sender_unbind_removes_stream() {
258 let mut chain = Registry::new()
259 .with(TwccSenderBuilder::new().build())
260 .build();
261
262 let info = StreamInfo {
263 ssrc: 12345,
264 rtp_header_extensions: vec![RTPHeaderExtension {
265 uri: super::super::TRANSPORT_CC_URI.to_string(),
266 id: 5,
267 }],
268 ..Default::default()
269 };
270
271 chain.bind_local_stream(&info);
272 assert!(chain.streams.contains_key(&12345));
273
274 chain.unbind_local_stream(&info);
275 assert!(!chain.streams.contains_key(&12345));
276 }
277
278 #[test]
279 fn test_twcc_sender_sequence_wraparound() {
280 let mut chain = Registry::new()
281 .with(TwccSenderBuilder::new().build())
282 .build();
283
284 let info = StreamInfo {
285 ssrc: 12345,
286 rtp_header_extensions: vec![RTPHeaderExtension {
287 uri: super::super::TRANSPORT_CC_URI.to_string(),
288 id: 5,
289 }],
290 ..Default::default()
291 };
292 chain.bind_local_stream(&info);
293
294 chain.next_sequence_number = 65534;
296
297 for expected_seq in [65534u16, 65535, 0, 1] {
298 let pkt = make_rtp_packet(12345, 1);
299 chain.handle_write(pkt).unwrap();
300 let out = chain.poll_write().unwrap();
301
302 if let Packet::Rtp(rtp) = out.message {
303 let ext = rtp.header.get_extension(5).unwrap();
304 let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
305 &mut ext.as_ref(),
306 )
307 .unwrap();
308 assert_eq!(tcc.transport_sequence, expected_seq);
309 }
310 }
311 }
312
313 #[test]
314 fn test_twcc_sender_multiple_streams_share_counter() {
315 let mut chain = Registry::new()
316 .with(TwccSenderBuilder::new().build())
317 .build();
318
319 let info1 = StreamInfo {
321 ssrc: 1111,
322 rtp_header_extensions: vec![RTPHeaderExtension {
323 uri: super::super::TRANSPORT_CC_URI.to_string(),
324 id: 5,
325 }],
326 ..Default::default()
327 };
328 let info2 = StreamInfo {
329 ssrc: 2222,
330 rtp_header_extensions: vec![RTPHeaderExtension {
331 uri: super::super::TRANSPORT_CC_URI.to_string(),
332 id: 5,
333 }],
334 ..Default::default()
335 };
336 chain.bind_local_stream(&info1);
337 chain.bind_local_stream(&info2);
338
339 for (i, ssrc) in [1111u32, 2222, 1111, 2222].iter().enumerate() {
341 let pkt = make_rtp_packet(*ssrc, 1);
342 chain.handle_write(pkt).unwrap();
343 let out = chain.poll_write().unwrap();
344
345 if let Packet::Rtp(rtp) = out.message {
346 let ext = rtp.header.get_extension(5).unwrap();
347 let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
348 &mut ext.as_ref(),
349 )
350 .unwrap();
351 assert_eq!(tcc.transport_sequence, i as u16);
352 }
353 }
354 }
355}