use std::{
collections::VecDeque,
io::SeekFrom,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
task::{Poll, Waker},
};
use anyhow::Context;
use dashmap::DashMap;
use librqbit_core::lengths::{CurrentPiece, Lengths, ValidPieceIndex};
use tokio::io::{AsyncRead, AsyncSeek};
use tracing::{debug, trace};
use crate::{
file_info::FileInfo, spawn_utils::BlockingSpawner, storage::TorrentStorage, ManagedTorrent,
};
use super::{ManagedTorrentHandle, TorrentMetadata};
type StreamId = usize;
const PER_STREAM_BUF_DEFAULT: u64 = 32 * 1024 * 1024;
struct StreamState {
file_id: usize,
file_len: u64,
file_abs_offset: u64,
position: u64,
waker: Option<Waker>,
}
impl StreamState {
fn current_piece(&self, lengths: &Lengths) -> Option<CurrentPiece> {
lengths.compute_current_piece(self.position, self.file_abs_offset)
}
fn queue<'a>(&self, lengths: &'a Lengths) -> impl Iterator<Item = ValidPieceIndex> + 'a {
let start = self.file_abs_offset + self.position;
let end = (start + PER_STREAM_BUF_DEFAULT).min(self.file_abs_offset + self.file_len);
let dpl = lengths.default_piece_length();
let start_id = (start / dpl as u64).try_into().unwrap();
let end_id = end.div_ceil(dpl as u64).try_into().unwrap();
(start_id..end_id).filter_map(|i| lengths.validate_piece_index(i))
}
}
#[derive(Default)]
pub(crate) struct TorrentStreams {
next_stream_id: AtomicUsize,
streams: DashMap<StreamId, StreamState>,
}
impl TorrentStreams {
fn next_id(&self) -> usize {
self.next_stream_id.fetch_add(1, Ordering::Relaxed)
}
fn register_waker(&self, stream_id: StreamId, waker: Waker) {
if let Some(mut s) = self.streams.get_mut(&stream_id) {
let vm = s.value_mut();
vm.waker = Some(waker);
}
}
pub(crate) fn iter_next_pieces<'a>(
&'a self,
lengths: &'a Lengths,
) -> impl Iterator<Item = ValidPieceIndex> + 'a {
struct Interleave<I> {
all: VecDeque<I>,
}
impl<I: Iterator<Item = ValidPieceIndex>> Iterator for Interleave<I> {
type Item = ValidPieceIndex;
fn next(&mut self) -> Option<Self::Item> {
while let Some(mut it) = self.all.pop_front() {
if let Some(piece) = it.next() {
self.all.push_back(it);
return Some(piece);
}
}
None
}
}
let mut all: Vec<_> = self.streams.iter().map(|s| s.queue(lengths)).collect();
use rand::seq::SliceRandom;
all.shuffle(&mut rand::rng());
Interleave { all: all.into() }
}
pub(crate) fn wake_streams_on_piece_completed(
&self,
piece_id: ValidPieceIndex,
lengths: &Lengths,
) {
for mut w in self.streams.iter_mut() {
if w.value().current_piece(lengths).map(|p| p.id) == Some(piece_id) {
if let Some(waker) = w.value_mut().waker.take() {
trace!(
stream_id = *w.key(),
piece_id = piece_id.get(),
"waking stream"
);
waker.wake();
}
}
}
}
fn drop_stream(&self, stream_id: StreamId) -> Option<StreamState> {
trace!(stream_id, "dropping stream");
self.streams.remove(&stream_id).map(|s| s.1)
}
pub(crate) fn streamed_file_ids(&self) -> impl Iterator<Item = usize> + '_ {
self.streams.iter().map(|s| s.value().file_id)
}
}
pub struct FileStream {
torrent: ManagedTorrentHandle,
metadata: Arc<TorrentMetadata>,
streams: Arc<TorrentStreams>,
stream_id: usize,
file_id: usize,
position: u64,
file_len: u64,
file_torrent_abs_offset: u64,
spawner: BlockingSpawner,
}
macro_rules! map_io_err {
($e:expr) => {
$e.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
};
}
macro_rules! poll_try_io {
($e:expr) => {{
let e = map_io_err!($e);
match e {
Ok(r) => r,
Err(e) => {
debug!("stream error {e:#}");
return Poll::Ready(Err(e));
}
}
}};
}
impl AsyncRead for FileStream {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
tbuf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if self.position == self.file_len {
trace!(
stream_id = self.stream_id,
file_id = self.file_id,
"stream completed, EOF"
);
return Poll::Ready(Ok(()));
}
let current = poll_try_io!(self
.metadata
.lengths
.compute_current_piece(self.position, self.file_torrent_abs_offset)
.context("invalid position"));
let have = poll_try_io!(self.torrent.with_chunk_tracker(|ct| {
let have = ct.get_have_pieces().as_slice()[current.id.get() as usize];
if !have {
self.streams
.register_waker(self.stream_id, cx.waker().clone());
}
have
}));
if !have {
trace!(stream_id = self.stream_id, file_id = self.file_id, piece_id = %current.id, "poll pending, not have");
return Poll::Pending;
}
let buf = tbuf.initialize_unfilled();
let file_remaining = self.file_len - self.position;
let bytes_to_read: usize = poll_try_io!((buf.len() as u64)
.min(current.piece_remaining as u64)
.min(file_remaining)
.try_into());
let buf = &mut buf[..bytes_to_read];
trace!(
buflen = buf.len(),
stream_id = self.stream_id,
file_id = self.file_id,
"will write bytes"
);
poll_try_io!(poll_try_io!(self.spawner.spawn_block_in_place(|| {
self.torrent.with_storage_and_file(
self.file_id,
|files, _fi| {
files.pread_exact(self.file_id, self.position, buf)?;
Ok::<_, anyhow::Error>(())
},
&self.metadata,
)
})));
self.as_mut().advance(bytes_to_read as u64);
tbuf.advance(bytes_to_read);
Poll::Ready(Ok(()))
}
}
impl AsyncSeek for FileStream {
fn start_seek(
mut self: std::pin::Pin<&mut Self>,
position: std::io::SeekFrom,
) -> std::io::Result<()> {
let end_i64 = map_io_err!(TryInto::<i64>::try_into(self.file_len))?;
let new_pos: i64 = match position {
SeekFrom::Start(s) => map_io_err!(s.try_into())?,
SeekFrom::End(e) => map_io_err!(TryInto::<i64>::try_into(self.file_len))? + e,
SeekFrom::Current(o) => map_io_err!(TryInto::<i64>::try_into(self.position))? + o,
};
if new_pos < 0 || new_pos > end_i64 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
anyhow::anyhow!("invalid seek"),
));
}
self.as_mut().set_position(map_io_err!(new_pos.try_into())?);
trace!(stream_id = self.stream_id, position = self.position, "seek");
Ok(())
}
fn poll_complete(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<u64>> {
Poll::Ready(Ok(self.position))
}
}
impl Drop for FileStream {
fn drop(&mut self) {
self.streams.drop_stream(self.stream_id);
}
}
impl ManagedTorrent {
fn with_storage_and_file<F, R>(
&self,
file_id: usize,
f: F,
metadata: &TorrentMetadata,
) -> anyhow::Result<R>
where
F: FnOnce(&dyn TorrentStorage, &FileInfo) -> R,
{
self.with_state(|s| {
let files = match s {
crate::ManagedTorrentState::Paused(p) => &*p.files,
crate::ManagedTorrentState::Live(l) => &*l.files,
s => anyhow::bail!("with_storage_and_file: invalid state: {}", s.name()),
};
let fi = metadata.file_infos.get(file_id).context("invalid file")?;
Ok(f(files, fi))
})
}
fn streams(&self) -> anyhow::Result<Arc<TorrentStreams>> {
self.with_state(|s| match s {
crate::ManagedTorrentState::Paused(p) => Ok(p.streams.clone()),
crate::ManagedTorrentState::Live(l) => Ok(l.streams.clone()),
s => anyhow::bail!("streams: invalid state {}", s.name()),
})
}
fn maybe_reconnect_needed_peers_for_file(&self, file_id: usize) -> bool {
if self.is_file_finished(file_id) {
return false;
}
self.with_state(|state| {
if let crate::ManagedTorrentState::Live(l) = &state {
l.reconnect_all_not_needed_peers();
}
});
true
}
fn is_file_finished(&self, file_id: usize) -> bool {
let metadata = self.metadata.load();
let metadata = match metadata.as_ref() {
Some(r) => r,
None => return false,
};
self.with_chunk_tracker(|ct| ct.is_file_finished(&metadata.file_infos[file_id]))
.unwrap_or(false)
}
pub fn stream(self: Arc<Self>, file_id: usize) -> anyhow::Result<FileStream> {
let metadata = self
.metadata
.load_full()
.context("torrent metadata is not resolved")?;
let (fd_len, fd_offset) = self.with_storage_and_file(
file_id,
|_fd, fi| (fi.len, fi.offset_in_torrent),
&metadata,
)?;
let streams = self.streams()?;
let s = FileStream {
stream_id: streams.next_id(),
streams: streams.clone(),
file_id,
position: 0,
file_len: fd_len,
file_torrent_abs_offset: fd_offset,
torrent: self,
spawner: BlockingSpawner::default(),
metadata,
};
s.torrent.maybe_reconnect_needed_peers_for_file(file_id);
streams.streams.insert(
s.stream_id,
StreamState {
file_id,
position: 0,
waker: None,
file_len: fd_len,
file_abs_offset: fd_offset,
},
);
debug!(stream_id = s.stream_id, file_id, "started stream");
Ok(s)
}
}
impl FileStream {
pub fn position(&self) -> u64 {
self.position
}
fn advance(&mut self, diff: u64) {
self.set_position(self.position + diff)
}
fn set_position(&mut self, new_pos: u64) {
self.position = new_pos;
self.streams
.streams
.get_mut(&self.stream_id)
.unwrap()
.value_mut()
.position = new_pos;
}
pub fn len(&self) -> u64 {
self.file_len
}
}