use std::collections::VecDeque;
use std::time::{Duration, Instant};
const MAX_BUFFER_SIZE: usize = 256;
#[derive(Debug)]
pub struct BufferedPacket {
pub sequence: u32,
pub timestamp: u32,
pub data: Vec<u8>,
pub received_at: Instant,
}
#[derive(Debug, Clone, Default)]
pub struct JitterStats {
pub received: u64,
pub played: u64,
pub dropped_late: u64,
pub dropped_duplicate: u64,
pub concealed: u64,
pub jitter_us: u32,
pub buffer_depth: usize,
}
pub struct JitterBuffer {
buffer: VecDeque<BufferedPacket>,
target_delay: Duration,
min_delay: Duration,
max_delay: Duration,
current_delay: Duration,
next_seq: u32,
playing: bool,
play_start: Option<Instant>,
last_played_ts: u32,
stats: JitterStats,
jitter_ema: f64,
last_arrival: Option<Instant>,
last_timestamp: Option<u32>,
}
impl JitterBuffer {
pub fn new(target_delay: Duration) -> Self {
Self {
buffer: VecDeque::with_capacity(64),
target_delay,
min_delay: target_delay / 2,
max_delay: target_delay * 3,
current_delay: target_delay,
next_seq: 0,
playing: false,
play_start: None,
last_played_ts: 0,
stats: JitterStats::default(),
jitter_ema: 0.0,
last_arrival: None,
last_timestamp: None,
}
}
pub fn with_bounds(target: Duration, min: Duration, max: Duration) -> Self {
let mut jb = Self::new(target);
jb.min_delay = min;
jb.max_delay = max;
jb
}
pub fn push(&mut self, sequence: u32, timestamp: u32, data: Vec<u8>) -> bool {
let now = Instant::now();
self.stats.received += 1;
self.update_jitter(now, timestamp);
if self.playing && self.seq_before(sequence, self.next_seq) {
self.stats.dropped_duplicate += 1;
return false;
}
if self.playing {
if let Some(start) = self.play_start {
let elapsed = now.duration_since(start);
let packet_time = self.ts_to_duration(timestamp, self.last_played_ts);
if packet_time + self.current_delay < elapsed {
self.stats.dropped_late += 1;
return false;
}
}
}
if self.buffer.len() >= MAX_BUFFER_SIZE {
self.buffer.pop_front();
}
let packet = BufferedPacket {
sequence,
timestamp,
data,
received_at: now,
};
let pos = self.buffer.iter().position(|p| self.seq_before(sequence, p.sequence));
match pos {
Some(i) => self.buffer.insert(i, packet),
None => self.buffer.push_back(packet),
}
self.stats.buffer_depth = self.buffer.len();
true
}
pub fn pop(&mut self) -> Option<BufferedPacket> {
let now = Instant::now();
if !self.playing {
if self.buffer.len() >= 3 {
self.playing = true;
self.play_start = Some(now);
if let Some(first) = self.buffer.front() {
self.next_seq = first.sequence;
self.last_played_ts = first.timestamp;
}
} else {
return None;
}
}
if let Some(front) = self.buffer.front() {
if front.sequence == self.next_seq {
let wait_time = front.received_at + self.current_delay;
if now >= wait_time {
let packet = self.buffer.pop_front().unwrap();
self.next_seq = self.next_seq.wrapping_add(1);
self.last_played_ts = packet.timestamp;
self.stats.played += 1;
self.stats.buffer_depth = self.buffer.len();
self.adapt_delay();
return Some(packet);
}
} else if self.seq_before(self.next_seq, front.sequence) {
self.stats.concealed += 1;
self.next_seq = front.sequence;
}
}
None
}
#[inline]
pub fn stats(&self) -> &JitterStats {
&self.stats
}
#[inline]
pub fn current_delay(&self) -> Duration {
self.current_delay
}
#[inline]
pub fn depth(&self) -> usize {
self.buffer.len()
}
pub fn reset(&mut self) {
self.buffer.clear();
self.playing = false;
self.play_start = None;
self.next_seq = 0;
self.current_delay = self.target_delay;
}
fn update_jitter(&mut self, arrival: Instant, timestamp: u32) {
if let (Some(last_arr), Some(last_ts)) = (self.last_arrival, self.last_timestamp) {
let arrival_diff = arrival.duration_since(last_arr).as_micros() as i64;
let ts_diff = timestamp.wrapping_sub(last_ts) as i64;
let d = (arrival_diff - ts_diff).abs() as f64;
self.jitter_ema += (d - self.jitter_ema) / 16.0;
self.stats.jitter_us = self.jitter_ema as u32;
}
self.last_arrival = Some(arrival);
self.last_timestamp = Some(timestamp);
}
fn adapt_delay(&mut self) {
let jitter_ms = self.jitter_ema / 1000.0;
let target_ms = self.target_delay.as_millis() as f64;
if jitter_ms > target_ms * 0.8 {
let new_delay = self.current_delay + Duration::from_millis(5);
self.current_delay = new_delay.min(self.max_delay);
} else if jitter_ms < target_ms * 0.3 && self.buffer.len() > 5 {
let new_delay = self.current_delay.saturating_sub(Duration::from_millis(2));
self.current_delay = new_delay.max(self.min_delay);
}
}
#[inline]
fn seq_before(&self, a: u32, b: u32) -> bool {
let diff = a.wrapping_sub(b) as i32;
diff < 0
}
#[inline]
fn ts_to_duration(&self, ts: u32, base_ts: u32) -> Duration {
let diff = ts.wrapping_sub(base_ts);
Duration::from_millis(diff as u64)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jitter_buffer_basic() {
let mut jb = JitterBuffer::new(Duration::from_millis(50));
assert!(jb.push(0, 0, vec![1]));
assert!(jb.push(1, 20, vec![2]));
assert!(jb.push(2, 40, vec![3]));
assert_eq!(jb.depth(), 3);
}
#[test]
fn test_jitter_buffer_ordering() {
let mut jb = JitterBuffer::new(Duration::from_millis(10));
jb.push(2, 40, vec![3]);
jb.push(0, 0, vec![1]);
jb.push(1, 20, vec![2]);
std::thread::sleep(Duration::from_millis(15));
let p1 = jb.pop().unwrap();
assert_eq!(p1.sequence, 0);
}
#[test]
fn test_jitter_buffer_duplicate() {
let mut jb = JitterBuffer::new(Duration::from_millis(10));
jb.push(0, 0, vec![1]);
jb.push(1, 20, vec![2]);
jb.push(2, 40, vec![3]);
std::thread::sleep(Duration::from_millis(15));
jb.pop();
assert!(!jb.push(0, 0, vec![1]));
assert_eq!(jb.stats().dropped_duplicate, 1);
}
#[test]
fn test_jitter_stats() {
let mut jb = JitterBuffer::new(Duration::from_millis(5));
for i in 0..10 {
jb.push(i, i * 20, vec![i as u8]);
}
assert_eq!(jb.stats().received, 10);
assert!(jb.stats().buffer_depth > 0);
}
}