native_neural_network 0.1.6

Lib no_std Rust for native neural network (.rnn)
Documentation
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum LossKind {
    Mse,
    Mae,
    Huber { delta: f32 },
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum LossError {
    Empty,
    ShapeMismatch,
    NonFinite,
}

pub fn loss_and_gradient(
    kind: LossKind,
    prediction: &[f32],
    target: &[f32],
    grad_out: &mut [f32],
) -> Result<f32, LossError> {
    if prediction.is_empty() {
        return Err(LossError::Empty);
    }
    if prediction.len() != target.len() || grad_out.len() < prediction.len() {
        return Err(LossError::ShapeMismatch);
    }

    let n = prediction.len() as f32;
    let mut loss = 0.0f32;

    match kind {
        LossKind::Mse => {
            let scale = 2.0 / n;
            for i in 0..prediction.len() {
                let diff = prediction[i] - target[i];
                loss += diff * diff;
                grad_out[i] = scale * diff;
            }
            loss /= n;
        }
        LossKind::Mae => {
            let inv_n = 1.0 / n;
            for i in 0..prediction.len() {
                let diff = prediction[i] - target[i];
                loss += diff.abs();
                grad_out[i] = if diff > 0.0 {
                    inv_n
                } else if diff < 0.0 {
                    -inv_n
                } else {
                    0.0
                };
            }
            loss *= inv_n;
        }
        LossKind::Huber { delta } => {
            if !delta.is_finite() || delta <= 0.0 {
                return Err(LossError::NonFinite);
            }
            let inv_n = 1.0 / n;
            for i in 0..prediction.len() {
                let diff = prediction[i] - target[i];
                let ad = diff.abs();
                if ad <= delta {
                    loss += 0.5 * diff * diff;
                    grad_out[i] = diff * inv_n;
                } else {
                    loss += delta * (ad - 0.5 * delta);
                    grad_out[i] = if diff > 0.0 {
                        delta * inv_n
                    } else {
                        -delta * inv_n
                    };
                }
            }
            loss *= inv_n;
        }
    }

    if !loss.is_finite() {
        return Err(LossError::NonFinite);
    }

    Ok(loss)
}