use crate::error::{Result, SklearsError};
use rayon::prelude::*;
use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::random::Random;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct ParallelEnsemble {
config: EnsembleConfig,
base_learners: Vec<Arc<dyn BaseEstimator>>,
training_state: Arc<RwLock<TrainingState>>,
}
impl ParallelEnsemble {
pub fn new(config: EnsembleConfig) -> Self {
let base_learners = Self::create_base_learners(&config);
Self {
config,
base_learners,
training_state: Arc::new(RwLock::new(TrainingState::new())),
}
}
fn create_base_learners(config: &EnsembleConfig) -> Vec<Arc<dyn BaseEstimator>> {
let mut learners = Vec::new();
for i in 0..config.n_estimators {
let learner: Arc<dyn BaseEstimator> = match &config.ensemble_type {
EnsembleType::RandomForest => {
Arc::new(RandomForestEstimator::new(i, &config.base_config))
}
EnsembleType::GradientBoosting => {
Arc::new(GradientBoostingEstimator::new(i, &config.base_config))
}
EnsembleType::AdaBoost => Arc::new(AdaBoostEstimator::new(i, &config.base_config)),
EnsembleType::Voting => Arc::new(VotingEstimator::new(i, &config.base_config)),
EnsembleType::Stacking => Arc::new(StackingEstimator::new(i, &config.base_config)),
};
learners.push(learner);
}
learners
}
pub fn n_estimators(&self) -> usize {
self.base_learners.len()
}
pub fn parallel_fit(
&self,
x: &ArrayView2<f64>,
y: &ArrayView1<f64>,
) -> Result<TrainedParallelEnsemble> {
let start_time = Instant::now();
{
let mut state = self
.training_state
.write()
.unwrap_or_else(|e| e.into_inner());
state.start_training(x.nrows(), self.n_estimators());
}
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(self.config.parallel_config.num_threads)
.build()
.map_err(|e| {
SklearsError::InvalidInput(format!("Failed to create thread pool: {e}"))
})?;
let trained_learners = pool.install(|| {
self.base_learners
.par_iter()
.enumerate()
.map(|(i, learner)| {
let result = self.train_single_learner(learner.as_ref(), x, y, i);
{
let mut state = self
.training_state
.write()
.unwrap_or_else(|e| e.into_inner());
state.update_progress(i, result.is_ok());
}
result
})
.collect::<Result<Vec<_>>>()
})?;
{
let mut state = self
.training_state
.write()
.unwrap_or_else(|e| e.into_inner());
state.complete_training(start_time.elapsed());
}
Ok(TrainedParallelEnsemble {
config: self.config.clone(),
trained_learners,
training_metrics: self
.training_state
.read()
.unwrap_or_else(|e| e.into_inner())
.clone(),
})
}
fn train_single_learner(
&self,
learner: &dyn BaseEstimator,
x: &ArrayView2<f64>,
y: &ArrayView1<f64>,
learner_id: usize,
) -> Result<TrainedBaseEstimator> {
let (train_x, train_y) = self.prepare_training_data(x, y, learner_id)?;
let start_time = Instant::now();
let trained = learner.fit(&train_x.view(), &train_y.view())?;
let training_time = start_time.elapsed();
let training_accuracy =
self.compute_training_accuracy(trained.as_ref(), &train_x, &train_y)?;
Ok(TrainedBaseEstimator {
learner_id,
model: trained,
training_time,
training_accuracy,
})
}
fn prepare_training_data(
&self,
x: &ArrayView2<f64>,
y: &ArrayView1<f64>,
learner_id: usize,
) -> Result<(Array2<f64>, Array1<f64>)> {
match self.config.sampling_strategy {
SamplingStrategy::Bootstrap => self.bootstrap_sample(x, y, learner_id),
SamplingStrategy::Bagging => self.bag_sample(x, y, learner_id),
SamplingStrategy::None => Ok((x.to_owned(), y.to_owned())),
}
}
fn bootstrap_sample(
&self,
x: &ArrayView2<f64>,
y: &ArrayView1<f64>,
seed: usize,
) -> Result<(Array2<f64>, Array1<f64>)> {
let mut rng = Random::seed(self.config.random_seed + seed as u64);
let n_samples = x.nrows();
let mut sampled_x = Array2::zeros((n_samples, x.ncols()));
let mut sampled_y = Array1::zeros(n_samples);
for i in 0..n_samples {
let sample_idx = rng.gen_range(0..n_samples);
sampled_x.row_mut(i).assign(&x.row(sample_idx));
sampled_y[i] = y[sample_idx];
}
Ok((sampled_x, sampled_y))
}
fn bag_sample(
&self,
x: &ArrayView2<f64>,
y: &ArrayView1<f64>,
seed: usize,
) -> Result<(Array2<f64>, Array1<f64>)> {
let mut rng = Random::seed(self.config.random_seed + seed as u64);
let n_samples = x.nrows();
let sample_size = (n_samples as f64 * self.config.subsample_ratio).round() as usize;
let mut indices: Vec<usize> = (0..n_samples).collect();
rng.shuffle(&mut indices);
indices.truncate(sample_size);
let mut sampled_x = Array2::zeros((sample_size, x.ncols()));
let mut sampled_y = Array1::zeros(sample_size);
for (i, &idx) in indices.iter().enumerate() {
sampled_x.row_mut(i).assign(&x.row(idx));
sampled_y[i] = y[idx];
}
Ok((sampled_x, sampled_y))
}
fn compute_training_accuracy(
&self,
model: &dyn TrainedBaseModel,
x: &Array2<f64>,
y: &Array1<f64>,
) -> Result<f64> {
let predictions = model.predict(&x.view())?;
let correct = predictions
.iter()
.zip(y.iter())
.map(|(pred, actual)| {
if (pred - actual).abs() < 0.5 {
1.0
} else {
0.0
}
})
.sum::<f64>();
Ok(correct / y.len() as f64)
}
}
#[derive(Debug)]
pub struct TrainedParallelEnsemble {
config: EnsembleConfig,
trained_learners: Vec<TrainedBaseEstimator>,
training_metrics: TrainingState,
}
impl TrainedParallelEnsemble {
pub fn n_estimators(&self) -> usize {
self.trained_learners.len()
}
pub fn training_metrics(&self) -> &TrainingState {
&self.training_metrics
}
pub fn parallel_predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
let n_samples = x.nrows();
let _n_estimators = self.trained_learners.len();
let all_predictions: Vec<Array1<f64>> = self
.trained_learners
.par_iter()
.map(|learner| learner.model.predict(x))
.collect::<Result<Vec<_>>>()?;
let mut final_predictions = Array1::zeros(n_samples);
match self.config.aggregation_method {
AggregationMethod::Voting => {
self.voting_aggregation(&all_predictions, &mut final_predictions)?;
}
AggregationMethod::Averaging => {
self.averaging_aggregation(&all_predictions, &mut final_predictions)?;
}
AggregationMethod::WeightedVoting => {
self.weighted_voting_aggregation(&all_predictions, &mut final_predictions)?;
}
AggregationMethod::Stacking => {
return self.stacking_aggregation(&all_predictions, x);
}
}
Ok(final_predictions)
}
fn voting_aggregation(
&self,
predictions: &[Array1<f64>],
output: &mut Array1<f64>,
) -> Result<()> {
let n_samples = output.len();
for i in 0..n_samples {
let mut votes = HashMap::new();
for pred_array in predictions {
let vote = pred_array[i].round() as i32;
*votes.entry(vote).or_insert(0) += 1;
}
let majority_vote = votes
.into_iter()
.max_by_key(|(_, count)| *count)
.map(|(vote, _)| vote as f64)
.unwrap_or(0.0);
output[i] = majority_vote;
}
Ok(())
}
fn averaging_aggregation(
&self,
predictions: &[Array1<f64>],
output: &mut Array1<f64>,
) -> Result<()> {
let n_estimators = predictions.len() as f64;
output.fill(0.0);
for pred_array in predictions {
for (out, pred) in output.iter_mut().zip(pred_array.iter()) {
*out += pred;
}
}
for out in output.iter_mut() {
*out /= n_estimators;
}
Ok(())
}
fn weighted_voting_aggregation(
&self,
predictions: &[Array1<f64>],
output: &mut Array1<f64>,
) -> Result<()> {
let n_samples = output.len();
let weights: Vec<f64> = self
.trained_learners
.iter()
.map(|learner| learner.training_accuracy)
.collect();
let weight_sum: f64 = weights.iter().sum();
output.fill(0.0);
for i in 0..n_samples {
for (j, pred_array) in predictions.iter().enumerate() {
output[i] += pred_array[i] * weights[j];
}
output[i] /= weight_sum;
}
Ok(())
}
fn stacking_aggregation(
&self,
predictions: &[Array1<f64>],
original_features: &ArrayView2<f64>,
) -> Result<Array1<f64>> {
let n_samples = original_features.nrows();
let n_base_features = original_features.ncols();
let n_meta_features = n_base_features + predictions.len();
let mut meta_features = Array2::zeros((n_samples, n_meta_features));
meta_features
.slice_mut(s![.., 0..n_base_features])
.assign(original_features);
for (i, pred_array) in predictions.iter().enumerate() {
meta_features
.column_mut(n_base_features + i)
.assign(pred_array);
}
let mut result = Array1::zeros(n_samples);
self.averaging_aggregation(predictions, &mut result)?;
Ok(result)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnsembleConfig {
pub ensemble_type: EnsembleType,
pub n_estimators: usize,
pub parallel_config: ParallelConfig,
pub sampling_strategy: SamplingStrategy,
pub aggregation_method: AggregationMethod,
pub base_config: BaseEstimatorConfig,
pub random_seed: u64,
pub subsample_ratio: f64,
}
impl EnsembleConfig {
pub fn random_forest() -> Self {
Self {
ensemble_type: EnsembleType::RandomForest,
n_estimators: 100,
parallel_config: ParallelConfig::default(),
sampling_strategy: SamplingStrategy::Bootstrap,
aggregation_method: AggregationMethod::Voting,
base_config: BaseEstimatorConfig::decision_tree(),
random_seed: 42,
subsample_ratio: 1.0,
}
}
pub fn gradient_boosting() -> Self {
Self {
ensemble_type: EnsembleType::GradientBoosting,
n_estimators: 100,
parallel_config: ParallelConfig::default(),
sampling_strategy: SamplingStrategy::None,
aggregation_method: AggregationMethod::Averaging,
base_config: BaseEstimatorConfig::decision_tree(),
random_seed: 42,
subsample_ratio: 1.0,
}
}
pub fn with_n_estimators(mut self, n: usize) -> Self {
self.n_estimators = n;
self
}
pub fn with_parallel_config(mut self, config: ParallelConfig) -> Self {
self.parallel_config = config;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum EnsembleType {
RandomForest,
GradientBoosting,
AdaBoost,
Voting,
Stacking,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SamplingStrategy {
Bootstrap,
Bagging,
None,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AggregationMethod {
Voting,
Averaging,
WeightedVoting,
Stacking,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParallelConfig {
pub num_threads: usize,
pub batch_size: usize,
pub work_stealing: bool,
pub load_balancing: LoadBalancingStrategy,
}
impl ParallelConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_num_threads(mut self, threads: usize) -> Self {
self.num_threads = threads;
self
}
pub fn with_batch_size(mut self, size: usize) -> Self {
self.batch_size = size;
self
}
pub fn with_work_stealing(mut self, enabled: bool) -> Self {
self.work_stealing = enabled;
self
}
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
num_threads: num_cpus::get(),
batch_size: 1000,
work_stealing: true,
load_balancing: LoadBalancingStrategy::Dynamic,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LoadBalancingStrategy {
Static,
Dynamic,
WorkStealing,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BaseEstimatorConfig {
pub estimator_type: BaseEstimatorType,
pub parameters: HashMap<String, f64>,
}
impl BaseEstimatorConfig {
pub fn decision_tree() -> Self {
let mut params = HashMap::new();
params.insert("max_depth".to_string(), 10.0);
params.insert("min_samples_split".to_string(), 2.0);
params.insert("min_samples_leaf".to_string(), 1.0);
Self {
estimator_type: BaseEstimatorType::DecisionTree,
parameters: params,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BaseEstimatorType {
DecisionTree,
LinearModel,
NeuralNetwork,
SVM,
}
#[derive(Debug, Clone)]
pub struct TrainingState {
pub total_estimators: usize,
pub completed_estimators: usize,
pub failed_estimators: usize,
pub training_start_time: Option<Instant>,
pub training_duration: Option<Duration>,
pub data_size: (usize, usize), pub parallel_efficiency: f64,
}
impl TrainingState {
pub fn new() -> Self {
Self {
total_estimators: 0,
completed_estimators: 0,
failed_estimators: 0,
training_start_time: None,
training_duration: None,
data_size: (0, 0),
parallel_efficiency: 0.0,
}
}
pub fn start_training(&mut self, n_samples: usize, n_estimators: usize) {
self.total_estimators = n_estimators;
self.data_size = (n_samples, 0); self.training_start_time = Some(Instant::now());
self.completed_estimators = 0;
self.failed_estimators = 0;
}
pub fn update_progress(&mut self, _learner_id: usize, success: bool) {
if success {
self.completed_estimators += 1;
} else {
self.failed_estimators += 1;
}
}
pub fn complete_training(&mut self, duration: Duration) {
self.training_duration = Some(duration);
let sequential_time_estimate = duration.as_secs_f64() * self.total_estimators as f64;
let actual_time = duration.as_secs_f64();
self.parallel_efficiency = if actual_time > 0.0 {
(sequential_time_estimate / actual_time).min(1.0)
} else {
0.0
};
}
pub fn progress_percentage(&self) -> f64 {
if self.total_estimators == 0 {
0.0
} else {
(self.completed_estimators as f64 / self.total_estimators as f64) * 100.0
}
}
}
impl Default for TrainingState {
fn default() -> Self {
Self::new()
}
}
pub trait BaseEstimator: Send + Sync + std::fmt::Debug {
fn fit(&self, x: &ArrayView2<f64>, y: &ArrayView1<f64>) -> Result<Box<dyn TrainedBaseModel>>;
fn get_config(&self) -> &BaseEstimatorConfig;
}
pub trait TrainedBaseModel: Send + Sync + std::fmt::Debug {
fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>>;
fn get_importance(&self) -> Option<Array1<f64>> {
None
}
}
#[derive(Debug)]
pub struct TrainedBaseEstimator {
pub learner_id: usize,
pub model: Box<dyn TrainedBaseModel>,
pub training_time: Duration,
pub training_accuracy: f64,
}
#[derive(Debug)]
pub struct RandomForestEstimator {
id: usize,
config: BaseEstimatorConfig,
}
impl RandomForestEstimator {
pub fn new(id: usize, config: &BaseEstimatorConfig) -> Self {
Self {
id,
config: config.clone(),
}
}
}
impl BaseEstimator for RandomForestEstimator {
fn fit(&self, x: &ArrayView2<f64>, _y: &ArrayView1<f64>) -> Result<Box<dyn TrainedBaseModel>> {
std::thread::sleep(Duration::from_millis(10));
Ok(Box::new(TrainedRandomForestModel {
id: self.id,
feature_count: x.ncols(),
sample_count: x.nrows(),
}))
}
fn get_config(&self) -> &BaseEstimatorConfig {
&self.config
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct TrainedRandomForestModel {
id: usize,
feature_count: usize,
sample_count: usize,
}
impl TrainedBaseModel for TrainedRandomForestModel {
fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
let mut rng = Random::seed(self.id as u64);
let predictions =
Array1::from_iter((0..x.nrows()).map(|_| rng.random_range(0.0_f64..3.0_f64).round()));
Ok(predictions)
}
}
#[derive(Debug)]
pub struct GradientBoostingEstimator {
id: usize,
config: BaseEstimatorConfig,
}
impl GradientBoostingEstimator {
pub fn new(id: usize, config: &BaseEstimatorConfig) -> Self {
Self {
id,
config: config.clone(),
}
}
}
impl BaseEstimator for GradientBoostingEstimator {
fn fit(&self, x: &ArrayView2<f64>, _y: &ArrayView1<f64>) -> Result<Box<dyn TrainedBaseModel>> {
std::thread::sleep(Duration::from_millis(15));
Ok(Box::new(TrainedGradientBoostingModel {
id: self.id,
feature_count: x.ncols(),
}))
}
fn get_config(&self) -> &BaseEstimatorConfig {
&self.config
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct TrainedGradientBoostingModel {
id: usize,
feature_count: usize,
}
impl TrainedBaseModel for TrainedGradientBoostingModel {
fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
let predictions = Array1::from_iter(x.rows().into_iter().map(|row| row.sum() * 0.1));
Ok(predictions)
}
}
macro_rules! impl_base_estimator {
($estimator:ident, $model:ident, $sleep_ms:expr, $prediction_fn:expr) => {
#[derive(Debug)]
pub struct $estimator {
id: usize,
config: BaseEstimatorConfig,
}
impl $estimator {
pub fn new(id: usize, config: &BaseEstimatorConfig) -> Self {
Self {
id,
config: config.clone(),
}
}
}
impl BaseEstimator for $estimator {
fn fit(
&self,
x: &ArrayView2<f64>,
_y: &ArrayView1<f64>,
) -> Result<Box<dyn TrainedBaseModel>> {
std::thread::sleep(Duration::from_millis($sleep_ms));
Ok(Box::new($model {
id: self.id,
feature_count: x.ncols(),
}))
}
fn get_config(&self) -> &BaseEstimatorConfig {
&self.config
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct $model {
id: usize,
feature_count: usize,
}
impl TrainedBaseModel for $model {
fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
let predictions = Array1::from_iter(x.rows().into_iter().map($prediction_fn));
Ok(predictions)
}
}
};
}
impl_base_estimator!(AdaBoostEstimator, TrainedAdaBoostModel, 12, |row| row
.mean()
.unwrap_or(0.0));
impl_base_estimator!(VotingEstimator, TrainedVotingModel, 8, |row| row
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(&0.0)
* 0.5);
impl_base_estimator!(StackingEstimator, TrainedStackingModel, 20, |row| row.sum()
/ row.len() as f64);
#[derive(Debug)]
pub struct DistributedEnsemble {
config: DistributedConfig,
}
impl DistributedEnsemble {
pub fn new(config: DistributedConfig) -> Self {
Self { config }
}
pub async fn join_cluster(&self) -> Result<()> {
println!("Joining cluster at {}", self.config.coordinator_address);
Ok(())
}
pub async fn distributed_fit(
&self,
_x: &ArrayView2<'_, f64>,
_y: &ArrayView1<'_, f64>,
) -> Result<TrainedDistributedEnsemble> {
Ok(TrainedDistributedEnsemble {
cluster_size: self.config.cluster_size,
})
}
}
#[derive(Debug, Clone)]
pub struct DistributedConfig {
pub cluster_size: usize,
pub node_role: NodeRole,
pub coordinator_address: String,
pub fault_tolerance: bool,
pub checkpointing_interval: Duration,
}
impl Default for DistributedConfig {
fn default() -> Self {
Self::new()
}
}
impl DistributedConfig {
pub fn new() -> Self {
Self {
cluster_size: 1,
node_role: NodeRole::Coordinator,
coordinator_address: "localhost:8080".to_string(),
fault_tolerance: false,
checkpointing_interval: Duration::from_secs(300),
}
}
pub fn with_cluster_size(mut self, size: usize) -> Self {
self.cluster_size = size;
self
}
pub fn with_node_role(mut self, role: NodeRole) -> Self {
self.node_role = role;
self
}
pub fn with_coordinator_address(mut self, address: &str) -> Self {
self.coordinator_address = address.to_string();
self
}
pub fn with_fault_tolerance(mut self, enabled: bool) -> Self {
self.fault_tolerance = enabled;
self
}
pub fn with_checkpointing_interval(mut self, interval: Duration) -> Self {
self.checkpointing_interval = interval;
self
}
}
#[derive(Debug, Clone)]
pub enum NodeRole {
Coordinator,
Worker,
}
#[derive(Debug)]
pub struct TrainedDistributedEnsemble {
cluster_size: usize,
}
impl TrainedDistributedEnsemble {
pub fn cluster_size(&self) -> usize {
self.cluster_size
}
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ensemble_config_creation() {
let config = EnsembleConfig::random_forest()
.with_n_estimators(50)
.with_parallel_config(ParallelConfig::new().with_num_threads(4));
assert_eq!(config.n_estimators, 50);
assert_eq!(config.parallel_config.num_threads, 4);
assert!(matches!(config.ensemble_type, EnsembleType::RandomForest));
}
#[test]
fn test_parallel_config() {
let config = ParallelConfig::new()
.with_num_threads(8)
.with_batch_size(2000)
.with_work_stealing(false);
assert_eq!(config.num_threads, 8);
assert_eq!(config.batch_size, 2000);
assert!(!config.work_stealing);
}
#[test]
fn test_training_state() {
let mut state = TrainingState::new();
state.start_training(1000, 10);
assert_eq!(state.total_estimators, 10);
assert_eq!(state.progress_percentage(), 0.0);
state.update_progress(0, true);
state.update_progress(1, true);
state.update_progress(2, false);
assert_eq!(state.completed_estimators, 2);
assert_eq!(state.failed_estimators, 1);
assert_eq!(state.progress_percentage(), 20.0);
}
#[test]
fn test_base_estimator_creation() {
let config = BaseEstimatorConfig::decision_tree();
let estimator = RandomForestEstimator::new(0, &config);
assert!(estimator.get_config().parameters.contains_key("max_depth"));
}
#[test]
fn test_parallel_ensemble_creation() {
let config = EnsembleConfig::random_forest().with_n_estimators(5);
let ensemble = ParallelEnsemble::new(config);
assert_eq!(ensemble.n_estimators(), 5);
}
#[test]
fn test_sampling_strategies() {
let config = EnsembleConfig::random_forest();
let ensemble = ParallelEnsemble::new(config);
let x = Array2::from_shape_vec((10, 3), (0..30).map(|i| i as f64).collect())
.expect("valid array shape");
let y = Array1::from_shape_vec(10, (0..10).map(|i| i as f64).collect())
.expect("valid array shape");
let (sampled_x, sampled_y) = ensemble
.bootstrap_sample(&x.view(), &y.view(), 0)
.expect("expected valid value");
assert_eq!(sampled_x.shape(), x.shape());
assert_eq!(sampled_y.len(), y.len());
}
#[test]
fn test_aggregation_methods() {
let config = EnsembleConfig::random_forest();
let trained_learners = vec![
TrainedBaseEstimator {
learner_id: 0,
model: Box::new(TrainedRandomForestModel {
id: 0,
feature_count: 3,
sample_count: 10,
}),
training_time: Duration::from_millis(100),
training_accuracy: 0.8,
},
TrainedBaseEstimator {
learner_id: 1,
model: Box::new(TrainedRandomForestModel {
id: 1,
feature_count: 3,
sample_count: 10,
}),
training_time: Duration::from_millis(120),
training_accuracy: 0.9,
},
];
let ensemble = TrainedParallelEnsemble {
config,
trained_learners,
training_metrics: TrainingState::new(),
};
let x = Array2::zeros((5, 3));
let result = ensemble.parallel_predict(&x.view());
assert!(result.is_ok());
let predictions = result.expect("expected valid value");
assert_eq!(predictions.len(), 5);
}
#[test]
fn test_distributed_config() {
let config = DistributedConfig::new()
.with_cluster_size(4)
.with_node_role(NodeRole::Worker)
.with_coordinator_address("192.168.1.100:8080")
.with_fault_tolerance(true);
assert_eq!(config.cluster_size, 4);
assert!(matches!(config.node_role, NodeRole::Worker));
assert_eq!(config.coordinator_address, "192.168.1.100:8080");
assert!(config.fault_tolerance);
}
}