#[cfg(any(feature = "cuda", feature = "wgpu"))]
use crate::algorithm::linalg::helpers::{
validate_linalg_dtype, validate_matrix_2d, validate_square_matrix,
};
#[cfg(any(feature = "cuda", feature = "wgpu"))]
use crate::algorithm::linalg::{LinearAlgebraAlgorithms, SlogdetResult};
#[cfg(any(feature = "cuda", feature = "wgpu"))]
use crate::dtype::DType;
#[cfg(any(feature = "cuda", feature = "wgpu"))]
use crate::error::Result;
#[cfg(any(feature = "cuda", feature = "wgpu"))]
use crate::ops::{BinaryOps, CompareOps, ReduceOps, ScalarOps, TypeConversionOps, UtilityOps};
#[cfg(any(feature = "cuda", feature = "wgpu"))]
use crate::runtime::Runtime;
#[cfg(any(feature = "cuda", feature = "wgpu"))]
use crate::tensor::Tensor;
#[cfg(any(feature = "cuda", feature = "wgpu"))]
enum Triangle {
Upper,
Lower,
}
#[cfg(any(feature = "cuda", feature = "wgpu"))]
fn triangular_mask_impl<R, C>(
client: &C,
a: &Tensor<R>,
diagonal: i64,
triangle: Triangle,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: UtilityOps<R> + ScalarOps<R> + CompareOps<R> + TypeConversionOps<R> + BinaryOps<R>,
{
let (m, n) = validate_matrix_2d(a.shape())?;
let dtype = a.dtype();
let row_indices = client
.arange(0.0, m as f64, 1.0, DType::F32)?
.reshape(&[m, 1])?;
let col_indices = client
.arange(0.0, n as f64, 1.0, DType::F32)?
.reshape(&[1, n])?;
let row_plus_diag = client.add_scalar(&row_indices, diagonal as f64)?;
let mask = match triangle {
Triangle::Upper => client.ge(&col_indices, &row_plus_diag)?,
Triangle::Lower => client.le(&col_indices, &row_plus_diag)?,
};
let mask_typed = if dtype != DType::F32 {
client.cast(&mask, dtype)?
} else {
mask
};
client.mul(a, &mask_typed)
}
#[cfg(any(feature = "cuda", feature = "wgpu"))]
pub fn triu_impl<R, C>(client: &C, a: &Tensor<R>, diagonal: i64) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: UtilityOps<R> + ScalarOps<R> + CompareOps<R> + TypeConversionOps<R> + BinaryOps<R>,
{
triangular_mask_impl(client, a, diagonal, Triangle::Upper)
}
#[cfg(any(feature = "cuda", feature = "wgpu"))]
pub fn tril_impl<R, C>(client: &C, a: &Tensor<R>, diagonal: i64) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: UtilityOps<R> + ScalarOps<R> + CompareOps<R> + TypeConversionOps<R> + BinaryOps<R>,
{
triangular_mask_impl(client, a, diagonal, Triangle::Lower)
}
#[cfg(any(feature = "cuda", feature = "wgpu"))]
pub fn slogdet_impl<R, C>(client: &C, a: &Tensor<R>) -> Result<SlogdetResult<R>>
where
R: Runtime<DType = DType>,
C: LinearAlgebraAlgorithms<R>
+ UtilityOps<R>
+ BinaryOps<R>
+ CompareOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ crate::ops::UnaryOps<R>,
{
validate_linalg_dtype(a.dtype())?;
let n = validate_square_matrix(a.shape())?;
let dtype = a.dtype();
if n == 0 {
let sign = client.fill(&[], 1.0, dtype)?;
let logabsdet = client.fill(&[], 0.0, dtype)?;
return Ok(SlogdetResult { sign, logabsdet });
}
let lu_result = client.lu_decompose(a)?;
let diag = LinearAlgebraAlgorithms::diag(client, &lu_result.lu)?;
let abs_diag = client.abs(&diag)?;
let log_abs_diag = client.log(&abs_diag)?;
let logabsdet = client.sum(&log_abs_diag, &[], false)?;
let zero = client.fill(&[], 0.0, dtype)?;
let pos_mask = client.gt(&diag, &zero)?;
let neg_mask = client.lt(&diag, &zero)?;
let sign_per_elem = client.sub(&pos_mask, &neg_mask)?;
let sign_product = client.prod(&sign_per_elem, &[], false)?;
let swap_sign = if lu_result.num_swaps % 2 == 0 {
1.0
} else {
-1.0
};
let sign = client.mul_scalar(&sign_product, swap_sign)?;
Ok(SlogdetResult { sign, logabsdet })
}