weld_codegen/
loader.rs

1use std::{
2    path::{Path, PathBuf},
3    time::Duration,
4};
5
6use atelier_core::model::Model;
7use reqwest::Url;
8use rustc_hash::FxHasher;
9
10use crate::{
11    config::ModelSource,
12    error::{Error, Result},
13};
14
15/// maximum number of parallel downloader threads
16const MAX_PARALLEL_DOWNLOADS: u16 = 8;
17/// how long cached smithy file can be used before we attempt to download another
18const CACHED_FILE_MAX_AGE: Duration = Duration::from_secs(60 * 60 * 24); // one day
19const SMITHY_CACHE_ENV_VAR: &str = "SMITHY_CACHE";
20const SMITHY_CACHE_NO_EXPIRE: &str = "NO_EXPIRE";
21
22/// Load all model sources and merge into single model.
23/// - Sources may be a combination of files, directories, and urls.
24/// - Model files may be .smithy or .json
25/// See the codegen.toml documentation on `[[models]]` for
26/// a description of valid ModelSources.
27/// If `relative_dir` is provided, all relative paths read will be made relative to that folder,
28/// (Relative paths in codegen.toml are relative to the file codegen.toml, not
29/// necessarily the current directory of the OS process)
30/// Returns single merged model.
31pub fn sources_to_model(sources: &[ModelSource], base_dir: &Path, verbose: u8) -> Result<Model> {
32    let paths = sources_to_paths(sources, base_dir, verbose)?;
33    let mut assembler = atelier_assembler::ModelAssembler::default();
34    for path in paths.iter() {
35        if !path.exists() {
36            return Err(Error::MissingFile(format!(
37                "'{}' is not a valid path to a file or directory",
38                path.display(),
39            )));
40        }
41        let _ = assembler.push(path);
42    }
43    let model: Model = assembler
44        .try_into()
45        .map_err(|e| Error::Model(format!("assembling model: {e:#?}")))?;
46    Ok(model)
47}
48
49/// Flatten source lists and collect list of paths to local files.
50/// All returned paths that were relative have been joined to base_dir.
51/// Download any urls to cache dir if they aren't already cached.
52/// If any of the source paths are local directories, they are passed
53/// to the result and the caller is expected to traverse them
54/// or pass them to an Assembler for traversal.
55#[doc(hidden)]
56pub(crate) fn sources_to_paths(
57    sources: &[ModelSource],
58    base_dir: &Path,
59    verbose: u8,
60) -> Result<Vec<PathBuf>> {
61    let mut results = Vec::new();
62    let mut urls = Vec::new();
63
64    for source in sources.iter() {
65        match source {
66            ModelSource::Path { path, files } => {
67                let prefix = if path.is_absolute() {
68                    path.to_path_buf()
69                } else {
70                    base_dir.join(path)
71                };
72                if files.is_empty() {
73                    // If path is a file, it will be added; if a directory, and source.files is empty,
74                    // the directory will be traversed to find model files
75                    if verbose > 0 {
76                        println!("DEBUG: adding path: {}", &prefix.display());
77                    }
78                    results.push(prefix)
79                } else {
80                    for file in files.iter() {
81                        let path = prefix.join(file);
82                        if verbose > 0 {
83                            println!("DEBUG: adding path: {}", &path.display());
84                        }
85                        results.push(path);
86                    }
87                }
88            }
89            ModelSource::Url { url, files } => {
90                if files.is_empty() {
91                    if verbose > 0 {
92                        println!("DEBUG: adding url: {url}");
93                    }
94                    urls.push(url.to_string());
95                } else {
96                    for file in files.iter() {
97                        let url = format!(
98                            "{}{}{}",
99                            url,
100                            if !url.ends_with('/') && !file.starts_with('/') { "/" } else { "" },
101                            file
102                        );
103                        if verbose > 0 {
104                            println!("DEBUG: adding url: {}", &url);
105                        }
106                        urls.push(url);
107                    }
108                }
109            }
110        }
111    }
112    if !urls.is_empty() {
113        let cached = urls_to_cached_files(urls)?;
114        results.extend_from_slice(&cached);
115    }
116    Ok(results)
117}
118
119/// Returns cache_path, relative to download directory
120/// format: host_dir/file_stem.HASH.ext
121fn url_to_cache_path(url: &str) -> Result<PathBuf> {
122    let origin = url.parse::<Url>().map_err(|e| bad_url(url, e))?;
123    let host_dir = origin.host_str().ok_or_else(|| bad_url(url, "no-host"))?;
124    let file_name = PathBuf::from(
125        origin
126            .path_segments()
127            .ok_or_else(|| bad_url(url, "path"))?
128            .last()
129            .map(|s| s.to_string())
130            .ok_or_else(|| bad_url(url, "last-path"))?,
131    );
132    let file_stem = file_name
133        .file_stem()
134        .map(|s| s.to_str())
135        .unwrap_or_default()
136        .unwrap_or("index");
137    let file_ext = file_name
138        .extension()
139        .map(|s| s.to_str())
140        .unwrap_or_default()
141        .unwrap_or("raw");
142    let new_file_name = format!("{}.{:x}.{}", file_stem, hash(origin.path()), file_ext);
143    let path = PathBuf::from(host_dir).join(new_file_name);
144    Ok(path)
145}
146
147/// Locate the weld cache directory
148#[doc(hidden)]
149pub fn weld_cache_dir() -> Result<PathBuf> {
150    let dirs = directories::BaseDirs::new()
151        .ok_or_else(|| Error::Other("invalid home directory".to_string()))?;
152    let weld_cache = dirs.cache_dir().join("smithy");
153    Ok(weld_cache)
154}
155
156/// Returns true if the file is older than the specified cache age.
157/// If the environment contains SMITHY_CACHE=NO_EXPIRE, the file age is ignored and false is returned.
158pub fn cache_expired(path: &Path) -> bool {
159    if let Ok(cache_flag) = std::env::var(SMITHY_CACHE_ENV_VAR) {
160        if cache_flag == SMITHY_CACHE_NO_EXPIRE {
161            return false;
162        }
163    }
164    if let Ok(md) = std::fs::metadata(path) {
165        if let Ok(modified) = md.modified() {
166            if let Ok(age) = modified.elapsed() {
167                return age >= CACHED_FILE_MAX_AGE;
168            }
169        }
170    }
171    // If the OS can't read the file timestamp, assume it's expired and return true.
172    true
173}
174
175/// Returns a list of cached files for a list of urls. Files that are not present in the cache are fetched
176/// with a parallel downloader. This function fails if any file cannot be retrieved.
177/// Files are downloaded into a temp dir, so that if there's a download error they don't overwrite
178/// any cached values
179fn urls_to_cached_files(urls: Vec<String>) -> Result<Vec<PathBuf>> {
180    let mut results = Vec::new();
181    let mut to_download = Vec::new();
182
183    let weld_cache = weld_cache_dir()?;
184
185    let tmpdir =
186        tempfile::tempdir().map_err(|e| Error::Io(format!("creating temp folder: {e}")))?;
187    for url in urls.iter() {
188        let rel_path = url_to_cache_path(url)?;
189        let cache_path = weld_cache.join(&rel_path);
190        if cache_path.is_file() && !cache_expired(&cache_path) {
191            // found cached file
192            results.push(cache_path);
193        } else {
194            // no cache file (or expired), download to temp dir
195            let temp_path = tmpdir.path().join(&rel_path);
196            std::fs::create_dir_all(temp_path.parent().unwrap()).map_err(|e| {
197                crate::Error::Io(format!(
198                    "creating folder {}: {}",
199                    &temp_path.parent().unwrap().display(),
200                    e,
201                ))
202            })?;
203            let dl = downloader::Download::new(url).file_name(&temp_path);
204            to_download.push(dl);
205        }
206    }
207
208    if !to_download.is_empty() {
209        let mut downloader = downloader::Downloader::builder()
210            .download_folder(tmpdir.path())
211            .parallel_requests(MAX_PARALLEL_DOWNLOADS)
212            .build()
213            .map_err(|e| Error::Other(format!("internal error: download failure: {e}")))?;
214        // invoke parallel downloader, returns when all have been read
215        let result = downloader
216            .download(&to_download)
217            .map_err(|e| Error::Other(format!("download error: {e}")))?;
218
219        for r in result.iter() {
220            match r {
221                Err(e) => {
222                    println!("Failure downloading: {e}");
223                }
224                Ok(summary) => {
225                    for status in summary.status.iter() {
226                        if (200..300).contains(&status.1) {
227                            // take first with status ok
228                            let downloaded_file = &summary.file_name;
229                            let rel_path = downloaded_file.strip_prefix(&tmpdir).map_err(|e| {
230                                Error::Other(format!("internal download error {e}"))
231                            })?;
232                            let cache_file = weld_cache.join(rel_path);
233                            std::fs::create_dir_all(cache_file.parent().unwrap()).map_err(|e| {
234                                Error::Io(format!(
235                                    "creating folder {}: {}",
236                                    &cache_file.parent().unwrap().display(),
237                                    e
238                                ))
239                            })?;
240                            std::fs::copy(downloaded_file, &cache_file).map_err(|e| {
241                                Error::Other(format!(
242                                    "writing cache file {}: {}",
243                                    &cache_file.display(),
244                                    e
245                                ))
246                            })?;
247                            results.push(cache_file);
248                            break;
249                        } else {
250                            println!("Warning: url '{}' got status {}", status.0, status.1);
251                        }
252                    }
253                }
254            };
255        }
256    }
257    if results.len() != urls.len() {
258        Err(Error::Other(format!(
259            "Quitting - {} model files could not be downloaded and were not found in the cache. \
260             If you have previously built this project and are working \"offline\", try setting \
261             SMITHY_CACHE=NO_EXPIRE in the environment",
262            urls.len() - results.len()
263        )))
264    } else {
265        Ok(results)
266    }
267}
268
269fn bad_url<E: std::fmt::Display>(s: &str, e: E) -> Error {
270    Error::Other(format!("bad url {s}: {e}"))
271}
272
273#[cfg(test)]
274type TestResult = std::result::Result<(), Box<dyn std::error::Error>>;
275
276#[test]
277fn test_cache_path() -> TestResult {
278    assert_eq!(
279        "localhost/file.1dc75e4e94bec8fd.smithy",
280        url_to_cache_path("http://localhost/path/file.smithy")
281            .unwrap()
282            .to_str()
283            .unwrap()
284    );
285
286    assert_eq!(
287        "localhost/file.cd93a55565eb790a.smithy",
288        url_to_cache_path("http://localhost/path/to/file.smithy")
289            .unwrap()
290            .to_str()
291            .unwrap(),
292        "hash changes with path"
293    );
294
295    assert_eq!(
296        "localhost/file.1dc75e4e94bec8fd.smithy",
297        url_to_cache_path("http://localhost:8080/path/file.smithy")
298            .unwrap()
299            .to_str()
300            .unwrap(),
301        "hash is not dependent on port",
302    );
303
304    assert_eq!(
305        "127.0.0.1/file.1dc75e4e94bec8fd.smithy",
306        url_to_cache_path("http://127.0.0.1/path/file.smithy")
307            .unwrap()
308            .to_str()
309            .unwrap(),
310        "hash is not dependent on host",
311    );
312
313    assert_eq!(
314        "127.0.0.1/foo.3f066558cb61d00f.raw",
315        url_to_cache_path("http://127.0.0.1/path/foo").unwrap().to_str().unwrap(),
316        "generate .raw for missing extension",
317    );
318
319    assert_eq!(
320        "127.0.0.1/index.ce34ccb3ff9b34cd.raw",
321        url_to_cache_path("http://127.0.0.1/dir/").unwrap().to_str().unwrap(),
322        "generate index.raw for missing filename",
323    );
324
325    Ok(())
326}
327
328fn hash(s: &str) -> u64 {
329    use std::hash::Hasher;
330    let mut hasher = FxHasher::default();
331    hasher.write(s.as_bytes());
332    hasher.finish()
333}
334
335#[test]
336fn test_hash() {
337    assert_eq!(0, hash(""));
338    assert_eq!(18099358241699475913, hash("hello"));
339}