1use shadow_core::{Packet, PacketType, PeerInfo, Transport, TransportStats, CoverProtocol};
7use shadow_core::error::{Result, ShadowError};
8use bytes::Bytes;
9use async_trait::async_trait;
10use tokio::sync::mpsc;
11use std::time::{Duration, Instant};
12use rand::Rng;
13
14#[derive(Debug, Clone)]
16pub struct RTPPacket {
17 pub version: u8,
19 pub padding: bool,
21 pub extension: bool,
23 pub csrc_count: u8,
25 pub marker: bool,
27 pub payload_type: u8,
29 pub sequence: u16,
31 pub timestamp: u32,
33 pub ssrc: u32,
35 pub payload: Vec<u8>,
37}
38
39impl RTPPacket {
40 pub fn new(sequence: u16, timestamp: u32, ssrc: u32, payload: Vec<u8>) -> Self {
42 Self {
43 version: 2,
44 padding: false,
45 extension: false,
46 csrc_count: 0,
47 marker: false,
48 payload_type: 96, sequence,
50 timestamp,
51 ssrc,
52 payload,
53 }
54 }
55
56 pub fn to_bytes(&self) -> Vec<u8> {
58 let mut bytes = Vec::with_capacity(12 + self.payload.len());
59
60 bytes.push(
62 (self.version << 6) |
63 ((self.padding as u8) << 5) |
64 ((self.extension as u8) << 4) |
65 self.csrc_count
66 );
67
68 bytes.push(
70 ((self.marker as u8) << 7) |
71 self.payload_type
72 );
73
74 bytes.extend_from_slice(&self.sequence.to_be_bytes());
76
77 bytes.extend_from_slice(&self.timestamp.to_be_bytes());
79
80 bytes.extend_from_slice(&self.ssrc.to_be_bytes());
82
83 bytes.extend_from_slice(&self.payload);
85
86 bytes
87 }
88
89 pub fn from_bytes(data: &[u8]) -> Result<Self> {
91 if data.len() < 12 {
92 return Err(ShadowError::InvalidPacket("RTP packet too short".into()));
93 }
94
95 let version = (data[0] >> 6) & 0x03;
96 let padding = (data[0] & 0x20) != 0;
97 let extension = (data[0] & 0x10) != 0;
98 let csrc_count = data[0] & 0x0F;
99
100 let marker = (data[1] & 0x80) != 0;
101 let payload_type = data[1] & 0x7F;
102
103 let sequence = u16::from_be_bytes([data[2], data[3]]);
104 let timestamp = u32::from_be_bytes([data[4], data[5], data[6], data[7]]);
105 let ssrc = u32::from_be_bytes([data[8], data[9], data[10], data[11]]);
106
107 let payload = data[12..].to_vec();
108
109 Ok(Self {
110 version,
111 padding,
112 extension,
113 csrc_count,
114 marker,
115 payload_type,
116 sequence,
117 timestamp,
118 ssrc,
119 payload,
120 })
121 }
122}
123
124pub struct WebRTCTransport {
126 ssrc: u32,
128 sequence: u16,
130 timestamp: u32,
132 start_time: Instant,
134 tx: mpsc::UnboundedSender<Packet>,
136 rx: mpsc::UnboundedReceiver<Packet>,
138 stats: TransportStats,
140 dummy_traffic: bool,
142 target_bitrate: u64,
144}
145
146impl WebRTCTransport {
147 pub fn new() -> Self {
149 let (tx, rx) = mpsc::unbounded_channel();
150 let mut rng = rand::thread_rng();
151
152 Self {
153 ssrc: rng.gen(),
154 sequence: 0,
155 timestamp: 0,
156 start_time: Instant::now(),
157 tx,
158 rx,
159 stats: TransportStats::default(),
160 dummy_traffic: true,
161 target_bitrate: 500_000, }
163 }
164
165 fn create_rtp_packet(&mut self, data: &[u8]) -> RTPPacket {
167 let max_payload = 1200; let chunks: Vec<&[u8]> = data.chunks(max_payload).collect();
170
171 let payload = if !chunks.is_empty() {
173 chunks[0].to_vec()
174 } else {
175 vec![]
176 };
177
178 self.sequence = self.sequence.wrapping_add(1);
179
180 self.timestamp = self.timestamp.wrapping_add(3000);
182
183 RTPPacket::new(self.sequence, self.timestamp, self.ssrc, payload)
184 }
185
186 fn generate_packet_size(&self) -> usize {
188 let mut rng = rand::thread_rng();
189
190 let r: f64 = rng.gen();
196 if r < 0.03 {
197 rng.gen_range(5000..15000)
199 } else if r < 0.7 {
200 rng.gen_range(500..3000)
202 } else {
203 rng.gen_range(100..500)
205 }
206 }
207
208 async fn add_jitter(&self) {
210 let jitter_ms = {
212 let mut rng = rand::thread_rng();
213 (rng.gen::<f64>() * 10.0 + 20.0).max(0.0) as u64
215 };
216
217 tokio::time::sleep(Duration::from_millis(jitter_ms)).await;
218 }
219}
220
221impl Default for WebRTCTransport {
222 fn default() -> Self {
223 Self::new()
224 }
225}
226
227#[async_trait]
228impl Transport for WebRTCTransport {
229 fn protocol(&self) -> CoverProtocol {
230 CoverProtocol::WebRTC
231 }
232
233 async fn send(&mut self, packet: Packet, _peer: &PeerInfo) -> Result<()> {
234 let rtp = self.create_rtp_packet(&packet.payload);
236 let rtp_bytes = rtp.to_bytes();
237
238 self.add_jitter().await;
240
241 self.stats.packets_sent += 1;
243 self.stats.bytes_sent += rtp_bytes.len() as u64;
244
245 self.tx.send(packet).map_err(|e| {
248 ShadowError::Transport(format!("Failed to send packet: {}", e))
249 })?;
250
251 Ok(())
252 }
253
254 async fn recv(&mut self) -> Result<Packet> {
255 self.rx.recv().await.ok_or_else(|| {
257 ShadowError::Transport("Channel closed".into())
258 })
259 }
260
261 fn stats(&self) -> TransportStats {
262 self.stats.clone()
263 }
264
265 async fn start_background_traffic(&mut self) -> Result<()> {
266 self.dummy_traffic = true;
267
268 Ok(())
272 }
273
274 async fn stop_background_traffic(&mut self) -> Result<()> {
275 self.dummy_traffic = false;
276 Ok(())
277 }
278
279 async fn set_bandwidth_limit(&mut self, bytes_per_sec: u64) -> Result<()> {
280 self.target_bitrate = bytes_per_sec;
281 Ok(())
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_rtp_packet_serialization() {
291 let packet = RTPPacket::new(12345, 67890, 11111, vec![1, 2, 3, 4, 5]);
292 let bytes = packet.to_bytes();
293 let parsed = RTPPacket::from_bytes(&bytes).unwrap();
294
295 assert_eq!(parsed.sequence, 12345);
296 assert_eq!(parsed.timestamp, 67890);
297 assert_eq!(parsed.ssrc, 11111);
298 assert_eq!(parsed.payload, vec![1, 2, 3, 4, 5]);
299 }
300
301 #[tokio::test]
302 async fn test_webrtc_transport() {
303 let mut transport = WebRTCTransport::new();
304
305 assert_eq!(transport.protocol(), CoverProtocol::WebRTC);
306
307 let packet = Packet::new(
308 PacketType::Data,
309 None,
310 None,
311 Bytes::from(vec![1, 2, 3, 4]),
312 );
313
314 let peer = PeerInfo::new(
315 shadow_core::PeerId::random(),
316 vec!["127.0.0.1:9000".to_string()],
317 [0u8; 32],
318 [0u8; 32],
319 );
320
321 transport.send(packet, &peer).await.unwrap();
322
323 let stats = transport.stats();
324 assert_eq!(stats.packets_sent, 1);
325 }
326}