oxits 0.1.0

Time series classification and transformation library for Rust
Documentation
/// UCR/UEA Time Series Classification Archive dataset fetching.
///
/// Downloads and caches datasets from the UCR archive.
/// Requires the `datasets` feature flag.
use std::path::PathBuf;

/// Get the default cache directory for downloaded datasets.
pub fn cache_dir() -> PathBuf {
    let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
    PathBuf::from(home).join(".oxits").join("datasets")
}

/// List of available UCR datasets.
pub fn available_datasets() -> Vec<&'static str> {
    vec![
        "Adiac",
        "ArrowHead",
        "Beef",
        "BeetleFly",
        "BirdChicken",
        "Car",
        "CBF",
        "ChlorineConcentration",
        "CinCECGTorso",
        "Coffee",
        "Computers",
        "CricketX",
        "CricketY",
        "CricketZ",
        "DiatomSizeReduction",
        "DistalPhalanxOutlineCorrect",
        "ECG200",
        "ECG5000",
        "ECGFiveDays",
        "FaceAll",
        "FaceFour",
        "FacesUCR",
        "FiftyWords",
        "Fish",
        "FordA",
        "FordB",
        "GunPoint",
        "Ham",
        "Haptics",
        "Herring",
        "ItalyPowerDemand",
        "Lightning2",
        "Lightning7",
        "Mallat",
        "Meat",
        "MedicalImages",
        "MiddlePhalanxOutlineCorrect",
        "MoteStrain",
        "OliveOil",
        "OSULeaf",
        "PhalangesOutlinesCorrect",
        "Plane",
        "ProximalPhalanxOutlineCorrect",
        "ShapeletSim",
        "SonyAIBORobotSurface1",
        "SonyAIBORobotSurface2",
        "StarLightCurves",
        "SwedishLeaf",
        "Symbols",
        "SyntheticControl",
        "ToeSegmentation1",
        "ToeSegmentation2",
        "Trace",
        "TwoLeadECG",
        "TwoPatterns",
        "UWaveGestureLibraryAll",
        "Wafer",
        "Wine",
        "WordSynonyms",
        "Worms",
        "Yoga",
    ]
}

/// Fetch a UCR dataset by name.
///
/// Downloads from the UCR archive if not already cached.
/// Returns (x_train, x_test, y_train, y_test).
pub fn fetch_ucr_dataset(name: &str) -> Result<super::TrainTestSplit, String> {
    let cache = cache_dir().join(name);

    // Check cache
    let train_path = cache.join(format!("{name}_TRAIN.tsv"));
    let test_path = cache.join(format!("{name}_TEST.tsv"));

    if train_path.exists() && test_path.exists() {
        let (x_train, y_train) = parse_tsv(&train_path)?;
        let (x_test, y_test) = parse_tsv(&test_path)?;
        return Ok((x_train, x_test, y_train, y_test));
    }

    // Download
    let base_url = "https://www.timeseriesclassification.com/aeon-toolkit";
    let url = format!("{base_url}/{name}/{name}.zip");

    let response = ureq::get(&url)
        .call()
        .map_err(|e| format!("Failed to download {name}: {e}"))?;

    let mut body = Vec::new();
    response
        .into_reader()
        .read_to_end(&mut body)
        .map_err(|e| format!("Failed to read response: {e}"))?;

    // Create cache directory
    std::fs::create_dir_all(&cache).map_err(|e| format!("Failed to create cache dir: {e}"))?;

    // For now, save the raw ZIP and parse later
    // In practice you'd unzip and parse the TSV files
    let zip_path = cache.join(format!("{name}.zip"));
    std::fs::write(&zip_path, &body).map_err(|e| format!("Failed to write zip: {e}"))?;

    Err(format!(
        "Dataset {name} downloaded to {zip_path:?}. ZIP extraction not yet implemented. \
         Please extract manually and place {name}_TRAIN.tsv and {name}_TEST.tsv in {cache:?}"
    ))
}

/// Split a line by tab first, falling back to comma if tab yields only one field.
fn split_delimited(line: &str) -> Vec<&str> {
    let parts: Vec<&str> = line.split('\t').collect();
    if parts.len() >= 2 {
        parts
    } else {
        line.split(',').collect()
    }
}

/// Parse a UCR TSV file where the first column is the label.
fn parse_tsv(path: &std::path::Path) -> Result<(Vec<Vec<f64>>, Vec<String>), String> {
    let content =
        std::fs::read_to_string(path).map_err(|e| format!("Failed to read {path:?}: {e}"))?;

    let mut x = Vec::new();
    let mut y = Vec::new();

    for line in content.lines() {
        let line = line.trim();
        if line.is_empty() {
            continue;
        }

        let parts = split_delimited(line);
        if parts.len() < 2 {
            continue;
        }
        y.push(parts[0].to_string());
        let ts: Vec<f64> = parts[1..].iter().filter_map(|s| s.parse().ok()).collect();
        x.push(ts);
    }

    Ok((x, y))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_available_datasets() {
        let datasets = available_datasets();
        assert!(datasets.contains(&"GunPoint"));
        assert!(datasets.contains(&"Coffee"));
        assert!(datasets.len() > 50);
    }

    #[test]
    fn test_cache_dir() {
        let dir = cache_dir();
        assert!(dir.to_string_lossy().contains(".oxits"));
    }

    #[test]
    fn test_parse_tsv_tab_separated() {
        let dir = std::env::temp_dir().join("oxits_test_tsv");
        std::fs::create_dir_all(&dir).unwrap();
        let path = dir.join("test.tsv");
        std::fs::write(&path, "A\t1.0\t2.0\t3.0\nB\t4.0\t5.0\t6.0\n").unwrap();
        let (x, y) = parse_tsv(&path).unwrap();
        assert_eq!(y, vec!["A", "B"]);
        assert_eq!(x[0], vec![1.0, 2.0, 3.0]);
        assert_eq!(x[1], vec![4.0, 5.0, 6.0]);
        std::fs::remove_dir_all(&dir).ok();
    }

    #[test]
    fn test_parse_tsv_comma_separated() {
        let dir = std::env::temp_dir().join("oxits_test_csv");
        std::fs::create_dir_all(&dir).unwrap();
        let path = dir.join("test.csv");
        std::fs::write(&path, "X,10.0,20.0\nY,30.0,40.0\n").unwrap();
        let (x, y) = parse_tsv(&path).unwrap();
        assert_eq!(y, vec!["X", "Y"]);
        assert_eq!(x[0], vec![10.0, 20.0]);
        std::fs::remove_dir_all(&dir).ok();
    }

    #[test]
    fn test_parse_tsv_empty_lines() {
        let dir = std::env::temp_dir().join("oxits_test_empty");
        std::fs::create_dir_all(&dir).unwrap();
        let path = dir.join("test.tsv");
        std::fs::write(&path, "\nA\t1.0\t2.0\n\nB\t3.0\t4.0\n\n").unwrap();
        let (x, y) = parse_tsv(&path).unwrap();
        assert_eq!(x.len(), 2);
        assert_eq!(y.len(), 2);
        std::fs::remove_dir_all(&dir).ok();
    }

    #[test]
    fn test_fetch_ucr_dataset_not_cached() {
        // Non-existent dataset name won't be cached, will try to download
        // and fail (no network in tests) — we just verify the error path
        let result = fetch_ucr_dataset("NonExistentDataset12345");
        assert!(result.is_err());
    }
}