use super::{
BernoulliNBParameters, CategoricalNBParameters, DecisionTreeClassifierParameters,
FinalAlgorithm, GaussianNBParameters, KNNParameters, LogisticRegressionParameters, Metric,
MultinomialNBParameters, PreProcessing, RandomForestClassifierParameters, SVCParameters,
SettingsError, SupervisedSettings, WithSupervisedSettings,
};
use crate::settings::macros::with_settings_methods;
use smartcore::linalg::basic::arrays::Array1;
use smartcore::metrics::accuracy;
use smartcore::model_selection::KFold;
use smartcore::numbers::basenum::Number;
pub struct ClassificationSettings {
pub(crate) supervised: SupervisedSettings,
pub(crate) knn_classifier_settings: Option<KNNParameters>,
pub(crate) decision_tree_classifier_settings: Option<DecisionTreeClassifierParameters>,
pub(crate) random_forest_classifier_settings: Option<RandomForestClassifierParameters>,
pub(crate) logistic_regression_settings: Option<LogisticRegressionParameters<f64>>,
pub(crate) bernoulli_nb_settings: Option<BernoulliNBParameters<f64>>,
pub(crate) gaussian_nb_settings: Option<GaussianNBParameters>,
pub(crate) categorical_nb_settings: Option<CategoricalNBParameters>,
pub(crate) multinomial_nb_settings: Option<MultinomialNBParameters>,
pub(crate) svc_settings: Option<SVCParameters>,
}
impl Default for ClassificationSettings {
fn default() -> Self {
Self {
supervised: SupervisedSettings {
sort_by: Metric::Accuracy,
..SupervisedSettings::default()
},
knn_classifier_settings: Some(KNNParameters::default()),
decision_tree_classifier_settings: Some(DecisionTreeClassifierParameters::default()),
random_forest_classifier_settings: Some(RandomForestClassifierParameters::default()),
logistic_regression_settings: Some(LogisticRegressionParameters::default()),
bernoulli_nb_settings: None,
gaussian_nb_settings: Some(GaussianNBParameters::default()),
categorical_nb_settings: None,
multinomial_nb_settings: None,
svc_settings: None,
}
}
}
impl ClassificationSettings {
pub fn get_metric<OUTPUT, OutputArray>(
&self,
) -> Result<fn(&OutputArray, &OutputArray) -> f64, SettingsError>
where
OUTPUT: Number + Ord,
OutputArray: Array1<OUTPUT>,
{
match self.supervised.sort_by {
Metric::Accuracy => Ok(accuracy),
Metric::None => Err(SettingsError::MetricNotSet),
m => Err(SettingsError::UnsupportedMetric(m)),
}
}
with_settings_methods! {
with_knn_classifier_settings, knn_classifier_settings, KNNParameters;
with_decision_tree_classifier_settings, decision_tree_classifier_settings, DecisionTreeClassifierParameters;
with_random_forest_classifier_settings, random_forest_classifier_settings, RandomForestClassifierParameters;
with_logistic_regression_settings, logistic_regression_settings, LogisticRegressionParameters<f64>;
with_svc_settings, svc_settings, SVCParameters;
}
#[must_use]
pub fn with_bernoulli_nb_settings(mut self, settings: BernoulliNBParameters<f64>) -> Self {
self.bernoulli_nb_settings = Some(settings);
self
}
#[must_use]
pub fn with_gaussian_nb_settings(mut self, settings: GaussianNBParameters) -> Self {
self.gaussian_nb_settings = Some(settings);
self
}
#[must_use]
pub fn with_categorical_nb_settings(mut self, settings: CategoricalNBParameters) -> Self {
self.categorical_nb_settings = Some(settings);
self
}
#[must_use]
pub fn with_multinomial_nb_settings(mut self, settings: MultinomialNBParameters) -> Self {
self.multinomial_nb_settings = Some(settings);
self
}
#[must_use]
pub fn with_number_of_folds(self, n: usize) -> Self {
<Self as WithSupervisedSettings>::with_number_of_folds(self, n)
}
#[must_use]
pub fn shuffle_data(self, shuffle: bool) -> Self {
<Self as WithSupervisedSettings>::shuffle_data(self, shuffle)
}
#[must_use]
pub fn verbose(self, verbose: bool) -> Self {
<Self as WithSupervisedSettings>::verbose(self, verbose)
}
#[must_use]
pub fn with_preprocessing(self, pre: PreProcessing) -> Self {
<Self as WithSupervisedSettings>::with_preprocessing(self, pre)
}
#[must_use]
pub fn with_final_model(self, approach: FinalAlgorithm) -> Self {
<Self as WithSupervisedSettings>::with_final_model(self, approach)
}
#[must_use]
pub fn sorted_by(self, sort_by: Metric) -> Self {
<Self as WithSupervisedSettings>::sorted_by(self, sort_by)
}
#[must_use]
pub fn get_kfolds(&self) -> KFold {
<Self as WithSupervisedSettings>::get_kfolds(self)
}
}
impl WithSupervisedSettings for ClassificationSettings {
fn supervised(&self) -> &SupervisedSettings {
&self.supervised
}
fn supervised_mut(&mut self) -> &mut SupervisedSettings {
&mut self.supervised
}
}
mod serde_impls {
use super::{
BernoulliNBParameters, CategoricalNBParameters, ClassificationSettings,
DecisionTreeClassifierParameters, GaussianNBParameters, KNNParameters,
LogisticRegressionParameters, MultinomialNBParameters, RandomForestClassifierParameters,
SVCParameters, SupervisedSettings,
};
use serde::de::{self, MapAccess, Visitor};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;
impl Serialize for ClassificationSettings {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("ClassificationSettings", 10)?;
state.serialize_field("supervised", &self.supervised)?;
state.serialize_field("knn_classifier_settings", &self.knn_classifier_settings)?;
state.serialize_field(
"decision_tree_classifier_settings",
&self.decision_tree_classifier_settings,
)?;
state.serialize_field(
"random_forest_classifier_settings",
&self.random_forest_classifier_settings,
)?;
state.serialize_field(
"logistic_regression_settings",
&self.logistic_regression_settings,
)?;
state.serialize_field("bernoulli_nb_settings", &self.bernoulli_nb_settings)?;
state.serialize_field("gaussian_nb_settings", &self.gaussian_nb_settings)?;
state.serialize_field("categorical_nb_settings", &self.categorical_nb_settings)?;
state.serialize_field("multinomial_nb_settings", &self.multinomial_nb_settings)?;
state.serialize_field("svc_settings", &self.svc_settings)?;
state.end()
}
}
#[allow(clippy::too_many_lines)]
impl<'de> Deserialize<'de> for ClassificationSettings {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
enum Field {
Supervised,
KnnClassifierSettings,
DecisionTreeClassifierSettings,
RandomForestClassifierSettings,
LogisticRegressionSettings,
BernoulliNbSettings,
GaussianNbSettings,
CategoricalNbSettings,
MultinomialNbSettings,
SvcSettings,
}
impl<'de> Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct FieldVisitor;
#[allow(clippy::elidable_lifetime_names)]
impl<'de> Visitor<'de> for FieldVisitor {
type Value = Field;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a valid field name for ClassificationSettings")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
match value {
"supervised" => Ok(Field::Supervised),
"knn_classifier_settings" => Ok(Field::KnnClassifierSettings),
"decision_tree_classifier_settings" => {
Ok(Field::DecisionTreeClassifierSettings)
}
"random_forest_classifier_settings" => {
Ok(Field::RandomForestClassifierSettings)
}
"logistic_regression_settings" => {
Ok(Field::LogisticRegressionSettings)
}
"bernoulli_nb_settings" => Ok(Field::BernoulliNbSettings),
"gaussian_nb_settings" => Ok(Field::GaussianNbSettings),
"categorical_nb_settings" => Ok(Field::CategoricalNbSettings),
"multinomial_nb_settings" => Ok(Field::MultinomialNbSettings),
"svc_settings" => Ok(Field::SvcSettings),
other => Err(de::Error::unknown_field(other, FIELDS)),
}
}
}
deserializer.deserialize_identifier(FieldVisitor)
}
}
struct ClassificationSettingsVisitor;
#[allow(clippy::elidable_lifetime_names)]
impl<'de> Visitor<'de> for ClassificationSettingsVisitor {
type Value = ClassificationSettings;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a map describing ClassificationSettings")
}
#[allow(clippy::too_many_lines)]
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut supervised: Option<SupervisedSettings> = None;
let mut knn_classifier_settings: Option<Option<KNNParameters>> = None;
let mut decision_tree_classifier_settings: Option<
Option<DecisionTreeClassifierParameters>,
> = None;
let mut random_forest_classifier_settings: Option<
Option<RandomForestClassifierParameters>,
> = None;
let mut logistic_regression_settings: Option<
Option<LogisticRegressionParameters<f64>>,
> = None;
let mut bernoulli_nb_settings: Option<Option<BernoulliNBParameters<f64>>> =
None;
let mut gaussian_nb_settings: Option<Option<GaussianNBParameters>> = None;
let mut categorical_nb_settings: Option<Option<CategoricalNBParameters>> = None;
let mut multinomial_nb_settings: Option<Option<MultinomialNBParameters>> = None;
let mut svc_settings: Option<Option<SVCParameters>> = None;
while let Some(key) = map.next_key()? {
match key {
Field::Supervised => {
if supervised.is_some() {
return Err(de::Error::duplicate_field("supervised"));
}
supervised = Some(map.next_value()?);
}
Field::KnnClassifierSettings => {
if knn_classifier_settings.is_some() {
return Err(de::Error::duplicate_field(
"knn_classifier_settings",
));
}
knn_classifier_settings = Some(map.next_value()?);
}
Field::DecisionTreeClassifierSettings => {
if decision_tree_classifier_settings.is_some() {
return Err(de::Error::duplicate_field(
"decision_tree_classifier_settings",
));
}
decision_tree_classifier_settings = Some(map.next_value()?);
}
Field::RandomForestClassifierSettings => {
if random_forest_classifier_settings.is_some() {
return Err(de::Error::duplicate_field(
"random_forest_classifier_settings",
));
}
random_forest_classifier_settings = Some(map.next_value()?);
}
Field::LogisticRegressionSettings => {
if logistic_regression_settings.is_some() {
return Err(de::Error::duplicate_field(
"logistic_regression_settings",
));
}
logistic_regression_settings = Some(map.next_value()?);
}
Field::BernoulliNbSettings => {
if bernoulli_nb_settings.is_some() {
return Err(de::Error::duplicate_field(
"bernoulli_nb_settings",
));
}
bernoulli_nb_settings = Some(map.next_value()?);
}
Field::GaussianNbSettings => {
if gaussian_nb_settings.is_some() {
return Err(de::Error::duplicate_field("gaussian_nb_settings"));
}
gaussian_nb_settings = Some(map.next_value()?);
}
Field::CategoricalNbSettings => {
if categorical_nb_settings.is_some() {
return Err(de::Error::duplicate_field(
"categorical_nb_settings",
));
}
categorical_nb_settings = Some(map.next_value()?);
}
Field::MultinomialNbSettings => {
if multinomial_nb_settings.is_some() {
return Err(de::Error::duplicate_field(
"multinomial_nb_settings",
));
}
multinomial_nb_settings = Some(map.next_value()?);
}
Field::SvcSettings => {
if svc_settings.is_some() {
return Err(de::Error::duplicate_field("svc_settings"));
}
svc_settings = Some(map.next_value()?);
}
}
}
let mut settings = ClassificationSettings::default();
if let Some(value) = supervised {
settings.supervised = value;
}
if let Some(value) = knn_classifier_settings {
settings.knn_classifier_settings = value;
}
if let Some(value) = decision_tree_classifier_settings {
settings.decision_tree_classifier_settings = value;
}
if let Some(value) = random_forest_classifier_settings {
settings.random_forest_classifier_settings = value;
}
if let Some(value) = logistic_regression_settings {
settings.logistic_regression_settings = value;
}
if let Some(value) = bernoulli_nb_settings {
settings.bernoulli_nb_settings = value;
}
if let Some(value) = gaussian_nb_settings {
settings.gaussian_nb_settings = value;
}
if let Some(value) = categorical_nb_settings {
settings.categorical_nb_settings = value;
}
if let Some(value) = multinomial_nb_settings {
settings.multinomial_nb_settings = value;
}
if let Some(value) = svc_settings {
settings.svc_settings = value;
}
Ok(settings)
}
}
const FIELDS: &[&str] = &[
"supervised",
"knn_classifier_settings",
"decision_tree_classifier_settings",
"random_forest_classifier_settings",
"logistic_regression_settings",
"bernoulli_nb_settings",
"gaussian_nb_settings",
"categorical_nb_settings",
"multinomial_nb_settings",
"svc_settings",
];
deserializer.deserialize_struct(
"ClassificationSettings",
FIELDS,
ClassificationSettingsVisitor,
)
}
}
}