use std::sync::Arc;
use ferrotorch_core::autograd::no_grad::is_grad_enabled;
use ferrotorch_core::grad_fns::activation as act;
use ferrotorch_core::grad_fns::arithmetic;
use ferrotorch_core::grad_fns::linalg::mm_differentiable;
use ferrotorch_core::grad_fns::reduction as red;
use ferrotorch_core::grad_fns::shape::transpose_2d;
use ferrotorch_core::grad_fns::transcendental as trans;
use ferrotorch_core::ops::elementwise::{binary_map, mean as elem_mean};
use ferrotorch_core::tensor::GradFn;
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
pub fn linear<T: Float>(
input: &Tensor<T>,
weight: &Tensor<T>,
bias: Option<&Tensor<T>>,
) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 2 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"functional::linear expects 2D input [B, in_features], got shape {:?}",
input.shape()
),
});
}
if weight.ndim() != 2 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"functional::linear expects 2D weight [out, in], got shape {:?}",
weight.shape()
),
});
}
let in_features = input.shape()[1];
let weight_in = weight.shape()[1];
let out_features = weight.shape()[0];
if in_features != weight_in {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"functional::linear: input has {} features but weight expects {}",
in_features, weight_in,
),
});
}
let weight_t = transpose_2d(weight)?;
let output = mm_differentiable(input, &weight_t)?;
match bias {
Some(b) => {
if b.ndim() != 1 || b.shape()[0] != out_features {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"functional::linear: bias shape {:?} does not match out_features {}",
b.shape(),
out_features,
),
});
}
let bias_data = b.data()?;
let bias_2d = Tensor::from_storage(
TensorStorage::cpu(bias_data.to_vec()),
vec![1, out_features],
b.requires_grad(),
)?;
arithmetic::add(&output, &bias_2d)
}
None => Ok(output),
}
}
#[inline]
pub fn relu<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::relu(input)
}
#[inline]
pub fn sigmoid<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::sigmoid(input)
}
#[inline]
pub fn tanh<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::tanh(input)
}
#[inline]
pub fn gelu<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::gelu(input)
}
#[inline]
pub fn gelu_with<T: Float>(
input: &Tensor<T>,
approximate: act::GeluApproximate,
) -> FerrotorchResult<Tensor<T>> {
act::gelu_with(input, approximate)
}
#[inline]
pub fn silu<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::silu(input)
}
#[inline]
pub fn softmax<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::softmax(input)
}
#[inline]
pub fn log_softmax<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::log_softmax(input)
}
pub fn leaky_relu<T: Float>(input: &Tensor<T>, negative_slope: f64) -> FerrotorchResult<Tensor<T>> {
if (negative_slope - 0.0).abs() < f64::EPSILON {
return act::relu(input);
}
if (negative_slope - 1.0).abs() < f64::EPSILON {
return Ok(input.clone());
}
let relu_x = act::relu(input)?;
let scale = T::from(1.0 - negative_slope).unwrap();
let slope = T::from(negative_slope).unwrap();
let scale_tensor = ferrotorch_core::scalar(scale)?;
let slope_tensor = ferrotorch_core::scalar(slope)?;
let scaled_relu = arithmetic::mul(&relu_x, &scale_tensor)?;
let scaled_x = arithmetic::mul(input, &slope_tensor)?;
arithmetic::add(&scaled_relu, &scaled_x)
}
#[inline]
pub fn sum<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
red::sum(input)
}
#[inline]
pub fn mean<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
red::mean(input)
}
fn xorshift_seed() -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::time::SystemTime;
let mut hasher = DefaultHasher::new();
SystemTime::now().hash(&mut hasher);
std::thread::current().id().hash(&mut hasher);
let mut state = hasher.finish();
if state == 0 {
state = 0xdeadbeefcafe;
}
state
}
#[inline]
fn xorshift_next(state: &mut u64) -> f64 {
*state ^= *state << 13;
*state ^= *state >> 7;
*state ^= *state << 17;
(*state as f64) / (u64::MAX as f64)
}
#[derive(Debug)]
struct DropoutBackward<T: Float> {
input: Tensor<T>,
scaled_mask: Vec<T>,
}
impl<T: Float> GradFn<T> for DropoutBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let da = if self.input.requires_grad() {
let go_data = grad_output.data_vec()?;
let grad_a: Vec<T> = go_data
.iter()
.zip(self.scaled_mask.iter())
.map(|(&g, &m)| g * m)
.collect();
let g = Tensor::from_storage(
TensorStorage::cpu(grad_a),
self.input.shape().to_vec(),
false,
)?;
Some(if self.input.is_cuda() {
g.to(self.input.device())?
} else {
g
})
} else {
None
};
Ok(vec![da])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"DropoutBackward"
}
}
pub fn dropout<T: Float>(input: &Tensor<T>, p: f64, training: bool) -> FerrotorchResult<Tensor<T>> {
if !(0.0..1.0).contains(&p) {
return Err(FerrotorchError::InvalidArgument {
message: format!("dropout probability must be in [0, 1), got {p}"),
});
}
if !training || p == 0.0 {
return Ok(input.clone());
}
let numel = input.numel();
let scale = T::from(1.0 / (1.0 - p)).unwrap();
let zero = <T as num_traits::Zero>::zero();
let mut state = xorshift_seed();
let scaled_mask: Vec<T> = (0..numel)
.map(|_| {
if xorshift_next(&mut state) < p {
zero
} else {
scale
}
})
.collect();
let device = input.device();
let input_data = input.data_vec()?;
let output_data: Vec<T> = input_data
.iter()
.zip(scaled_mask.iter())
.map(|(&x, &m)| x * m)
.collect();
let result = if is_grad_enabled() && input.requires_grad() {
Tensor::from_operation(
TensorStorage::cpu(output_data),
input.shape().to_vec(),
Arc::new(DropoutBackward {
input: input.clone(),
scaled_mask,
}),
)?
} else {
Tensor::from_storage(
TensorStorage::cpu(output_data),
input.shape().to_vec(),
false,
)?
};
if device.is_cuda() {
result.to(device)
} else {
Ok(result)
}
}
pub fn mse_loss<T: Float>(pred: &Tensor<T>, target: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if pred.shape() != target.shape() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"mse_loss: pred shape {:?} != target shape {:?}",
pred.shape(),
target.shape(),
),
});
}
let diff = binary_map(pred, target, |p, t| p - t)?;
let sq = ferrotorch_core::ops::elementwise::unary_map(&diff, |x| x * x)?;
let reduced = elem_mean(&sq)?;
if is_grad_enabled() && pred.requires_grad() {
let grad_fn = Arc::new(MSEBackward {
pred: pred.clone(),
target: target.clone(),
});
Tensor::from_operation(
TensorStorage::cpu(reduced.data_vec()?),
reduced.shape().to_vec(),
grad_fn,
)
} else {
Ok(reduced)
}
}
#[derive(Debug)]
struct MSEBackward<T: Float> {
pred: Tensor<T>,
target: Tensor<T>,
}
impl<T: Float> GradFn<T> for MSEBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let pred_data = self.pred.data_vec()?;
let target_data = self.target.data_vec()?;
let grad_data = grad_output.data_vec()?;
let two = T::from(2.0).unwrap();
let n = T::from(pred_data.len()).unwrap();
let go = grad_data[0];
let result: Vec<T> = pred_data
.iter()
.zip(target_data.iter())
.map(|(&p, &t)| two * (p - t) * go / n)
.collect();
let grad_input = Tensor::from_storage(
TensorStorage::cpu(result),
self.pred.shape().to_vec(),
false,
)?;
let grad_input = grad_input.to(self.pred.device())?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.pred]
}
fn name(&self) -> &'static str {
"MSEBackward"
}
}
pub fn cross_entropy<T: Float>(
logits: &Tensor<T>,
targets: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
let shape = logits.shape();
if shape.len() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"cross_entropy: expected 2D logits [B, C], got shape {:?}",
shape,
),
});
}
let batch = shape[0];
let classes = shape[1];
if targets.shape() != [batch] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"cross_entropy: target shape {:?} does not match batch size {}",
targets.shape(),
batch,
),
});
}
let logits_data = logits.data_vec()?;
let targets_data = targets.data_vec()?;
let mut log_probs = vec![<T as num_traits::Zero>::zero(); batch * classes];
let mut softmax_out = vec![<T as num_traits::Zero>::zero(); batch * classes];
for b in 0..batch {
let base = b * classes;
let mut max_val = logits_data[base];
for c in 1..classes {
if logits_data[base + c] > max_val {
max_val = logits_data[base + c];
}
}
let mut sum_exp = <T as num_traits::Zero>::zero();
for c in 0..classes {
let e = (logits_data[base + c] - max_val).exp();
softmax_out[base + c] = e;
sum_exp += e;
}
let log_sum = sum_exp.ln();
for c in 0..classes {
softmax_out[base + c] = softmax_out[base + c] / sum_exp;
log_probs[base + c] = logits_data[base + c] - max_val - log_sum;
}
}
let mut total_loss = <T as num_traits::Zero>::zero();
for (b, &target) in targets_data.iter().enumerate() {
let base = b * classes;
let target_class = target.to_usize().unwrap_or(0);
total_loss = total_loss - log_probs[base + target_class];
}
let loss_val = total_loss / T::from(batch).unwrap();
let reduced = Tensor::from_storage(TensorStorage::cpu(vec![loss_val]), vec![], false)?;
if is_grad_enabled() && logits.requires_grad() {
let softmax_tensor =
Tensor::from_storage(TensorStorage::cpu(softmax_out), vec![batch, classes], false)?;
let grad_fn = Arc::new(CrossEntropyBackward {
logits: logits.clone(),
targets: targets.clone(),
softmax: softmax_tensor,
});
Tensor::from_operation(
TensorStorage::cpu(reduced.data_vec()?),
reduced.shape().to_vec(),
grad_fn,
)
} else {
Ok(reduced)
}
}
#[derive(Debug)]
struct CrossEntropyBackward<T: Float> {
logits: Tensor<T>,
targets: Tensor<T>,
softmax: Tensor<T>,
}
impl<T: Float> GradFn<T> for CrossEntropyBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let shape = self.logits.shape();
let batch = shape[0];
let classes = shape[1];
let sm_data = self.softmax.data_vec()?;
let targets_data = self.targets.data_vec()?;
let grad_data = grad_output.data_vec()?;
let go = grad_data[0];
let mut result = vec![<T as num_traits::Zero>::zero(); batch * classes];
let inv_batch = T::from(1.0).unwrap() / T::from(batch).unwrap();
for (b, &target) in targets_data.iter().enumerate() {
let base = b * classes;
let target_class = target.to_usize().unwrap_or(0);
for c in 0..classes {
let one_hot = if c == target_class {
<T as num_traits::One>::one()
} else {
<T as num_traits::Zero>::zero()
};
result[base + c] = (sm_data[base + c] - one_hot) * inv_batch * go;
}
}
let grad_input = Tensor::from_storage(
TensorStorage::cpu(result),
self.logits.shape().to_vec(),
false,
)?;
let grad_input = grad_input.to(self.logits.device())?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.logits]
}
fn name(&self) -> &'static str {
"CrossEntropyBackward"
}
}
pub use crate::upsample::{GridSampleMode, GridSamplePaddingMode, InterpolateMode};
pub fn interpolate<T: Float>(
input: &Tensor<T>,
size: Option<[usize; 2]>,
scale_factor: Option<[f64; 2]>,
mode: InterpolateMode,
align_corners: bool,
) -> FerrotorchResult<Tensor<T>> {
crate::upsample::interpolate(input, size, scale_factor, mode, align_corners)
}
pub fn grid_sample<T: Float>(
input: &Tensor<T>,
grid: &Tensor<T>,
mode: GridSampleMode,
padding_mode: GridSamplePaddingMode,
align_corners: bool,
) -> FerrotorchResult<Tensor<T>> {
crate::upsample::grid_sample(input, grid, mode, padding_mode, align_corners)
}
pub fn affine_grid<T: Float>(
theta: &Tensor<T>,
size: [usize; 4],
align_corners: bool,
) -> FerrotorchResult<Tensor<T>> {
crate::upsample::affine_grid(theta, size, align_corners)
}
pub fn pixel_shuffle<T: Float>(
input: &Tensor<T>,
upscale_factor: usize,
) -> FerrotorchResult<Tensor<T>> {
crate::upsample::pixel_shuffle(input, upscale_factor)
}
pub fn pixel_unshuffle<T: Float>(
input: &Tensor<T>,
downscale_factor: usize,
) -> FerrotorchResult<Tensor<T>> {
crate::upsample::pixel_unshuffle(input, downscale_factor)
}
pub fn unfold<T: Float>(
input: &Tensor<T>,
kernel_size: [usize; 2],
dilation: [usize; 2],
padding: [usize; 2],
stride: [usize; 2],
) -> FerrotorchResult<Tensor<T>> {
crate::upsample::unfold(input, kernel_size, dilation, padding, stride)
}
pub fn fold<T: Float>(
input: &Tensor<T>,
output_size: [usize; 2],
kernel_size: [usize; 2],
dilation: [usize; 2],
padding: [usize; 2],
stride: [usize; 2],
) -> FerrotorchResult<Tensor<T>> {
crate::upsample::fold(input, output_size, kernel_size, dilation, padding, stride)
}
#[inline]
pub fn hardtanh<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
hardtanh_with(input, T::from(-1.0).unwrap(), T::from(1.0).unwrap())
}
pub fn hardtanh_with<T: Float>(
input: &Tensor<T>,
min_val: T,
max_val: T,
) -> FerrotorchResult<Tensor<T>> {
trans::clamp(input, min_val, max_val)
}
#[inline]
pub fn relu6<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
trans::clamp(input, T::from(0.0).unwrap(), T::from(6.0).unwrap())
}
pub fn hardsigmoid<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let three = ferrotorch_core::scalar(T::from(3.0).unwrap())?;
let inv_six = ferrotorch_core::scalar(T::from(1.0 / 6.0).unwrap())?;
let shifted = arithmetic::add(input, &three)?;
let scaled = arithmetic::mul(&shifted, &inv_six)?;
trans::clamp(&scaled, T::from(0.0).unwrap(), T::from(1.0).unwrap())
}
pub fn hardswish<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let hs = hardsigmoid(input)?;
arithmetic::mul(input, &hs)
}
pub fn log_sigmoid<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let neg = arithmetic::neg(input)?;
let sp = act::softplus(&neg, 1.0, 20.0)?;
arithmetic::neg(&sp)
}
pub fn softmin<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let neg = arithmetic::neg(input)?;
act::softmax(&neg)
}
pub fn softsign<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let abs_x = arithmetic::abs(input)?;
let one = ferrotorch_core::scalar(T::from(1.0).unwrap())?;
let denom = arithmetic::add(&abs_x, &one)?;
arithmetic::div(input, &denom)
}
pub fn tanhshrink<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let t = act::tanh(input)?;
arithmetic::sub(input, &t)
}
pub fn selu<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
const ALPHA: f64 = 1.6732632423543772;
const SCALE: f64 = 1.0507009873554805;
let e = act::elu(input, ALPHA)?;
let scale = ferrotorch_core::scalar(T::from(SCALE).unwrap())?;
arithmetic::mul(&e, &scale)
}
#[inline]
pub fn softplus<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::softplus(input, 1.0, 20.0)
}
#[inline]
pub fn softplus_with<T: Float>(
input: &Tensor<T>,
beta: f64,
threshold: f64,
) -> FerrotorchResult<Tensor<T>> {
act::softplus(input, beta, threshold)
}
#[inline]
pub fn elu<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::elu(input, 1.0)
}
#[inline]
pub fn elu_with<T: Float>(input: &Tensor<T>, alpha: f64) -> FerrotorchResult<Tensor<T>> {
act::elu(input, alpha)
}
#[inline]
pub fn mish<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::mish(input)
}
pub fn glu<T: Float>(input: &Tensor<T>, dim: i64) -> FerrotorchResult<Tensor<T>> {
act::glu(input, dim)
}
pub fn prelu<T: Float>(input: &Tensor<T>, alpha: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
act::prelu(input, alpha)
}
fn apply_reduction<T: Float>(
loss: Tensor<T>,
reduction: crate::module::Reduction,
) -> FerrotorchResult<Tensor<T>> {
match reduction {
crate::module::Reduction::None => Ok(loss),
crate::module::Reduction::Sum => red::sum(&loss),
crate::module::Reduction::Mean => red::mean(&loss),
}
}
pub fn l1_loss<T: Float>(
pred: &Tensor<T>,
target: &Tensor<T>,
reduction: crate::module::Reduction,
) -> FerrotorchResult<Tensor<T>> {
let diff = arithmetic::sub(pred, target)?;
let abs_diff = arithmetic::abs(&diff)?;
apply_reduction(abs_diff, reduction)
}
pub fn binary_cross_entropy<T: Float>(
pred: &Tensor<T>,
target: &Tensor<T>,
reduction: crate::module::Reduction,
) -> FerrotorchResult<Tensor<T>> {
let one = ferrotorch_core::scalar(T::from(1.0).unwrap())?;
let log_p = trans::log(pred)?;
let one_minus_p = arithmetic::sub(&one, pred)?;
let log_1mp = trans::log(&one_minus_p)?;
let one_minus_y = arithmetic::sub(&one, target)?;
let term1 = arithmetic::mul(target, &log_p)?;
let term2 = arithmetic::mul(&one_minus_y, &log_1mp)?;
let sum = arithmetic::add(&term1, &term2)?;
let neg = arithmetic::neg(&sum)?;
apply_reduction(neg, reduction)
}
pub fn binary_cross_entropy_with_logits<T: Float>(
logits: &Tensor<T>,
target: &Tensor<T>,
reduction: crate::module::Reduction,
) -> FerrotorchResult<Tensor<T>> {
let neg_y = arithmetic::neg(target)?;
let term1 = arithmetic::mul(&neg_y, logits)?;
let sp = act::softplus(logits, 1.0, 20.0)?;
let total = arithmetic::add(&term1, &sp)?;
apply_reduction(total, reduction)
}
pub fn kl_div<T: Float>(
pred: &Tensor<T>,
target: &Tensor<T>,
reduction: crate::module::Reduction,
) -> FerrotorchResult<Tensor<T>> {
let eps = ferrotorch_core::scalar(T::from(1e-12_f64).unwrap())?;
let target_eps = arithmetic::add(target, &eps)?;
let log_target = trans::log(&target_eps)?;
let diff = arithmetic::sub(&log_target, pred)?;
let elemwise = arithmetic::mul(target, &diff)?;
apply_reduction(elemwise, reduction)
}
pub fn normalize<T: Float>(
input: &Tensor<T>,
p: f64,
dim: i64,
eps: f64,
) -> FerrotorchResult<Tensor<T>> {
let abs_x = arithmetic::abs(input)?;
let abs_p = arithmetic::pow(&abs_x, p)?;
let summed = red::sum_dim(&abs_p, dim, true)?;
let norm = arithmetic::pow(&summed, 1.0 / p)?;
let eps_t = T::from(eps).unwrap();
let clamped = trans::clamp(&norm, eps_t, T::from(f64::INFINITY).unwrap())?;
arithmetic::div(input, &clamped)
}
pub fn cosine_similarity<T: Float>(
x: &Tensor<T>,
y: &Tensor<T>,
dim: i64,
eps: f64,
) -> FerrotorchResult<Tensor<T>> {
let xy = arithmetic::mul(x, y)?;
let dot = red::sum_dim(&xy, dim, false)?;
let xx = arithmetic::mul(x, x)?;
let nx_sq = red::sum_dim(&xx, dim, false)?;
let nx = arithmetic::sqrt(&nx_sq)?;
let yy = arithmetic::mul(y, y)?;
let ny_sq = red::sum_dim(&yy, dim, false)?;
let ny = arithmetic::sqrt(&ny_sq)?;
let prod = arithmetic::mul(&nx, &ny)?;
let eps_t = T::from(eps).unwrap();
let denom = trans::clamp(&prod, eps_t, T::from(f64::INFINITY).unwrap())?;
arithmetic::div(&dot, &denom)
}
pub fn pairwise_distance<T: Float>(
x: &Tensor<T>,
y: &Tensor<T>,
p: f64,
eps: f64,
) -> FerrotorchResult<Tensor<T>> {
let diff = arithmetic::sub(x, y)?;
let abs_diff = arithmetic::abs(&diff)?;
let eps_t = ferrotorch_core::scalar(T::from(eps).unwrap())?;
let shifted = arithmetic::add(&abs_diff, &eps_t)?;
let pwr = arithmetic::pow(&shifted, p)?;
let summed = red::sum_dim(&pwr, -1, false)?;
arithmetic::pow(&summed, 1.0 / p)
}
pub fn one_hot<T: Float>(input: &Tensor<T>, num_classes: usize) -> FerrotorchResult<Tensor<T>> {
if num_classes == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "one_hot: num_classes must be > 0".into(),
});
}
let in_data = input.data_vec()?;
let in_shape = input.shape().to_vec();
let mut out_shape = in_shape.clone();
out_shape.push(num_classes);
let total: usize = in_data.len() * num_classes;
let mut out = vec![T::from(0.0).unwrap(); total];
let one = T::from(1.0).unwrap();
for (i, val) in in_data.iter().enumerate() {
let f = val.to_f64().unwrap_or(-1.0);
if !f.is_finite() || f < 0.0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"one_hot: index at flat position {i} is {f}, must be in [0, {num_classes})"
),
});
}
let idx = f.round() as usize;
if idx >= num_classes {
return Err(FerrotorchError::InvalidArgument {
message: format!("one_hot: index {idx} out of range (num_classes = {num_classes})"),
});
}
out[i * num_classes + idx] = one;
}
Tensor::from_storage(TensorStorage::cpu(out), out_shape, false)
}
use crate::conv::{Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d};
use crate::module::Module;
pub fn conv1d<T: Float>(
input: &Tensor<T>,
weight: &Tensor<T>,
bias: Option<&Tensor<T>>,
stride: usize,
padding: usize,
) -> FerrotorchResult<Tensor<T>> {
let layer = Conv1d::from_parts(weight.clone(), bias.cloned(), stride, padding)?;
layer.forward(input)
}
pub fn conv2d<T: Float>(
input: &Tensor<T>,
weight: &Tensor<T>,
bias: Option<&Tensor<T>>,
stride: (usize, usize),
padding: (usize, usize),
) -> FerrotorchResult<Tensor<T>> {
let layer = Conv2d::from_parts(weight.clone(), bias.cloned(), stride, padding)?;
layer.forward(input)
}
pub fn conv3d<T: Float>(
input: &Tensor<T>,
weight: &Tensor<T>,
bias: Option<&Tensor<T>>,
stride: (usize, usize, usize),
padding: (usize, usize, usize),
) -> FerrotorchResult<Tensor<T>> {
let layer = Conv3d::from_parts(weight.clone(), bias.cloned(), stride, padding)?;
layer.forward(input)
}
pub fn conv_transpose1d<T: Float>(
input: &Tensor<T>,
weight: &Tensor<T>,
bias: Option<&Tensor<T>>,
stride: usize,
padding: usize,
output_padding: usize,
) -> FerrotorchResult<Tensor<T>> {
let layer = ConvTranspose1d::from_parts(
weight.clone(),
bias.cloned(),
stride,
padding,
output_padding,
)?;
layer.forward(input)
}
pub fn conv_transpose2d<T: Float>(
input: &Tensor<T>,
weight: &Tensor<T>,
bias: Option<&Tensor<T>>,
stride: (usize, usize),
padding: (usize, usize),
output_padding: (usize, usize),
) -> FerrotorchResult<Tensor<T>> {
let layer = ConvTranspose2d::from_parts(
weight.clone(),
bias.cloned(),
stride,
padding,
output_padding,
)?;
layer.forward(input)
}
pub fn conv_transpose3d<T: Float>(
input: &Tensor<T>,
weight: &Tensor<T>,
bias: Option<&Tensor<T>>,
stride: (usize, usize, usize),
padding: (usize, usize, usize),
output_padding: (usize, usize, usize),
) -> FerrotorchResult<Tensor<T>> {
let layer = ConvTranspose3d::from_parts(
weight.clone(),
bias.cloned(),
stride,
padding,
output_padding,
)?;
layer.forward(input)
}
pub use crate::pooling::{
adaptive_avg_pool1d, adaptive_avg_pool2d, adaptive_avg_pool3d, adaptive_max_pool1d,
adaptive_max_pool2d, adaptive_max_pool3d, avg_pool1d, avg_pool2d, avg_pool3d, lp_pool1d,
lp_pool2d, max_pool1d, max_pool2d, max_pool3d,
};
pub use crate::padding::{
PaddingMode, functional_pad_1d as pad1d, functional_pad_2d as pad2d, functional_pad_3d as pad3d,
};
pub fn embedding<T: Float>(
input: &Tensor<T>,
weight: &Tensor<T>,
padding_idx: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let layer = crate::embedding::Embedding::from_pretrained(weight.clone(), padding_idx)?;
layer.forward(input)
}
pub fn scaled_dot_product_attention<T: Float>(
query: &Tensor<T>,
key: &Tensor<T>,
value: &Tensor<T>,
is_causal: bool,
) -> FerrotorchResult<Tensor<T>> {
crate::flash_attention::flash_attention(query, key, value, is_causal, 64)
}
#[cfg(test)]
mod tests {
use super::*;
use ferrotorch_core::TensorStorage;
fn leaf(data: &[f32], shape: &[usize], requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
shape.to_vec(),
requires_grad,
)
.unwrap()
}
fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
assert_eq!(
actual.len(),
expected.len(),
"length mismatch: {} vs {}",
actual.len(),
expected.len(),
);
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < tol,
"index {i}: actual={a} expected={e} diff={}",
(a - e).abs(),
);
}
}
#[test]
fn test_linear_no_bias() {
let weight = leaf(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &[2, 3], false);
let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
let output = linear(&input, &weight, None).unwrap();
assert_eq!(output.shape(), &[2, 2]);
assert_close(output.data().unwrap(), &[1.0, 2.0, 4.0, 5.0], 1e-6);
}
#[test]
fn test_linear_with_bias() {
let weight = leaf(&[1.0, 0.0, 0.0, 1.0], &[2, 2], false);
let bias = leaf(&[10.0, 20.0], &[2], false);
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
let output = linear(&input, &weight, Some(&bias)).unwrap();
assert_eq!(output.shape(), &[2, 2]);
assert_close(output.data().unwrap(), &[11.0, 22.0, 13.0, 24.0], 1e-6);
}
#[test]
fn test_linear_matches_module() {
use crate::linear::Linear;
use crate::module::Module;
use crate::parameter::Parameter;
let mut layer = Linear::<f32>::new(3, 2, true).unwrap();
layer.weight = Parameter::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
*layer.bias.as_mut().unwrap() = Parameter::from_slice(&[0.1, 0.2], &[2]).unwrap();
let input = leaf(&[1.0, 0.0, -1.0, 2.0, 1.0, 0.0], &[2, 3], false);
let module_out = layer.forward(&input).unwrap();
let func_out = linear(
&input,
layer.weight.tensor(),
Some(layer.bias.as_ref().unwrap().tensor()),
)
.unwrap();
assert_eq!(module_out.shape(), func_out.shape());
assert_close(module_out.data().unwrap(), func_out.data().unwrap(), 1e-5);
}
#[test]
fn test_linear_wrong_input_dims() {
let weight = leaf(&[1.0; 6], &[2, 3], false);
let input_1d = leaf(&[1.0, 2.0, 3.0], &[3], false);
assert!(linear(&input_1d, &weight, None).is_err());
}
#[test]
fn test_linear_wrong_weight_dims() {
let weight = leaf(&[1.0; 6], &[6], false);
let input = leaf(&[1.0; 6], &[2, 3], false);
assert!(linear(&input, &weight, None).is_err());
}
#[test]
fn test_linear_feature_mismatch() {
let weight = leaf(&[1.0; 8], &[2, 4], false);
let input = leaf(&[1.0; 6], &[2, 3], false);
assert!(linear(&input, &weight, None).is_err());
}
#[test]
fn test_linear_bias_shape_mismatch() {
let weight = leaf(&[1.0; 6], &[2, 3], false);
let bias = leaf(&[1.0; 3], &[3], false); let input = leaf(&[1.0; 6], &[2, 3], false);
assert!(linear(&input, &weight, Some(&bias)).is_err());
}
#[test]
fn test_relu_matches_core() {
let input = leaf(&[-2.0, -1.0, 0.0, 1.0, 2.0], &[5], false);
let func_out = relu(&input).unwrap();
let core_out = act::relu(&input).unwrap();
assert_close(func_out.data().unwrap(), core_out.data().unwrap(), 1e-7);
}
#[test]
fn test_relu_values() {
let input = leaf(&[-3.0, -1.0, 0.0, 0.5, 2.0], &[5], false);
let output = relu(&input).unwrap();
assert_close(output.data().unwrap(), &[0.0, 0.0, 0.0, 0.5, 2.0], 1e-7);
}
#[test]
fn test_sigmoid_values() {
let input = leaf(&[0.0], &[1], false);
let output = sigmoid(&input).unwrap();
assert!((output.data().unwrap()[0] - 0.5).abs() < 1e-6);
}
#[test]
fn test_tanh_values() {
let input = leaf(&[0.0], &[1], false);
let output = tanh(&input).unwrap();
assert!(output.data().unwrap()[0].abs() < 1e-6);
}
#[test]
fn test_gelu_positive() {
let input = leaf(&[1.0, 2.0], &[2], false);
let output = gelu(&input).unwrap();
let d = output.data().unwrap();
assert!(d[0] > 0.0);
assert!(d[1] > 0.0);
}
#[test]
fn test_silu_zero() {
let input = leaf(&[0.0], &[1], false);
let output = silu(&input).unwrap();
assert!(output.data().unwrap()[0].abs() < 1e-6);
}
#[test]
fn test_softmax_sums_to_one() {
let input = leaf(&[1.0, 2.0, 3.0], &[3], false);
let output = softmax(&input).unwrap();
let d = output.data().unwrap();
let total: f32 = d.iter().sum();
assert!((total - 1.0).abs() < 1e-5);
}
#[test]
fn test_log_softmax_negative() {
let input = leaf(&[1.0, 2.0, 3.0], &[3], false);
let output = log_softmax(&input).unwrap();
let d = output.data().unwrap();
assert!(d.iter().all(|&v| v <= 0.0));
}
#[test]
fn test_leaky_relu_values() {
let input = leaf(&[-2.0, -1.0, 0.0, 1.0, 2.0], &[5], false);
let output = leaky_relu(&input, 0.01).unwrap();
let d = output.data().unwrap();
assert!((d[0] - (-0.02)).abs() < 1e-5);
assert!((d[1] - (-0.01)).abs() < 1e-5);
assert!((d[2] - 0.0).abs() < 1e-5);
assert!((d[3] - 1.0).abs() < 1e-5);
assert!((d[4] - 2.0).abs() < 1e-5);
}
#[test]
fn test_leaky_relu_zero_slope_is_relu() {
let input = leaf(&[-2.0, 0.0, 3.0], &[3], false);
let lrelu_out = leaky_relu(&input, 0.0).unwrap();
let relu_out = relu(&input).unwrap();
assert_close(lrelu_out.data().unwrap(), relu_out.data().unwrap(), 1e-7);
}
#[test]
fn test_leaky_relu_one_slope_is_identity() {
let input = leaf(&[-2.0, 0.0, 3.0], &[3], false);
let output = leaky_relu(&input, 1.0).unwrap();
assert_close(output.data().unwrap(), &[-2.0, 0.0, 3.0], 1e-7);
}
#[test]
fn test_sum_values() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let output = sum(&input).unwrap();
assert!((output.item().unwrap() - 10.0).abs() < 1e-6);
}
#[test]
fn test_mean_values() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[4], false);
let output = mean(&input).unwrap();
assert!((output.item().unwrap() - 2.5).abs() < 1e-6);
}
#[test]
fn test_dropout_eval_is_identity() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0], &[5], false);
let output = dropout(&input, 0.5, false).unwrap();
assert!(output.is_same(&input));
}
#[test]
fn test_dropout_zero_prob_is_identity() {
let input = leaf(&[1.0, 2.0, 3.0], &[3], false);
let output = dropout(&input, 0.0, true).unwrap();
assert!(output.is_same(&input));
}
#[test]
fn test_dropout_invalid_p() {
let input = leaf(&[1.0], &[1], false);
assert!(dropout(&input, 1.0, true).is_err());
assert!(dropout(&input, -0.1, true).is_err());
assert!(dropout(&input, 1.5, true).is_err());
}
#[test]
fn test_dropout_rate_approximately_correct() {
let input = ferrotorch_core::ones::<f32>(&[100_000]).unwrap();
let output = dropout(&input, 0.5, true).unwrap();
let data = output.data().unwrap();
let zeros = data.iter().filter(|&&x| x == 0.0).count();
let rate = zeros as f64 / data.len() as f64;
assert!(
(rate - 0.5).abs() < 0.05,
"dropout rate = {rate}, expected ~0.5"
);
let non_zero: Vec<f32> = data.iter().copied().filter(|&x| x != 0.0).collect();
assert!(!non_zero.is_empty());
for &v in &non_zero {
assert!(
(v - 2.0).abs() < 1e-6,
"surviving element = {v}, expected 2.0"
);
}
}
#[test]
fn test_dropout_training_flag() {
let input = ferrotorch_core::ones::<f32>(&[1000]).unwrap();
let output = dropout(&input, 0.99, false).unwrap();
assert!(output.is_same(&input));
}
#[test]
fn test_mse_loss_zero() {
let pred = leaf(&[1.0, 2.0, 3.0], &[3], false);
let target = leaf(&[1.0, 2.0, 3.0], &[3], false);
let loss = mse_loss(&pred, &target).unwrap();
assert!(loss.item().unwrap().abs() < 1e-7);
}
#[test]
fn test_mse_loss_known_value() {
let pred = leaf(&[1.0, 2.0], &[2], false);
let target = leaf(&[3.0, 4.0], &[2], false);
let loss = mse_loss(&pred, &target).unwrap();
assert!((loss.item().unwrap() - 4.0).abs() < 1e-6);
}
#[test]
fn test_mse_loss_shape_mismatch() {
let pred = leaf(&[1.0, 2.0], &[2], false);
let target = leaf(&[1.0, 2.0, 3.0], &[3], false);
assert!(mse_loss(&pred, &target).is_err());
}
#[test]
fn test_cross_entropy_basic() {
let logits = leaf(&[10.0, 0.0, 0.0, 10.0], &[2, 2], false);
let targets = leaf(&[0.0, 1.0], &[2], false);
let loss = cross_entropy(&logits, &targets).unwrap();
assert!(loss.item().unwrap() < 0.01);
}
#[test]
fn test_cross_entropy_wrong_logits_shape() {
let logits = leaf(&[1.0, 2.0, 3.0], &[3], false);
let targets = leaf(&[0.0], &[1], false);
assert!(cross_entropy(&logits, &targets).is_err());
}
#[test]
fn test_cross_entropy_target_batch_mismatch() {
let logits = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2], false);
let targets = leaf(&[0.0, 1.0, 0.0], &[3], false);
assert!(cross_entropy(&logits, &targets).is_err());
}
#[test]
fn test_cross_entropy_uniform_logits() {
let logits = leaf(&[0.0, 0.0, 0.0, 0.0], &[2, 2], false);
let targets = leaf(&[0.0, 1.0], &[2], false);
let loss = cross_entropy(&logits, &targets).unwrap();
let expected = (2.0_f32).ln();
assert!(
(loss.item().unwrap() - expected).abs() < 1e-5,
"loss = {}, expected ln(2) = {}",
loss.item().unwrap(),
expected,
);
}
fn close(a: f32, b: f32, tol: f32) -> bool {
(a - b).abs() < tol
}
#[test]
fn test_hardtanh_default_clamps_to_minus_one_one() {
let x = leaf(&[-3.0, -1.0, 0.5, 1.5], &[4], false);
let out = hardtanh(&x).unwrap();
let d = out.data().unwrap();
assert_eq!(d, &[-1.0, -1.0, 0.5, 1.0]);
}
#[test]
fn test_hardtanh_with_custom_bounds() {
let x = leaf(&[-3.0, -0.5, 1.0, 5.0], &[4], false);
let out = hardtanh_with(&x, -2.0, 3.0).unwrap();
let d = out.data().unwrap();
assert_eq!(d, &[-2.0, -0.5, 1.0, 3.0]);
}
#[test]
fn test_relu6_clamps_top_at_6() {
let x = leaf(&[-1.0, 0.0, 3.0, 7.0], &[4], false);
let out = relu6(&x).unwrap();
let d = out.data().unwrap();
assert_eq!(d, &[0.0, 0.0, 3.0, 6.0]);
}
#[test]
fn test_hardsigmoid_endpoints() {
let x = leaf(&[-5.0, -3.0, 0.0, 3.0, 5.0], &[5], false);
let out = hardsigmoid(&x).unwrap();
let d = out.data().unwrap();
assert!(close(d[0], 0.0, 1e-6));
assert!(close(d[1], 0.0, 1e-6));
assert!(close(d[2], 0.5, 1e-6));
assert!(close(d[3], 1.0, 1e-6));
assert!(close(d[4], 1.0, 1e-6));
}
#[test]
fn test_hardswish_zero_at_minus_three_and_below() {
let x = leaf(&[-5.0, -3.0, 0.0, 1.0, 5.0], &[5], false);
let out = hardswish(&x).unwrap();
let d = out.data().unwrap();
assert!(close(d[0], 0.0, 1e-6));
assert!(close(d[1], 0.0, 1e-6));
assert!(close(d[2], 0.0, 1e-6));
assert!(close(d[4], 5.0, 1e-6));
}
#[test]
fn test_log_sigmoid_matches_log_of_sigmoid() {
let x = leaf(&[-2.0, -0.5, 0.0, 0.5, 2.0], &[5], false);
let out = log_sigmoid(&x).unwrap();
let d = out.data().unwrap();
for (i, &xi) in [-2.0, -0.5, 0.0, 0.5, 2.0].iter().enumerate() {
let ref_val = (1.0_f32 / (1.0 + (-xi as f32).exp())).ln();
assert!(
close(d[i], ref_val, 1e-5),
"log_sigmoid({xi}) = {} vs {ref_val}",
d[i]
);
}
}
#[test]
fn test_softmin_inverts_softmax() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
let out = softmin(&x).unwrap();
let neg_x = leaf(&[-1.0, -2.0, -3.0], &[3], false);
let ref_out = softmax(&neg_x).unwrap();
let d = out.data().unwrap();
let r = ref_out.data().unwrap();
for i in 0..3 {
assert!(close(d[i], r[i], 1e-6));
}
}
#[test]
fn test_softsign_bounded() {
let x = leaf(&[-1000.0, -1.0, 0.0, 1.0, 1000.0], &[5], false);
let out = softsign(&x).unwrap();
let d = out.data().unwrap();
assert!(close(d[0], -1.0, 1e-2));
assert!(close(d[1], -0.5, 1e-6));
assert!(close(d[2], 0.0, 1e-6));
assert!(close(d[3], 0.5, 1e-6));
assert!(close(d[4], 1.0, 1e-2));
}
#[test]
fn test_tanhshrink() {
let x = leaf(&[0.0, 1.0, 2.0], &[3], false);
let out = tanhshrink(&x).unwrap();
let d = out.data().unwrap();
assert!(close(d[0], 0.0, 1e-6));
assert!(close(d[1], 1.0 - 1.0_f32.tanh(), 1e-6));
assert!(close(d[2], 2.0 - 2.0_f32.tanh(), 1e-6));
}
#[test]
fn test_selu_scale_constants() {
let x = leaf(&[0.0, 1.0], &[2], false);
let out = selu(&x).unwrap();
let d = out.data().unwrap();
assert!(close(d[0], 0.0, 1e-6));
assert!(close(d[1], 1.050_700_9, 1e-5));
}
#[test]
fn test_softplus_default_matches_explicit() {
let x = leaf(&[-1.0, 0.0, 1.0, 2.0], &[4], false);
let out_default = softplus(&x).unwrap();
let out_explicit = softplus_with(&x, 1.0, 20.0).unwrap();
let a = out_default.data().unwrap();
let b = out_explicit.data().unwrap();
for i in 0..4 {
assert!(close(a[i], b[i], 1e-6));
}
}
#[test]
fn test_elu_zero_at_origin() {
let x = leaf(&[-1.0, 0.0, 1.0], &[3], false);
let out = elu(&x).unwrap();
let d = out.data().unwrap();
assert!(close(d[0], (-1.0_f32).exp() - 1.0, 1e-5));
assert!(close(d[1], 0.0, 1e-6));
assert!(close(d[2], 1.0, 1e-6));
}
#[test]
fn test_glu_halves_input() {
let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[1, 4], false);
let out = glu(&x, -1).unwrap();
assert_eq!(out.shape(), &[1, 2]);
let d = out.data().unwrap();
let s3 = 1.0 / (1.0 + (-3.0_f32).exp());
let s4 = 1.0 / (1.0 + (-4.0_f32).exp());
assert!(close(d[0], 1.0 * s3, 1e-5));
assert!(close(d[1], 2.0 * s4, 1e-5));
}
#[test]
fn test_glu_rejects_odd_dim() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
let err = glu(&x, 0).unwrap_err();
assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
}
use crate::module::Reduction;
#[test]
fn test_l1_loss_mean() {
let p = leaf(&[1.0, 2.0, 3.0], &[3], false);
let t = leaf(&[0.0, 0.0, 0.0], &[3], false);
let loss = l1_loss(&p, &t, Reduction::Mean).unwrap();
assert!(close(loss.item().unwrap(), 2.0, 1e-5));
}
#[test]
fn test_l1_loss_sum() {
let p = leaf(&[1.0, 2.0, 3.0], &[3], false);
let t = leaf(&[0.0, 0.0, 0.0], &[3], false);
let loss = l1_loss(&p, &t, Reduction::Sum).unwrap();
assert!(close(loss.item().unwrap(), 6.0, 1e-5));
}
#[test]
fn test_l1_loss_none_returns_per_element() {
let p = leaf(&[1.0, -2.0, 3.0], &[3], false);
let t = leaf(&[0.0, 0.0, 0.0], &[3], false);
let loss = l1_loss(&p, &t, Reduction::None).unwrap();
assert_eq!(loss.shape(), &[3]);
let d = loss.data().unwrap();
assert_eq!(d, &[1.0, 2.0, 3.0]);
}
#[test]
fn test_binary_cross_entropy_log2_loss_for_uniform() {
let p = leaf(&[0.5], &[1], false);
let t = leaf(&[1.0], &[1], false);
let loss = binary_cross_entropy(&p, &t, Reduction::Mean).unwrap();
assert!(close(loss.item().unwrap(), 2.0_f32.ln(), 1e-5));
}
#[test]
fn test_binary_cross_entropy_with_logits_matches_bce_for_sigmoid_input() {
let z = leaf(&[-1.0, 0.5, 2.0], &[3], false);
let y = leaf(&[0.0, 1.0, 1.0], &[3], false);
let s = sigmoid(&z).unwrap();
let lhs = binary_cross_entropy_with_logits(&z, &y, Reduction::Mean)
.unwrap()
.item()
.unwrap();
let rhs = binary_cross_entropy(&s, &y, Reduction::Mean)
.unwrap()
.item()
.unwrap();
assert!(close(lhs, rhs, 1e-4), "logit-bce={lhs} vs bce(sig)={rhs}");
}
#[test]
fn test_kl_div_zero_for_matched_distributions() {
let target = leaf(&[0.5, 0.5], &[2], false);
let pred = leaf(&[2.0_f32.recip().ln(), 2.0_f32.recip().ln()], &[2], false);
let loss = kl_div(&pred, &target, Reduction::Sum).unwrap();
assert!(close(loss.item().unwrap(), 0.0, 1e-5));
}
#[test]
fn test_normalize_l2_unit_norm() {
let x = leaf(&[3.0, 4.0], &[2], false);
let out = normalize(&x, 2.0, 0, 1e-12).unwrap();
let d = out.data().unwrap();
assert!(close(d[0], 0.6, 1e-5));
assert!(close(d[1], 0.8, 1e-5));
}
#[test]
fn test_normalize_l2_2d_per_row() {
let x = leaf(&[3.0, 4.0, 6.0, 8.0], &[2, 2], false);
let out = normalize(&x, 2.0, 1, 1e-12).unwrap();
let d = out.data().unwrap();
assert!(close(d[0], 0.6, 1e-5));
assert!(close(d[1], 0.8, 1e-5));
assert!(close(d[2], 0.6, 1e-5));
assert!(close(d[3], 0.8, 1e-5));
}
#[test]
fn test_cosine_similarity_aligned_vectors() {
let x = leaf(&[1.0, 2.0, 3.0], &[3], false);
let y = leaf(&[2.0, 4.0, 6.0], &[3], false);
let sim = cosine_similarity(&x, &y, 0, 1e-8).unwrap();
assert!(close(sim.item().unwrap(), 1.0, 1e-4));
}
#[test]
fn test_cosine_similarity_orthogonal_vectors() {
let x = leaf(&[1.0, 0.0], &[2], false);
let y = leaf(&[0.0, 1.0], &[2], false);
let sim = cosine_similarity(&x, &y, 0, 1e-8).unwrap();
assert!(close(sim.item().unwrap(), 0.0, 1e-5));
}
#[test]
fn test_pairwise_distance_zero_for_equal_vectors() {
let x = leaf(&[1.0, 2.0, 3.0], &[1, 3], false);
let y = leaf(&[1.0, 2.0, 3.0], &[1, 3], false);
let d = pairwise_distance(&x, &y, 2.0, 0.0).unwrap();
assert!(close(d.item().unwrap(), 0.0, 1e-5));
}
#[test]
fn test_pairwise_distance_l2_simple() {
let x = leaf(&[3.0, 4.0], &[1, 2], false);
let y = leaf(&[0.0, 0.0], &[1, 2], false);
let d = pairwise_distance(&x, &y, 2.0, 0.0).unwrap();
assert!(close(d.item().unwrap(), 5.0, 1e-4));
}
#[test]
fn test_one_hot_basic() {
let idx = leaf(&[0.0, 2.0, 1.0], &[3], false);
let oh = one_hot(&idx, 3).unwrap();
assert_eq!(oh.shape(), &[3, 3]);
let d = oh.data().unwrap();
assert_eq!(
d,
&[
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, ]
);
}
#[test]
fn test_one_hot_2d_input() {
let idx = leaf(&[0.0, 1.0, 2.0, 0.0], &[2, 2], false);
let oh = one_hot(&idx, 3).unwrap();
assert_eq!(oh.shape(), &[2, 2, 3]);
}
#[test]
fn test_one_hot_rejects_out_of_range() {
let idx = leaf(&[0.0, 5.0], &[2], false);
let err = one_hot(&idx, 3).unwrap_err();
assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
}
#[test]
fn test_one_hot_rejects_negative() {
let idx = leaf(&[-1.0, 0.0], &[2], false);
let err = one_hot(&idx, 3).unwrap_err();
assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
}
#[test]
fn conv2d_forwarder_matches_module() {
let input = leaf(&[1.0; 16], &[1, 1, 4, 4], false);
let weight = leaf(&[1.0; 9], &[1, 1, 3, 3], false);
let bias = leaf(&[0.0], &[1], false);
let f_out = super::conv2d(&input, &weight, Some(&bias), (1, 1), (0, 0)).unwrap();
assert_eq!(f_out.shape(), &[1, 1, 2, 2]);
assert_close(f_out.data().unwrap(), &[9.0, 9.0, 9.0, 9.0], 1e-5);
}
#[test]
fn conv2d_rejects_bad_weight_shape() {
let input = leaf(&[0.0; 16], &[1, 1, 4, 4], false);
let weight = leaf(&[1.0; 4], &[2, 2], false);
let err = super::conv2d(&input, &weight, None, (1, 1), (0, 0)).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn conv1d_forwarder_smoke() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4], false);
let weight = leaf(&[1.0, 1.0], &[1, 1, 2], false);
let out = super::conv1d(&input, &weight, None, 1, 0).unwrap();
assert_eq!(out.shape(), &[1, 1, 3]);
assert_close(out.data().unwrap(), &[3.0, 5.0, 7.0], 1e-5);
}
#[test]
fn embedding_forwarder_matches_module() {
let weight = leaf(&[0.0, 0.1, 1.0, 1.1, 2.0, 2.1], &[3, 2], false);
let idx = leaf(&[0.0, 2.0, 1.0], &[3], false);
let out = super::embedding(&idx, &weight, None).unwrap();
assert_eq!(out.shape(), &[3, 2]);
assert_close(out.data().unwrap(), &[0.0, 0.1, 2.0, 2.1, 1.0, 1.1], 1e-5);
}
#[test]
fn pad_re_export_2d() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2], false);
let out = super::pad2d(&input, 1, 1, 1, 1, super::PaddingMode::Zeros, 0.0).unwrap();
assert_eq!(out.shape(), &[1, 1, 4, 4]);
}
#[test]
fn sdpa_forwarder_smoke() {
let q = leaf(&[1.0, 0.0], &[1, 1, 2], false);
let k = leaf(&[1.0, 0.0], &[1, 1, 2], false);
let v = leaf(&[3.0, 4.0], &[1, 1, 2], false);
let out = super::scaled_dot_product_attention(&q, &k, &v, false).unwrap();
assert_eq!(out.shape(), &[1, 1, 2]);
assert_close(out.data().unwrap(), &[3.0, 4.0], 1e-5);
}
}