use std::cmp::Ordering;
use runmat_accelerate_api::{AccelProvider, GpuTensorHandle};
use runmat_builtins::{Tensor, Type, Value};
use runmat_macros::runtime_builtin;
use crate::{build_runtime_error, BuiltinResult, RuntimeError};
const NAME: &str = "median";
use runmat_builtins::ResolveContext;
fn median_type(args: &[Type], ctx: &ResolveContext) -> Type {
reduce_numeric_type(args, ctx)
}
use crate::builtins::common::arg_tokens::tokens_from_values;
use crate::builtins::common::random_args::keyword_of;
use crate::builtins::common::spec::{
BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
};
use crate::builtins::common::{gpu_helpers, tensor};
use crate::builtins::math::reduction::type_resolvers::reduce_numeric_type;
#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::reduction::median")]
pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
name: "median",
op_kind: GpuOpKind::Reduction,
supported_precisions: &[ScalarType::F32, ScalarType::F64],
broadcast: BroadcastSemantics::Matlab,
provider_hooks: &[
ProviderHook::Reduction {
name: "reduce_median_dim",
},
ProviderHook::Reduction {
name: "reduce_median",
},
],
constant_strategy: ConstantStrategy::InlineLiteral,
residency: ResidencyPolicy::NewHandle,
nan_mode: ReductionNaN::Include,
two_pass_threshold: None,
workgroup_size: None,
accepts_nan_mode: false,
notes:
"Providers may execute medians entirely on device; runtimes fall back to host when hooks are missing or omitnan is requested.",
};
fn median_error(message: impl Into<String>) -> RuntimeError {
build_runtime_error(message).with_builtin(NAME).build()
}
#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::reduction::median")]
pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
name: "median",
shape: ShapeRequirements::BroadcastCompatible,
constant_strategy: ConstantStrategy::InlineLiteral,
elementwise: None,
reduction: None,
emits_nan: true,
notes:
"Fusion planner gathers to the host; future kernels may expose order-statistic reductions.",
};
#[derive(Clone)]
enum MedianAxes {
Default,
Dim(usize),
Vec(Vec<usize>),
All,
}
struct ParsedArguments {
axes: MedianAxes,
nan_mode: ReductionNaN,
}
#[runtime_builtin(
name = "median",
category = "math/reduction",
summary = "Median of scalars, vectors, matrices, or N-D tensors.",
keywords = "median,reduction,omitnan,includenan,statistics,gpu",
accel = "reduction",
type_resolver(median_type),
builtin_path = "crate::builtins::math::reduction::median"
)]
async fn median_builtin(value: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
let parsed = parse_arguments(&rest).await?;
match value {
Value::GpuTensor(handle) => median_gpu(handle, &parsed).await,
other => median_host(other, &parsed),
}
}
async fn parse_arguments(args: &[Value]) -> BuiltinResult<ParsedArguments> {
let mut axes = MedianAxes::Default;
let mut axes_set = false;
let mut nan_mode = ReductionNaN::Include;
let tokens = tokens_from_values(args);
let mut idx = 0;
while idx < args.len() {
let arg = &args[idx];
if let Some(crate::builtins::common::arg_tokens::ArgToken::String(text)) = tokens.get(idx) {
match text.as_str() {
"omitnan" => {
nan_mode = ReductionNaN::Omit;
idx += 1;
continue;
}
"includenan" => {
nan_mode = ReductionNaN::Include;
idx += 1;
continue;
}
"all" => {
if axes_set && !matches!(axes, MedianAxes::Default) {
return Err(median_error(
"median: 'all' cannot be combined with an explicit dimension",
));
}
axes = MedianAxes::All;
axes_set = true;
idx += 1;
continue;
}
_ => {}
}
}
if let Some(keyword) = keyword_of(arg) {
match keyword.as_str() {
"omitnan" => {
nan_mode = ReductionNaN::Omit;
idx += 1;
continue;
}
"includenan" => {
nan_mode = ReductionNaN::Include;
idx += 1;
continue;
}
"all" => {
if axes_set && !matches!(axes, MedianAxes::Default) {
return Err(median_error(
"median: 'all' cannot be combined with an explicit dimension",
));
}
axes = MedianAxes::All;
axes_set = true;
idx += 1;
continue;
}
"" => {
return Err(median_error(
"median: keyword arguments must not be empty strings",
));
}
_ => {
if let Some(original) = value_as_str(arg) {
return Err(median_error(format!(
"median: unrecognised argument '{original}'"
)));
} else {
return Err(median_error(format!(
"median: unrecognised argument {arg:?}"
)));
}
}
}
}
if !axes_set || matches!(axes, MedianAxes::Default) {
if let Some(selection) = parse_axes(arg).await? {
if matches!(selection, MedianAxes::All) {
if axes_set && !matches!(axes, MedianAxes::Default) {
return Err(median_error(
"median: 'all' cannot be combined with an explicit dimension",
));
}
axes = MedianAxes::All;
} else {
axes = selection;
}
axes_set = true;
idx += 1;
continue;
}
} else if parse_axes(arg).await?.is_some() {
return Err(median_error(
"median: multiple dimension specifications provided",
));
}
return Err(median_error(format!(
"median: unrecognised argument {arg:?}"
)));
}
Ok(ParsedArguments { axes, nan_mode })
}
fn median_host(value: Value, args: &ParsedArguments) -> BuiltinResult<Value> {
let tensor = tensor::value_into_tensor_for("median", value).map_err(median_error)?;
let reduced = median_tensor(tensor, args.axes.clone(), args.nan_mode)?;
Ok(tensor::tensor_into_value(reduced))
}
async fn median_gpu(handle: GpuTensorHandle, args: &ParsedArguments) -> BuiltinResult<Value> {
if args.nan_mode == ReductionNaN::Include {
if let Some(provider) = runmat_accelerate_api::provider() {
if let Some(device_result) = median_gpu_try(provider, &handle, &args.axes).await {
return Ok(Value::GpuTensor(device_result));
}
}
}
let gathered = gpu_helpers::gather_tensor_async(&handle).await?;
let reduced = median_tensor(gathered, args.axes.clone(), args.nan_mode)?;
Ok(tensor::tensor_into_value(reduced))
}
async fn median_gpu_try(
provider: &dyn AccelProvider,
handle: &GpuTensorHandle,
axes: &MedianAxes,
) -> Option<GpuTensorHandle> {
match axes {
MedianAxes::Default => {
if handle.shape.is_empty() {
Some(handle.clone())
} else {
let dim = default_dimension_from_shape(&handle.shape);
reduce_median_dim_gpu(provider, handle.clone(), dim).await
}
}
MedianAxes::Dim(dim) => reduce_median_dim_gpu(provider, handle.clone(), *dim).await,
MedianAxes::Vec(dims) => {
let mut result = handle.clone();
let mut dims_sorted = dims.clone();
dims_sorted.sort_unstable();
dims_sorted.dedup();
for dim in dims_sorted {
match reduce_median_dim_gpu(provider, result, dim).await {
Some(next) => result = next,
None => return None,
}
}
Some(result)
}
MedianAxes::All => {
if handle.shape.is_empty() {
Some(handle.clone())
} else {
provider
.reduce_median(handle)
.await
.map_err(|err| {
log::trace!("median: provider reduce_median fallback triggered: {err}");
err
})
.ok()
}
}
}
}
async fn reduce_median_dim_gpu(
provider: &dyn AccelProvider,
handle: GpuTensorHandle,
dim: usize,
) -> Option<GpuTensorHandle> {
if dim == 0 {
return None;
}
if handle.shape.len() < dim {
return Some(handle);
}
provider
.reduce_median_dim(&handle, dim - 1)
.await
.map_err(|err| {
log::trace!("median: provider reduce_median_dim fallback triggered: {err}");
err
})
.ok()
}
fn median_tensor(
tensor: Tensor,
axes: MedianAxes,
nan_mode: ReductionNaN,
) -> BuiltinResult<Tensor> {
match axes {
MedianAxes::Default => {
let dim = default_dimension(&tensor);
reduce_tensor_median_dim(&tensor, dim, nan_mode)
}
MedianAxes::Dim(dim) => reduce_tensor_median_dim(&tensor, dim, nan_mode),
MedianAxes::Vec(mut dims) => {
let mut current = tensor;
dims.sort_unstable();
dims.dedup();
if dims.is_empty() {
let dim = default_dimension(¤t);
current = reduce_tensor_median_dim(¤t, dim, nan_mode)?;
return Ok(current);
}
for dim in dims {
current = reduce_tensor_median_dim(¤t, dim, nan_mode)?;
}
Ok(current)
}
MedianAxes::All => {
if tensor.shape.is_empty() {
Ok(tensor)
} else {
let mut current = tensor;
let rank = current.shape.len();
for dim in 1..=rank {
current = reduce_tensor_median_dim(¤t, dim, nan_mode)?;
}
Ok(current)
}
}
}
}
async fn parse_axes(value: &Value) -> BuiltinResult<Option<MedianAxes>> {
if let Some(text) = value_as_str(value) {
let trimmed = text.trim();
if trimmed.is_empty() {
return Err(median_error("median: dimension string must not be empty"));
}
let lowered = trimmed.to_ascii_lowercase();
return match lowered.as_str() {
"all" => Ok(Some(MedianAxes::All)),
"omitnan" | "includenan" => Ok(None),
_ => Err(median_error(format!(
"median: unrecognised argument '{trimmed}'"
))),
};
}
let (scalar_hint, is_empty) = match value {
Value::Num(_) | Value::Int(_) => (true, false),
Value::Tensor(t) => (t.data.len() == 1, t.data.is_empty()),
Value::LogicalArray(logical) => (logical.data.len() == 1, logical.data.is_empty()),
Value::GpuTensor(handle) => {
let count = tensor::element_count(&handle.shape);
(handle.shape.is_empty() || count == 1, count == 0)
}
_ => (false, false),
};
if is_empty {
return Ok(Some(MedianAxes::Default));
}
let dims = match value {
Value::Tensor(_)
| Value::LogicalArray(_)
| Value::Int(_)
| Value::Num(_)
| Value::GpuTensor(_) => tensor::dims_from_value_async(value)
.await
.map_err(|err| map_dims_error(err, scalar_hint))?,
Value::Bool(_) => {
return Err(median_error("median: dimension must be numeric"));
}
_ => return Ok(None),
};
let Some(dims) = dims else {
return Ok(None);
};
if dims.is_empty() {
return Ok(Some(MedianAxes::Default));
}
if dims.len() == 1 {
let dim = dims[0];
if dim < 1 {
return Err(median_error("median: dimension must be >= 1"));
}
return Ok(Some(MedianAxes::Dim(dim)));
}
for &dim in &dims {
if dim < 1 {
return Err(median_error("median: dimension entries must be >= 1"));
}
}
Ok(Some(MedianAxes::Vec(dims)))
}
fn map_dims_error(message: String, scalar: bool) -> RuntimeError {
if message.contains("non-negative") {
if scalar {
return median_error("median: dimension must be >= 1");
}
return median_error("median: dimension entries must be >= 1");
}
if message.contains("finite") {
if scalar {
return median_error("median: dimension must be finite");
}
return median_error("median: dimension entries must be finite integers");
}
if message.contains("integer") {
if scalar {
return median_error("median: dimension must be an integer");
}
return median_error("median: dimension entries must be integers");
}
median_error(message)
}
fn value_as_str(value: &Value) -> Option<String> {
match value {
Value::String(s) => Some(s.clone()),
Value::StringArray(sa) if sa.data.len() == 1 => Some(sa.data[0].clone()),
Value::CharArray(ca) if ca.rows == 1 => Some(ca.data.iter().collect()),
_ => None,
}
}
fn reduce_tensor_median_dim(
tensor: &Tensor,
dim: usize,
nan_mode: ReductionNaN,
) -> BuiltinResult<Tensor> {
if dim == 0 {
return Err(median_error("median: dimension must be >= 1"));
}
if tensor.shape.is_empty() {
let value = tensor.data.first().copied().unwrap_or(f64::NAN);
return Tensor::new(vec![value], vec![1, 1])
.map_err(|e| median_error(format!("median: {e}")));
}
if dim > tensor.shape.len() {
return Ok(tensor.clone());
}
let dim_index = dim - 1;
let reduce_len = tensor.shape[dim_index];
let Some(output_shape) = reduction_shape(&tensor.shape, dim) else {
return Ok(tensor.clone());
};
if reduce_len == 0 || tensor.data.is_empty() {
let fill = vec![f64::NAN; tensor::element_count(&output_shape)];
return Tensor::new(fill, output_shape).map_err(|e| median_error(format!("median: {e}")));
}
if reduce_len == 1 {
return Tensor::new(tensor.data.clone(), tensor.shape.clone())
.map_err(|e| median_error(format!("median: {e}")));
}
let stride_before = dim_product(&tensor.shape[..dim_index]);
let stride_after = dim_product(&tensor.shape[dim..]);
let mut output = vec![0.0f64; tensor::element_count(&output_shape)];
for after in 0..stride_after {
for before in 0..stride_before {
let mut slice = Vec::with_capacity(reduce_len);
let mut saw_nan = false;
for k in 0..reduce_len {
let idx = before + k * stride_before + after * stride_before * reduce_len;
let value = tensor.data[idx];
match nan_mode {
ReductionNaN::Include => {
if value.is_nan() {
saw_nan = true;
break;
}
slice.push(value);
}
ReductionNaN::Omit => {
if value.is_nan() {
continue;
}
slice.push(value);
}
}
}
let out_idx = after * stride_before + before;
if saw_nan {
output[out_idx] = f64::NAN;
continue;
}
if slice.is_empty() {
output[out_idx] = f64::NAN;
continue;
}
let median = compute_median_inplace(&mut slice);
output[out_idx] = median;
}
}
Tensor::new(output, output_shape).map_err(|e| median_error(format!("median: {e}")))
}
pub fn compute_median_inplace(values: &mut [f64]) -> f64 {
values.sort_by(|a, b| partial_cmp_f64(*a, *b));
let len = values.len();
if len % 2 == 1 {
values[len / 2]
} else {
let upper = values[len / 2];
let lower = values[len / 2 - 1];
0.5 * (lower + upper)
}
}
fn partial_cmp_f64(a: f64, b: f64) -> Ordering {
a.partial_cmp(&b).unwrap_or(Ordering::Less)
}
fn reduction_shape(shape: &[usize], dim: usize) -> Option<Vec<usize>> {
if dim == 0 {
return None;
}
if shape.is_empty() {
return Some(vec![1, 1]);
}
if dim > shape.len() {
return None;
}
let mut out = shape.to_vec();
out[dim - 1] = 1;
Some(out)
}
fn dim_product(dims: &[usize]) -> usize {
dims.iter()
.copied()
.fold(1usize, |acc, v| acc.saturating_mul(v))
}
fn default_dimension(tensor: &Tensor) -> usize {
default_dimension_from_shape(&tensor.shape)
}
fn default_dimension_from_shape(shape: &[usize]) -> usize {
if shape.is_empty() {
return 1;
}
shape
.iter()
.position(|&extent| extent != 1)
.map(|idx| idx + 1)
.unwrap_or(1)
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::builtins::common::test_support;
use futures::executor::block_on;
use runmat_builtins::IntValue;
#[test]
fn median_type_reduces_first_dim() {
let out = median_type(
&[Type::Tensor {
shape: Some(vec![Some(2), Some(5)]),
}],
&ResolveContext::new(Vec::new()),
);
assert_eq!(
out,
Type::Tensor {
shape: Some(vec![Some(1), Some(5)])
}
);
}
fn median_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
block_on(super::median_builtin(value, rest))
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn median_scalar_num() {
let result = median_builtin(Value::Num(5.0), Vec::new()).expect("median");
assert_eq!(result, Value::Num(5.0));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn median_vector_odd_length() {
let tensor = Tensor::new(vec![7.0, 2.0, 9.0, 4.0, 5.0], vec![5, 1]).unwrap();
let result = median_builtin(Value::Tensor(tensor), Vec::new()).expect("median");
assert_eq!(result, Value::Num(5.0));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn median_vector_even_length() {
let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 10.0], vec![4, 1]).unwrap();
let result = median_builtin(Value::Tensor(tensor), Vec::new()).expect("median");
assert_eq!(result, Value::Num(6.5));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn median_matrix_default_dimension() {
let tensor = Tensor::new(vec![1.0, 7.0, 2.0, 9.0, 5.0, 11.0], vec![3, 2]).expect("tensor");
let result = median_builtin(Value::Tensor(tensor), Vec::new()).expect("median");
match result {
Value::Tensor(out) => {
assert_eq!(out.shape, vec![1, 2]);
assert_eq!(out.data, vec![2.0, 9.0]);
}
other => panic!("expected tensor result, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn median_matrix_dimension_two() {
let tensor = Tensor::new(vec![1.0, 3.0, 5.0, 7.0, 9.0, 11.0], vec![3, 2]).expect("tensor");
let result = median_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(2))])
.expect("median");
match result {
Value::Tensor(out) => {
assert_eq!(out.shape, vec![3, 1]);
assert_eq!(out.data, vec![4.0, 6.0, 8.0]);
}
other => panic!("expected tensor result, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn median_all_across_matrix() {
let tensor = Tensor::new(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], vec![3, 2]).unwrap();
let result =
median_builtin(Value::Tensor(tensor), vec![Value::from("all")]).expect("median");
match result {
Value::Num(v) => assert!((v - 3.5).abs() < 1e-12),
other => panic!("expected scalar result, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn median_vecdim_multiple_axes() {
let tensor =
Tensor::new((1..=8).map(|v| v as f64).collect::<Vec<_>>(), vec![2, 2, 2]).unwrap();
let dims = Tensor::new(vec![1.0, 3.0], vec![1, 2]).unwrap();
let result =
median_builtin(Value::Tensor(tensor), vec![Value::Tensor(dims)]).expect("median");
match result {
Value::Tensor(out) => {
assert_eq!(out.shape, vec![1, 2, 1]);
assert_eq!(out.data.len(), 2);
assert!((out.data[0] - 3.5).abs() < 1e-12);
assert!((out.data[1] - 5.5).abs() < 1e-12);
}
other => panic!("expected tensor result, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn median_with_omit_nan() {
let tensor = Tensor::new(vec![1.0, f64::NAN, 5.0], vec![3, 1]).unwrap();
let result =
median_builtin(Value::Tensor(tensor), vec![Value::from("omitnan")]).expect("median");
assert_eq!(result, Value::Num(3.0));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn median_with_include_nan_propagates() {
let tensor = Tensor::new(vec![1.0, f64::NAN, 5.0], vec![3, 1]).unwrap();
let result = median_builtin(Value::Tensor(tensor), Vec::new()).expect("median");
match result {
Value::Num(n) => assert!(n.is_nan()),
other => panic!("expected scalar NaN, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn median_empty_returns_nan() {
let tensor = Tensor::new(vec![], vec![0, 1]).unwrap();
let result = median_builtin(Value::Tensor(tensor), Vec::new()).expect("median");
match result {
Value::Num(n) => assert!(n.is_nan()),
other => panic!("expected NaN scalar, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn median_dimension_greater_than_ndims_returns_input() {
let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
let original = tensor.clone();
let result = median_builtin(Value::Tensor(tensor), vec![Value::Int(IntValue::I32(5))])
.expect("median");
match result {
Value::Tensor(out) => assert_eq!(out, original),
Value::Num(n) => assert_eq!(n, original.data[0]),
other => panic!("expected tensor result, got {other:?}"),
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn median_rejects_unknown_keyword() {
let err = median_builtin(Value::Num(1.0), vec![Value::from("like")]).unwrap_err();
assert!(
err.message().contains("unrecognised argument"),
"unexpected error message: {err}"
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn median_gpu_provider_roundtrip() {
test_support::with_test_provider(|provider| {
let tensor = Tensor::new(vec![1.0, 4.0, 9.0, 16.0], vec![4, 1]).unwrap();
let view = runmat_accelerate_api::HostTensorView {
data: &tensor.data,
shape: &tensor.shape,
};
let handle = provider.upload(&view).expect("upload");
let result = median_builtin(Value::GpuTensor(handle), Vec::new()).expect("median");
let gathered = test_support::gather(result).expect("gather");
assert_eq!(gathered.shape, vec![1, 1]);
assert_eq!(gathered.data[0], 6.5);
});
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn median_gpu_omit_nan_falls_back_to_host() {
test_support::with_test_provider(|provider| {
let tensor = Tensor::new(vec![f64::NAN, 2.0, f64::NAN, 4.0], vec![4, 1]).unwrap();
let view = runmat_accelerate_api::HostTensorView {
data: &tensor.data,
shape: &tensor.shape,
};
let handle = provider.upload(&view).expect("upload");
let result = median_builtin(Value::GpuTensor(handle), vec![Value::from("omitnan")])
.expect("median");
let gathered = test_support::gather(result).expect("gather");
assert_eq!(gathered.shape, vec![1, 1]);
assert_eq!(gathered.data[0], 3.0);
});
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "wgpu")]
fn median_wgpu_dim_matches_cpu() {
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let tensor = Tensor::new(vec![1.0, 5.0, 9.0, 2.0, 6.0, 10.0], vec![3, 2]).unwrap();
let args_dim1 = ParsedArguments {
axes: MedianAxes::Dim(1),
nan_mode: ReductionNaN::Include,
};
let cpu = median_host(Value::Tensor(tensor.clone()), &args_dim1).expect("cpu median");
let view = runmat_accelerate_api::HostTensorView {
data: &tensor.data,
shape: &tensor.shape,
};
let handle = runmat_accelerate_api::provider()
.unwrap()
.upload(&view)
.expect("upload");
let gpu_value = block_on(median_gpu(handle, &args_dim1)).expect("gpu median");
let gathered = test_support::gather(gpu_value).expect("gather");
match (cpu, gathered) {
(Value::Tensor(ct), gt) => {
assert_eq!(ct.shape, gt.shape);
let tol = match runmat_accelerate_api::provider().unwrap().precision() {
runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
runmat_accelerate_api::ProviderPrecision::F32 => 5e-5,
};
for (a, b) in ct.data.iter().zip(gt.data.iter()) {
assert!((a - b).abs() < tol, "|{} - {}| >= {}", a, b, tol);
}
}
_ => panic!("unexpected shapes"),
}
let args_all = ParsedArguments {
axes: MedianAxes::All,
nan_mode: ReductionNaN::Include,
};
let cpu_all =
median_host(Value::Tensor(tensor.clone()), &args_all).expect("cpu median all");
let gpu_all = block_on(median_gpu(
runmat_accelerate_api::provider()
.unwrap()
.upload(&runmat_accelerate_api::HostTensorView {
data: &tensor.data,
shape: &tensor.shape,
})
.expect("upload"),
&args_all,
))
.expect("gpu median all");
let gathered_all = test_support::gather(gpu_all).expect("gather");
match cpu_all {
Value::Num(a) => {
assert_eq!(gathered_all.data.len(), 1);
assert!((a - gathered_all.data[0]).abs() < 1e-12);
}
Value::Tensor(t) => {
assert_eq!(t.data.len(), gathered_all.data.len());
for (a, b) in t.data.iter().zip(gathered_all.data.iter()) {
assert!((a - b).abs() < 1e-12);
}
}
other => panic!("unexpected CPU output for all: {other:?}"),
}
}
}