Skip to main content

whatsapp_rust/
download.rs

1use crate::client::Client;
2use crate::mediaconn::{MEDIA_AUTH_REFRESH_RETRY_ATTEMPTS, MediaConn, is_media_auth_error};
3use anyhow::{Result, anyhow};
4use std::io::{Seek, SeekFrom, Write};
5
6pub use wacore::download::{
7    DownloadUtils, Downloadable, MediaDecryption, MediaDecryptionError, MediaType,
8};
9
10impl From<&MediaConn> for wacore::download::MediaConnection {
11    fn from(conn: &MediaConn) -> Self {
12        wacore::download::MediaConnection {
13            hosts: conn
14                .hosts
15                .iter()
16                .map(|h| wacore::download::MediaHost {
17                    hostname: h.hostname.clone(),
18                })
19                .collect(),
20            auth: conn.auth.clone(),
21        }
22    }
23}
24
25/// Implements `Downloadable` from raw media parameters.
26struct DownloadParams {
27    direct_path: String,
28    media_key: Option<Vec<u8>>,
29    file_sha256: Vec<u8>,
30    file_enc_sha256: Option<Vec<u8>>,
31    file_length: u64,
32    media_type: MediaType,
33}
34
35impl Downloadable for DownloadParams {
36    fn direct_path(&self) -> Option<&str> {
37        Some(&self.direct_path)
38    }
39    fn media_key(&self) -> Option<&[u8]> {
40        self.media_key.as_deref()
41    }
42    fn file_enc_sha256(&self) -> Option<&[u8]> {
43        self.file_enc_sha256.as_deref()
44    }
45    fn file_sha256(&self) -> Option<&[u8]> {
46        Some(&self.file_sha256)
47    }
48    fn file_length(&self) -> Option<u64> {
49        Some(self.file_length)
50    }
51    fn app_info(&self) -> MediaType {
52        self.media_type
53    }
54}
55
56#[derive(Debug)]
57enum DownloadRequestError {
58    Auth(anyhow::Error),
59    /// 404/410 — media URL expired or not found. Needs fresh auth + URL re-derivation.
60    /// Matches WA Web's `MediaNotFoundError` handling.
61    NotFound(anyhow::Error),
62    Other(anyhow::Error),
63}
64
65impl DownloadRequestError {
66    fn auth(status_code: u16) -> Self {
67        Self::Auth(anyhow!("Download failed with status: {}", status_code))
68    }
69
70    fn not_found(status_code: u16) -> Self {
71        Self::NotFound(anyhow!(
72            "Download media not found/expired with status: {}",
73            status_code
74        ))
75    }
76
77    fn other(err: impl Into<anyhow::Error>) -> Self {
78        Self::Other(err.into())
79    }
80
81    fn is_auth(&self) -> bool {
82        matches!(self, Self::Auth(_))
83    }
84
85    /// Returns true for 404/410 (expired URL) — should trigger auth refresh like auth errors.
86    fn is_not_found(&self) -> bool {
87        matches!(self, Self::NotFound(_))
88    }
89
90    fn into_anyhow(self) -> anyhow::Error {
91        match self {
92            Self::Auth(err) | Self::NotFound(err) | Self::Other(err) => err,
93        }
94    }
95}
96
97async fn download_media_with_retry<
98    PrepareRequests,
99    PrepareRequestsFut,
100    InvalidateMediaConn,
101    InvalidateMediaConnFut,
102    ExecuteRequest,
103    ExecuteRequestFut,
104>(
105    mut prepare_requests: PrepareRequests,
106    mut invalidate_media_conn: InvalidateMediaConn,
107    mut execute_request: ExecuteRequest,
108) -> Result<Vec<u8>>
109where
110    PrepareRequests: FnMut(bool) -> PrepareRequestsFut,
111    PrepareRequestsFut:
112        std::future::Future<Output = Result<Vec<wacore::download::DownloadRequest>>>,
113    InvalidateMediaConn: FnMut() -> InvalidateMediaConnFut,
114    InvalidateMediaConnFut: std::future::Future<Output = ()>,
115    ExecuteRequest: FnMut(wacore::download::DownloadRequest) -> ExecuteRequestFut,
116    ExecuteRequestFut:
117        std::future::Future<Output = std::result::Result<Vec<u8>, DownloadRequestError>>,
118{
119    let mut force_refresh = false;
120    let mut last_err: Option<anyhow::Error> = None;
121
122    for attempt in 0..=MEDIA_AUTH_REFRESH_RETRY_ATTEMPTS {
123        let requests = prepare_requests(force_refresh).await?;
124        let mut retry_with_fresh_auth = false;
125
126        for request in requests {
127            match execute_request(request.clone()).await {
128                Ok(data) => return Ok(data),
129                Err(err) if (err.is_auth() || err.is_not_found()) && attempt == 0 => {
130                    // Auth error or 404/410 (expired URL): refresh media conn and re-derive URLs.
131                    // Matches WA Web's MediaNotFoundError → forceRefresh flow.
132                    invalidate_media_conn().await;
133                    force_refresh = true;
134                    retry_with_fresh_auth = true;
135                    break;
136                }
137                Err(err) if err.is_auth() || err.is_not_found() => return Err(err.into_anyhow()),
138                Err(err) => {
139                    let err = err.into_anyhow();
140                    log::warn!(
141                        "Failed to download from URL {}: {:?}. Trying next host.",
142                        request.url,
143                        err
144                    );
145                    last_err = Some(err);
146                }
147            }
148        }
149
150        if !retry_with_fresh_auth {
151            break;
152        }
153    }
154
155    match last_err {
156        Some(err) => Err(err),
157        None => Err(anyhow!("Failed to download from all available media hosts")),
158    }
159}
160
161async fn download_to_writer_with_retry<
162    W,
163    PrepareRequests,
164    PrepareRequestsFut,
165    InvalidateMediaConn,
166    InvalidateMediaConnFut,
167    ExecuteRequest,
168    ExecuteRequestFut,
169>(
170    mut writer: W,
171    mut prepare_requests: PrepareRequests,
172    mut invalidate_media_conn: InvalidateMediaConn,
173    mut execute_request: ExecuteRequest,
174) -> Result<W>
175where
176    W: Write + Seek + Send + 'static,
177    PrepareRequests: FnMut(bool) -> PrepareRequestsFut,
178    PrepareRequestsFut:
179        std::future::Future<Output = Result<Vec<wacore::download::DownloadRequest>>>,
180    InvalidateMediaConn: FnMut() -> InvalidateMediaConnFut,
181    InvalidateMediaConnFut: std::future::Future<Output = ()>,
182    ExecuteRequest: FnMut(wacore::download::DownloadRequest, W) -> ExecuteRequestFut,
183    ExecuteRequestFut:
184        std::future::Future<Output = Result<(W, std::result::Result<(), DownloadRequestError>)>>,
185{
186    let mut force_refresh = false;
187    let mut last_err: Option<anyhow::Error> = None;
188
189    for attempt in 0..=MEDIA_AUTH_REFRESH_RETRY_ATTEMPTS {
190        let requests = prepare_requests(force_refresh).await?;
191        let mut retry_with_fresh_auth = false;
192
193        for request in requests {
194            let (next_writer, result) = execute_request(request.clone(), writer).await?;
195            writer = next_writer;
196
197            match result {
198                Ok(()) => return Ok(writer),
199                Err(err) if (err.is_auth() || err.is_not_found()) && attempt == 0 => {
200                    invalidate_media_conn().await;
201                    force_refresh = true;
202                    retry_with_fresh_auth = true;
203                    break;
204                }
205                Err(err) if err.is_auth() || err.is_not_found() => return Err(err.into_anyhow()),
206                Err(err) => {
207                    let err = err.into_anyhow();
208                    log::warn!(
209                        "Failed to stream-download from URL {}: {:?}. Trying next host.",
210                        request.url,
211                        err
212                    );
213                    last_err = Some(err);
214                }
215            }
216        }
217
218        if !retry_with_fresh_auth {
219            break;
220        }
221    }
222
223    match last_err {
224        Some(err) => Err(err),
225        None => Err(anyhow!("Failed to download from all available media hosts")),
226    }
227}
228
229impl Client {
230    pub async fn download(&self, downloadable: &dyn Downloadable) -> Result<Vec<u8>> {
231        download_media_with_retry(
232            |force| self.prepare_requests(downloadable, force),
233            || async { self.invalidate_media_conn().await },
234            |request| async move { self.download_with_request(&request).await },
235        )
236        .await
237    }
238
239    pub async fn download_to_file<W: Write + Seek + Send + Unpin>(
240        &self,
241        downloadable: &dyn Downloadable,
242        mut writer: W,
243    ) -> Result<()> {
244        let data = self.download(downloadable).await?;
245        writer.seek(SeekFrom::Start(0))?;
246        writer.write_all(&data)?;
247        Ok(())
248    }
249
250    /// Downloads and decrypts media from raw parameters without needing the original message.
251    pub async fn download_from_params(
252        &self,
253        direct_path: &str,
254        media_key: &[u8],
255        file_sha256: &[u8],
256        file_enc_sha256: &[u8],
257        file_length: u64,
258        media_type: MediaType,
259    ) -> Result<Vec<u8>> {
260        let params = DownloadParams {
261            direct_path: direct_path.to_string(),
262            media_key: Some(media_key.to_vec()),
263            file_sha256: file_sha256.to_vec(),
264            file_enc_sha256: Some(file_enc_sha256.to_vec()),
265            file_length,
266            media_type,
267        };
268        self.download(&params).await
269    }
270
271    async fn prepare_requests(
272        &self,
273        downloadable: &dyn Downloadable,
274        force_refresh: bool,
275    ) -> Result<Vec<wacore::download::DownloadRequest>> {
276        let media_conn = self.refresh_media_conn(force_refresh).await?;
277        let core_media_conn = wacore::download::MediaConnection::from(&media_conn);
278        DownloadUtils::prepare_download_requests(downloadable, &core_media_conn)
279    }
280
281    async fn download_with_request(
282        &self,
283        request: &wacore::download::DownloadRequest,
284    ) -> std::result::Result<Vec<u8>, DownloadRequestError> {
285        let url = request.url.clone();
286        let decryption = request.decryption.clone();
287        let http_request = crate::http::HttpRequest::get(url);
288        let response = self
289            .http_client
290            .execute(http_request)
291            .await
292            .map_err(DownloadRequestError::other)?;
293
294        if response.status_code >= 300 {
295            return Err(if is_media_auth_error(response.status_code) {
296                DownloadRequestError::auth(response.status_code)
297            } else if matches!(response.status_code, 404 | 410) {
298                DownloadRequestError::not_found(response.status_code)
299            } else {
300                DownloadRequestError::other(anyhow!(
301                    "Download failed with status: {}",
302                    response.status_code
303                ))
304            });
305        }
306
307        match decryption {
308            MediaDecryption::Encrypted {
309                media_key,
310                media_type,
311            } => wacore::runtime::blocking(&*self.runtime, move || {
312                DownloadUtils::decrypt_stream(&response.body[..], &media_key, media_type)
313            })
314            .await
315            .map_err(DownloadRequestError::other),
316            MediaDecryption::Plaintext { file_sha256 } => {
317                let body = response.body;
318                wacore::runtime::blocking(&*self.runtime, move || {
319                    DownloadUtils::validate_plaintext_sha256(&body, &file_sha256)?;
320                    Ok::<Vec<u8>, anyhow::Error>(body)
321                })
322                .await
323                .map_err(DownloadRequestError::other)
324            }
325        }
326    }
327
328    /// Downloads and decrypts media with streaming (constant memory usage).
329    ///
330    /// The entire HTTP download, decryption, and file write happen in a single
331    /// blocking thread. The writer is seeked back to position 0 before returning.
332    ///
333    /// Memory usage: ~40KB regardless of file size (8KB read buffer + decrypt state).
334    pub async fn download_to_writer<W: Write + Seek + Send + 'static>(
335        &self,
336        downloadable: &dyn Downloadable,
337        writer: W,
338    ) -> Result<W> {
339        download_to_writer_with_retry(
340            writer,
341            |force| self.prepare_requests(downloadable, force),
342            || async { self.invalidate_media_conn().await },
343            |request, writer| async move { self.streaming_download_and_decrypt(&request, writer).await },
344        )
345        .await
346    }
347
348    /// Streaming variant of `download_from_params` that writes to a writer
349    /// instead of buffering in memory.
350    #[allow(clippy::too_many_arguments)]
351    pub async fn download_from_params_to_writer<W: Write + Seek + Send + 'static>(
352        &self,
353        direct_path: &str,
354        media_key: &[u8],
355        file_sha256: &[u8],
356        file_enc_sha256: &[u8],
357        file_length: u64,
358        media_type: MediaType,
359        writer: W,
360    ) -> Result<W> {
361        let params = DownloadParams {
362            direct_path: direct_path.to_string(),
363            media_key: Some(media_key.to_vec()),
364            file_sha256: file_sha256.to_vec(),
365            file_enc_sha256: Some(file_enc_sha256.to_vec()),
366            file_length,
367            media_type,
368        };
369        self.download_to_writer(&params, writer).await
370    }
371
372    /// Internal: stream download + decrypt to a writer in one blocking thread.
373    /// Always returns the writer (even on failure) so the caller can retry.
374    async fn streaming_download_and_decrypt<W: Write + Seek + Send + 'static>(
375        &self,
376        request: &wacore::download::DownloadRequest,
377        writer: W,
378    ) -> Result<(W, std::result::Result<(), DownloadRequestError>)> {
379        let http_client = self.http_client.clone();
380        let url = request.url.clone();
381        let decryption = request.decryption.clone();
382
383        wacore::runtime::blocking(&*self.runtime, move || {
384            let mut writer = writer;
385
386            // Seek to start before each attempt so retries start fresh
387            if let Err(e) = writer.seek(SeekFrom::Start(0)) {
388                return Ok((writer, Err(DownloadRequestError::other(e))));
389            }
390
391            let result = (|| -> std::result::Result<(), DownloadRequestError> {
392                let http_request = crate::http::HttpRequest::get(url);
393                let resp = http_client
394                    .execute_streaming(http_request)
395                    .map_err(DownloadRequestError::other)?;
396
397                if resp.status_code >= 300 {
398                    return Err(if is_media_auth_error(resp.status_code) {
399                        DownloadRequestError::auth(resp.status_code)
400                    } else if matches!(resp.status_code, 404 | 410) {
401                        DownloadRequestError::not_found(resp.status_code)
402                    } else {
403                        DownloadRequestError::other(anyhow!(
404                            "Download failed with status: {}",
405                            resp.status_code
406                        ))
407                    });
408                }
409
410                match &decryption {
411                    MediaDecryption::Encrypted {
412                        media_key,
413                        media_type,
414                    } => {
415                        DownloadUtils::decrypt_stream_to_writer(
416                            resp.body,
417                            media_key,
418                            *media_type,
419                            &mut writer,
420                        )
421                        .map_err(DownloadRequestError::other)?;
422                    }
423                    MediaDecryption::Plaintext { file_sha256 } => {
424                        DownloadUtils::copy_and_validate_plaintext_to_writer(
425                            resp.body,
426                            file_sha256,
427                            &mut writer,
428                        )
429                        .map_err(DownloadRequestError::other)?;
430                    }
431                }
432                writer
433                    .seek(SeekFrom::Start(0))
434                    .map_err(DownloadRequestError::other)?;
435                Ok(())
436            })();
437
438            Ok((writer, result))
439        })
440        .await
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use crate::mediaconn::{MediaConn, MediaConnHost};
448    use async_lock::Mutex;
449    use std::io::Cursor;
450    use std::sync::Arc;
451    use wacore::time::Instant;
452
453    struct PlaintextDownloadable {
454        direct_path: String,
455        file_sha256: Vec<u8>,
456    }
457
458    impl Downloadable for PlaintextDownloadable {
459        fn direct_path(&self) -> Option<&str> {
460            Some(&self.direct_path)
461        }
462
463        fn media_key(&self) -> Option<&[u8]> {
464            None
465        }
466
467        fn file_enc_sha256(&self) -> Option<&[u8]> {
468            None
469        }
470
471        fn file_sha256(&self) -> Option<&[u8]> {
472            Some(&self.file_sha256)
473        }
474
475        fn file_length(&self) -> Option<u64> {
476            None
477        }
478
479        fn app_info(&self) -> MediaType {
480            MediaType::Image
481        }
482    }
483
484    fn media_conn(auth: &str, hosts: &[&str]) -> MediaConn {
485        MediaConn {
486            auth: auth.to_string(),
487            ttl: 60,
488            auth_ttl: None,
489            hosts: hosts
490                .iter()
491                .map(|hostname| MediaConnHost::new((*hostname).to_string()))
492                .collect(),
493            fetched_at: Instant::now(),
494        }
495    }
496
497    fn plaintext_sha256(data: &[u8]) -> Vec<u8> {
498        wacore::upload::encrypt_media(data, MediaType::Image)
499            .expect("hash derivation should succeed")
500            .file_sha256
501            .to_vec()
502    }
503
504    #[test]
505    fn process_downloaded_media_ok() {
506        let data = b"Hello media test";
507        let enc = wacore::upload::encrypt_media(data, MediaType::Image)
508            .expect("encryption should succeed");
509        let mut cursor = Cursor::new(Vec::<u8>::new());
510        let plaintext = DownloadUtils::verify_and_decrypt(
511            &enc.data_to_upload,
512            &enc.media_key,
513            MediaType::Image,
514        )
515        .expect("decryption should succeed");
516        cursor.write_all(&plaintext).expect("write should succeed");
517        assert_eq!(cursor.into_inner(), data);
518    }
519
520    #[test]
521    fn process_downloaded_media_bad_mac() {
522        let data = b"Tamper";
523        let mut enc = wacore::upload::encrypt_media(data, MediaType::Image)
524            .expect("encryption should succeed");
525        let last = enc.data_to_upload.len() - 1;
526        enc.data_to_upload[last] ^= 0x01;
527
528        let err = DownloadUtils::verify_and_decrypt(
529            &enc.data_to_upload,
530            &enc.media_key,
531            MediaType::Image,
532        )
533        .unwrap_err();
534
535        assert!(
536            matches!(&err, wacore::download::MediaDecryptionError::InvalidMac),
537            "Expected InvalidMac, got: {}",
538            err
539        );
540    }
541
542    #[tokio::test]
543    async fn download_retries_with_forced_media_conn_refresh_after_auth_error() {
544        let body = b"download me".to_vec();
545        let downloadable = PlaintextDownloadable {
546            direct_path: "/v/t62.7118-24/123".to_string(),
547            file_sha256: plaintext_sha256(&body),
548        };
549        let first_conn = media_conn("stale-auth", &["cdn1.example.com"]);
550        let refreshed_conn = media_conn("fresh-auth", &["cdn2.example.com"]);
551        let refresh_calls = Arc::new(Mutex::new(Vec::new()));
552        let invalidations = Arc::new(Mutex::new(0usize));
553        let seen_urls = Arc::new(Mutex::new(Vec::new()));
554
555        let downloaded = download_media_with_retry(
556            {
557                let refresh_calls = Arc::clone(&refresh_calls);
558                let downloadable = &downloadable;
559                move |force| {
560                    let refresh_calls = Arc::clone(&refresh_calls);
561                    let first_conn = first_conn.clone();
562                    let refreshed_conn = refreshed_conn.clone();
563                    async move {
564                        refresh_calls.lock().await.push(force);
565                        let media_conn = if force { refreshed_conn } else { first_conn };
566                        DownloadUtils::prepare_download_requests(
567                            downloadable,
568                            &wacore::download::MediaConnection::from(&media_conn),
569                        )
570                    }
571                }
572            },
573            {
574                let invalidations = Arc::clone(&invalidations);
575                move || {
576                    let invalidations = Arc::clone(&invalidations);
577                    async move {
578                        *invalidations.lock().await += 1;
579                    }
580                }
581            },
582            {
583                let seen_urls = Arc::clone(&seen_urls);
584                let body = body.clone();
585                move |request| {
586                    let seen_urls = Arc::clone(&seen_urls);
587                    let body = body.clone();
588                    let url = request.url.clone();
589                    async move {
590                        seen_urls.lock().await.push(url.clone());
591                        if url.contains("stale-auth") {
592                            Err(DownloadRequestError::auth(401))
593                        } else {
594                            Ok(body)
595                        }
596                    }
597                }
598            },
599        )
600        .await
601        .expect("download should succeed after refreshing media auth");
602
603        assert_eq!(downloaded, body);
604        assert_eq!(*refresh_calls.lock().await, vec![false, true]);
605        assert_eq!(*invalidations.lock().await, 1);
606
607        let seen_urls = seen_urls.lock().await.clone();
608        assert_eq!(seen_urls.len(), 2);
609        assert!(seen_urls[0].contains("auth=stale-auth"));
610        assert!(seen_urls[1].contains("auth=fresh-auth"));
611    }
612
613    #[tokio::test]
614    async fn download_to_writer_retries_with_forced_media_conn_refresh_after_auth_error() {
615        let body = b"stream me".to_vec();
616        let downloadable = PlaintextDownloadable {
617            direct_path: "/v/t62.7118-24/stream".to_string(),
618            file_sha256: plaintext_sha256(&body),
619        };
620        let first_conn = media_conn("stale-auth", &["cdn1.example.com"]);
621        let refreshed_conn = media_conn("fresh-auth", &["cdn2.example.com"]);
622        let refresh_calls = Arc::new(Mutex::new(Vec::new()));
623        let invalidations = Arc::new(Mutex::new(0usize));
624        let seen_urls = Arc::new(Mutex::new(Vec::new()));
625
626        let writer = download_to_writer_with_retry(
627            Cursor::new(Vec::<u8>::new()),
628            {
629                let refresh_calls = Arc::clone(&refresh_calls);
630                let downloadable = &downloadable;
631                move |force| {
632                    let refresh_calls = Arc::clone(&refresh_calls);
633                    let first_conn = first_conn.clone();
634                    let refreshed_conn = refreshed_conn.clone();
635                    async move {
636                        refresh_calls.lock().await.push(force);
637                        let media_conn = if force { refreshed_conn } else { first_conn };
638                        DownloadUtils::prepare_download_requests(
639                            downloadable,
640                            &wacore::download::MediaConnection::from(&media_conn),
641                        )
642                    }
643                }
644            },
645            {
646                let invalidations = Arc::clone(&invalidations);
647                move || {
648                    let invalidations = Arc::clone(&invalidations);
649                    async move {
650                        *invalidations.lock().await += 1;
651                    }
652                }
653            },
654            {
655                let seen_urls = Arc::clone(&seen_urls);
656                let body = body.clone();
657                move |request, mut writer| {
658                    let seen_urls = Arc::clone(&seen_urls);
659                    let body = body.clone();
660                    let url = request.url.clone();
661                    async move {
662                        seen_urls.lock().await.push(url.clone());
663                        writer.seek(SeekFrom::Start(0))?;
664                        if url.contains("stale-auth") {
665                            Ok((writer, Err(DownloadRequestError::auth(403))))
666                        } else {
667                            writer.write_all(&body)?;
668                            writer.seek(SeekFrom::Start(0))?;
669                            Ok((writer, Ok(())))
670                        }
671                    }
672                }
673            },
674        )
675        .await
676        .expect("streaming download should succeed after refreshing media auth");
677
678        assert_eq!(writer.into_inner(), body);
679        assert_eq!(*refresh_calls.lock().await, vec![false, true]);
680        assert_eq!(*invalidations.lock().await, 1);
681
682        let seen_urls = seen_urls.lock().await.clone();
683        assert_eq!(seen_urls.len(), 2);
684        assert!(seen_urls[0].contains("auth=stale-auth"));
685        assert!(seen_urls[1].contains("auth=fresh-auth"));
686    }
687}