use std::path::PathBuf;
use serde::{Deserialize, Serialize};
use serde_aux::field_attributes::bool_true;
use crate::{
datatable::DataTable,
preprocessing::map::{MapOp, MapSelector},
};
#[derive(Serialize, Debug, Deserialize, Clone, Hash, Default)]
pub struct Dataset {
pub features: Vec<Feature>,
}
impl Dataset {
pub fn with_added_feature(&self, feature: Feature) -> Self {
let mut features = self.features.clone();
features.push(feature);
Self {
features,
}
}
pub fn without_feature(&self, feature_name: String) -> Self {
let mut features = self.features.clone();
features.retain(|f| f.name != feature_name);
Self {
features,
}
}
pub fn with_replaced_feature(&self, old_feature_name: &str, feature: Feature) -> Self {
let mut features = self.features.clone();
let index = features
.iter()
.position(|f| f.name == old_feature_name)
.unwrap();
features[index] = feature;
Self {
features,
}
}
pub fn feature_names(&self) -> Vec<&str> {
let mut names = Vec::new();
for feature in &self.features {
names.push(feature.name.as_str());
}
names
}
pub fn in_features_names(&self) -> Vec<&str> {
let mut names = Vec::new();
for feature in &self.features {
if !feature.predicted && !feature.is_id && !feature.date_format.is_some() {
names.push(feature.name.as_str());
}
}
names
}
pub fn predicted_features_names(&self) -> Vec<&str> {
let mut names = Vec::new();
for feature in &self.features {
if feature.predicted {
names.push(feature.name.as_str());
}
}
names
}
pub fn new(features: &[Feature]) -> Self {
Self {
features: features.to_vec(),
}
}
pub fn from_file<P: Into<PathBuf>>(path: P) -> Self {
let feature_names = DataTable::columns_names_from_file(path);
let mut features = Vec::new();
for feature_name in feature_names {
let feature = Feature::from_tags(&[FeatureTags::Name(feature_name.as_str())]);
features.push(feature);
}
Self::new(&features)
}
pub fn remove_features(&mut self, feature_names: &[&str]) -> &mut Self {
let mut new_features = Vec::new();
for feature in &self.features {
if !feature_names.contains(&feature.name.as_str()) {
new_features.push(feature.clone());
}
}
self.features = new_features;
self
}
pub fn tag_feature(&mut self, feature_name: &str, tag: FeatureTags) -> &mut Self {
for feature in &mut self.features {
if feature.name == feature_name {
tag.apply(feature);
}
}
self
}
pub fn tag_all(&mut self, tag: FeatureTags) -> &mut Self {
for feature in &mut self.features {
tag.apply(feature);
}
self
}
pub fn from_features_tags(features: &[&[FeatureTags]]) -> Self {
let mut dataset = Self::default();
for feature_tags in features {
let feature = Feature::from_tags(feature_tags);
dataset.features.push(feature);
}
dataset
}
pub fn get_id_column(&self) -> Option<&str> {
for feature in &self.features {
if feature.is_id {
return Some(feature.name.as_str());
}
}
None
}
}
#[derive(Default, Serialize, Debug, Deserialize, Clone, Hash, Eq, PartialEq)]
pub struct Feature {
pub name: String,
#[serde(default)]
pub predicted: bool,
pub date_format: Option<String>,
#[serde(default)]
pub to_timestamp: bool,
#[serde(default)]
pub extract_month: bool,
#[serde(default)]
pub log10: bool,
#[serde(default)]
pub normalized: bool,
#[serde(default)]
pub filter_outliers: bool,
#[serde(default)]
pub mapped: Option<(MapSelector, MapOp)>,
#[serde(default)]
pub squared: bool,
pub with_extracted_timestamp: Option<Box<Feature>>,
pub with_extracted_month: Option<Box<Feature>>,
pub with_log10: Option<Box<Feature>>,
pub with_normalized: Option<Box<Feature>>,
pub with_squared: Option<Box<Feature>>,
#[serde(default = "bool_true")]
pub used_in_model: bool,
#[serde(default)]
pub one_hot_encoded: bool,
#[serde(default)]
pub is_id: bool,
}
impl Feature {
pub fn from_tags(feature_tags: &[FeatureTags]) -> Self {
let mut feature = Feature::default();
feature.used_in_model = true;
for feature_tag in feature_tags {
feature_tag.apply(&mut feature);
}
feature
}
pub fn get_extracted_features_mut(&mut self) -> Vec<&mut Feature> {
let mut extracted_features = Vec::new();
if let Some(ref mut feature) = self.with_extracted_month {
extracted_features.push(feature.as_mut());
}
if let Some(ref mut feature) = self.with_extracted_timestamp {
extracted_features.push(feature.as_mut());
}
if let Some(ref mut feature) = self.with_log10 {
extracted_features.push(feature.as_mut());
}
if let Some(ref mut feature) = self.with_normalized {
extracted_features.push(feature.as_mut());
}
if let Some(ref mut feature) = self.with_squared {
extracted_features.push(feature.as_mut());
}
extracted_features
}
}
#[derive(Debug)]
pub enum FeatureTags<'a> {
Name(&'a str),
Predicted,
DateFormat(&'a str),
OneHotEncode,
ToTimestamp,
ExtractMonth,
Log10,
Normalized,
FilterOutliers,
Squared,
UsedInModel,
IsId,
AddExtractedMonth,
AddExtractedTimestamp,
AddLog10,
AddNormalized,
AddSquared,
AddFeatureExtractedMonth(&'a [FeatureTags<'a>]),
AddFeatureExtractedTimestamp(&'a [FeatureTags<'a>]),
AddFeatureLog10(&'a [FeatureTags<'a>]),
AddFeatureNormalized(&'a [FeatureTags<'a>]),
AddFeatureSquared(&'a [FeatureTags<'a>]),
Mapped(MapSelector, MapOp),
Not(&'a FeatureTags<'a>),
RecurseAdded(&'a FeatureTags<'a>),
ExceptFeatures(&'a FeatureTags<'a>, &'a [&'a str]),
OnlyFeatures(&'a FeatureTags<'a>, &'a [&'a str]),
}
impl<'a> FeatureTags<'a> {
pub fn apply(&self, feature: &mut Feature) {
self.apply_bool(feature, true)
}
pub fn except(&'a self, exceptions: &'a [&'a str]) -> FeatureTags<'a> {
FeatureTags::ExceptFeatures(self, exceptions)
}
pub fn only(&'a self, features: &'a [&'a str]) -> FeatureTags<'a> {
FeatureTags::OnlyFeatures(self, features)
}
pub fn incl_added_features(&'a self) -> FeatureTags<'a> {
FeatureTags::RecurseAdded(self)
}
fn apply_bool(&self, feature: &mut Feature, value: bool) {
match self {
FeatureTags::Name(name) => feature.name = name.to_string(),
FeatureTags::Predicted => feature.predicted = value,
FeatureTags::DateFormat(date_format) => {
feature.date_format = Some(date_format.to_string())
}
FeatureTags::ToTimestamp => feature.to_timestamp = value,
FeatureTags::ExtractMonth => feature.extract_month = value,
FeatureTags::Log10 => feature.log10 = value,
FeatureTags::Normalized => feature.normalized = value,
FeatureTags::FilterOutliers => feature.filter_outliers = value,
FeatureTags::Squared => feature.squared = value,
FeatureTags::OneHotEncode => feature.one_hot_encoded = value,
FeatureTags::UsedInModel => feature.used_in_model = value,
FeatureTags::IsId => feature.is_id = value,
FeatureTags::AddFeatureExtractedMonth(with_extracted_month) => {
feature.with_extracted_month =
Some(Box::new(Feature::from_tags(with_extracted_month)))
}
FeatureTags::AddFeatureExtractedTimestamp(with_extracted_timestamp) => {
feature.with_extracted_timestamp =
Some(Box::new(Feature::from_tags(with_extracted_timestamp)))
}
FeatureTags::AddFeatureLog10(with_log10) => {
feature.with_log10 = Some(Box::new(Feature::from_tags(with_log10)))
}
FeatureTags::AddFeatureNormalized(with_normalized) => {
feature.with_normalized = Some(Box::new(Feature::from_tags(with_normalized)))
}
FeatureTags::AddFeatureSquared(with_squared) => {
feature.with_squared = Some(Box::new(Feature::from_tags(with_squared)))
}
FeatureTags::Mapped(map_selector, map_op) => {
feature.mapped = Some((map_selector.clone(), map_op.clone()))
}
FeatureTags::Not(feature_tag) => feature_tag.apply_bool(feature, !value),
FeatureTags::ExceptFeatures(feature_tag, exceptions) => {
if !exceptions.contains(&feature.name.as_str()) {
feature_tag.apply_bool(feature, value)
}
}
FeatureTags::OnlyFeatures(feature_tag, inclusions) => {
if inclusions.contains(&feature.name.as_str()) {
feature_tag.apply_bool(feature, value)
}
}
FeatureTags::AddExtractedMonth => {
feature.with_extracted_month =
Some(Box::new(Feature::from_tags(&[FeatureTags::Name(
&format!("{}_month", feature.name),
)])))
}
FeatureTags::AddExtractedTimestamp => {
feature.with_extracted_timestamp =
Some(Box::new(Feature::from_tags(&[FeatureTags::Name(
&format!("{}_timestamp", feature.name),
)])))
}
FeatureTags::AddLog10 => {
feature.with_log10 = Some(Box::new(Feature::from_tags(&[FeatureTags::Name(
&format!("log10({})", feature.name),
)])))
}
FeatureTags::AddNormalized => {
feature.with_normalized =
Some(Box::new(Feature::from_tags(&[FeatureTags::Name(
&format!("{}_normalized", feature.name),
)])))
}
FeatureTags::AddSquared => {
feature.with_squared =
Some(Box::new(Feature::from_tags(&[FeatureTags::Name(
&format!("{}^2", feature.name),
)])))
}
FeatureTags::RecurseAdded(feature_tag) => {
for extracted_feature in feature.get_extracted_features_mut().into_iter() {
self.apply_bool(extracted_feature, value)
}
feature_tag.apply_bool(feature, value);
}
};
}
}