Skip to main content

termusic_stream/
source.rs

1use async_trait::async_trait;
2use bytes::Bytes;
3use futures::{Stream, StreamExt};
4use parking_lot::{Condvar, Mutex, RwLock, RwLockReadGuard};
5use rangemap::RangeSet;
6use std::{
7    error::Error,
8    fs::File,
9    io::{self, BufWriter, Seek, SeekFrom, Write},
10    sync::{
11        atomic::{AtomicI64, Ordering},
12        Arc,
13    },
14};
15use tap::TapFallible;
16use tokio::sync::mpsc;
17use tracing::{debug, error, info, trace};
18
19#[async_trait]
20pub trait SourceStream:
21    Stream<Item = Result<Bytes, Self::Error>> + Unpin + Send + Sync + Sized + 'static
22{
23    type Url: Send;
24    type Error: Error + Send;
25
26    async fn create(
27        url: Self::Url,
28        is_radio: bool,
29        radio_title: Arc<Mutex<String>>,
30    ) -> io::Result<Self>;
31    async fn content_length(&self) -> Option<u64>;
32    async fn seek_range(&mut self, start: u64, end: Option<u64>) -> io::Result<()>;
33}
34
35#[derive(Debug, Clone)]
36pub struct SourceHandle {
37    downloaded: Arc<RwLock<RangeSet<u64>>>,
38    requested_position: Arc<AtomicI64>,
39    position_reached: Arc<(Mutex<Waiter>, Condvar)>,
40    content_length_retrieved: Arc<(Mutex<bool>, Condvar)>,
41    content_length: Arc<AtomicI64>,
42    seek_tx: mpsc::Sender<u64>,
43}
44
45impl SourceHandle {
46    pub fn downloaded(&self) -> RwLockReadGuard<rangemap::RangeSet<u64>> {
47        self.downloaded.read()
48    }
49
50    pub fn request_position(&self, position: u64) {
51        self.requested_position
52            .store(position as i64, Ordering::SeqCst);
53    }
54
55    pub fn wait_for_requested_position(&self) {
56        let (mutex, cvar) = &*self.position_reached;
57        let mut waiter = mutex.lock();
58        if !waiter.stream_done {
59            debug!("Waiting for requested position");
60            cvar.wait_while(&mut waiter, |waiter| {
61                !waiter.stream_done && !waiter.position_reached
62            });
63            if !waiter.stream_done {
64                waiter.position_reached = false;
65            }
66            debug!("Position reached");
67        }
68    }
69
70    pub fn seek(&self, position: u64) {
71        self.seek_tx.try_send(position).ok();
72    }
73
74    pub fn content_length(&self) -> Option<u64> {
75        let (mutex, cvar) = &*self.content_length_retrieved;
76        let mut done = mutex.lock();
77        if !*done {
78            cvar.wait_while(&mut done, |done| !*done);
79        }
80        let length = self.content_length.load(Ordering::SeqCst);
81        if length > -1 {
82            Some(length as u64)
83        } else {
84            None
85        }
86    }
87}
88
89#[derive(Default, Debug)]
90struct Waiter {
91    position_reached: bool,
92    stream_done: bool,
93}
94
95pub struct Source {
96    writer: BufWriter<File>,
97    downloaded: Arc<RwLock<RangeSet<u64>>>,
98    requested_position: Arc<AtomicI64>,
99    position_reached: Arc<(Mutex<Waiter>, Condvar)>,
100    content_length_retrieved: Arc<(Mutex<bool>, Condvar)>,
101    content_length: Arc<AtomicI64>,
102    seek_tx: mpsc::Sender<u64>,
103    seek_rx: mpsc::Receiver<u64>,
104}
105
106const PREFETCH_BYTES: u64 = 1024 * 256;
107
108impl Source {
109    pub fn new(tempfile: File) -> Self {
110        let (seek_tx, seek_rx) = mpsc::channel(32);
111        Self {
112            writer: BufWriter::new(tempfile),
113            downloaded: Default::default(),
114            requested_position: Arc::new(AtomicI64::new(-1)),
115            position_reached: Default::default(),
116            content_length_retrieved: Default::default(),
117            seek_tx,
118            seek_rx,
119            content_length: Default::default(),
120        }
121    }
122
123    pub async fn download<S: SourceStream>(
124        mut self,
125        mut stream: S,
126        radio_downloaded: Arc<Mutex<u64>>,
127    ) -> io::Result<()> {
128        info!("Starting file download");
129        let content_length = stream.content_length().await;
130        if let Some(content_length) = content_length {
131            self.content_length
132                .swap(content_length as i64, Ordering::SeqCst);
133        } else {
134            self.content_length.swap(-1, Ordering::SeqCst);
135        }
136        {
137            let (mutex, cvar) = &*self.content_length_retrieved;
138            *mutex.lock() = true;
139            cvar.notify_all();
140        }
141        loop {
142            if let Some(Ok(bytes)) = stream
143                .next()
144                .await
145                .map(|b| b.tap_err(|e| error!("Error reading stream: {e}")))
146            {
147                self.writer.write_all(&bytes)?;
148                let stream_position = self.writer.stream_position()?;
149                trace!("Prefetch: {}/{} bytes", stream_position, PREFETCH_BYTES);
150                if stream_position >= PREFETCH_BYTES {
151                    self.downloaded.write().insert(0..stream_position);
152                    break;
153                }
154            } else {
155                info!("File shorter than prefetch length");
156                self.writer.flush()?;
157                self.downloaded
158                    .write()
159                    .insert(0..self.writer.stream_position()?);
160                let (mutex, cvar) = &*self.position_reached;
161                (mutex.lock()).stream_done = true;
162                cvar.notify_all();
163                return Ok(());
164            }
165        }
166        info!("Prefetch complete");
167        loop {
168            tokio::select! {
169                bytes = stream.next() => {
170                    if let Some(Ok(bytes)) =
171                        bytes.map(|b| b.tap_err(|e| error!("Error reading from stream: {e}"))) {
172                        let position = self.writer.stream_position()?;
173                        *radio_downloaded.lock() = position;
174                        self.writer.write_all(&bytes)?;
175                        let new_position = self.writer.stream_position()?;
176                        // *radio_downloaded.lock() = new_position;
177                    // eprintln!("downloaded: {new_position}");
178                        // trace!("Received response chunk. position={}", new_position);
179                        self.downloaded.write().insert(position .. new_position);
180                        let requested = self.requested_position.load(Ordering::SeqCst);
181                        if requested > -1 {
182                            debug!("downloader: requested {requested} current {}", new_position);
183                        }
184                        if requested > -1 && new_position as i64 >= requested {
185                            info!("Notifying requested position reached: {requested}. New position: {new_position}");
186                            self.requested_position.store(-1, Ordering::SeqCst);
187                            let (mutex, cvar) = &*self.position_reached;
188                            (mutex.lock()).position_reached = true;
189                            cvar.notify_all();
190                        }
191                    } else {
192                        info!("Stream finished downloading");
193                        if let Some(content_length) = content_length {
194                            let gap = {
195                                let downloaded = self.downloaded.read();
196                                let range = 0 .. content_length;
197                                let mut gaps = downloaded.gaps(&range);
198                                gaps.next()
199                            };
200                            if let Some(gap) = gap {
201                                debug!("Downloading missing stream chunk: {gap:?}.");
202                                stream.seek_range(gap.start, Some(gap.end)).await?;
203                                self.writer.seek(SeekFrom::Start(gap.start))?;
204                                continue;
205                            }
206                        }
207                        self.writer.flush()?;
208                        let (mutex, cvar) = &*self.position_reached;
209                        (mutex.lock()).stream_done = true;
210                        cvar.notify_all();
211                        return Ok(());
212                    }
213                },
214                pos = self.seek_rx.recv() => {
215                    if let Some(pos) = pos {
216                        debug!("Received seek position {pos}");
217                        let do_seek = {
218                            let downloaded = self.downloaded.read();
219                            if let Some(range) = downloaded.get(&pos) {
220                                !range.contains(&self.writer.stream_position()?)
221                            } else {
222                                true
223                            }
224                        };
225                        if do_seek {
226                            debug!("Seek position not yet downloaded");
227                            stream.seek_range(pos, None).await?;
228                            self.writer.seek(SeekFrom::Start(pos))?;
229                        }
230                    }
231                }
232            }
233        }
234    }
235
236    pub fn source_handle(&self) -> SourceHandle {
237        SourceHandle {
238            downloaded: self.downloaded.clone(),
239            requested_position: self.requested_position.clone(),
240            position_reached: self.position_reached.clone(),
241            seek_tx: self.seek_tx.clone(),
242            content_length_retrieved: self.content_length_retrieved.clone(),
243            content_length: self.content_length.clone(),
244        }
245    }
246}