torsh-tensor
PyTorch-compatible tensor implementation for ToRSh, built on top of scirs2.
Overview
This crate provides the core Tensor type with a familiar PyTorch-like API, wrapping scirs2's powerful autograd functionality.
Features
- PyTorch-compatible tensor operations
- Automatic differentiation support
- Broadcasting and shape manipulation
- Comprehensive indexing and slicing
- Integration with scirs2 for optimized computation
Usage
Basic Tensor Creation
use *;
// Create tensors using the tensor! macro
let a = tensor!;
let b = tensor!;
// Create tensors with specific shapes
let zeros = ;
let ones = ;
let eye = ;
// Random tensors
let uniform = ;
let normal = ;
Tensor Operations
// Element-wise operations
let c = a.add?;
let d = a.mul?;
// Matrix multiplication
let e = a.matmul?;
// Reductions
let sum = a.sum;
let mean = a.mean;
let max = a.max;
// Activation functions
let relu = a.relu;
let sigmoid = a.sigmoid;
Shape Manipulation
// Reshape
let reshaped = a.view?;
// Transpose
let transposed = a.t?;
// Squeeze and unsqueeze
let squeezed = a.squeeze;
let unsqueezed = a.unsqueeze?;
Automatic Differentiation
// Enable gradient computation
let x = tensor!.requires_grad_;
// Forward pass
let y = x.pow?.add?;
// Backward pass
y.backward?;
// Access gradient
let grad = x.grad.unwrap;
Indexing and Slicing
// Basic indexing
let element = tensor.get?;
let element_2d = tensor.get_2d?;
// Slicing with macros
let slice = tensor.index?;
// Boolean masking
let mask = tensor.gt?;
let selected = tensor.masked_select?;
License
Licensed under the Apache License, Version 2.0. See LICENSE for details.