use oxicuda_blas::GpuFloat;
use crate::error::{DnnError, DnnResult};
use crate::types::{TensorDesc, TensorDescMut};
pub(crate) fn nchw_dims<T: GpuFloat>(desc: &TensorDesc<T>) -> DnnResult<(u32, u32, u32, u32)> {
if desc.dims.len() != 4 {
return Err(DnnError::InvalidDimension(format!(
"expected 4-D tensor, got {}-D",
desc.dims.len()
)));
}
Ok((desc.dims[0], desc.dims[1], desc.dims[2], desc.dims[3]))
}
pub(crate) fn nchw_dims_mut<T: GpuFloat>(
desc: &TensorDescMut<T>,
) -> DnnResult<(u32, u32, u32, u32)> {
if desc.dims.len() != 4 {
return Err(DnnError::InvalidDimension(format!(
"expected 4-D tensor, got {}-D",
desc.dims.len()
)));
}
Ok((desc.dims[0], desc.dims[1], desc.dims[2], desc.dims[3]))
}
pub(crate) fn attn_dims<T: GpuFloat>(desc: &TensorDesc<T>) -> DnnResult<(u32, u32, u32, u32)> {
nchw_dims(desc)
}
pub(crate) fn attn_dims_mut<T: GpuFloat>(
desc: &TensorDescMut<T>,
) -> DnnResult<(u32, u32, u32, u32)> {
nchw_dims_mut(desc)
}