use crate::Result;
use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
use tokio::net::TcpListener;
use tokio::sync::Semaphore;
use tokio_util::sync::CancellationToken;
use tracing::warn;
pub const MAX_VIDEO_FRAME: usize = 8 * 1024 * 1024;
pub const MAX_AUDIO_FRAME: usize = 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct NalStart {
pub offset: usize,
pub len: usize,
}
pub fn find_nal_start(buf: &[u8], from: usize) -> Option<NalStart> {
crate::bytescan::scan_start_code(buf, from, 0).map(|(offset, len)| NalStart { offset, len })
}
#[derive(Debug, Default)]
pub struct KeyframeGate {
open: bool,
}
impl KeyframeGate {
pub fn new() -> Self {
Self { open: false }
}
pub fn open(&mut self) {
self.open = true;
}
pub fn is_open(&self) -> bool {
self.open
}
pub fn admit(&mut self, frame_type: crate::FrameType) -> bool {
match frame_type {
crate::FrameType::Audio => true,
crate::FrameType::Key => {
self.open = true;
true
}
crate::FrameType::Delta => self.open,
}
}
}
#[derive(Debug)]
pub struct IngestRateLimit {
max_bytes_per_sec: u64,
window_start: Instant,
bytes_in_window: u64,
}
impl IngestRateLimit {
pub fn new(max_bytes_per_sec: u64) -> Self {
Self {
max_bytes_per_sec,
window_start: Instant::now(),
bytes_in_window: 0,
}
}
pub fn allow(&mut self, len: usize) -> bool {
if self.max_bytes_per_sec == 0 {
return true; }
let now = Instant::now();
if now.duration_since(self.window_start).as_secs() >= 1 {
self.window_start = now;
self.bytes_in_window = 0;
}
self.bytes_in_window = self.bytes_in_window.saturating_add(len as u64);
self.bytes_in_window <= self.max_bytes_per_sec
}
}
pub async fn run_tcp_ingest_server<F, Fut>(
addr: SocketAddr,
max_connections: usize,
shutdown: CancellationToken,
handle: F,
) -> Result<()>
where
F: Fn(tokio::net::TcpStream, SocketAddr) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let listener = TcpListener::bind(addr).await?;
let limiter = Arc::new(Semaphore::new(max_connections.max(1)));
let handle = Arc::new(handle);
loop {
tokio::select! {
_ = shutdown.cancelled() => return Ok(()),
accepted = listener.accept() => {
let (sock, peer) = match accepted {
Ok(pair) => pair,
Err(e) => { warn!(error = %e, "accept failed"); continue; }
};
let permit = match Arc::clone(&limiter).try_acquire_owned() {
Ok(p) => p,
Err(_) => {
warn!(%peer, "connection limit reached; rejecting");
continue;
}
};
let handle = Arc::clone(&handle);
tokio::spawn(async move {
let _permit = permit; handle(sock, peer).await;
});
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::FrameType;
#[test]
fn nal_start_finds_three_and_four_byte_codes() {
let buf = [0, 0, 0, 1, 9, 0xF0, 0, 0, 1, 0x65];
let first = find_nal_start(&buf, 0).unwrap();
assert_eq!(first, NalStart { offset: 0, len: 4 });
let second = find_nal_start(&buf, first.offset + first.len).unwrap();
assert_eq!(second, NalStart { offset: 6, len: 3 });
}
#[test]
fn nal_start_returns_none_without_a_code() {
assert!(find_nal_start(&[0x01, 0x02, 0x00, 0x01], 0).is_none());
assert!(find_nal_start(&[], 0).is_none());
}
#[test]
fn rate_limit_resets_each_window() {
let mut rl = IngestRateLimit::new(100);
assert!(rl.allow(60));
assert!(rl.allow(40)); assert!(!rl.allow(1)); rl.window_start = Instant::now() - std::time::Duration::from_secs(2);
assert!(rl.allow(100));
}
#[test]
fn rate_limit_zero_is_unlimited() {
let mut rl = IngestRateLimit::new(0);
assert!(rl.allow(usize::MAX));
}
#[test]
fn keyframe_gate_holds_deltas_until_idr() {
let mut gate = KeyframeGate::new();
assert!(!gate.is_open());
assert!(!gate.admit(FrameType::Delta));
assert!(gate.admit(FrameType::Audio)); assert!(gate.admit(FrameType::Key)); assert!(gate.is_open());
assert!(gate.admit(FrameType::Delta));
}
}