onnxruntime_ng/
download.rs

1//! Module controlling models downloadable from ONNX Model Zoo
2//!
3//! Pre-trained models are available from the
4//! [ONNX Model Zoo](https://github.com/onnx/models).
5//!
6//! A pre-trained model can be downloaded automatically using the
7//! [`SessionBuilder`](../session/struct.SessionBuilder.html)'s
8//! [`with_model_downloaded()`](../session/struct.SessionBuilder.html#method.with_model_downloaded) method.
9//!
10//! See [`AvailableOnnxModel`](enum.AvailableOnnxModel.html) for the different models available
11//! to download.
12
13#[cfg(feature = "model-fetching")]
14use std::{
15    fs, io,
16    path::{Path, PathBuf},
17    time::Duration,
18};
19
20#[cfg(feature = "model-fetching")]
21use crate::error::{OrtDownloadError, Result};
22
23#[cfg(feature = "model-fetching")]
24use tracing::info;
25
26pub mod language;
27pub mod vision;
28
29/// Available pre-trained models to download from [ONNX Model Zoo](https://github.com/onnx/models).
30///
31/// According to [ONNX Model Zoo](https://github.com/onnx/models)'s GitHub page:
32///
33/// > The ONNX Model Zoo is a collection of pre-trained, state-of-the-art models in the ONNX format
34/// > contributed by community members like you.
35#[derive(Debug, Clone)]
36pub enum AvailableOnnxModel {
37    /// Computer vision model
38    Vision(vision::Vision),
39    /// Natural language model
40    Language(language::Language),
41}
42
43trait ModelUrl {
44    fn fetch_url(&self) -> &'static str;
45}
46
47impl ModelUrl for AvailableOnnxModel {
48    fn fetch_url(&self) -> &'static str {
49        match self {
50            AvailableOnnxModel::Vision(model) => model.fetch_url(),
51            AvailableOnnxModel::Language(model) => model.fetch_url(),
52        }
53    }
54}
55
56impl AvailableOnnxModel {
57    #[cfg(feature = "model-fetching")]
58    #[tracing::instrument]
59    pub(crate) fn download_to<P>(&self, download_dir: P) -> Result<PathBuf>
60    where
61        P: AsRef<Path> + std::fmt::Debug,
62    {
63        let url = self.fetch_url();
64
65        let model_filename = PathBuf::from(url.split('/').last().unwrap());
66        let model_filepath = download_dir.as_ref().join(model_filename);
67
68        if model_filepath.exists() {
69            info!(
70                model_filepath = format!("{}", model_filepath.display()).as_str(),
71                "File already exists, not re-downloading.",
72            );
73            Ok(model_filepath)
74        } else {
75            info!(
76                model_filepath = format!("{}", model_filepath.display()).as_str(),
77                url = format!("{:?}", url).as_str(),
78                "Downloading file, please wait....",
79            );
80
81            let resp = ureq::get(url)
82                .timeout(Duration::from_secs(180)) // 3 minutes
83                .call()
84                .map_err(Box::new)
85                .map_err(OrtDownloadError::UreqError)?;
86
87            assert!(resp.has("Content-Length"));
88            let len = resp
89                .header("Content-Length")
90                .and_then(|s| s.parse::<usize>().ok())
91                .unwrap();
92            info!(len, "Downloading {} bytes...", len);
93
94            let mut reader = resp.into_reader();
95
96            let f = fs::File::create(&model_filepath).unwrap();
97            let mut writer = io::BufWriter::new(f);
98
99            let bytes_io_count =
100                io::copy(&mut reader, &mut writer).map_err(OrtDownloadError::IoError)?;
101
102            if bytes_io_count == len as u64 {
103                Ok(model_filepath)
104            } else {
105                Err(OrtDownloadError::CopyError {
106                    expected: len as u64,
107                    io: bytes_io_count,
108                }
109                .into())
110            }
111        }
112    }
113}