LossFunction

Trait LossFunction 

Source
pub trait LossFunction: Send + Sync {
    // Required methods
    fn compute_loss(
        &self,
        predictions: &[f64],
        targets: &[f64],
    ) -> DeviceResult<f64>;
    fn compute_gradients(
        &self,
        predictions: &[f64],
        targets: &[f64],
    ) -> DeviceResult<Vec<f64>>;
    fn name(&self) -> &str;
}
Expand description

Loss function trait

Required Methods§

Source

fn compute_loss( &self, predictions: &[f64], targets: &[f64], ) -> DeviceResult<f64>

Compute loss value

Source

fn compute_gradients( &self, predictions: &[f64], targets: &[f64], ) -> DeviceResult<Vec<f64>>

Compute loss gradients

Source

fn name(&self) -> &str

Get loss function name

Implementors§