fast_pull/reqwest/
prefetch.rs

1use crate::{UrlInfo, reqwest::ReqwestReader};
2use content_disposition;
3use reqwest::{
4    Client, IntoUrl, StatusCode, Url,
5    header::{self, HeaderMap},
6};
7use sanitize_filename;
8
9fn get_file_size(headers: &HeaderMap, status: &StatusCode) -> u64 {
10    if *status == StatusCode::PARTIAL_CONTENT {
11        headers
12            .get(header::CONTENT_RANGE)
13            .and_then(|hv| hv.to_str().ok())
14            .and_then(|s| s.rsplit('/').next())
15            .and_then(|total| total.parse().ok())
16            .unwrap_or(0)
17    } else {
18        headers
19            .get(header::CONTENT_LENGTH)
20            .and_then(|hv| hv.to_str().ok())
21            .and_then(|s| s.parse().ok())
22            .unwrap_or(0)
23    }
24}
25
26fn get_header_str(headers: &HeaderMap, header_name: &header::HeaderName) -> Option<String> {
27    headers
28        .get(header_name)
29        .and_then(|hv| hv.to_str().ok())
30        .map(String::from)
31}
32
33fn get_filename(headers: &HeaderMap, final_url: &Url) -> String {
34    let from_disposition = headers
35        .get(header::CONTENT_DISPOSITION)
36        .and_then(|hv| hv.to_str().ok())
37        .and_then(|s| content_disposition::parse_content_disposition(s).filename_full())
38        .filter(|s| !s.trim().is_empty());
39
40    let from_url = final_url
41        .path_segments()
42        .and_then(|mut segments| segments.next_back())
43        .and_then(|s| urlencoding::decode(s).ok())
44        .filter(|s| !s.trim().is_empty())
45        .map(|s| s.to_string());
46
47    let raw_name = from_disposition
48        .or(from_url)
49        .unwrap_or_else(|| final_url.to_string());
50
51    sanitize_filename::sanitize_with_options(
52        &raw_name,
53        sanitize_filename::Options {
54            windows: true,
55            truncate: true,
56            replacement: "_",
57        },
58    )
59}
60
61pub trait Prefetch {
62    fn prefetch(
63        &self,
64        url: impl IntoUrl + Send,
65    ) -> impl Future<Output = Result<UrlInfo, reqwest::Error>> + Send;
66}
67
68impl Prefetch for Client {
69    async fn prefetch(&self, url: impl IntoUrl + Send) -> Result<UrlInfo, reqwest::Error> {
70        let url = url.into_url()?;
71        let resp = self.head(url.clone()).send().await?;
72        let resp = match resp.error_for_status() {
73            Ok(resp) => resp,
74            Err(_) => return prefetch_fallback(url, self).await,
75        };
76
77        let status = resp.status();
78        let final_url = resp.url();
79
80        let resp_headers = resp.headers();
81        let size = get_file_size(resp_headers, &status);
82
83        let supports_range = match resp.headers().get(header::ACCEPT_RANGES) {
84            Some(accept_ranges) => accept_ranges
85                .to_str()
86                .ok()
87                .map(|v| v.split(' '))
88                .and_then(|supports| supports.into_iter().find(|&ty| ty == "bytes"))
89                .is_some(),
90            None => return prefetch_fallback(url, self).await,
91        };
92
93        Ok(UrlInfo {
94            final_url: final_url.clone(),
95            name: Some(get_filename(resp_headers, final_url)),
96            size,
97            supports_range,
98            fast_download: size > 0 && supports_range,
99            etag: get_header_str(resp_headers, &header::ETAG),
100            last_modified: get_header_str(resp_headers, &header::LAST_MODIFIED),
101        })
102    }
103}
104
105impl Prefetch for ReqwestReader {
106    fn prefetch(
107        &self,
108        url: impl IntoUrl + Send,
109    ) -> impl Future<Output = Result<UrlInfo, reqwest::Error>> + Send {
110        self.client.prefetch(url)
111    }
112}
113
114async fn prefetch_fallback(url: Url, client: &Client) -> Result<UrlInfo, reqwest::Error> {
115    let resp = client
116        .get(url)
117        .header(header::RANGE, "bytes=0-")
118        .send()
119        .await?
120        .error_for_status()?;
121    let status = resp.status();
122    let final_url = resp.url();
123
124    let resp_headers = resp.headers();
125    let size = get_file_size(resp_headers, &status);
126    let supports_range = status == StatusCode::PARTIAL_CONTENT;
127    Ok(UrlInfo {
128        final_url: final_url.clone(),
129        name: Some(get_filename(resp_headers, final_url)),
130        size,
131        supports_range,
132        fast_download: size > 0 && supports_range,
133        etag: get_header_str(resp_headers, &header::ETAG),
134        last_modified: get_header_str(resp_headers, &header::LAST_MODIFIED),
135    })
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[tokio::test]
143    async fn test_redirect_and_content_range() {
144        let mut server = mockito::Server::new_async().await;
145
146        let mock_redirect = server
147            .mock("GET", "/redirect")
148            .with_status(301)
149            .with_header("Location", "/real-file.txt")
150            .create_async()
151            .await;
152
153        let mock_file = server
154            .mock("GET", "/real-file.txt")
155            .with_status(206)
156            .with_header("Content-Range", "bytes 0-1023/2048")
157            .with_body(vec![0; 1024])
158            .create_async()
159            .await;
160
161        let client = Client::new();
162        let url_info = client
163            .prefetch(&format!("{}/redirect", server.url()))
164            .await
165            .expect("Request should succeed");
166
167        assert_eq!(url_info.size, 2048);
168        assert_eq!(url_info.name.as_deref(), Some("real-file.txt"));
169        assert_eq!(
170            url_info.final_url.as_str(),
171            format!("{}/real-file.txt", server.url())
172        );
173        assert!(url_info.supports_range);
174
175        mock_redirect.assert_async().await;
176        mock_file.assert_async().await;
177    }
178
179    #[tokio::test]
180    async fn test_content_range_priority() {
181        let mut server = mockito::Server::new_async().await;
182        let mock = server
183            .mock("GET", "/file")
184            .with_status(206)
185            .with_header("Content-Range", "bytes 0-1023/2048")
186            .create_async()
187            .await;
188
189        let client = Client::new();
190        let url_info = client
191            .prefetch(&format!("{}/file", server.url()))
192            .await
193            .expect("Request should succeed");
194
195        assert_eq!(url_info.size, 2048);
196        mock.assert_async().await;
197    }
198
199    #[tokio::test]
200    async fn test_filename_sources() {
201        let mut server = mockito::Server::new_async().await;
202
203        // Test Content-Disposition source
204        let mock1 = server
205            .mock("GET", "/test1")
206            .with_header("Content-Disposition", "attachment; filename=\"test.txt\"")
207            .create_async()
208            .await;
209        let url_info = Client::new()
210            .prefetch(&format!("{}/test1", server.url()))
211            .await
212            .unwrap();
213        assert_eq!(url_info.name.as_deref(), Some("test.txt"));
214        mock1.assert_async().await;
215
216        // Test URL path source
217        let mock2 = server.mock("GET", "/test2/file.pdf").create_async().await;
218        let url_info = Client::new()
219            .prefetch(&format!("{}/test2/file.pdf", server.url()))
220            .await
221            .unwrap();
222        assert_eq!(url_info.name.as_deref(), Some("file.pdf"));
223        mock2.assert_async().await;
224
225        // Test sanitization
226        let mock3 = server
227      .mock("GET", "/test3")
228      .with_header(
229        "Content-Disposition",
230        "attachment; filename*=UTF-8''%E6%82%AA%E3%81%84%3C%3E%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%3F%E5%90%8D.txt",
231      )
232      .create_async()
233      .await;
234        let url_info = Client::new()
235            .prefetch(&format!("{}/test3", server.url()))
236            .await
237            .unwrap();
238        assert_eq!(url_info.name.as_deref(), Some("悪い__ファイル_名.txt"));
239        mock3.assert_async().await;
240    }
241
242    #[tokio::test]
243    async fn test_error_handling() {
244        let mut server = mockito::Server::new_async().await;
245        let mock1 = server
246            .mock("GET", "/404")
247            .with_status(404)
248            .create_async()
249            .await;
250
251        let client = Client::new();
252
253        match client.prefetch(&format!("{}/404", server.url())).await {
254            Ok(info) => panic!("404 status code should not success: {info:?}"),
255            Err(err) => {
256                assert!(err.is_status(), "should be error about status code");
257                assert_eq!(err.status(), Some(StatusCode::NOT_FOUND));
258            }
259        }
260
261        mock1.assert_async().await;
262    }
263}