fast_pull/reqwest/
prefetch.rs1use crate::{UrlInfo, reqwest::ReqwestReader};
2use content_disposition;
3use reqwest::{
4 Client, IntoUrl, StatusCode, Url,
5 header::{self, HeaderMap},
6};
7use sanitize_filename;
8
9fn get_file_size(headers: &HeaderMap, status: &StatusCode) -> u64 {
10 if *status == StatusCode::PARTIAL_CONTENT {
11 headers
12 .get(header::CONTENT_RANGE)
13 .and_then(|hv| hv.to_str().ok())
14 .and_then(|s| s.rsplit('/').next())
15 .and_then(|total| total.parse().ok())
16 .unwrap_or(0)
17 } else {
18 headers
19 .get(header::CONTENT_LENGTH)
20 .and_then(|hv| hv.to_str().ok())
21 .and_then(|s| s.parse().ok())
22 .unwrap_or(0)
23 }
24}
25
26fn get_header_str(headers: &HeaderMap, header_name: &header::HeaderName) -> Option<String> {
27 headers
28 .get(header_name)
29 .and_then(|hv| hv.to_str().ok())
30 .map(String::from)
31}
32
33fn get_filename(headers: &HeaderMap, final_url: &Url) -> String {
34 let from_disposition = headers
35 .get(header::CONTENT_DISPOSITION)
36 .and_then(|hv| hv.to_str().ok())
37 .and_then(|s| content_disposition::parse_content_disposition(s).filename_full())
38 .filter(|s| !s.trim().is_empty());
39
40 let from_url = final_url
41 .path_segments()
42 .and_then(|mut segments| segments.next_back())
43 .and_then(|s| urlencoding::decode(s).ok())
44 .filter(|s| !s.trim().is_empty())
45 .map(|s| s.to_string());
46
47 let raw_name = from_disposition
48 .or(from_url)
49 .unwrap_or_else(|| final_url.to_string());
50
51 sanitize_filename::sanitize_with_options(
52 &raw_name,
53 sanitize_filename::Options {
54 windows: true,
55 truncate: true,
56 replacement: "_",
57 },
58 )
59}
60
61pub trait Prefetch {
62 fn prefetch(
63 &self,
64 url: impl IntoUrl + Send,
65 ) -> impl Future<Output = Result<UrlInfo, reqwest::Error>> + Send;
66}
67
68impl Prefetch for Client {
69 async fn prefetch(&self, url: impl IntoUrl + Send) -> Result<UrlInfo, reqwest::Error> {
70 let url = url.into_url()?;
71 let resp = self.head(url.clone()).send().await?;
72 let resp = match resp.error_for_status() {
73 Ok(resp) => resp,
74 Err(_) => return prefetch_fallback(url, self).await,
75 };
76
77 let status = resp.status();
78 let final_url = resp.url();
79
80 let resp_headers = resp.headers();
81 let size = get_file_size(resp_headers, &status);
82
83 let supports_range = match resp.headers().get(header::ACCEPT_RANGES) {
84 Some(accept_ranges) => accept_ranges
85 .to_str()
86 .ok()
87 .map(|v| v.split(' '))
88 .and_then(|supports| supports.into_iter().find(|&ty| ty == "bytes"))
89 .is_some(),
90 None => return prefetch_fallback(url, self).await,
91 };
92
93 Ok(UrlInfo {
94 final_url: final_url.clone(),
95 name: Some(get_filename(resp_headers, final_url)),
96 size,
97 supports_range,
98 fast_download: size > 0 && supports_range,
99 etag: get_header_str(resp_headers, &header::ETAG),
100 last_modified: get_header_str(resp_headers, &header::LAST_MODIFIED),
101 })
102 }
103}
104
105impl Prefetch for ReqwestReader {
106 fn prefetch(
107 &self,
108 url: impl IntoUrl + Send,
109 ) -> impl Future<Output = Result<UrlInfo, reqwest::Error>> + Send {
110 self.client.prefetch(url)
111 }
112}
113
114async fn prefetch_fallback(url: Url, client: &Client) -> Result<UrlInfo, reqwest::Error> {
115 let resp = client
116 .get(url)
117 .header(header::RANGE, "bytes=0-")
118 .send()
119 .await?
120 .error_for_status()?;
121 let status = resp.status();
122 let final_url = resp.url();
123
124 let resp_headers = resp.headers();
125 let size = get_file_size(resp_headers, &status);
126 let supports_range = status == StatusCode::PARTIAL_CONTENT;
127 Ok(UrlInfo {
128 final_url: final_url.clone(),
129 name: Some(get_filename(resp_headers, final_url)),
130 size,
131 supports_range,
132 fast_download: size > 0 && supports_range,
133 etag: get_header_str(resp_headers, &header::ETAG),
134 last_modified: get_header_str(resp_headers, &header::LAST_MODIFIED),
135 })
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[tokio::test]
143 async fn test_redirect_and_content_range() {
144 let mut server = mockito::Server::new_async().await;
145
146 let mock_redirect = server
147 .mock("GET", "/redirect")
148 .with_status(301)
149 .with_header("Location", "/real-file.txt")
150 .create_async()
151 .await;
152
153 let mock_file = server
154 .mock("GET", "/real-file.txt")
155 .with_status(206)
156 .with_header("Content-Range", "bytes 0-1023/2048")
157 .with_body(vec![0; 1024])
158 .create_async()
159 .await;
160
161 let client = Client::new();
162 let url_info = client
163 .prefetch(&format!("{}/redirect", server.url()))
164 .await
165 .expect("Request should succeed");
166
167 assert_eq!(url_info.size, 2048);
168 assert_eq!(url_info.name.as_deref(), Some("real-file.txt"));
169 assert_eq!(
170 url_info.final_url.as_str(),
171 format!("{}/real-file.txt", server.url())
172 );
173 assert!(url_info.supports_range);
174
175 mock_redirect.assert_async().await;
176 mock_file.assert_async().await;
177 }
178
179 #[tokio::test]
180 async fn test_content_range_priority() {
181 let mut server = mockito::Server::new_async().await;
182 let mock = server
183 .mock("GET", "/file")
184 .with_status(206)
185 .with_header("Content-Range", "bytes 0-1023/2048")
186 .create_async()
187 .await;
188
189 let client = Client::new();
190 let url_info = client
191 .prefetch(&format!("{}/file", server.url()))
192 .await
193 .expect("Request should succeed");
194
195 assert_eq!(url_info.size, 2048);
196 mock.assert_async().await;
197 }
198
199 #[tokio::test]
200 async fn test_filename_sources() {
201 let mut server = mockito::Server::new_async().await;
202
203 let mock1 = server
205 .mock("GET", "/test1")
206 .with_header("Content-Disposition", "attachment; filename=\"test.txt\"")
207 .create_async()
208 .await;
209 let url_info = Client::new()
210 .prefetch(&format!("{}/test1", server.url()))
211 .await
212 .unwrap();
213 assert_eq!(url_info.name.as_deref(), Some("test.txt"));
214 mock1.assert_async().await;
215
216 let mock2 = server.mock("GET", "/test2/file.pdf").create_async().await;
218 let url_info = Client::new()
219 .prefetch(&format!("{}/test2/file.pdf", server.url()))
220 .await
221 .unwrap();
222 assert_eq!(url_info.name.as_deref(), Some("file.pdf"));
223 mock2.assert_async().await;
224
225 let mock3 = server
227 .mock("GET", "/test3")
228 .with_header(
229 "Content-Disposition",
230 "attachment; filename*=UTF-8''%E6%82%AA%E3%81%84%3C%3E%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%3F%E5%90%8D.txt",
231 )
232 .create_async()
233 .await;
234 let url_info = Client::new()
235 .prefetch(&format!("{}/test3", server.url()))
236 .await
237 .unwrap();
238 assert_eq!(url_info.name.as_deref(), Some("悪い__ファイル_名.txt"));
239 mock3.assert_async().await;
240 }
241
242 #[tokio::test]
243 async fn test_error_handling() {
244 let mut server = mockito::Server::new_async().await;
245 let mock1 = server
246 .mock("GET", "/404")
247 .with_status(404)
248 .create_async()
249 .await;
250
251 let client = Client::new();
252
253 match client.prefetch(&format!("{}/404", server.url())).await {
254 Ok(info) => panic!("404 status code should not success: {info:?}"),
255 Err(err) => {
256 assert!(err.is_status(), "should be error about status code");
257 assert_eq!(err.status(), Some(StatusCode::NOT_FOUND));
258 }
259 }
260
261 mock1.assert_async().await;
262 }
263}