oximedia-net 0.1.7

Network streaming for OxiMedia
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
//! SRT connection management with UDP socket.
//!
//! Provides high-level connection handling with async I/O.

use super::congestion::CongestionControl;
use super::crypto::AesContext;
use super::loss::{LossList, ReceiveBuffer};
use super::packet::{ControlPacket, DataPacket, SrtPacket};
use super::socket::{SrtConfig, SrtSocket};
use crate::error::{NetError, NetResult};
use bytes::Bytes;
use std::collections::VecDeque;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::net::UdpSocket;
use tokio::sync::Mutex;
use tokio::time;

/// Send queue entry.
#[derive(Debug, Clone)]
struct SendQueueEntry {
    /// Sequence number.
    seq: u32,
    /// Packet data.
    packet: DataPacket,
    /// Time first sent.
    sent_at: Option<Instant>,
    /// Number of retransmissions.
    retransmit_count: u32,
}

/// SRT connection with UDP transport.
pub struct SrtConnection {
    /// UDP socket.
    socket: Arc<UdpSocket>,
    /// Remote peer address.
    peer_addr: SocketAddr,
    /// SRT state machine.
    state: Arc<Mutex<SrtSocket>>,
    /// Congestion control.
    congestion: Arc<Mutex<CongestionControl>>,
    /// Loss list.
    loss_list: Arc<Mutex<LossList>>,
    /// Receive buffer.
    recv_buffer: Arc<Mutex<ReceiveBuffer>>,
    /// Send queue (unacknowledged packets).
    send_queue: Arc<Mutex<VecDeque<SendQueueEntry>>>,
    /// Encryption context.
    crypto: Arc<Mutex<Option<AesContext>>>,
    /// Last keepalive sent time.
    last_keepalive: Arc<Mutex<Instant>>,
    /// Read buffer for received data.
    read_buffer: Arc<Mutex<VecDeque<Bytes>>>,
}

impl SrtConnection {
    /// Creates a new SRT connection.
    ///
    /// # Errors
    ///
    /// Returns an error if socket binding fails.
    pub async fn new(
        local_addr: SocketAddr,
        peer_addr: SocketAddr,
        config: SrtConfig,
    ) -> NetResult<Self> {
        let socket = UdpSocket::bind(local_addr).await?;
        socket.connect(peer_addr).await?;

        let srt_socket = SrtSocket::new(config.clone());
        let initial_seq = srt_socket.send_seq;

        Ok(Self {
            socket: Arc::new(socket),
            peer_addr,
            state: Arc::new(Mutex::new(srt_socket)),
            congestion: Arc::new(Mutex::new(CongestionControl::new(
                config.flow_window,
                config.flow_window,
            ))),
            loss_list: Arc::new(Mutex::new(LossList::new(1000))),
            recv_buffer: Arc::new(Mutex::new(ReceiveBuffer::new(initial_seq, 1000))),
            send_queue: Arc::new(Mutex::new(VecDeque::new())),
            crypto: Arc::new(Mutex::new(None)),
            last_keepalive: Arc::new(Mutex::new(Instant::now())),
            read_buffer: Arc::new(Mutex::new(VecDeque::new())),
        })
    }

    /// Connects to a remote SRT peer (caller mode).
    ///
    /// # Errors
    ///
    /// Returns an error if the handshake fails or times out.
    pub async fn connect(&self, timeout: Duration) -> NetResult<()> {
        // Generate and send initial handshake
        let handshake_packet = {
            let mut state = self.state.lock().await;
            state.generate_caller_handshake()
        };

        self.send_packet(&handshake_packet).await?;

        // Wait for handshake response
        let deadline = Instant::now() + timeout;
        let mut buf = vec![0u8; 2048];

        loop {
            if Instant::now() > deadline {
                return Err(NetError::timeout("Connection timeout"));
            }

            let remaining = deadline.saturating_duration_since(Instant::now());
            let recv_result = time::timeout(remaining, self.socket.recv(&mut buf)).await;

            match recv_result {
                Ok(Ok(len)) => {
                    if let Ok(packet) = SrtPacket::decode(&buf[..len]) {
                        let responses = {
                            let mut state = self.state.lock().await;
                            state.process_packet(packet)?
                        };

                        for response in responses {
                            self.send_packet(&response).await?;
                        }

                        let is_connected = {
                            let state = self.state.lock().await;
                            state.is_connected()
                        };

                        if is_connected {
                            // Initialize encryption if configured
                            self.initialize_crypto().await?;
                            return Ok(());
                        }
                    }
                }
                Ok(Err(e)) => return Err(e.into()),
                Err(_) => {}
            }
        }
    }

    /// Creates an `SrtConnection` from an already-bound UDP socket that has
    /// received its first inbound packet.
    ///
    /// The socket is connected to `peer_addr` so that `send()` / `recv()`
    /// work without explicit addressing.  The first raw UDP datagram
    /// (`first_packet`) is decoded and fed through the SRT state machine
    /// (INDUCTION phase) before returning, so that callers can immediately
    /// proceed with `accept()` for the CONCLUSION phase.
    ///
    /// # Errors
    ///
    /// Returns an error if `socket.connect` fails or if the first packet
    /// triggers a protocol error.
    pub async fn from_inbound(
        socket: UdpSocket,
        peer_addr: SocketAddr,
        config: SrtConfig,
        first_packet: Vec<u8>,
    ) -> NetResult<Self> {
        socket.connect(peer_addr).await?;

        let srt_socket = SrtSocket::new(config.clone());
        let initial_seq = srt_socket.send_seq;

        let conn = Self {
            socket: Arc::new(socket),
            peer_addr,
            state: Arc::new(Mutex::new(srt_socket)),
            congestion: Arc::new(Mutex::new(CongestionControl::new(
                config.flow_window,
                config.flow_window,
            ))),
            loss_list: Arc::new(Mutex::new(LossList::new(1000))),
            recv_buffer: Arc::new(Mutex::new(ReceiveBuffer::new(initial_seq, 1000))),
            send_queue: Arc::new(Mutex::new(VecDeque::new())),
            crypto: Arc::new(Mutex::new(None)),
            last_keepalive: Arc::new(Mutex::new(Instant::now())),
            read_buffer: Arc::new(Mutex::new(VecDeque::new())),
        };

        // Process the INDUCTION packet that was already received on the
        // pre-bound socket.  Any responses (e.g. INDUCTION reply) are sent
        // immediately so the peer does not time out waiting for an answer.
        if let Ok(packet) = SrtPacket::decode(&first_packet) {
            let responses = {
                let mut state = conn.state.lock().await;
                state.process_packet(packet)?
            };

            for response in responses {
                conn.send_packet(&response).await?;
            }
        }

        Ok(conn)
    }

    /// Accepts an incoming SRT connection (listener mode).
    ///
    /// # Errors
    ///
    /// Returns an error if the handshake fails.
    pub async fn accept(&self) -> NetResult<()> {
        let mut buf = vec![0u8; 2048];

        loop {
            let (len, _addr) = self.socket.recv_from(&mut buf).await?;

            if let Ok(packet) = SrtPacket::decode(&buf[..len]) {
                let responses = {
                    let mut state = self.state.lock().await;
                    state.process_packet(packet)?
                };

                for response in responses {
                    self.send_packet(&response).await?;
                }

                let is_connected = {
                    let state = self.state.lock().await;
                    state.is_connected()
                };

                if is_connected {
                    self.initialize_crypto().await?;
                    return Ok(());
                }
            }
        }
    }

    /// Sends data over the SRT connection.
    ///
    /// # Errors
    ///
    /// Returns an error if not connected or send fails.
    pub async fn send(&self, data: &[u8]) -> NetResult<usize> {
        let is_connected = {
            let state = self.state.lock().await;
            state.is_connected()
        };

        if !is_connected {
            return Err(NetError::invalid_state("Not connected"));
        }

        // Check congestion window
        let cwnd = {
            let cc = self.congestion.lock().await;
            cc.window_size()
        };

        let send_queue_len = {
            let queue = self.send_queue.lock().await;
            queue.len()
        };

        if send_queue_len >= cwnd as usize {
            return Err(NetError::buffer("Send queue full"));
        }

        // Create data packet
        let mut packet = {
            let mut state = self.state.lock().await;
            state.create_data_packet(Bytes::copy_from_slice(data))
        };

        // Encrypt if needed
        if let Some(crypto) = self.crypto.lock().await.as_ref() {
            let iv = generate_iv(packet.sequence_number);
            packet.payload = crypto.encrypt(&packet.payload, &iv)?;
        }

        // Add to send queue
        {
            let mut queue = self.send_queue.lock().await;
            queue.push_back(SendQueueEntry {
                seq: packet.sequence_number,
                packet: packet.clone(),
                sent_at: Some(Instant::now()),
                retransmit_count: 0,
            });
        }

        // Send packet
        self.send_packet(&SrtPacket::Data(packet)).await?;

        Ok(data.len())
    }

    /// Receives data from the SRT connection.
    ///
    /// # Errors
    ///
    /// Returns an error if not connected or receive fails.
    pub async fn recv(&self, buf: &mut [u8]) -> NetResult<usize> {
        loop {
            // Check read buffer first
            {
                let mut read_buf = self.read_buffer.lock().await;
                if let Some(data) = read_buf.pop_front() {
                    let len = data.len().min(buf.len());
                    buf[..len].copy_from_slice(&data[..len]);
                    return Ok(len);
                }
            }

            // Receive from network
            let mut recv_buf = vec![0u8; 2048];
            let len = self.socket.recv(&mut recv_buf).await?;

            if let Ok(packet) = SrtPacket::decode(&recv_buf[..len]) {
                match packet {
                    SrtPacket::Data(data_packet) => {
                        // Decrypt if needed
                        let payload = if let Some(crypto) = self.crypto.lock().await.as_ref() {
                            let iv = generate_iv(data_packet.sequence_number);
                            crypto.decrypt(&data_packet.payload, &iv)?
                        } else {
                            data_packet.payload.clone()
                        };

                        let copy_len = payload.len().min(buf.len());
                        buf[..copy_len].copy_from_slice(&payload[..copy_len]);

                        // Process packet for sequencing
                        let responses = {
                            let mut state = self.state.lock().await;
                            state.process_packet(SrtPacket::Data(data_packet))?
                        };

                        for response in responses {
                            self.send_packet(&response).await?;
                        }

                        return Ok(copy_len);
                    }
                    SrtPacket::Control(ctrl) => {
                        let responses = {
                            let mut state = self.state.lock().await;
                            state.process_packet(SrtPacket::Control(ctrl))?
                        };

                        for response in responses {
                            self.send_packet(&response).await?;
                        }

                        // No data received, loop again to try receiving data
                    }
                }
            } else {
                return Err(NetError::protocol("Invalid packet"));
            }
        }
    }

    /// Runs background tasks (keepalive, retransmission, etc.).
    ///
    /// # Errors
    ///
    /// Returns an error if a critical failure occurs.
    pub async fn run_background_tasks(&self) -> NetResult<()> {
        let mut interval = time::interval(Duration::from_millis(10));

        loop {
            interval.tick().await;

            // Check if connection is still alive
            {
                let state = self.state.lock().await;
                if state.state().is_finished() {
                    break;
                }

                if state.check_timeout() {
                    return Err(NetError::timeout("Peer timeout"));
                }
            }

            // Send keepalive if needed
            self.send_keepalive_if_needed().await?;

            // Check for retransmissions
            self.check_retransmissions().await?;

            // Detect packet loss
            self.detect_loss().await?;
        }

        Ok(())
    }

    /// Closes the connection gracefully.
    ///
    /// # Errors
    ///
    /// Returns an error if sending shutdown packet fails.
    pub async fn close(&self) -> NetResult<()> {
        let shutdown_packet = {
            let mut state = self.state.lock().await;
            state.close()
        };

        if let Some(packet) = shutdown_packet {
            self.send_packet(&packet).await?;
        }

        Ok(())
    }

    /// Returns the peer address.
    #[must_use]
    pub const fn peer_addr(&self) -> SocketAddr {
        self.peer_addr
    }

    /// Returns true if connected.
    pub async fn is_connected(&self) -> bool {
        let state = self.state.lock().await;
        state.is_connected()
    }

    /// Returns current RTT estimate in microseconds.
    pub async fn rtt(&self) -> u32 {
        let cc = self.congestion.lock().await;
        cc.rtt()
    }

    async fn send_packet(&self, packet: &SrtPacket) -> NetResult<()> {
        let encoded = packet.encode();
        self.socket.send(&encoded).await?;
        Ok(())
    }

    async fn initialize_crypto(&self) -> NetResult<()> {
        let config = {
            let state = self.state.lock().await;
            state.config().clone()
        };

        if let Some(passphrase) = config.passphrase {
            let ctx = AesContext::from_passphrase(&passphrase, config.key_size as usize)?;
            let mut crypto = self.crypto.lock().await;
            *crypto = Some(ctx);
        }

        Ok(())
    }

    async fn send_keepalive_if_needed(&self) -> NetResult<()> {
        let mut last_ka = self.last_keepalive.lock().await;

        if last_ka.elapsed() > Duration::from_secs(1) {
            let peer_socket_id = {
                let state = self.state.lock().await;
                state.peer_socket_id()
            };

            let keepalive = ControlPacket::keepalive(peer_socket_id);
            self.send_packet(&SrtPacket::Control(keepalive)).await?;

            *last_ka = Instant::now();
        }

        Ok(())
    }

    async fn check_retransmissions(&self) -> NetResult<()> {
        let rto = {
            let cc = self.congestion.lock().await;
            cc.rto()
        };

        let mut to_retransmit = Vec::new();

        {
            let mut queue = self.send_queue.lock().await;

            for entry in queue.iter_mut() {
                if let Some(sent_at) = entry.sent_at {
                    let elapsed = sent_at.elapsed().as_micros() as u32;

                    if elapsed > rto && entry.retransmit_count < 5 {
                        to_retransmit.push(entry.packet.clone());
                        entry.sent_at = Some(Instant::now());
                        entry.retransmit_count += 1;
                    }
                }
            }
        }

        for packet in to_retransmit {
            self.send_packet(&SrtPacket::Data(packet)).await?;

            let mut cc = self.congestion.lock().await;
            cc.on_loss();
        }

        Ok(())
    }

    async fn detect_loss(&self) -> NetResult<()> {
        let gaps = {
            let recv_buf = self.recv_buffer.lock().await;
            recv_buf.detect_gaps()
        };

        if !gaps.is_empty() {
            let peer_socket_id = {
                let state = self.state.lock().await;
                state.peer_socket_id()
            };

            // Send NAK for lost packets
            let nak = ControlPacket::nak(&gaps, peer_socket_id);
            self.send_packet(&SrtPacket::Control(nak)).await?;

            let mut loss_list = self.loss_list.lock().await;
            for gap in gaps {
                loss_list.add(gap);
            }
        }

        Ok(())
    }
}

/// Generates an IV from sequence number.
fn generate_iv(seq: u32) -> [u8; 16] {
    let mut iv = [0u8; 16];
    iv[0..4].copy_from_slice(&seq.to_be_bytes());
    iv
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_generate_iv() {
        let iv1 = generate_iv(12345);
        let iv2 = generate_iv(12345);
        assert_eq!(iv1, iv2);

        let iv3 = generate_iv(54321);
        assert_ne!(iv1, iv3);
    }
}