use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::broadcast_shape;
use crate::runtime::Runtime;
use crate::tensor::Tensor;
#[inline]
pub fn normalize_dim(dim: isize, ndim: usize) -> Result<usize> {
let dim_idx = if dim < 0 {
let adjusted = ndim as isize + dim;
if adjusted < 0 {
return Err(Error::InvalidDimension { dim, ndim });
}
adjusted as usize
} else {
dim as usize
};
if dim_idx >= ndim {
return Err(Error::InvalidDimension { dim, ndim });
}
Ok(dim_idx)
}
#[inline]
pub fn validate_arange(start: f64, stop: f64, step: f64) -> Result<usize> {
if step == 0.0 {
return Err(Error::InvalidArgument {
arg: "step",
reason: "step cannot be zero".to_string(),
});
}
if (stop > start && step < 0.0) || (stop < start && step > 0.0) {
return Err(Error::InvalidArgument {
arg: "step",
reason: "step sign must match direction from start to stop".to_string(),
});
}
let numel = if start == stop {
0
} else {
((stop - start) / step).ceil() as usize
};
Ok(numel)
}
#[inline]
pub fn validate_eye(n: usize, m: Option<usize>) -> (usize, usize) {
let cols = m.unwrap_or(n);
(n, cols)
}
#[inline]
pub fn ensure_contiguous<R: Runtime<DType = DType>>(tensor: &Tensor<R>) -> Tensor<R> {
if tensor.is_contiguous() {
tensor.clone()
} else {
tensor.contiguous()
}
}
#[inline]
pub fn validate_binary_dtypes<R: Runtime<DType = DType>>(
a: &Tensor<R>,
b: &Tensor<R>,
) -> Result<DType> {
if a.dtype() != b.dtype() {
return Err(Error::DTypeMismatch {
lhs: a.dtype(),
rhs: b.dtype(),
});
}
Ok(a.dtype())
}
#[inline]
pub fn compute_broadcast_shape<R: Runtime<DType = DType>>(
a: &Tensor<R>,
b: &Tensor<R>,
) -> Result<Vec<usize>> {
broadcast_shape(a.shape(), b.shape()).ok_or_else(|| Error::BroadcastError {
lhs: a.shape().to_vec(),
rhs: b.shape().to_vec(),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_dim_positive() {
assert_eq!(normalize_dim(0, 3).unwrap(), 0);
assert_eq!(normalize_dim(2, 3).unwrap(), 2);
}
#[test]
fn test_normalize_dim_negative() {
assert_eq!(normalize_dim(-1, 3).unwrap(), 2);
assert_eq!(normalize_dim(-3, 3).unwrap(), 0);
}
#[test]
fn test_normalize_dim_out_of_bounds() {
assert!(normalize_dim(3, 3).is_err());
assert!(normalize_dim(-4, 3).is_err());
}
#[test]
fn test_ensure_contiguous() {
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
use crate::tensor::Tensor;
let device = CpuDevice::new();
let a = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device);
assert!(a.is_contiguous());
let c = ensure_contiguous(&a);
assert!(c.is_contiguous());
}
}