fast_down/core/
prefetch.rs

1use content_disposition;
2use reqwest::{
3    header::{self, HeaderMap},
4    Client, IntoUrl, StatusCode, Url,
5};
6use sanitize_filename;
7
8#[derive(Debug, Clone)]
9pub struct UrlInfo {
10    pub file_size: u64,
11    pub file_name: String,
12    pub supports_range: bool,
13    pub can_fast_download: bool,
14    pub final_url: Url,
15    pub etag: Option<String>,
16    pub last_modified: Option<String>,
17}
18
19fn get_file_size(headers: &HeaderMap, status: &StatusCode) -> u64 {
20    if *status == StatusCode::PARTIAL_CONTENT {
21        headers
22            .get(header::CONTENT_RANGE)
23            .and_then(|hv| hv.to_str().ok())
24            .and_then(|s| s.rsplit('/').next())
25            .and_then(|total| total.parse().ok())
26            .unwrap_or_default()
27    } else {
28        headers
29            .get(header::CONTENT_LENGTH)
30            .and_then(|hv| hv.to_str().ok())
31            .and_then(|s| s.parse().ok())
32            .unwrap_or_default()
33    }
34}
35
36fn get_header_str(headers: &HeaderMap, header_name: &header::HeaderName) -> Option<String> {
37    headers
38        .get(header_name)
39        .and_then(|hv| hv.to_str().ok())
40        .map(String::from)
41}
42
43fn get_filename(headers: &HeaderMap, final_url: &Url) -> String {
44    let from_disposition = headers
45        .get(header::CONTENT_DISPOSITION)
46        .and_then(|hv| hv.to_str().ok())
47        .and_then(|s| content_disposition::parse_content_disposition(s).filename_full())
48        .filter(|s| !s.trim().is_empty());
49
50    let from_url = final_url
51        .path_segments()
52        .and_then(|segments| segments.last())
53        .and_then(|s| urlencoding::decode(s).ok())
54        .filter(|s| !s.trim().is_empty())
55        .map(|s| (&s).to_string());
56
57    let raw_name = from_disposition
58        .or(from_url)
59        .unwrap_or_else(|| final_url.to_string());
60
61    sanitize_filename::sanitize_with_options(
62        &raw_name,
63        sanitize_filename::Options {
64            windows: true,
65            truncate: true,
66            replacement: "_",
67        },
68    )
69}
70
71pub async fn get_url_info(url: impl IntoUrl, client: &Client) -> Result<UrlInfo, reqwest::Error> {
72    let url = url.into_url()?;
73    let resp = client.head(url.clone()).send().await?;
74    let resp = match resp.error_for_status() {
75        Ok(resp) => resp,
76        Err(_) => return get_url_info_fallback(url, client).await,
77    };
78
79    let status = resp.status();
80    let final_url = resp.url();
81
82    let resp_headers = resp.headers();
83    let file_size = get_file_size(resp_headers, &status);
84
85    let supports_range = match resp.headers().get(header::ACCEPT_RANGES) {
86        Some(accept_ranges) => accept_ranges
87            .to_str()
88            .ok()
89            .map(|v| v.split(' '))
90            .and_then(|supports| supports.into_iter().find(|&ty| ty == "bytes"))
91            .is_some(),
92        None => return get_url_info_fallback(url, client).await,
93    };
94
95    Ok(UrlInfo {
96        final_url: final_url.clone(),
97        file_name: get_filename(resp_headers, &final_url),
98        file_size,
99        supports_range,
100        can_fast_download: file_size > 0 && supports_range,
101        etag: get_header_str(resp_headers, &header::ETAG),
102        last_modified: get_header_str(resp_headers, &header::LAST_MODIFIED),
103    })
104}
105
106async fn get_url_info_fallback(url: Url, client: &Client) -> Result<UrlInfo, reqwest::Error> {
107    let resp = client
108        .get(url)
109        .header(header::RANGE, "bytes=0-")
110        .send()
111        .await?
112        .error_for_status()?;
113    let status = resp.status();
114    let final_url = resp.url();
115
116    let resp_headers = resp.headers();
117    let file_size = get_file_size(resp_headers, &status);
118    let supports_range = status == StatusCode::PARTIAL_CONTENT;
119    Ok(UrlInfo {
120        final_url: final_url.clone(),
121        file_name: get_filename(resp_headers, &final_url),
122        file_size,
123        supports_range,
124        can_fast_download: file_size > 0 && supports_range,
125        etag: get_header_str(resp_headers, &header::ETAG),
126        last_modified: get_header_str(resp_headers, &header::LAST_MODIFIED),
127    })
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[tokio::test]
135    async fn test_redirect_and_content_range() {
136        let mut server = mockito::Server::new_async().await;
137
138        let mock_redirect = server
139            .mock("GET", "/redirect")
140            .with_status(301)
141            .with_header("Location", "/real-file.txt")
142            .create_async()
143            .await;
144
145        let mock_file = server
146            .mock("GET", "/real-file.txt")
147            .with_status(206)
148            .with_header("Content-Range", "bytes 0-1023/2048")
149            .with_body(vec![0; 1024])
150            .create_async()
151            .await;
152
153        let client = Client::new();
154        let url_info = get_url_info(&format!("{}/redirect", server.url()), &client)
155            .await
156            .expect("Request should succeed");
157
158        assert_eq!(url_info.file_size, 2048);
159        assert_eq!(url_info.file_name, "real-file.txt");
160        assert_eq!(
161            url_info.final_url.as_str(),
162            format!("{}/real-file.txt", server.url())
163        );
164        assert!(url_info.supports_range);
165
166        mock_redirect.assert_async().await;
167        mock_file.assert_async().await;
168    }
169
170    #[tokio::test]
171    async fn test_content_range_priority() {
172        let mut server = mockito::Server::new_async().await;
173        let mock = server
174            .mock("GET", "/file")
175            .with_status(206)
176            .with_header("Content-Range", "bytes 0-1023/2048")
177            .create_async()
178            .await;
179
180        let client = Client::new();
181        let url_info = get_url_info(&format!("{}/file", server.url()), &client)
182            .await
183            .expect("Request should succeed");
184
185        assert_eq!(url_info.file_size, 2048);
186        mock.assert_async().await;
187    }
188
189    #[tokio::test]
190    async fn test_filename_sources() {
191        let mut server = mockito::Server::new_async().await;
192
193        // Test Content-Disposition source
194        let mock1 = server
195            .mock("GET", "/test1")
196            .with_header("Content-Disposition", "attachment; filename=\"test.txt\"")
197            .create_async()
198            .await;
199        let url_info = get_url_info(&format!("{}/test1", server.url()), &Client::new())
200            .await
201            .unwrap();
202        assert_eq!(url_info.file_name, "test.txt");
203        mock1.assert_async().await;
204
205        // Test URL path source
206        let mock2 = server.mock("GET", "/test2/file.pdf").create_async().await;
207        let url_info = get_url_info(&format!("{}/test2/file.pdf", server.url()), &Client::new())
208            .await
209            .unwrap();
210        assert_eq!(url_info.file_name, "file.pdf");
211        mock2.assert_async().await;
212
213        // Test sanitization
214        let mock3 = server
215            .mock("GET", "/test3")
216            .with_header(
217                "Content-Disposition",
218                "attachment; filename*=UTF-8''%E6%82%AA%E3%81%84%3C%3E%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB%3F%E5%90%8D.txt"
219            )
220            .create_async().await;
221        let url_info = get_url_info(&format!("{}/test3", server.url()), &Client::new())
222            .await
223            .unwrap();
224        assert_eq!(url_info.file_name, "悪い__ファイル_名.txt");
225        mock3.assert_async().await;
226    }
227
228    #[tokio::test]
229    async fn test_error_handling() {
230        let mut server = mockito::Server::new_async().await;
231        let mock1 = server
232            .mock("GET", "/404")
233            .with_status(404)
234            .create_async()
235            .await;
236
237        let client = Client::new();
238
239        match get_url_info(&format!("{}/404", server.url()), &client).await {
240            Ok(info) => panic!("404 status code should not success: {:?}", info),
241            Err(err) => {
242                assert!(err.is_status(), "should be error about status code");
243                assert_eq!(err.status(), Some(StatusCode::NOT_FOUND));
244            }
245        }
246
247        mock1.assert_async().await;
248    }
249}