#[derive(Debug, Default, serde::Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct Config {
pub dataset: Dataset,
pub features: Features,
pub train: Train,
}
#[derive(Debug, serde::Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct Dataset {
pub shuffle: Shuffle,
pub comparison_fraction: f32,
pub test_fraction: f32,
pub columns: Vec<Column>,
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Shuffle {
pub enable: bool,
pub seed: u64,
}
#[derive(Debug, serde::Deserialize)]
#[serde(tag = "type")]
pub enum Column {
#[serde(rename = "unknown")]
Unknown(UnknownColumn),
#[serde(rename = "number")]
Number(NumberColumn),
#[serde(rename = "enum")]
Enum(EnumColumn),
#[serde(rename = "text")]
Text(TextColumn),
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct UnknownColumn {
pub name: String,
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct NumberColumn {
pub name: String,
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct EnumColumn {
pub name: String,
pub variants: Vec<String>,
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct TextColumn {
pub name: String,
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Features {
pub auto: AutoFeatures,
pub include: Option<Vec<FeatureGroup>>,
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct AutoFeatures {
pub enable: bool,
pub exclude_columns: Option<Vec<String>>,
}
#[derive(Debug, serde::Deserialize)]
#[serde(tag = "type")]
pub enum FeatureGroup {
#[serde(rename = "identity")]
Identity(IdentityFeatureGroup),
#[serde(rename = "normalized")]
Normalized(NormalizedFeatureGroup),
#[serde(rename = "one_hot_encoded")]
OneHotEncoded(OneHotEncodedFeatureGroup),
#[serde(rename = "bag_of_words")]
BagOfWords(BagOfWordsFeatureGroup),
#[serde(rename = "bag_of_words_cosine_similarity")]
BagOfWordsCosineSimilarity(BagOfWordsCosineSimilarityFeatureGroup),
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct IdentityFeatureGroup {
pub source_column_name: String,
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct NormalizedFeatureGroup {
pub source_column_name: String,
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct OneHotEncodedFeatureGroup {
pub source_column_name: String,
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct BagOfWordsFeatureGroup {
pub source_column_name: String,
pub strategy: Option<BagOfWordsFeatureGroupStrategy>,
}
#[derive(Debug, serde::Deserialize)]
pub enum BagOfWordsFeatureGroupStrategy {
#[serde(rename = "present")]
Present,
#[serde(rename = "count")]
Count,
#[serde(rename = "tfidf")]
TfIdf,
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct BagOfWordsCosineSimilarityFeatureGroup {
pub source_column_name_a: String,
pub source_column_name_b: String,
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Train {
pub grid: Option<Vec<GridItem>>,
pub comparison_metric: Option<ComparisonMetric>,
}
#[derive(Debug, serde::Deserialize)]
#[serde(tag = "model")]
pub enum GridItem {
#[serde(rename = "linear")]
Linear(LinearGridItem),
#[serde(rename = "tree")]
Tree(TreeGridItem),
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct LinearGridItem {
pub early_stopping_options: Option<EarlyStoppingOptions>,
pub l2_regularization: Option<f32>,
pub learning_rate: Option<f32>,
pub max_epochs: Option<u64>,
pub n_examples_per_batch: Option<u64>,
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct TreeGridItem {
pub binned_features_layout: Option<BinnedFeaturesLayout>,
pub early_stopping_options: Option<EarlyStoppingOptions>,
pub l2_regularization_for_continuous_splits: Option<f32>,
pub l2_regularization_for_discrete_splits: Option<f32>,
pub learning_rate: Option<f32>,
pub max_depth: Option<u64>,
pub max_examples_for_computing_bin_thresholds: Option<u64>,
pub max_leaf_nodes: Option<u64>,
pub max_rounds: Option<u64>,
pub max_valid_bins_for_number_features: Option<u8>,
pub min_examples_per_node: Option<u64>,
pub min_gain_to_split: Option<f32>,
pub min_sum_hessians_per_node: Option<f32>,
pub smoothing_factor_for_discrete_bin_sorting: Option<f32>,
}
#[derive(Debug, serde::Deserialize)]
pub enum BinnedFeaturesLayout {
RowMajor,
ColumnMajor,
}
#[derive(Debug, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct EarlyStoppingOptions {
pub early_stopping_fraction: f32,
pub n_rounds_without_improvement_to_stop: usize,
pub min_decrease_in_loss_for_significant_change: f32,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub enum ComparisonMetric {
#[serde(rename = "mae")]
Mae,
#[serde(rename = "mse")]
Mse,
#[serde(rename = "rmse")]
Rmse,
#[serde(rename = "r2")]
R2,
#[serde(rename = "accuracy")]
Accuracy,
#[serde(rename = "auc")]
Auc,
#[serde(rename = "f1")]
F1,
}
impl Default for Dataset {
fn default() -> Self {
Dataset {
comparison_fraction: 0.1,
test_fraction: 0.2,
shuffle: Default::default(),
columns: Default::default(),
}
}
}
impl Default for Shuffle {
fn default() -> Self {
Shuffle {
enable: true,
seed: 42,
}
}
}
impl Default for Features {
fn default() -> Self {
Features {
auto: Default::default(),
include: Default::default(),
}
}
}
impl Default for AutoFeatures {
fn default() -> Self {
AutoFeatures {
enable: true,
exclude_columns: Default::default(),
}
}
}
impl Default for Train {
fn default() -> Self {
Train {
grid: Default::default(),
comparison_metric: Default::default(),
}
}
}
impl std::fmt::Display for ComparisonMetric {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
ComparisonMetric::Mae => "Mean Absolute Error",
ComparisonMetric::Mse => "Mean Squared Error",
ComparisonMetric::Rmse => "Root Mean Squared Error",
ComparisonMetric::R2 => "R2",
ComparisonMetric::Accuracy => "Accuracy",
ComparisonMetric::Auc => "Area Under the Receiver Operating Characteristic Curve",
ComparisonMetric::F1 => "F1",
};
write!(f, "{}", s)
}
}