use crate::{Dataset, DatasetError, acquire_dataset, download_to};
use csv::ReaderBuilder;
use ndarray::{Array1, Array2};
use std::fs::File;
const BOSTON_HOUSING_DATA_URL: &str =
"https://github.com/selva86/datasets/raw/master/BostonHousing.csv";
const BOSTON_HOUSING_FILENAME: &str = "BostonHousing.csv";
const BOSTON_HOUSING_SHA256: &str =
"ab16ba38fbbbbcc69fe930aab1293104f1442c8279c130d9eba03dd864bef675";
const BOSTON_HOUSING_DATASET_NAME: &str = "boston_housing";
#[derive(Debug)]
pub struct BostonHousing {
dataset: Dataset<(Array2<f64>, Array1<f64>)>,
}
impl BostonHousing {
pub fn new(storage_dir: &str) -> Self {
BostonHousing {
dataset: Dataset::new(storage_dir),
}
}
fn load_data(dir: &str) -> Result<(Array2<f64>, Array1<f64>), DatasetError> {
let file_path = acquire_dataset(
dir,
BOSTON_HOUSING_FILENAME,
BOSTON_HOUSING_DATASET_NAME,
Some(BOSTON_HOUSING_SHA256),
|temp_path| {
download_to(BOSTON_HOUSING_DATA_URL, temp_path, None)?;
Ok(temp_path.join(BOSTON_HOUSING_FILENAME))
},
)?;
let file = File::open(&file_path)?;
let mut rdr = ReaderBuilder::new().has_headers(true).from_reader(file);
let mut features = Vec::new();
let mut targets = Vec::new();
let mut num_features: Option<usize> = None;
for (idx, result) in rdr.records().enumerate() {
let record =
result.map_err(|e| DatasetError::csv_read_error(BOSTON_HOUSING_DATASET_NAME, e))?;
let line_num = idx + 2;
if num_features.is_none() {
if record.len() < 2 {
return Err(DatasetError::invalid_column_count(
BOSTON_HOUSING_DATASET_NAME,
2,
record.len(),
line_num,
&format!("{:?}", record),
));
}
num_features = Some(record.len() - 1);
}
let n_features = num_features.unwrap();
if record.len() != n_features + 1 {
return Err(DatasetError::invalid_column_count(
BOSTON_HOUSING_DATASET_NAME,
n_features + 1,
record.len(),
line_num,
&format!("{:?}", record),
));
}
for i in 0..n_features {
features.push(record[i].parse::<f64>().map_err(|e| {
let field = format!("feature[{i}]");
DatasetError::parse_failed(
BOSTON_HOUSING_DATASET_NAME,
&field,
line_num,
&format!("{:?}", record),
e,
)
})?);
}
targets.push(record[n_features].parse::<f64>().map_err(|e| {
DatasetError::parse_failed(
BOSTON_HOUSING_DATASET_NAME,
"target",
line_num,
&format!("{:?}", record),
e,
)
})?);
}
let n_samples = targets.len();
if n_samples == 0 {
return Err(DatasetError::empty_dataset(BOSTON_HOUSING_DATASET_NAME));
}
let n_features = num_features.unwrap();
let features_array =
Array2::from_shape_vec((n_samples, n_features), features).map_err(|e| {
DatasetError::array_shape_error(BOSTON_HOUSING_DATASET_NAME, "features", e)
})?;
let targets_array = Array1::from_vec(targets);
Ok((features_array, targets_array))
}
pub fn features(&self) -> Result<&Array2<f64>, DatasetError> {
Ok(&self.dataset.load(Self::load_data)?.0)
}
pub fn targets(&self) -> Result<&Array1<f64>, DatasetError> {
Ok(&self.dataset.load(Self::load_data)?.1)
}
pub fn data(&self) -> Result<(&Array2<f64>, &Array1<f64>), DatasetError> {
let data = self.dataset.load(Self::load_data)?;
Ok((&data.0, &data.1))
}
}