use std::collections::BTreeMap;
use std::ops::Index;
use serde::Deserialize;
use serde::Serialize;
use crate::data::Category;
use crate::data::Container;
use crate::data::Datum;
use crate::data::FeatureData;
#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
pub struct DataStore(pub BTreeMap<usize, FeatureData>);
impl Index<usize> for DataStore {
type Output = FeatureData;
fn index(&self, ix: usize) -> &Self::Output {
&self.0[&ix]
}
}
impl DataStore {
pub fn new(data: BTreeMap<usize, FeatureData>) -> Self {
DataStore(data)
}
pub fn get(&self, row_ix: usize, col_ix: usize) -> Datum {
match self.0[&col_ix] {
FeatureData::Binary(ref xs) => {
xs.get(row_ix).map(Datum::Binary).unwrap_or(Datum::Missing)
}
FeatureData::Continuous(ref xs) => xs
.get(row_ix)
.map(Datum::Continuous)
.unwrap_or(Datum::Missing),
FeatureData::Categorical(ref xs) => xs
.get(row_ix)
.map(|x| Datum::Categorical(Category::UInt(x)))
.unwrap_or(Datum::Missing),
FeatureData::Count(ref xs) => {
xs.get(row_ix).map(Datum::Count).unwrap_or(Datum::Missing)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::SparseContainer;
fn fixture() -> DataStore {
let dc1: SparseContainer<f64> = SparseContainer::from(vec![
(4.0, true),
(3.0, false),
(2.0, true),
(1.0, true),
(0.0, true),
]);
let dc2: SparseContainer<u32> = SparseContainer::from(vec![
(5, true),
(3, true),
(2, true),
(1, false),
(4, true),
]);
let mut data = BTreeMap::<usize, FeatureData>::new();
data.insert(0, FeatureData::Continuous(dc1));
data.insert(1, FeatureData::Categorical(dc2));
DataStore(data)
}
#[test]
fn gets_present_continuous_data() {
let ds = fixture();
assert_eq!(ds.get(0, 0), Datum::Continuous(4.0));
assert_eq!(ds.get(2, 0), Datum::Continuous(2.0));
}
#[test]
fn gets_present_categorical_data() {
let ds = fixture();
assert_eq!(ds.get(0, 1), Datum::Categorical(5_u32.into()));
assert_eq!(ds.get(4, 1), Datum::Categorical(4_u32.into()));
}
#[test]
fn gets_missing_continuous_data() {
let ds = fixture();
assert_eq!(ds.get(1, 0), Datum::Missing);
}
#[test]
fn gets_missing_categorical_data() {
let ds = fixture();
assert_eq!(ds.get(3, 1), Datum::Missing);
}
}