use std::{
fs,
io::Write,
path::{Path, PathBuf},
};
use indicatif::{ProgressBar, ProgressStyle};
use sha2::{Digest, Sha256};
use tracing::{info, warn};
use crate::error::{Result, SensorLMError};
#[derive(Debug, Clone)]
pub struct DatasetEntry {
pub name: &'static str,
pub url: &'static str,
pub sha256: &'static str,
pub size_bytes: u64,
}
pub const KNOWN_DATASETS: &[DatasetEntry] = &[
DatasetEntry {
name: "PAMAP2",
url: "https://archive.ics.uci.edu/static/public/231/pamap2+physical+activity+monitoring.zip",
sha256: "", size_bytes: 680_000_000,
},
DatasetEntry {
name: "WESAD",
url: "https://uni-siegen.sciebo.de/s/HGdUkoNlW1Ub0Gx/download",
sha256: "",
size_bytes: 1_800_000_000,
},
];
pub fn find_dataset(name: &str) -> Option<&'static DatasetEntry> {
KNOWN_DATASETS
.iter()
.find(|d| d.name.to_ascii_lowercase() == name.to_ascii_lowercase())
}
pub fn download_file(url: &str, dest_path: &Path, expected_sha256: &str) -> Result<()> {
if dest_path.exists() && !expected_sha256.is_empty() {
let existing_hash = sha256_of_file(dest_path)?;
if existing_hash.eq_ignore_ascii_case(expected_sha256) {
info!("✓ {} already downloaded and verified.", dest_path.display());
return Ok(());
} else {
warn!(
"Checksum mismatch for {}: expected {} got {}. Re-downloading.",
dest_path.display(),
expected_sha256,
existing_hash
);
}
}
info!("Downloading {} → {}", url, dest_path.display());
if let Some(parent) = dest_path.parent() {
fs::create_dir_all(parent)?;
}
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(3600))
.build()
.map_err(|e| SensorLMError::DownloadError { url: url.to_string(), source: e })?;
let mut response = client
.get(url)
.send()
.map_err(|e| SensorLMError::DownloadError { url: url.to_string(), source: e })?;
let total_bytes = response.content_length().unwrap_or(0);
let pb = ProgressBar::new(total_bytes);
pb.set_style(
ProgressStyle::with_template(
"{spinner:.green} [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})",
)
.unwrap()
.progress_chars("=>-"),
);
let mut file = fs::File::create(dest_path)?;
let mut downloaded = 0u64;
let mut buf = vec![0u8; 8192];
loop {
use std::io::Read;
let n = response
.read(&mut buf)
.map_err(|e| SensorLMError::Io(e))?;
if n == 0 {
break;
}
file.write_all(&buf[..n])?;
downloaded += n as u64;
pb.set_position(downloaded);
}
pb.finish_with_message("Download complete");
if !expected_sha256.is_empty() {
let actual_hash = sha256_of_file(dest_path)?;
if !actual_hash.eq_ignore_ascii_case(expected_sha256) {
fs::remove_file(dest_path)?;
return Err(SensorLMError::DatasetError(format!(
"SHA-256 mismatch: expected {expected_sha256}, got {actual_hash}"
)));
}
info!("✓ Checksum verified.");
}
Ok(())
}
fn sha256_of_file(path: &Path) -> Result<String> {
let bytes = fs::read(path)?;
let mut hasher = Sha256::new();
hasher.update(&bytes);
Ok(hex::encode(hasher.finalize()))
}
pub fn default_data_dir() -> PathBuf {
dirs::data_local_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("sensorlm")
}
use ndarray::Array2;
use rand::{Rng as _, SeedableRng};
use rand_distr::{Distribution, Normal};
#[derive(Debug, Clone)]
pub struct SyntheticDataConfig {
pub num_samples: usize,
pub seed: u64,
pub add_circadian: bool,
pub add_missingness: bool,
pub missingness_rate: f64,
}
impl Default for SyntheticDataConfig {
fn default() -> Self {
Self {
num_samples: 1000,
seed: 42,
add_circadian: true,
add_missingness: true,
missingness_rate: 0.1,
}
}
}
#[derive(Debug, Clone)]
pub struct SyntheticSample {
pub sensor: Array2<f32>,
pub mask: Array2<u8>,
pub caption: String,
pub id: usize,
}
pub fn generate_synthetic_dataset(cfg: &SyntheticDataConfig) -> Vec<SyntheticSample> {
use crate::constants::{NUM_CHANNELS, NORM_PARAMS, TIME_STEPS};
let mut rng = rand::rngs::StdRng::seed_from_u64(cfg.seed);
let mut samples = Vec::with_capacity(cfg.num_samples);
for id in 0..cfg.num_samples {
let mut sensor = Array2::<f32>::zeros((TIME_STEPS, NUM_CHANNELS));
let mut mask = Array2::<u8>::zeros((TIME_STEPS, NUM_CHANNELS));
for ch in 0..NUM_CHANNELS {
let (_mean, _std) = NORM_PARAMS[ch]; let noise = Normal::new(0.0f64, 0.3).unwrap();
for t in 0..TIME_STEPS {
let base: f64 = noise.sample(&mut rng);
let circadian = if cfg.add_circadian && (ch == 0 || ch == 3) {
0.5 * (2.0 * std::f64::consts::PI * t as f64 / TIME_STEPS as f64).sin()
} else {
0.0
};
sensor[[t, ch]] = (base + circadian) as f32;
}
if cfg.add_missingness {
let n_missing = ((TIME_STEPS as f64 * cfg.missingness_rate) as usize).max(1);
for _ in 0..n_missing {
let t: usize = rng.gen_range(0..TIME_STEPS);
mask[[t, ch]] = 1;
}
}
}
let caption = format!(
"Synthetic 24-hour recording for individual {id}. \
Heart rate shows typical circadian variation. \
Activity patterns reflect normal daily movement."
);
samples.push(SyntheticSample { sensor, mask, caption, id });
}
samples
}