use std::cmp;
use std::io;
use std::io::{IoSliceMut, Read, Seek};
use std::ops::Sub;
use super::SeekBuffered;
use super::{MediaSource, ReadBytes};
#[inline(always)]
fn unexpected_eof_error<T>() -> io::Result<T> {
Err(io::Error::from(io::ErrorKind::UnexpectedEof))
}
pub struct MediaSourceStreamOptions {
pub buffer_len: usize,
}
impl Default for MediaSourceStreamOptions {
fn default() -> Self {
MediaSourceStreamOptions { buffer_len: 64 * 1024 }
}
}
pub struct MediaSourceStream<'s> {
inner: Box<dyn MediaSource + 's>,
ring: Box<[u8]>,
ring_mask: usize,
read_pos: usize,
write_pos: usize,
read_block_len: usize,
abs_pos: u64,
rel_pos: u64,
}
impl<'s> MediaSourceStream<'s> {
const MIN_BLOCK_LEN: usize = 1 * 1024;
const MAX_BLOCK_LEN: usize = 32 * 1024;
pub fn new(source: Box<dyn MediaSource + 's>, options: MediaSourceStreamOptions) -> Self {
assert!(options.buffer_len.count_ones() == 1);
assert!(options.buffer_len > Self::MAX_BLOCK_LEN);
MediaSourceStream {
inner: source,
ring: vec![0; options.buffer_len].into_boxed_slice(),
ring_mask: options.buffer_len - 1,
read_pos: 0,
write_pos: 0,
read_block_len: Self::MIN_BLOCK_LEN,
abs_pos: 0,
rel_pos: 0,
}
}
#[inline(always)]
fn is_buffer_exhausted(&self) -> bool {
self.read_pos == self.write_pos
}
fn fetch(&mut self) -> io::Result<()> {
if self.is_buffer_exhausted() {
let (vec1, vec0) = self.ring.split_at_mut(self.write_pos);
let actual_read_len = if vec0.len() >= self.read_block_len {
self.inner.read(&mut vec0[..self.read_block_len])?
}
else {
let rem = self.read_block_len - vec0.len();
let ring_vectors = &mut [IoSliceMut::new(vec0), IoSliceMut::new(&mut vec1[..rem])];
self.inner.read_vectored(ring_vectors)?
};
self.write_pos = (self.write_pos + actual_read_len) & self.ring_mask;
self.abs_pos += actual_read_len as u64;
self.rel_pos += actual_read_len as u64;
self.read_block_len = cmp::min(self.read_block_len << 1, Self::MAX_BLOCK_LEN);
}
Ok(())
}
fn fetch_or_eof(&mut self) -> io::Result<()> {
self.fetch()?;
if self.is_buffer_exhausted() {
return unexpected_eof_error();
}
Ok(())
}
#[inline(always)]
fn consume(&mut self, len: usize) {
self.read_pos = (self.read_pos + len) & self.ring_mask;
}
#[inline(always)]
fn continguous_buf(&self) -> &[u8] {
if self.write_pos >= self.read_pos {
&self.ring[self.read_pos..self.write_pos]
}
else {
&self.ring[self.read_pos..]
}
}
fn reset(&mut self, pos: u64) {
self.read_pos = 0;
self.write_pos = 0;
self.read_block_len = Self::MIN_BLOCK_LEN;
self.abs_pos = pos;
self.rel_pos = 0;
}
}
impl MediaSource for MediaSourceStream<'_> {
#[inline]
fn is_seekable(&self) -> bool {
self.inner.is_seekable()
}
#[inline]
fn byte_len(&self) -> Option<u64> {
self.inner.byte_len()
}
}
impl io::Read for MediaSourceStream<'_> {
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
let read_len = buf.len();
while !buf.is_empty() {
self.fetch()?;
match self.continguous_buf().read(buf) {
Ok(0) => break,
Ok(count) => {
buf = &mut buf[count..];
self.consume(count);
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(read_len - buf.len())
}
}
impl io::Seek for MediaSourceStream<'_> {
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
let pos = match pos {
io::SeekFrom::Current(0) => return Ok(self.pos()),
io::SeekFrom::Current(delta_pos) => {
let delta = delta_pos - self.unread_buffer_len() as i64;
self.inner.seek(io::SeekFrom::Current(delta))
}
_ => self.inner.seek(pos),
}?;
self.reset(pos);
Ok(pos)
}
}
impl ReadBytes for MediaSourceStream<'_> {
#[inline(always)]
fn read_byte(&mut self) -> io::Result<u8> {
if self.is_buffer_exhausted() {
self.fetch_or_eof()?;
}
let value = self.ring[self.read_pos];
self.consume(1);
Ok(value)
}
fn read_double_bytes(&mut self) -> io::Result<[u8; 2]> {
let mut bytes = [0; 2];
let buf = self.continguous_buf();
if buf.len() >= 2 {
bytes.copy_from_slice(&buf[..2]);
self.consume(2);
}
else {
for byte in bytes.iter_mut() {
*byte = self.read_byte()?;
}
};
Ok(bytes)
}
fn read_triple_bytes(&mut self) -> io::Result<[u8; 3]> {
let mut bytes = [0; 3];
let buf = self.continguous_buf();
if buf.len() >= 3 {
bytes.copy_from_slice(&buf[..3]);
self.consume(3);
}
else {
for byte in bytes.iter_mut() {
*byte = self.read_byte()?;
}
};
Ok(bytes)
}
fn read_quad_bytes(&mut self) -> io::Result<[u8; 4]> {
let mut bytes = [0; 4];
let buf = self.continguous_buf();
if buf.len() >= 4 {
bytes.copy_from_slice(&buf[..4]);
self.consume(4);
}
else {
for byte in bytes.iter_mut() {
*byte = self.read_byte()?;
}
};
Ok(bytes)
}
fn read_buf(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let read = self.read(buf)?;
if !buf.is_empty() && read == 0 { unexpected_eof_error() } else { Ok(read) }
}
fn read_buf_exact(&mut self, mut buf: &mut [u8]) -> io::Result<()> {
while !buf.is_empty() {
match self.read(buf) {
Ok(0) => break,
Ok(count) => {
buf = &mut buf[count..];
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
if !buf.is_empty() { unexpected_eof_error() } else { Ok(()) }
}
fn scan_bytes_aligned<'a>(
&mut self,
_: &[u8],
_: usize,
_: &'a mut [u8],
) -> io::Result<&'a mut [u8]> {
unimplemented!();
}
fn ignore_bytes(&mut self, mut count: u64) -> io::Result<()> {
let ring_len = self.ring.len() as u64;
while count >= 2 * ring_len && self.is_seekable() {
let delta = count.clamp(0, i64::MAX as u64).sub(ring_len);
self.seek(io::SeekFrom::Current(delta as i64))?;
count -= delta;
}
while count > 0 {
self.fetch_or_eof()?;
let discard_count = cmp::min(self.unread_buffer_len() as u64, count);
self.consume(discard_count as usize);
count -= discard_count;
}
Ok(())
}
fn pos(&self) -> u64 {
self.abs_pos - self.unread_buffer_len() as u64
}
}
impl SeekBuffered for MediaSourceStream<'_> {
fn ensure_seekback_buffer(&mut self, len: usize) {
let ring_len = self.ring.len();
let new_ring_len = (Self::MAX_BLOCK_LEN + len).next_power_of_two();
if ring_len < new_ring_len {
let mut new_ring = vec![0; new_ring_len].into_boxed_slice();
let (vec0, vec1) = if self.write_pos >= self.read_pos {
(&self.ring[self.read_pos..self.write_pos], None)
}
else {
(&self.ring[self.read_pos..], Some(&self.ring[..self.write_pos]))
};
let vec0_len = vec0.len();
new_ring[..vec0_len].copy_from_slice(vec0);
self.write_pos = if let Some(vec1) = vec1 {
let total_len = vec0_len + vec1.len();
new_ring[vec0_len..total_len].copy_from_slice(vec1);
total_len
}
else {
vec0_len
};
self.ring = new_ring;
self.ring_mask = new_ring_len - 1;
self.read_pos = 0;
}
}
fn unread_buffer_len(&self) -> usize {
if self.write_pos >= self.read_pos {
self.write_pos - self.read_pos
}
else {
self.write_pos + (self.ring.len() - self.read_pos)
}
}
fn read_buffer_len(&self) -> usize {
let unread_len = self.unread_buffer_len();
cmp::min(self.ring.len(), self.rel_pos as usize) - unread_len
}
fn seek_buffered(&mut self, pos: u64) -> u64 {
let old_pos = self.pos();
let delta = if pos > old_pos {
assert!(pos - old_pos < isize::MAX as u64);
(pos - old_pos) as isize
}
else if pos < old_pos {
assert!(old_pos - pos < isize::MAX as u64);
-((old_pos - pos) as isize)
}
else {
0
};
self.seek_buffered_rel(delta)
}
fn seek_buffered_rel(&mut self, delta: isize) -> u64 {
if delta < 0 {
let abs_delta = cmp::min((-delta) as usize, self.read_buffer_len());
self.read_pos = (self.read_pos + self.ring.len() - abs_delta) & self.ring_mask;
}
else if delta > 0 {
let abs_delta = cmp::min(delta as usize, self.unread_buffer_len());
self.read_pos = (self.read_pos + abs_delta) & self.ring_mask;
}
self.pos()
}
}
#[cfg(test)]
mod tests {
use super::{MediaSourceStream, ReadBytes, SeekBuffered};
use std::io::{Cursor, Read};
fn generate_random_bytes(len: usize) -> Box<[u8]> {
let mut lcg: u32 = 0xec57c4bf;
let mut bytes = vec![0; len];
for quad in bytes.chunks_mut(4) {
lcg = lcg.wrapping_mul(1664525).wrapping_add(1013904223);
for (src, dest) in quad.iter_mut().zip(&lcg.to_le_bytes()) {
*src = *dest;
}
}
bytes.into_boxed_slice()
}
#[test]
fn verify_mss_read() {
let data = generate_random_bytes(5 * 96 * 1024);
let ms = Cursor::new(data.clone());
let mut mss = MediaSourceStream::new(Box::new(ms), Default::default());
let mut buf = &data[..];
for byte in &buf[..96 * 1024] {
assert_eq!(*byte, mss.read_byte().unwrap());
}
mss.ignore_bytes(11).unwrap();
buf = &buf[11 + (96 * 1024)..];
for bytes in buf[..2 * 48 * 1024].chunks_exact(2) {
assert_eq!(bytes, &mss.read_double_bytes().unwrap());
}
mss.ignore_bytes(33).unwrap();
buf = &buf[33 + (2 * 48 * 1024)..];
for bytes in buf[..3 * 32 * 1024].chunks_exact(3) {
assert_eq!(bytes, &mss.read_triple_bytes().unwrap());
}
mss.ignore_bytes(55).unwrap();
buf = &buf[55 + (3 * 32 * 1024)..];
for bytes in buf[..4 * 24 * 1024].chunks_exact(4) {
assert_eq!(bytes, &mss.read_quad_bytes().unwrap());
}
}
#[test]
fn verify_mss_read_to_end() {
let data = generate_random_bytes(5 * 96 * 1024);
let ms = Cursor::new(data.clone());
let mut mss = MediaSourceStream::new(Box::new(ms), Default::default());
let mut output: Vec<u8> = Vec::new();
assert_eq!(mss.read_to_end(&mut output).unwrap(), data.len());
assert_eq!(output.into_boxed_slice(), data);
}
#[test]
fn verify_mss_seek_buffered() {
let data = generate_random_bytes(1024 * 1024);
let ms = Cursor::new(data);
let mut mss = MediaSourceStream::new(Box::new(ms), Default::default());
assert_eq!(mss.read_buffer_len(), 0);
assert_eq!(mss.unread_buffer_len(), 0);
mss.ignore_bytes(5122).unwrap();
assert_eq!(5122, mss.pos());
assert_eq!(mss.read_buffer_len(), 5122);
let upper = mss.read_byte().unwrap();
assert_eq!(mss.seek_buffered_rel(-1000), 4123);
assert_eq!(mss.pos(), 4123);
assert_eq!(mss.read_buffer_len(), 4123);
assert_eq!(mss.seek_buffered_rel(999), 5122);
assert_eq!(mss.pos(), 5122);
assert_eq!(mss.read_buffer_len(), 5122);
assert_eq!(upper, mss.read_byte().unwrap());
}
#[test]
fn verify_reading_be() {
let data = generate_random_bytes(1024 * 1024);
let ms = Cursor::new(data);
let mut mss = MediaSourceStream::new(Box::new(ms), Default::default());
mss.ignore_bytes(2).unwrap();
assert_eq!(mss.read_be_f32().unwrap(), -72818055000000000000000000000.0);
assert_eq!(mss.read_be_f64().unwrap(), -0.000000000000011582640453292664);
assert_eq!(mss.read_be_u16().unwrap(), 32624);
assert_eq!(mss.read_be_u24().unwrap(), 6739677);
assert_eq!(mss.read_be_u32().unwrap(), 1569552917);
assert_eq!(mss.read_be_u64().unwrap(), 6091217585348000864);
}
#[test]
fn verify_reading_le() {
let data = generate_random_bytes(1024 * 1024);
let ms = Cursor::new(data);
let mut mss = MediaSourceStream::new(Box::new(ms), Default::default());
mss.ignore_bytes(1024).unwrap();
assert_eq!(mss.read_f32().unwrap(), -0.00000000000000000000000000048426285);
assert_eq!(mss.read_f64().unwrap(), -6444325820119113.0);
assert_eq!(mss.read_u16().unwrap(), 36195);
assert_eq!(mss.read_u24().unwrap(), 6710386);
assert_eq!(mss.read_u32().unwrap(), 2378776723);
assert_eq!(mss.read_u64().unwrap(), 5170196279331153683);
}
}