dataset_core/utils.rs
1use crate::DatasetError;
2use sha2::{Digest, Sha256};
3use std::fs::File;
4use std::io;
5use std::io::Read;
6use std::path::{Path, PathBuf};
7use zip::ZipArchive;
8use zip::result::ZipError;
9
10/// Download a remote file into the given directory.
11///
12/// It downloads the content at `url` (using [`ureq`] crate) into `storage_path` using the file name
13/// extracted from the last segment of the URL, unless a custom filename is provided.
14///
15/// # Parameters
16///
17/// - `url` - The URL to download.
18/// - `storage_path` - The directory to store the downloaded file in.
19/// - `filename` - Optional custom filename (with extension). If `None`, the filename is extracted
20/// from the last segment of the URL.
21///
22/// # Errors
23///
24/// - `DatasetError` - Returned when the download fails or URL is invalid.
25///
26/// # Example
27/// ```rust
28/// use dataset_core::download_to;
29/// use std::path::Path;
30///
31/// let download_dir = "./download_example";
32/// std::fs::create_dir_all(download_dir).unwrap();
33///
34/// // Download a file from the internet
35/// let url = "https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv";
36///
37/// // Use filename from URL
38/// download_to(url, Path::new(download_dir), None).unwrap();
39/// assert!(Path::new(download_dir).join("iris.csv").exists());
40///
41/// // Use custom filename
42/// download_to(url, Path::new(download_dir), Some("custom.csv")).unwrap();
43/// assert!(Path::new(download_dir).join("custom.csv").exists());
44///
45/// // Clean up (dispensable)
46/// std::fs::remove_dir_all(download_dir).unwrap();
47/// ```
48pub fn download_to(
49 url: &str,
50 storage_path: &Path,
51 filename: Option<&str>,
52) -> Result<(), DatasetError> {
53 // Get the filename: use provided name, or fall back to URL extraction
54 let filename = filename.or_else(|| url.split('/').last()).ok_or_else(|| {
55 DatasetError::ValidationError("Invalid URL: cannot extract filename from URL".to_string())
56 })?;
57
58 let save_path = storage_path.join(filename);
59
60 let mut response = ureq::get(url).call()?;
61 let mut body = response.body_mut().as_reader();
62
63 // create local file and write body to it
64 let mut file = File::create(save_path)?;
65 io::copy(&mut body, &mut file)?;
66
67 Ok(())
68}
69
70/// Extract a zip archive into a target directory using [`ZipArchive`] in [`zip`] crate.
71///
72/// # Parameters
73///
74/// - `file_path` - Path to the `.zip` file to extract.
75/// - `extract_dir` - Directory to extract the archive contents into.
76///
77/// # Errors
78///
79/// - `DatasetError` - Returned when opening the zip file fails or when extraction fails.
80///
81/// # Example
82/// ```rust
83/// use dataset_core::{download_to, unzip};
84/// use std::path::Path;
85///
86/// let work_dir = "./unzip_example";
87/// std::fs::create_dir_all(work_dir).unwrap();
88///
89/// // First download a file
90/// let url = "https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv";
91/// download_to(url, Path::new(work_dir), None).unwrap();
92///
93/// // The file is already a CSV (no extraction needed in this example)
94/// assert!(Path::new(work_dir).join("iris.csv").exists());
95///
96/// // Clean up (dispensable)
97/// std::fs::remove_dir_all(work_dir).unwrap();
98/// ```
99pub fn unzip(file_path: &Path, extract_dir: &Path) -> Result<(), DatasetError> {
100 let file = File::open(file_path).map_err(|e| DatasetError::from(ZipError::Io(e)))?;
101
102 ZipArchive::new(file)?.extract(extract_dir)?;
103
104 Ok(())
105}
106
107/// Create a temporary directory under the given parent directory.
108///
109/// This is a small wrapper around [`tempfile::Builder`] used by dataset loaders to
110/// keep intermediate download/extraction artifacts isolated. The created directory
111/// is removed automatically when the returned [`tempfile::TempDir`] is dropped.
112///
113/// # Parameters
114///
115/// - `tempdir_in` - The parent directory in which the temporary directory will be created.
116///
117/// # Errors
118///
119/// - `DatasetError` - Returned if the temporary directory cannot be created.
120///
121/// # Example
122/// ```rust
123/// use dataset_core::create_temp_dir;
124/// use std::path::Path;
125///
126/// let parent_dir = "./temp_dir_example";
127/// std::fs::create_dir_all(parent_dir).unwrap();
128///
129/// // Create a temporary directory
130/// let temp_dir = create_temp_dir(Path::new(parent_dir)).unwrap();
131/// let temp_path = temp_dir.path();
132///
133/// // Use the temporary directory for intermediate operations
134/// let temp_file = temp_path.join("temp_file.txt");
135/// std::fs::write(&temp_file, "temporary content").unwrap();
136/// assert!(temp_file.exists());
137///
138/// // The temporary directory is automatically removed when `temp_dir` is dropped
139/// drop(temp_dir);
140///
141/// // Clean up parent directory
142/// std::fs::remove_dir_all(parent_dir).unwrap();
143/// ```
144pub fn create_temp_dir(tempdir_in: &Path) -> Result<tempfile::TempDir, DatasetError> {
145 let temp_dir = tempfile::Builder::new().tempdir_in(tempdir_in)?;
146
147 Ok(temp_dir)
148}
149
150/// Verify that a file's SHA256 hash matches an expected value.
151///
152/// This function computes the SHA256 hash of the file at the given path and compares
153/// it with the expected hexadecimal hash string (case-insensitive). It is used by
154/// dataset loaders to validate downloaded files before parsing.
155///
156/// # Parameters
157///
158/// - `path` - Path to the file to verify.
159/// - `expected_hex` - Expected SHA256 hash as a hexadecimal string.
160///
161/// # Returns
162///
163/// - `bool` - true if the computed hash matches the expected hash, false if the hashes don't match
164///
165/// # Errors
166///
167/// - `DatasetError::IoError` - Returned when file I/O operations fail (opening file, reading data).
168///
169/// # Example
170/// ```rust
171/// use dataset_core::file_sha256_matches;
172/// use std::path::Path;
173/// use std::io::Write;
174///
175/// let test_dir = "./sha256_example";
176/// std::fs::create_dir_all(test_dir).unwrap();
177///
178/// // Create a test file with known content
179/// let file_path = Path::new(test_dir).join("test.txt");
180/// let mut file = std::fs::File::create(&file_path).unwrap();
181/// file.write_all(b"hello world").unwrap();
182/// drop(file);
183///
184/// // SHA256 of "hello world" is:
185/// // b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9
186/// let expected_hash = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
187///
188/// // Verify the hash matches
189/// assert!(file_sha256_matches(&file_path, expected_hash).unwrap());
190///
191/// // Case-insensitive comparison also works
192/// let upper_hash = "B94D27B9934D3E08A52E52D7DA7DABFAC484EFE37A5380EE9088F7ACE2EFCDE9";
193/// assert!(file_sha256_matches(&file_path, upper_hash).unwrap());
194///
195/// // Wrong hash returns false
196/// assert!(!file_sha256_matches(&file_path, "0000000000000000000000000000000000000000000000000000000000000000").unwrap());
197///
198/// // Clean up (dispensable)
199/// std::fs::remove_dir_all(test_dir).unwrap();
200/// ```
201pub fn file_sha256_matches(path: &Path, expected_hex: &str) -> Result<bool, DatasetError> {
202 let mut file = File::open(path)?;
203
204 let mut hasher = Sha256::new();
205 let mut buf = [0u8; 8192];
206
207 loop {
208 let read = file.read(&mut buf)?;
209 if read == 0 {
210 break;
211 }
212 hasher.update(&buf[..read]);
213 }
214
215 let digest = hasher.finalize();
216 let actual_hex = digest
217 .iter()
218 .map(|b| format!("{:02x}", b))
219 .collect::<String>();
220 Ok(actual_hex.eq_ignore_ascii_case(expected_hex))
221}
222
223/// Evaluate the storage state for a dataset file.
224///
225/// This helper ensures the target directory exists and checks whether the destination
226/// file is already present and valid. If `expected_sha256` is `None`, any existing
227/// file at `dst` is accepted without validation.
228///
229/// # Parameters
230///
231/// - `path` - Directory path where the dataset will be stored.
232/// - `dst` - Destination file path for the dataset.
233/// - `expected_sha256` - Optional expected SHA256 hash for the dataset file. If `None`,
234/// any existing file at `dst` is accepted without validation.
235///
236/// # Returns
237///
238/// - `(need_acquire, need_overwrite)` - Flags indicating whether a new file needs to be
239/// prepared and whether an existing file should be overwritten.
240///
241/// # Errors
242///
243/// - `DatasetError::IoError` - Returned when creating the directory fails or when
244/// file I/O operations fail during hash verification.
245///
246/// # Example
247/// ```rust
248/// use dataset_core::utils::evaluate_storage;
249/// use std::path::Path;
250/// use std::io::Write;
251///
252/// let test_dir = "./evaluate_storage_example";
253/// let dir_path = Path::new(test_dir);
254/// let file_path = dir_path.join("data.txt");
255///
256/// // Case 1: Directory doesn't exist yet
257/// let (need_acquire, need_overwrite) = evaluate_storage(
258/// dir_path,
259/// &file_path,
260/// None,
261/// ).unwrap();
262/// assert!(need_acquire); // File doesn't exist, a new file must be prepared
263/// assert!(!need_overwrite); // Nothing to overwrite
264///
265/// // Case 2: File exists with correct hash
266/// let mut file = std::fs::File::create(&file_path).unwrap();
267/// file.write_all(b"hello world").unwrap();
268/// drop(file);
269///
270/// let correct_hash = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
271/// let (need_acquire, need_overwrite) = evaluate_storage(
272/// dir_path,
273/// &file_path,
274/// Some(correct_hash),
275/// ).unwrap();
276/// assert!(!need_acquire); // File exists with correct hash
277/// assert!(!need_overwrite);
278///
279/// // Case 3: File exists but hash doesn't match
280/// let wrong_hash = "0000000000000000000000000000000000000000000000000000000000000000";
281/// let (need_acquire, need_overwrite) = evaluate_storage(
282/// dir_path,
283/// &file_path,
284/// Some(wrong_hash),
285/// ).unwrap();
286/// assert!(need_acquire); // Hash mismatch, a new file must be prepared
287/// assert!(need_overwrite); // Existing file needs to be replaced
288///
289/// // Clean up (dispensable)
290/// std::fs::remove_dir_all(test_dir).unwrap();
291/// ```
292pub fn evaluate_storage(
293 path: &Path,
294 dst: &Path,
295 expected_sha256: Option<&str>,
296) -> Result<(bool, bool), DatasetError> {
297 let mut need_acquire = true;
298 let mut need_overwrite = false;
299
300 if !path.exists() {
301 std::fs::create_dir_all(path)?;
302 }
303
304 if dst.exists() {
305 if let Some(hash) = expected_sha256 {
306 // SHA256 validation enabled
307 if file_sha256_matches(dst, hash)? {
308 need_acquire = false;
309 } else {
310 need_overwrite = true;
311 }
312 } else {
313 // No SHA256 validation: accept existing file
314 need_acquire = false;
315 }
316 }
317
318 Ok((need_acquire, need_overwrite))
319}
320
321/// Acquire a dataset file using a caller-provided preparation closure.
322///
323/// This function orchestrates the dataset acquisition workflow: it checks whether
324/// the destination file can be reused, creates a temporary directory when a new
325/// file is needed, delegates file preparation to a user-provided closure,
326/// optionally validates the prepared file with SHA256, and moves it to the final
327/// destination.
328///
329/// The function itself does not perform network I/O. The `prepare_file` closure
330/// is responsible for preparing the dataset file, which may include downloading,
331/// extracting archives, or locating files within an extracted directory.
332///
333/// # Parameters
334///
335/// - `dir` - Target storage directory path.
336/// - `filename` - Final dataset filename (stored as `dir/filename`).
337/// Please give the filename with the extension (e.g., `"iris.csv"`).
338/// - `dataset_name` - Dataset name for error messages (e.g., `"iris"`).
339/// - `expected_sha256` - Optional expected SHA256 hash of the dataset file. If `None`,
340/// any existing file at the destination is accepted without validation, and newly
341/// prepared files skip SHA256 verification.
342/// - `prepare_file` - Closure that prepares the dataset file in the temporary directory.
343/// - Input: `temp_dir: &Path` - Path to the temporary directory.
344/// It is recommended to execute file operations within this directory, as it will be
345/// cleaned up automatically when the closure returns. But it is not required.
346/// (Please note that the file will be moved to the final destination, not copied.)
347/// - Output: `Result<PathBuf, DatasetError>` - Path to the prepared dataset file
348/// (which will be moved to `dir/filename`).
349/// - Responsibility: This closure can perform any operations needed to prepare the
350/// dataset file, such as downloading (you can use [`download_to`] provided in this crate),
351/// extracting archives (you can use [`unzip`] provided in this crate), or locating files
352/// within extracted folders. The returned `PathBuf` must point to the final dataset file
353/// ready for validation.
354///
355/// # Returns
356///
357/// - `PathBuf` - Path to the final dataset file (`dir/filename`).
358///
359/// # Errors
360///
361/// - `DatasetError::IoError` - Returned when directory creation, file operations, or
362/// hash verification fails.
363/// - `DatasetError::Sha256ValidationFailed` - Returned when `expected_sha256` is provided
364/// and the prepared file's SHA256 hash does not match it.
365/// - Any error returned by the `prepare_file` closure.
366///
367/// # Example
368/// ```rust
369/// // Implement the file preparation process for the Iris dataset.
370///
371/// /// The URL for the Iris dataset.
372/// ///
373/// /// # Citation
374/// ///
375/// /// R. A. Fisher. "Iris," UCI Machine Learning Repository, \[Online\].
376/// /// Available: <https://doi.org/10.24432/C56C76>
377/// const IRIS_DATA_URL: &str = "https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv";
378///
379/// /// The name of the Iris dataset file.
380/// const IRIS_FILENAME: &str = "iris.csv";
381///
382/// /// The SHA256 hash of the Iris dataset file.
383/// const IRIS_SHA256: &str = "c52742e50315a99f956a383faedf7575552675f6409ef0f9a47076dd08479930";
384///
385/// /// The name of the dataset.
386/// const IRIS_DATASET_NAME: &str = "iris";
387///
388/// use dataset_core::acquire_dataset;
389/// use dataset_core::download_to;
390///
391/// fn main() {
392/// let dir = "./somewhere";
393///
394/// let file_path = acquire_dataset(
395/// // Target storage directory path
396/// dir,
397/// // Final dataset filename (will be stored as `dir/filename`)
398/// IRIS_FILENAME,
399/// // Dataset name for error messages
400/// IRIS_DATASET_NAME,
401/// // Expected SHA256 hash of the dataset file
402/// Some(IRIS_SHA256),
403/// // Closure that prepares the dataset file in the temporary directory
404/// |temp_path| {
405/// // Download the dataset into the temporary directory
406/// download_to(IRIS_DATA_URL, temp_path, None)?;
407/// Ok(temp_path.join(IRIS_FILENAME))
408/// },
409/// ).unwrap();
410///
411/// // `file_path` is now the path to the acquired Iris dataset file.
412/// // It can be used to locate or parse the dataset.
413///
414/// // cleanup (dispensable)
415/// std::fs::remove_dir_all(dir).unwrap();
416/// }
417/// ```
418pub fn acquire_dataset<F>(
419 dir: &str,
420 filename: &str,
421 dataset_name: &str,
422 expected_sha256: Option<&str>,
423 prepare_file: F,
424) -> Result<PathBuf, DatasetError>
425where
426 F: FnOnce(&Path) -> Result<PathBuf, DatasetError>,
427{
428 let dir_path = Path::new(dir);
429 let dst = dir_path.join(filename);
430 let (need_acquire, need_overwrite) = evaluate_storage(dir_path, &dst, expected_sha256)?;
431
432 if need_acquire {
433 let temp_dir = create_temp_dir(dir_path)?;
434 let temp_path = temp_dir.path();
435
436 // Call user closure: prepare the dataset file in temporary directory
437 let src = prepare_file(temp_path)?;
438
439 // Validate SHA256 hash if provided
440 if let Some(hash) = expected_sha256 {
441 if !file_sha256_matches(&src, hash)? {
442 drop(temp_dir); // Clean up temporary directory
443 return Err(DatasetError::sha256_validation_failed(
444 dataset_name,
445 filename,
446 ));
447 }
448 }
449
450 // Move file to final destination
451 if need_overwrite {
452 std::fs::remove_file(&dst)?;
453 }
454 std::fs::rename(&src, &dst)?;
455 }
456
457 Ok(dst)
458}