1extern crate alloc;
2use crate::{UrlInfo, reqwest::ReqwestReader};
3use alloc::string::{String, ToString};
4use content_disposition;
5use reqwest::{
6 Client, IntoUrl, StatusCode, Url,
7 header::{self, HeaderMap},
8};
9use sanitize_filename;
10
11fn get_file_size(headers: &HeaderMap, status: &StatusCode) -> u64 {
12 if *status == StatusCode::PARTIAL_CONTENT {
13 headers
14 .get(header::CONTENT_RANGE)
15 .and_then(|hv| hv.to_str().ok())
16 .and_then(|s| s.rsplit('/').next())
17 .and_then(|total| total.parse().ok())
18 .unwrap_or(0)
19 } else {
20 headers
21 .get(header::CONTENT_LENGTH)
22 .and_then(|hv| hv.to_str().ok())
23 .and_then(|s| s.parse().ok())
24 .unwrap_or(0)
25 }
26}
27
28fn get_header_str(headers: &HeaderMap, header_name: &header::HeaderName) -> Option<String> {
29 headers
30 .get(header_name)
31 .and_then(|hv| hv.to_str().ok())
32 .map(String::from)
33}
34
35fn get_filename(headers: &HeaderMap, final_url: &Url) -> String {
36 let from_disposition = headers
37 .get(header::CONTENT_DISPOSITION)
38 .and_then(|hv| hv.to_str().ok())
39 .and_then(|s| content_disposition::parse_content_disposition(s).filename_full())
40 .filter(|s| !s.trim().is_empty());
41
42 let from_url = final_url
43 .path_segments()
44 .and_then(|mut segments| segments.next_back())
45 .and_then(|s| urlencoding::decode(s).ok())
46 .filter(|s| !s.trim().is_empty())
47 .map(|s| s.to_string());
48
49 let raw_name = from_disposition
50 .or(from_url)
51 .unwrap_or_else(|| final_url.to_string());
52
53 sanitize_filename::sanitize_with_options(
54 &raw_name,
55 sanitize_filename::Options {
56 windows: true,
57 truncate: true,
58 replacement: "_",
59 },
60 )
61}
62
63pub trait Prefetch {
64 fn prefetch(
65 &self,
66 url: impl IntoUrl + Send,
67 ) -> impl Future<Output = Result<UrlInfo, reqwest::Error>> + Send;
68}
69
70impl Prefetch for Client {
71 async fn prefetch(&self, url: impl IntoUrl + Send) -> Result<UrlInfo, reqwest::Error> {
72 let url = url.into_url()?;
73 let resp = match self.head(url.clone()).send().await {
74 Ok(resp) => resp,
75 Err(_) => return prefetch_fallback(url, self).await,
76 };
77 let resp = match resp.error_for_status() {
78 Ok(resp) => resp,
79 Err(_) => return prefetch_fallback(url, self).await,
80 };
81
82 let status = resp.status();
83 let final_url = resp.url();
84
85 let resp_headers = resp.headers();
86 let size = get_file_size(resp_headers, &status);
87
88 let supports_range = match resp.headers().get(header::ACCEPT_RANGES) {
89 Some(accept_ranges) => accept_ranges
90 .to_str()
91 .ok()
92 .map(|v| v.split(' '))
93 .and_then(|supports| supports.into_iter().find(|&ty| ty == "bytes"))
94 .is_some(),
95 None => return prefetch_fallback(url, self).await,
96 };
97
98 Ok(UrlInfo {
99 final_url: final_url.clone(),
100 name: get_filename(resp_headers, final_url),
101 size,
102 supports_range,
103 fast_download: size > 0 && supports_range,
104 etag: get_header_str(resp_headers, &header::ETAG),
105 last_modified: get_header_str(resp_headers, &header::LAST_MODIFIED),
106 })
107 }
108}
109
110impl Prefetch for ReqwestReader {
111 fn prefetch(
112 &self,
113 url: impl IntoUrl + Send,
114 ) -> impl Future<Output = Result<UrlInfo, reqwest::Error>> + Send {
115 self.client.prefetch(url)
116 }
117}
118
119async fn prefetch_fallback(url: Url, client: &Client) -> Result<UrlInfo, reqwest::Error> {
120 let resp = client
121 .get(url)
122 .header(header::RANGE, "bytes=0-")
123 .send()
124 .await?
125 .error_for_status()?;
126 let status = resp.status();
127 let final_url = resp.url();
128
129 let resp_headers = resp.headers();
130 let size = get_file_size(resp_headers, &status);
131 let supports_range = status == StatusCode::PARTIAL_CONTENT;
132 Ok(UrlInfo {
133 final_url: final_url.clone(),
134 name: get_filename(resp_headers, final_url),
135 size,
136 supports_range,
137 fast_download: size > 0 && supports_range,
138 etag: get_header_str(resp_headers, &header::ETAG),
139 last_modified: get_header_str(resp_headers, &header::LAST_MODIFIED),
140 })
141}
142
143#[cfg(test)]
144mod tests {
145 use alloc::{format, vec};
146
147 use super::*;
148
149 #[tokio::test]
150 async fn test_redirect_and_content_range() {
151 let mut server = mockito::Server::new_async().await;
152
153 let mock_redirect = server
154 .mock("GET", "/redirect")
155 .with_status(301)
156 .with_header("Location", "/real-file.txt")
157 .create_async()
158 .await;
159
160 let mock_file = server
161 .mock("GET", "/real-file.txt")
162 .with_status(206)
163 .with_header("Content-Range", "bytes 0-1023/2048")
164 .with_body(vec![0; 1024])
165 .create_async()
166 .await;
167
168 let client = Client::new();
169 let url_info = client
170 .prefetch(&format!("{}/redirect", server.url()))
171 .await
172 .expect("Request should succeed");
173
174 assert_eq!(url_info.size, 2048);
175 assert_eq!(url_info.name, "real-file.txt");
176 assert_eq!(
177 url_info.final_url.as_str(),
178 format!("{}/real-file.txt", server.url())
179 );
180 assert!(url_info.supports_range);
181
182 mock_redirect.assert_async().await;
183 mock_file.assert_async().await;
184 }
185
186 #[tokio::test]
187 async fn test_content_range_priority() {
188 let mut server = mockito::Server::new_async().await;
189 let mock = server
190 .mock("GET", "/file")
191 .with_status(206)
192 .with_header("Content-Range", "bytes 0-1023/2048")
193 .create_async()
194 .await;
195
196 let client = Client::new();
197 let url_info = client
198 .prefetch(&format!("{}/file", server.url()))
199 .await
200 .expect("Request should succeed");
201
202 assert_eq!(url_info.size, 2048);
203 mock.assert_async().await;
204 }
205
206 #[tokio::test]
207 async fn test_filename_sources() {
208 let mut server = mockito::Server::new_async().await;
209
210 let mock1 = server
212 .mock("GET", "/test1")
213 .with_header("Content-Disposition", "attachment; filename=\"test.txt\"")
214 .create_async()
215 .await;
216 let url_info = Client::new()
217 .prefetch(&format!("{}/test1", server.url()))
218 .await
219 .unwrap();
220 assert_eq!(url_info.name, "test.txt");
221 mock1.assert_async().await;
222
223 let mock2 = server.mock("GET", "/test2/file.pdf").create_async().await;
225 let url_info = Client::new()
226 .prefetch(&format!("{}/test2/file.pdf", server.url()))
227 .await
228 .unwrap();
229 assert_eq!(url_info.name, "file.pdf");
230 mock2.assert_async().await;
231
232 let mock3 = server
234 .mock("GET", "/test3")
235 .with_header(
236 "Content-Disposition",
237 "attachment; filename*=UTF-8''%E6%82%AA%E3%81%84%3C%3E%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%3F%E5%90%8D.txt",
238 )
239 .create_async()
240 .await;
241 let url_info = Client::new()
242 .prefetch(&format!("{}/test3", server.url()))
243 .await
244 .unwrap();
245 assert_eq!(url_info.name, "悪い__ファイル_名.txt");
246 mock3.assert_async().await;
247 }
248
249 #[tokio::test]
250 async fn test_error_handling() {
251 let mut server = mockito::Server::new_async().await;
252 let mock1 = server
253 .mock("GET", "/404")
254 .with_status(404)
255 .create_async()
256 .await;
257
258 let client = Client::new();
259
260 match client.prefetch(&format!("{}/404", server.url())).await {
261 Ok(info) => panic!("404 status code should not success: {info:?}"),
262 Err(err) => {
263 assert!(err.is_status(), "should be error about status code");
264 assert_eq!(err.status(), Some(StatusCode::NOT_FOUND));
265 }
266 }
267
268 mock1.assert_async().await;
269 }
270}