use crate::{Dataset, DatasetError, acquire_dataset, download_to};
use csv::ReaderBuilder;
use ndarray::{Array1, Array2};
use std::fs::File;
type TitanicData = (Array2<String>, Array2<f64>, Array1<f64>);
const TITANIC_DATA_URL: &str =
"https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv";
const TITANIC_FILENAME: &str = "titanic.csv";
const TITANIC_SHA256: &str = "4a437fde05fe5264e1701a7387ac6fb75393772ba38bb2c9c566405af5af4bd7";
const TITANIC_DATASET_NAME: &str = "titanic";
#[derive(Debug)]
pub struct Titanic {
dataset: Dataset<TitanicData>,
}
impl Titanic {
pub fn new(storage_dir: &str) -> Self {
Titanic {
dataset: Dataset::new(storage_dir),
}
}
fn load_data(dir: &str) -> Result<TitanicData, DatasetError> {
let file_path = acquire_dataset(
dir,
TITANIC_FILENAME,
TITANIC_DATASET_NAME,
Some(TITANIC_SHA256),
|temp_path| {
download_to(TITANIC_DATA_URL, temp_path, None)?;
Ok(temp_path.join(TITANIC_FILENAME))
},
)?;
let file = File::open(&file_path)?;
let mut rdr = ReaderBuilder::new().has_headers(true).from_reader(file);
let mut string_features = Vec::new();
let mut numeric_features = Vec::new();
let mut labels = Vec::new();
let numeric_indices = vec![0, 2, 5, 6, 7, 9];
let string_indices = vec![3, 4, 8, 10, 11];
let label_index = 1;
let mut num_string_features: Option<usize> = None;
let mut num_numeric_features: Option<usize> = None;
for (idx, result) in rdr.records().enumerate() {
let record =
result.map_err(|e| DatasetError::csv_read_error(TITANIC_DATASET_NAME, e))?;
let line_num = idx + 2;
if num_string_features.is_none() {
if record.len() < 12 {
return Err(DatasetError::invalid_column_count(
TITANIC_DATASET_NAME,
12,
record.len(),
line_num,
&format!("{:?}", record),
));
}
num_string_features = Some(string_indices.len());
num_numeric_features = Some(numeric_indices.len());
}
let parse_numeric = |index: usize, field_name: &str| -> Result<f64, DatasetError> {
let val = record[index].trim();
if val.is_empty() {
Ok(f64::NAN)
} else {
val.parse::<f64>().map_err(|e| {
DatasetError::parse_failed(
TITANIC_DATASET_NAME,
field_name,
line_num,
&format!("{:?}", record),
e,
)
})
}
};
labels.push(parse_numeric(label_index, "survived")?);
for (i, &col_idx) in numeric_indices.iter().enumerate() {
let field_name = match i {
0 => "passenger_id",
1 => "pclass",
2 => "age",
3 => "sib_sp",
4 => "parch",
5 => "fare",
_ => "numeric_feature",
};
numeric_features.push(parse_numeric(col_idx, field_name)?);
}
for &col_idx in string_indices.iter() {
string_features.push(record[col_idx].to_string());
}
}
let n_samples = labels.len();
if n_samples == 0 {
return Err(DatasetError::empty_dataset(TITANIC_DATASET_NAME));
}
let n_string_features = num_string_features.unwrap();
let n_numeric_features = num_numeric_features.unwrap();
let string_array = Array2::from_shape_vec((n_samples, n_string_features), string_features)
.map_err(|e| {
DatasetError::array_shape_error(TITANIC_DATASET_NAME, "string_features", e)
})?;
let numeric_array =
Array2::from_shape_vec((n_samples, n_numeric_features), numeric_features).map_err(
|e| DatasetError::array_shape_error(TITANIC_DATASET_NAME, "numeric_features", e),
)?;
let labels_array = Array1::from_vec(labels);
Ok((string_array, numeric_array, labels_array))
}
pub fn features(&self) -> Result<(&Array2<String>, &Array2<f64>), DatasetError> {
let data = self.dataset.load(Self::load_data)?;
Ok((&data.0, &data.1))
}
pub fn labels(&self) -> Result<&Array1<f64>, DatasetError> {
Ok(&self.dataset.load(Self::load_data)?.2)
}
pub fn data(&self) -> Result<(&Array2<String>, &Array2<f64>, &Array1<f64>), DatasetError> {
let data = self.dataset.load(Self::load_data)?;
Ok((&data.0, &data.1, &data.2))
}
}