use crate::shapes::TensorLayoutBuffers;
use crate::tensor::{AsTensorMut, AsTensorRef};
use khal::Shader;
use khal::backend::{GpuBackend, GpuBackendError, GpuPass};
use crate::shaders::linalg::Repeat as GpuRepeat;
#[derive(Shader)]
pub struct Repeat {
pub repeat: GpuRepeat,
}
impl Repeat {
pub fn launch(
&self,
backend: &GpuBackend,
#[cfg_attr(feature = "push_constants", allow(unused_variables))]
shapes: &mut TensorLayoutBuffers,
pass: &mut GpuPass,
mut destination: impl AsTensorMut<f32>,
source: impl AsTensorRef<f32>,
) -> Result<(), GpuBackendError> {
let mut result = destination.as_tensor_mut();
let source = source.as_tensor_ref();
let Some((mut result_shape, mut source_shape)) =
result.layout().broadcast_assign(source.layout())
else {
panic!(
"destination: {:?} is incompatible with source: {:?}",
result.layout(),
source.layout()
)
};
result_shape = result_shape.canonicalize();
source_shape = source_shape.canonicalize();
let num_threads = result.len() as u32;
#[cfg(not(feature = "push_constants"))]
{
shapes.insert(backend, source_shape)?;
shapes.insert(backend, result_shape)?;
let shape_source = shapes.get(source_shape).unwrap_or_else(|| unreachable!());
let shape_result = shapes.get(result_shape).unwrap_or_else(|| unreachable!());
let mut buf_result = result.buffer_mut();
self.repeat.call(
pass,
num_threads,
&shape_result.as_slice(),
&shape_source.as_slice(),
&mut buf_result,
&source.buffer(),
)
}
#[cfg(feature = "push_constants")]
{
let mut buf_result = result.buffer_mut();
self.repeat.call(
pass,
num_threads,
&mut buf_result,
&source.buffer(),
crate::shaders::linalg::Shapes2 {
shape_a: result_shape.into(),
shape_b: source_shape.into(),
},
)
}
}
}