downloader_http_rs/
http.rs

1use crate::config::Config;
2use anyhow::{anyhow, Error};
3use bounded_join_set::JoinSet;
4use reqwest::header::{self, HeaderMap, HeaderValue};
5use reqwest::{Client, Request};
6use std::fmt;
7use std::fs::{self, File, OpenOptions};
8use std::io::{self, BufRead, BufReader};
9use std::io::{BufWriter, Seek, SeekFrom, Write};
10use std::path::{Path, PathBuf};
11use std::time::Duration;
12use tokio::{sync::mpsc, time::timeout};
13use url::{ParseError, Url};
14
15pub struct HttpDownload {
16    //配置
17    conf: Config,
18    //文件名
19    filename: String,
20    //
21    retries: u8,
22    //是否支持分片下载
23    support_chunk: bool,
24    //http client
25    httpclient: Client,
26    //下载文件句柄
27    file: Option<BufWriter<fs::File>>,
28    //状态文件句柄
29    st_file: Option<BufWriter<fs::File>>,
30}
31
32impl fmt::Debug for HttpDownload {
33    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34        write!(
35            f,
36            "http download url: {}",
37            self.conf.url.as_ref().map(|url| url.as_str()).unwrap_or("")
38        )
39    }
40}
41
42impl HttpDownload {
43    pub fn new() -> Self {
44        Self {
45            conf: Config::default(),
46            retries: 0,
47            filename: "".to_string(),
48            st_file: None,
49            file: None,
50            support_chunk: false,
51            httpclient: Client::new(),
52        }
53    }
54
55    pub fn new_with_config(conf: Config) -> Self {
56        Self {
57            conf,
58            retries: 0,
59            filename: "".to_string(),
60            st_file: None,
61            file: None,
62            support_chunk: false,
63            httpclient: Client::new(),
64        }
65    }
66
67    pub fn set_timeout(&mut self, timeout: u64) -> &mut Self {
68        self.conf.timeout = timeout;
69        self
70    }
71
72    pub fn on_down_progress(
73        &mut self,
74        cb: Option<Box<dyn Fn(u8) + Send + Sync + 'static>>,
75    ) -> &mut Self {
76        self.conf.on_down_progress = cb;
77        self
78    }
79    pub fn on_down_finish(&mut self, cb: Option<fn(String)>) -> &mut Self {
80        self.conf.on_down_finish = cb;
81        self
82    }
83
84    pub fn set_headers(&mut self, headers: HeaderMap) -> &mut Self {
85        self.conf.headers = headers;
86        self
87    }
88    //保存目录
89    pub fn set_save_dir(&mut self, save_dir: String) -> &mut Self {
90        self.conf.save_dir = save_dir;
91        self
92    }
93    pub fn set_file_md5(&mut self, md5: String) -> &mut Self {
94        self.conf.file_md5 = md5;
95        self
96    }
97
98    pub fn set_chunk_size(&mut self, chunk_size: u64) -> &mut Self {
99        self.conf.chunk_size = chunk_size;
100        self
101    }
102    pub fn set_max_retries(&mut self, max_retries: u8) -> &mut Self {
103        self.conf.max_retries = max_retries;
104        self
105    }
106    pub fn set_user_agent(&mut self, user_agent: &str) -> &mut Self {
107        if let Ok(ag) = header::HeaderValue::from_str(user_agent) {
108            self.conf.headers.insert(header::USER_AGENT, ag);
109        }
110        self
111    }
112    pub fn set_num_workers(&mut self, num_workers: usize) -> &mut Self {
113        self.conf.num_workers = num_workers;
114        self
115    }
116    pub fn debug(&mut self, debug: bool) -> &mut Self {
117        self.conf.debug = debug;
118        self
119    }
120
121    pub async fn url_headers_info(&self, url: &str) -> Result<(), Error> {
122        let url = self.parse_url(url)?;
123        let headers = self.get_headers_from_url(&url).await?;
124        self.print_headers(&headers);
125        Ok(())
126    }
127
128    fn print_headers(&self, headers: &HeaderMap) {
129        for (hdr, val) in headers.iter() {
130            println!("{}: {}", hdr.as_str(), val.to_str().unwrap_or("<..>"));
131        }
132    }
133
134    //从url 服务中获取headers 信息
135    async fn get_headers_from_url(&self, url: &Url) -> Result<HeaderMap, Error> {
136        let resp = Client::new()
137            .get(url.as_ref())
138            .timeout(Duration::from_secs(10))
139            .headers(self.conf.headers.clone())
140            .header(header::ACCEPT, HeaderValue::from_str("*/*")?);
141
142        let resp = resp.send().await?;
143        Ok(resp.headers().clone())
144    }
145
146    //检查文件是否已下载过
147    fn compari_file_md5(&self, file_path: &str, md5: &str) -> bool {
148        if md5.is_empty() {
149            return false;
150        }
151
152        if let Ok(file_md5) = crate::get_file_md5(&file_path) {
153            log::debug!("file_md5:{},md5:{}", file_md5, md5);
154            return file_md5.to_lowercase().eq(&md5.to_lowercase());
155        }
156
157        false
158    }
159
160    fn parse_url(&self, url: &str) -> Result<Url, ParseError> {
161        match Url::parse(url) {
162            Ok(url) => Ok(url),
163            Err(ParseError::RelativeUrlWithoutBase) => {
164                let url_with_base = format!("{}{}", "http://", url);
165                Url::parse(url_with_base.as_str())
166            }
167            Err(error) => Err(error),
168        }
169    }
170
171    //设置下载url
172    pub async fn set_url(&mut self, url: &str) -> Result<&mut Self, Error> {
173        let url = self
174            .parse_url(url)
175            .map_err(|e| anyhow!("url不合法: {} err:{}", url, e))?;
176
177        let headers = self.get_headers_from_url(&url).await?;
178        //打印http头
179        if self.conf.debug {
180            self.print_headers(&headers);
181        }
182
183        //获下载文件名
184        let fname = gen_filename(&url, Some(&headers));
185        if fname.is_empty() {
186            return Err(anyhow!("filename is empty"));
187        }
188        self.filename = fname;
189
190        //判断服务器是否支持分片下载
191        let server_acccept_ranges = match headers.get(header::ACCEPT_RANGES) {
192            Some(val) => val == "bytes",
193            None => false,
194        };
195
196        self.support_chunk = server_acccept_ranges;
197
198        //文件长度
199        let mut content_len = 0;
200        if let Some(val) = headers.get(header::CONTENT_LENGTH) {
201            content_len = val.to_str().unwrap_or("").parse::<u64>()?;
202        }
203        self.conf.content_len = content_len;
204
205        self.conf.url = Some(url);
206
207        Ok(self)
208    }
209
210    fn on_progress(&mut self) {
211        let mut pro = if self.conf.content_len > 0 {
212            if self.conf.download_len >= self.conf.content_len {
213                100
214            } else {
215                let r = (self.conf.download_len as f64 / self.conf.content_len as f64) * 100.0;
216                r.ceil() as u8
217            }
218        } else {
219            0
220        };
221
222        if pro > 100 {
223            pro = 100;
224        }
225
226        if self.conf.progress != pro {
227            self.conf.progress = pro;
228            if let Some(evt) = &self.conf.on_down_progress {
229                evt(pro);
230            }
231            log::trace!("download progress :{}", pro);
232        }
233    }
234    fn on_finish(&self, file_path: String) {
235        if let Some(evt) = &self.conf.on_down_progress {
236            evt(100);
237        }
238        if let Some(evt) = &self.conf.on_down_finish {
239            evt(file_path);
240        }
241        let filepath = self.get_file_path(format!("{}.st", self.filename));
242        let _ = fs::remove_file(&filepath);
243    }
244
245    fn write_content(&mut self, content: &[u8]) -> Result<(), Error> {
246        if self.file.is_none() {
247            return Err(anyhow!("file handler is none"));
248        }
249
250        //写入文件字节
251        if let Some(ref mut file) = self.file {
252            match file.write_all(content) {
253                Ok(()) => {}
254                Err(e) => {
255                    return Err(anyhow!("write_content file.write_all err {}", e));
256                }
257            }
258        }
259
260        self.conf.download_len += content.len() as u64;
261        //设置进度条
262        self.on_progress();
263        Ok(())
264    }
265
266    fn chunk_write_content(&mut self, content: (u64, u64, &[u8])) -> Result<(), Error> {
267        if self.file.is_none() {
268            return Err(anyhow!("file handler is none"));
269        }
270
271        let (byte_count, offset, buf) = content;
272        //写入文件
273        if let Some(ref mut file) = self.file {
274            file.seek(SeekFrom::Start(offset))?;
275            file.write_all(buf)?;
276            file.flush()?;
277        }
278
279        //记录进度临时文件
280        if let Some(ref mut file) = self.st_file {
281            writeln!(file, "{}:{}", byte_count, offset)?;
282            file.flush()?;
283        }
284
285        self.conf.download_len += byte_count;
286        //设置进度条
287        self.on_progress();
288        Ok(())
289    }
290
291    pub async fn start(&mut self) -> Result<String, Error> {
292        if self.support_chunk {
293            log::info!("use chunk download..");
294            self.chunk_download().await
295        } else {
296            log::info!("use general download..");
297            self.gener_download().await
298        }
299    }
300
301    //普通下载
302    #[allow(unused)]
303    pub async fn gener_download(&mut self) -> Result<String, Error> {
304        let filepath = self.get_file_path(self.filename.clone());
305
306        if self.compari_file_md5(&filepath, &self.conf.file_md5) {
307            self.on_finish(filepath.clone());
308            return Ok(filepath);
309        }
310
311        //不是分片下载不支持,文件续传
312        self.file = Some(create_filehandler(
313            &self.filename,
314            &self.conf.save_dir,
315            false,
316        )?);
317
318        let timeout = self.conf.timeout;
319
320        let headers = self.conf.headers.clone();
321
322        let Some(url) = self.conf.url.as_ref() else {
323            return Err(anyhow!("url is empty"));
324        };
325
326        let mut req = self.httpclient.get(url.clone());
327        if timeout > 0 {
328            req = req.timeout(Duration::from_secs(timeout));
329        }
330        let req = req.headers(headers).build()?;
331
332        let mut resp = self.httpclient.execute(req).await?;
333
334        let ct_len = if let Some(val) = resp.headers().get(header::CONTENT_LENGTH) {
335            Some(val.to_str()?.parse::<usize>()?)
336        } else {
337            None
338        };
339
340        let mut cnt = 0;
341        let mut total_read = 0;
342
343        while let Some(chunk) = resp.chunk().await? {
344            let chunk_buffer = chunk.to_vec();
345            let bcount = chunk_buffer.len();
346
347            cnt += bcount;
348            total_read += bcount;
349
350            if !chunk_buffer.is_empty() {
351                self.write_content(&chunk_buffer)?;
352            }
353
354            if let Some(ct_len) = ct_len {
355                if total_read >= ct_len {
356                    break;
357                }
358            } else if bcount == 0 {
359                break;
360            }
361        }
362
363        //如果配置较验验,则校验文件
364        if !self.conf.file_md5.is_empty() && !self.compari_file_md5(&filepath, &self.conf.file_md5)
365        {
366            return Err(anyhow!("download ok but, file md5 not match"));
367        }
368
369        self.on_finish(filepath.clone());
370
371        Ok(filepath)
372    }
373
374    //获取分片分段
375    #[allow(dead_code)]
376    fn get_chunk_offsets(&self) -> Vec<(u64, u64)> {
377        let ct_len = self.conf.content_len;
378
379        let chunk_size = self.conf.chunk_size;
380
381        let num_chunks = ct_len / chunk_size;
382
383        log::info!("num_chunks:{}", num_chunks);
384
385        let mut sizes = Vec::new();
386
387        for i in 0..num_chunks {
388            let bound = if i == num_chunks - 1 {
389                ct_len
390            } else {
391                ((i + 1) * chunk_size) - 1
392            };
393
394            sizes.push((i * chunk_size, bound));
395        }
396
397        if sizes.is_empty() {
398            sizes.push((0, ct_len));
399        }
400
401        sizes
402    }
403
404    fn get_file_path(&self, filename: String) -> String {
405        let mut path = PathBuf::from(filename.clone());
406        if !self.conf.save_dir.is_empty() {
407            path = PathBuf::from(&self.conf.save_dir);
408            path.push(&filename);
409        }
410        path.to_str().unwrap_or_default().to_string()
411    }
412
413    //获取下灰复下载的位置
414    fn get_resume_chunk_offsets(&self) -> Result<(Vec<(u64, u64)>, u64), Error> {
415        if self.st_file.is_none() {
416            return Err(anyhow!("st_file is none"));
417        }
418
419        let fname = format!("{}.st", self.filename);
420        let mut path = PathBuf::from(fname.clone());
421
422        if !self.conf.save_dir.is_empty() {
423            path = PathBuf::from(&self.conf.save_dir);
424            path.push(&fname);
425        }
426
427        let ct_len = self.conf.content_len;
428        let chunk_size = self.conf.chunk_size;
429
430        let input = fs::File::open(&path)?;
431        let buf = BufReader::new(input);
432        let mut already_downloaded_bytes = 0u64;
433
434        let mut downloaded = vec![];
435        for line in buf.lines() {
436            let l = line?;
437            let l = l.split(':').collect::<Vec<_>>();
438            let n = (l[0].parse::<u64>()?, l[1].parse::<u64>()?);
439            // 已下载字节数
440            already_downloaded_bytes += n.0;
441            //已下载位置
442            downloaded.push(n);
443        }
444        downloaded.sort_by_key(|a| a.1);
445
446        let mut chunks = vec![];
447
448        let mut i: u64 = 0;
449        for (bc, offset) in downloaded {
450            if i == offset {
451                i = offset + bc;
452            } else {
453                chunks.push((i, offset - 1));
454                i = offset + bc;
455            }
456        }
457
458        while (ct_len - i) > chunk_size {
459            chunks.push((i, i + chunk_size - 1));
460            i += chunk_size;
461        }
462        chunks.push((i, ct_len));
463
464        Ok((chunks, already_downloaded_bytes))
465    }
466
467    //分片下载
468    pub async fn chunk_download(&mut self) -> Result<String, Error> {
469        let filepath = self.get_file_path(self.filename.clone());
470        if self.compari_file_md5(&filepath, &self.conf.file_md5) {
471            self.on_finish(filepath.clone());
472            return Ok(filepath);
473        }
474
475        if !self.support_chunk {
476            return Err(anyhow!("chunk download not support"));
477        }
478
479        //创建分片状态文件
480        let filename = format!("{}.st", self.filename);
481        self.st_file = Some(create_filehandler(&filename, &self.conf.save_dir, true)?);
482        //分段内容
483        let chunk_offsets_info = self.get_resume_chunk_offsets()?;
484
485        //剩余分片
486        let chunk_offsets = chunk_offsets_info.0;
487        //文件已下载的进度
488        let already_download = chunk_offsets_info.1;
489
490        log::info!("already_download len :{}", already_download);
491        log::info!("chunk_offsets count :{}", chunk_offsets.len());
492
493        let mut append = false;
494
495        //灰复下载,设置进度条进度
496        if already_download > 0 {
497            self.conf.download_len = already_download;
498            self.on_progress();
499            append = true;
500        }
501
502        self.file = Some(create_filehandler(
503            &self.filename,
504            &self.conf.save_dir,
505            append,
506        )?);
507
508        let mut headers = self.conf.headers.clone();
509        let mut num_workers = self.conf.num_workers;
510        let max_retries = self.conf.max_retries;
511        if num_workers == 0 {
512            num_workers = 1;
513        }
514
515        if headers.contains_key(header::RANGE) {
516            headers.remove(header::RANGE);
517        }
518
519        let Some(url) = self.conf.url.as_ref() else {
520            return Err(anyhow!("url is empty"));
521        };
522        let mut req = self.httpclient.get(url.clone());
523
524        if self.conf.timeout > 0 {
525            req = req.timeout(Duration::from_secs(self.conf.timeout));
526        }
527        let req = req.headers(headers).build()?;
528
529        let (data_tx, mut data_rx) = mpsc::channel::<(u64, u64, Vec<u8>)>(32);
530        let (errors_tx, mut errors_rx) = mpsc::channel::<(u64, u64)>(32);
531
532        let mut join_set = JoinSet::new(num_workers);
533
534        for offsets in chunk_offsets {
535            let p_data_tx = data_tx.clone();
536            let p_errors_tx = errors_tx.clone();
537            let Some(p_req) = req.try_clone() else {
538                return Err(anyhow!("req.try_clone() err"));
539            };
540
541            join_set.spawn(async move {
542                download_chunk(p_req, offsets, p_data_tx.clone(), p_errors_tx).await;
543            });
544        }
545
546        let mut count = already_download;
547        loop {
548            if count == self.conf.content_len {
549                break;
550            }
551
552            if let Some((byte_count, offset, buf)) = data_rx.recv().await {
553                count += byte_count;
554
555                self.chunk_write_content((byte_count, offset, &buf))?;
556
557                match timeout(Duration::from_micros(1), errors_rx.recv()).await {
558                    Ok(Some(offsets)) => {
559                        if self.retries > max_retries {
560                            if let Some(ref mut file) = self.file {
561                                let _ = file.flush();
562                            }
563                            if let Some(ref mut file) = self.st_file {
564                                let _ = file.flush();
565                            }
566                            return Err(anyhow!("max retries"));
567                        }
568
569                        self.retries += 1;
570                        let data_tx = data_tx.clone();
571                        let errors_tx = errors_tx.clone();
572                        let Some(req) = req.try_clone() else {
573                            return Err(anyhow!("req.try_clone() err"));
574                        };
575
576                        join_set.spawn(async move {
577                            download_chunk(req, offsets, data_tx.clone(), errors_tx).await;
578                        });
579                    }
580                    _ => {}
581                }
582            }
583        }
584
585        join_set.join_next().await;
586
587        //如果配置较验验,则校验文件
588        if !self.conf.file_md5.is_empty() && !self.compari_file_md5(&filepath, &self.conf.file_md5)
589        {
590            return Err(anyhow!("download ok but, file md5 not match"));
591        }
592
593        self.on_finish(filepath.clone());
594
595        log::debug!("[downloader] download finish....");
596        Ok(filepath)
597    }
598}
599
600//分片下载
601async fn download_chunk(
602    req: Request,
603    offsets: (u64, u64),
604    sender: mpsc::Sender<(u64, u64, Vec<u8>)>,
605    errors: mpsc::Sender<(u64, u64)>,
606) {
607    async fn inner(
608        mut req: Request,
609        offsets: (u64, u64),
610        sender: mpsc::Sender<(u64, u64, Vec<u8>)>,
611        start_offset: &mut u64,
612    ) -> Result<(), Error> {
613        log::trace!("download chunk:{}-{}", offsets.0, offsets.1);
614        //0-10485759
615        let byte_range = format!("bytes={}-{}", offsets.0, offsets.1);
616        let headers = req.headers_mut();
617        headers.insert(header::RANGE, HeaderValue::from_str(&byte_range)?);
618        headers.insert(header::ACCEPT, HeaderValue::from_str("*/*")?);
619        headers.insert(header::CONNECTION, HeaderValue::from_str("keep-alive")?);
620        let mut resp = Client::new().execute(req).await?;
621
622        let chunk_sz = offsets.1 - offsets.0;
623        let mut cnt = 0u64;
624
625        while let Some(chunk) = resp.chunk().await? {
626            let byte_count = chunk.len() as u64;
627
628            cnt += byte_count;
629
630            sender
631                .send((byte_count, *start_offset, chunk.to_vec()))
632                .await?;
633
634            *start_offset += byte_count;
635
636            if cnt >= chunk_sz + 1 {
637                break;
638            }
639        }
640        log::trace!("[downloader] download chunk:ok...");
641        Ok(())
642    }
643
644    let mut start_offset = offsets.0;
645    let end_offset = offsets.1;
646
647    if inner(req, offsets, sender, &mut start_offset)
648        .await
649        .is_err()
650    {
651        let _ = errors.send((start_offset, end_offset));
652    }
653}
654
655//下载文件
656fn get_file_handle(fname: &str, append: bool) -> io::Result<File> {
657    if Path::new(fname).exists() {
658        if append {
659            OpenOptions::new().append(true).open(fname)
660        } else {
661            OpenOptions::new().write(true).open(fname)
662        }
663    } else {
664        OpenOptions::new()
665            .write(true)
666            .truncate(true)
667            .create(true)
668            .open(fname)
669    }
670}
671
672//创建文件句柄
673fn create_filehandler(
674    filename: &str,
675    save_dir: &str,
676    append: bool,
677) -> Result<BufWriter<File>, Error> {
678    let mut fpath = filename.to_owned();
679    // 创建保存文件目录
680    if !save_dir.is_empty() {
681        let path = Path::new(save_dir);
682        if !path.exists() {
683            fs::create_dir(save_dir)?;
684        }
685        let mut path = PathBuf::from(save_dir);
686        path.push(filename);
687
688        fpath = path
689            .to_str()
690            .map(|p| p.to_string())
691            .unwrap_or("".to_string());
692    }
693
694    let handler = get_file_handle(fpath.as_str(), append)?;
695    Ok(BufWriter::new(handler))
696}
697
698#[allow(dead_code)]
699fn get_file_extension(file_path: &str) -> Option<&str> {
700    let path = Path::new(file_path);
701    path.extension().and_then(|s| s.to_str())
702}
703//生成文件名
704//val:inline; filename="4f319a7a8cd6d322c6d938f7b8c2adb9.zip"; filename*=utf-8''4f319a7a8cd6d322c6d938f7b8c2adb9.zip
705fn gen_filename(url: &Url, headers: Option<&HeaderMap>) -> String {
706    let content = headers
707        .and_then(|hdrs| hdrs.get(header::CONTENT_DISPOSITION))
708        .and_then(|val| {
709            let val = val.to_str().unwrap_or("");
710            if val.contains("filename=") {
711                Some(val)
712            } else {
713                None
714            }
715        })
716        .and_then(|s| {
717            let parts: Vec<&str> = s.rsplit(';').collect();
718            let mut filename: Option<String> = None;
719            for part in parts {
720                if part.trim().starts_with("filename=") {
721                    let name = part.trim().split('=').nth(1).unwrap_or("");
722                    if !name.is_empty() {
723                        let name = name.trim_start_matches('"').trim_end_matches('"');
724                        filename = Some(name.to_owned());
725                    }
726                    break;
727                }
728            }
729            filename
730        });
731
732    let filename = match content {
733        Some(val) => val,
734        None => {
735            let name = &url.path().split('/').last().unwrap_or("");
736            if !name.is_empty() {
737                match decode_percent_encoded_data(name) {
738                    Ok(val) => val,
739                    _ => name.to_string(),
740                }
741            } else {
742                "index.html".to_owned()
743            }
744        }
745    };
746    filename.trim().to_owned()
747}
748
749//url  decode
750fn decode_percent_encoded_data(data: &str) -> Result<String, Error> {
751    let mut unescaped_bytes: Vec<u8> = Vec::new();
752    let mut bytes = data.bytes();
753
754    while let Some(b) = bytes.next() {
755        match b as char {
756            '%' => {
757                let bytes_to_decode = &[bytes.next().unwrap_or(0), bytes.next().unwrap_or(0)];
758                let hex_str = std::str::from_utf8(bytes_to_decode)?;
759                unescaped_bytes.push(u8::from_str_radix(hex_str, 16)?);
760            }
761            _ => {
762                unescaped_bytes.push(b);
763            }
764        }
765    }
766
767    String::from_utf8(unescaped_bytes).map_err(|e| anyhow!(format!("String::from_utf8 ERR:{}", e)))
768}