use std::sync::Arc;
use smallvec::smallvec;
use crate::error::Error;
use crate::uop::UOp;
use crate::{BinaryOp, DType, Op, Result};
pub mod compute;
pub mod control;
pub mod data;
pub mod graph;
pub mod hardware;
pub mod memory;
pub mod reduce;
pub mod shape;
impl UOp {
pub(crate) fn promote_and_cast(lhs: Arc<Self>, rhs: Arc<Self>) -> Result<(Arc<Self>, Arc<Self>, DType)> {
let lhs_dtype = lhs.dtype();
let rhs_dtype = rhs.dtype();
if lhs_dtype == DType::Void || rhs_dtype == DType::Void {
return Err(Error::VoidTypeInOp);
}
let target_dtype = DType::least_upper_dtype(&[lhs_dtype.clone(), rhs_dtype.clone()])
.ok_or(Error::TypePromotionFailed { lhs: lhs_dtype.clone(), rhs: rhs_dtype.clone() })?;
let lhs = if lhs_dtype != target_dtype { lhs.cast(target_dtype.clone()) } else { lhs };
let rhs = if rhs_dtype != target_dtype { rhs.cast(target_dtype.clone()) } else { rhs };
Ok((lhs, rhs, target_dtype))
}
pub(crate) fn check_bitwise_dtype(dtype: DType, operation: BinaryOp) -> Result<()> {
let is_valid = dtype.is_bool() || dtype.is_int();
if !is_valid { Err(Error::InvalidDTypeForBinaryOp { operation, dtypes: smallvec![dtype] }) } else { Ok(()) }
}
pub(crate) fn check_division_by_zero(divisor: &Arc<Self>) -> Result<()> {
use crate::ConstValue;
use crate::error::DivisionByZeroSnafu;
use snafu::ensure;
if let Op::Const(const_hash) = divisor.op() {
let is_zero = match const_hash.0 {
ConstValue::Int(v) => v == 0,
ConstValue::UInt(v) => v == 0,
ConstValue::Float(v) => v == 0.0,
ConstValue::Bool(_) => false,
};
ensure!(!is_zero, DivisionByZeroSnafu);
}
Ok(())
}
pub(crate) fn validate_binary_shapes(lhs: &Arc<Self>, rhs: &Arc<Self>, op: crate::BinaryOp) -> Result<()> {
use crate::error::BinaryShapeMismatchSnafu;
use crate::shape::shapes_equal;
let lhs_shape = lhs.shape()?;
let rhs_shape = rhs.shape()?;
match (lhs_shape, rhs_shape) {
(Some(ls), Some(rs)) if !shapes_equal(ls, rs) => {
BinaryShapeMismatchSnafu { op, lhs: Box::new(ls.clone()), rhs: Box::new(rs.clone()) }.fail()
}
_ => Ok(()), }
}
pub(crate) fn validate_ternary_shapes(true_val: &Arc<Self>, false_val: &Arc<Self>) -> Result<()> {
use crate::error::TernaryBranchShapeMismatchSnafu;
use crate::shape::shapes_equal;
let true_shape = true_val.shape()?;
let false_shape = false_val.shape()?;
match (true_shape, false_shape) {
(Some(ts), Some(fs)) if !shapes_equal(ts, fs) && !ts.is_empty() && !fs.is_empty() => {
TernaryBranchShapeMismatchSnafu {
true_branch: Box::new(ts.clone()),
false_branch: Box::new(fs.clone()),
}
.fail()
}
_ => Ok(()),
}
}
pub(crate) fn validate_permutation(axes: &[usize], expected_dims: usize) -> Result<()> {
use crate::error::PermuteInvalidPermutationSnafu;
use snafu::ensure;
ensure!(
axes.len() == expected_dims,
PermuteInvalidPermutationSnafu { permutation: axes.to_vec(), expected_dims }
);
if expected_dims <= 64 {
let mut seen = 0u64;
for &axis in axes {
ensure!(
axis < expected_dims,
PermuteInvalidPermutationSnafu { permutation: axes.to_vec(), expected_dims }
);
let bit = 1u64 << axis;
ensure!(seen & bit == 0, PermuteInvalidPermutationSnafu { permutation: axes.to_vec(), expected_dims });
seen |= bit;
}
ensure!(
seen == (1u64 << expected_dims) - 1,
PermuteInvalidPermutationSnafu { permutation: axes.to_vec(), expected_dims }
);
} else {
let mut seen = vec![false; expected_dims];
for &axis in axes {
ensure!(
axis < expected_dims,
PermuteInvalidPermutationSnafu { permutation: axes.to_vec(), expected_dims }
);
ensure!(!seen[axis], PermuteInvalidPermutationSnafu { permutation: axes.to_vec(), expected_dims });
seen[axis] = true;
}
}
Ok(())
}
pub(crate) fn validate_reduce_axes(axes: &[usize], shape_dims: usize) -> Result<()> {
use crate::error::ReduceAxisInvalidSnafu;
use snafu::ensure;
for &axis in axes {
ensure!(axis < shape_dims, ReduceAxisInvalidSnafu { axis: axis as i32, shape_dims });
}
Ok(())
}
pub(crate) fn validate_flip_axes(axes: &[bool], expected_dims: usize) -> Result<()> {
use crate::error::FlipInvalidSpecSnafu;
use snafu::ensure;
ensure!(axes.len() == expected_dims, FlipInvalidSpecSnafu { expected_dims, got_dims: axes.len() });
Ok(())
}
}