Skip to main content

dataset_core/
utils.rs

1use crate::DatasetError;
2use sha2::{Digest, Sha256};
3use std::fs::File;
4use std::io;
5use std::io::Read;
6use std::path::{Path, PathBuf};
7use zip::ZipArchive;
8use zip::result::ZipError;
9
10/// Download a remote file into the given directory.
11///
12/// It downloads the content at `url` (using [`ureq`] crate) into `storage_path` using the file name
13/// extracted from the last segment of the URL, unless a custom filename is provided.
14///
15/// # Parameters
16///
17/// - `url` - The URL to download.
18/// - `storage_path` - The directory to store the downloaded file in.
19/// - `filename` - Optional custom filename (with extension). If `None`, the filename is extracted
20///   from the last segment of the URL.
21///
22/// # Errors
23///
24/// - `DatasetError` - Returned when the download fails or URL is invalid.
25///
26/// # Example
27/// ```rust
28/// use dataset_core::download_to;
29/// use std::path::Path;
30///
31/// let download_dir = "./download_example";
32/// std::fs::create_dir_all(download_dir).unwrap();
33///
34/// // Download a file from the internet
35/// let url = "https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv";
36///
37/// // Use filename from URL
38/// download_to(url, Path::new(download_dir), None).unwrap();
39/// assert!(Path::new(download_dir).join("iris.csv").exists());
40///
41/// // Use custom filename
42/// download_to(url, Path::new(download_dir), Some("custom.csv")).unwrap();
43/// assert!(Path::new(download_dir).join("custom.csv").exists());
44///
45/// // Clean up (dispensable)
46/// std::fs::remove_dir_all(download_dir).unwrap();
47/// ```
48pub fn download_to(
49    url: &str,
50    storage_path: &Path,
51    filename: Option<&str>,
52) -> Result<(), DatasetError> {
53    // Get the filename: use provided name, or fall back to URL extraction
54    let filename = filename.or_else(|| url.split('/').last()).ok_or_else(|| {
55        DatasetError::ValidationError("Invalid URL: cannot extract filename from URL".to_string())
56    })?;
57
58    let save_path = storage_path.join(filename);
59
60    let mut response = ureq::get(url).call()?;
61    let mut body = response.body_mut().as_reader();
62
63    // create local file and write body to it
64    let mut file = File::create(save_path)?;
65    io::copy(&mut body, &mut file)?;
66
67    Ok(())
68}
69
70/// Extract a zip archive into a target directory using [`ZipArchive`] in [`zip`] crate.
71///
72/// # Parameters
73///
74/// - `file_path` - Path to the `.zip` file to extract.
75/// - `extract_dir` - Directory to extract the archive contents into.
76///
77/// # Errors
78///
79/// - `DatasetError` - Returned when opening the zip file fails or when extraction fails.
80///
81/// # Example
82/// ```rust
83/// use dataset_core::{download_to, unzip};
84/// use std::path::Path;
85///
86/// let work_dir = "./unzip_example";
87/// std::fs::create_dir_all(work_dir).unwrap();
88///
89/// // First download a file
90/// let url = "https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv";
91/// download_to(url, Path::new(work_dir), None).unwrap();
92///
93/// // The file is already a CSV (no extraction needed in this example)
94/// assert!(Path::new(work_dir).join("iris.csv").exists());
95///
96/// // Clean up (dispensable)
97/// std::fs::remove_dir_all(work_dir).unwrap();
98/// ```
99pub fn unzip(file_path: &Path, extract_dir: &Path) -> Result<(), DatasetError> {
100    let file = File::open(file_path).map_err(|e| DatasetError::from(ZipError::Io(e)))?;
101
102    ZipArchive::new(file)?.extract(extract_dir)?;
103
104    Ok(())
105}
106
107/// Create a temporary directory under the given parent directory.
108///
109/// This is a small wrapper around [`tempfile::Builder`] used by dataset loaders to
110/// keep intermediate download/extraction artifacts isolated. The created directory
111/// is removed automatically when the returned [`tempfile::TempDir`] is dropped.
112///
113/// # Parameters
114///
115/// - `tempdir_in` - The parent directory in which the temporary directory will be created.
116///
117/// # Errors
118///
119/// - `DatasetError` - Returned if the temporary directory cannot be created.
120///
121/// # Example
122/// ```rust
123/// use dataset_core::create_temp_dir;
124/// use std::path::Path;
125///
126/// let parent_dir = "./temp_dir_example";
127/// std::fs::create_dir_all(parent_dir).unwrap();
128///
129/// // Create a temporary directory
130/// let temp_dir = create_temp_dir(Path::new(parent_dir)).unwrap();
131/// let temp_path = temp_dir.path();
132///
133/// // Use the temporary directory for intermediate operations
134/// let temp_file = temp_path.join("temp_file.txt");
135/// std::fs::write(&temp_file, "temporary content").unwrap();
136/// assert!(temp_file.exists());
137///
138/// // The temporary directory is automatically removed when `temp_dir` is dropped
139/// drop(temp_dir);
140///
141/// // Clean up parent directory
142/// std::fs::remove_dir_all(parent_dir).unwrap();
143/// ```
144pub fn create_temp_dir(tempdir_in: &Path) -> Result<tempfile::TempDir, DatasetError> {
145    let temp_dir = tempfile::Builder::new().tempdir_in(tempdir_in)?;
146
147    Ok(temp_dir)
148}
149
150/// Verify that a file's SHA256 hash matches an expected value.
151///
152/// This function computes the SHA256 hash of the file at the given path and compares
153/// it with the expected hexadecimal hash string (case-insensitive). It is used by
154/// dataset loaders to validate downloaded files before parsing.
155///
156/// # Parameters
157///
158/// - `path` - Path to the file to verify.
159/// - `expected_hex` - Expected SHA256 hash as a hexadecimal string.
160///
161/// # Returns
162///
163/// - `bool` - true if the computed hash matches the expected hash, false if the hashes don't match
164///
165/// # Errors
166///
167/// - `DatasetError::IoError` - Returned when file I/O operations fail (opening file, reading data).
168///
169/// # Example
170/// ```rust
171/// use dataset_core::file_sha256_matches;
172/// use std::path::Path;
173/// use std::io::Write;
174///
175/// let test_dir = "./sha256_example";
176/// std::fs::create_dir_all(test_dir).unwrap();
177///
178/// // Create a test file with known content
179/// let file_path = Path::new(test_dir).join("test.txt");
180/// let mut file = std::fs::File::create(&file_path).unwrap();
181/// file.write_all(b"hello world").unwrap();
182/// drop(file);
183///
184/// // SHA256 of "hello world" is:
185/// // b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9
186/// let expected_hash = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
187///
188/// // Verify the hash matches
189/// assert!(file_sha256_matches(&file_path, expected_hash).unwrap());
190///
191/// // Case-insensitive comparison also works
192/// let upper_hash = "B94D27B9934D3E08A52E52D7DA7DABFAC484EFE37A5380EE9088F7ACE2EFCDE9";
193/// assert!(file_sha256_matches(&file_path, upper_hash).unwrap());
194///
195/// // Wrong hash returns false
196/// assert!(!file_sha256_matches(&file_path, "0000000000000000000000000000000000000000000000000000000000000000").unwrap());
197///
198/// // Clean up (dispensable)
199/// std::fs::remove_dir_all(test_dir).unwrap();
200/// ```
201pub fn file_sha256_matches(path: &Path, expected_hex: &str) -> Result<bool, DatasetError> {
202    let mut file = File::open(path)?;
203
204    let mut hasher = Sha256::new();
205    let mut buf = [0u8; 8192];
206
207    loop {
208        let read = file.read(&mut buf)?;
209        if read == 0 {
210            break;
211        }
212        hasher.update(&buf[..read]);
213    }
214
215    let digest = hasher.finalize();
216    let actual_hex = digest
217        .iter()
218        .map(|b| format!("{:02x}", b))
219        .collect::<String>();
220    Ok(actual_hex.eq_ignore_ascii_case(expected_hex))
221}
222
223/// Evaluate the storage state for a dataset file.
224///
225/// This helper ensures the target directory exists and checks whether the destination
226/// file is already present and valid. If `expected_sha256` is `None`, any existing
227/// file at `dst` is accepted without validation.
228///
229/// # Parameters
230///
231/// - `path` - Directory path where the dataset will be stored.
232/// - `dst` - Destination file path for the dataset.
233/// - `expected_sha256` - Optional expected SHA256 hash for the dataset file. If `None`,
234///   any existing file at `dst` is accepted without validation.
235///
236/// # Returns
237///
238/// - `(need_acquire, need_overwrite)` - Flags indicating whether a new file needs to be
239///   prepared and whether an existing file should be overwritten.
240///
241/// # Errors
242///
243/// - `DatasetError::IoError` - Returned when creating the directory fails or when
244///   file I/O operations fail during hash verification.
245///
246/// # Example
247/// ```rust
248/// use dataset_core::utils::evaluate_storage;
249/// use std::path::Path;
250/// use std::io::Write;
251///
252/// let test_dir = "./evaluate_storage_example";
253/// let dir_path = Path::new(test_dir);
254/// let file_path = dir_path.join("data.txt");
255///
256/// // Case 1: Directory doesn't exist yet
257/// let (need_acquire, need_overwrite) = evaluate_storage(
258///     dir_path,
259///     &file_path,
260///     None,
261/// ).unwrap();
262/// assert!(need_acquire);     // File doesn't exist, a new file must be prepared
263/// assert!(!need_overwrite);  // Nothing to overwrite
264///
265/// // Case 2: File exists with correct hash
266/// let mut file = std::fs::File::create(&file_path).unwrap();
267/// file.write_all(b"hello world").unwrap();
268/// drop(file);
269///
270/// let correct_hash = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
271/// let (need_acquire, need_overwrite) = evaluate_storage(
272///     dir_path,
273///     &file_path,
274///     Some(correct_hash),
275/// ).unwrap();
276/// assert!(!need_acquire);    // File exists with correct hash
277/// assert!(!need_overwrite);
278///
279/// // Case 3: File exists but hash doesn't match
280/// let wrong_hash = "0000000000000000000000000000000000000000000000000000000000000000";
281/// let (need_acquire, need_overwrite) = evaluate_storage(
282///     dir_path,
283///     &file_path,
284///     Some(wrong_hash),
285/// ).unwrap();
286/// assert!(need_acquire);     // Hash mismatch, a new file must be prepared
287/// assert!(need_overwrite);   // Existing file needs to be replaced
288///
289/// // Clean up (dispensable)
290/// std::fs::remove_dir_all(test_dir).unwrap();
291/// ```
292pub fn evaluate_storage(
293    path: &Path,
294    dst: &Path,
295    expected_sha256: Option<&str>,
296) -> Result<(bool, bool), DatasetError> {
297    let mut need_acquire = true;
298    let mut need_overwrite = false;
299
300    if !path.exists() {
301        std::fs::create_dir_all(path)?;
302    }
303
304    if dst.exists() {
305        if let Some(hash) = expected_sha256 {
306            // SHA256 validation enabled
307            if file_sha256_matches(dst, hash)? {
308                need_acquire = false;
309            } else {
310                need_overwrite = true;
311            }
312        } else {
313            // No SHA256 validation: accept existing file
314            need_acquire = false;
315        }
316    }
317
318    Ok((need_acquire, need_overwrite))
319}
320
321/// Acquire a dataset file using a caller-provided preparation closure.
322///
323/// This function orchestrates the dataset acquisition workflow: it checks whether
324/// the destination file can be reused, creates a temporary directory when a new
325/// file is needed, delegates file preparation to a user-provided closure,
326/// optionally validates the prepared file with SHA256, and moves it to the final
327/// destination.
328///
329/// The function itself does not perform network I/O. The `prepare_file` closure
330/// is responsible for preparing the dataset file, which may include downloading,
331/// extracting archives, or locating files within an extracted directory.
332///
333/// # Parameters
334///
335/// - `dir` - Target storage directory path.
336/// - `filename` - Final dataset filename (stored as `dir/filename`).
337///   Please give the filename with the extension (e.g., `"iris.csv"`).
338/// - `dataset_name` - Dataset name for error messages (e.g., `"iris"`).
339/// - `expected_sha256` - Optional expected SHA256 hash of the dataset file. If `None`,
340///   any existing file at the destination is accepted without validation, and newly
341///   prepared files skip SHA256 verification.
342/// - `prepare_file` - Closure that prepares the dataset file in the temporary directory.
343///   - Input: `temp_dir: &Path` - Path to the temporary directory.
344///     It is recommended to execute file operations within this directory, as it will be
345///     cleaned up automatically when the closure returns. But it is not required.
346///     (Please note that the file will be moved to the final destination, not copied.)
347///   - Output: `Result<PathBuf, DatasetError>` - Path to the prepared dataset file
348///     (which will be moved to `dir/filename`).
349///   - Responsibility: This closure can perform any operations needed to prepare the
350///     dataset file, such as downloading (you can use [`download_to`] provided in this crate),
351///     extracting archives (you can use [`unzip`] provided in this crate), or locating files
352///     within extracted folders. The returned `PathBuf` must point to the final dataset file
353///     ready for validation.
354///
355/// # Returns
356///
357/// - `PathBuf` - Path to the final dataset file (`dir/filename`).
358///
359/// # Errors
360///
361/// - `DatasetError::IoError` - Returned when directory creation, file operations, or
362///   hash verification fails.
363/// - `DatasetError::Sha256ValidationFailed` - Returned when `expected_sha256` is provided
364///   and the prepared file's SHA256 hash does not match it.
365/// - Any error returned by the `prepare_file` closure.
366///
367/// # Example
368/// ```rust
369/// // Implement the file preparation process for the Iris dataset.
370///
371/// /// The URL for the Iris dataset.
372/// ///
373/// /// # Citation
374/// ///
375/// /// R. A. Fisher. "Iris," UCI Machine Learning Repository, \[Online\].
376/// /// Available: <https://doi.org/10.24432/C56C76>
377/// const IRIS_DATA_URL: &str = "https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv";
378///
379/// /// The name of the Iris dataset file.
380/// const IRIS_FILENAME: &str = "iris.csv";
381///
382/// /// The SHA256 hash of the Iris dataset file.
383/// const IRIS_SHA256: &str = "c52742e50315a99f956a383faedf7575552675f6409ef0f9a47076dd08479930";
384///
385/// /// The name of the dataset.
386/// const IRIS_DATASET_NAME: &str = "iris";
387///
388/// use dataset_core::acquire_dataset;
389/// use dataset_core::download_to;
390///
391/// fn main() {
392///     let dir = "./somewhere";
393///
394///     let file_path = acquire_dataset(
395///         // Target storage directory path
396///         dir,
397///         // Final dataset filename (will be stored as `dir/filename`)
398///         IRIS_FILENAME,
399///         // Dataset name for error messages
400///         IRIS_DATASET_NAME,
401///         // Expected SHA256 hash of the dataset file
402///         Some(IRIS_SHA256),
403///         // Closure that prepares the dataset file in the temporary directory
404///         |temp_path| {
405///             // Download the dataset into the temporary directory
406///             download_to(IRIS_DATA_URL, temp_path, None)?;
407///             Ok(temp_path.join(IRIS_FILENAME))
408///         },
409///     ).unwrap();
410///
411///     // `file_path` is now the path to the acquired Iris dataset file.
412///     // It can be used to locate or parse the dataset.
413///
414///     // cleanup (dispensable)
415///     std::fs::remove_dir_all(dir).unwrap();
416/// }
417/// ```
418pub fn acquire_dataset<F>(
419    dir: &str,
420    filename: &str,
421    dataset_name: &str,
422    expected_sha256: Option<&str>,
423    prepare_file: F,
424) -> Result<PathBuf, DatasetError>
425where
426    F: FnOnce(&Path) -> Result<PathBuf, DatasetError>,
427{
428    let dir_path = Path::new(dir);
429    let dst = dir_path.join(filename);
430    let (need_acquire, need_overwrite) = evaluate_storage(dir_path, &dst, expected_sha256)?;
431
432    if need_acquire {
433        let temp_dir = create_temp_dir(dir_path)?;
434        let temp_path = temp_dir.path();
435
436        // Call user closure: prepare the dataset file in temporary directory
437        let src = prepare_file(temp_path)?;
438
439        // Validate SHA256 hash if provided
440        if let Some(hash) = expected_sha256 {
441            if !file_sha256_matches(&src, hash)? {
442                drop(temp_dir); // Clean up temporary directory
443                return Err(DatasetError::sha256_validation_failed(
444                    dataset_name,
445                    filename,
446                ));
447            }
448        }
449
450        // Move file to final destination
451        if need_overwrite {
452            std::fs::remove_file(&dst)?;
453        }
454        std::fs::rename(&src, &dst)?;
455    }
456
457    Ok(dst)
458}