Skip to main content

dataset_core/
utils.rs

1use crate::DatasetError;
2use sha2::{Digest, Sha256};
3use std::fs::File;
4use std::io;
5use std::io::Read;
6use std::path::{Path, PathBuf};
7use zip::ZipArchive;
8use zip::result::ZipError;
9
10/// Download a remote file into the given directory.
11///
12/// It downloads the content at `url` (using [`ureq`] crate) into `storage_path` using the file name
13/// extracted from the last segment of the URL, unless a custom filename is provided.
14///
15/// # Parameters
16///
17/// - `url` - The URL to download.
18/// - `storage_path` - The directory to store the downloaded file in.
19/// - `filename` - Optional custom filename (with extension). If `None`, the filename is extracted
20///   from the last segment of the URL.
21///
22/// # Errors
23///
24/// - `DatasetError` - Returned when the download fails or URL is invalid.
25///
26/// # Example
27/// ```no_run
28/// use dataset_core::download_to;
29/// use std::path::Path;
30///
31/// let download_dir = "./download_example";
32/// std::fs::create_dir_all(download_dir).unwrap();
33///
34/// // Download a file from the internet
35/// let url = "https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv";
36///
37/// // Use filename from URL
38/// download_to(url, Path::new(download_dir), None).unwrap();
39/// assert!(Path::new(download_dir).join("iris.csv").exists());
40///
41/// // Use custom filename
42/// download_to(url, Path::new(download_dir), Some("custom.csv")).unwrap();
43/// assert!(Path::new(download_dir).join("custom.csv").exists());
44/// ```
45pub fn download_to(
46    url: &str,
47    storage_path: &Path,
48    filename: Option<&str>,
49) -> Result<(), DatasetError> {
50    // Get the filename: use provided name, or fall back to URL extraction
51    let filename = filename.or_else(|| url.split('/').next_back()).ok_or_else(|| {
52        DatasetError::ValidationError("Invalid URL: cannot extract filename from URL".to_string())
53    })?;
54
55    let save_path = storage_path.join(filename);
56
57    let mut response = ureq::get(url).call()?;
58    let mut body = response.body_mut().as_reader();
59
60    // create local file and write body to it
61    let mut file = File::create(save_path)?;
62    io::copy(&mut body, &mut file)?;
63
64    Ok(())
65}
66
67/// Extract a zip archive into a target directory using [`ZipArchive`] in [`zip`] crate.
68///
69/// # Parameters
70///
71/// - `file_path` - Path to the `.zip` file to extract.
72/// - `extract_dir` - Directory to extract the archive contents into.
73///
74/// # Errors
75///
76/// - `DatasetError` - Returned when opening the zip file fails or when extraction fails.
77///
78/// # Example
79/// ```no_run
80/// use dataset_core::{download_to, unzip};
81/// use std::path::Path;
82///
83/// let work_dir = "./unzip_example";
84/// std::fs::create_dir_all(work_dir).unwrap();
85///
86/// // First download a file
87/// let url = "https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv";
88/// download_to(url, Path::new(work_dir), None).unwrap();
89///
90/// // The file is already a CSV (no extraction needed in this example)
91/// assert!(Path::new(work_dir).join("iris.csv").exists());
92/// ```
93pub fn unzip(file_path: &Path, extract_dir: &Path) -> Result<(), DatasetError> {
94    let file = File::open(file_path).map_err(|e| DatasetError::from(ZipError::Io(e)))?;
95
96    ZipArchive::new(file)?.extract(extract_dir)?;
97
98    Ok(())
99}
100
101/// Create a temporary directory under the given parent directory.
102///
103/// This is a small wrapper around [`tempfile::Builder`] used by dataset loaders to
104/// keep intermediate download/extraction artifacts isolated. The created directory
105/// is removed automatically when the returned [`tempfile::TempDir`] is dropped.
106///
107/// # Parameters
108///
109/// - `tempdir_in` - The parent directory in which the temporary directory will be created.
110///
111/// # Errors
112///
113/// - `DatasetError` - Returned if the temporary directory cannot be created.
114///
115/// # Example
116/// ```no_run
117/// use dataset_core::create_temp_dir;
118/// use std::path::Path;
119///
120/// let parent_dir = "./temp_dir_example";
121/// std::fs::create_dir_all(parent_dir).unwrap();
122///
123/// // Create a temporary directory
124/// let temp_dir = create_temp_dir(Path::new(parent_dir)).unwrap();
125/// let temp_path = temp_dir.path();
126///
127/// // Use the temporary directory for intermediate operations
128/// let temp_file = temp_path.join("temp_file.txt");
129/// std::fs::write(&temp_file, "temporary content").unwrap();
130/// assert!(temp_file.exists());
131///
132/// // The temporary directory is automatically removed when `temp_dir` is dropped
133/// drop(temp_dir);
134/// ```
135pub fn create_temp_dir(tempdir_in: &Path) -> Result<tempfile::TempDir, DatasetError> {
136    let temp_dir = tempfile::Builder::new().tempdir_in(tempdir_in)?;
137
138    Ok(temp_dir)
139}
140
141/// Verify that a file's SHA256 hash matches an expected value.
142///
143/// This function computes the SHA256 hash of the file at the given path and compares
144/// it with the expected hexadecimal hash string (case-insensitive). It is used by
145/// dataset loaders to validate downloaded files before parsing.
146///
147/// # Parameters
148///
149/// - `path` - Path to the file to verify.
150/// - `expected_hex` - Expected SHA256 hash as a hexadecimal string.
151///
152/// # Returns
153///
154/// - `bool` - true if the computed hash matches the expected hash, false if the hashes don't match
155///
156/// # Errors
157///
158/// - `DatasetError::IoError` - Returned when file I/O operations fail (opening file, reading data).
159///
160/// # Example
161/// ```no_run
162/// use dataset_core::file_sha256_matches;
163/// use std::path::Path;
164/// use std::io::Write;
165///
166/// let test_dir = "./sha256_example";
167/// std::fs::create_dir_all(test_dir).unwrap();
168///
169/// // Create a test file with known content
170/// let file_path = Path::new(test_dir).join("test.txt");
171/// let mut file = std::fs::File::create(&file_path).unwrap();
172/// file.write_all(b"hello world").unwrap();
173/// drop(file);
174///
175/// // SHA256 of "hello world" is:
176/// // b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9
177/// let expected_hash = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
178///
179/// // Verify the hash matches
180/// assert!(file_sha256_matches(&file_path, expected_hash).unwrap());
181///
182/// // Case-insensitive comparison also works
183/// let upper_hash = "B94D27B9934D3E08A52E52D7DA7DABFAC484EFE37A5380EE9088F7ACE2EFCDE9";
184/// assert!(file_sha256_matches(&file_path, upper_hash).unwrap());
185///
186/// // Wrong hash returns false
187/// assert!(!file_sha256_matches(&file_path, "0000000000000000000000000000000000000000000000000000000000000000").unwrap());
188/// ```
189pub fn file_sha256_matches(path: &Path, expected_hex: &str) -> Result<bool, DatasetError> {
190    let mut file = File::open(path)?;
191
192    let mut hasher = Sha256::new();
193    let mut buf = [0u8; 8192];
194
195    loop {
196        let read = file.read(&mut buf)?;
197        if read == 0 {
198            break;
199        }
200        hasher.update(&buf[..read]);
201    }
202
203    let digest = hasher.finalize();
204    let actual_hex = digest
205        .iter()
206        .map(|b| format!("{:02x}", b))
207        .collect::<String>();
208    Ok(actual_hex.eq_ignore_ascii_case(expected_hex))
209}
210
211/// Evaluate the storage state for a dataset file.
212///
213/// This helper ensures the target directory exists and checks whether the destination
214/// file is already present and valid. If `expected_sha256` is `None`, any existing
215/// file at `dst` is accepted without validation.
216///
217/// # Parameters
218///
219/// - `path` - Directory path where the dataset will be stored.
220/// - `dst` - Destination file path for the dataset.
221/// - `expected_sha256` - Optional expected SHA256 hash for the dataset file. If `None`,
222///   any existing file at `dst` is accepted without validation.
223///
224/// # Returns
225///
226/// - `(need_acquire, need_overwrite)` - Flags indicating whether a new file needs to be
227///   prepared and whether an existing file should be overwritten.
228///
229/// # Errors
230///
231/// - `DatasetError::IoError` - Returned when creating the directory fails or when
232///   file I/O operations fail during hash verification.
233///
234/// # Example
235/// ```no_run
236/// use dataset_core::utils::evaluate_storage;
237/// use std::path::Path;
238/// use std::io::Write;
239///
240/// let test_dir = "./evaluate_storage_example";
241/// let dir_path = Path::new(test_dir);
242/// let file_path = dir_path.join("data.txt");
243///
244/// // Case 1: Directory doesn't exist yet
245/// let (need_acquire, need_overwrite) = evaluate_storage(
246///     dir_path,
247///     &file_path,
248///     None,
249/// ).unwrap();
250/// assert!(need_acquire);     // File doesn't exist, a new file must be prepared
251/// assert!(!need_overwrite);  // Nothing to overwrite
252///
253/// // Case 2: File exists with correct hash
254/// let mut file = std::fs::File::create(&file_path).unwrap();
255/// file.write_all(b"hello world").unwrap();
256/// drop(file);
257///
258/// let correct_hash = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
259/// let (need_acquire, need_overwrite) = evaluate_storage(
260///     dir_path,
261///     &file_path,
262///     Some(correct_hash),
263/// ).unwrap();
264/// assert!(!need_acquire);    // File exists with correct hash
265/// assert!(!need_overwrite);
266///
267/// // Case 3: File exists but hash doesn't match
268/// let wrong_hash = "0000000000000000000000000000000000000000000000000000000000000000";
269/// let (need_acquire, need_overwrite) = evaluate_storage(
270///     dir_path,
271///     &file_path,
272///     Some(wrong_hash),
273/// ).unwrap();
274/// assert!(need_acquire);     // Hash mismatch, a new file must be prepared
275/// assert!(need_overwrite);   // Existing file needs to be replaced
276/// ```
277pub fn evaluate_storage(
278    path: &Path,
279    dst: &Path,
280    expected_sha256: Option<&str>,
281) -> Result<(bool, bool), DatasetError> {
282    let mut need_acquire = true;
283    let mut need_overwrite = false;
284
285    if !path.exists() {
286        std::fs::create_dir_all(path)?;
287    }
288
289    if dst.exists() {
290        if let Some(hash) = expected_sha256 {
291            // SHA256 validation enabled
292            if file_sha256_matches(dst, hash)? {
293                need_acquire = false;
294            } else {
295                need_overwrite = true;
296            }
297        } else {
298            // No SHA256 validation: accept existing file
299            need_acquire = false;
300        }
301    }
302
303    Ok((need_acquire, need_overwrite))
304}
305
306/// Acquire a dataset file using a caller-provided preparation closure.
307///
308/// This function orchestrates the dataset acquisition workflow: it checks whether
309/// the destination file can be reused, creates a temporary directory when a new
310/// file is needed, delegates file preparation to a user-provided closure,
311/// optionally validates the prepared file with SHA256, and moves it to the final
312/// destination.
313///
314/// The function itself does not perform network I/O. The `prepare_file` closure
315/// is responsible for preparing the dataset file, which may include downloading,
316/// extracting archives, or locating files within an extracted directory.
317///
318/// # Parameters
319///
320/// - `dir` - Target storage directory path.
321/// - `filename` - Final dataset filename (stored as `dir/filename`).
322///   Please give the filename with the extension (e.g., `"iris.csv"`).
323/// - `dataset_name` - Dataset name for error messages (e.g., `"iris"`).
324/// - `expected_sha256` - Optional expected SHA256 hash of the dataset file. If `None`,
325///   any existing file at the destination is accepted without validation, and newly
326///   prepared files skip SHA256 verification.
327/// - `prepare_file` - Closure that prepares the dataset file in the temporary directory.
328///   - Input: `temp_dir: &Path` - Path to the temporary directory.
329///     It is recommended to execute file operations within this directory, as it will be
330///     cleaned up automatically when the closure returns. But it is not required.
331///     (Please note that the file will be moved to the final destination, not copied.)
332///   - Output: `Result<PathBuf, DatasetError>` - Path to the prepared dataset file
333///     (which will be moved to `dir/filename`).
334///   - Responsibility: This closure can perform any operations needed to prepare the
335///     dataset file, such as downloading (you can use [`download_to`] provided in this crate),
336///     extracting archives (you can use [`unzip`] provided in this crate), or locating files
337///     within extracted folders. The returned `PathBuf` must point to the final dataset file
338///     ready for validation.
339///
340/// # Returns
341///
342/// - `PathBuf` - Path to the final dataset file (`dir/filename`).
343///
344/// # Errors
345///
346/// - `DatasetError::IoError` - Returned when directory creation, file operations, or
347///   hash verification fails.
348/// - `DatasetError::Sha256ValidationFailed` - Returned when `expected_sha256` is provided
349///   and the prepared file's SHA256 hash does not match it.
350/// - Any error returned by the `prepare_file` closure.
351///
352/// # Example
353/// ```no_run
354/// // Implement the file preparation process for the Iris dataset.
355///
356/// /// The URL for the Iris dataset.
357/// ///
358/// /// # Citation
359/// ///
360/// /// R. A. Fisher. "Iris," UCI Machine Learning Repository, \[Online\].
361/// /// Available: <https://doi.org/10.24432/C56C76>
362/// const IRIS_DATA_URL: &str = "https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv";
363///
364/// /// The name of the Iris dataset file.
365/// const IRIS_FILENAME: &str = "iris.csv";
366///
367/// /// The SHA256 hash of the Iris dataset file.
368/// const IRIS_SHA256: &str = "c52742e50315a99f956a383faedf7575552675f6409ef0f9a47076dd08479930";
369///
370/// /// The name of the dataset.
371/// const IRIS_DATASET_NAME: &str = "iris";
372///
373/// use dataset_core::acquire_dataset;
374/// use dataset_core::download_to;
375///
376/// fn main() {
377///     let dir = "./somewhere";
378///
379///     let file_path = acquire_dataset(
380///         // Target storage directory path
381///         dir,
382///         // Final dataset filename (will be stored as `dir/filename`)
383///         IRIS_FILENAME,
384///         // Dataset name for error messages
385///         IRIS_DATASET_NAME,
386///         // Expected SHA256 hash of the dataset file
387///         Some(IRIS_SHA256),
388///         // Closure that prepares the dataset file in the temporary directory
389///         |temp_path| {
390///             // Download the dataset into the temporary directory
391///             download_to(IRIS_DATA_URL, temp_path, None)?;
392///             Ok(temp_path.join(IRIS_FILENAME))
393///         },
394///     ).unwrap();
395///
396///     // `file_path` is now the path to the acquired Iris dataset file.
397///     // It can be used to locate or parse the dataset.
398/// }
399/// ```
400pub fn acquire_dataset<F>(
401    dir: &str,
402    filename: &str,
403    dataset_name: &str,
404    expected_sha256: Option<&str>,
405    prepare_file: F,
406) -> Result<PathBuf, DatasetError>
407where
408    F: FnOnce(&Path) -> Result<PathBuf, DatasetError>,
409{
410    let dir_path = Path::new(dir);
411    let dst = dir_path.join(filename);
412    let (need_acquire, need_overwrite) = evaluate_storage(dir_path, &dst, expected_sha256)?;
413
414    if need_acquire {
415        let temp_dir = create_temp_dir(dir_path)?;
416        let temp_path = temp_dir.path();
417
418        // Call user closure: prepare the dataset file in temporary directory
419        let src = prepare_file(temp_path)?;
420
421        // Validate SHA256 hash if provided
422        if let Some(hash) = expected_sha256 {
423            if !file_sha256_matches(&src, hash)? {
424                drop(temp_dir); // Clean up temporary directory
425                return Err(DatasetError::sha256_validation_failed(
426                    dataset_name,
427                    filename,
428                ));
429            }
430        }
431
432        // Move file to final destination
433        if need_overwrite {
434            std::fs::remove_file(&dst)?;
435        }
436        std::fs::rename(&src, &dst)?;
437    }
438
439    Ok(dst)
440}