fast_pull/reqwest/
prefetch.rs

1extern crate alloc;
2use crate::{UrlInfo, reqwest::ReqwestPuller};
3use alloc::string::{String, ToString};
4use content_disposition;
5use reqwest::{
6    Client, IntoUrl, StatusCode, Url,
7    header::{self, HeaderMap},
8};
9use sanitize_filename;
10
11fn get_file_size(headers: &HeaderMap, status: &StatusCode) -> u64 {
12    if *status == StatusCode::PARTIAL_CONTENT {
13        headers
14            .get(header::CONTENT_RANGE)
15            .and_then(|hv| hv.to_str().ok())
16            .and_then(|s| s.rsplit('/').next())
17            .and_then(|total| total.parse().ok())
18            .unwrap_or(0)
19    } else {
20        headers
21            .get(header::CONTENT_LENGTH)
22            .and_then(|hv| hv.to_str().ok())
23            .and_then(|s| s.parse().ok())
24            .unwrap_or(0)
25    }
26}
27
28fn get_header_str(headers: &HeaderMap, header_name: &header::HeaderName) -> Option<String> {
29    headers
30        .get(header_name)
31        .and_then(|hv| hv.to_str().ok())
32        .map(String::from)
33}
34
35fn get_filename(headers: &HeaderMap, final_url: &Url) -> String {
36    let from_disposition = headers
37        .get(header::CONTENT_DISPOSITION)
38        .and_then(|hv| hv.to_str().ok())
39        .and_then(|s| content_disposition::parse_content_disposition(s).filename_full())
40        .filter(|s| !s.trim().is_empty());
41
42    let from_url = final_url
43        .path_segments()
44        .and_then(|mut segments| segments.next_back())
45        .and_then(|s| urlencoding::decode(s).ok())
46        .filter(|s| !s.trim().is_empty())
47        .map(|s| s.to_string());
48
49    let raw_name = from_disposition
50        .or(from_url)
51        .unwrap_or_else(|| final_url.to_string());
52
53    sanitize_filename::sanitize_with_options(
54        &raw_name,
55        sanitize_filename::Options {
56            windows: true,
57            truncate: true,
58            replacement: "_",
59        },
60    )
61}
62
63pub trait Prefetch {
64    fn prefetch(
65        &self,
66        url: impl IntoUrl + Send,
67    ) -> impl Future<Output = Result<UrlInfo, reqwest::Error>> + Send;
68}
69
70impl Prefetch for Client {
71    async fn prefetch(&self, url: impl IntoUrl + Send) -> Result<UrlInfo, reqwest::Error> {
72        let url = url.into_url()?;
73        let resp = match self.head(url.clone()).send().await {
74            Ok(resp) => resp,
75            Err(_) => return prefetch_fallback(url, self).await,
76        };
77        let resp = match resp.error_for_status() {
78            Ok(resp) => resp,
79            Err(_) => return prefetch_fallback(url, self).await,
80        };
81
82        let status = resp.status();
83        let final_url = resp.url();
84
85        let resp_headers = resp.headers();
86        let size = get_file_size(resp_headers, &status);
87
88        let supports_range = match resp.headers().get(header::ACCEPT_RANGES) {
89            Some(accept_ranges) => accept_ranges
90                .to_str()
91                .ok()
92                .map(|v| v.split(' '))
93                .and_then(|supports| supports.into_iter().find(|&ty| ty == "bytes"))
94                .is_some(),
95            None => return prefetch_fallback(url, self).await,
96        };
97
98        Ok(UrlInfo {
99            final_url: final_url.clone(),
100            name: get_filename(resp_headers, final_url),
101            size,
102            supports_range,
103            fast_download: size > 0 && supports_range,
104            etag: get_header_str(resp_headers, &header::ETAG),
105            last_modified: get_header_str(resp_headers, &header::LAST_MODIFIED),
106        })
107    }
108}
109
110impl Prefetch for ReqwestPuller {
111    fn prefetch(
112        &self,
113        url: impl IntoUrl + Send,
114    ) -> impl Future<Output = Result<UrlInfo, reqwest::Error>> + Send {
115        self.client.prefetch(url)
116    }
117}
118
119async fn prefetch_fallback(url: Url, client: &Client) -> Result<UrlInfo, reqwest::Error> {
120    let resp = client
121        .get(url)
122        .header(header::RANGE, "bytes=0-")
123        .send()
124        .await?
125        .error_for_status()?;
126    let status = resp.status();
127    let final_url = resp.url();
128
129    let resp_headers = resp.headers();
130    let size = get_file_size(resp_headers, &status);
131    let supports_range = status == StatusCode::PARTIAL_CONTENT;
132    Ok(UrlInfo {
133        final_url: final_url.clone(),
134        name: get_filename(resp_headers, final_url),
135        size,
136        supports_range,
137        fast_download: size > 0 && supports_range,
138        etag: get_header_str(resp_headers, &header::ETAG),
139        last_modified: get_header_str(resp_headers, &header::LAST_MODIFIED),
140    })
141}
142
143#[cfg(test)]
144mod tests {
145    use alloc::{format, vec};
146
147    use super::*;
148
149    #[tokio::test]
150    async fn test_redirect_and_content_range() {
151        let mut server = mockito::Server::new_async().await;
152
153        let mock_redirect = server
154            .mock("GET", "/redirect")
155            .with_status(301)
156            .with_header("Location", "/real-file.txt")
157            .create_async()
158            .await;
159
160        let mock_file = server
161            .mock("GET", "/real-file.txt")
162            .with_status(206)
163            .with_header("Content-Range", "bytes 0-1023/2048")
164            .with_body(vec![0; 1024])
165            .create_async()
166            .await;
167
168        let client = Client::new();
169        let url_info = client
170            .prefetch(&format!("{}/redirect", server.url()))
171            .await
172            .expect("Request should succeed");
173
174        assert_eq!(url_info.size, 2048);
175        assert_eq!(url_info.name, "real-file.txt");
176        assert_eq!(
177            url_info.final_url.as_str(),
178            format!("{}/real-file.txt", server.url())
179        );
180        assert!(url_info.supports_range);
181
182        mock_redirect.assert_async().await;
183        mock_file.assert_async().await;
184    }
185
186    #[tokio::test]
187    async fn test_content_range_priority() {
188        let mut server = mockito::Server::new_async().await;
189        let mock = server
190            .mock("GET", "/file")
191            .with_status(206)
192            .with_header("Content-Range", "bytes 0-1023/2048")
193            .create_async()
194            .await;
195
196        let client = Client::new();
197        let url_info = client
198            .prefetch(&format!("{}/file", server.url()))
199            .await
200            .expect("Request should succeed");
201
202        assert_eq!(url_info.size, 2048);
203        mock.assert_async().await;
204    }
205
206    #[tokio::test]
207    async fn test_filename_sources() {
208        let mut server = mockito::Server::new_async().await;
209
210        // Test Content-Disposition source
211        let mock1 = server
212            .mock("GET", "/test1")
213            .with_header("Content-Disposition", "attachment; filename=\"test.txt\"")
214            .create_async()
215            .await;
216        let url_info = Client::new()
217            .prefetch(&format!("{}/test1", server.url()))
218            .await
219            .unwrap();
220        assert_eq!(url_info.name, "test.txt");
221        mock1.assert_async().await;
222
223        // Test URL path source
224        let mock2 = server.mock("GET", "/test2/file.pdf").create_async().await;
225        let url_info = Client::new()
226            .prefetch(&format!("{}/test2/file.pdf", server.url()))
227            .await
228            .unwrap();
229        assert_eq!(url_info.name, "file.pdf");
230        mock2.assert_async().await;
231
232        // Test sanitization
233        let mock3 = server
234      .mock("GET", "/test3")
235      .with_header(
236        "Content-Disposition",
237        "attachment; filename*=UTF-8''%E6%82%AA%E3%81%84%3C%3E%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%3F%E5%90%8D.txt",
238      )
239      .create_async()
240      .await;
241        let url_info = Client::new()
242            .prefetch(&format!("{}/test3", server.url()))
243            .await
244            .unwrap();
245        assert_eq!(url_info.name, "悪い__ファイル_名.txt");
246        mock3.assert_async().await;
247    }
248
249    #[tokio::test]
250    async fn test_error_handling() {
251        let mut server = mockito::Server::new_async().await;
252        let mock1 = server
253            .mock("GET", "/404")
254            .with_status(404)
255            .create_async()
256            .await;
257
258        let client = Client::new();
259
260        match client.prefetch(&format!("{}/404", server.url())).await {
261            Ok(info) => panic!("404 status code should not success: {info:?}"),
262            Err(err) => {
263                assert!(err.is_status(), "should be error about status code");
264                assert_eq!(err.status(), Some(StatusCode::NOT_FOUND));
265            }
266        }
267
268        mock1.assert_async().await;
269    }
270}