1use std::ops::Range;
2#[cfg(feature = "download")]
3use std::path::{Path, PathBuf};
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7use chrono::{DateTime, Utc};
8use mime::Mime;
9use reqwest::Client;
10use serde_with::{DisplayFromStr, serde_as};
11#[cfg(feature = "download")]
12use tokio::{
13 fs::File,
14 io::AsyncWriteExt,
15};
16#[cfg(feature = "callback")]
17use tokio::sync::mpsc::error::TrySendError;
18#[cfg(feature = "download")]
19use tokio_stream::StreamExt;
20
21#[cfg(feature = "callback")]
22use callback::{InternalSender, InternalSignal};
23#[cfg(all(feature = "callback", feature = "stream", feature = "blocking"))]
24use callback::Callback;
25
26#[cfg(feature = "download")]
27use crate::{Error, Result};
28use crate::{
29 video_info::player_response::streaming_data::{
30 AudioQuality, ColorInfo, FormatType, ProjectionType,
31 Quality, QualityLabel, RawFormat, SignatureCipher,
32 },
33 VideoDetails,
34};
35
36#[cfg(feature = "callback")]
37pub mod callback;
38
39#[cfg(all(not(feature = "callback"), feature = "download"))]
44type InternalSender = ();
45
46#[serde_as]
48#[derive(Clone, derivative::Derivative, serde::Deserialize, serde::Serialize)]
49#[derivative(Debug, PartialEq)]
50pub struct Stream {
51 #[serde_as(as = "DisplayFromStr")]
52 pub mime: Mime,
53 pub codecs: Vec<String>,
54 pub is_progressive: bool,
55 pub includes_video_track: bool,
56 pub includes_audio_track: bool,
57 pub format_type: Option<FormatType>,
58 pub approx_duration_ms: Option<u64>,
59 pub audio_channels: Option<u8>,
60 pub audio_quality: Option<AudioQuality>,
61 pub audio_sample_rate: Option<u64>,
62 pub average_bitrate: Option<u64>,
63 pub bitrate: Option<u64>,
64 pub color_info: Option<ColorInfo>,
65 #[derivative(PartialEq(compare_with = "atomic_u64_is_eq"))]
66 content_length: Arc<AtomicU64>,
67 pub fps: u8,
68 pub height: Option<u64>,
69 pub high_replication: Option<bool>,
70 pub index_range: Option<Range<u64>>,
71 pub init_range: Option<Range<u64>>,
72 pub is_otf: bool,
73 pub itag: u64,
74 pub last_modified: Option<DateTime<Utc>>,
75 pub loudness_db: Option<f64>,
76 pub projection_type: ProjectionType,
77 pub quality: Quality,
78 pub quality_label: Option<QualityLabel>,
79 pub signature_cipher: SignatureCipher,
80 pub width: Option<u64>,
81 pub video_details: Arc<VideoDetails>,
82 #[allow(dead_code)]
83 #[serde(skip)]
84 #[derivative(Debug = "ignore", PartialEq = "ignore")]
85 client: Client,
86}
87
88
89impl Stream {
90 pub(crate) fn from_raw_format(raw_format: RawFormat, client: Client, video_details: Arc<VideoDetails>) -> Self {
92 Self {
93 is_progressive: is_progressive(&raw_format.mime_type.codecs),
94 includes_video_track: includes_video_track(&raw_format.mime_type.codecs, &raw_format.mime_type.mime),
95 includes_audio_track: includes_audio_track(&raw_format.mime_type.codecs, &raw_format.mime_type.mime),
96 mime: raw_format.mime_type.mime,
97 codecs: raw_format.mime_type.codecs,
98 format_type: raw_format.format_type,
99 approx_duration_ms: raw_format.approx_duration_ms,
100 audio_channels: raw_format.audio_channels,
101 audio_quality: raw_format.audio_quality,
102 audio_sample_rate: raw_format.audio_sample_rate,
103 average_bitrate: raw_format.average_bitrate,
104 bitrate: raw_format.bitrate,
105 color_info: raw_format.color_info,
106 content_length: Arc::new(AtomicU64::new(raw_format.content_length.unwrap_or(0))),
107 fps: raw_format.fps,
108 height: raw_format.height,
109 high_replication: raw_format.high_replication,
110 index_range: raw_format.index_range,
111 init_range: raw_format.init_range,
112 is_otf: matches!(raw_format.format_type, Some(FormatType::Otf)),
113 itag: raw_format.itag,
114 last_modified: raw_format.last_modified,
115 loudness_db: raw_format.loudness_db,
116 projection_type: raw_format.projection_type,
117 quality: raw_format.quality,
118 quality_label: raw_format.quality_label,
119 signature_cipher: raw_format.signature_cipher,
120 width: raw_format.width,
121 client,
122 video_details,
123 }
124 }
125}
126
127#[cfg(feature = "download")]
131impl Stream {
132 #[inline]
139 pub async fn content_length(&self) -> Result<u64> {
140 let cl = self.content_length.load(Ordering::SeqCst);
141 if cl != 0 { return Ok(cl); }
142
143 self.client
144 .head(self.signature_cipher.url.as_str())
145 .send()
146 .await?
147 .error_for_status()?
148 .headers()
149 .get(reqwest::header::CONTENT_LENGTH)
150 .and_then(|cl| cl.to_str().ok())
151 .and_then(|cl| cl.parse::<u64>().ok())
152 .map(|cl| {
153 log::trace!("content length of {:?} is {}", self, cl);
154 self.content_length.store(cl, Ordering::SeqCst);
155 cl
156 })
157 .ok_or_else(|| Error::UnexpectedResponse(
158 "the response did not contain a valid content-length field".into()
159 ))
160 }
161
162 #[inline]
165 pub async fn download(&self) -> Result<PathBuf> {
166 self.internal_download(None).await
167 }
168
169 #[inline]
170 async fn internal_download(&self, channel: Option<InternalSender>) -> Result<PathBuf> {
171 let path = Path::new(self.video_details.video_id.as_str())
172 .with_extension(self.mime.subtype().as_str());
173 self.internal_download_to(&path, channel)
174 .await
175 }
176
177 #[inline]
180 pub async fn download_to_dir<P: AsRef<Path>>(&self, dir: P) -> Result<PathBuf> {
181 self.internal_download_to_dir(dir, None).await
182 }
183
184 #[inline]
185 async fn internal_download_to_dir<P: AsRef<Path>>(
186 &self,
187 dir: P,
188 channel: Option<InternalSender>,
189 ) -> Result<PathBuf> {
190 let mut path = dir
191 .as_ref()
192 .join(self.video_details.video_id.as_str());
193 path.set_extension(self.mime.subtype().as_str());
194 self.internal_download_to(&path, channel)
195 .await
196 }
197
198 #[inline]
201 pub async fn download_to<P: AsRef<Path>>(&self, path: P) -> Result<()> {
202 let _ = self.internal_download_to(path, None).await?;
203 Ok(())
204 }
205
206 #[allow(unused_mut, clippy::let_and_return)]
207 async fn internal_download_to<P: AsRef<Path>>(&self, path: P, channel: Option<InternalSender>) -> Result<PathBuf> {
208 log::trace!("download_to: {:?}", path.as_ref());
209 log::debug!("start downloading {}", self.video_details.video_id);
210 let mut file = File::create(&path).await?;
211
212 let result = match self.download_full(&self.signature_cipher.url, &mut file, &channel, 0).await {
213 Ok(_) => {
214 log::info!(
215 "downloaded {} successfully to {:?}",
216 self.video_details.video_id, path.as_ref()
217 );
218 log::debug!("downloaded stream {:?}", &self);
219 Ok(())
220 }
221 Err(Error::Request(e)) if matches!(e.status(), Some(reqwest::StatusCode::NOT_FOUND)) => {
222 log::error!("failed to download {}: {:?}", self.video_details.video_id, e);
223 log::info!("try to download {} using sequenced download", self.video_details.video_id);
224 self.download_full_seq(&mut file, &channel)
226 .await
227 .map_err(|e| {
228 log::error!(
229 "failed to download {} using sequenced download: {:?}",
230 self.video_details.video_id, e
231 );
232 e
233 })
234 }
235 Err(e) => {
236 log::error!("failed to download {}: {:?}", self.video_details.video_id, e);
237 drop(file);
238 tokio::fs::remove_file(path.as_ref()).await?;
239 Err(e)
240 }
241 }.map(|_| path.as_ref().to_path_buf());
242
243 #[cfg(feature = "callback")]
244 if let Some(channel) = channel {
245 let _ = channel.send(InternalSignal::Finished).await;
246 }
247
248 result
249 }
250
251 async fn download_full_seq(&self, file: &mut File, channel: &Option<InternalSender>) -> Result<()> {
252 log::warn!(
255 "`download_full_seq` is not tested yet and probably broken!\n\
256 Please open a GitHub issue (https://github.com/DzenanJupic/rustube/issues) and paste \
257 the whole warning message in:\n\
258 id: {}\n\
259 url: {}",
260 self.video_details.video_id,
261 self.signature_cipher.url.as_str()
262 );
263
264 let mut url = self.signature_cipher.url.clone();
265 let base_query = url
266 .query()
267 .map(str::to_owned)
268 .unwrap_or_else(String::new);
269
270 Self::set_url_seq_query(&mut url, &base_query, 0);
273 let res = self.get(&url).await?;
274 let segment_count = Stream::extract_segment_count(&res)?;
275 self.write_stream_to_file(res.bytes_stream(), file, &None, 0).await?;
277 let mut count = 0;
278
279 for i in 1..segment_count {
280 Self::set_url_seq_query(&mut url, &base_query, i);
281 count = self.download_full(&url, file, channel, count).await?;
282 }
283
284 Ok(())
285 }
286
287 #[inline]
288 async fn download_full(
289 &self,
290 url: &url::Url,
291 file: &mut File,
292 channel: &Option<InternalSender>,
293 count: usize,
294 ) -> Result<usize> {
295 let res = self.get(url).await?;
296 self.write_stream_to_file(res.bytes_stream(), file, channel, count).await
297 }
298
299 #[inline]
300 async fn get(&self, url: &url::Url) -> Result<reqwest::Response> {
301 log::trace!("get: {}", url.as_str());
302 Ok(
303 self.client
304 .get(url.as_str())
305 .send()
306 .await?
307 .error_for_status()?
308 )
309 }
310
311 #[inline]
312 #[allow(unused_variables, unused_mut)]
313 async fn write_stream_to_file(
314 &self,
315 mut stream: impl tokio_stream::Stream<Item=reqwest::Result<bytes::Bytes>> + Unpin,
316 file: &mut File,
317 channel: &Option<InternalSender>,
318 mut counter: usize,
319 ) -> Result<usize> {
320 while let Some(chunk) = stream.next().await {
322 let chunk = chunk?;
323 let len = chunk.len();
324 log::trace!("received {} byte chunk ", len);
325
326 file.write_all(&chunk).await?;
327 #[cfg(feature = "callback")]
328 if let Some(channel) = &channel {
329 counter += len;
331 if let Err(TrySendError::Closed(_)) =
334 channel.try_send(InternalSignal::Value(counter))
335 {
336 return Err(Error::ChannelClosed);
337 }
338 }
339 }
340 Ok(counter)
341 }
342
343 #[inline]
344 fn set_url_seq_query(url: &mut url::Url, base_query: &str, sq: u64) {
345 url.set_query(Some(base_query));
346 url
347 .query_pairs_mut()
348 .append_pair("sq", &sq.to_string());
349 }
350
351 #[inline]
352 fn extract_segment_count(res: &reqwest::Response) -> Result<u64> {
353 res
354 .headers()
355 .get("Segment-Count")
356 .ok_or_else(|| Error::UnexpectedResponse(
357 "sequence download request did not contain a Segment-Count".into()
358 ))?
359 .to_str()
360 .map_err(|_| Error::UnexpectedResponse(
361 "Segment-Count is not valid utf-8".into()
362 ))?
363 .parse::<u64>()
364 .map_err(|_| Error::UnexpectedResponse(
365 "Segment-Count could not be parsed into an integer".into()
366 ))
367 }
368}
369
370#[cfg(all(feature = "download", feature = "blocking"))]
371impl Stream {
372 #[inline]
374 pub fn blocking_download(&self) -> Result<PathBuf> {
375 crate::block!(self.download())
376 }
377
378 #[cfg(feature = "callback")]
380 #[inline]
381 pub fn blocking_download_with_callback<'a>(&'a self, callback: Callback<'a>) -> Result<PathBuf> {
382 crate::block!(self.download_with_callback(callback))
383 }
384
385 #[inline]
387 pub fn blocking_download_to_dir<P: AsRef<Path>>(&self, dir: P) -> Result<PathBuf> {
388 crate::block!(self.download_to_dir(dir))
389 }
390
391 #[cfg(feature = "callback")]
393 #[inline]
394 pub fn blocking_download_to_dir_with_callback<'a, P: AsRef<Path>>(
395 &'a self,
396 dir: P,
397 callback: Callback<'a>,
398 ) -> Result<PathBuf> {
399 crate::block!(self.download_to_dir_with_callback(dir, callback))
400 }
401
402 pub fn blocking_download_to<P: AsRef<Path>>(&self, path: P) -> Result<()> {
404 crate::block!(self.download_to(path))
405 }
406
407 #[cfg(feature = "callback")]
409 pub fn blocking_download_to_with_callback<'a, P: AsRef<Path>>(&'a self, path: P, callback: Callback<'a>) -> Result<()> {
410 crate::block!(self.download_to_with_callback(path, callback))
411 }
412
413 #[inline]
415 pub fn blocking_content_length(&self) -> Result<u64> {
416 crate::block!(self.content_length())
417 }
418}
419
420#[inline]
421fn is_adaptive(codecs: &[String]) -> bool {
422 codecs.len() % 2 != 0
423}
424
425#[inline]
426fn includes_video_track(codecs: &[String], mime: &Mime) -> bool {
427 is_progressive(codecs) || mime.type_() == "video"
428}
429
430#[inline]
431fn includes_audio_track(codecs: &[String], mime: &Mime) -> bool {
432 is_progressive(codecs) || mime.type_() == "audio"
433}
434
435#[inline]
436fn is_progressive(codecs: &[String]) -> bool {
437 !is_adaptive(codecs)
438}
439
440#[inline]
441fn atomic_u64_is_eq(lhs: &Arc<AtomicU64>, rhs: &Arc<AtomicU64>) -> bool {
442 lhs.load(Ordering::Acquire) == rhs.load(Ordering::Acquire)
443}