use std::cmp::min;
use std::io::{Read, Seek};
use bytes::Bytes;
use flume::{Receiver, Sender, bounded};
use futures::{Stream, StreamExt};
use symphonia::core::io::MediaSource;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
type ByteStreamType =
Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send + std::marker::Unpin>;
pub struct ByteStreamSource {
finished: bool,
seekable: bool,
size: Option<u64>,
read_position: usize,
fetcher: ByteStreamSourceFetcher,
abort: CancellationToken,
}
struct ByteStreamSourceFetcher {
start: u64,
end: Option<u64>,
buffer: Vec<u8>,
ready_receiver: Receiver<()>,
ready: Sender<()>,
receiver: Receiver<Bytes>,
sender: Sender<Bytes>,
abort_handle: Option<JoinHandle<()>>,
abort: CancellationToken,
stream_abort: CancellationToken,
}
#[cfg_attr(feature = "profiling", profiling::all_functions)]
impl ByteStreamSourceFetcher {
pub fn new(
stream: ByteStreamType,
start: u64,
end: Option<u64>,
autostart: bool,
stream_abort: CancellationToken,
) -> Self {
let (tx, rx) = bounded(1);
let (tx_ready, rx_ready) = bounded(1);
let mut fetcher = Self {
start,
end,
buffer: vec![],
ready_receiver: rx_ready,
ready: tx_ready,
receiver: rx,
sender: tx,
abort_handle: None,
abort: CancellationToken::new(),
stream_abort,
};
if autostart {
fetcher.start_fetch(stream);
}
fetcher
}
fn start_fetch(&mut self, mut stream: ByteStreamType) {
let sender = self.sender.clone();
let ready_receiver = self.ready_receiver.clone();
let abort = self.abort.clone();
let stream_abort = self.stream_abort.clone();
let start = self.start;
let end = self.end;
log::debug!("Starting fetch for byte stream with range start={start} end={end:?}");
self.abort_handle = Some(moosicbox_task::spawn(
"audio_decoder: ByteStreamSource Fetcher",
async move {
log::debug!("Fetching byte stream with range start={start} end={end:?}");
while let Some(item) = tokio::select! {
resp = stream.next() => resp,
() = abort.cancelled() => {
log::debug!("Aborted");
None
}
() = stream_abort.cancelled() => {
log::debug!("Stream aborted");
None
}
} {
log::trace!("Received more bytes from stream");
let bytes = item.unwrap();
if let Err(err) = sender.send_async(bytes).await {
log::info!("Aborted byte stream read: {err:?}");
return;
}
}
log::debug!("Finished reading from stream");
if sender.send_async(Bytes::new()).await.is_ok()
&& ready_receiver.recv_async().await.is_err()
{
log::info!("Byte stream read has been aborted");
}
},
));
}
fn abort(&mut self) {
self.abort.cancel();
if let Some(handle) = &self.abort_handle {
log::debug!("Aborting request");
handle.abort();
self.abort_handle = None;
} else {
log::debug!("No join handle for request");
}
self.abort = CancellationToken::new();
}
}
impl Drop for ByteStreamSourceFetcher {
fn drop(&mut self) {
self.abort();
}
}
impl ByteStreamSource {
#[must_use]
pub fn new(
stream: ByteStreamType,
size: Option<u64>,
autostart_fetch: bool,
seekable: bool,
abort: CancellationToken,
) -> Self {
Self {
finished: false,
seekable,
size,
read_position: 0,
fetcher: ByteStreamSourceFetcher::new(stream, 0, size, autostart_fetch, abort.clone()),
abort,
}
}
}
#[cfg_attr(feature = "profiling", profiling::all_functions)]
impl Read for ByteStreamSource {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.finished {
return Ok(0);
}
let mut written = 0;
let mut read_position = self.read_position;
let write_max = buf.len();
while written < write_max {
let receiver = self.fetcher.receiver.clone();
let fetcher = &mut self.fetcher;
let buffer_len = fetcher.buffer.len();
#[allow(clippy::cast_possible_truncation)]
let fetcher_start = fetcher.start as usize;
log::debug!(
"Read: read_pos[{read_position}] write_max[{write_max}] fetcher_start[{fetcher_start}] buffer_len[{buffer_len}] written[{written}]"
);
let bytes_written = if fetcher_start + buffer_len > read_position {
let fetcher_buf_start = read_position - fetcher_start;
let bytes_to_read_from_buf = buffer_len - fetcher_buf_start;
log::trace!(
"Reading bytes from buffer: {bytes_to_read_from_buf} (max {write_max})"
);
let bytes_to_write = min(bytes_to_read_from_buf, write_max);
buf[written..written + bytes_to_write].copy_from_slice(
&fetcher.buffer[fetcher_buf_start..fetcher_buf_start + bytes_to_write],
);
bytes_to_write
} else {
log::trace!("Waiting for bytes...");
let new_bytes = receiver.recv().unwrap();
if fetcher.abort.is_cancelled() || self.abort.is_cancelled() {
return Ok(written);
}
fetcher.buffer.extend_from_slice(&new_bytes);
let len = new_bytes.len();
log::trace!("Received bytes {len}");
if len == 0 {
self.finished = true;
self.fetcher.ready.send(()).unwrap();
break;
}
let bytes_to_write = min(len, write_max - written);
buf[written..written + bytes_to_write]
.copy_from_slice(&new_bytes[..bytes_to_write]);
bytes_to_write
};
written += bytes_written;
read_position += bytes_written;
}
self.read_position = read_position;
Ok(written)
}
}
#[cfg_attr(feature = "profiling", profiling::all_functions)]
impl Seek for ByteStreamSource {
fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
let seek_position: usize = match pos {
#[allow(clippy::cast_possible_truncation)]
std::io::SeekFrom::Start(pos) => pos as usize,
std::io::SeekFrom::Current(pos) => {
#[allow(clippy::cast_possible_wrap)]
let pos = self.read_position as i64 + pos;
pos.try_into().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid seek: {pos}"),
)
})?
}
std::io::SeekFrom::End(pos) => {
#[allow(clippy::cast_possible_wrap)]
let pos = self.size.unwrap() as i64 - pos;
pos.try_into().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid seek: {pos}"),
)
})?
}
};
log::info!(
"Seeking: pos[{seek_position}] current=[{}] type[{pos:?}]",
self.read_position
);
self.read_position = seek_position;
Ok(seek_position as u64)
}
}
impl MediaSource for ByteStreamSource {
fn is_seekable(&self) -> bool {
log::debug!("seekable={} size={:?}", self.seekable, self.size);
self.seekable && self.size.is_some()
}
fn byte_len(&self) -> Option<u64> {
log::debug!("byte_len={:?}", self.size);
self.size
}
}