use crate::{
types::*,
Result,
};
use bytes::Bytes;
use std::collections::{BTreeMap, VecDeque};
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, warn, instrument};
#[derive(Debug, Clone)]
pub struct BufferedSegment {
pub segment: Segment,
pub data: Bytes,
pub start_time: f64,
pub end_time: f64,
pub consumed: bool,
}
#[derive(Debug, Clone)]
pub struct BufferConfig {
pub min_buffer_time: f64,
pub max_buffer_time: f64,
pub rebuffer_threshold: f64,
pub max_memory_bytes: usize,
pub prefetch_enabled: bool,
pub prefetch_count: usize,
}
impl Default for BufferConfig {
fn default() -> Self {
Self {
min_buffer_time: 10.0,
max_buffer_time: 30.0,
rebuffer_threshold: 2.0,
max_memory_bytes: 256 * 1024 * 1024, prefetch_enabled: true,
prefetch_count: 3,
}
}
}
pub struct BufferManager {
config: BufferConfig,
segments: RwLock<BTreeMap<u64, BufferedSegment>>,
playback_position: RwLock<f64>,
buffered_duration: RwLock<f64>,
memory_used: RwLock<usize>,
fetch_queue: Mutex<VecDeque<Segment>>,
}
impl BufferManager {
pub fn new(config: BufferConfig) -> Self {
Self {
config,
segments: RwLock::new(BTreeMap::new()),
playback_position: RwLock::new(0.0),
buffered_duration: RwLock::new(0.0),
memory_used: RwLock::new(0),
fetch_queue: Mutex::new(VecDeque::new()),
}
}
#[instrument(skip(self, data))]
pub async fn add_segment(&self, segment: Segment, data: Bytes) -> Result<()> {
let segment_duration = segment.duration.as_secs_f64();
let segment_size = data.len();
let current_memory = *self.memory_used.read().await;
if current_memory + segment_size > self.config.max_memory_bytes {
self.evict_segments(segment_size).await?;
}
let segments = self.segments.read().await;
let start_time = if let Some((_, last)) = segments.iter().last() {
last.end_time
} else {
0.0
};
drop(segments);
let buffered_segment = BufferedSegment {
segment: segment.clone(),
data,
start_time,
end_time: start_time + segment_duration,
consumed: false,
};
let mut segments = self.segments.write().await;
segments.insert(segment.number, buffered_segment);
*self.buffered_duration.write().await += segment_duration;
*self.memory_used.write().await += segment_size;
debug!(
segment = segment.number,
duration = segment_duration,
buffer_level = *self.buffered_duration.read().await,
"Segment added to buffer"
);
Ok(())
}
pub async fn get_next_segment(&self) -> Option<BufferedSegment> {
let playback_pos = *self.playback_position.read().await;
let segments = self.segments.read().await;
for (_, segment) in segments.iter() {
if !segment.consumed && segment.end_time > playback_pos {
return Some(segment.clone());
}
}
None
}
pub async fn get_segment_at(&self, time: f64) -> Option<BufferedSegment> {
let segments = self.segments.read().await;
for (_, segment) in segments.iter() {
if time >= segment.start_time && time < segment.end_time {
return Some(segment.clone());
}
}
None
}
pub async fn consume_segment(&self, sequence: u64) {
let mut segments = self.segments.write().await;
if let Some(segment) = segments.get_mut(&sequence) {
segment.consumed = true;
}
}
pub async fn update_position(&self, position: f64) {
*self.playback_position.write().await = position;
self.cleanup_consumed(position).await;
}
pub async fn buffer_level(&self) -> f64 {
let playback_pos = *self.playback_position.read().await;
let segments = self.segments.read().await;
let mut buffered = 0.0;
for (_, segment) in segments.iter() {
if segment.end_time > playback_pos && !segment.consumed {
let start = segment.start_time.max(playback_pos);
buffered += segment.end_time - start;
}
}
buffered
}
pub async fn is_buffer_healthy(&self) -> bool {
self.buffer_level().await >= self.config.rebuffer_threshold
}
pub async fn needs_data(&self) -> bool {
self.buffer_level().await < self.config.max_buffer_time
}
pub async fn can_start_playback(&self) -> bool {
self.buffer_level().await >= self.config.min_buffer_time
}
pub async fn buffered_ranges(&self) -> Vec<(f64, f64)> {
let segments = self.segments.read().await;
let mut ranges = Vec::new();
let mut current_start: Option<f64> = None;
let mut current_end: f64 = 0.0;
for (_, segment) in segments.iter() {
if !segment.consumed {
match current_start {
None => {
current_start = Some(segment.start_time);
current_end = segment.end_time;
}
Some(_) => {
if (segment.start_time - current_end).abs() < 0.1 {
current_end = segment.end_time;
} else {
ranges.push((current_start.unwrap(), current_end));
current_start = Some(segment.start_time);
current_end = segment.end_time;
}
}
}
}
}
if let Some(start) = current_start {
ranges.push((start, current_end));
}
ranges
}
pub async fn seek(&self, position: f64) -> Result<bool> {
*self.playback_position.write().await = position;
let is_buffered = self.get_segment_at(position).await.is_some();
if !is_buffered {
self.clear().await;
}
Ok(is_buffered)
}
pub async fn clear(&self) {
let mut segments = self.segments.write().await;
segments.clear();
*self.buffered_duration.write().await = 0.0;
*self.memory_used.write().await = 0;
let mut queue = self.fetch_queue.lock().await;
queue.clear();
debug!("Buffer cleared");
}
async fn evict_segments(&self, needed_bytes: usize) -> Result<()> {
let playback_pos = *self.playback_position.read().await;
let mut segments = self.segments.write().await;
let mut memory = self.memory_used.write().await;
let mut duration = self.buffered_duration.write().await;
let mut freed = 0;
let mut to_remove = Vec::new();
for (&seq, segment) in segments.iter() {
if freed >= needed_bytes {
break;
}
if segment.consumed || segment.end_time < playback_pos - 5.0 {
to_remove.push(seq);
freed += segment.data.len();
}
}
for seq in to_remove {
if let Some(segment) = segments.remove(&seq) {
*memory -= segment.data.len();
*duration -= segment.segment.duration.as_secs_f64();
debug!(segment = seq, "Evicted segment from buffer");
}
}
if freed < needed_bytes {
warn!(
needed = needed_bytes,
freed = freed,
"Could not free enough memory"
);
}
Ok(())
}
async fn cleanup_consumed(&self, playback_pos: f64) {
let threshold = playback_pos - 10.0;
let mut segments = self.segments.write().await;
let mut memory = self.memory_used.write().await;
let mut duration = self.buffered_duration.write().await;
let to_remove: Vec<_> = segments
.iter()
.filter(|(_, s)| s.consumed && s.end_time < threshold)
.map(|(&seq, _)| seq)
.collect();
for seq in to_remove {
if let Some(segment) = segments.remove(&seq) {
*memory -= segment.data.len();
*duration -= segment.segment.duration.as_secs_f64();
}
}
}
pub async fn stats(&self) -> BufferStats {
let segments = self.segments.read().await;
let ranges = self.buffered_ranges().await;
BufferStats {
segment_count: segments.len(),
buffer_level: self.buffer_level().await,
memory_used: *self.memory_used.read().await,
buffered_ranges: ranges,
playback_position: *self.playback_position.read().await,
}
}
pub async fn queue_fetch(&self, segments: Vec<Segment>) {
let mut queue = self.fetch_queue.lock().await;
for segment in segments {
queue.push_back(segment);
}
}
pub async fn next_fetch(&self) -> Option<Segment> {
let mut queue = self.fetch_queue.lock().await;
queue.pop_front()
}
}
#[derive(Debug, Clone)]
pub struct BufferStats {
pub segment_count: usize,
pub buffer_level: f64,
pub memory_used: usize,
pub buffered_ranges: Vec<(f64, f64)>,
pub playback_position: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use url::Url;
fn create_test_segment(num: u64) -> Segment {
Segment {
number: num,
uri: Url::parse(&format!("https://example.com/seg{}.ts", num)).unwrap(),
duration: Duration::from_secs(4),
byte_range: None,
encryption: None,
discontinuity_sequence: 0,
program_date_time: None,
}
}
#[tokio::test]
async fn test_add_segment() {
let buffer = BufferManager::new(BufferConfig::default());
let segment = create_test_segment(1);
let data = Bytes::from(vec![0u8; 1024]);
buffer.add_segment(segment, data).await.unwrap();
assert_eq!(buffer.buffer_level().await, 4.0);
}
#[tokio::test]
async fn test_buffer_level() {
let buffer = BufferManager::new(BufferConfig::default());
for i in 1..=5 {
let segment = create_test_segment(i);
let data = Bytes::from(vec![0u8; 1024]);
buffer.add_segment(segment, data).await.unwrap();
}
assert_eq!(buffer.buffer_level().await, 20.0);
buffer.update_position(8.0).await;
assert!((buffer.buffer_level().await - 12.0).abs() < 0.1);
}
#[tokio::test]
async fn test_seek_buffered() {
let buffer = BufferManager::new(BufferConfig::default());
for i in 1..=5 {
let segment = create_test_segment(i);
let data = Bytes::from(vec![0u8; 1024]);
buffer.add_segment(segment, data).await.unwrap();
}
let is_buffered = buffer.seek(10.0).await.unwrap();
assert!(is_buffered);
let is_buffered = buffer.seek(100.0).await.unwrap();
assert!(!is_buffered);
}
}