use std::collections::VecDeque;
use web_time::Duration;
use crate::statistics::StatisticsCalculator;
use crate::{AudioPacket, Result};
#[derive(Debug, Clone)]
pub struct SmartFlushConfig {
pub target_level_threshold_ms: u32,
pub target_level_multiplier: u32,
}
impl Default for SmartFlushConfig {
fn default() -> Self {
Self {
target_level_threshold_ms: 500,
target_level_multiplier: 3,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum BufferReturnCode {
Ok,
Flushed,
PartialFlush,
NotFound,
BufferEmpty,
InvalidPacket,
}
#[derive(Debug)]
pub struct PacketBuffer {
max_packets: usize,
buffer: VecDeque<AudioPacket>,
smart_flush_config: SmartFlushConfig,
max_packet_age: Duration,
}
impl PacketBuffer {
pub fn new(max_packets: usize) -> Self {
Self {
max_packets,
buffer: VecDeque::with_capacity(max_packets),
smart_flush_config: SmartFlushConfig::default(),
max_packet_age: Duration::from_secs(2),
}
}
pub fn with_config(max_packets: usize, smart_flush_config: SmartFlushConfig) -> Self {
Self {
max_packets,
buffer: VecDeque::with_capacity(max_packets),
smart_flush_config,
max_packet_age: Duration::from_secs(2),
}
}
pub fn set_max_packet_age(&mut self, max_age: Duration) {
self.max_packet_age = max_age;
}
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn utilization(&self) -> f32 {
(self.buffer.len() as f32 / self.max_packets as f32) * 100.0
}
pub fn flush(&mut self, stats: &mut StatisticsCalculator) {
let flushed_count = self.buffer.len();
self.buffer.clear();
if flushed_count > 0 {
stats.buffer_flush();
log::debug!("Flushed {flushed_count} packets from buffer");
}
}
pub fn partial_flush(
&mut self,
target_level_ms: u32,
_sample_rate: u32,
stats: &mut StatisticsCalculator,
) -> Result<BufferReturnCode> {
if self.buffer.is_empty() {
return Ok(BufferReturnCode::BufferEmpty);
}
let target_duration = Duration::from_millis(target_level_ms as u64);
let mut current_duration = Duration::from_millis(0);
let mut keep_count = 0;
for packet in self.buffer.iter().rev() {
current_duration += Duration::from_millis(packet.duration_ms as u64);
keep_count += 1;
if current_duration >= target_duration {
break;
}
}
let remove_count = self.buffer.len().saturating_sub(keep_count);
if remove_count > 0 {
for _ in 0..remove_count {
if let Some(packet) = self.buffer.pop_front() {
stats.packet_discarded(packet.is_older_than(self.max_packet_age));
}
}
log::debug!("Partial flush: removed {remove_count} packets, kept {keep_count}");
return Ok(BufferReturnCode::PartialFlush);
}
Ok(BufferReturnCode::Ok)
}
pub fn insert_packet(
&mut self,
packet: AudioPacket,
stats: &mut StatisticsCalculator,
target_level_ms: u32,
) -> Result<BufferReturnCode> {
self.discard_old_packets(stats);
if self.should_smart_flush(target_level_ms) {
self.partial_flush(target_level_ms, packet.sample_rate, stats)?;
}
if self.buffer.len() >= self.max_packets {
self.partial_flush(target_level_ms, packet.sample_rate, stats)?;
if self.buffer.len() >= self.max_packets {
self.flush(stats);
log::warn!("Buffer overflow: performed full flush");
stats.buffer_flush();
}
}
let insert_pos = self.find_insert_position(&packet);
if self.is_duplicate(&packet, insert_pos) {
log::debug!(
"Discarding duplicate packet: seq={}, ts={}",
packet.header.sequence_number,
packet.header.timestamp
);
stats.packet_discarded(false);
return Ok(BufferReturnCode::Ok);
}
let is_reordered = insert_pos < self.buffer.len();
if is_reordered {
let expected_pos = self.buffer.len();
let distance = (expected_pos - insert_pos) as u16;
stats.packet_reordered(distance);
log::debug!(
"Reordered packet detected: seq={}, ts={}, insert_pos={}, expected_pos={}, distance={}",
packet.header.sequence_number,
packet.header.timestamp,
insert_pos,
expected_pos,
distance
);
} else {
stats.packet_in_order();
}
self.buffer.insert(insert_pos, packet);
let arrival_delay = self.calculate_arrival_delay(insert_pos);
stats.packet_arrived(arrival_delay);
Ok(BufferReturnCode::Ok)
}
pub fn peek_next_timestamp(&self) -> Option<u32> {
self.buffer.front().map(|packet| packet.header.timestamp)
}
pub fn peek_next_packet_from_timestamp(&self, timestamp: u32) -> Option<&AudioPacket> {
self.buffer
.iter()
.find(|packet| packet.header.timestamp >= timestamp)
}
pub fn get_next_packet(&mut self) -> Option<AudioPacket> {
self.buffer.pop_front()
}
pub fn discard_next_packet(
&mut self,
stats: &mut StatisticsCalculator,
) -> Result<BufferReturnCode> {
if let Some(packet) = self.buffer.pop_front() {
stats.packet_discarded(packet.is_older_than(self.max_packet_age));
Ok(BufferReturnCode::Ok)
} else {
Ok(BufferReturnCode::BufferEmpty)
}
}
pub fn discard_old_packets_by_timestamp(
&mut self,
timestamp_limit: u32,
stats: &mut StatisticsCalculator,
) {
let initial_len = self.buffer.len();
self.buffer.retain(|packet| {
let should_keep = packet.header.timestamp >= timestamp_limit;
if !should_keep {
stats.packet_discarded(true);
}
should_keep
});
let discarded = initial_len - self.buffer.len();
if discarded > 0 {
log::debug!("Discarded {discarded} old packets by timestamp");
}
}
pub fn get_span_duration_ms(&self) -> u32 {
if self.buffer.is_empty() {
return 0;
}
let oldest_ts = self.buffer.front().unwrap().header.timestamp;
let newest_ts = self.buffer.back().unwrap().header.timestamp;
let span_samples = if newest_ts >= oldest_ts {
newest_ts - oldest_ts
} else {
(u32::MAX - oldest_ts) + newest_ts + 1
};
let sample_rate = self.buffer.front().unwrap().sample_rate;
(span_samples * 1000) / sample_rate
}
pub fn get_total_content_duration_ms(&self) -> u32 {
self.buffer.iter().map(|packet| packet.duration_ms).sum()
}
pub fn num_samples_in_buffer(&self) -> usize {
self.buffer
.iter()
.map(|packet| packet.expected_samples())
.sum()
}
fn find_insert_position(&self, packet: &AudioPacket) -> usize {
let mut low = 0;
let mut high = self.buffer.len();
while low < high {
let mid = (low + high) / 2;
if self.buffer[mid].header.timestamp <= packet.header.timestamp {
low = mid + 1;
} else {
high = mid;
}
}
low
}
fn is_duplicate(&self, packet: &AudioPacket, insert_pos: usize) -> bool {
let check_positions = [
insert_pos.saturating_sub(1),
insert_pos,
(insert_pos + 1).min(self.buffer.len().saturating_sub(1)),
];
for &pos in &check_positions {
if pos < self.buffer.len() {
if let Some(existing) = self.buffer.get(pos) {
if existing.header.timestamp == packet.header.timestamp
&& existing.header.sequence_number == packet.header.sequence_number
&& existing.header.ssrc == packet.header.ssrc
{
return true;
}
}
}
}
false
}
fn calculate_arrival_delay(&self, insert_pos: usize) -> i32 {
if insert_pos == 0 {
0
} else {
(insert_pos as i32) * 10 }
}
fn should_smart_flush(&self, target_level_ms: u32) -> bool {
if self.buffer.is_empty() {
return false;
}
let current_span_ms = self.get_span_duration_ms();
let flush_threshold = self
.smart_flush_config
.target_level_threshold_ms
.max(target_level_ms)
* self.smart_flush_config.target_level_multiplier;
current_span_ms > flush_threshold
}
fn discard_old_packets(&mut self, stats: &mut StatisticsCalculator) {
let initial_len = self.buffer.len();
self.buffer.retain(|packet| {
let should_keep = !packet.is_older_than(self.max_packet_age);
if !should_keep {
stats.packet_discarded(true);
}
should_keep
});
let discarded = initial_len - self.buffer.len();
if discarded > 0 {
log::debug!("Discarded {discarded} stale packets");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packet::{AudioPacket, RtpHeader};
fn create_test_packet(seq: u16, ts: u32, duration_ms: u32) -> AudioPacket {
let header = RtpHeader::new(seq, ts, 12345, 96, false);
AudioPacket::new(header, vec![0; 160], 16000, 1, duration_ms)
}
#[test]
fn test_buffer_creation() {
let buffer = PacketBuffer::new(100);
assert!(buffer.is_empty());
assert_eq!(buffer.len(), 0);
assert_eq!(buffer.utilization(), 0.0);
}
#[test]
fn test_packet_insertion_and_ordering() {
let mut buffer = PacketBuffer::new(10);
let mut stats = StatisticsCalculator::new();
let packet3 = create_test_packet(3, 3000, 20);
let packet1 = create_test_packet(1, 1000, 20);
let packet2 = create_test_packet(2, 2000, 20);
buffer.insert_packet(packet3, &mut stats, 100).unwrap();
buffer.insert_packet(packet1, &mut stats, 100).unwrap();
buffer.insert_packet(packet2, &mut stats, 100).unwrap();
assert_eq!(buffer.len(), 3);
assert_eq!(buffer.peek_next_timestamp(), Some(1000));
let p1 = buffer.get_next_packet().unwrap();
assert_eq!(p1.header.timestamp, 1000);
let p2 = buffer.get_next_packet().unwrap();
assert_eq!(p2.header.timestamp, 2000);
let p3 = buffer.get_next_packet().unwrap();
assert_eq!(p3.header.timestamp, 3000);
}
#[test]
fn test_duplicate_detection() {
let mut buffer = PacketBuffer::new(10);
let mut stats = StatisticsCalculator::new();
let packet1 = create_test_packet(1, 1000, 20);
let packet1_dup = create_test_packet(1, 1000, 20);
buffer.insert_packet(packet1, &mut stats, 100).unwrap();
buffer.insert_packet(packet1_dup, &mut stats, 100).unwrap();
assert_eq!(buffer.len(), 1);
}
#[test]
fn test_buffer_overflow() {
let mut buffer = PacketBuffer::new(2);
let mut stats = StatisticsCalculator::new();
buffer
.insert_packet(create_test_packet(1, 1000, 20), &mut stats, 100)
.unwrap();
buffer
.insert_packet(create_test_packet(2, 2000, 20), &mut stats, 100)
.unwrap();
buffer
.insert_packet(create_test_packet(3, 3000, 20), &mut stats, 100)
.unwrap();
assert!(buffer.len() <= 2);
}
#[test]
fn test_span_duration_calculation() {
let mut buffer = PacketBuffer::new(10);
let mut stats = StatisticsCalculator::new();
buffer
.insert_packet(create_test_packet(1, 0, 20), &mut stats, 100)
.unwrap();
buffer
.insert_packet(create_test_packet(2, 320, 20), &mut stats, 100)
.unwrap(); buffer
.insert_packet(create_test_packet(3, 640, 20), &mut stats, 100)
.unwrap();
assert_eq!(buffer.get_span_duration_ms(), 40);
}
}