use crate::LibraryName;
use crate::encoder::EncoderExt;
use crate::kernels::utils;
use metal::MTLSize;
use tract_core::internal::*;
use tract_gpu::tensor::DeviceTensor;
pub use tract_gpu::ops::reduce::Reducer;
pub fn kernel_name(reducer: &Reducer, dt: DatumType) -> TractResult<String> {
ensure!(reducer.is_supported_dt(dt), "Unsupported dt {dt:?} for metal reduceop {reducer:?}",);
let tname = DeviceTensor::tname(dt)?;
Ok(format!("nn_ops::reduce_{}_nd3_{tname}", reducer))
}
pub fn metal_reduce_launch(
reducer: &Reducer,
input: &DeviceTensor,
axis: usize,
output: &DeviceTensor,
) -> TractResult<()> {
crate::with_metal_stream(|stream| {
stream.retain_tensor(input);
stream.retain_tensor(output);
ensure!(output.datum_type() == input.datum_type());
ensure!(output.shape()[axis] == 1);
let input_shape_nd3 = utils::reshape_to_rank_3(input.shape(), axis);
let input_strides_nd3 = Tensor::natural_strides(&input_shape_nd3);
let output_shape_nd3 = utils::reshape_to_rank_3(output.shape(), axis);
let output_strides_nd3 = Tensor::natural_strides(&output_shape_nd3);
let pipeline =
stream.load_pipeline(LibraryName::NNOps, &kernel_name(reducer, input.datum_type())?)?;
let command_buffer = stream.command_buffer();
command_buffer.encode(|encoder| {
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_metal_tensor(0, input, metal::MTLResourceUsage::Read);
encoder.set_metal_tensor(1, output, metal::MTLResourceUsage::Write);
encoder.set_slice(2, &input_shape_nd3);
encoder.set_slice(3, &input_strides_nd3);
encoder.set_slice(4, &output_strides_nd3);
let grid_size = utils::build_metal_size_for_shape(&output_shape_nd3);
let group_size =
MTLSize { width: usize::min(32, input_shape_nd3[1]) as _, height: 1, depth: 1 };
encoder.dispatch_thread_groups(grid_size, group_size);
});
Ok(())
})
}
crate::register_metal_op!(tract_core::ops::nn::Reduce, |source, node, op| {
let dt = source.node_input_facts(node.id)?[0].datum_type;
if let Ok(gpu_op) =
tract_gpu::ops::reduce::GpuReduce::from_tract_core(op, "Metal", metal_reduce_launch)
{
rule_if!(gpu_op.reducer.is_supported_dt(dt));
return Ok(Some(Box::new(gpu_op)));
}
Ok(None)
});
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::with_borrowed_metal_stream;
use derive_new::new;
use num_traits::AsPrimitive;
use num_traits::Float;
use proptest::collection::vec;
use proptest::prelude::*;
use tract_core::internal::Tensor;
use tract_core::ops::nn::Reducer as TractReducer;
use tract_core::tract_data::itertools::Itertools;
use tract_gpu::tensor::IntoDevice;
fn test_case<F>(
reducer: Reducer,
tract_reducer: TractReducer,
shape: &[usize],
axis: usize,
scale: f32,
) -> TractResult<()>
where
F: Float + Datum,
usize: AsPrimitive<f32>,
f32: AsPrimitive<F>,
{
with_borrowed_metal_stream(|stream| {
let len = shape.iter().product::<usize>();
let a = Tensor::from_shape(
shape,
&(0..len)
.map(|f| -> F {
let v: f32 = f.as_();
(v * scale).as_()
})
.collect::<Vec<_>>(),
)?
.into_device()?;
let cpu_output = tract_reducer.reduce(&[axis], &a.to_host()?.into_tensor())?;
let mut o_shape = a.shape().to_vec();
o_shape[axis] = 1;
let output = unsafe { DeviceTensor::uninitialized_dt(a.datum_type(), &o_shape)? };
metal_reduce_launch(&reducer, &a, axis, &output)?;
stream.wait_until_completed()?;
let metal_output = output;
cpu_output
.close_enough(&metal_output.to_host()?.into_tensor(), Approximation::Approximate)
.with_context(|| {
format!(
"A: {:?}, scale: {:?} Cpu: {:?}, Metal: {:?}",
a.to_host().and_then(|it| it.dump(true)),
scale,
cpu_output.dump(true),
metal_output.to_host().and_then(|it| it.dump(true))
)
})?;
Ok(())
})
}
#[test]
fn test_reduce_mean_of_squares() -> TractResult<()> {
test_case::<f32>(Reducer::MeanOfSquares, TractReducer::MeanOfSquares, &[4, 4], 1, 1.0)?;
test_case::<f16>(
Reducer::MeanOfSquares,
TractReducer::MeanOfSquares,
&[4, 4],
1,
1.0 / 100.0,
)?;
test_case::<f16>(
Reducer::MeanOfSquares,
TractReducer::MeanOfSquares,
&[1, 10],
0,
1.0 / 100.0,
)?;
test_case::<f32>(
Reducer::MeanOfSquares,
TractReducer::MeanOfSquares,
&[1, 10],
0,
1.0 / 100.0,
)?;
test_case::<f16>(
Reducer::MeanOfSquares,
TractReducer::MeanOfSquares,
&[2, 1],
1,
1.0 / 100.0,
)?;
test_case::<f32>(
Reducer::MeanOfSquares,
TractReducer::MeanOfSquares,
&[2, 1],
1,
1.0 / 100.0,
)?;
test_case::<f16>(
Reducer::MeanOfSquares,
TractReducer::MeanOfSquares,
&[2, 2, 82, 38],
1,
1.0 / 100.0,
)?;
test_case::<f16>(
Reducer::MeanOfSquares,
TractReducer::MeanOfSquares,
&[2, 2, 82, 38],
2,
1.0 / 100.0,
)?;
test_case::<f32>(
Reducer::MeanOfSquares,
TractReducer::MeanOfSquares,
&[2, 2, 82, 38],
1,
1.0 / 100.0,
)?;
test_case::<f32>(
Reducer::MeanOfSquares,
TractReducer::MeanOfSquares,
&[2, 2, 82, 38],
2,
1.0 / 100.0,
)?;
Ok(())
}
#[test]
fn test_reduce_sum() -> TractResult<()> {
test_case::<f32>(Reducer::Sum, TractReducer::Sum, &[4, 4], 1, 1.0)?;
test_case::<f16>(Reducer::Sum, TractReducer::Sum, &[4, 4], 1, 1.0 / 100.0)?;
test_case::<f16>(Reducer::Sum, TractReducer::Sum, &[1, 10], 0, 1.0 / 100.0)?;
test_case::<f32>(Reducer::Sum, TractReducer::Sum, &[1, 10], 0, 1.0 / 100.0)?;
test_case::<f16>(Reducer::Sum, TractReducer::Sum, &[2, 1], 1, 1.0 / 100.0)?;
test_case::<f32>(Reducer::Sum, TractReducer::Sum, &[2, 1], 1, 1.0 / 100.0)?;
test_case::<f16>(Reducer::Sum, TractReducer::Sum, &[2, 2, 82, 38], 1, 1.0 / 100.0)?;
test_case::<f16>(Reducer::Sum, TractReducer::Sum, &[2, 2, 82, 38], 2, 1.0 / 100.0)?;
test_case::<f32>(Reducer::Sum, TractReducer::Sum, &[2, 2, 82, 38], 1, 1.0 / 100.0)?;
test_case::<f32>(Reducer::Sum, TractReducer::Sum, &[2, 2, 82, 38], 2, 1.0 / 100.0)?;
Ok(())
}
#[test]
fn test_reduce_prod() -> TractResult<()> {
test_case::<f32>(Reducer::Prod, TractReducer::Prod, &[4, 4], 1, 1.0)?;
test_case::<f16>(Reducer::Prod, TractReducer::Prod, &[4, 4], 1, 1.0 / 100.0)?;
test_case::<f16>(Reducer::Prod, TractReducer::Prod, &[1, 10], 0, 1.0 / 100.0)?;
test_case::<f32>(Reducer::Prod, TractReducer::Prod, &[1, 10], 0, 1.0 / 100.0)?;
test_case::<f16>(Reducer::Prod, TractReducer::Prod, &[2, 1], 1, 1.0 / 100.0)?;
test_case::<f32>(Reducer::Prod, TractReducer::Prod, &[2, 1], 1, 1.0 / 100.0)?;
test_case::<f16>(Reducer::Prod, TractReducer::Prod, &[2, 2, 82, 38], 1, 1.0 / 100.0)?;
test_case::<f16>(Reducer::Prod, TractReducer::Prod, &[2, 2, 82, 38], 2, 1.0 / 100000.0)?;
test_case::<f32>(Reducer::Prod, TractReducer::Prod, &[2, 2, 82, 38], 1, 1.0 / 100.0)?;
test_case::<f32>(Reducer::Prod, TractReducer::Prod, &[2, 2, 82, 38], 2, 1.0 / 1000.0)?;
Ok(())
}
#[test]
fn test_reduce_max() -> TractResult<()> {
test_case::<f32>(Reducer::Max, TractReducer::Max, &[4, 4], 1, 1.0)?;
test_case::<f16>(Reducer::Max, TractReducer::Max, &[4, 4], 1, 1.0 / 100.0)?;
test_case::<f16>(Reducer::Max, TractReducer::Max, &[1, 10], 0, -1.0 / 100.0)?;
test_case::<f32>(Reducer::Max, TractReducer::Max, &[1, 10], 0, 1.0 / 100.0)?;
test_case::<f16>(Reducer::Max, TractReducer::Max, &[2, 1], 1, -1.0 / 100.0)?;
test_case::<f32>(Reducer::Max, TractReducer::Max, &[2, 1], 1, 1.0 / 100.0)?;
test_case::<f16>(Reducer::Max, TractReducer::Max, &[2, 2, 82, 38], 1, -1.0 / 100.0)?;
test_case::<f16>(Reducer::Max, TractReducer::Max, &[2, 2, 82, 38], 2, 1.0 / 100.0)?;
test_case::<f32>(Reducer::Max, TractReducer::Max, &[2, 2, 82, 38], 1, 1.0 / 100.0)?;
test_case::<f32>(Reducer::Max, TractReducer::Max, &[2, 2, 82, 38], 2, -1.0 / 100.0)?;
Ok(())
}
#[test]
fn test_reduce_min() -> TractResult<()> {
test_case::<f32>(Reducer::Min, TractReducer::Min, &[4, 4], 1, 1.0)?;
test_case::<f16>(Reducer::Min, TractReducer::Min, &[4, 4], 1, 1.0 / 100.0)?;
test_case::<f16>(Reducer::Min, TractReducer::Min, &[1, 10], 0, -1.0 / 100.0)?;
test_case::<f32>(Reducer::Min, TractReducer::Min, &[1, 10], 0, 1.0 / 100.0)?;
test_case::<f16>(Reducer::Min, TractReducer::Min, &[2, 1], 1, 1.0 / 100.0)?;
test_case::<f32>(Reducer::Min, TractReducer::Min, &[2, 1], 1, 1.0 / 100.0)?;
test_case::<f16>(Reducer::Min, TractReducer::Min, &[2, 2, 82, 38], 1, -1.0 / 100.0)?;
test_case::<f16>(Reducer::Min, TractReducer::Min, &[2, 2, 82, 38], 2, 1.0 / 100.0)?;
test_case::<f32>(Reducer::Min, TractReducer::Min, &[2, 2, 82, 38], 1, -1.0 / 100.0)?;
test_case::<f32>(Reducer::Min, TractReducer::Min, &[2, 2, 82, 38], 2, 1.0 / 100.0)?;
Ok(())
}
proptest::proptest! {
#[test]
fn reduce_prop_f32(pb in any::<ReduceProblem<f32>>()) {
fn run(pb: ReduceProblem<f32>) -> TractResult<()> {
let out = pb.run()?;
let reference = pb.reference()?;
out.close_enough(&reference, Approximation::Approximate)
.with_context(|| format!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true)))
}
run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?;
}
#[test]
fn reduce_prop_f16(pb in any::<ReduceProblem<f16>>()) {
fn run(pb: ReduceProblem<f16>) -> TractResult<()> {
let out = pb.run()?;
let reference = pb.reference()?;
out.close_enough(&reference, Approximation::Approximate)
.with_context(|| format!("Cpu: {:?}, Metal: {:?}", reference.dump(true), out.dump(true)))
}
run(pb).map_err(|e| TestCaseError::Fail(format!("{:?}", e).into()))?;
}
}
#[derive(Debug, new)]
pub struct ReduceProblem<F: Datum + Float>
where
F: Datum + Float,
usize: AsPrimitive<F>,
{
pub op: Reducer,
pub shape: Vec<usize>,
pub axis: usize,
pub input: Vec<F>,
}
impl<F> Arbitrary for ReduceProblem<F>
where
F: Datum + Float,
usize: AsPrimitive<F>,
{
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: ()) -> Self::Strategy {
let reducers = Reducer::ALL.into_iter().filter(|r| !r.is_logic()).collect_vec();
(0..reducers.len(), 0usize..3, 0usize..3)
.prop_flat_map(move |(op_idx, left, right)| {
let axis = left;
let shape_len = usize::min(left + right + 1, 4);
let shape = 1usize..10;
let op = reducers[op_idx];
(Just(op), vec(shape, shape_len..=shape_len), Just(axis))
})
.prop_map(|(op, shape, axis)| {
let input = (0..shape.iter().product::<usize>())
.map(|f| f.as_() / 1000.as_())
.collect::<Vec<_>>();
Self { op, shape, axis, input }
})
.boxed()
}
}
impl<F> ReduceProblem<F>
where
F: Datum + Float + std::ops::AddAssign,
usize: AsPrimitive<F>,
{
pub fn reference(&self) -> TractResult<Tensor> {
let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?;
let cpu_output = match self.op {
Reducer::Sum => TractReducer::Sum.reduce(&[self.axis], &a)?,
Reducer::Prod => TractReducer::Prod.reduce(&[self.axis], &a)?,
Reducer::MeanOfSquares => TractReducer::MeanOfSquares.reduce(&[self.axis], &a)?,
Reducer::Min => TractReducer::Min.reduce(&[self.axis], &a)?,
Reducer::Max => TractReducer::Max.reduce(&[self.axis], &a)?,
_ => unreachable!(),
};
Ok(cpu_output)
}
pub fn run(&self) -> TractResult<Tensor> {
with_borrowed_metal_stream(|stream| {
let a = Tensor::from_shape(self.shape.as_slice(), &self.input)?.into_device()?;
let mut o_shape = a.shape().to_vec();
o_shape[self.axis] = 1;
let output = unsafe { DeviceTensor::uninitialized_dt(a.datum_type(), &o_shape)? };
metal_reduce_launch(&self.op, &a, self.axis, &output)?;
stream.wait_until_completed()?;
Ok(output.to_host()?.into_tensor())
})
}
}
}