1use content_disposition;
2use reqwest::{
3 header::{self, HeaderMap},
4 Client, IntoUrl, StatusCode, Url,
5};
6use sanitize_filename;
7
8#[derive(Debug, Clone)]
9pub struct UrlInfo {
10 pub file_size: u64,
11 pub file_name: String,
12 pub supports_range: bool,
13 pub can_fast_download: bool,
14 pub final_url: Url,
15 pub etag: Option<String>,
16 pub last_modified: Option<String>,
17}
18
19fn get_file_size(headers: &HeaderMap, status: &StatusCode) -> u64 {
20 if *status == StatusCode::PARTIAL_CONTENT {
21 headers
22 .get(header::CONTENT_RANGE)
23 .and_then(|hv| hv.to_str().ok())
24 .and_then(|s| s.rsplit('/').next())
25 .and_then(|total| total.parse().ok())
26 .unwrap_or_default()
27 } else {
28 headers
29 .get(header::CONTENT_LENGTH)
30 .and_then(|hv| hv.to_str().ok())
31 .and_then(|s| s.parse().ok())
32 .unwrap_or_default()
33 }
34}
35
36fn get_header_str(headers: &HeaderMap, header_name: &header::HeaderName) -> Option<String> {
37 headers
38 .get(header_name)
39 .and_then(|hv| hv.to_str().ok())
40 .map(String::from)
41}
42
43fn get_filename(headers: &HeaderMap, final_url: &Url) -> String {
44 let from_disposition = headers
45 .get(header::CONTENT_DISPOSITION)
46 .and_then(|hv| hv.to_str().ok())
47 .and_then(|s| content_disposition::parse_content_disposition(s).filename_full())
48 .filter(|s| !s.trim().is_empty());
49
50 let from_url = final_url
51 .path_segments()
52 .and_then(|segments| segments.last())
53 .and_then(|s| urlencoding::decode(s).ok())
54 .filter(|s| !s.trim().is_empty())
55 .map(|s| (&s).to_string());
56
57 let raw_name = from_disposition
58 .or(from_url)
59 .unwrap_or_else(|| final_url.to_string());
60
61 sanitize_filename::sanitize_with_options(
62 &raw_name,
63 sanitize_filename::Options {
64 windows: true,
65 truncate: true,
66 replacement: "_",
67 },
68 )
69}
70
71pub async fn get_url_info(url: impl IntoUrl, client: &Client) -> Result<UrlInfo, reqwest::Error> {
72 let url = url.into_url()?;
73 let resp = client.head(url.clone()).send().await?;
74 let resp = match resp.error_for_status() {
75 Ok(resp) => resp,
76 Err(_) => return get_url_info_fallback(url, client).await,
77 };
78
79 let status = resp.status();
80 let final_url = resp.url();
81
82 let resp_headers = resp.headers();
83 let file_size = get_file_size(resp_headers, &status);
84
85 let supports_range = match resp.headers().get(header::ACCEPT_RANGES) {
86 Some(accept_ranges) => accept_ranges
87 .to_str()
88 .ok()
89 .map(|v| v.split(' '))
90 .and_then(|supports| supports.into_iter().find(|&ty| ty == "bytes"))
91 .is_some(),
92 None => return get_url_info_fallback(url, client).await,
93 };
94
95 Ok(UrlInfo {
96 final_url: final_url.clone(),
97 file_name: get_filename(resp_headers, &final_url),
98 file_size,
99 supports_range,
100 can_fast_download: file_size > 0 && supports_range,
101 etag: get_header_str(resp_headers, &header::ETAG),
102 last_modified: get_header_str(resp_headers, &header::LAST_MODIFIED),
103 })
104}
105
106async fn get_url_info_fallback(url: Url, client: &Client) -> Result<UrlInfo, reqwest::Error> {
107 let resp = client
108 .get(url)
109 .header(header::RANGE, "bytes=0-")
110 .send()
111 .await?
112 .error_for_status()?;
113 let status = resp.status();
114 let final_url = resp.url();
115
116 let resp_headers = resp.headers();
117 let file_size = get_file_size(resp_headers, &status);
118 let supports_range = status == StatusCode::PARTIAL_CONTENT;
119 Ok(UrlInfo {
120 final_url: final_url.clone(),
121 file_name: get_filename(resp_headers, &final_url),
122 file_size,
123 supports_range,
124 can_fast_download: file_size > 0 && supports_range,
125 etag: get_header_str(resp_headers, &header::ETAG),
126 last_modified: get_header_str(resp_headers, &header::LAST_MODIFIED),
127 })
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[tokio::test]
135 async fn test_redirect_and_content_range() {
136 let mut server = mockito::Server::new_async().await;
137
138 let mock_redirect = server
139 .mock("GET", "/redirect")
140 .with_status(301)
141 .with_header("Location", "/real-file.txt")
142 .create_async()
143 .await;
144
145 let mock_file = server
146 .mock("GET", "/real-file.txt")
147 .with_status(206)
148 .with_header("Content-Range", "bytes 0-1023/2048")
149 .with_body(vec![0; 1024])
150 .create_async()
151 .await;
152
153 let client = Client::new();
154 let url_info = get_url_info(&format!("{}/redirect", server.url()), &client)
155 .await
156 .expect("Request should succeed");
157
158 assert_eq!(url_info.file_size, 2048);
159 assert_eq!(url_info.file_name, "real-file.txt");
160 assert_eq!(
161 url_info.final_url.as_str(),
162 format!("{}/real-file.txt", server.url())
163 );
164 assert!(url_info.supports_range);
165
166 mock_redirect.assert_async().await;
167 mock_file.assert_async().await;
168 }
169
170 #[tokio::test]
171 async fn test_content_range_priority() {
172 let mut server = mockito::Server::new_async().await;
173 let mock = server
174 .mock("GET", "/file")
175 .with_status(206)
176 .with_header("Content-Range", "bytes 0-1023/2048")
177 .create_async()
178 .await;
179
180 let client = Client::new();
181 let url_info = get_url_info(&format!("{}/file", server.url()), &client)
182 .await
183 .expect("Request should succeed");
184
185 assert_eq!(url_info.file_size, 2048);
186 mock.assert_async().await;
187 }
188
189 #[tokio::test]
190 async fn test_filename_sources() {
191 let mut server = mockito::Server::new_async().await;
192
193 let mock1 = server
195 .mock("GET", "/test1")
196 .with_header("Content-Disposition", "attachment; filename=\"test.txt\"")
197 .create_async()
198 .await;
199 let url_info = get_url_info(&format!("{}/test1", server.url()), &Client::new())
200 .await
201 .unwrap();
202 assert_eq!(url_info.file_name, "test.txt");
203 mock1.assert_async().await;
204
205 let mock2 = server.mock("GET", "/test2/file.pdf").create_async().await;
207 let url_info = get_url_info(&format!("{}/test2/file.pdf", server.url()), &Client::new())
208 .await
209 .unwrap();
210 assert_eq!(url_info.file_name, "file.pdf");
211 mock2.assert_async().await;
212
213 let mock3 = server
215 .mock("GET", "/test3")
216 .with_header(
217 "Content-Disposition",
218 "attachment; filename*=UTF-8''%E6%82%AA%E3%81%84%3C%3E%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%3F%E5%90%8D.txt"
219 )
220 .create_async().await;
221 let url_info = get_url_info(&format!("{}/test3", server.url()), &Client::new())
222 .await
223 .unwrap();
224 assert_eq!(url_info.file_name, "悪い__ファイル_名.txt");
225 mock3.assert_async().await;
226 }
227
228 #[tokio::test]
229 async fn test_error_handling() {
230 let mut server = mockito::Server::new_async().await;
231 let mock1 = server
232 .mock("GET", "/404")
233 .with_status(404)
234 .create_async()
235 .await;
236
237 let client = Client::new();
238
239 match get_url_info(&format!("{}/404", server.url()), &client).await {
240 Ok(info) => panic!("404 status code should not success: {:?}", info),
241 Err(err) => {
242 assert!(err.is_status(), "should be error about status code");
243 assert_eq!(err.status(), Some(StatusCode::NOT_FOUND));
244 }
245 }
246
247 mock1.assert_async().await;
248 }
249}