pub mod activation;
pub mod conv;
pub mod core;
pub mod linear;
pub mod loss;
pub mod loss_advanced;
pub mod norm;
pub mod pooling;
pub use core::*;
pub use activation::{
dropout, elu, gelu, leaky_relu, log_softmax, mish, relu, relu_inplace, selu, sigmoid, softmax,
swish, tanh,
};
pub use activation::{
LeakyReLU, LogSoftmax, Mish, ReLU, Sigmoid, Softmax, Swish, Tanh, ELU, GELU, SELU,
};
pub use activation::configured::{
gelu_configured, mish_configured, relu_configured, sigmoid_configured, softmax_configured,
swish_configured, tanh_configured,
};
pub use conv::{conv1d, conv2d, conv3d, conv_transpose1d, conv_transpose2d, conv_transpose3d};
pub use conv::{conv_output_size, conv_transpose_output_size, validate_conv_params};
pub use pooling::{
adaptive_max_pool1d, adaptive_max_pool2d, adaptive_max_pool3d, global_max_pool1d,
global_max_pool2d, global_max_pool3d, max_pool1d, max_pool2d, max_pool3d,
};
pub use pooling::{
adaptive_avg_pool1d, adaptive_avg_pool2d, adaptive_avg_pool3d, avg_pool1d, avg_pool2d,
avg_pool3d, global_avg_pool1d, global_avg_pool2d, global_avg_pool3d,
};
pub use pooling::{
pad, reflection_pad1d, reflection_pad2d, replication_pad1d, replication_pad2d, zero_pad2d,
};
pub use pooling::{adaptive_pool_params, pool_output_size};
pub use linear::{bilinear, linear};
pub use linear::{embedding, embedding_bag, one_hot};
pub use linear::{
grouped_query_attention, multi_head_attention, multi_query_attention,
scaled_dot_product_attention,
};
pub use linear::{
learnable_positional_encoding, rotary_positional_encoding, sinusoidal_positional_encoding,
};
pub use linear::{post_norm_layer_norm, pre_norm_layer_norm, rms_norm};
pub use linear::{geglu, glu, swiglu};
pub use linear::{apply_attention_mask, create_causal_mask, create_padding_mask};
pub use loss::{
binary_cross_entropy, binary_cross_entropy_with_logits, cross_entropy, focal_loss,
multi_margin_loss, multilabel_margin_loss, nll_loss,
};
pub use loss::{huber_loss, l1_loss, mse_loss, smooth_l1_loss};
pub use loss::kl_div;
pub use loss::{contrastive_loss, cosine_embedding_loss, triplet_margin_loss};
pub use loss::{center_loss, dice_loss, infonce_loss, tversky_loss, wing_loss};
pub use loss_advanced::{CustomLoss, LossBuilder, LossFactory, Reduction};
pub use loss_advanced::{
AdaptiveLoss, CombinedLoss, DiceLoss, IoULoss, SmoothL1Loss, WeightedLoss,
};
pub use loss_advanced::{
BinaryCrossEntropy, CategoricalCrossEntropy, CosineEmbeddingLoss, FocalLoss, HingeLoss,
HuberLoss, KLDivLoss, L1Loss, MSELoss, NLLLoss, TripletMarginLoss,
};
pub use loss_advanced::validation as loss_validation;
pub use norm::{
batch_norm, batch_norm_1d, batch_norm_2d, batch_norm_2d_with_config, batch_norm_3d,
};
pub use norm::{layer_norm, layer_norm_configured, layer_norm_enhanced};
pub use norm::{
group_norm, instance_norm, local_response_norm, rms_norm as rms_norm_standalone, spectral_norm,
weight_norm,
};
pub use norm::configured::batch_norm_configured;
pub use norm::{create_affine_params, get_norm_features, validate_norm_params};
pub mod activations {
pub use super::activation::configured::*;
pub use super::activation::{gelu, mish, relu, sigmoid, softmax, swish, tanh};
}
pub mod losses {
pub use super::loss::*;
pub fn mse_loss_configured(
input: &crate::Tensor,
target: &crate::Tensor,
reduction: &str,
config: &super::FunctionalConfig,
) -> super::FuncResult<crate::Tensor> {
crate::validate_inputs!(
config,
super::validation::validate_not_empty(input, "input"),
super::validation::validate_not_empty(target, "target"),
super::validation::validate_compatible_shapes(input, target, "MSE loss")
);
crate::func_error!(super::mse_loss(input, target, reduction), "MSE loss")
}
pub fn l1_loss_configured(
input: &crate::Tensor,
target: &crate::Tensor,
reduction: &str,
config: &super::FunctionalConfig,
) -> super::FuncResult<crate::Tensor> {
crate::validate_inputs!(
config,
super::validation::validate_not_empty(input, "input"),
super::validation::validate_not_empty(target, "target"),
super::validation::validate_compatible_shapes(input, target, "L1 loss")
);
crate::func_error!(super::l1_loss(input, target, reduction), "L1 loss")
}
pub fn cross_entropy_configured(
input: &crate::Tensor,
target: &crate::Tensor<i64>,
weight: Option<&crate::Tensor>,
ignore_index: Option<i64>,
reduction: &str,
config: &super::FunctionalConfig,
) -> super::FuncResult<crate::Tensor> {
crate::validate_inputs!(
config,
super::validation::validate_not_empty(input, "input"),
super::validation::validate_not_empty(target, "target"),
super::validation::validate_min_ndim(input, 2, "input")
);
crate::func_error!(
super::cross_entropy(input, target, weight, reduction, ignore_index),
"Cross entropy loss"
)
}
}
pub mod normalization {
pub use super::norm::configured::*;
pub use super::norm::{batch_norm_2d, layer_norm_enhanced};
}
pub mod prelude {
pub use super::{
activations, default_config, losses, normalization, numerics, optimized, performance, safe,
validation, Activation, ActivationConfig, CustomLoss, FunctionalBuilder, FunctionalConfig,
LossBuilder, MemoryOptLevel, Reduction,
};
}
use torsh_core::error::Result;
use torsh_tensor::Tensor;
#[allow(dead_code)]
trait TensorCast {
fn cast_i64(&self) -> Result<Tensor<i64>>;
}
#[allow(dead_code)]
impl TensorCast for Tensor {
fn cast_i64(&self) -> Result<Tensor<i64>> {
let data = self.to_vec()?;
let i64_data: Vec<i64> = data.into_iter().map(|x| x as i64).collect();
Ok(Tensor::from_data(
i64_data,
self.shape().dims().to_vec(),
self.device(),
)?)
}
}
pub struct SparseMatrix;
impl SparseMatrix {
pub fn new() -> Self {
Self
}
}
impl Default for SparseMatrix {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_modular_functional_system() {
let input = torsh_tensor::creation::randn::<f32>(&[2, 4]).unwrap();
let _relu_result = relu(&input).unwrap();
let _sigmoid_result = sigmoid(&input).unwrap();
let _tanh_result = tanh(&input).unwrap();
let config = FunctionalConfig::default();
let _configured_relu = activations::relu_configured(&input, &config).unwrap();
let _optimized_config = optimized().build();
let _safe_config = safe().build();
}
#[test]
fn test_backward_compatibility() {
let input = torsh_tensor::creation::randn::<f32>(&[4, 3, 32, 32]).unwrap();
let weight = torsh_tensor::creation::ones(&[3]).unwrap();
let bias = torsh_tensor::creation::zeros(&[3]).unwrap();
let _batch_norm_result = batch_norm_2d(
&input,
Some(&weight),
Some(&bias),
None,
None,
true,
0.1,
1e-5,
)
.unwrap();
let activation_input = torsh_tensor::creation::randn::<f32>(&[2, 4]).unwrap();
let _relu_result = relu(&activation_input).unwrap();
let _gelu_result = gelu(&activation_input).unwrap();
let _swish_result = swish(&activation_input).unwrap();
}
#[test]
fn test_modular_structure_integrity() {
let config = FunctionalConfig::default();
assert_eq!(config.validate_inputs, true);
assert_eq!(config.eps, 1e-8);
let custom_config = FunctionalBuilder::new().eps(1e-6).inplace(true).build();
assert_eq!(custom_config.eps, 1e-6);
assert_eq!(custom_config.inplace, true);
let _default_conf = prelude::default_config();
}
#[test]
fn test_loss_framework() {
let predictions = torsh_tensor::creation::randn::<f32>(&[4, 10]).unwrap();
let targets = torsh_tensor::creation::randn::<f32>(&[4, 10]).unwrap();
let mse = MSELoss::new(Reduction::Mean);
let _loss_result = mse.compute_loss(&predictions, &targets).unwrap();
let _smooth_l1 = LossBuilder::new()
.with_reduction(Reduction::Sum)
.smooth_l1(1.0);
}
}
#[cfg(test)]
mod examples {
use super::*;
#[test]
fn example_basic_usage() {
let input = torsh_tensor::creation::randn::<f32>(&[4, 3, 32, 32]).unwrap();
let target = torsh_tensor::creation::randn::<f32>(&[4, 10]).unwrap();
let activated = relu(&input).unwrap();
let _softmax_result = softmax(&activated, Some(-1)).unwrap();
let weight = torsh_tensor::creation::ones(&[3]).unwrap();
let bias = torsh_tensor::creation::zeros(&[3]).unwrap();
let _normalized = batch_norm_2d(
&input,
Some(&weight),
Some(&bias),
None,
None,
true,
0.1,
1e-5,
)
.unwrap();
let predictions = torsh_tensor::creation::randn::<f32>(&[4, 10]).unwrap();
let _mse_loss = mse_loss(&predictions, &target, "mean").unwrap();
}
#[test]
fn example_configured_usage() {
let config = FunctionalBuilder::new()
.validate(true)
.eps(1e-6)
.memory_opt(MemoryOptLevel::Maximum)
.build();
let input = torsh_tensor::creation::randn::<f32>(&[4, 8]).unwrap();
let _relu_result = activations::relu_configured(&input, &config).unwrap();
let _sigmoid_result = activations::sigmoid_configured(&input, &config).unwrap();
}
#[test]
fn example_advanced_loss_usage() {
let predictions = torsh_tensor::creation::randn::<f32>(&[4, 10]).unwrap();
let targets = torsh_tensor::creation::randn::<f32>(&[4, 10]).unwrap();
let dice_loss = LossBuilder::new()
.with_reduction(Reduction::Mean)
.dice(1e-6);
let _loss_result = dice_loss.compute_loss(&predictions, &targets).unwrap();
let mse = Box::new(MSELoss::new(Reduction::None));
let l1 = Box::new(L1Loss::new(Reduction::None));
let combined = LossBuilder::new()
.with_reduction(Reduction::Mean)
.combined(vec![mse, l1], vec![0.7, 0.3]);
let _combined_loss = combined.compute_loss(&predictions, &targets).unwrap();
}
}