rtc_interceptor/twcc/
sender.rs1use super::stream_supports_twcc;
4use crate::stream_info::StreamInfo;
5use crate::{Interceptor, Packet, TaggedPacket, interceptor};
6use shared::error::Error;
7use shared::marshal::Marshal;
8use std::collections::HashMap;
9use std::marker::PhantomData;
10
11pub struct TwccSenderBuilder<P> {
23 _phantom: PhantomData<P>,
24}
25
26impl<P> Default for TwccSenderBuilder<P> {
27 fn default() -> Self {
28 Self {
29 _phantom: PhantomData,
30 }
31 }
32}
33
34impl<P> TwccSenderBuilder<P> {
35 pub fn new() -> Self {
37 Self::default()
38 }
39
40 pub fn build(self) -> impl FnOnce(P) -> TwccSenderInterceptor<P> {
42 move |inner| TwccSenderInterceptor::new(inner)
43 }
44}
45
46struct LocalStream {
48 hdr_ext_id: u8,
50}
51
52#[derive(Interceptor)]
57pub struct TwccSenderInterceptor<P> {
58 #[next]
59 inner: P,
60 next_sequence_number: u16,
62 streams: HashMap<u32, LocalStream>,
64}
65
66impl<P> TwccSenderInterceptor<P> {
67 fn new(inner: P) -> Self {
68 Self {
69 inner,
70 next_sequence_number: 0,
71 streams: HashMap::new(),
72 }
73 }
74}
75
76#[interceptor]
77impl<P: Interceptor> TwccSenderInterceptor<P> {
78 #[overrides]
79 fn handle_write(&mut self, mut msg: TaggedPacket) -> Result<(), Self::Error> {
80 if let Packet::Rtp(ref mut rtp_packet) = msg.message
82 && let Some(stream) = self.streams.get(&rtp_packet.header.ssrc)
83 {
84 let seq = self.next_sequence_number;
86 self.next_sequence_number = self.next_sequence_number.wrapping_add(1);
87
88 let tcc_ext = rtp::extension::transport_cc_extension::TransportCcExtension {
89 transport_sequence: seq,
90 };
91
92 if let Ok(ext_data) = tcc_ext.marshal() {
94 let _ = rtp_packet
96 .header
97 .set_extension(stream.hdr_ext_id, ext_data.freeze());
98 }
99 }
100
101 self.inner.handle_write(msg)
102 }
103
104 #[overrides]
105 fn bind_local_stream(&mut self, info: &StreamInfo) {
106 if let Some(hdr_ext_id) = stream_supports_twcc(info) {
107 if hdr_ext_id != 0 {
109 self.streams.insert(info.ssrc, LocalStream { hdr_ext_id });
110 }
111 }
112 self.inner.bind_local_stream(info);
113 }
114
115 #[overrides]
116 fn unbind_local_stream(&mut self, info: &StreamInfo) {
117 self.streams.remove(&info.ssrc);
118 self.inner.unbind_local_stream(info);
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::Registry;
126 use crate::stream_info::RTPHeaderExtension;
127 use sansio::Protocol;
128 use shared::marshal::Unmarshal;
129 use std::time::Instant;
130
131 fn make_rtp_packet(ssrc: u32, seq: u16) -> TaggedPacket {
132 TaggedPacket {
133 now: Instant::now(),
134 transport: Default::default(),
135 message: Packet::Rtp(rtp::Packet {
136 header: rtp::header::Header {
137 ssrc,
138 sequence_number: seq,
139 ..Default::default()
140 },
141 payload: vec![].into(),
142 }),
143 }
144 }
145
146 #[test]
147 fn test_twcc_sender_builder_defaults() {
148 let chain = Registry::new()
149 .with(TwccSenderBuilder::default().build())
150 .build();
151
152 assert!(chain.streams.is_empty());
153 }
154
155 #[test]
156 fn test_twcc_sender_adds_extension() {
157 let mut chain = Registry::new()
158 .with(TwccSenderBuilder::new().build())
159 .build();
160
161 let info = StreamInfo {
163 ssrc: 12345,
164 rtp_header_extensions: vec![RTPHeaderExtension {
165 uri: super::super::TRANSPORT_CC_URI.to_string(),
166 id: 5,
167 }],
168 ..Default::default()
169 };
170 chain.bind_local_stream(&info);
171
172 let pkt1 = make_rtp_packet(12345, 1);
174 chain.handle_write(pkt1).unwrap();
175 let out1 = chain.poll_write().unwrap();
176
177 let pkt2 = make_rtp_packet(12345, 2);
178 chain.handle_write(pkt2).unwrap();
179 let out2 = chain.poll_write().unwrap();
180
181 if let Packet::Rtp(rtp1) = out1.message {
183 let ext = rtp1.header.get_extension(5);
184 assert!(ext.is_some());
185 let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
186 &mut ext.unwrap().as_ref(),
187 )
188 .unwrap();
189 assert_eq!(tcc.transport_sequence, 0);
190 } else {
191 panic!("Expected RTP packet");
192 }
193
194 if let Packet::Rtp(rtp2) = out2.message {
195 let ext = rtp2.header.get_extension(5);
196 assert!(ext.is_some());
197 let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
198 &mut ext.unwrap().as_ref(),
199 )
200 .unwrap();
201 assert_eq!(tcc.transport_sequence, 1);
202 } else {
203 panic!("Expected RTP packet");
204 }
205 }
206
207 #[test]
208 fn test_twcc_sender_no_extension_without_binding() {
209 let mut chain = Registry::new()
210 .with(TwccSenderBuilder::new().build())
211 .build();
212
213 let pkt = make_rtp_packet(12345, 1);
215 chain.handle_write(pkt).unwrap();
216 let out = chain.poll_write().unwrap();
217
218 if let Packet::Rtp(rtp) = out.message {
220 assert!(rtp.header.get_extension(5).is_none());
221 } else {
222 panic!("Expected RTP packet");
223 }
224 }
225
226 #[test]
227 fn test_twcc_sender_unbind_removes_stream() {
228 let mut chain = Registry::new()
229 .with(TwccSenderBuilder::new().build())
230 .build();
231
232 let info = StreamInfo {
233 ssrc: 12345,
234 rtp_header_extensions: vec![RTPHeaderExtension {
235 uri: super::super::TRANSPORT_CC_URI.to_string(),
236 id: 5,
237 }],
238 ..Default::default()
239 };
240
241 chain.bind_local_stream(&info);
242 assert!(chain.streams.contains_key(&12345));
243
244 chain.unbind_local_stream(&info);
245 assert!(!chain.streams.contains_key(&12345));
246 }
247
248 #[test]
249 fn test_twcc_sender_sequence_wraparound() {
250 let mut chain = Registry::new()
251 .with(TwccSenderBuilder::new().build())
252 .build();
253
254 let info = StreamInfo {
255 ssrc: 12345,
256 rtp_header_extensions: vec![RTPHeaderExtension {
257 uri: super::super::TRANSPORT_CC_URI.to_string(),
258 id: 5,
259 }],
260 ..Default::default()
261 };
262 chain.bind_local_stream(&info);
263
264 chain.next_sequence_number = 65534;
266
267 for expected_seq in [65534u16, 65535, 0, 1] {
268 let pkt = make_rtp_packet(12345, 1);
269 chain.handle_write(pkt).unwrap();
270 let out = chain.poll_write().unwrap();
271
272 if let Packet::Rtp(rtp) = out.message {
273 let ext = rtp.header.get_extension(5).unwrap();
274 let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
275 &mut ext.as_ref(),
276 )
277 .unwrap();
278 assert_eq!(tcc.transport_sequence, expected_seq);
279 }
280 }
281 }
282
283 #[test]
284 fn test_twcc_sender_multiple_streams_share_counter() {
285 let mut chain = Registry::new()
286 .with(TwccSenderBuilder::new().build())
287 .build();
288
289 let info1 = StreamInfo {
291 ssrc: 1111,
292 rtp_header_extensions: vec![RTPHeaderExtension {
293 uri: super::super::TRANSPORT_CC_URI.to_string(),
294 id: 5,
295 }],
296 ..Default::default()
297 };
298 let info2 = StreamInfo {
299 ssrc: 2222,
300 rtp_header_extensions: vec![RTPHeaderExtension {
301 uri: super::super::TRANSPORT_CC_URI.to_string(),
302 id: 5,
303 }],
304 ..Default::default()
305 };
306 chain.bind_local_stream(&info1);
307 chain.bind_local_stream(&info2);
308
309 for (i, ssrc) in [1111u32, 2222, 1111, 2222].iter().enumerate() {
311 let pkt = make_rtp_packet(*ssrc, 1);
312 chain.handle_write(pkt).unwrap();
313 let out = chain.poll_write().unwrap();
314
315 if let Packet::Rtp(rtp) = out.message {
316 let ext = rtp.header.get_extension(5).unwrap();
317 let tcc = rtp::extension::transport_cc_extension::TransportCcExtension::unmarshal(
318 &mut ext.as_ref(),
319 )
320 .unwrap();
321 assert_eq!(tcc.transport_sequence, i as u16);
322 }
323 }
324 }
325}