pub fn load_gunpoint_synthetic() -> super::TrainTestSplit {
let n_timestamps = 20;
let n_train = 10;
let n_test = 10;
let mut x_train = Vec::with_capacity(n_train);
let mut y_train = Vec::with_capacity(n_train);
let mut x_test = Vec::with_capacity(n_test);
let mut y_test = Vec::with_capacity(n_test);
for i in 0..n_train {
if i < n_train / 2 {
let ts: Vec<f64> = (0..n_timestamps)
.map(|t| t as f64 / n_timestamps as f64 + 0.1 * i as f64)
.collect();
x_train.push(ts);
y_train.push("1".to_string());
} else {
let ts: Vec<f64> = (0..n_timestamps)
.map(|t| 1.0 - t as f64 / n_timestamps as f64 + 0.1 * (i - n_train / 2) as f64)
.collect();
x_train.push(ts);
y_train.push("2".to_string());
}
}
for i in 0..n_test {
if i < n_test / 2 {
let ts: Vec<f64> = (0..n_timestamps)
.map(|t| t as f64 / n_timestamps as f64 + 0.05 * i as f64)
.collect();
x_test.push(ts);
y_test.push("1".to_string());
} else {
let ts: Vec<f64> = (0..n_timestamps)
.map(|t| 1.0 - t as f64 / n_timestamps as f64 + 0.05 * (i - n_test / 2) as f64)
.collect();
x_test.push(ts);
y_test.push("2".to_string());
}
}
(x_train, x_test, y_train, y_test)
}
pub fn load_coffee_synthetic() -> super::TrainTestSplit {
let n_timestamps = 15;
let n_train = 8;
let n_test = 8;
let mut x_train = Vec::with_capacity(n_train);
let mut y_train = Vec::with_capacity(n_train);
let mut x_test = Vec::with_capacity(n_test);
let mut y_test = Vec::with_capacity(n_test);
for i in 0..n_train {
if i < n_train / 2 {
let ts: Vec<f64> = (0..n_timestamps)
.map(|t| (2.0 * std::f64::consts::PI * t as f64 / n_timestamps as f64).sin())
.collect();
x_train.push(ts);
y_train.push("A".to_string());
} else {
let ts: Vec<f64> = (0..n_timestamps)
.map(|t| (2.0 * std::f64::consts::PI * t as f64 / n_timestamps as f64).cos())
.collect();
x_train.push(ts);
y_train.push("B".to_string());
}
}
for i in 0..n_test {
if i < n_test / 2 {
let ts: Vec<f64> = (0..n_timestamps)
.map(|t| (2.0 * std::f64::consts::PI * t as f64 / n_timestamps as f64).sin() + 0.1)
.collect();
x_test.push(ts);
y_test.push("A".to_string());
} else {
let ts: Vec<f64> = (0..n_timestamps)
.map(|t| (2.0 * std::f64::consts::PI * t as f64 / n_timestamps as f64).cos() + 0.1)
.collect();
x_test.push(ts);
y_test.push("B".to_string());
}
}
(x_train, x_test, y_train, y_test)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_gunpoint_synthetic() {
let (x_train, x_test, y_train, y_test) = load_gunpoint_synthetic();
assert_eq!(x_train.len(), 10);
assert_eq!(x_test.len(), 10);
assert_eq!(y_train.len(), 10);
assert_eq!(y_test.len(), 10);
assert_eq!(x_train[0].len(), 20);
}
#[test]
fn test_load_coffee_synthetic() {
let (x_train, x_test, y_train, y_test) = load_coffee_synthetic();
assert_eq!(x_train.len(), 8);
assert_eq!(x_test.len(), 8);
assert_eq!(y_train.len(), 8);
assert_eq!(y_test.len(), 8);
}
}