simian 0.2.1

A command-line tool for exploring and implementing Machine Learning algorithms in Rust.
1
2
3
4
5
6
7
8
9
10
use candle_core::{Result, Tensor};

pub mod bce;
pub mod mse;

pub trait Loss {
  /// Computes the loss value (a scalar) and the gradient of the loss with respect to the predictions.
  /// Returns a tuple: (loss_value, d_loss_d_y_pred)
  fn compute(&self, y_pred: &Tensor, y_true: &Tensor) -> Result<(f32, Tensor)>;
}