1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
//! Module controlling models downloadable from ONNX Model Zoom
//!
//! Pre-trained models are available from the
//! [ONNX Model Zoo](https://github.com/onnx/models).
//!
//! A pre-trained model can be downloaded automatically using the
//! [`SessionBuilder`](../session/struct.SessionBuilder.html)'s
//! [`with_model_downloaded()`](../session/struct.SessionBuilder.html#method.with_model_downloaded) method.
//!
//! See [`AvailableOnnxModel`](enum.AvailableOnnxModel.html) for the different models available
//! to download.

#[cfg(feature = "model-fetching")]
use std::{
    fs, io,
    path::{Path, PathBuf},
    time::Duration,
};

#[cfg(feature = "model-fetching")]
use crate::error::{OrtDownloadError, Result};

#[cfg(feature = "model-fetching")]
use tracing::info;

pub mod language;
pub mod vision;

/// Available pre-trained models to download from [ONNX Model Zoo](https://github.com/onnx/models).
///
/// According to [ONNX Model Zoo](https://github.com/onnx/models)'s GitHub page:
///
/// > The ONNX Model Zoo is a collection of pre-trained, state-of-the-art models in the ONNX format
/// > contributed by community members like you.
#[derive(Debug, Clone)]
pub enum AvailableOnnxModel {
    /// Computer vision model
    Vision(vision::Vision),
    /// Natural language model
    Language(language::Language),
}

trait ModelUrl {
    fn fetch_url(&self) -> &'static str;
}

impl ModelUrl for AvailableOnnxModel {
    fn fetch_url(&self) -> &'static str {
        match self {
            AvailableOnnxModel::Vision(model) => model.fetch_url(),
            AvailableOnnxModel::Language(model) => model.fetch_url(),
        }
    }
}

impl AvailableOnnxModel {
    #[cfg(feature = "model-fetching")]
    #[tracing::instrument]
    pub(crate) fn download_to<P>(&self, download_dir: P) -> Result<PathBuf>
    where
        P: AsRef<Path> + std::fmt::Debug,
    {
        let url = self.fetch_url();

        let model_filename = PathBuf::from(url.split('/').last().unwrap());
        let model_filepath = download_dir.as_ref().join(model_filename);

        if model_filepath.exists() {
            info!(
                model_filepath = format!("{}", model_filepath.display()).as_str(),
                "File already exists, not re-downloading.",
            );
            Ok(model_filepath)
        } else {
            info!(
                model_filepath = format!("{}", model_filepath.display()).as_str(),
                url = format!("{:?}", url).as_str(),
                "Downloading file, please wait....",
            );

            let resp = ureq::get(url)
                .timeout(Duration::from_secs(180)) // 3 minutes
                .call()
                .map_err(Box::new)
                .map_err(OrtDownloadError::UreqError)?;

            assert!(resp.has("Content-Length"));
            let len = resp
                .header("Content-Length")
                .and_then(|s| s.parse::<usize>().ok())
                .unwrap();
            info!(len, "Downloading {} bytes...", len);

            let mut reader = resp.into_reader();

            let f = fs::File::create(&model_filepath).unwrap();
            let mut writer = io::BufWriter::new(f);

            let bytes_io_count =
                io::copy(&mut reader, &mut writer).map_err(OrtDownloadError::IoError)?;

            if bytes_io_count == len as u64 {
                Ok(model_filepath)
            } else {
                Err(OrtDownloadError::CopyError {
                    expected: len as u64,
                    io: bytes_io_count,
                }
                .into())
            }
        }
    }
}