Struct linfa_trees::DecisionTreeParams[][src]

pub struct DecisionTreeParams<F, L> {
    pub split_quality: SplitQuality,
    pub max_depth: Option<usize>,
    pub min_weight_split: f32,
    pub min_weight_leaf: f32,
    pub min_impurity_decrease: F,
    pub phantom: PhantomData<L>,
}

The set of hyperparameters that can be specified for fitting a decision tree.

Example

use linfa_trees::{DecisionTree, SplitQuality};
use linfa_datasets::iris;
use linfa::prelude::*;

// Initialize the default set of parameters
let params = DecisionTree::params();
// Set the parameters to the desired values
let params = params.split_quality(SplitQuality::Entropy).max_depth(Some(5)).min_weight_leaf(2.);

// Load the data
let (train, val) = linfa_datasets::iris().split_with_ratio(0.9);
// Fit the decision tree on the training data
let tree = params.fit(&train).unwrap();
// Predict on validation and check accuracy
let val_accuracy = tree.predict(&val).confusion_matrix(&val).unwrap().accuracy();
assert!(val_accuracy > 0.99);

Fields

split_quality: SplitQualitymax_depth: Option<usize>min_weight_split: f32min_weight_leaf: f32min_impurity_decrease: Fphantom: PhantomData<L>

Implementations

impl<F: Float, L: Label> DecisionTreeParams<F, L>[src]

pub fn split_quality(self, split_quality: SplitQuality) -> Self[src]

Sets the metric used to decide the feature on which to split a node

pub fn max_depth(self, max_depth: Option<usize>) -> Self[src]

Sets the optional limit to the depth of the decision tree

pub fn min_weight_split(self, min_weight_split: f32) -> Self[src]

Sets the minimum weight of samples required to split a node.

If the observations do not have associated weights, this value represents the minimum number of samples required to split a node.

pub fn min_weight_leaf(self, min_weight_leaf: f32) -> Self[src]

Sets the minimum weight of samples that a split has to place in each leaf

If the observations do not have associated weights, this value represents the minimum number of samples that a split has to place in each leaf.

pub fn min_impurity_decrease(self, min_impurity_decrease: F) -> Self[src]

Sets the minimum decrease in impurity that a split needs to bring in order for it to be applied

pub fn validate(&self) -> Result<()>[src]

Checks the correctness of the hyperparameters

Panics

If the minimum impurity increase is not greater than zero

Trait Implementations

impl<F: Clone, L: Clone> Clone for DecisionTreeParams<F, L>[src]

impl<F: Copy, L: Copy> Copy for DecisionTreeParams<F, L>[src]

impl<F: Debug, L: Debug> Debug for DecisionTreeParams<F, L>[src]

impl<'a, F: Float, L: Label + 'a + Debug, D, T> Fit<'a, ArrayBase<D, Dim<[usize; 2]>>, T> for DecisionTreeParams<F, L> where
    D: Data<Elem = F>,
    T: AsTargets<Elem = L> + Labels<Elem = L>, 
[src]

type Object = Result<DecisionTree<F, L>>

fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Self::Object[src]

Fit a decision tree using hyperparamters on the dataset consisting of a matrix of features x and an array of labels y.

Auto Trait Implementations

impl<F, L> RefUnwindSafe for DecisionTreeParams<F, L> where
    F: RefUnwindSafe,
    L: RefUnwindSafe

impl<F, L> Send for DecisionTreeParams<F, L> where
    F: Send,
    L: Send

impl<F, L> Sync for DecisionTreeParams<F, L> where
    F: Sync,
    L: Sync

impl<F, L> Unpin for DecisionTreeParams<F, L> where
    F: Unpin,
    L: Unpin

impl<F, L> UnwindSafe for DecisionTreeParams<F, L> where
    F: UnwindSafe,
    L: UnwindSafe

Blanket Implementations

impl<T> Any for T where
    T: 'static + ?Sized
[src]

impl<T> Borrow<T> for T where
    T: ?Sized
[src]

impl<T> BorrowMut<T> for T where
    T: ?Sized
[src]

impl<T> From<T> for T[src]

impl<T, U> Into<U> for T where
    U: From<T>, 
[src]

impl<T> Pointable for T

type Init = T

The type for initializers.

impl<T> ToOwned for T where
    T: Clone
[src]

type Owned = T

The resulting type after obtaining ownership.

impl<T, U> TryFrom<U> for T where
    U: Into<T>, 
[src]

type Error = Infallible

The type returned in the event of a conversion error.

impl<T, U> TryInto<U> for T where
    U: TryFrom<T>, 
[src]

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.

impl<V, T> VZip<V> for T where
    V: MultiLane<T>,