use crate::ndarray;
use crate::graph::AsGraph;
use crate::ndarray_ext::{ArrayRng, NdArray};
use crate::tensor::{AsTensor, Tensor};
use crate::Float;
use scirs2_core::random::{Rng, RngExt};
use scirs2_core::ndarray::{Array1, ArrayView1, ArrayView2, ArrayViewMut2};
#[cfg(feature = "simd")]
use scirs2_core::simd::{
simd_add_f32_adaptive, simd_dot_f32_ultra, simd_fma_f32_ultra, simd_mul_f32_hyperoptimized,
};
use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
pub(crate) type BlasIF = i32;
pub(crate) type MklInt = BlasIF;
#[allow(dead_code)]
pub(crate) const CBLAS_ROW_MAJOR: i32 = 101;
#[allow(dead_code)]
pub(crate) const CBLAS_NO_TRANS: i32 = 111;
#[allow(dead_code)]
pub(crate) const CBLAS_TRANS: i32 = 112;
#[allow(dead_code)]
pub(crate) unsafe fn cblas_sgemm_simd_ultra(
_layout: i32,
_transa: i32,
_transb: i32,
m: BlasIF,
n: BlasIF,
k: BlasIF,
alpha: f32,
a: *const f32,
lda: BlasIF,
b: *const f32,
ldb: BlasIF,
beta: f32,
c: *mut f32,
ldc: BlasIF,
) {
let m_usize = m as usize;
let n_usize = n as usize;
let k_usize = k as usize;
let caps = PlatformCapabilities::detect();
if m_usize >= 256 && n_usize >= 256 && k_usize >= 256 && caps.has_avx2() {
cblas_sgemm_large_simd_ultra(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
} else if m_usize >= 64 && n_usize >= 64 && caps.has_sse() {
cblas_sgemm_medium_simd(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
} else {
cblas_sgemm_fallback(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
}
unsafe fn cblas_sgemm_large_simd_ultra(
m: BlasIF,
n: BlasIF,
k: BlasIF,
alpha: f32,
a: *const f32,
lda: BlasIF,
b: *const f32,
ldb: BlasIF,
beta: f32,
c: *mut f32,
ldc: BlasIF,
) {
let m_usize = m as usize;
let n_usize = n as usize;
let k_usize = k as usize;
const TILE_SIZE: usize = 64;
for i_tile in (0..m_usize).step_by(TILE_SIZE) {
for j_tile in (0..n_usize).step_by(TILE_SIZE) {
let i_end = (i_tile + TILE_SIZE).min(m_usize);
let j_end = (j_tile + TILE_SIZE).min(n_usize);
for i in i_tile..i_end {
for j in j_tile..j_end {
let c_offset = i * ldc as usize + j;
let mut sum = 0.0f32;
if k_usize >= 16 {
let mut a_row = Vec::with_capacity(k_usize);
let mut b_col = Vec::with_capacity(k_usize);
for ki in 0..k_usize {
a_row.push(*a.add(i * lda as usize + ki));
b_col.push(*b.add(ki * ldb as usize + j));
}
let a_array = Array1::from_vec(a_row);
let b_array = Array1::from_vec(b_col);
#[cfg(feature = "simd")]
{
sum = simd_dot_f32_ultra(&a_array.view(), &b_array.view());
}
#[cfg(not(feature = "simd"))]
{
sum = a_array
.iter()
.zip(b_array.iter())
.map(|(&a, &b)| a * b)
.sum();
}
} else {
for ki in 0..k_usize {
sum += *a.add(i * lda as usize + ki) * *b.add(ki * ldb as usize + j);
}
}
let c_ptr = c.add(c_offset);
if beta == 0.0 {
*c_ptr = alpha * sum;
} else {
*c_ptr = alpha * sum + beta * *c_ptr;
}
}
}
}
}
}
unsafe fn cblas_sgemm_medium_simd(
m: BlasIF,
n: BlasIF,
k: BlasIF,
alpha: f32,
a: *const f32,
lda: BlasIF,
b: *const f32,
ldb: BlasIF,
beta: f32,
c: *mut f32,
ldc: BlasIF,
) {
let m_usize = m as usize;
let n_usize = n as usize;
let k_usize = k as usize;
const TILE_SIZE: usize = 32;
for i_tile in (0..m_usize).step_by(TILE_SIZE) {
for j_tile in (0..n_usize).step_by(TILE_SIZE) {
let i_end = (i_tile + TILE_SIZE).min(m_usize);
let j_end = (j_tile + TILE_SIZE).min(n_usize);
for i in i_tile..i_end {
for j in j_tile..j_end {
let c_offset = i * ldc as usize + j;
let mut sum = 0.0f32;
if k_usize >= 8 {
let mut products = Vec::with_capacity(k_usize);
for ki in 0..k_usize {
products.push(
*a.add(i * lda as usize + ki) * *b.add(ki * ldb as usize + j),
);
}
let products_array = Array1::from_vec(products);
sum = f32::simd_sum(&products_array.view());
} else {
for ki in 0..k_usize {
sum += *a.add(i * lda as usize + ki) * *b.add(ki * ldb as usize + j);
}
}
let c_ptr = c.add(c_offset);
if beta == 0.0 {
*c_ptr = alpha * sum;
} else {
*c_ptr = alpha * sum + beta * *c_ptr;
}
}
}
}
}
}
unsafe fn cblas_sgemm_fallback(
m: BlasIF,
n: BlasIF,
k: BlasIF,
alpha: f32,
a: *const f32,
lda: BlasIF,
b: *const f32,
ldb: BlasIF,
beta: f32,
c: *mut f32,
ldc: BlasIF,
) {
let a_slice = std::slice::from_raw_parts(a, (m * k) as usize);
let b_slice = std::slice::from_raw_parts(b, (k * n) as usize);
let c_slice = std::slice::from_raw_parts_mut(c, (m * n) as usize);
let a_mat =
ArrayView2::from_shape((m as usize, k as usize), a_slice).expect("Operation failed");
let b_mat =
ArrayView2::from_shape((k as usize, n as usize), b_slice).expect("Operation failed");
let mut c_mat =
ArrayViewMut2::from_shape((m as usize, n as usize), c_slice).expect("Operation failed");
if beta == 0.0 {
c_mat.fill(0.0);
} else if beta != 1.0 {
c_mat.mapv_inplace(|x| x * beta);
}
let result = alpha * a_mat.dot(&b_mat);
c_mat += &result;
}
#[allow(dead_code)]
pub(crate) unsafe fn cblas_sgemm(
layout: i32,
transa: i32,
transb: i32,
m: BlasIF,
n: BlasIF,
k: BlasIF,
alpha: f32,
a: *const f32,
lda: BlasIF,
b: *const f32,
ldb: BlasIF,
beta: f32,
c: *mut f32,
ldc: BlasIF,
) {
cblas_sgemm_simd_ultra(
layout, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
);
}
#[allow(dead_code)]
pub(crate) unsafe fn cblas_dgemm(
_layout: i32,
_transa: i32,
_transb: i32,
m: BlasIF,
n: BlasIF,
k: BlasIF,
alpha: f64,
a: *const f64,
lda: BlasIF,
b: *const f64,
ldb: BlasIF,
beta: f64,
c: *mut f64,
ldc: BlasIF,
) {
let a_slice = std::slice::from_raw_parts(a, (m * k) as usize);
let b_slice = std::slice::from_raw_parts(b, (k * n) as usize);
let c_slice = std::slice::from_raw_parts_mut(c, (m * n) as usize);
let a_mat = scirs2_core::ndarray::ArrayView2::from_shape((m as usize, k as usize), a_slice)
.expect("Operation failed");
let b_mat = scirs2_core::ndarray::ArrayView2::from_shape((k as usize, n as usize), b_slice)
.expect("Operation failed");
let mut c_mat =
scirs2_core::ndarray::ArrayViewMut2::from_shape((m as usize, n as usize), c_slice)
.expect("Operation failed");
if beta == 0.0 {
c_mat.fill(0.0);
} else if beta != 1.0 {
c_mat.mapv_inplace(|x| x * beta);
}
let result = alpha * a_mat.dot(&b_mat);
c_mat += &result;
}
#[allow(dead_code)]
#[allow(non_snake_case)]
pub(crate) unsafe fn vsAdd(n: MklInt, a: *const f32, b: *const f32, y: *mut f32) {
let a_slice = std::slice::from_raw_parts(a, n as usize);
let b_slice = std::slice::from_raw_parts(b, n as usize);
let y_slice = std::slice::from_raw_parts_mut(y, n as usize);
for i in 0..n as usize {
y_slice[i] = a_slice[i] + b_slice[i];
}
}
#[allow(dead_code)]
#[allow(non_snake_case)]
pub(crate) unsafe fn vdAdd(n: MklInt, a: *const f64, b: *const f64, y: *mut f64) {
let a_slice = std::slice::from_raw_parts(a, n as usize);
let b_slice = std::slice::from_raw_parts(b, n as usize);
let y_slice = std::slice::from_raw_parts_mut(y, n as usize);
for i in 0..n as usize {
y_slice[i] = a_slice[i] + b_slice[i];
}
}
#[allow(dead_code)]
#[allow(non_snake_case)]
pub(crate) unsafe fn vsMul(n: MklInt, a: *const f32, b: *const f32, y: *mut f32) {
let a_slice = std::slice::from_raw_parts(a, n as usize);
let b_slice = std::slice::from_raw_parts(b, n as usize);
let y_slice = std::slice::from_raw_parts_mut(y, n as usize);
for i in 0..n as usize {
y_slice[i] = a_slice[i] * b_slice[i];
}
}
#[allow(dead_code)]
#[allow(non_snake_case)]
pub(crate) unsafe fn vdMul(n: MklInt, a: *const f64, b: *const f64, y: *mut f64) {
let a_slice = std::slice::from_raw_parts(a, n as usize);
let b_slice = std::slice::from_raw_parts(b, n as usize);
let y_slice = std::slice::from_raw_parts_mut(y, n as usize);
for i in 0..n as usize {
y_slice[i] = a_slice[i] * b_slice[i];
}
}
#[allow(dead_code)]
#[allow(non_snake_case)]
pub(crate) unsafe fn vsExp(n: MklInt, a: *const f32, y: *mut f32) {
let a_slice = std::slice::from_raw_parts(a, n as usize);
let y_slice = std::slice::from_raw_parts_mut(y, n as usize);
for i in 0..n as usize {
y_slice[i] = a_slice[i].exp();
}
}
#[allow(dead_code)]
#[allow(non_snake_case)]
pub(crate) unsafe fn vdExp(n: MklInt, a: *const f64, y: *mut f64) {
let a_slice = std::slice::from_raw_parts(a, n as usize);
let y_slice = std::slice::from_raw_parts_mut(y, n as usize);
for i in 0..n as usize {
y_slice[i] = a_slice[i].exp();
}
}
#[allow(dead_code)]
#[allow(non_snake_case)]
pub(crate) unsafe fn vsLn(n: MklInt, a: *const f32, y: *mut f32) {
let a_slice = std::slice::from_raw_parts(a, n as usize);
let y_slice = std::slice::from_raw_parts_mut(y, n as usize);
for i in 0..n as usize {
y_slice[i] = a_slice[i].ln();
}
}
#[allow(dead_code)]
#[allow(non_snake_case)]
pub(crate) unsafe fn vdLn(n: MklInt, a: *const f64, y: *mut f64) {
let a_slice = std::slice::from_raw_parts(a, n as usize);
let y_slice = std::slice::from_raw_parts_mut(y, n as usize);
for i in 0..n as usize {
y_slice[i] = a_slice[i].ln();
}
}
#[allow(dead_code)]
#[allow(non_snake_case)]
pub(crate) unsafe fn vsTanh(n: MklInt, a: *const f32, y: *mut f32) {
let a_slice = std::slice::from_raw_parts(a, n as usize);
let y_slice = std::slice::from_raw_parts_mut(y, n as usize);
for i in 0..n as usize {
y_slice[i] = a_slice[i].tanh();
}
}
#[allow(dead_code)]
#[allow(non_snake_case)]
pub(crate) unsafe fn vdTanh(n: MklInt, a: *const f64, y: *mut f64) {
let a_slice = std::slice::from_raw_parts(a, n as usize);
let y_slice = std::slice::from_raw_parts_mut(y, n as usize);
for i in 0..n as usize {
y_slice[i] = a_slice[i].tanh();
}
}
pub mod activation;
pub mod arithmetic;
pub mod linear_algebra;
pub mod reduction;
mod activation_ops;
pub(crate) mod array_ops;
pub(crate) mod basic_source_ops;
pub(crate) mod binary_ops;
pub(crate) mod const_gen_ops;
mod conv_ops;
pub(crate) mod dot_ops;
pub(crate) mod gradient_descent_ops;
mod gradient_ops;
mod graph_ops;
pub(crate) mod higher_order_ops;
pub(crate) mod hook_ops;
mod math_ops;
mod random_ops;
pub(crate) mod reduction_ops;
mod xent_ops;
mod decomposition_ops;
mod eigen_ops;
mod linalg_ops;
mod matrix_ops;
mod norm_ops;
mod scalar_ops;
pub(crate) mod solver_ops;
mod special_matrices;
mod advanced_tensor_ops;
mod matrix_norms;
mod matrix_solvers;
mod special_decompositions;
mod symmetric_ops;
mod advanced_decompositions;
mod iterative_solvers;
mod matrix_functions;
mod matrix_trig_functions;
mod checkpoint_ops;
mod debug_ops;
mod advanced_indexing;
mod broadcast_ops;
mod memory_optimization;
mod efficient_ops;
mod custom_activations;
mod performance_ops;
mod graph_enhancements;
mod numerical_props;
mod kronecker_ops;
#[cfg(feature = "simd")]
pub mod simd_ops;
impl<'graph, F: Float> Tensor<'graph, F> {
pub fn access_elem(self, i: isize) -> Tensor<'graph, F> {
let op = array_ops::IndexOp { index: i };
Tensor::builder(self.graph)
.append_input(self, false)
.build(op)
}
}
#[allow(dead_code)]
pub fn grad<'graph, F: Float, A, B>(ys: &[A], xs: &[B]) -> Vec<Tensor<'graph, F>>
where
A: AsRef<Tensor<'graph, F>>,
B: AsRef<Tensor<'graph, F>>,
{
use crate::gradient::compute_gradients;
let g = ys[0].as_ref().graph();
let _ys: Vec<_> = ys.iter().map(|y| sum_all(y)).collect();
let mut grads = compute_gradients(_ys.as_slice(), xs, None, g);
let mut ret = Vec::with_capacity(xs.len());
for x in xs {
if let Some(gx) = grads.extract_grad(x) {
ret.push(gx);
} else {
let zero_tensor = crate::tensor_ops::arithmetic::mul(x.as_ref(), scalar(F::zero(), g));
ret.push(zero_tensor);
}
}
ret
}
#[allow(dead_code)]
pub fn grad_with_default<'graph, F: Float, A, B>(
ys: &[A],
xs: &[B],
ys_grads: &[Tensor<'graph, F>],
) -> Vec<Tensor<'graph, F>>
where
A: AsRef<Tensor<'graph, F>>,
B: AsRef<Tensor<'graph, F>>,
{
use crate::gradient::compute_gradients;
let g = ys[0].as_ref().graph();
let mut grads = compute_gradients(ys, xs, Some(ys_grads), g);
let mut ret = Vec::with_capacity(xs.len());
for x in xs {
if let Some(gx) = grads.extract_grad(x) {
ret.push(gx);
} else {
let zero_tensor = crate::tensor_ops::arithmetic::mul(x.as_ref(), scalar(F::zero(), g));
ret.push(zero_tensor);
}
}
ret
}
#[allow(dead_code)]
pub fn jacobians<'graph, A, B, F: Float>(
y_: A,
xs_: &[B],
objective_len: usize,
) -> Vec<Tensor<'graph, F>>
where
A: AsRef<Tensor<'graph, F>>,
B: AsRef<Tensor<'graph, F>>,
{
let y = y_.as_ref();
let mut vec_vec = Vec::with_capacity(objective_len);
for i in 0..objective_len as isize {
vec_vec.push(grad(&[y.access_elem(i)], xs_));
}
let _len = xs_.len();
let mut ret = Vec::with_capacity(_len);
for i in 0.._len {
let mut jac = Vec::with_capacity(objective_len);
for vec in &vec_vec {
jac.push(expand_dims(flatten(vec[i]), &[0]));
}
ret.push(concat(&jac, 0));
}
ret
}
#[allow(dead_code)]
pub fn _hessian_vector_product<'graph, A, B, C, F: Float>(
ys: &[A],
xs: &[B],
vectors: &[C],
) -> Vec<Tensor<'graph, F>>
where
A: AsRef<Tensor<'graph, F>>,
B: AsRef<Tensor<'graph, F>>,
C: AsRef<Tensor<'graph, F>>,
{
let grads = grad(ys, xs);
let products = grads
.into_iter()
.zip(vectors)
.map(|(g, v)| *g.as_ref() * *v.as_ref())
.collect::<Vec<_>>();
grad(products.as_slice(), xs)
}
#[allow(dead_code)]
pub fn stop_gradient<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x, false)
.set_differentiable(false)
.build(gradient_ops::StopGradient)
}
#[allow(dead_code)]
pub fn shape<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
if let Some(id) = x.inner().shape {
return g.tensor(id);
}
Tensor::builder(g)
.append_input(x.as_ref(), false)
.set_differentiable(false)
.build(array_ops::Shape)
}
#[allow(dead_code)]
pub fn size<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.set_differentiable(false)
.build(array_ops::Size)
}
#[allow(dead_code)]
pub fn rank<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.set_differentiable(false)
.build(array_ops::Rank)
}
#[doc(hidden)]
#[allow(dead_code)]
pub fn nth_tensor<'graph, A, F: Float>(x: A, n: usize) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input_with_selector(x, false, n)
.build(activation_ops::Identity)
}
#[allow(dead_code)]
pub fn identity<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.setshape(&shape(x))
.build(activation_ops::Identity)
}
#[allow(dead_code)]
pub fn setdiff1d<'graph, A, B, F: Float>(a: A, b: B) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let a = a.as_ref();
let g = a.graph();
let op = array_ops::SetDiff1D;
Tensor::builder(g)
.append_input(a.as_ref(), false)
.append_input(b.as_ref(), false)
.build(op)
}
#[allow(dead_code)]
pub fn slice<'graph, A, S, E, F: Float>(x: A, starts: S, ends: E) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
S: AsRef<[isize]>,
E: AsRef<[isize]>,
{
let x = x.as_ref();
let g = x.graph();
let starts = starts.as_ref();
let ends = ends.as_ref();
assert_eq!(starts.len(), ends.len());
let starts_ends = starts.iter().zip(ends.iter());
let indices = starts_ends
.map(|(s, &e)| {
let e = if e == -1 {
None
} else {
Some(if e < -1 { e + 1 } else { e })
};
let slice = scirs2_core::ndarray::Slice::new(*s, e, 1);
scirs2_core::ndarray::SliceInfoElem::from(slice)
})
.collect::<Vec<scirs2_core::ndarray::SliceInfoElem>>();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.build(array_ops::Slice { indices })
}
#[allow(dead_code)]
pub fn gather_common<'graph, A, B, F: Float>(param: A, indices: B, axis: isize) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let _param = param.as_ref();
let g = _param.graph();
let op = array_ops::Gather {
axis,
should_normalize_negative_indices: true,
};
Tensor::builder(g)
.append_input(indices.as_ref(), false)
.append_input(_param.as_ref(), false)
.build(op)
}
#[allow(dead_code)]
pub fn gather<'graph, A, B, F: Float>(param: A, indices: B, axis: isize) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let _param = param.as_ref();
let g = _param.graph();
let op = array_ops::Gather {
axis,
should_normalize_negative_indices: false,
};
Tensor::builder(g)
.append_input(indices.as_ref(), false)
.append_input(_param, false)
.build(op)
}
#[allow(dead_code)]
pub fn reshape<'graph, A, AT, F: Float>(x: A, shape: &AT) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
AT: AsTensor<'graph, F>,
{
let x = x.as_ref();
let g = x.graph();
let t = shape.as_tensor(g);
Tensor::builder(g)
.append_input(x.as_ref(), false)
.append_input(t, false)
.setshape(&t)
.build(array_ops::Reshape)
}
#[allow(dead_code)]
pub fn flatten<'graph, A, F: Float>(x: A) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let _g = x.graph();
let shape_val = [-1i32];
reshape(x, &shape_val)
}
#[allow(dead_code)]
pub fn expand_dims<'graph, A, AT, F: Float>(x: A, axes: &AT) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
AT: AsTensor<'graph, F>,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.append_input(axes.as_tensor(g), false)
.build(array_ops::ExpandDims)
}
#[allow(dead_code)]
pub fn squeeze<'graph, A, AT, F: Float>(x: A, axes: &AT) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
AT: AsTensor<'graph, F>,
{
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.append_input(axes.as_tensor(g), false)
.build(array_ops::Squeeze)
}
#[allow(dead_code)]
pub fn dropout<'graph, A, F: Float>(x: A, dropoutratio: F, train: bool) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
dropout_rng(
x,
dropoutratio,
train,
crate::ndarray_ext::get_default_rng::<F>(),
)
}
#[allow(dead_code)]
pub fn dropout_rng<'graph, A, F: Float, R: Rng + 'static>(
x: A,
dropout_ratio: F,
train: bool,
mut rng: R,
) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let g = x.graph();
let seed = rng.random::<u64>();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.build(random_ops::Dropout {
train,
arr_rng: ArrayRng::from_seed(seed),
dropout_ratio,
})
}
#[allow(dead_code)]
pub fn map<'graph, A, F: Float>(
x: A,
f: fn(crate::ndarray_ext::NdArrayView<F>) -> NdArray<F>,
) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
use std::marker::PhantomData;
let x = x.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x.as_ref(), false)
.build(higher_order_ops::MapOp {
phantom: PhantomData,
f,
})
}
#[allow(dead_code)]
pub fn control_dependencies<'graph, A, F: Float>(
x: Tensor<'graph, F>,
deps: &[A],
) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
let g = x.graph();
if let Some(x_input) = x.get_incoming_tensor(0, g) {
let mut ctrl_deps = Tensor::builder(g).append_input(x_input, false);
for dep in deps {
ctrl_deps = ctrl_deps.append_input(dep.as_ref(), false);
}
let new_x_input = ctrl_deps.build(graph_ops::ControlDependency);
g.access_inner_mut(x.id).incoming_nodes[0].id = new_x_input.id;
x
} else {
panic!("Source tensor cannot depend on any other tensors.");
}
}
#[allow(dead_code)]
pub fn assign<'graph, A, B, F: Float>(x: A, y: B) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
B: AsRef<Tensor<'graph, F>> + Copy,
{
let x = x.as_ref();
let y = y.as_ref();
let g = x.graph();
Tensor::builder(g)
.append_input(x, true)
.append_input(y, false)
.build(array_ops::Assign)
}
#[allow(dead_code)]
pub fn convert_to_tensor<F: Float, D>(
arr: scirs2_core::ndarray::Array<F, D>,
graph: &impl AsGraph<F>,
) -> Tensor<F>
where
D: scirs2_core::ndarray::Dimension,
{
let originalshape = arr.shape().to_vec();
let arr = arr.into_dyn();
let shape_isize: Vec<isize> = originalshape.iter().map(|&s| s as isize).collect();
let tensor = Tensor::builder(graph)
.set_knownshape(&shape_isize)
.set_differentiable(false)
.build(const_gen_ops::ConvertToTensor { arr });
if let Some(ctx) = crate::graph::AsGraph::context_ref(graph) {
if let Ok(eval_result) = tensor.eval(ctx) {
if eval_result.shape() != originalshape.as_slice() {
println!(
"DEBUG: convert_to_tensor shape mismatch: Expected {:?}, got {:?}",
originalshape,
eval_result.shape()
);
}
}
}
tensor
}
#[allow(dead_code)]
pub fn scalar<F: Float>(val: F, graph: &impl AsGraph<F>) -> Tensor<F> {
let op = const_gen_ops::Scalar { val };
Tensor::builder(graph).set_knownshape(&[]).build(op)
}
#[allow(dead_code)]
pub fn random_normal<'graph, A, F: Float>(
shape: &A,
mean: f64,
stddev: f64,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
random_normal_rng(Default::default(), shape, mean, stddev, graph)
}
#[allow(dead_code)]
pub fn random_normal_rng<'graph, A, F: Float>(
arr_rng: ArrayRng<F>,
shape: &A,
mean: f64,
stddev: f64,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
let t = shape.as_tensor(graph);
Tensor::builder(graph)
.append_input(t, false)
.setshape(&t)
.build(random_ops::RandomNormal::new(arr_rng, mean, stddev))
}
#[allow(dead_code)]
pub fn random_uniform<'graph, A, F: Float>(
shape: &A,
min: f64,
max: f64,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
random_uniform_rng(Default::default(), shape, min, max, graph)
}
#[allow(dead_code)]
pub fn random_uniform_rng<'graph, A, F: Float>(
arr_rng: ArrayRng<F>,
shape: &A,
min: f64,
max: f64,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
let t = shape.as_tensor(graph);
Tensor::builder(graph)
.append_input(t, false)
.setshape(&t)
.build(random_ops::RandomUniform::new(arr_rng, min, max))
}
#[allow(dead_code)]
pub fn standard_normal<'graph, A, F: Float>(
shape: &A,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
standard_normal_rng(Default::default(), shape, graph)
}
#[allow(dead_code)]
pub fn standard_normal_rng<'graph, A, F: Float>(
arr_rng: ArrayRng<F>,
shape: &A,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
let t = shape.as_tensor(graph);
Tensor::builder(graph)
.append_input(t, false)
.setshape(&t)
.build(random_ops::StandardNormal::new(arr_rng))
}
#[allow(dead_code)]
pub fn standard_uniform<'graph, A, F: Float>(
shape: &A,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
standard_uniform_rng(Default::default(), shape, graph)
}
#[allow(dead_code)]
pub fn standard_uniform_rng<'graph, F: Float, A>(
arr_rng: ArrayRng<F>,
shape: &A,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
let t = shape.as_tensor(graph);
Tensor::builder(graph)
.append_input(t, false)
.setshape(&t)
.build(random_ops::StandardUniform::new(arr_rng))
}
#[allow(dead_code)]
pub fn bernoulli<'graph, A, F: Float>(
shape: &A,
p: f64,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
bernoulli_rng(Default::default(), shape, p, graph)
}
#[allow(dead_code)]
pub fn bernoulli_rng<'graph, A, F: Float>(
arr_rng: ArrayRng<F>,
shape: &A,
p: f64,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
let t = shape.as_tensor(graph);
Tensor::builder(graph)
.append_input(t, false)
.setshape(&t)
.build(random_ops::Bernoulli::new(arr_rng, p))
}
#[allow(dead_code)]
pub fn random_exp<'graph, A, F: Float>(
shape: &A,
lambda: f64,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
random_exp_rng(Default::default(), shape, lambda, graph)
}
#[allow(dead_code)]
pub fn random_exp_rng<'graph, A, F: Float>(
arr_rng: ArrayRng<F>,
shape: &A,
lambda: f64,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
let t = shape.as_tensor(graph);
Tensor::builder(graph)
.append_input(t, false)
.setshape(&t)
.build(random_ops::Exponential::new(arr_rng, lambda))
}
#[allow(dead_code)]
pub fn random_gamma<'graph, A, F: Float>(
shape: &A,
shape_param: f64,
scale: f64,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
random_gamma_rng(Default::default(), shape, shape_param, scale, graph)
}
#[allow(dead_code)]
pub fn random_gamma_rng<'graph, A, F: Float>(
arr_rng: ArrayRng<F>,
shape: &A,
shape_param: f64,
scale: f64,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
let t = shape.as_tensor(graph);
Tensor::builder(graph)
.append_input(t, false)
.setshape(&t)
.build(random_ops::Gamma::new(arr_rng, shape_param, scale))
}
#[allow(dead_code)]
pub fn log_normal<'graph, A, F: Float>(
shape: &A,
mean: f64,
stddev: f64,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
log_normal_rng(Default::default(), shape, mean, stddev, graph)
}
#[allow(dead_code)]
pub fn log_normal_rng<'graph, A, F: Float>(
arr_rng: ArrayRng<F>,
shape: &A,
mean: f64,
stddev: f64,
graph: &'graph impl AsGraph<F>,
) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
let t = shape.as_tensor(graph);
Tensor::builder(graph)
.append_input(t, false)
.setshape(&t)
.build(random_ops::LogNormal::new(arr_rng, mean, stddev))
}
#[allow(dead_code)]
pub fn zeros<'graph, A, F: Float>(shape: &A, graph: &'graph impl AsGraph<F>) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
let shape_tensor = shape.as_tensor(graph);
Tensor::builder(graph)
.append_input(shape_tensor, false)
.build(const_gen_ops::Zeros)
}
#[allow(dead_code)]
pub fn ones<'graph, A, F: Float>(shape: &A, graph: &'graph impl AsGraph<F>) -> Tensor<'graph, F>
where
A: AsTensor<'graph, F>,
{
Tensor::builder(graph)
.append_input(shape.as_tensor(graph), false)
.build(const_gen_ops::Ones)
}
#[allow(dead_code)]
pub fn variable<F: Float, D>(
arr: scirs2_core::ndarray::Array<F, D>,
graph: &impl AsGraph<F>,
) -> Tensor<F>
where
D: scirs2_core::ndarray::Dimension,
{
let origshape = arr.shape().to_vec();
println!("Creating variable with shape: {origshape:?}");
let arr_dyn = arr.into_dyn();
let tensor = Tensor::builder(graph).build(const_gen_ops::ConvertToTensor { arr: arr_dyn });
if let Some(ctx) = crate::graph::AsGraph::context_ref(graph) {
if let Ok(eval_result) = tensor.eval(ctx) {
println!("Created tensor with shape: {:?}", eval_result.shape());
if eval_result.shape() != origshape.as_slice() {
println!(
"WARNING: Shape mismatch! Expected {:?}, got {:?}",
origshape,
eval_result.shape()
);
}
}
}
tensor
}
impl<'g, F: Float> Tensor<'g, F> {
#[inline]
pub fn reshape<AT: AsTensor<'g, F>>(&self, shape: &AT) -> Tensor<'g, F> {
reshape(self, shape)
}
#[inline]
pub fn flatten(&self) -> Tensor<'g, F> {
flatten(self)
}
#[inline]
pub fn squeeze<AT: AsTensor<'g, F>>(&self, axes: &AT) -> Tensor<'g, F> {
squeeze(self, axes)
}
#[inline]
pub fn expand_dims<AT: AsTensor<'g, F>>(&self, axes: &AT) -> Tensor<'g, F> {
expand_dims(self, axes)
}
#[inline]
pub fn transpose<AT: AsTensor<'g, F>>(&self, axes: &AT) -> Tensor<'g, F> {
transpose(self, axes)
}
#[inline]
pub fn size(&self) -> Tensor<'g, F> {
size(self)
}
#[inline]
pub fn rank(&self) -> Tensor<'g, F> {
rank(self)
}
#[inline]
pub fn shape_tensor(&self) -> Tensor<'g, F> {
shape(self)
}
#[inline]
pub fn reduce_sum<AT: AsTensor<'g, F>>(&self, axes: &AT, keepdims: bool) -> Tensor<'g, F> {
reduce_sum(self, axes, keepdims)
}
#[inline]
pub fn reduce_mean<AT: AsTensor<'g, F>>(&self, axes: &AT, keepdims: bool) -> Tensor<'g, F> {
reduce_mean(self, axes, keepdims)
}
#[inline]
pub fn reduce_prod<AT: AsTensor<'g, F>>(&self, axes: &AT, keepdims: bool) -> Tensor<'g, F> {
reduce_prod(self, axes, keepdims)
}
#[inline]
pub fn reduce_min<AT: AsTensor<'g, F>>(&self, axes: &AT, keepdims: bool) -> Tensor<'g, F> {
reduce_min(self, axes, keepdims)
}
#[inline]
pub fn reduce_max<AT: AsTensor<'g, F>>(&self, axes: &AT, keepdims: bool) -> Tensor<'g, F> {
reduce_max(self, axes, keepdims)
}
#[inline]
pub fn reduce_variance<AT: AsTensor<'g, F>>(
&self,
axes: &AT,
keep_dims: bool,
) -> Tensor<'g, F> {
reduce_variance(self, axes, keep_dims)
}
#[inline]
pub fn sum_all<AT: AsTensor<'g, F>>(&self) -> Tensor<'g, F> {
sum_all(self)
}
#[inline]
pub fn mean_all<AT: AsTensor<'g, F>>(&self) -> Tensor<'g, F> {
mean_all(self)
}
pub fn trace(&self) -> Tensor<'g, F> {
trace(self)
}
pub fn diag(&self) -> Tensor<'g, F> {
extract_diag(self)
}
pub fn frobenius_norm(&self) -> Tensor<'g, F> {
frobenius_norm(self)
}
pub fn scalar_mul(&self, scalar: F) -> Tensor<'g, F> {
scalar_mul(self, scalar)
}
}
pub use arithmetic::{
abs as neg_abs, acos, acosh, add, asin, asinh, atan, atanh, ceil, clip, cos, cosh, digamma_f32,
digamma_f64, div, equal, exp, exp10, exp2, floor, greater, greater_equal, inv, inv_sqrt,
lesser, lesser_equal, lgamma_f32, lgamma_f64, ln, log10, log2, maximum, minimum, mul, neg,
not_equal, pow, sign, sin, sinh, sqrt, square, sub, tan, tanh,
};
pub use reduction::{
add_n, argmax, argmin, frobenius_norm, l1_norm, l2_norm, lp_norm, mean_all, reduce_all,
reduce_any, reduce_logsumexp, reduce_max, reduce_mean, reduce_min, reduce_prod, reduce_std,
reduce_sum, reduce_variance, sum_all,
};
pub use linear_algebra::{
batch_matmul, batch_matmul_t, concat, conv2d, conv2d_transpose, determinant, diag,
dilated_conv2d, eigen, eigenvalues, extract_diag, eye, lstsq, matmul, matrix_inverse,
max_pool2d, qr, scalar_mul, solve, split, svd, tensordot, trace, transpose,
};
pub use activation::{
batch_norm, elu, gelu, hard_sigmoid, hard_tanh, leaky_relu, log_softmax, mean_squared_error,
mish, normalize, relu, relu6, sigmoid, sigmoid_cross_entropy, softmax, softmax_cross_entropy,
softplus, sparse_softmax_cross_entropy, swish,
};
pub use debug_ops::{debug_identity_with_gradient, debug_scalar_one};
pub use decomposition_ops::matrix_exp;
pub use decomposition_ops::{lu, qr as decomp_qr, svd as decomp_svd};
pub use eigen_ops::{eigen as eigen_decomp, eigenvalues as eigen_vals};
pub use linalg_ops::{
diag as linalg_diag, extract_diag as linalg_extract_diag, eye as linalg_eye,
trace as linalg_trace,
};
pub use matrix_ops::{
determinant as matrix_det, matrix_inverse as matrix_inv,
pseudo_inverse as matrix_pseudo_inverse,
};
pub use norm_ops::{frobenius_norm as norm_frobenius, nuclear_norm, spectral_norm};
pub use scalar_ops::scalar_mul as scalar_multiply;
pub use solver_ops::{lstsq as linalg_lstsq, solve as linalg_solve};
pub use special_matrices::{band_matrix, cholesky, symmetrize, tril, triu};
pub use eigen_ops::eigen as eig;
pub use matrix_ops::determinant as det;
pub use matrix_ops::matrix_inverse as matinv;
pub use matrix_ops::pseudo_inverse as pinv;
pub use matrix_functions::{logm, powm, sqrtm};
pub use matrix_functions::{matrix_log, matrix_power, matrix_sqrt};
pub use numerical_props::{
cond, cond_1, cond_2, cond_fro, cond_inf, logdet, matrix_rank, slogdet, ConditionType,
};
pub use kronecker_ops::kron;
pub use matrix_norms::{norm1, norm2, normfro, norminf};
pub use matrix_solvers::{cholesky_solve, solve_lyapunov, solve_sylvester};
pub use symmetric_ops::{eigh, eigvalsh};
pub use special_decompositions::{polar, schur};
pub use advanced_tensor_ops::{einsum, kron as kron_tensor, tensor_solve};
pub use matrix_ops::{expm2, expm3};
pub use advanced_decompositions::{generalized_eigen, qr_pivot, randomized_svd, svd_jacobi};
pub use iterative_solvers::{
bicgstab_solve, conjugate_gradient_solve, gmres_solve, pcg_solve, PreconditionerType,
};
pub use matrix_trig_functions::{coshm, cosm, funm, signm, sinhm, sinm};
pub use advanced_tensor_ops::kron as kronecker_product;
pub use checkpoint_ops::{
adaptive_checkpoint, checkpoint, checkpoint_segment, checkpoint_segment_flex, detach,
CheckpointGroup, CheckpointProfiler,
};
pub use advanced_indexing::{
advanced_gather, boolean_mask, get_at_coords, scatter, select_columns, select_rows, take,
where_op,
};
pub use broadcast_ops::{
analyze_broadcast, broadcast_add, broadcast_div, broadcast_maximum, broadcast_minimum,
broadcast_mul, broadcast_pow, broadcast_sub, clear_broadcast_cache, get_broadcast_cache_stats,
BroadcastInfo, BroadcastStrategy,
};
pub use memory_optimization::{
clear_memory_pool, configure_memory_pool, disable_memory_tracking, efficient_ones,
efficient_view, efficient_zeros, enable_memory_tracking, get_memory_pool_stats,
get_memory_tracking_stats, get_pooled_buffer, inplace_abs, inplace_add, inplace_div,
inplace_mul, inplace_neg, inplace_scalar_mul, inplace_sub, reset_memory_tracking,
return_pooled_buffer, set_memory_pool_enabled, MemoryOptimizer, MemoryPoolStats,
MemoryTrackerStats,
};
pub use efficient_ops::{
clear_reshape_cache, efficient_concat, efficient_reshape, efficient_reshape_withshape,
efficient_slice, efficient_transpose, get_reshape_cache_stats, EfficientOpsManager,
EfficientOpsStats, SliceRange,
};
pub use custom_activations::{
create_custom_activation, custom_activation, is_activation_registered,
list_activation_functions, parameterized_activation, register_activation, ActivationProperties,
CustomActivation, CustomActivationBuilder,
};
pub use performance_ops::{
cache_friendly_matmul, is_parallel_enabled, is_simd_enabled, parallel_sum,
set_parallel_enabled, set_simd_enabled, simd_add, simd_mul, simd_relu, simd_sigmoid,
PerformanceConfig, ReductionOperation, SimdBinaryOperation, SimdUnaryOperation,
};
pub use graph_enhancements::{
cached_op, clear_computation_cache, conditional, configure_cache, get_cache_stats,
get_gc_stats, run_garbage_collection, smart_checkpoint, CacheStats, GcStats, GraphEnhancer,
GraphStats, PredicateType,
};
#[cfg(feature = "simd")]
pub use simd_ops::{
simd_activation_relu, simd_activation_sigmoid, simd_activation_tanh, simd_broadcast_add,
simd_broadcast_mul, simd_dot_product, simd_elementwise_add, simd_elementwise_div,
simd_elementwise_mul, simd_elementwise_sub, simd_gradient_accumulate, simd_reduction_sum,
simd_scaled_gradient_accumulate, SimdConfig, SimdDotProduct, SimdElementwiseAdd,
SimdElementwiseDiv, SimdElementwiseMul, SimdElementwiseSub, SimdGradientAccumulate, SimdReLU,
SimdReductionSum, SimdSigmoid, SimdTanh,
};
#[cfg(test)]
#[path = "mod_tests.rs"]
mod tests;