use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::Runtime;
use crate::tensor::Tensor;
pub fn validate_make_complex_inputs<R: Runtime<DType = DType>>(
real: &Tensor<R>,
imag: &Tensor<R>,
) -> Result<()> {
if real.shape() != imag.shape() {
return Err(Error::ShapeMismatch {
expected: real.shape().to_vec(),
got: imag.shape().to_vec(),
});
}
if real.dtype() != imag.dtype() {
return Err(Error::DTypeMismatch {
lhs: real.dtype(),
rhs: imag.dtype(),
});
}
match real.dtype() {
DType::F32 | DType::F64 => Ok(()),
dtype => Err(Error::UnsupportedDType {
dtype,
op: "make_complex",
}),
}
}
#[cfg(feature = "wgpu")]
pub fn validate_make_complex_inputs_f32_only<R: Runtime<DType = DType>>(
real: &Tensor<R>,
imag: &Tensor<R>,
) -> Result<()> {
if real.shape() != imag.shape() {
return Err(Error::ShapeMismatch {
expected: real.shape().to_vec(),
got: imag.shape().to_vec(),
});
}
if real.dtype() != imag.dtype() {
return Err(Error::DTypeMismatch {
lhs: real.dtype(),
rhs: imag.dtype(),
});
}
match real.dtype() {
DType::F32 => Ok(()),
DType::F64 => Err(Error::UnsupportedDType {
dtype: DType::F64,
op: "make_complex (WebGPU does not support F64)",
}),
dtype => Err(Error::UnsupportedDType {
dtype,
op: "make_complex",
}),
}
}
pub fn validate_complex_real_inputs<R: Runtime<DType = DType>>(
complex: &Tensor<R>,
real: &Tensor<R>,
op: &'static str,
) -> Result<()> {
if complex.shape() != real.shape() {
return Err(Error::ShapeMismatch {
expected: complex.shape().to_vec(),
got: real.shape().to_vec(),
});
}
match (complex.dtype(), real.dtype()) {
(DType::Complex64, DType::F32) => Ok(()),
(DType::Complex128, DType::F64) => Ok(()),
(DType::Complex64, other) => Err(Error::DTypeMismatch {
lhs: DType::Complex64,
rhs: other,
}),
(DType::Complex128, other) => Err(Error::DTypeMismatch {
lhs: DType::Complex128,
rhs: other,
}),
(other, _) => Err(Error::UnsupportedDType { dtype: other, op }),
}
}
#[cfg(feature = "wgpu")]
pub fn validate_complex_real_inputs_f32_only<R: Runtime<DType = DType>>(
complex: &Tensor<R>,
real: &Tensor<R>,
op: &'static str,
) -> Result<()> {
if complex.shape() != real.shape() {
return Err(Error::ShapeMismatch {
expected: complex.shape().to_vec(),
got: real.shape().to_vec(),
});
}
match (complex.dtype(), real.dtype()) {
(DType::Complex64, DType::F32) => Ok(()),
(DType::Complex64, other) => Err(Error::DTypeMismatch {
lhs: DType::F32,
rhs: other,
}),
(DType::Complex128, _) => Err(Error::UnsupportedDType {
dtype: DType::Complex128,
op,
}),
(other, _) => Err(Error::UnsupportedDType { dtype: other, op }),
}
}