use std::sync::Arc;
use ferrotorch_core::autograd::no_grad::is_grad_enabled;
use ferrotorch_core::grad_fns::activation as act;
use ferrotorch_core::grad_fns::arithmetic;
use ferrotorch_core::grad_fns::linalg::mm_differentiable;
use ferrotorch_core::grad_fns::reduction as red;
use ferrotorch_core::grad_fns::shape::transpose_2d;
use ferrotorch_core::ops::elementwise::{binary_map, mean as elem_mean};
use ferrotorch_core::tensor::GradFn;
use ferrotorch_core::{Float, FerrotorchError, FerrotorchResult, Tensor, TensorStorage};
pub fn linear<T: Float>(
input: &Tensor<T>,
weight: &Tensor<T>,
bias: Option<&Tensor<T>>,
) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 2 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"functional::linear expects 2D input [B, in_features], got shape {:?}",
input.shape()
),
});
}
if weight.ndim() != 2 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"functional::linear expects 2D weight [out, in], got shape {:?}",
weight.shape()
),
});
}
let in_features = input.shape()[1];
let weight_in = weight.shape()[1];
let out_features = weight.shape()[0];
if in_features != weight_in {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"functional::linear: input has {} features but weight expects {}",
in_features, weight_in,
),
});
}
let weight_t = transpose_2d(weight)?;
let output = mm_differentiable(input, &weight_t)?;
match bias {
Some(b) => {
if b.ndim() != 1 || b.shape()[0] != out_features {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"functional::linear: bias shape {:?} does not match out_features {}",
b.shape(),
out_features,
),
});
}
let bias_data = b.data()?;
let bias_2d = Tensor::from_storage(
TensorStorage::cpu(bias_data.to_vec()),
vec![1, out_features],
b.requires_grad(),
)?;
arithmetic::add(&output, &bias_2d)
}
None => Ok(output),
}
}
#[inline]
pub fn relu<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::relu(input)
}
#[inline]
pub fn sigmoid<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::sigmoid(input)
}
#[inline]
pub fn tanh<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::tanh(input)
}
#[inline]
pub fn gelu<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::gelu(input)
}
#[inline]
pub fn silu<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::silu(input)
}
#[inline]
pub fn softmax<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::softmax(input)
}
#[inline]
pub fn log_softmax<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::log_softmax(input)
}
pub fn leaky_relu<T: Float>(
input: &Tensor<T>,
negative_slope: f64,
) -> FerrotorchResult<Tensor<T>> {
if (negative_slope - 0.0).abs() < f64::EPSILON {
return act::relu(input);
}
if (negative_slope - 1.0).abs() < f64::EPSILON {
return Ok(input.clone());
}
let relu_x = act::relu(input)?;
let scale = T::from(1.0 - negative_slope).unwrap();
let slope = T::from(negative_slope).unwrap();
let scale_tensor = ferrotorch_core::scalar(scale)?;
let slope_tensor = ferrotorch_core::scalar(slope)?;
let scaled_relu = arithmetic::mul(&relu_x, &scale_tensor)?;
let scaled_x = arithmetic::mul(input, &slope_tensor)?;
arithmetic::add(&scaled_relu, &scaled_x)
}
#[inline]
pub fn sum<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
red::sum(input)
}
#[inline]
pub fn mean<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
red::mean(input)
}
fn xorshift_seed() -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::time::SystemTime;
let mut hasher = DefaultHasher::new();
SystemTime::now().hash(&mut hasher);
std::thread::current().id().hash(&mut hasher);
let mut state = hasher.finish();
if state == 0 {
state = 0xdeadbeefcafe;
}
state
}
#[inline]
fn xorshift_next(state: &mut u64) -> f64 {
*state ^= *state << 13;
*state ^= *state >> 7;
*state ^= *state << 17;
(*state as f64) / (u64::MAX as f64)
}
#[derive(Debug)]
struct DropoutBackward<T: Float> {
input: Tensor<T>,
scaled_mask: Vec<T>,
}
impl<T: Float> GradFn<T> for DropoutBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go_data = grad_output.data()?;
let grad_a: Vec<T> = go_data
.iter()
.zip(self.scaled_mask.iter())
.map(|(&g, &m)| g * m)
.collect();
Some(Tensor::from_storage(
TensorStorage::cpu(grad_a),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"DropoutBackward"
}
}
pub fn dropout<T: Float>(
input: &Tensor<T>,
p: f64,
training: bool,
) -> FerrotorchResult<Tensor<T>> {
if !(0.0..1.0).contains(&p) {
return Err(FerrotorchError::InvalidArgument {
message: format!("dropout probability must be in [0, 1), got {p}"),
});
}
if !training || p == 0.0 {
return Ok(input.clone());
}
let numel = input.numel();
let scale = T::from(1.0 / (1.0 - p)).unwrap();
let zero = <T as num_traits::Zero>::zero();
let mut state = xorshift_seed();
let scaled_mask: Vec<T> = (0..numel)
.map(|_| {
if xorshift_next(&mut state) < p {
zero
} else {
scale
}
})
.collect();
let input_data = input.data()?;
let output_data: Vec<T> = input_data
.iter()
.zip(scaled_mask.iter())
.map(|(&x, &m)| x * m)
.collect();
if is_grad_enabled() && input.requires_grad() {
Tensor::from_operation(
TensorStorage::cpu(output_data),
input.shape().to_vec(),
Arc::new(DropoutBackward {
input: input.clone(),
scaled_mask,
}),
)
} else {
Tensor::from_storage(TensorStorage::cpu(output_data), input.shape().to_vec(), false)
}
}
pub fn mse_loss<T: Float>(
pred: &Tensor<T>,
target: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
if pred.shape() != target.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"mse_loss: pred shape {:?} != target shape {:?}",
pred.shape(),
target.shape(),
),
});
}
let diff = binary_map(pred, target, |p, t| p - t)?;
let sq = ferrotorch_core::ops::elementwise::unary_map(&diff, |x| x * x)?;
let reduced = elem_mean(&sq)?;
if is_grad_enabled() && pred.requires_grad() {
let grad_fn = Arc::new(MSEBackward {
pred: pred.clone(),
target: target.clone(),
});
Tensor::from_operation(
TensorStorage::cpu(reduced.data()?.to_vec()),
reduced.shape().to_vec(),
grad_fn,
)
} else {
Ok(reduced)
}
}
#[derive(Debug)]
struct MSEBackward<T: Float> {
pred: Tensor<T>,
target: Tensor<T>,
}
impl<T: Float> GradFn<T> for MSEBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let pred_data = self.pred.data()?;
let target_data = self.target.data()?;
let grad_data = grad_output.data()?;
let two = T::from(2.0).unwrap();
let n = T::from(pred_data.len()).unwrap();
let go = grad_data[0];
let result: Vec<T> = pred_data
.iter()
.zip(target_data.iter())
.map(|(&p, &t)| two * (p - t) * go / n)
.collect();
let grad_input = Tensor::from_storage(
TensorStorage::cpu(result),
self.pred.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.pred]
}
fn name(&self) -> &'static str {
"MSEBackward"
}
}
pub fn cross_entropy<T: Float>(
logits: &Tensor<T>,
targets: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
let shape = logits.shape();
if shape.len() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"cross_entropy: expected 2D logits [B, C], got shape {:?}",
shape,
),
});
}
let batch = shape[0];
let classes = shape[1];
if targets.shape() != [batch] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"cross_entropy: target shape {:?} does not match batch size {}",
targets.shape(),
batch,
),
});
}
let logits_data = logits.data()?;
let targets_data = targets.data()?;
let mut log_probs = vec![<T as num_traits::Zero>::zero(); batch * classes];
let mut softmax_out = vec![<T as num_traits::Zero>::zero(); batch * classes];
for b in 0..batch {
let base = b * classes;
let mut max_val = logits_data[base];
for c in 1..classes {
if logits_data[base + c] > max_val {
max_val = logits_data[base + c];
}
}
let mut sum_exp = <T as num_traits::Zero>::zero();
for c in 0..classes {
let e = (logits_data[base + c] - max_val).exp();
softmax_out[base + c] = e;
sum_exp = sum_exp + e;
}
let log_sum = sum_exp.ln();
for c in 0..classes {
softmax_out[base + c] = softmax_out[base + c] / sum_exp;
log_probs[base + c] = logits_data[base + c] - max_val - log_sum;
}
}
let mut total_loss = <T as num_traits::Zero>::zero();
for b in 0..batch {
let base = b * classes;
let target_class = targets_data[b].to_usize().unwrap_or(0);
total_loss = total_loss - log_probs[base + target_class];
}
let loss_val = total_loss / T::from(batch).unwrap();
let reduced = Tensor::from_storage(
TensorStorage::cpu(vec![loss_val]),
vec![],
false,
)?;
if is_grad_enabled() && logits.requires_grad() {
let softmax_tensor = Tensor::from_storage(
TensorStorage::cpu(softmax_out),
vec![batch, classes],
false,
)?;
let grad_fn = Arc::new(CrossEntropyBackward {
logits: logits.clone(),
targets: targets.clone(),
softmax: softmax_tensor,
});
Tensor::from_operation(
TensorStorage::cpu(reduced.data()?.to_vec()),
reduced.shape().to_vec(),
grad_fn,
)
} else {
Ok(reduced)
}
}
#[derive(Debug)]
struct CrossEntropyBackward<T: Float> {
logits: Tensor<T>,
targets: Tensor<T>,
softmax: Tensor<T>,
}
impl<T: Float> GradFn<T> for CrossEntropyBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let shape = self.logits.shape();
let batch = shape[0];
let classes = shape[1];
let sm_data = self.softmax.data()?;
let targets_data = self.targets.data()?;
let grad_data = grad_output.data()?;
let go = grad_data[0];
let mut result = vec![<T as num_traits::Zero>::zero(); batch * classes];
let inv_batch = T::from(1.0).unwrap() / T::from(batch).unwrap();
for b in 0..batch {
let base = b * classes;
let target_class = targets_data[b].to_usize().unwrap_or(0);
for c in 0..classes {
let one_hot = if c == target_class {
<T as num_traits::One>::one()
} else {
<T as num_traits::Zero>::zero()
};
result[base + c] = (sm_data[base + c] - one_hot) * inv_batch * go;
}
}
let grad_input = Tensor::from_storage(
TensorStorage::cpu(result),
self.logits.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.logits]
}
fn name(&self) -> &'static str {
"CrossEntropyBackward"
}
}
#[cfg(test)]
mod tests {
use super::*;
use ferrotorch_core::TensorStorage;
fn leaf(data: &[f32], shape: &[usize], requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
shape.to_vec(),
requires_grad,
)
.unwrap()
}
fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
assert_eq!(
actual.len(),
expected.len(),
"length mismatch: {} vs {}",
actual.len(),
expected.len(),
);
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < tol,
"index {i}: actual={a} expected={e} diff={}",
(a - e).abs(),
);
}
}
#[test]
fn test_linear_no_bias() {
let weight = leaf(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &[2, 3], false);
let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let output = linear(&input, &weight, None).unwrap();
assert_eq!(output.shape(), &[2, 2]);
assert_close(output.data().unwrap(), &[1.0, 2.0, 4.0, 5.0], 1e-6);
}
#[test]
fn test_linear_with_bias() {
let weight = leaf(&[1.0, 0.0, 0.0, 1.0], &[2, 2], false);
let bias = leaf(&[10.0, 20.0], &[2], false);
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
let output = linear(&input, &weight, Some(&bias)).unwrap();
assert_eq!(output.shape(), &[2, 2]);
assert_close(output.data().unwrap(), &[11.0, 22.0, 13.0, 24.0], 1e-6);
}
#[test]
fn test_linear_matches_module() {
use crate::linear::Linear;
use crate::module::Module;
use crate::parameter::Parameter;
let mut layer = Linear::<f32>::new(3, 2, true).unwrap();
layer.weight = Parameter::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
*layer.bias.as_mut().unwrap() = Parameter::from_slice(&[0.1, 0.2], &[2]).unwrap();
let input = leaf(&[1.0, 0.0, -1.0, 2.0, 1.0, 0.0], &[2, 3], false);
let module_out = layer.forward(&input).unwrap();
let func_out = linear(
&input,
layer.weight.tensor(),
Some(layer.bias.as_ref().unwrap().tensor()),
)
.unwrap();
assert_eq!(module_out.shape(), func_out.shape());
assert_close(module_out.data().unwrap(), func_out.data().unwrap(), 1e-5);
}
#[test]
fn test_linear_wrong_input_dims() {
let weight = leaf(&[1.0; 6], &[2, 3], false);
let input_1d = leaf(&[1.0, 2.0, 3.0], &[3], false);
assert!(linear(&input_1d, &weight, None).is_err());
}
#[test]
fn test_linear_wrong_weight_dims() {
let weight = leaf(&[1.0; 6], &[6], false);
let input = leaf(&[1.0; 6], &[2, 3], false);
assert!(linear(&input, &weight, None).is_err());
}
#[test]
fn test_linear_feature_mismatch() {
let weight = leaf(&[1.0; 8], &[2, 4], false);
let input = leaf(&[1.0; 6], &[2, 3], false);
assert!(linear(&input, &weight, None).is_err());
}
#[test]
fn test_linear_bias_shape_mismatch() {
let weight = leaf(&[1.0; 6], &[2, 3], false);
let bias = leaf(&[1.0; 3], &[3], false); let input = leaf(&[1.0; 6], &[2, 3], false);
assert!(linear(&input, &weight, Some(&bias)).is_err());
}
#[test]
fn test_relu_matches_core() {
let input = leaf(&[-2.0, -1.0, 0.0, 1.0, 2.0], &[5], false);
let func_out = relu(&input).unwrap();
let core_out = act::relu(&input).unwrap();
assert_close(func_out.data().unwrap(), core_out.data().unwrap(), 1e-7);
}
#[test]
fn test_relu_values() {
let input = leaf(&[-3.0, -1.0, 0.0, 0.5, 2.0], &[5], false);
let output = relu(&input).unwrap();
assert_close(output.data().unwrap(), &[0.0, 0.0, 0.0, 0.5, 2.0], 1e-7);
}
#[test]
fn test_sigmoid_values() {
let input = leaf(&[0.0], &[1], false);
let output = sigmoid(&input).unwrap();
assert!((output.data().unwrap()[0] - 0.5).abs() < 1e-6);
}
#[test]
fn test_tanh_values() {
let input = leaf(&[0.0], &[1], false);
let output = tanh(&input).unwrap();
assert!(output.data().unwrap()[0].abs() < 1e-6);
}
#[test]
fn test_gelu_positive() {
let input = leaf(&[1.0, 2.0], &[2], false);
let output = gelu(&input).unwrap();
let d = output.data().unwrap();
assert!(d[0] > 0.0);
assert!(d[1] > 0.0);
}
#[test]
fn test_silu_zero() {
let input = leaf(&[0.0], &[1], false);
let output = silu(&input).unwrap();
assert!(output.data().unwrap()[0].abs() < 1e-6);
}
#[test]
fn test_softmax_sums_to_one() {
let input = leaf(&[1.0, 2.0, 3.0], &[3], false);
let output = softmax(&input).unwrap();
let d = output.data().unwrap();
let total: f32 = d.iter().sum();
assert!((total - 1.0).abs() < 1e-5);
}
#[test]
fn test_log_softmax_negative() {
let input = leaf(&[1.0, 2.0, 3.0], &[3], false);
let output = log_softmax(&input).unwrap();
let d = output.data().unwrap();
assert!(d.iter().all(|&v| v <= 0.0));
}
#[test]
fn test_leaky_relu_values() {
let input = leaf(&[-2.0, -1.0, 0.0, 1.0, 2.0], &[5], false);
let output = leaky_relu(&input, 0.01).unwrap();
let d = output.data().unwrap();
assert!((d[0] - (-0.02)).abs() < 1e-5);
assert!((d[1] - (-0.01)).abs() < 1e-5);
assert!((d[2] - 0.0).abs() < 1e-5);
assert!((d[3] - 1.0).abs() < 1e-5);
assert!((d[4] - 2.0).abs() < 1e-5);
}
#[test]
fn test_leaky_relu_zero_slope_is_relu() {
let input = leaf(&[-2.0, 0.0, 3.0], &[3], false);
let lrelu_out = leaky_relu(&input, 0.0).unwrap();
let relu_out = relu(&input).unwrap();
assert_close(lrelu_out.data().unwrap(), relu_out.data().unwrap(), 1e-7);
}
#[test]
fn test_leaky_relu_one_slope_is_identity() {
let input = leaf(&[-2.0, 0.0, 3.0], &[3], false);
let output = leaky_relu(&input, 1.0).unwrap();
assert_close(output.data().unwrap(), &[-2.0, 0.0, 3.0], 1e-7);
}
#[test]
fn test_sum_values() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let output = sum(&input).unwrap();
assert!((output.item().unwrap() - 10.0).abs() < 1e-6);
}
#[test]
fn test_mean_values() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let output = mean(&input).unwrap();
assert!((output.item().unwrap() - 2.5).abs() < 1e-6);
}
#[test]
fn test_dropout_eval_is_identity() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0], &[5], false);
let output = dropout(&input, 0.5, false).unwrap();
assert!(output.is_same(&input));
}
#[test]
fn test_dropout_zero_prob_is_identity() {
let input = leaf(&[1.0, 2.0, 3.0], &[3], false);
let output = dropout(&input, 0.0, true).unwrap();
assert!(output.is_same(&input));
}
#[test]
fn test_dropout_invalid_p() {
let input = leaf(&[1.0], &[1], false);
assert!(dropout(&input, 1.0, true).is_err());
assert!(dropout(&input, -0.1, true).is_err());
assert!(dropout(&input, 1.5, true).is_err());
}
#[test]
fn test_dropout_rate_approximately_correct() {
let input = ferrotorch_core::ones::<f32>(&[100_000]).unwrap();
let output = dropout(&input, 0.5, true).unwrap();
let data = output.data().unwrap();
let zeros = data.iter().filter(|&&x| x == 0.0).count();
let rate = zeros as f64 / data.len() as f64;
assert!(
(rate - 0.5).abs() < 0.05,
"dropout rate = {rate}, expected ~0.5"
);
let non_zero: Vec<f32> = data.iter().copied().filter(|&x| x != 0.0).collect();
assert!(!non_zero.is_empty());
for &v in &non_zero {
assert!(
(v - 2.0).abs() < 1e-6,
"surviving element = {v}, expected 2.0"
);
}
}
#[test]
fn test_dropout_training_flag() {
let input = ferrotorch_core::ones::<f32>(&[1000]).unwrap();
let output = dropout(&input, 0.99, false).unwrap();
assert!(output.is_same(&input));
}
#[test]
fn test_mse_loss_zero() {
let pred = leaf(&[1.0, 2.0, 3.0], &[3], false);
let target = leaf(&[1.0, 2.0, 3.0], &[3], false);
let loss = mse_loss(&pred, &target).unwrap();
assert!(loss.item().unwrap().abs() < 1e-7);
}
#[test]
fn test_mse_loss_known_value() {
let pred = leaf(&[1.0, 2.0], &[2], false);
let target = leaf(&[3.0, 4.0], &[2], false);
let loss = mse_loss(&pred, &target).unwrap();
assert!((loss.item().unwrap() - 4.0).abs() < 1e-6);
}
#[test]
fn test_mse_loss_shape_mismatch() {
let pred = leaf(&[1.0, 2.0], &[2], false);
let target = leaf(&[1.0, 2.0, 3.0], &[3], false);
assert!(mse_loss(&pred, &target).is_err());
}
#[test]
fn test_cross_entropy_basic() {
let logits = leaf(&[10.0, 0.0, 0.0, 10.0], &[2, 2], false);
let targets = leaf(&[0.0, 1.0], &[2], false);
let loss = cross_entropy(&logits, &targets).unwrap();
assert!(loss.item().unwrap() < 0.01);
}
#[test]
fn test_cross_entropy_wrong_logits_shape() {
let logits = leaf(&[1.0, 2.0, 3.0], &[3], false);
let targets = leaf(&[0.0], &[1], false);
assert!(cross_entropy(&logits, &targets).is_err());
}
#[test]
fn test_cross_entropy_target_batch_mismatch() {
let logits = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
let targets = leaf(&[0.0, 1.0, 0.0], &[3], false);
assert!(cross_entropy(&logits, &targets).is_err());
}
#[test]
fn test_cross_entropy_uniform_logits() {
let logits = leaf(&[0.0, 0.0, 0.0, 0.0], &[2, 2], false);
let targets = leaf(&[0.0, 1.0], &[2], false);
let loss = cross_entropy(&logits, &targets).unwrap();
let expected = (2.0_f32).ln();
assert!(
(loss.item().unwrap() - expected).abs() < 1e-5,
"loss = {}, expected ln(2) = {}",
loss.item().unwrap(),
expected,
);
}
}