pub struct MultiTargetSGBT<L: Loss = SquaredLoss> { /* private fields */ }Available on crate feature
alloc only.Expand description
Multi-target regression SGBT.
Wraps T independent SGBT<L> models, one per target dimension.
Each model is trained and predicts independently, sharing the same
configuration and loss function.
§Examples
use irithyll::ensemble::multi_target::MultiTargetSGBT;
use irithyll::SGBTConfig;
let config = SGBTConfig::builder()
.n_steps(10)
.learning_rate(0.1)
.grace_period(10)
.build()
.unwrap();
let mut model = MultiTargetSGBT::new(config, 3).unwrap();
model.train_one(&[1.0, 2.0], &[0.5, 1.0, 1.5]);
let preds = model.predict(&[1.0, 2.0]);
assert_eq!(preds.len(), 3);Implementations§
Source§impl MultiTargetSGBT<SquaredLoss>
impl MultiTargetSGBT<SquaredLoss>
Sourcepub fn new(config: SGBTConfig, n_targets: usize) -> Result<Self>
pub fn new(config: SGBTConfig, n_targets: usize) -> Result<Self>
Create a new multi-target SGBT with squared loss (default).
§Errors
Returns IrithyllError::InvalidConfig if n_targets < 1.
Source§impl<L: Loss + Clone> MultiTargetSGBT<L>
impl<L: Loss + Clone> MultiTargetSGBT<L>
Sourcepub fn with_loss(config: SGBTConfig, loss: L, n_targets: usize) -> Result<Self>
pub fn with_loss(config: SGBTConfig, loss: L, n_targets: usize) -> Result<Self>
Create a new multi-target SGBT with a custom loss function.
The loss is cloned for each target model.
§Errors
Returns IrithyllError::InvalidConfig if n_targets < 1.
Sourcepub fn train_batch(
&mut self,
feature_matrix: &[Vec<f64>],
target_matrix: &[Vec<f64>],
)
pub fn train_batch( &mut self, feature_matrix: &[Vec<f64>], target_matrix: &[Vec<f64>], )
Train on a batch of multi-target samples.
Sourcepub fn predict(&self, features: &[f64]) -> Vec<f64>
pub fn predict(&self, features: &[f64]) -> Vec<f64>
Predict all target values for a feature vector.
Sourcepub fn n_samples_seen(&self) -> u64
pub fn n_samples_seen(&self) -> u64
Total samples trained.
Trait Implementations§
Auto Trait Implementations§
impl<L> Freeze for MultiTargetSGBT<L>
impl<L = SquaredLoss> !RefUnwindSafe for MultiTargetSGBT<L>
impl<L> Send for MultiTargetSGBT<L>
impl<L> Sync for MultiTargetSGBT<L>
impl<L> Unpin for MultiTargetSGBT<L>where
L: Unpin,
impl<L> UnsafeUnpin for MultiTargetSGBT<L>
impl<L = SquaredLoss> !UnwindSafe for MultiTargetSGBT<L>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more