pub mod parameter;
pub mod buffer;
pub mod init;
pub mod linear;
pub mod activation;
pub mod loss;
pub mod optim;
pub mod clip;
pub mod scheduler;
pub mod dropout;
pub mod padding;
pub mod layernorm;
pub mod rmsnorm;
pub mod embedding;
pub mod grucell;
pub mod gru;
pub mod lstmcell;
pub mod lstm;
pub mod conv1d;
pub mod conv2d;
pub mod conv_transpose1d;
pub mod conv_transpose2d;
pub mod conv3d;
pub mod conv_transpose3d;
pub mod groupnorm;
pub mod batchnorm;
pub mod instancenorm;
pub mod pooling;
pub mod bilinear;
pub mod attention;
pub mod checkpoint;
pub mod amp;
pub mod cuda_graph;
pub mod functional;
pub use parameter::Parameter;
pub use buffer::Buffer;
pub use linear::Linear;
pub use activation::{
Identity, ReLU, Sigmoid, Tanh, GELU, SiLU,
LeakyReLU, ELU, Softplus, Mish,
SELU, Hardswish, Hardsigmoid, PReLU,
Softmax, LogSoftmax, Flatten,
};
pub use loss::{
mse_loss, cross_entropy_loss, bce_loss, bce_with_logits_loss,
l1_loss, smooth_l1_loss, kl_div_loss,
nll_loss, ctc_loss, focal_loss,
triplet_margin_loss, cosine_embedding_loss,
hinge_embedding_loss, margin_ranking_loss, poisson_nll_loss,
};
pub use optim::{Optimizer, Stateful, SGD, SGDBuilder, Adam, AdamBuilder, AdamW, AdamWBuilder, RMSprop, RMSpropBuilder, Adagrad, AdagradBuilder, RAdam, NAdam};
pub use checkpoint::{
save_checkpoint, load_checkpoint, save_checkpoint_file, load_checkpoint_file,
migrate_checkpoint, migrate_checkpoint_file, checkpoint_version,
LoadReport, MigrateReport,
};
pub use amp::{GradScaler, cast_parameters, AutocastGuard, autocast, is_autocast_enabled};
pub use clip::{clip_grad_norm, clip_grad_value};
pub use scheduler::{Scheduler, StepDecay, CosineScheduler, WarmupScheduler, PlateauScheduler, ExponentialLR, MultiStepLR, OneCycleLR, CyclicLR};
pub use dropout::{Dropout, Dropout2d, AlphaDropout};
pub use padding::{ZeroPad2d, ReflectionPad2d};
pub use layernorm::LayerNorm;
pub use rmsnorm::RMSNorm;
pub use embedding::{Embedding, EmbeddingBag};
pub use grucell::GRUCell;
pub use gru::GRU;
pub use lstmcell::LSTMCell;
pub use lstm::LSTM;
pub use conv1d::{Conv1d, Conv1dBuilder};
pub use conv2d::{Conv2d, Conv2dBuilder};
pub use conv_transpose1d::{ConvTranspose1d, ConvTranspose1dBuilder};
pub use conv_transpose2d::{ConvTranspose2d, ConvTranspose2dBuilder};
pub use conv3d::{Conv3d, Conv3dBuilder};
pub use conv_transpose3d::{ConvTranspose3d, ConvTranspose3dBuilder};
pub use groupnorm::GroupNorm;
pub use batchnorm::{BatchNorm, BatchNorm2d};
pub use instancenorm::InstanceNorm;
pub use pooling::{MaxPool2d, AvgPool2d, MaxPool1d, AvgPool1d, AdaptiveMaxPool2d, AdaptiveAvgPool2d, PixelShuffle, PixelUnshuffle, Upsample, Unfold, Fold};
pub use bilinear::Bilinear;
pub use attention::MultiheadAttention;
pub use init::{xavier_uniform, xavier_normal, kaiming_uniform, kaiming_normal, uniform_bias, uniform, normal, orthogonal, trunc_normal};
pub use functional::{gaussian_blur_2d, GaussianBlur};
pub use cuda_graph::{CudaGraph, MemPoolId, CaptureMode, cuda_graph_capture, cuda_graph_pool_handle};
use std::collections::{HashMap, HashSet};
use std::rc::Rc;
use crate::autograd::Variable;
use crate::graph::Graph;
use crate::tensor::Result;
pub trait Module {
fn forward(&self, input: &Variable) -> Result<Variable>;
fn parameters(&self) -> Vec<Parameter> {
let subs = self.sub_modules();
if subs.is_empty() {
return vec![];
}
let mut params = Vec::new();
let mut seen = HashSet::new();
let mut visited = HashSet::new();
for child in &subs {
walk_modules_visited(child.as_ref(), &mut visited, &mut |m| {
for p in m.parameters() {
let ptr = Rc::as_ptr(&p.variable.inner) as usize;
if seen.insert(ptr) {
params.push(p);
}
}
});
}
params
}
fn buffers(&self) -> Vec<Buffer> {
let subs = self.sub_modules();
if subs.is_empty() {
return vec![];
}
let mut bufs = Vec::new();
let mut seen = HashSet::new();
let mut visited = HashSet::new();
for child in &subs {
walk_modules_visited(child.as_ref(), &mut visited, &mut |m| {
for b in m.buffers() {
let ptr = Rc::as_ptr(&b.inner) as usize;
if seen.insert(ptr) {
bufs.push(b);
}
}
});
}
bufs
}
fn name(&self) -> &str { "module" }
fn sub_modules(&self) -> Vec<Rc<dyn Module>> { vec![] }
fn move_to_device(&self, _device: crate::tensor::Device) {}
fn set_training(&self, _training: bool) {}
fn train(&self) { self.set_training(true); }
fn eval(&self) { self.set_training(false); }
fn trace(&self) -> Option<Variable> { None }
fn as_named_input(&self) -> Option<&dyn NamedInputModule> { None }
fn as_graph(&self) -> Option<&Graph> { None }
fn structural_hash(&self) -> Option<String> { None }
fn reset(&self) {}
fn detach_state(&self) {}
}
impl Module for Box<dyn Module> {
fn forward(&self, input: &Variable) -> Result<Variable> {
(**self).forward(input)
}
fn parameters(&self) -> Vec<Parameter> {
(**self).parameters()
}
fn buffers(&self) -> Vec<Buffer> {
(**self).buffers()
}
fn name(&self) -> &str {
(**self).name()
}
fn sub_modules(&self) -> Vec<Rc<dyn Module>> {
(**self).sub_modules()
}
fn move_to_device(&self, device: crate::tensor::Device) {
(**self).move_to_device(device);
}
fn set_training(&self, training: bool) {
(**self).set_training(training);
}
fn trace(&self) -> Option<Variable> {
(**self).trace()
}
fn as_named_input(&self) -> Option<&dyn NamedInputModule> {
(**self).as_named_input()
}
fn as_graph(&self) -> Option<&Graph> {
(**self).as_graph()
}
fn structural_hash(&self) -> Option<String> {
(**self).structural_hash()
}
fn reset(&self) {
(**self).reset();
}
fn detach_state(&self) {
(**self).detach_state();
}
}
pub trait NamedInputModule: Module {
fn forward_named(
&self,
input: &Variable,
refs: &HashMap<String, Variable>,
) -> Result<Variable>;
}
pub fn walk_modules(module: &dyn Module, f: &mut dyn FnMut(&dyn Module)) {
let mut visited = HashSet::new();
walk_modules_visited(module, &mut visited, f);
}
pub fn walk_modules_visited(
module: &dyn Module,
visited: &mut HashSet<usize>,
f: &mut dyn FnMut(&dyn Module),
) {
let ptr = module as *const dyn Module as *const () as usize;
if !visited.insert(ptr) {
return;
}
f(module);
for child in module.sub_modules() {
walk_modules_visited(child.as_ref(), visited, f);
}
}
pub fn collect_parameters(modules: &[&dyn Module]) -> Vec<Parameter> {
let mut params = Vec::new();
for m in modules {
params.extend(m.parameters());
}
params
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::Variable;
use crate::tensor::Tensor;
fn from_f32(data: &[f32], shape: &[i64]) -> Tensor {
Tensor::from_f32(data, shape, crate::tensor::test_device()).unwrap()
}
#[test]
fn test_linear_forward() {
let model = Linear::on_device(3, 2, crate::tensor::test_device()).unwrap();
model.weight.variable.set_data(
from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]),
);
model.bias.as_ref().unwrap().variable.set_data(
from_f32(&[0.1, 0.2], &[2]),
);
let x = Variable::new(from_f32(&[1.0, 1.0, 1.0], &[1, 3]), false);
let y = model.forward(&x).unwrap();
let data = y.data().to_f32_vec().unwrap();
assert!((data[0] - 6.1).abs() < 1e-5);
assert!((data[1] - 15.2).abs() < 1e-5);
}
#[test]
fn test_linear_backward() {
let model = Linear::on_device(3, 2, crate::tensor::test_device()).unwrap();
model.weight.variable.set_data(
from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]),
);
model.bias.as_ref().unwrap().variable.set_data(
from_f32(&[0.0, 0.0], &[2]),
);
let x = Variable::new(from_f32(&[1.0, 1.0, 1.0], &[1, 3]), true);
let y = model.forward(&x).unwrap();
let loss = y.sum().unwrap();
loss.backward().unwrap();
let gw = model.weight.variable.grad().unwrap().to_f32_vec().unwrap();
assert_eq!(gw, vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
let gb = model.bias.as_ref().unwrap().variable.grad().unwrap().to_f32_vec().unwrap();
assert_eq!(gb, vec![1.0, 1.0]);
let gx = x.grad().unwrap().to_f32_vec().unwrap();
assert_eq!(gx, vec![5.0, 7.0, 9.0]);
}
#[test]
fn test_mse_loss() {
let pred = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[3]), false);
let target = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[3]), false);
let loss = mse_loss(&pred, &target).unwrap();
assert!((loss.item().unwrap()).abs() < 1e-7);
let pred2 = Variable::new(from_f32(&[2.0, 3.0, 4.0], &[3]), false);
let loss2 = mse_loss(&pred2, &target).unwrap();
assert!((loss2.item().unwrap() - 1.0).abs() < 1e-5);
}
#[test]
fn test_sgd_step() {
let model = Linear::on_device(2, 1, crate::tensor::test_device()).unwrap();
model.weight.variable.set_data(from_f32(&[1.0, 1.0], &[1, 2]));
model.bias.as_ref().unwrap().variable.set_data(from_f32(&[0.0], &[1]));
let params = model.parameters();
let mut optim = SGD::new(¶ms, 0.1, 0.0);
let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
let target = Variable::new(from_f32(&[5.0], &[1, 1]), false);
let pred = model.forward(&x).unwrap();
let loss = mse_loss(&pred, &target).unwrap();
let loss_before = loss.item().unwrap();
loss.backward().unwrap();
optim.step().unwrap();
optim.zero_grad();
let pred2 = model.forward(&x).unwrap();
let loss2 = mse_loss(&pred2, &target).unwrap();
assert!(loss2.item().unwrap() < loss_before, "loss should decrease");
}
#[test]
fn test_linear_regression_sgd() {
let model = Linear::on_device(1, 1, crate::tensor::test_device()).unwrap();
let params = model.parameters();
let mut optim = SGD::new(¶ms, 0.01, 0.0);
let x = Variable::new(
from_f32(&[1.0, 2.0, 3.0, 4.0], &[4, 1]),
false,
);
let target = Variable::new(
from_f32(&[3.0, 5.0, 7.0, 9.0], &[4, 1]),
false,
);
let mut last_loss = f64::MAX;
for _ in 0..800 {
optim.zero_grad();
let pred = model.forward(&x).unwrap();
let loss = mse_loss(&pred, &target).unwrap();
last_loss = loss.item().unwrap();
loss.backward().unwrap();
optim.step().unwrap();
}
assert!(
last_loss < 0.01,
"SGD should converge on linear regression, got loss={}",
last_loss
);
}
#[test]
fn test_linear_regression_adam() {
let model = Linear::on_device(1, 1, crate::tensor::test_device()).unwrap();
let params = model.parameters();
let mut optim = Adam::new(¶ms, 0.1);
let x = Variable::new(
from_f32(&[1.0, 2.0, 3.0, 4.0], &[4, 1]),
false,
);
let target = Variable::new(
from_f32(&[3.0, 5.0, 7.0, 9.0], &[4, 1]),
false,
);
let mut last_loss = f64::MAX;
for _ in 0..500 {
optim.zero_grad();
let pred = model.forward(&x).unwrap();
let loss = mse_loss(&pred, &target).unwrap();
last_loss = loss.item().unwrap();
loss.backward().unwrap();
optim.step().unwrap();
}
assert!(
last_loss < 0.02,
"Adam should converge on linear regression, got loss={}",
last_loss
);
}
#[test]
fn test_relu_module() {
let relu = ReLU::new();
let x = Variable::new(from_f32(&[1.0, -1.0, 2.0, -2.0], &[4]), false);
let y = relu.forward(&x).unwrap();
assert_eq!(y.data().to_f32_vec().unwrap(), vec![1.0, 0.0, 2.0, 0.0]);
assert!(relu.parameters().is_empty());
}
#[test]
fn test_collect_parameters() {
let l1 = Linear::on_device(3, 4, crate::tensor::test_device()).unwrap();
let l2 = Linear::on_device(4, 2, crate::tensor::test_device()).unwrap();
let params = collect_parameters(&[&l1, &l2]);
assert_eq!(params.len(), 4);
}
#[test]
fn test_sgd_momentum() {
let model = Linear::on_device(1, 1, crate::tensor::test_device()).unwrap();
let params = model.parameters();
let mut optim = SGD::new(¶ms, 0.01, 0.9);
let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0], &[4, 1]), false);
let target = Variable::new(from_f32(&[3.0, 5.0, 7.0, 9.0], &[4, 1]), false);
let mut last_loss = f64::MAX;
for _ in 0..200 {
optim.zero_grad();
let pred = model.forward(&x).unwrap();
let loss = mse_loss(&pred, &target).unwrap();
last_loss = loss.item().unwrap();
loss.backward().unwrap();
optim.step().unwrap();
}
assert!(
last_loss < 0.01,
"SGD with momentum should converge, got loss={}",
last_loss
);
}
#[test]
fn test_cross_entropy_loss() {
let pred = Variable::new(
from_f32(&[2.0, 1.0, 1.0, 3.0], &[2, 2]),
true,
);
let target = Variable::new(
from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]),
false,
);
let loss = cross_entropy_loss(&pred, &target).unwrap();
let val = loss.item().unwrap();
assert!(val > 0.0, "cross entropy should be positive");
assert!((val - 0.22).abs() < 0.02, "expected ~0.22, got {}", val);
loss.backward().unwrap();
assert!(pred.grad().is_some());
}
#[test]
fn test_cross_entropy_converges() {
let model = Linear::on_device(2, 2, crate::tensor::test_device()).unwrap();
let params = model.parameters();
let mut optim = SGD::new(¶ms, 0.1, 0.0);
let x = Variable::new(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]), false);
let target = Variable::new(from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]), false);
let mut last_loss = f64::MAX;
for _ in 0..200 {
optim.zero_grad();
let pred = model.forward(&x).unwrap();
let loss = cross_entropy_loss(&pred, &target).unwrap();
last_loss = loss.item().unwrap();
loss.backward().unwrap();
optim.step().unwrap();
}
assert!(last_loss < 0.1, "cross entropy should converge, got {}", last_loss);
}
#[test]
fn test_bce_with_logits_loss() {
let pred = Variable::new(from_f32(&[0.0], &[1]), true);
let target = Variable::new(from_f32(&[1.0], &[1]), false);
let loss = bce_with_logits_loss(&pred, &target).unwrap();
let val = loss.item().unwrap();
assert!(
(val - 0.693).abs() < 0.01,
"expected ~0.693, got {}",
val
);
let pred2 = Variable::new(from_f32(&[10.0], &[1]), false);
let loss2 = bce_with_logits_loss(&pred2, &target).unwrap();
assert!(loss2.item().unwrap() < 0.001);
loss.backward().unwrap();
assert!(pred.grad().is_some());
}
#[test]
fn test_l1_loss() {
let pred = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[3]), true);
let target = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[3]), false);
let loss = l1_loss(&pred, &target).unwrap();
assert!((loss.item().unwrap()).abs() < 1e-6);
let pred2 = Variable::new(from_f32(&[2.0, 4.0, 6.0], &[3]), true);
let loss2 = l1_loss(&pred2, &target).unwrap();
assert!((loss2.item().unwrap() - 2.0).abs() < 1e-5);
loss2.backward().unwrap();
assert!(pred2.grad().is_some());
}
#[test]
fn test_smooth_l1_loss() {
let pred = Variable::new(from_f32(&[1.5], &[1]), true);
let target = Variable::new(from_f32(&[1.0], &[1]), false);
let loss = smooth_l1_loss(&pred, &target, 1.0).unwrap();
assert!((loss.item().unwrap() - 0.125).abs() < 1e-5, "got {}", loss.item().unwrap());
let pred2 = Variable::new(from_f32(&[3.0], &[1]), true);
let loss2 = smooth_l1_loss(&pred2, &target, 1.0).unwrap();
assert!((loss2.item().unwrap() - 1.5).abs() < 1e-5);
loss2.backward().unwrap();
assert!(pred2.grad().is_some());
}
#[test]
fn test_kl_div_loss() {
let log_probs = Variable::new(
from_f32(&[-0.693, -0.693, -0.693, -0.693], &[2, 2]),
true,
);
let probs = Variable::new(
from_f32(&[0.5, 0.5, 0.5, 0.5], &[2, 2]),
false,
);
let loss = kl_div_loss(&log_probs, &probs).unwrap();
assert!(loss.item().unwrap().abs() < 0.01, "KL should be ~0, got {}", loss.item().unwrap());
loss.backward().unwrap();
assert!(log_probs.grad().is_some());
}
#[test]
fn test_clip_grad_norm() {
crate::manual_seed(42);
let model = Linear::on_device(2, 1, crate::tensor::test_device()).unwrap();
let params = model.parameters();
let x = Variable::new(from_f32(&[10.0, 20.0], &[1, 2]), false);
let target = Variable::new(from_f32(&[0.0], &[1, 1]), false);
let pred = model.forward(&x).unwrap();
let loss = mse_loss(&pred, &target).unwrap();
loss.backward().unwrap();
let norm_before = clip_grad_norm(¶ms, 1.0).unwrap();
assert!(norm_before > 1.0, "large input should produce large gradients");
let mut total_sq = 0.0f64;
for p in ¶ms {
if let Some(g) = p.variable.grad() {
for &v in &g.to_f32_vec().unwrap() {
total_sq += (v as f64) * (v as f64);
}
}
}
let clipped_norm = total_sq.sqrt();
assert!(
clipped_norm <= 1.0 + 1e-5,
"clipped norm should be <= 1.0, got {}",
clipped_norm
);
}
#[test]
fn test_clip_grad_value() {
let model = Linear::on_device(2, 1, crate::tensor::test_device()).unwrap();
let params = model.parameters();
let x = Variable::new(from_f32(&[10.0, 20.0], &[1, 2]), false);
let target = Variable::new(from_f32(&[0.0], &[1, 1]), false);
let pred = model.forward(&x).unwrap();
let loss = mse_loss(&pred, &target).unwrap();
loss.backward().unwrap();
clip_grad_value(¶ms, 0.5).unwrap();
for p in ¶ms {
if let Some(g) = p.variable.grad() {
for &v in &g.to_f32_vec().unwrap() {
assert!(
v.abs() <= 0.5 + 1e-6,
"all grads should be clamped to [-0.5, 0.5], got {}",
v
);
}
}
}
}
#[test]
fn test_step_decay_scheduler() {
let sched = StepDecay::new(0.1, 3, 0.5);
assert!((sched.lr(0) - 0.1).abs() < 1e-10); assert!((sched.lr(1) - 0.1).abs() < 1e-10); assert!((sched.lr(2) - 0.1).abs() < 1e-10); assert!((sched.lr(3) - 0.05).abs() < 1e-10); assert!((sched.lr(4) - 0.05).abs() < 1e-10); assert!((sched.lr(5) - 0.05).abs() < 1e-10); assert!((sched.lr(6) - 0.025).abs() < 1e-10); }
#[test]
fn test_cosine_scheduler() {
let sched = CosineScheduler::new(0.1, 0.001, 100);
assert!((sched.lr(0) - 0.1).abs() < 1e-10);
let mid_lr = sched.lr(50);
assert!(mid_lr > 0.001 && mid_lr < 0.1, "mid lr={}", mid_lr);
assert!((sched.lr(100) - 0.001).abs() < 1e-5, "end lr={}", sched.lr(100));
}
#[test]
fn test_plateau_scheduler() {
let mut sched = PlateauScheduler::new(0.1, 3, 0.5, 0.001);
sched.observe(1.0);
sched.observe(0.9);
sched.observe(0.8);
assert!((sched.lr() - 0.1).abs() < 1e-10);
sched.observe(0.81); sched.observe(0.82); sched.observe(0.83); assert!((sched.lr() - 0.05).abs() < 1e-10);
}
#[test]
fn test_gelu() {
let gelu = GELU::new();
let x = Variable::new(from_f32(&[0.0, 1.0, -1.0], &[3]), true);
let y = gelu.forward(&x).unwrap();
let data = y.data().to_f32_vec().unwrap();
assert!(data[0].abs() < 0.01, "GELU(0)={}", data[0]);
assert!((data[1] - 0.841).abs() < 0.01, "GELU(1)={}", data[1]);
assert!((data[2] - (-0.159)).abs() < 0.01, "GELU(-1)={}", data[2]);
let loss = y.sum().unwrap();
loss.backward().unwrap();
assert!(x.grad().is_some());
}
#[test]
fn test_silu() {
let silu = SiLU::new();
let x = Variable::new(from_f32(&[0.0, 2.0, -2.0], &[3]), true);
let y = silu.forward(&x).unwrap();
let data = y.data().to_f32_vec().unwrap();
assert!(data[0].abs() < 0.01);
assert!((data[1] - 1.762).abs() < 0.02, "SiLU(2)={}", data[1]);
let loss = y.sum().unwrap();
loss.backward().unwrap();
assert!(x.grad().is_some());
}
#[test]
fn test_dropout() {
let drop = Dropout::new(0.5);
let x = Variable::new(from_f32(&[1.0; 100], &[10, 10]), false);
let y = drop.forward(&x).unwrap();
let data = y.data().to_f32_vec().unwrap();
let zeros = data.iter().filter(|&&v| v == 0.0).count();
assert!(zeros > 10 && zeros < 90, "zeros={} of 100", zeros);
for &v in &data {
if v != 0.0 {
assert!((v - 2.0).abs() < 1e-5, "scaled value should be 2.0, got {}", v);
}
}
drop.set_training(false);
let y_eval = drop.forward(&x).unwrap();
let eval_data = y_eval.data().to_f32_vec().unwrap();
assert!(eval_data.iter().all(|&v| (v - 1.0).abs() < 1e-5));
}
#[test]
fn test_layernorm() {
let ln = LayerNorm::on_device(4, crate::tensor::test_device()).unwrap();
let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]), true);
let y = ln.forward(&x).unwrap();
let data = y.data().to_f32_vec().unwrap();
assert_eq!(y.shape(), vec![1, 4]);
let mean: f32 = data.iter().sum::<f32>() / 4.0;
assert!(mean.abs() < 0.1, "mean should be ~0, got {}", mean);
let loss = y.sum().unwrap();
loss.backward().unwrap();
assert!(x.grad().is_some());
assert_eq!(ln.parameters().len(), 2); }
#[test]
fn test_embedding() {
let emb = Embedding::on_device(5, 3, crate::tensor::test_device()).unwrap();
let indices = Variable::new(from_f32(&[0.0, 2.0, 4.0], &[3]), false);
let y = emb.forward(&indices).unwrap();
assert_eq!(y.shape(), vec![3, 3]);
let idx_same = Variable::new(from_f32(&[1.0, 1.0], &[2]), false);
let y2 = emb.forward(&idx_same).unwrap();
let data = y2.data().to_f32_vec().unwrap();
assert_eq!(&data[0..3], &data[3..6]);
assert_eq!(emb.parameters().len(), 1); }
#[test]
fn test_embedding_backward() {
let emb = Embedding::on_device(5, 3, crate::tensor::test_device()).unwrap();
let indices = Variable::new(from_f32(&[0.0, 2.0], &[2]), false);
let y = emb.forward(&indices).unwrap();
let loss = y.sum().unwrap();
loss.backward().unwrap();
let grad = emb.weight.variable.grad().unwrap();
let grad_data = grad.to_f32_vec().unwrap();
assert!(grad_data[0..3].iter().all(|&v| (v - 1.0).abs() < 1e-5)); assert!(grad_data[3..6].iter().all(|&v| v.abs() < 1e-5)); assert!(grad_data[6..9].iter().all(|&v| (v - 1.0).abs() < 1e-5)); }
#[test]
fn test_grucell() {
let gru = GRUCell::on_device(4, 3, crate::tensor::test_device()).unwrap();
let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]), true);
let h1 = gru.forward_step(&x, None).unwrap();
assert_eq!(h1.shape(), vec![1, 3]);
let h2 = gru.forward_step(&x, Some(&h1)).unwrap();
assert_eq!(h2.shape(), vec![1, 3]);
let loss = h2.sum().unwrap();
loss.backward().unwrap();
assert!(x.grad().is_some());
assert_eq!(gru.parameters().len(), 4);
}
#[test]
fn test_lstmcell() {
let lstm = LSTMCell::on_device(4, 3, crate::tensor::test_device()).unwrap();
let x = Variable::new(from_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 4]), true);
let state1 = lstm.forward_step(&x, None).unwrap();
assert_eq!(state1.shape(), vec![1, 6]);
let h1 = state1.narrow(1, 0, 3).unwrap();
assert_eq!(h1.shape(), vec![1, 3]);
let state2 = lstm.forward_step(&x, Some(&state1)).unwrap();
assert_eq!(state2.shape(), vec![1, 6]);
let loss = state2.sum().unwrap();
loss.backward().unwrap();
assert!(x.grad().is_some());
assert_eq!(lstm.parameters().len(), 4);
}
#[test]
fn test_conv2d() {
let conv = Conv2d::build(1, 2, 3, true, [1, 1], [0, 0], [1, 1], 1, crate::tensor::test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[1, 1, 5, 5], crate::tensor::test_opts()).unwrap(),
true,
);
let out = conv.forward(&x).unwrap();
assert_eq!(out.shape(), vec![1, 2, 3, 3]);
let loss = out.sum().unwrap();
loss.backward().unwrap();
assert!(x.grad().is_some());
assert_eq!(conv.parameters().len(), 2);
}
#[test]
fn test_conv2d_no_bias() {
let conv = Conv2d::build(3, 8, 3, false, [1, 1], [0, 0], [1, 1], 1, crate::tensor::test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[2, 3, 8, 8], crate::tensor::test_opts()).unwrap(),
true,
);
let out = conv.forward(&x).unwrap();
assert_eq!(out.shape(), vec![2, 8, 6, 6]);
assert_eq!(conv.parameters().len(), 1); }
#[test]
fn test_conv2d_with_padding() {
let conv = Conv2d::build(1, 1, 3, true, [1, 1], [1, 1], [1, 1], 1, crate::tensor::test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[1, 1, 5, 5], crate::tensor::test_opts()).unwrap(),
true,
);
let out = conv.forward(&x).unwrap();
assert_eq!(out.shape(), vec![1, 1, 5, 5]);
}
#[test]
fn test_conv_transpose2d() {
let conv = ConvTranspose2d::build(2, 1, 3, true, [1, 1], [0, 0], [0, 0], [1, 1], 1, crate::tensor::test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[1, 2, 3, 3], crate::tensor::test_opts()).unwrap(),
true,
);
let out = conv.forward(&x).unwrap();
assert_eq!(out.shape(), vec![1, 1, 5, 5]);
let loss = out.sum().unwrap();
loss.backward().unwrap();
assert!(x.grad().is_some());
assert_eq!(conv.parameters().len(), 2);
}
#[test]
fn test_batchnorm_training() {
let bn = BatchNorm::on_device(4, crate::tensor::test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[8, 4], crate::tensor::test_opts()).unwrap(),
true,
);
let out = bn.forward(&x).unwrap();
assert_eq!(out.shape(), vec![8, 4]);
let out_data = out.data().to_f32_vec().unwrap();
let mean: f32 = out_data.iter().sum::<f32>() / out_data.len() as f32;
assert!(mean.abs() < 0.5, "mean should be close to 0, got {}", mean);
let loss = out.sum().unwrap();
loss.backward().unwrap();
assert!(x.grad().is_some());
assert_eq!(bn.parameters().len(), 2);
}
#[test]
fn test_batchnorm_eval() {
let bn = BatchNorm::on_device(3, crate::tensor::test_device()).unwrap();
for _ in 0..5 {
let x = Variable::new(
Tensor::randn(&[4, 3], crate::tensor::test_opts()).unwrap(),
false,
);
bn.forward(&x).unwrap();
}
bn.set_training(false);
let x = Variable::new(
Tensor::randn(&[2, 3], crate::tensor::test_opts()).unwrap(),
false,
);
let out = bn.forward(&x).unwrap();
assert_eq!(out.shape(), vec![2, 3]);
}
#[test]
fn test_adamw() {
let w = Parameter {
variable: Variable::new(from_f32(&[1.0, 2.0, 3.0], &[3]), true),
name: "w".into(),
};
let params = vec![w.clone()];
let mut opt = AdamW::new(¶ms, 0.01, 0.01);
let loss = w.variable.mul_scalar(2.0).unwrap().sum().unwrap();
loss.backward().unwrap();
let before = w.variable.data().to_f32_vec().unwrap();
opt.step().unwrap();
let after = w.variable.data().to_f32_vec().unwrap();
assert!(before != after, "AdamW should update parameters");
opt.zero_grad();
}
#[test]
fn test_xavier_init() {
let t = init::xavier_uniform(&[10, 20], 10, 20, crate::tensor::test_device()).unwrap();
assert_eq!(t.shape(), vec![10, 20]);
let data = t.to_f32_vec().unwrap();
let bound = (6.0 / 30.0_f64).sqrt() as f32;
for &v in &data {
assert!(v >= -bound - 0.01 && v <= bound + 0.01,
"xavier_uniform value {} out of bounds [{}, {}]", v, -bound, bound);
}
let t = init::xavier_normal(&[10, 20], 10, 20, crate::tensor::test_device()).unwrap();
assert_eq!(t.shape(), vec![10, 20]);
}
#[test]
fn test_linspace_arange() {
let t = Tensor::linspace(0.0, 1.0, 5, crate::tensor::test_opts()).unwrap();
assert_eq!(t.shape(), vec![5]);
let data = t.to_f32_vec().unwrap();
assert!((data[0] - 0.0).abs() < 1e-5);
assert!((data[4] - 1.0).abs() < 1e-5);
let t = Tensor::arange(0.0, 5.0, 1.0, crate::tensor::test_opts()).unwrap();
assert_eq!(t.shape(), vec![5]);
let data = t.to_f32_vec().unwrap();
assert_eq!(data, vec![0.0, 1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_min_max_argmax() {
let t = from_f32(&[3.0, 1.0, 4.0, 1.0, 5.0, 9.0], &[2, 3]);
assert!((t.min().unwrap().item().unwrap() - 1.0).abs() < 1e-5);
let min_d1 = t.min_dim(1, false).unwrap();
assert_eq!(min_d1.shape(), vec![2]);
let data = min_d1.to_f32_vec().unwrap();
assert!((data[0] - 1.0).abs() < 1e-5);
assert!((data[1] - 1.0).abs() < 1e-5);
let max_d1 = t.max_dim(1, false).unwrap();
let data = max_d1.to_f32_vec().unwrap();
assert!((data[0] - 4.0).abs() < 1e-5);
assert!((data[1] - 9.0).abs() < 1e-5);
let am = t.argmax(1, false).unwrap();
assert_eq!(am.shape(), vec![2]);
}
#[test]
fn test_comparisons() {
let t = from_f32(&[1.0, 2.0, 3.0, 4.0], &[4]);
let ge = t.ge_scalar(2.0).unwrap().to_f32_vec().unwrap();
assert_eq!(ge, vec![0.0, 1.0, 1.0, 1.0]);
let le = t.le_scalar(2.0).unwrap().to_f32_vec().unwrap();
assert_eq!(le, vec![1.0, 1.0, 0.0, 0.0]);
let lt = t.lt_scalar(2.0).unwrap().to_f32_vec().unwrap();
assert_eq!(lt, vec![1.0, 0.0, 0.0, 0.0]);
}
#[test]
fn test_squeeze_unsqueeze() {
let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[1, 3]), true);
let squeezed = x.squeeze(0).unwrap();
assert_eq!(squeezed.shape(), vec![3]);
let unsqueezed = squeezed.unsqueeze(1).unwrap();
assert_eq!(unsqueezed.shape(), vec![3, 1]);
let loss = unsqueezed.sum().unwrap();
loss.backward().unwrap();
assert_eq!(x.grad().unwrap().shape(), vec![1, 3]);
}
#[test]
fn test_where_cond() {
let cond = from_f32(&[1.0, 0.0, 1.0, 0.0], &[4]);
let x = from_f32(&[10.0, 20.0, 30.0, 40.0], &[4]);
let y = from_f32(&[-1.0, -2.0, -3.0, -4.0], &[4]);
let result = Tensor::where_cond(&cond, &x, &y).unwrap().to_f32_vec().unwrap();
assert_eq!(result, vec![10.0, -2.0, 30.0, -4.0]);
}
#[test]
fn test_to_dtype() {
use crate::tensor::DType;
let t = from_f32(&[1.5, 2.7], &[2]);
let t64 = t.to_dtype(DType::Float64).unwrap();
assert_eq!(t64.dtype(), DType::Float64);
}
#[test]
fn test_all_finite() {
let t = from_f32(&[1.0, 2.0, 3.0], &[3]);
assert!(t.all_finite().unwrap());
}
#[test]
fn test_save_load_checkpoint() {
let model = Linear::on_device(3, 2, crate::tensor::test_device()).unwrap();
let params = model.parameters();
params[0].variable.set_data(from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]));
params[1].variable.set_data(from_f32(&[0.1, 0.2], &[2]));
let named: Vec<(String, Parameter)> = params.into_iter()
.map(|p| (format!("linear/{}", p.name), p))
.collect();
let mut buf = Vec::new();
checkpoint::save_checkpoint(&mut buf, &named, &[], None).unwrap();
let model2 = Linear::on_device(3, 2, crate::tensor::test_device()).unwrap();
let named2: Vec<(String, Parameter)> = model2.parameters().into_iter()
.map(|p| (format!("linear/{}", p.name), p))
.collect();
let mut cursor = std::io::Cursor::new(&buf);
let report = checkpoint::load_checkpoint(&mut cursor, &named2, &[], None).unwrap();
assert_eq!(report.loaded.len(), 2);
assert!(report.missing.is_empty());
let w = named2[0].1.variable.data().to_f32_vec().unwrap();
assert_eq!(w, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = named2[1].1.variable.data().to_f32_vec().unwrap();
assert!((b[0] - 0.1).abs() < 1e-5 && (b[1] - 0.2).abs() < 1e-5);
}
#[test]
fn test_save_load_sgd_state() {
use optim::Stateful;
let model = Linear::on_device(2, 1, crate::tensor::test_device()).unwrap();
let params = model.parameters();
let mut optim = SGD::new(¶ms, 0.1, 0.9);
let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
let target = Variable::new(from_f32(&[5.0], &[1, 1]), false);
let pred = model.forward(&x).unwrap();
let loss = mse_loss(&pred, &target).unwrap();
loss.backward().unwrap();
optim.step().unwrap();
let mut buf = Vec::new();
optim.save_state(&mut buf).unwrap();
let mut optim2 = SGD::new(¶ms, 0.5, 0.9); let mut cursor = std::io::Cursor::new(&buf);
optim2.load_state(&mut cursor).unwrap();
assert!((optim2.lr() - 0.1).abs() < 1e-10, "lr should be restored");
}
#[test]
fn test_save_load_adam_state() {
use optim::Stateful;
let model = Linear::on_device(2, 1, crate::tensor::test_device()).unwrap();
let params = model.parameters();
let mut optim = Adam::new(¶ms, 0.01);
for _ in 0..2 {
optim.zero_grad();
let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
let target = Variable::new(from_f32(&[5.0], &[1, 1]), false);
let pred = model.forward(&x).unwrap();
let loss = mse_loss(&pred, &target).unwrap();
loss.backward().unwrap();
optim.step().unwrap();
}
let mut buf = Vec::new();
optim.save_state(&mut buf).unwrap();
let mut optim2 = Adam::new(¶ms, 0.5);
let mut cursor = std::io::Cursor::new(&buf);
optim2.load_state(&mut cursor).unwrap();
assert!((optim2.lr() - 0.01).abs() < 1e-10, "lr should be restored");
}
#[test]
fn test_cast_parameters() {
use crate::tensor::DType;
let model = Linear::on_device(3, 2, crate::tensor::test_device()).unwrap();
let params = model.parameters();
assert_eq!(params[0].variable.data().dtype(), DType::Float32);
amp::cast_parameters(¶ms, DType::Float64);
assert_eq!(params[0].variable.data().dtype(), DType::Float64);
assert_eq!(params[1].variable.data().dtype(), DType::Float64);
amp::cast_parameters(¶ms, DType::Float32);
assert_eq!(params[0].variable.data().dtype(), DType::Float32);
}
#[test]
fn test_grad_scaler_finite() {
use optim::Stateful;
let model = Linear::on_device(2, 1, crate::tensor::test_device()).unwrap();
let params = model.parameters();
let mut optim = SGD::new(¶ms, 0.1, 0.0);
let x = Variable::new(from_f32(&[1.0, 2.0], &[1, 2]), false);
let target = Variable::new(from_f32(&[5.0], &[1, 1]), false);
let mut scaler = amp::GradScaler::new();
let pred = model.forward(&x).unwrap();
let loss = mse_loss(&pred, &target).unwrap();
let scaled_loss = scaler.scale(&loss).unwrap();
scaled_loss.backward().unwrap();
let params_for_step = model.parameters();
let success = scaler.step(¶ms_for_step, &mut || optim.step()).unwrap();
assert!(success, "step should succeed with finite gradients");
scaler.update();
let mut buf = Vec::new();
scaler.save_state(&mut buf).unwrap();
let mut scaler2 = amp::GradScaler::new();
let mut cursor = std::io::Cursor::new(&buf);
scaler2.load_state(&mut cursor).unwrap();
assert!((scaler2.scale_factor() - scaler.scale_factor()).abs() < 1e-10);
}
#[test]
fn test_autocast_guard_toggle() {
assert!(!amp::is_autocast_enabled());
{
let _guard = amp::AutocastGuard::new(crate::tensor::DType::Float16);
assert!(amp::is_autocast_enabled());
}
assert!(!amp::is_autocast_enabled());
}
#[test]
fn test_autocast_nesting() {
assert!(!amp::is_autocast_enabled());
let outer = amp::AutocastGuard::new(crate::tensor::DType::Float16);
assert!(amp::is_autocast_enabled());
{
let _inner = amp::AutocastGuard::new(crate::tensor::DType::Float16);
assert!(amp::is_autocast_enabled());
}
assert!(amp::is_autocast_enabled());
drop(outer);
assert!(!amp::is_autocast_enabled());
}
#[test]
fn test_autocast_closure() {
assert!(!amp::is_autocast_enabled());
amp::autocast(crate::tensor::DType::Float16, || {
assert!(amp::is_autocast_enabled());
});
assert!(!amp::is_autocast_enabled());
}
#[test]
fn test_autocast_matmul_output_dtype() {
let opts = crate::tensor::test_opts();
let dev = crate::tensor::test_device();
let (device_type, _) = dev.to_ffi();
let a = Tensor::randn(&[4, 4], opts).unwrap();
let b = Tensor::randn(&[4, 4], opts).unwrap();
let c = a.matmul(&b).unwrap();
assert_eq!(c.dtype(), crate::tensor::DType::Float32);
let cast_dtype = if dev.is_cuda() {
crate::tensor::DType::Float16
} else {
crate::tensor::DType::BFloat16
};
let _guard = amp::AutocastGuard::for_device(device_type, cast_dtype);
let c_amp = a.matmul(&b).unwrap();
assert_eq!(c_amp.dtype(), cast_dtype,
"matmul under autocast should produce {:?}, got {:?}", cast_dtype, c_amp.dtype());
}
#[test]
fn test_adaptive_avg_pool2d() {
use crate::autograd::adaptive_avg_pool2d;
let x = Variable::new(
Tensor::randn(&[1, 1, 4, 4], crate::tensor::test_opts()).unwrap(),
true,
);
let out = adaptive_avg_pool2d(&x, [2, 2]).unwrap();
assert_eq!(out.shape(), vec![1, 1, 2, 2]);
let loss = out.sum().unwrap();
loss.backward().unwrap();
assert!(x.grad().is_some());
assert_eq!(x.grad().unwrap().shape(), vec![1, 1, 4, 4]);
}
#[test]
fn test_grid_sample() {
use crate::autograd::grid_sample;
let input = Variable::new(
Tensor::randn(&[1, 1, 4, 4], crate::tensor::test_opts()).unwrap(),
true,
);
let grid = Variable::new(
Tensor::rand(&[1, 2, 2, 2], crate::tensor::test_opts()).unwrap()
.mul_scalar(2.0).unwrap()
.add_scalar(-1.0).unwrap(), true,
);
let out = grid_sample(&input, &grid, 0, 0, true).unwrap();
assert_eq!(out.shape(), vec![1, 1, 2, 2]);
let loss = out.sum().unwrap();
loss.backward().unwrap();
assert!(input.grad().is_some());
assert!(grid.grad().is_some());
}
#[test]
fn test_identity() {
let id = activation::Identity;
let x = Variable::new(from_f32(&[1.0, 2.0, 3.0], &[3]), true);
let y = id.forward(&x).unwrap();
assert_eq!(y.data().to_f32_vec().unwrap(), vec![1.0, 2.0, 3.0]);
assert!(id.parameters().is_empty());
}
#[test]
fn test_cross_entropy_indices() {
let pred = Variable::new(
from_f32(&[2.0, 1.0, 1.0, 3.0], &[2, 2]),
true,
);
let target_idx = Variable::new(
Tensor::from_i64(&[0, 1], &[2], crate::tensor::test_device()).unwrap(),
false,
);
let loss_idx = cross_entropy_loss(&pred, &target_idx).unwrap();
let pred2 = Variable::new(
from_f32(&[2.0, 1.0, 1.0, 3.0], &[2, 2]),
true,
);
let target_oh = Variable::new(
from_f32(&[1.0, 0.0, 0.0, 1.0], &[2, 2]),
false,
);
let loss_oh = cross_entropy_loss(&pred2, &target_oh).unwrap();
let v1 = loss_idx.item().unwrap();
let v2 = loss_oh.item().unwrap();
assert!(
(v1 - v2).abs() < 1e-5,
"index loss ({}) should match one-hot loss ({})", v1, v2
);
loss_idx.backward().unwrap();
assert!(pred.grad().is_some());
}
#[test]
fn test_cross_entropy_indices_converges() {
let model = Linear::on_device(2, 3, crate::tensor::test_device()).unwrap();
let params = model.parameters();
let mut optim = Adam::new(¶ms, 0.05);
let x = Variable::new(from_f32(&[1.0, 0.0, 0.0, 1.0, 0.5, 0.5], &[3, 2]), false);
let target = Variable::new(
Tensor::from_i64(&[0, 1, 2], &[3], crate::tensor::test_device()).unwrap(),
false,
);
let mut last_loss = f64::MAX;
for _ in 0..200 {
optim.zero_grad();
let pred = model.forward(&x).unwrap();
let loss = cross_entropy_loss(&pred, &target).unwrap();
last_loss = loss.item().unwrap();
loss.backward().unwrap();
optim.step().unwrap();
}
assert!(last_loss < 0.5, "cross entropy with indices should converge, got {}", last_loss);
}
#[test]
fn test_batchnorm2d() {
let bn = BatchNorm2d::on_device(3, crate::tensor::test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[4, 3, 8, 8], crate::tensor::test_opts()).unwrap(),
true,
);
let out = bn.forward(&x).unwrap();
assert_eq!(out.shape(), vec![4, 3, 8, 8]);
let loss = out.sum().unwrap();
loss.backward().unwrap();
assert!(x.grad().is_some());
assert_eq!(bn.parameters().len(), 2); }
#[test]
fn test_batchnorm2d_eval() {
let bn = BatchNorm2d::on_device(4, crate::tensor::test_device()).unwrap();
for _ in 0..3 {
let x = Variable::new(
Tensor::randn(&[4, 4, 6, 6], crate::tensor::test_opts()).unwrap(),
false,
);
bn.forward(&x).unwrap();
}
bn.set_training(false);
let x = Variable::new(
Tensor::randn(&[2, 4, 6, 6], crate::tensor::test_opts()).unwrap(),
false,
);
let out = bn.forward(&x).unwrap();
assert_eq!(out.shape(), vec![2, 4, 6, 6]);
}
#[test]
fn test_conv1d() {
let conv = Conv1d::build(1, 2, 3, true, 1, 0, 1, 1, crate::tensor::test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[1, 1, 10], crate::tensor::test_opts()).unwrap(),
true,
);
let out = conv.forward(&x).unwrap();
assert_eq!(out.shape(), vec![1, 2, 8]);
let loss = out.sum().unwrap();
loss.backward().unwrap();
assert!(x.grad().is_some());
assert_eq!(conv.parameters().len(), 2);
}
#[test]
fn test_conv1d_no_bias() {
let conv = Conv1d::build(3, 8, 3, false, 1, 0, 1, 1, crate::tensor::test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[2, 3, 20], crate::tensor::test_opts()).unwrap(),
true,
);
let out = conv.forward(&x).unwrap();
assert_eq!(out.shape(), vec![2, 8, 18]);
assert_eq!(conv.parameters().len(), 1); }
#[test]
fn test_conv1d_with_padding() {
let conv = Conv1d::build(1, 1, 3, true, 1, 1, 1, 1, crate::tensor::test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[1, 1, 10], crate::tensor::test_opts()).unwrap(),
true,
);
let out = conv.forward(&x).unwrap();
assert_eq!(out.shape(), vec![1, 1, 10]);
}
#[test]
fn test_conv1d_builder() {
let conv = Conv1d::configure(3, 16, 5)
.with_stride(2)
.with_padding(2)
.on_device(crate::tensor::test_device())
.done()
.unwrap();
let x = Variable::new(
Tensor::randn(&[1, 3, 100], crate::tensor::test_opts()).unwrap(),
false,
);
let out = conv.forward(&x).unwrap();
assert_eq!(out.shape(), vec![1, 16, 50]);
}
#[test]
fn test_conv_transpose1d() {
let conv = ConvTranspose1d::build(2, 1, 3, true, 1, 0, 0, 1, 1, crate::tensor::test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[1, 2, 5], crate::tensor::test_opts()).unwrap(),
true,
);
let out = conv.forward(&x).unwrap();
assert_eq!(out.shape(), vec![1, 1, 7]);
let loss = out.sum().unwrap();
loss.backward().unwrap();
assert!(x.grad().is_some());
assert_eq!(conv.parameters().len(), 2);
}
#[test]
fn test_groupnorm() {
let gn = GroupNorm::on_device(4, 8, crate::tensor::test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[2, 8, 4, 4], crate::tensor::test_opts()).unwrap(),
true,
);
let out = gn.forward(&x).unwrap();
assert_eq!(out.shape(), vec![2, 8, 4, 4]);
let loss = out.sum().unwrap();
loss.backward().unwrap();
assert!(x.grad().is_some());
assert_eq!(gn.parameters().len(), 2);
}
#[test]
fn test_groupnorm_1d() {
let gn = GroupNorm::on_device(2, 4, crate::tensor::test_device()).unwrap();
let x = Variable::new(
Tensor::randn(&[3, 4, 10], crate::tensor::test_opts()).unwrap(),
false,
);
let out = gn.forward(&x).unwrap();
assert_eq!(out.shape(), vec![3, 4, 10]);
}
#[test]
fn test_cosine_similarity() {
let a = Tensor::from_f32(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &[2, 3], crate::tensor::test_device()).unwrap();
let b = Tensor::from_f32(&[1.0, 0.0, 0.0, 0.0, 0.0, 1.0], &[2, 3], crate::tensor::test_device()).unwrap();
let sim = a.cosine_similarity(&b, 1, 1e-8).unwrap();
let data = sim.to_f32_vec().unwrap();
assert!((data[0] - 1.0).abs() < 1e-5);
assert!(data[1].abs() < 1e-5);
}
#[test]
fn test_cosine_similarity_autograd() {
let a = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], crate::tensor::test_device()).unwrap(),
true,
);
let b = Variable::new(
Tensor::from_f32(&[4.0, 5.0, 6.0], &[1, 3], crate::tensor::test_device()).unwrap(),
false,
);
let sim = a.cosine_similarity(&b, 1, 1e-8).unwrap();
let loss = sim.sum().unwrap();
loss.backward().unwrap();
assert!(a.grad().is_some());
}
#[test]
fn test_bce_loss() {
let pred = Variable::new(
Tensor::from_f32(&[0.8, 0.2, 0.9, 0.1], &[4], crate::tensor::test_device()).unwrap(),
true,
);
let target = Variable::new(
Tensor::from_f32(&[1.0, 0.0, 1.0, 0.0], &[4], crate::tensor::test_device()).unwrap(),
false,
);
let loss = bce_loss(&pred, &target).unwrap();
let val = loss.data().item().unwrap();
assert!(val > 0.0 && val < 1.0, "bce_loss = {}", val);
loss.backward().unwrap();
assert!(pred.grad().is_some());
}
#[test]
fn test_pad_mode_constant() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[1, 1, 6], crate::tensor::test_device()).unwrap();
let padded = t.pad_mode(&[1, 1], 0, 0.0).unwrap();
assert_eq!(padded.shape(), vec![1, 1, 8]);
}
#[test]
fn test_pad_mode_reflect() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[1, 1, 6], crate::tensor::test_device()).unwrap();
let padded = t.pad_mode(&[2, 2], 1, 0.0).unwrap();
assert_eq!(padded.shape(), vec![1, 1, 10]);
let data = padded.to_f32_vec().unwrap();
assert!((data[0] - 3.0).abs() < 1e-5);
assert!((data[1] - 2.0).abs() < 1e-5);
}
#[test]
fn test_pad_mode_replicate() {
let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4], crate::tensor::test_device()).unwrap();
let padded = t.pad_mode(&[1, 1], 2, 0.0).unwrap();
assert_eq!(padded.shape(), vec![1, 1, 6]);
let data = padded.to_f32_vec().unwrap();
assert!((data[0] - 1.0).abs() < 1e-5);
assert!((data[5] - 4.0).abs() < 1e-5);
}
#[test]
fn test_interpolate_nearest() {
let t = Tensor::randn(&[1, 1, 4, 4], crate::tensor::test_opts()).unwrap();
let up = t.interpolate(&[8, 8], 0, false).unwrap();
assert_eq!(up.shape(), vec![1, 1, 8, 8]);
}
#[test]
fn test_interpolate_bilinear() {
let t = Tensor::randn(&[1, 3, 4, 4], crate::tensor::test_opts()).unwrap();
let up = t.interpolate(&[8, 8], 1, false).unwrap();
assert_eq!(up.shape(), vec![1, 3, 8, 8]);
}
#[test]
fn test_interpolate_bicubic() {
let t = Tensor::randn(&[1, 3, 8, 8], crate::tensor::test_opts()).unwrap();
let down = t.interpolate(&[4, 4], 2, false).unwrap();
assert_eq!(down.shape(), vec![1, 3, 4, 4]);
}
#[test]
fn test_eq_tensor_int64() {
let a = Tensor::from_i64(&[1, 2, 3], &[3], crate::tensor::test_device()).unwrap();
let b = Tensor::from_i64(&[1, 5, 3], &[3], crate::tensor::test_device()).unwrap();
let eq = a.eq_tensor(&b).unwrap();
assert_eq!(eq.dtype(), crate::tensor::DType::Float32);
let data = eq.to_f32_vec().unwrap();
assert_eq!(data, vec![1.0, 0.0, 1.0]);
let m = eq.mean().unwrap().item().unwrap();
assert!((m - 2.0 / 3.0).abs() < 1e-5);
}
}