Skip to main content

ac_rustube/stream/
mod.rs

1use std::ops::Range;
2#[cfg(feature = "download")]
3use std::path::{Path, PathBuf};
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7use chrono::{DateTime, Utc};
8use mime::Mime;
9use reqwest::Client;
10use serde_with::{DisplayFromStr, serde_as};
11#[cfg(feature = "download")]
12use tokio::{
13    fs::File,
14    io::AsyncWriteExt,
15};
16#[cfg(feature = "callback")]
17use tokio::sync::mpsc::error::TrySendError;
18#[cfg(feature = "download")]
19use tokio_stream::StreamExt;
20
21#[cfg(feature = "callback")]
22use callback::{InternalSender, InternalSignal};
23#[cfg(all(feature = "callback", feature = "stream", feature = "blocking"))]
24use callback::Callback;
25
26#[cfg(feature = "download")]
27use crate::{Error, Result};
28use crate::{
29    video_info::player_response::streaming_data::{
30        AudioQuality, ColorInfo, FormatType, ProjectionType,
31        Quality, QualityLabel, RawFormat, SignatureCipher,
32    },
33    VideoDetails,
34};
35
36#[cfg(feature = "callback")]
37pub mod callback;
38
39// todo:
40//  there are different types of streams: video, audio, and video + audio
41//  make Stream and RawFormat an enum, so there are less options in it
42
43#[cfg(all(not(feature = "callback"), feature = "download"))]
44type InternalSender = ();
45
46/// A downloadable video Stream, that contains all the important information.
47#[serde_as]
48#[derive(Clone, derivative::Derivative, serde::Deserialize, serde::Serialize)]
49#[derivative(Debug, PartialEq)]
50pub struct Stream {
51    #[serde_as(as = "DisplayFromStr")]
52    pub mime: Mime,
53    pub codecs: Vec<String>,
54    pub is_progressive: bool,
55    pub includes_video_track: bool,
56    pub includes_audio_track: bool,
57    pub format_type: Option<FormatType>,
58    pub approx_duration_ms: Option<u64>,
59    pub audio_channels: Option<u8>,
60    pub audio_quality: Option<AudioQuality>,
61    pub audio_sample_rate: Option<u64>,
62    pub average_bitrate: Option<u64>,
63    pub bitrate: Option<u64>,
64    pub color_info: Option<ColorInfo>,
65    #[derivative(PartialEq(compare_with = "atomic_u64_is_eq"))]
66    content_length: Arc<AtomicU64>,
67    pub fps: u8,
68    pub height: Option<u64>,
69    pub high_replication: Option<bool>,
70    pub index_range: Option<Range<u64>>,
71    pub init_range: Option<Range<u64>>,
72    pub is_otf: bool,
73    pub itag: u64,
74    pub last_modified: Option<DateTime<Utc>>,
75    pub loudness_db: Option<f64>,
76    pub projection_type: ProjectionType,
77    pub quality: Quality,
78    pub quality_label: Option<QualityLabel>,
79    pub signature_cipher: SignatureCipher,
80    pub width: Option<u64>,
81    pub video_details: Arc<VideoDetails>,
82    #[allow(dead_code)]
83    #[serde(skip)]
84    #[derivative(Debug = "ignore", PartialEq = "ignore")]
85    client: Client,
86}
87
88
89impl Stream {
90    // maybe deserialize RawFormat seeded with client and VideoDetails
91    pub(crate) fn from_raw_format(raw_format: RawFormat, client: Client, video_details: Arc<VideoDetails>) -> Self {
92        Self {
93            is_progressive: is_progressive(&raw_format.mime_type.codecs),
94            includes_video_track: includes_video_track(&raw_format.mime_type.codecs, &raw_format.mime_type.mime),
95            includes_audio_track: includes_audio_track(&raw_format.mime_type.codecs, &raw_format.mime_type.mime),
96            mime: raw_format.mime_type.mime,
97            codecs: raw_format.mime_type.codecs,
98            format_type: raw_format.format_type,
99            approx_duration_ms: raw_format.approx_duration_ms,
100            audio_channels: raw_format.audio_channels,
101            audio_quality: raw_format.audio_quality,
102            audio_sample_rate: raw_format.audio_sample_rate,
103            average_bitrate: raw_format.average_bitrate,
104            bitrate: raw_format.bitrate,
105            color_info: raw_format.color_info,
106            content_length: Arc::new(AtomicU64::new(raw_format.content_length.unwrap_or(0))),
107            fps: raw_format.fps,
108            height: raw_format.height,
109            high_replication: raw_format.high_replication,
110            index_range: raw_format.index_range,
111            init_range: raw_format.init_range,
112            is_otf: matches!(raw_format.format_type, Some(FormatType::Otf)),
113            itag: raw_format.itag,
114            last_modified: raw_format.last_modified,
115            loudness_db: raw_format.loudness_db,
116            projection_type: raw_format.projection_type,
117            quality: raw_format.quality,
118            quality_label: raw_format.quality_label,
119            signature_cipher: raw_format.signature_cipher,
120            width: raw_format.width,
121            client,
122            video_details,
123        }
124    }
125}
126
127// todo: download in ranges
128// todo: blocking download
129
130#[cfg(feature = "download")]
131impl Stream {
132    /// The content length of the video.
133    /// If the content length was not included in the [`RawFormat`], this method will make a `HEAD`
134    /// request, to try to figure it out.
135    ///
136    /// ### Errors:
137    /// - When the content length was not included in the [`RawFormat`], and the request fails.
138    #[inline]
139    pub async fn content_length(&self) -> Result<u64> {
140        let cl = self.content_length.load(Ordering::SeqCst);
141        if cl != 0 { return Ok(cl); }
142
143        self.client
144            .head(self.signature_cipher.url.as_str())
145            .send()
146            .await?
147            .error_for_status()?
148            .headers()
149            .get(reqwest::header::CONTENT_LENGTH)
150            .and_then(|cl| cl.to_str().ok())
151            .and_then(|cl| cl.parse::<u64>().ok())
152            .map(|cl| {
153                log::trace!("content length of {:?} is {}", self, cl);
154                self.content_length.store(cl, Ordering::SeqCst);
155                cl
156            })
157            .ok_or_else(|| Error::UnexpectedResponse(
158                "the response did not contain a valid content-length field".into()
159            ))
160    }
161
162    /// Attempts to downloads the [`Stream`]s resource.
163    /// This will download the video to <video_id>.mp4 in the current working directory.
164    #[inline]
165    pub async fn download(&self) -> Result<PathBuf> {
166        self.internal_download(None).await
167    }
168
169    #[inline]
170    async fn internal_download(&self, channel: Option<InternalSender>) -> Result<PathBuf> {
171        let path = Path::new(self.video_details.video_id.as_str())
172            .with_extension(self.mime.subtype().as_str());
173        self.internal_download_to(&path, channel)
174            .await
175    }
176
177    /// Attempts to downloads the [`Stream`]s resource.
178    /// This will download the video to <video_id>.mp4 in the provided directory.
179    #[inline]
180    pub async fn download_to_dir<P: AsRef<Path>>(&self, dir: P) -> Result<PathBuf> {
181        self.internal_download_to_dir(dir, None).await
182    }
183
184    #[inline]
185    async fn internal_download_to_dir<P: AsRef<Path>>(
186        &self,
187        dir: P,
188        channel: Option<InternalSender>,
189    ) -> Result<PathBuf> {
190        let mut path = dir
191            .as_ref()
192            .join(self.video_details.video_id.as_str());
193        path.set_extension(self.mime.subtype().as_str());
194        self.internal_download_to(&path, channel)
195            .await
196    }
197
198    /// Attempts to downloads the [`Stream`]s resource.
199    /// This will download the video to the provided file path.
200    #[inline]
201    pub async fn download_to<P: AsRef<Path>>(&self, path: P) -> Result<()> {
202        let _ = self.internal_download_to(path, None).await?;
203        Ok(())
204    }
205
206    #[allow(unused_mut, clippy::let_and_return)]
207    async fn internal_download_to<P: AsRef<Path>>(&self, path: P, channel: Option<InternalSender>) -> Result<PathBuf> {
208        log::trace!("download_to: {:?}", path.as_ref());
209        log::debug!("start downloading {}", self.video_details.video_id);
210        let mut file = File::create(&path).await?;
211
212        let result = match self.download_full(&self.signature_cipher.url, &mut file, &channel, 0).await {
213            Ok(_) => {
214                log::info!(
215                    "downloaded {} successfully to {:?}",
216                    self.video_details.video_id, path.as_ref()
217                );
218                log::debug!("downloaded stream {:?}", &self);
219                Ok(())
220            }
221            Err(Error::Request(e)) if matches!(e.status(), Some(reqwest::StatusCode::NOT_FOUND)) => {
222                log::error!("failed to download {}: {:?}", self.video_details.video_id, e);
223                log::info!("try to download {} using sequenced download", self.video_details.video_id);
224                // Some adaptive streams need to be requested with sequence numbers
225                self.download_full_seq(&mut file, &channel)
226                    .await
227                    .map_err(|e| {
228                        log::error!(
229                            "failed to download {} using sequenced download: {:?}",
230                            self.video_details.video_id, e
231                        );
232                        e
233                    })
234            }
235            Err(e) => {
236                log::error!("failed to download {}: {:?}", self.video_details.video_id, e);
237                drop(file);
238                tokio::fs::remove_file(path.as_ref()).await?;
239                Err(e)
240            }
241        }.map(|_| path.as_ref().to_path_buf());
242
243        #[cfg(feature = "callback")]
244        if let Some(channel) = channel {
245            let _ = channel.send(InternalSignal::Finished).await;
246        }
247
248        result
249    }
250
251    async fn download_full_seq(&self, file: &mut File, channel: &Option<InternalSender>) -> Result<()> {
252        // fixme: this implementation is **not** tested yet!
253        // To test it, I would need an url of a video, which does require sequenced downloading.
254        log::warn!(
255            "`download_full_seq` is not tested yet and probably broken!\n\
256            Please open a GitHub issue (https://github.com/DzenanJupic/rustube/issues) and paste \
257            the whole warning message in:\n\
258            id: {}\n\
259            url: {}",
260            self.video_details.video_id,
261            self.signature_cipher.url.as_str()
262        );
263
264        let mut url = self.signature_cipher.url.clone();
265        let base_query = url
266            .query()
267            .map(str::to_owned)
268            .unwrap_or_else(String::new);
269
270        // The 0th sequential request provides the file headers, which tell us
271        // information about how the file is segmented.
272        Self::set_url_seq_query(&mut url, &base_query, 0);
273        let res = self.get(&url).await?;
274        let segment_count = Stream::extract_segment_count(&res)?;
275        // No callback action since this is not really part of the progress
276        self.write_stream_to_file(res.bytes_stream(), file, &None, 0).await?;
277        let mut count = 0;
278
279        for i in 1..segment_count {
280            Self::set_url_seq_query(&mut url, &base_query, i);
281            count = self.download_full(&url, file, channel, count).await?;
282        }
283
284        Ok(())
285    }
286
287    #[inline]
288    async fn download_full(
289        &self,
290        url: &url::Url,
291        file: &mut File,
292        channel: &Option<InternalSender>,
293        count: usize,
294    ) -> Result<usize> {
295        let res = self.get(url).await?;
296        self.write_stream_to_file(res.bytes_stream(), file, channel, count).await
297    }
298
299    #[inline]
300    async fn get(&self, url: &url::Url) -> Result<reqwest::Response> {
301        log::trace!("get: {}", url.as_str());
302        Ok(
303            self.client
304                .get(url.as_str())
305                .send()
306                .await?
307                .error_for_status()?
308        )
309    }
310
311    #[inline]
312    #[allow(unused_variables, unused_mut)]
313    async fn write_stream_to_file(
314        &self,
315        mut stream: impl tokio_stream::Stream<Item=reqwest::Result<bytes::Bytes>> + Unpin,
316        file: &mut File,
317        channel: &Option<InternalSender>,
318        mut counter: usize,
319    ) -> Result<usize> {
320        // Counter will be 0 if callback is not enabled
321        while let Some(chunk) = stream.next().await {
322            let chunk = chunk?;
323            let len = chunk.len();
324            log::trace!("received {} byte chunk ", len);
325
326            file.write_all(&chunk).await?;
327            #[cfg(feature = "callback")]
328            if let Some(channel) = &channel {
329                // network chunks of ~10kb size
330                counter += len;
331                // Will abort if the receiver is closed
332                // Will ignore if the channel is full and thus not slow down the download
333                if let Err(TrySendError::Closed(_)) =
334                    channel.try_send(InternalSignal::Value(counter))
335                {
336                    return Err(Error::ChannelClosed);
337                }
338            }
339        }
340        Ok(counter)
341    }
342
343    #[inline]
344    fn set_url_seq_query(url: &mut url::Url, base_query: &str, sq: u64) {
345        url.set_query(Some(base_query));
346        url
347            .query_pairs_mut()
348            .append_pair("sq", &sq.to_string());
349    }
350
351    #[inline]
352    fn extract_segment_count(res: &reqwest::Response) -> Result<u64> {
353        res
354            .headers()
355            .get("Segment-Count")
356            .ok_or_else(|| Error::UnexpectedResponse(
357                "sequence download request did not contain a Segment-Count".into()
358            ))?
359            .to_str()
360            .map_err(|_| Error::UnexpectedResponse(
361                "Segment-Count is not valid utf-8".into()
362            ))?
363            .parse::<u64>()
364            .map_err(|_| Error::UnexpectedResponse(
365                "Segment-Count could not be parsed into an integer".into()
366            ))
367    }
368}
369
370#[cfg(all(feature = "download", feature = "blocking"))]
371impl Stream {
372    /// A synchronous wrapper around [`Stream::download`](crate::Stream::download).
373    #[inline]
374    pub fn blocking_download(&self) -> Result<PathBuf> {
375        crate::block!(self.download())
376    }
377
378    /// A synchronous wrapper around [`Stream::download_with_callback`](crate::Stream::download_with_callback).
379    #[cfg(feature = "callback")]
380    #[inline]
381    pub fn blocking_download_with_callback<'a>(&'a self, callback: Callback<'a>) -> Result<PathBuf> {
382        crate::block!(self.download_with_callback(callback))
383    }
384
385    /// A synchronous wrapper around [`Stream::download_to_dir`](crate::Stream::download_to_dir).
386    #[inline]
387    pub fn blocking_download_to_dir<P: AsRef<Path>>(&self, dir: P) -> Result<PathBuf> {
388        crate::block!(self.download_to_dir(dir))
389    }
390
391    /// A synchronous wrapper around [`Stream::download_to_dir_with_callback`](crate::Stream::download_to_dir_with_callback).
392    #[cfg(feature = "callback")]
393    #[inline]
394    pub fn blocking_download_to_dir_with_callback<'a, P: AsRef<Path>>(
395        &'a self,
396        dir: P,
397        callback: Callback<'a>,
398    ) -> Result<PathBuf> {
399        crate::block!(self.download_to_dir_with_callback(dir, callback))
400    }
401
402    /// A synchronous wrapper around [`Stream::download_to`](crate::Stream::download_to).
403    pub fn blocking_download_to<P: AsRef<Path>>(&self, path: P) -> Result<()> {
404        crate::block!(self.download_to(path))
405    }
406
407    /// A synchronous wrapper around [`Stream::download_to_with_callback`](crate::Stream::download_to_with_callback).
408    #[cfg(feature = "callback")]
409    pub fn blocking_download_to_with_callback<'a, P: AsRef<Path>>(&'a self, path: P, callback: Callback<'a>) -> Result<()> {
410        crate::block!(self.download_to_with_callback(path, callback))
411    }
412
413    /// A synchronous wrapper around [`Stream::content_length`](crate::Stream::content_length).
414    #[inline]
415    pub fn blocking_content_length(&self) -> Result<u64> {
416        crate::block!(self.content_length())
417    }
418}
419
420#[inline]
421fn is_adaptive(codecs: &[String]) -> bool {
422    codecs.len() % 2 != 0
423}
424
425#[inline]
426fn includes_video_track(codecs: &[String], mime: &Mime) -> bool {
427    is_progressive(codecs) || mime.type_() == "video"
428}
429
430#[inline]
431fn includes_audio_track(codecs: &[String], mime: &Mime) -> bool {
432    is_progressive(codecs) || mime.type_() == "audio"
433}
434
435#[inline]
436fn is_progressive(codecs: &[String]) -> bool {
437    !is_adaptive(codecs)
438}
439
440#[inline]
441fn atomic_u64_is_eq(lhs: &Arc<AtomicU64>, rhs: &Arc<AtomicU64>) -> bool {
442    lhs.load(Ordering::Acquire) == rhs.load(Ordering::Acquire)
443}