1use crate::{
10 types::*,
11 Result,
12};
13use bytes::Bytes;
14use std::collections::{BTreeMap, VecDeque};
15use tokio::sync::{Mutex, RwLock};
16use tracing::{debug, warn, instrument};
17
18#[derive(Debug, Clone)]
20pub struct BufferedSegment {
21 pub segment: Segment,
23 pub data: Bytes,
25 pub start_time: f64,
27 pub end_time: f64,
29 pub consumed: bool,
31}
32
33#[derive(Debug, Clone)]
35pub struct BufferConfig {
36 pub min_buffer_time: f64,
38 pub max_buffer_time: f64,
40 pub rebuffer_threshold: f64,
42 pub max_memory_bytes: usize,
44 pub prefetch_enabled: bool,
46 pub prefetch_count: usize,
48}
49
50impl Default for BufferConfig {
51 fn default() -> Self {
52 Self {
53 min_buffer_time: 10.0,
54 max_buffer_time: 30.0,
55 rebuffer_threshold: 2.0,
56 max_memory_bytes: 256 * 1024 * 1024, prefetch_enabled: true,
58 prefetch_count: 3,
59 }
60 }
61}
62
63pub struct BufferManager {
65 config: BufferConfig,
67 segments: RwLock<BTreeMap<u64, BufferedSegment>>,
69 playback_position: RwLock<f64>,
71 buffered_duration: RwLock<f64>,
73 memory_used: RwLock<usize>,
75 fetch_queue: Mutex<VecDeque<Segment>>,
77}
78
79impl BufferManager {
80 pub fn new(config: BufferConfig) -> Self {
82 Self {
83 config,
84 segments: RwLock::new(BTreeMap::new()),
85 playback_position: RwLock::new(0.0),
86 buffered_duration: RwLock::new(0.0),
87 memory_used: RwLock::new(0),
88 fetch_queue: Mutex::new(VecDeque::new()),
89 }
90 }
91
92 #[instrument(skip(self, data))]
94 pub async fn add_segment(&self, segment: Segment, data: Bytes) -> Result<()> {
95 let segment_duration = segment.duration.as_secs_f64();
96 let segment_size = data.len();
97
98 let current_memory = *self.memory_used.read().await;
100 if current_memory + segment_size > self.config.max_memory_bytes {
101 self.evict_segments(segment_size).await?;
103 }
104
105 let segments = self.segments.read().await;
106 let start_time = if let Some((_, last)) = segments.iter().last() {
107 last.end_time
108 } else {
109 0.0
110 };
111 drop(segments);
112
113 let buffered_segment = BufferedSegment {
114 segment: segment.clone(),
115 data,
116 start_time,
117 end_time: start_time + segment_duration,
118 consumed: false,
119 };
120
121 let mut segments = self.segments.write().await;
123 segments.insert(segment.number, buffered_segment);
124
125 *self.buffered_duration.write().await += segment_duration;
127 *self.memory_used.write().await += segment_size;
128
129 debug!(
130 segment = segment.number,
131 duration = segment_duration,
132 buffer_level = *self.buffered_duration.read().await,
133 "Segment added to buffer"
134 );
135
136 Ok(())
137 }
138
139 pub async fn get_next_segment(&self) -> Option<BufferedSegment> {
141 let playback_pos = *self.playback_position.read().await;
142
143 let segments = self.segments.read().await;
144 for (_, segment) in segments.iter() {
145 if !segment.consumed && segment.end_time > playback_pos {
146 return Some(segment.clone());
147 }
148 }
149 None
150 }
151
152 pub async fn get_segment_at(&self, time: f64) -> Option<BufferedSegment> {
154 let segments = self.segments.read().await;
155 for (_, segment) in segments.iter() {
156 if time >= segment.start_time && time < segment.end_time {
157 return Some(segment.clone());
158 }
159 }
160 None
161 }
162
163 pub async fn consume_segment(&self, sequence: u64) {
165 let mut segments = self.segments.write().await;
166 if let Some(segment) = segments.get_mut(&sequence) {
167 segment.consumed = true;
168 }
169 }
170
171 pub async fn update_position(&self, position: f64) {
173 *self.playback_position.write().await = position;
174
175 self.cleanup_consumed(position).await;
177 }
178
179 pub async fn buffer_level(&self) -> f64 {
181 let playback_pos = *self.playback_position.read().await;
182 let segments = self.segments.read().await;
183
184 let mut buffered = 0.0;
185 for (_, segment) in segments.iter() {
186 if segment.end_time > playback_pos && !segment.consumed {
187 let start = segment.start_time.max(playback_pos);
188 buffered += segment.end_time - start;
189 }
190 }
191 buffered
192 }
193
194 pub async fn is_buffer_healthy(&self) -> bool {
196 self.buffer_level().await >= self.config.rebuffer_threshold
197 }
198
199 pub async fn needs_data(&self) -> bool {
201 self.buffer_level().await < self.config.max_buffer_time
202 }
203
204 pub async fn can_start_playback(&self) -> bool {
206 self.buffer_level().await >= self.config.min_buffer_time
207 }
208
209 pub async fn buffered_ranges(&self) -> Vec<(f64, f64)> {
211 let segments = self.segments.read().await;
212 let mut ranges = Vec::new();
213
214 let mut current_start: Option<f64> = None;
215 let mut current_end: f64 = 0.0;
216
217 for (_, segment) in segments.iter() {
218 if !segment.consumed {
219 match current_start {
220 None => {
221 current_start = Some(segment.start_time);
222 current_end = segment.end_time;
223 }
224 Some(_) => {
225 if (segment.start_time - current_end).abs() < 0.1 {
227 current_end = segment.end_time;
229 } else {
230 ranges.push((current_start.unwrap(), current_end));
232 current_start = Some(segment.start_time);
233 current_end = segment.end_time;
234 }
235 }
236 }
237 }
238 }
239
240 if let Some(start) = current_start {
241 ranges.push((start, current_end));
242 }
243
244 ranges
245 }
246
247 pub async fn seek(&self, position: f64) -> Result<bool> {
249 *self.playback_position.write().await = position;
250
251 let is_buffered = self.get_segment_at(position).await.is_some();
253
254 if !is_buffered {
255 self.clear().await;
257 }
258
259 Ok(is_buffered)
260 }
261
262 pub async fn clear(&self) {
264 let mut segments = self.segments.write().await;
265 segments.clear();
266
267 *self.buffered_duration.write().await = 0.0;
268 *self.memory_used.write().await = 0;
269
270 let mut queue = self.fetch_queue.lock().await;
271 queue.clear();
272
273 debug!("Buffer cleared");
274 }
275
276 async fn evict_segments(&self, needed_bytes: usize) -> Result<()> {
278 let playback_pos = *self.playback_position.read().await;
279 let mut segments = self.segments.write().await;
280 let mut memory = self.memory_used.write().await;
281 let mut duration = self.buffered_duration.write().await;
282
283 let mut freed = 0;
284 let mut to_remove = Vec::new();
285
286 for (&seq, segment) in segments.iter() {
288 if freed >= needed_bytes {
289 break;
290 }
291 if segment.consumed || segment.end_time < playback_pos - 5.0 {
292 to_remove.push(seq);
293 freed += segment.data.len();
294 }
295 }
296
297 for seq in to_remove {
299 if let Some(segment) = segments.remove(&seq) {
300 *memory -= segment.data.len();
301 *duration -= segment.segment.duration.as_secs_f64();
302 debug!(segment = seq, "Evicted segment from buffer");
303 }
304 }
305
306 if freed < needed_bytes {
307 warn!(
308 needed = needed_bytes,
309 freed = freed,
310 "Could not free enough memory"
311 );
312 }
313
314 Ok(())
315 }
316
317 async fn cleanup_consumed(&self, playback_pos: f64) {
319 let threshold = playback_pos - 10.0; let mut segments = self.segments.write().await;
322 let mut memory = self.memory_used.write().await;
323 let mut duration = self.buffered_duration.write().await;
324
325 let to_remove: Vec<_> = segments
326 .iter()
327 .filter(|(_, s)| s.consumed && s.end_time < threshold)
328 .map(|(&seq, _)| seq)
329 .collect();
330
331 for seq in to_remove {
332 if let Some(segment) = segments.remove(&seq) {
333 *memory -= segment.data.len();
334 *duration -= segment.segment.duration.as_secs_f64();
335 }
336 }
337 }
338
339 pub async fn stats(&self) -> BufferStats {
341 let segments = self.segments.read().await;
342 let ranges = self.buffered_ranges().await;
343
344 BufferStats {
345 segment_count: segments.len(),
346 buffer_level: self.buffer_level().await,
347 memory_used: *self.memory_used.read().await,
348 buffered_ranges: ranges,
349 playback_position: *self.playback_position.read().await,
350 }
351 }
352
353 pub async fn queue_fetch(&self, segments: Vec<Segment>) {
355 let mut queue = self.fetch_queue.lock().await;
356 for segment in segments {
357 queue.push_back(segment);
358 }
359 }
360
361 pub async fn next_fetch(&self) -> Option<Segment> {
363 let mut queue = self.fetch_queue.lock().await;
364 queue.pop_front()
365 }
366}
367
368#[derive(Debug, Clone)]
370pub struct BufferStats {
371 pub segment_count: usize,
372 pub buffer_level: f64,
373 pub memory_used: usize,
374 pub buffered_ranges: Vec<(f64, f64)>,
375 pub playback_position: f64,
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use std::time::Duration;
382 use url::Url;
383
384 fn create_test_segment(num: u64) -> Segment {
385 Segment {
386 number: num,
387 uri: Url::parse(&format!("https://example.com/seg{}.ts", num)).unwrap(),
388 duration: Duration::from_secs(4),
389 byte_range: None,
390 encryption: None,
391 discontinuity_sequence: 0,
392 program_date_time: None,
393 }
394 }
395
396 #[tokio::test]
397 async fn test_add_segment() {
398 let buffer = BufferManager::new(BufferConfig::default());
399
400 let segment = create_test_segment(1);
401 let data = Bytes::from(vec![0u8; 1024]);
402
403 buffer.add_segment(segment, data).await.unwrap();
404
405 assert_eq!(buffer.buffer_level().await, 4.0);
406 }
407
408 #[tokio::test]
409 async fn test_buffer_level() {
410 let buffer = BufferManager::new(BufferConfig::default());
411
412 for i in 1..=5 {
413 let segment = create_test_segment(i);
414 let data = Bytes::from(vec![0u8; 1024]);
415 buffer.add_segment(segment, data).await.unwrap();
416 }
417
418 assert_eq!(buffer.buffer_level().await, 20.0);
419
420 buffer.update_position(8.0).await;
421 assert!((buffer.buffer_level().await - 12.0).abs() < 0.1);
422 }
423
424 #[tokio::test]
425 async fn test_seek_buffered() {
426 let buffer = BufferManager::new(BufferConfig::default());
427
428 for i in 1..=5 {
429 let segment = create_test_segment(i);
430 let data = Bytes::from(vec![0u8; 1024]);
431 buffer.add_segment(segment, data).await.unwrap();
432 }
433
434 let is_buffered = buffer.seek(10.0).await.unwrap();
436 assert!(is_buffered);
437
438 let is_buffered = buffer.seek(100.0).await.unwrap();
440 assert!(!is_buffered);
441 }
442}