Skip to main content

arcly_stream/protocol/
ingest.rs

1//! Shared ingest utilities for TCP-based protocol handlers.
2//!
3//! Ported from `sc-protocol-ingest`, freed of the `sc-metrics::Metrics::global()`
4//! calls — observability is injected via the engine's [`Observer`], not reached
5//! for here.
6//!
7//! [`Observer`]: crate::Observer
8
9use crate::Result;
10use std::future::Future;
11use std::net::SocketAddr;
12use std::sync::Arc;
13use std::time::Instant;
14use tokio::net::TcpListener;
15use tokio::sync::Semaphore;
16use tokio_util::sync::CancellationToken;
17use tracing::warn;
18
19// ── Frame size constants ────────────────────────────────────────────────────
20
21/// Maximum allowed video frame size (8 MiB). Frames larger than this should be
22/// dropped before parsing so a crafted source cannot exhaust heap memory.
23pub const MAX_VIDEO_FRAME: usize = 8 * 1024 * 1024;
24
25/// Maximum allowed audio frame size (1 MiB).
26pub const MAX_AUDIO_FRAME: usize = 1024 * 1024;
27
28// ── Annex-B NAL start-code scanning ──────────────────────────────────────────
29
30/// A located Annex-B start code within an H.264/H.265 bytestream.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub struct NalStart {
33    /// Byte offset of the start code's leading zero.
34    pub offset: usize,
35    /// Length of the start code in bytes: 3 (`00 00 01`) or 4 (`00 00 00 01`).
36    pub len: usize,
37}
38
39/// Find the next Annex-B NAL start code (`00 00 01` or `00 00 00 01`) in `buf`
40/// at or after `from`.
41///
42/// Uses `memchr` to skip directly between candidate `0x01` bytes rather than
43/// scanning byte-by-byte — the difference is significant on the ingest hot path
44/// where every incoming H.264 packet is split into NAL units.
45///
46/// ```
47/// use arcly_stream::protocol::find_nal_start;
48///
49/// // 4-byte start code at offset 0, 3-byte start code at offset 7.
50/// let buf = [0, 0, 0, 1, 9, 0xF0, 0, 0, 0, 1, 0x65];
51/// let first = find_nal_start(&buf, 0).unwrap();
52/// assert_eq!((first.offset, first.len), (0, 4));
53/// let second = find_nal_start(&buf, first.offset + first.len).unwrap();
54/// assert_eq!((second.offset, second.len), (6, 4));
55/// ```
56pub fn find_nal_start(buf: &[u8], from: usize) -> Option<NalStart> {
57    let mut search = from;
58    while let Some(rel) = memchr::memchr(0x01, &buf[search..]) {
59        let one = search + rel;
60        // Need at least `00 00` immediately before the `01`.
61        if one >= 2 && buf[one - 1] == 0 && buf[one - 2] == 0 {
62            // Prefer the 4-byte form when a third leading zero is present.
63            if one >= 3 && buf[one - 3] == 0 {
64                return Some(NalStart {
65                    offset: one - 3,
66                    len: 4,
67                });
68            }
69            return Some(NalStart {
70                offset: one - 2,
71                len: 3,
72            });
73        }
74        search = one + 1;
75    }
76    None
77}
78
79// ── KeyframeGate ─────────────────────────────────────────────────────────────
80
81/// Gate that suppresses delta frames until the first IDR (keyframe) arrives, so
82/// late-joining or re-subscribing clients always start at a clean decoder
83/// boundary.
84///
85/// ```
86/// use arcly_stream::protocol::KeyframeGate;
87/// use arcly_stream::FrameType;
88///
89/// let mut gate = KeyframeGate::new();
90/// // A delta frame before any keyframe is held back:
91/// assert!(!gate.admit(FrameType::Delta));
92/// // The first keyframe opens the gate and is itself admitted:
93/// assert!(gate.admit(FrameType::Key));
94/// // Subsequent deltas now flow:
95/// assert!(gate.admit(FrameType::Delta));
96/// // Audio always flows (it carries no decode dependency on video IDRs):
97/// assert!(KeyframeGate::new().admit(FrameType::Audio));
98/// ```
99#[derive(Debug, Default)]
100pub struct KeyframeGate {
101    open: bool,
102}
103
104impl KeyframeGate {
105    /// Create a new gate (initially closed — non-audio frames are held).
106    pub fn new() -> Self {
107        Self { open: false }
108    }
109
110    /// Open the gate unconditionally.
111    pub fn open(&mut self) {
112        self.open = true;
113    }
114
115    /// Whether the gate is currently open.
116    pub fn is_open(&self) -> bool {
117        self.open
118    }
119
120    /// Decide whether a frame of the given type should be admitted, opening the
121    /// gate on the first keyframe. Audio is always admitted.
122    pub fn admit(&mut self, frame_type: crate::FrameType) -> bool {
123        match frame_type {
124            crate::FrameType::Audio => true,
125            crate::FrameType::Key => {
126                self.open = true;
127                true
128            }
129            crate::FrameType::Delta => self.open,
130        }
131    }
132}
133
134// ── IngestRateLimit ──────────────────────────────────────────────────────────
135
136/// Per-connection sliding-window ingress rate limiter (bytes/second).
137#[derive(Debug)]
138pub struct IngestRateLimit {
139    max_bytes_per_sec: u64,
140    window_start: Instant,
141    bytes_in_window: u64,
142}
143
144impl IngestRateLimit {
145    pub fn new(max_bytes_per_sec: u64) -> Self {
146        Self {
147            max_bytes_per_sec,
148            window_start: Instant::now(),
149            bytes_in_window: 0,
150        }
151    }
152
153    /// Record `len` ingested bytes; returns `false` if the per-second budget is
154    /// exceeded (the caller should drop the connection or backpressure).
155    pub fn allow(&mut self, len: usize) -> bool {
156        if self.max_bytes_per_sec == 0 {
157            return true; // unlimited
158        }
159        let now = Instant::now();
160        if now.duration_since(self.window_start).as_secs() >= 1 {
161            self.window_start = now;
162            self.bytes_in_window = 0;
163        }
164        self.bytes_in_window = self.bytes_in_window.saturating_add(len as u64);
165        self.bytes_in_window <= self.max_bytes_per_sec
166    }
167}
168
169// ── Generic TCP accept loop ──────────────────────────────────────────────────
170
171/// Generic TCP accept loop shared by RTMP/RTSP handlers.
172///
173/// Binds `addr`, then for each accepted connection spawns `handle` (bounded by
174/// `max_connections` via a semaphore). Runs until `shutdown` is cancelled.
175///
176/// `handle` receives the socket and peer address; it owns parsing the protocol
177/// and forwarding frames to a [`PublishRegistry`](crate::PublishRegistry).
178pub async fn run_tcp_ingest_server<F, Fut>(
179    addr: SocketAddr,
180    max_connections: usize,
181    shutdown: CancellationToken,
182    handle: F,
183) -> Result<()>
184where
185    F: Fn(tokio::net::TcpStream, SocketAddr) -> Fut + Send + Sync + 'static,
186    Fut: Future<Output = ()> + Send + 'static,
187{
188    let listener = TcpListener::bind(addr).await?;
189    let limiter = Arc::new(Semaphore::new(max_connections.max(1)));
190    let handle = Arc::new(handle);
191
192    loop {
193        tokio::select! {
194            _ = shutdown.cancelled() => return Ok(()),
195            accepted = listener.accept() => {
196                let (sock, peer) = match accepted {
197                    Ok(pair) => pair,
198                    Err(e) => { warn!(error = %e, "accept failed"); continue; }
199                };
200                let permit = match Arc::clone(&limiter).try_acquire_owned() {
201                    Ok(p) => p,
202                    Err(_) => {
203                        warn!(%peer, "connection limit reached; rejecting");
204                        continue;
205                    }
206                };
207                let handle = Arc::clone(&handle);
208                tokio::spawn(async move {
209                    let _permit = permit; // released on task completion
210                    handle(sock, peer).await;
211                });
212            }
213        }
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use crate::FrameType;
221
222    #[test]
223    fn nal_start_finds_three_and_four_byte_codes() {
224        let buf = [0, 0, 0, 1, 9, 0xF0, 0, 0, 1, 0x65];
225        let first = find_nal_start(&buf, 0).unwrap();
226        assert_eq!(first, NalStart { offset: 0, len: 4 });
227        // Resume scanning past the first; next is the 3-byte code at offset 6.
228        let second = find_nal_start(&buf, first.offset + first.len).unwrap();
229        assert_eq!(second, NalStart { offset: 6, len: 3 });
230    }
231
232    #[test]
233    fn nal_start_returns_none_without_a_code() {
234        assert!(find_nal_start(&[0x01, 0x02, 0x00, 0x01], 0).is_none());
235        assert!(find_nal_start(&[], 0).is_none());
236    }
237
238    #[test]
239    fn rate_limit_resets_each_window() {
240        let mut rl = IngestRateLimit::new(100);
241        assert!(rl.allow(60));
242        assert!(rl.allow(40)); // exactly at budget
243        assert!(!rl.allow(1)); // over budget within the same second
244                               // Force the window boundary and confirm the budget refreshes.
245        rl.window_start = Instant::now() - std::time::Duration::from_secs(2);
246        assert!(rl.allow(100));
247    }
248
249    #[test]
250    fn rate_limit_zero_is_unlimited() {
251        let mut rl = IngestRateLimit::new(0);
252        assert!(rl.allow(usize::MAX));
253    }
254
255    #[test]
256    fn keyframe_gate_holds_deltas_until_idr() {
257        let mut gate = KeyframeGate::new();
258        assert!(!gate.is_open());
259        assert!(!gate.admit(FrameType::Delta));
260        assert!(gate.admit(FrameType::Audio)); // audio bypasses the gate
261        assert!(gate.admit(FrameType::Key)); // opens
262        assert!(gate.is_open());
263        assert!(gate.admit(FrameType::Delta));
264    }
265}