forestfire-core 0.1.0

Core tree-learning algorithms for the ForestFire project.
Documentation
use crate::{
    Criterion, GradientBoostedTrees, Model, Parallelism, RandomForest, Task, TrainAlgorithm,
    TrainConfig, TrainError, TreeType, tree,
};
use forestfire_data::TableAccess;
use rayon::ThreadPoolBuilder;

pub fn train(train_set: &dyn TableAccess, config: TrainConfig) -> Result<Model, TrainError> {
    let criterion = resolve_criterion(
        config.algorithm,
        config.task,
        config.tree_type,
        config.criterion,
    )?;
    let parallelism = resolve_parallelism(config.physical_cores)?;
    let max_depth = config.max_depth.unwrap_or(8);
    if max_depth == 0 {
        return Err(TrainError::InvalidMaxDepth(max_depth));
    }
    let min_samples_split = config.min_samples_split.unwrap_or(2);
    if min_samples_split == 0 {
        return Err(TrainError::InvalidMinSamplesSplit(min_samples_split));
    }
    let min_samples_leaf = config.min_samples_leaf.unwrap_or(1);
    if min_samples_leaf == 0 {
        return Err(TrainError::InvalidMinSamplesLeaf(min_samples_leaf));
    }

    run_with_parallelism(parallelism, || match config.algorithm {
        TrainAlgorithm::Dt => train_single_model(
            train_set,
            SingleModelConfig {
                task: config.task,
                tree_type: config.tree_type,
                criterion,
                parallelism,
                max_depth,
                min_samples_split,
                min_samples_leaf,
            },
        ),
        TrainAlgorithm::Rf => train_random_forest(
            train_set,
            RandomForestConfig {
                task: config.task,
                tree_type: config.tree_type,
                criterion,
                parallelism,
                n_trees: config.n_trees.unwrap_or(1000),
                max_depth,
                min_samples_split,
                min_samples_leaf,
                max_features: config.max_features,
                seed: config.seed,
                compute_oob: config.compute_oob,
            },
        ),
        TrainAlgorithm::Gbm => train_gradient_boosting(
            train_set,
            TrainConfig {
                criterion,
                ..config
            },
            parallelism,
        ),
    })
}

pub(crate) struct SingleModelConfig {
    pub(crate) task: Task,
    pub(crate) tree_type: TreeType,
    pub(crate) criterion: Criterion,
    pub(crate) parallelism: Parallelism,
    pub(crate) max_depth: usize,
    pub(crate) min_samples_split: usize,
    pub(crate) min_samples_leaf: usize,
}

pub(crate) struct SingleModelFeatureSubsetConfig {
    pub(crate) base: SingleModelConfig,
    pub(crate) max_features: Option<usize>,
    pub(crate) random_seed: u64,
}

struct RandomForestConfig {
    task: Task,
    tree_type: TreeType,
    criterion: Criterion,
    parallelism: Parallelism,
    n_trees: usize,
    max_depth: usize,
    min_samples_split: usize,
    min_samples_leaf: usize,
    max_features: crate::MaxFeatures,
    seed: Option<u64>,
    compute_oob: bool,
}

pub(crate) fn train_single_model(
    train_set: &dyn TableAccess,
    config: SingleModelConfig,
) -> Result<Model, TrainError> {
    train_single_model_with_feature_subset(
        train_set,
        SingleModelFeatureSubsetConfig {
            base: config,
            max_features: None,
            random_seed: 0,
        },
    )
}

pub(crate) fn train_single_model_with_feature_subset(
    train_set: &dyn TableAccess,
    config: SingleModelFeatureSubsetConfig,
) -> Result<Model, TrainError> {
    let SingleModelFeatureSubsetConfig {
        base:
            SingleModelConfig {
                task,
                tree_type,
                criterion,
                parallelism,
                max_depth,
                min_samples_split,
                min_samples_leaf,
            },
        max_features,
        random_seed,
    } = config;
    let classifier_options = tree::classifier::DecisionTreeOptions {
        max_depth,
        min_samples_split,
        min_samples_leaf,
        max_features,
        random_seed,
    };
    let regressor_options = tree::regressor::RegressionTreeOptions {
        max_depth,
        min_samples_split,
        min_samples_leaf,
        max_features,
        random_seed,
    };

    match (task, tree_type, criterion) {
        (Task::Classification, TreeType::Id3, Criterion::Gini)
        | (Task::Classification, TreeType::Id3, Criterion::Entropy) => {
            tree::classifier::train_id3_with_criterion_parallelism_and_options(
                train_set,
                criterion,
                parallelism,
                classifier_options,
            )
            .map(Model::DecisionTreeClassifier)
            .map_err(TrainError::DecisionTree)
        }
        (Task::Classification, TreeType::C45, Criterion::Gini)
        | (Task::Classification, TreeType::C45, Criterion::Entropy) => {
            tree::classifier::train_c45_with_criterion_parallelism_and_options(
                train_set,
                criterion,
                parallelism,
                classifier_options,
            )
            .map(Model::DecisionTreeClassifier)
            .map_err(TrainError::DecisionTree)
        }
        (Task::Classification, TreeType::Cart, Criterion::Gini)
        | (Task::Classification, TreeType::Cart, Criterion::Entropy) => {
            tree::classifier::train_cart_with_criterion_parallelism_and_options(
                train_set,
                criterion,
                parallelism,
                classifier_options,
            )
            .map(Model::DecisionTreeClassifier)
            .map_err(TrainError::DecisionTree)
        }
        (Task::Classification, TreeType::Randomized, Criterion::Gini)
        | (Task::Classification, TreeType::Randomized, Criterion::Entropy) => {
            tree::classifier::train_randomized_with_criterion_parallelism_and_options(
                train_set,
                criterion,
                parallelism,
                classifier_options,
            )
            .map(Model::DecisionTreeClassifier)
            .map_err(TrainError::DecisionTree)
        }
        (Task::Classification, TreeType::Oblivious, Criterion::Gini)
        | (Task::Classification, TreeType::Oblivious, Criterion::Entropy) => {
            tree::classifier::train_oblivious_with_criterion_parallelism_and_options(
                train_set,
                criterion,
                parallelism,
                classifier_options,
            )
            .map(Model::DecisionTreeClassifier)
            .map_err(TrainError::DecisionTree)
        }
        (Task::Regression, TreeType::Cart, Criterion::Mean)
        | (Task::Regression, TreeType::Cart, Criterion::Median) => {
            tree::regressor::train_cart_regressor_with_criterion_parallelism_and_options(
                train_set,
                criterion,
                parallelism,
                regressor_options,
            )
            .map(Model::DecisionTreeRegressor)
            .map_err(TrainError::RegressionTree)
        }
        (Task::Regression, TreeType::Randomized, Criterion::Mean)
        | (Task::Regression, TreeType::Randomized, Criterion::Median) => {
            tree::regressor::train_randomized_regressor_with_criterion_parallelism_and_options(
                train_set,
                criterion,
                parallelism,
                regressor_options,
            )
            .map(Model::DecisionTreeRegressor)
            .map_err(TrainError::RegressionTree)
        }
        (Task::Regression, TreeType::Oblivious, Criterion::Mean)
        | (Task::Regression, TreeType::Oblivious, Criterion::Median) => {
            tree::regressor::train_oblivious_regressor_with_criterion_parallelism_and_options(
                train_set,
                criterion,
                parallelism,
                regressor_options,
            )
            .map(Model::DecisionTreeRegressor)
            .map_err(TrainError::RegressionTree)
        }
        (task, tree_type, criterion) => Err(TrainError::UnsupportedConfiguration {
            task,
            tree_type,
            criterion,
        }),
    }
}

fn train_random_forest(
    train_set: &dyn TableAccess,
    config: RandomForestConfig,
) -> Result<Model, TrainError> {
    RandomForest::train(
        train_set,
        TrainConfig {
            algorithm: TrainAlgorithm::Rf,
            task: config.task,
            tree_type: config.tree_type,
            criterion: config.criterion,
            max_depth: Some(config.max_depth),
            min_samples_split: Some(config.min_samples_split),
            min_samples_leaf: Some(config.min_samples_leaf),
            physical_cores: None,
            n_trees: Some(config.n_trees),
            max_features: config.max_features,
            seed: config.seed,
            compute_oob: config.compute_oob,
            learning_rate: None,
            bootstrap: false,
            top_gradient_fraction: None,
            other_gradient_fraction: None,
        },
        config.criterion,
        config.parallelism,
    )
    .map(Model::RandomForest)
}

fn train_gradient_boosting(
    train_set: &dyn TableAccess,
    config: TrainConfig,
    parallelism: Parallelism,
) -> Result<Model, TrainError> {
    GradientBoostedTrees::train(train_set, config, parallelism)
        .map(Model::GradientBoostedTrees)
        .map_err(TrainError::Boosting)
}

fn resolve_criterion(
    algorithm: TrainAlgorithm,
    task: Task,
    tree_type: TreeType,
    criterion: Criterion,
) -> Result<Criterion, TrainError> {
    let resolved = match (algorithm, task, tree_type, criterion) {
        (
            TrainAlgorithm::Gbm,
            Task::Regression | Task::Classification,
            TreeType::Cart | TreeType::Randomized | TreeType::Oblivious,
            Criterion::Auto,
        ) => Criterion::SecondOrder,
        (
            TrainAlgorithm::Dt | TrainAlgorithm::Rf,
            Task::Regression,
            TreeType::Cart | TreeType::Randomized | TreeType::Oblivious,
            Criterion::Auto,
        ) => Criterion::Mean,
        (
            TrainAlgorithm::Dt | TrainAlgorithm::Rf,
            Task::Regression,
            TreeType::Cart | TreeType::Randomized | TreeType::Oblivious,
            Criterion::Mean | Criterion::Median,
        ) => criterion,
        (
            TrainAlgorithm::Dt | TrainAlgorithm::Rf,
            Task::Classification,
            TreeType::Id3 | TreeType::C45,
            Criterion::Auto,
        ) => Criterion::Entropy,
        (
            TrainAlgorithm::Dt | TrainAlgorithm::Rf,
            Task::Classification,
            TreeType::Id3 | TreeType::C45,
            Criterion::Gini | Criterion::Entropy,
        ) => criterion,
        (
            TrainAlgorithm::Dt | TrainAlgorithm::Rf,
            Task::Classification,
            TreeType::Cart | TreeType::Randomized | TreeType::Oblivious,
            Criterion::Auto,
        ) => Criterion::Gini,
        (
            TrainAlgorithm::Dt | TrainAlgorithm::Rf,
            Task::Classification,
            TreeType::Cart | TreeType::Randomized | TreeType::Oblivious,
            Criterion::Gini | Criterion::Entropy,
        ) => criterion,
        (_, task, tree_type, criterion) => {
            return Err(TrainError::UnsupportedConfiguration {
                task,
                tree_type,
                criterion,
            });
        }
    };

    Ok(resolved)
}

fn resolve_parallelism(physical_cores: Option<usize>) -> Result<Parallelism, TrainError> {
    let available = num_cpus::get_physical().max(1);
    let requested = physical_cores.unwrap_or(available);

    if requested == 0 {
        return Err(TrainError::InvalidPhysicalCoreCount {
            requested,
            available,
        });
    }

    Ok(Parallelism {
        thread_count: requested.min(available),
    })
}

fn run_with_parallelism<T, F>(parallelism: Parallelism, train_fn: F) -> Result<T, TrainError>
where
    T: Send,
    F: FnOnce() -> Result<T, TrainError> + Send,
{
    if !parallelism.enabled() {
        return train_fn();
    }

    ThreadPoolBuilder::new()
        .num_threads(parallelism.thread_count)
        .build()
        .map_err(|err| TrainError::ThreadPoolBuildFailed(err.to_string()))?
        .install(train_fn)
}