1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
use lazy_static::lazy_static;
use std::collections::HashMap;
use std::fs;
use std::io::Cursor;
use std::path::PathBuf;
use thiserror::Error;

lazy_static! {
    static ref MODEL_DATA: HashMap<&'static str, &'static str> = {
        // this is checked at compile time so a relative path is ok
        let raw_csv = include_str!("../models.csv");
        let mut model_data = HashMap::new();

        for line in raw_csv.lines() {
            let mut parts = line.split(',');

            model_data.insert(parts.next().unwrap(), parts.next().unwrap());
        }

        model_data
    };
}

/// An error retrieving a resource.
#[derive(Error, Debug)]
#[allow(missing_docs)]
pub enum ResourceError {
    #[error("network error fetching \"{file_name}\" for \"{model_name}\": {source}")]
    NetworkError {
        model_name: String,
        file_name: String,
        source: minreq::Error,
    },
    #[error("model not found: \"{model_name}\"")]
    ModelNotFoundError { model_name: String },
    #[error(transparent)]
    UrlParseError { source: url::ParseError },
    #[error(transparent)]
    IoError { source: std::io::Error },
}

impl From<url::ParseError> for ResourceError {
    fn from(source: url::ParseError) -> Self {
        ResourceError::UrlParseError { source }
    }
}

impl From<std::io::Error> for ResourceError {
    fn from(source: std::io::Error) -> Self {
        ResourceError::IoError { source }
    }
}

/// Loads the file for the given model, either retrieving it from the cache or downloading it if it is not found.
pub fn get_resource(
    model_name: &str,
    file: &str,
) -> Result<(impl std::io::Read, Option<PathBuf>), ResourceError> {
    let base_url = url::Url::parse(MODEL_DATA.get(model_name).ok_or_else(|| {
        ResourceError::ModelNotFoundError {
            model_name: model_name.to_owned(),
        }
    })?)?;
    let url = base_url.join(file)?;
    let mut cache_path: Option<PathBuf> = None;

    // try to find a file at which to cache the data
    if let Some(project_dirs) = directories::ProjectDirs::from("", "", "nnsplit") {
        let cache_dir = project_dirs.cache_dir();

        cache_path = Some(cache_dir.join(model_name).join(file));
    }

    // if the file can be read, the data is already cached ...
    if let Some(path) = &cache_path {
        if let Ok(bytes) = fs::read(path) {
            return Ok((Cursor::new(bytes), cache_path));
        }
    }

    // ... otherwise, request the data from the URL ...
    let bytes = minreq::get(&url.to_string())
        .send()
        .map_err(|source| ResourceError::NetworkError {
            model_name: model_name.to_owned(),
            file_name: file.to_owned(),
            source,
        })?
        .into_bytes();

    // ... and then cache the data at the provided file, if one was found
    if let Some(path) = &cache_path {
        fs::create_dir_all(path.parent().unwrap())?;
        fs::write(path, &bytes)?;
    }

    Ok((Cursor::new(bytes), cache_path))
}