forestfire-core 0.3.0

Core tree-learning algorithms for the ForestFire project.
Documentation
use super::*;

pub(crate) const PARALLEL_INFERENCE_ROW_THRESHOLD: usize = 256;
pub(crate) const PARALLEL_INFERENCE_CHUNK_ROWS: usize = 256;
pub(crate) const STANDARD_BATCH_INFERENCE_CHUNK_ROWS: usize = 4096;
pub(crate) const OBLIVIOUS_SIMD_LANES: usize = 8;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) enum OptimizedRuntime {
    BinaryClassifier {
        nodes: Vec<OptimizedBinaryClassifierNode>,
        class_labels: Vec<f64>,
    },
    StandardClassifier {
        nodes: Vec<OptimizedClassifierNode>,
        root: usize,
        class_labels: Vec<f64>,
    },
    ObliviousClassifier {
        feature_indices: Vec<usize>,
        threshold_bins: Vec<u16>,
        leaf_values: Vec<Vec<f64>>,
        class_labels: Vec<f64>,
    },
    BinaryRegressor {
        nodes: Vec<OptimizedBinaryRegressorNode>,
    },
    ObliviousRegressor {
        feature_indices: Vec<usize>,
        threshold_bins: Vec<u16>,
        leaf_values: Vec<f64>,
    },
    ForestClassifier {
        trees: Vec<OptimizedRuntime>,
        class_labels: Vec<f64>,
    },
    ForestRegressor {
        trees: Vec<OptimizedRuntime>,
    },
    BoostedBinaryClassifier {
        trees: Vec<OptimizedRuntime>,
        tree_weights: Vec<f64>,
        base_score: f64,
        class_labels: Vec<f64>,
    },
    BoostedRegressor {
        trees: Vec<OptimizedRuntime>,
        tree_weights: Vec<f64>,
        base_score: f64,
    },
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) enum OptimizedClassifierNode {
    Leaf(Vec<f64>),
    Binary {
        feature_index: usize,
        threshold_bin: u16,
        children: [usize; 2],
        missing_bin: Option<u16>,
        missing_child: Option<usize>,
        missing_probabilities: Option<Vec<f64>>,
    },
    Multiway {
        feature_index: usize,
        child_lookup: Vec<usize>,
        max_bin_index: usize,
        missing_bin: Option<u16>,
        missing_child: Option<usize>,
        fallback_probabilities: Vec<f64>,
    },
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) enum OptimizedBinaryClassifierNode {
    Leaf(Vec<f64>),
    Branch {
        feature_index: usize,
        threshold_bin: u16,
        jump_index: usize,
        jump_if_greater: bool,
        missing_bin: Option<u16>,
        missing_jump_index: Option<usize>,
        missing_probabilities: Option<Vec<f64>>,
    },
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) enum OptimizedBinaryRegressorNode {
    Leaf(f64),
    Branch {
        feature_index: usize,
        threshold_bin: u16,
        jump_index: usize,
        jump_if_greater: bool,
        missing_bin: Option<u16>,
        missing_jump_index: Option<usize>,
        missing_value: Option<f64>,
    },
}

#[derive(Debug, Clone)]
pub(crate) struct InferenceExecutor {
    pub(crate) thread_count: usize,
    pub(crate) pool: Option<Arc<rayon::ThreadPool>>,
}

impl InferenceExecutor {
    pub(crate) fn new(thread_count: usize) -> Result<Self, OptimizeError> {
        let pool = if thread_count > 1 {
            Some(Arc::new(
                ThreadPoolBuilder::new()
                    .num_threads(thread_count)
                    .build()
                    .map_err(|err| OptimizeError::ThreadPoolBuildFailed(err.to_string()))?,
            ))
        } else {
            None
        };

        Ok(Self { thread_count, pool })
    }

    pub(crate) fn predict_rows<F>(&self, n_rows: usize, predict_row: F) -> Vec<f64>
    where
        F: Fn(usize) -> f64 + Sync + Send,
    {
        if self.thread_count == 1 || n_rows < PARALLEL_INFERENCE_ROW_THRESHOLD {
            return (0..n_rows).map(predict_row).collect();
        }

        self.pool
            .as_ref()
            .expect("thread pool exists when parallel inference is enabled")
            .install(|| (0..n_rows).into_par_iter().map(predict_row).collect())
    }

    pub(crate) fn fill_chunks<F>(&self, outputs: &mut [f64], chunk_rows: usize, fill_chunk: F)
    where
        F: Fn(usize, &mut [f64]) + Sync + Send,
    {
        if self.thread_count == 1 || outputs.len() < PARALLEL_INFERENCE_ROW_THRESHOLD {
            for (chunk_index, chunk) in outputs.chunks_mut(chunk_rows).enumerate() {
                fill_chunk(chunk_index * chunk_rows, chunk);
            }
            return;
        }

        self.pool
            .as_ref()
            .expect("thread pool exists when parallel inference is enabled")
            .install(|| {
                outputs
                    .par_chunks_mut(chunk_rows)
                    .enumerate()
                    .for_each(|(chunk_index, chunk)| fill_chunk(chunk_index * chunk_rows, chunk));
            });
    }
}

pub(crate) fn resolve_inference_thread_count(
    physical_cores: Option<usize>,
) -> Result<usize, OptimizeError> {
    let available = num_cpus::get_physical().max(1);
    let requested = physical_cores.unwrap_or(available);

    if requested == 0 {
        return Err(OptimizeError::InvalidPhysicalCoreCount {
            requested,
            available,
        });
    }

    Ok(requested.min(available))
}