use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::TypeConversionOps;
use crate::ops::broadcast_shape;
use crate::runtime::cuda::kernels::compute_broadcast_strides;
use crate::runtime::cuda::{CudaClient, CudaDevice, CudaRuntime};
use crate::tensor::Tensor;
pub fn validate_mask_dtype(mask: &Tensor<CudaRuntime>) -> Result<()> {
if mask.dtype() != DType::U8 {
return Err(Error::DTypeMismatch {
lhs: DType::U8,
rhs: mask.dtype(),
});
}
Ok(())
}
pub fn normalize_indices_to_i64(
client: &CudaClient,
indices: &Tensor<CudaRuntime>,
) -> Result<Tensor<CudaRuntime>> {
match indices.dtype() {
DType::I64 => Ok(indices.clone()),
DType::I32 => client.cast(indices, DType::I64),
other => Err(Error::DTypeMismatch {
lhs: DType::I64,
rhs: other,
}),
}
}
pub struct BroadcastContext {
pub needs_broadcast: bool,
pub strides_tensor: Option<Tensor<CudaRuntime>>,
pub shape_tensor: Option<Tensor<CudaRuntime>>,
pub ndim: usize,
}
impl BroadcastContext {
pub fn prepare(
a: &Tensor<CudaRuntime>,
mask: &Tensor<CudaRuntime>,
device: &CudaDevice,
) -> Result<Self> {
let needs_broadcast = a.shape() != mask.shape();
if !needs_broadcast {
return Ok(Self {
needs_broadcast: false,
strides_tensor: None,
shape_tensor: None,
ndim: a.shape().len(),
});
}
let broadcast_result = broadcast_shape(a.shape(), mask.shape());
match broadcast_result {
Some(ref bcast_shape) if bcast_shape == a.shape() => {
}
_ => {
return Err(Error::BroadcastError {
lhs: a.shape().to_vec(),
rhs: mask.shape().to_vec(),
});
}
}
let mask_strides = compute_broadcast_strides(mask.shape(), a.shape());
let out_shape_u32: Vec<u32> = a.shape().iter().map(|&x| x as u32).collect();
let ndim = a.shape().len();
let strides_tensor = Tensor::<CudaRuntime>::from_slice(&mask_strides, &[ndim], device);
let shape_tensor = Tensor::<CudaRuntime>::from_slice(&out_shape_u32, &[ndim], device);
Ok(Self {
needs_broadcast: true,
strides_tensor: Some(strides_tensor),
shape_tensor: Some(shape_tensor),
ndim,
})
}
#[inline]
pub fn strides_ptr(&self) -> u64 {
debug_assert!(
self.needs_broadcast,
"strides_ptr() called on non-broadcast context"
);
self.strides_tensor.as_ref().map(|t| t.ptr()).unwrap_or(0)
}
#[inline]
pub fn shape_ptr(&self) -> u64 {
debug_assert!(
self.needs_broadcast,
"shape_ptr() called on non-broadcast context"
);
self.shape_tensor.as_ref().map(|t| t.ptr()).unwrap_or(0)
}
}