use std::path::PathBuf;
pub fn cache_dir() -> PathBuf {
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
PathBuf::from(home).join(".oxits").join("datasets")
}
pub fn available_datasets() -> Vec<&'static str> {
vec![
"Adiac",
"ArrowHead",
"Beef",
"BeetleFly",
"BirdChicken",
"Car",
"CBF",
"ChlorineConcentration",
"CinCECGTorso",
"Coffee",
"Computers",
"CricketX",
"CricketY",
"CricketZ",
"DiatomSizeReduction",
"DistalPhalanxOutlineCorrect",
"ECG200",
"ECG5000",
"ECGFiveDays",
"FaceAll",
"FaceFour",
"FacesUCR",
"FiftyWords",
"Fish",
"FordA",
"FordB",
"GunPoint",
"Ham",
"Haptics",
"Herring",
"ItalyPowerDemand",
"Lightning2",
"Lightning7",
"Mallat",
"Meat",
"MedicalImages",
"MiddlePhalanxOutlineCorrect",
"MoteStrain",
"OliveOil",
"OSULeaf",
"PhalangesOutlinesCorrect",
"Plane",
"ProximalPhalanxOutlineCorrect",
"ShapeletSim",
"SonyAIBORobotSurface1",
"SonyAIBORobotSurface2",
"StarLightCurves",
"SwedishLeaf",
"Symbols",
"SyntheticControl",
"ToeSegmentation1",
"ToeSegmentation2",
"Trace",
"TwoLeadECG",
"TwoPatterns",
"UWaveGestureLibraryAll",
"Wafer",
"Wine",
"WordSynonyms",
"Worms",
"Yoga",
]
}
pub fn fetch_ucr_dataset(name: &str) -> Result<super::TrainTestSplit, String> {
let cache = cache_dir().join(name);
let train_path = cache.join(format!("{name}_TRAIN.tsv"));
let test_path = cache.join(format!("{name}_TEST.tsv"));
if train_path.exists() && test_path.exists() {
let (x_train, y_train) = parse_tsv(&train_path)?;
let (x_test, y_test) = parse_tsv(&test_path)?;
return Ok((x_train, x_test, y_train, y_test));
}
let base_url = "https://www.timeseriesclassification.com/aeon-toolkit";
let url = format!("{base_url}/{name}/{name}.zip");
let response = ureq::get(&url)
.call()
.map_err(|e| format!("Failed to download {name}: {e}"))?;
let mut body = Vec::new();
response
.into_reader()
.read_to_end(&mut body)
.map_err(|e| format!("Failed to read response: {e}"))?;
std::fs::create_dir_all(&cache).map_err(|e| format!("Failed to create cache dir: {e}"))?;
let zip_path = cache.join(format!("{name}.zip"));
std::fs::write(&zip_path, &body).map_err(|e| format!("Failed to write zip: {e}"))?;
Err(format!(
"Dataset {name} downloaded to {zip_path:?}. ZIP extraction not yet implemented. \
Please extract manually and place {name}_TRAIN.tsv and {name}_TEST.tsv in {cache:?}"
))
}
fn split_delimited(line: &str) -> Vec<&str> {
let parts: Vec<&str> = line.split('\t').collect();
if parts.len() >= 2 {
parts
} else {
line.split(',').collect()
}
}
fn parse_tsv(path: &std::path::Path) -> Result<(Vec<Vec<f64>>, Vec<String>), String> {
let content =
std::fs::read_to_string(path).map_err(|e| format!("Failed to read {path:?}: {e}"))?;
let mut x = Vec::new();
let mut y = Vec::new();
for line in content.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
let parts = split_delimited(line);
if parts.len() < 2 {
continue;
}
y.push(parts[0].to_string());
let ts: Vec<f64> = parts[1..].iter().filter_map(|s| s.parse().ok()).collect();
x.push(ts);
}
Ok((x, y))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_available_datasets() {
let datasets = available_datasets();
assert!(datasets.contains(&"GunPoint"));
assert!(datasets.contains(&"Coffee"));
assert!(datasets.len() > 50);
}
#[test]
fn test_cache_dir() {
let dir = cache_dir();
assert!(dir.to_string_lossy().contains(".oxits"));
}
#[test]
fn test_parse_tsv_tab_separated() {
let dir = std::env::temp_dir().join("oxits_test_tsv");
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("test.tsv");
std::fs::write(&path, "A\t1.0\t2.0\t3.0\nB\t4.0\t5.0\t6.0\n").unwrap();
let (x, y) = parse_tsv(&path).unwrap();
assert_eq!(y, vec!["A", "B"]);
assert_eq!(x[0], vec![1.0, 2.0, 3.0]);
assert_eq!(x[1], vec![4.0, 5.0, 6.0]);
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn test_parse_tsv_comma_separated() {
let dir = std::env::temp_dir().join("oxits_test_csv");
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("test.csv");
std::fs::write(&path, "X,10.0,20.0\nY,30.0,40.0\n").unwrap();
let (x, y) = parse_tsv(&path).unwrap();
assert_eq!(y, vec!["X", "Y"]);
assert_eq!(x[0], vec![10.0, 20.0]);
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn test_parse_tsv_empty_lines() {
let dir = std::env::temp_dir().join("oxits_test_empty");
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("test.tsv");
std::fs::write(&path, "\nA\t1.0\t2.0\n\nB\t3.0\t4.0\n\n").unwrap();
let (x, y) = parse_tsv(&path).unwrap();
assert_eq!(x.len(), 2);
assert_eq!(y.len(), 2);
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn test_fetch_ucr_dataset_not_cached() {
let result = fetch_ucr_dataset("NonExistentDataset12345");
assert!(result.is_err());
}
}