1use std::cmp::Ordering;
4use std::collections::BTreeSet;
5
6use runmat_accelerate_api::{AccelProvider, GpuTensorHandle, ReduceDimResult};
7use runmat_builtins::{ComplexTensor, ResolveContext, Tensor, Type, Value};
8use runmat_macros::runtime_builtin;
9
10use crate::{build_runtime_error, BuiltinResult, RuntimeError};
11
12const NAME: &str = "max";
13
14fn max_type(args: &[Type], ctx: &ResolveContext) -> Type {
15 min_max_type(args, ctx)
16}
17
18fn max_error(message: impl Into<String>) -> RuntimeError {
19 build_runtime_error(message).with_builtin(NAME).build()
20}
21
22use crate::builtins::common::arg_tokens::tokens_from_values;
23use crate::builtins::common::broadcast::BroadcastPlan;
24use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
25use crate::builtins::common::spec::{
26 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, FusionError,
27 FusionExprContext, FusionKernelTemplate, GpuOpKind, ProviderHook, ReductionNaN,
28 ResidencyPolicy, ScalarType, ShapeRequirements,
29};
30use crate::builtins::common::{
31 gpu_helpers,
32 shape::{is_scalar_shape, normalize_scalar_shape},
33 tensor,
34};
35use crate::builtins::math::reduction::type_resolvers::min_max_type;
36
37#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::reduction::max")]
38pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
39 name: "max",
40 op_kind: GpuOpKind::Reduction,
41 supported_precisions: &[ScalarType::F32, ScalarType::F64],
42 broadcast: BroadcastSemantics::Matlab,
43 provider_hooks: &[
44 ProviderHook::Reduction {
45 name: "reduce_max_dim",
46 },
47 ProviderHook::Reduction {
48 name: "reduce_max",
49 },
50 ],
51 constant_strategy: ConstantStrategy::InlineLiteral,
52 residency: ResidencyPolicy::NewHandle,
53 nan_mode: ReductionNaN::Include,
54 two_pass_threshold: Some(256),
55 workgroup_size: Some(256),
56 accepts_nan_mode: false,
57 notes:
58 "Providers should implement reduce_max_dim / reduce_max. Requests that require omitnan, comparisonmethod overrides, or complex inputs fall back to the host implementation.",
59};
60
61#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::reduction::max")]
62pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
63 name: "max",
64 shape: ShapeRequirements::BroadcastCompatible,
65 constant_strategy: ConstantStrategy::InlineLiteral,
66 elementwise: None,
67 reduction: Some(FusionKernelTemplate {
68 scalar_precisions: &[ScalarType::F32, ScalarType::F64],
69 wgsl_body: |ctx: &FusionExprContext| {
70 let input = ctx.inputs.first().ok_or(FusionError::MissingInput(0))?;
71 Ok(format!("accumulator = max(accumulator, {input});"))
72 },
73 }),
74 emits_nan: true,
75 notes: "Fusion planner emits canonical reduction kernels; providers may substitute custom WGSL via reduce_max_dim hooks.",
76};
77
78#[derive(Debug, Clone)]
80pub struct MaxEvaluation {
81 values: Value,
82 indices: Value,
83}
84
85impl MaxEvaluation {
86 pub fn into_value(self) -> Value {
88 self.values
89 }
90
91 pub fn into_pair(self) -> (Value, Value) {
93 (self.values, self.indices)
94 }
95
96 pub fn indices_value(&self) -> Value {
98 self.indices.clone()
99 }
100}
101
102#[runtime_builtin(
103 name = "max",
104 category = "math/reduction",
105 summary = "Return the maximum elements of scalars, vectors, matrices, or N-D tensors.",
106 keywords = "max,maximum,reduction,gpu,comparisonmethod,omitnan",
107 accel = "reduction",
108 type_resolver(max_type),
109 builtin_path = "crate::builtins::math::reduction::max"
110)]
111async fn max_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
112 evaluate(value, &rest).await.map(|eval| eval.into_value())
113}
114
115pub async fn evaluate(value: Value, rest: &[Value]) -> BuiltinResult<MaxEvaluation> {
117 let parsed = parse_call(rest).await?;
118 if std::env::var("RUNMAT_DEBUG_MAX").is_ok() {
119 let call_label = match &parsed {
120 ParsedCall::Reduction(_) => "reduction",
121 ParsedCall::Elementwise(_) => "elementwise",
122 };
123 let first_arg = rest.first().map(debug_value_kind).unwrap_or("None");
124 tracing::debug!(
125 call_type = call_label,
126 rest_len = rest.len(),
127 first_arg = first_arg,
128 "[runmat-debug-max]"
129 );
130 }
131 match parsed {
132 ParsedCall::Elementwise(args) => elementwise_max(value, args).await,
133 ParsedCall::Reduction(args) => reduction_max(value, args).await,
134 }
135}
136
137#[derive(Debug, Clone)]
138enum ParsedCall {
139 Reduction(ReductionArgs),
140 Elementwise(ElementwiseArgs),
141}
142
143#[derive(Debug, Clone)]
144struct ReductionArgs {
145 selection: DimSelection,
146 nan_mode: ReductionNaN,
147 comparison: ComparisonMethod,
148 linear_index: bool,
149}
150
151impl Default for ReductionArgs {
152 fn default() -> Self {
153 Self {
154 selection: DimSelection::Auto,
155 nan_mode: ReductionNaN::Include,
156 comparison: ComparisonMethod::Auto,
157 linear_index: false,
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
163enum DimSelection {
164 Auto,
165 Dim(usize),
166 Vec(Vec<usize>),
167 All,
168}
169
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171enum ComparisonMethod {
172 Auto,
173 Real,
174 Abs,
175}
176
177#[derive(Debug, Clone)]
178struct ElementwiseArgs {
179 other: Value,
180 comparison: ComparisonMethod,
181}
182
183async fn parse_call(rest: &[Value]) -> BuiltinResult<ParsedCall> {
184 if rest.is_empty() {
185 return Ok(ParsedCall::Reduction(ReductionArgs::default()));
186 }
187
188 let first = &rest[0];
189 if !is_empty_placeholder(first) {
190 let comparison = parse_elementwise_options(&rest[1..])?;
191 return Ok(ParsedCall::Elementwise(ElementwiseArgs {
192 other: first.clone(),
193 comparison,
194 }));
195 }
196
197 let mut args = ReductionArgs::default();
198 parse_reduction_options(&mut args, &rest[1..]).await?;
199 Ok(ParsedCall::Reduction(args))
200}
201
202fn debug_value_kind(value: &Value) -> &'static str {
203 match value {
204 Value::Num(_) => "Num",
205 Value::Int(_) => "Int",
206 Value::Bool(_) => "Bool",
207 Value::Tensor(t) => {
208 if t.data.is_empty() {
209 "Tensor(empty)"
210 } else {
211 "Tensor"
212 }
213 }
214 Value::GpuTensor(_) => "GpuTensor",
215 Value::String(_) => "String",
216 Value::CharArray(_) => "CharArray",
217 Value::StringArray(sa) => {
218 if sa.data.is_empty() {
219 "StringArray(empty)"
220 } else {
221 "StringArray"
222 }
223 }
224 Value::LogicalArray(l) => {
225 if l.data.is_empty() {
226 "LogicalArray(empty)"
227 } else {
228 "LogicalArray"
229 }
230 }
231 Value::Cell(c) => {
232 if c.data.is_empty() {
233 "Cell(empty)"
234 } else {
235 "Cell"
236 }
237 }
238 _ => "Other",
239 }
240}
241
242fn is_empty_placeholder(value: &Value) -> bool {
243 match value {
244 Value::Tensor(t) => t.data.is_empty(),
245 Value::LogicalArray(l) => l.data.is_empty(),
246 Value::StringArray(sa) => sa.data.is_empty(),
247 Value::CharArray(ca) => ca.data.is_empty(),
248 Value::Cell(cell) => cell.data.is_empty(),
249 Value::String(s) => s.is_empty(),
250 _ => false,
251 }
252}
253
254async fn parse_reduction_options(args: &mut ReductionArgs, rest: &[Value]) -> BuiltinResult<()> {
255 let mut idx = 0usize;
256 let mut selection_set = !matches!(args.selection, DimSelection::Auto);
257 let mut comparison_set = matches!(args.comparison, ComparisonMethod::Auto);
258 let tokens = tokens_from_values(rest);
259 while idx < rest.len() {
260 if let Some(crate::builtins::common::arg_tokens::ArgToken::String(text)) = tokens.get(idx) {
261 match text.as_str() {
262 "omitnan" => {
263 args.nan_mode = ReductionNaN::Omit;
264 idx += 1;
265 continue;
266 }
267 "includenan" => {
268 args.nan_mode = ReductionNaN::Include;
269 idx += 1;
270 continue;
271 }
272 "all" => {
273 if selection_set {
274 return Err(max_error(
275 "max: 'all' cannot be combined with an explicit dimension",
276 ));
277 }
278 args.selection = DimSelection::All;
279 selection_set = true;
280 idx += 1;
281 continue;
282 }
283 _ => {}
284 }
285 }
286 if let Some(keyword) = keyword_of(&rest[idx]) {
287 match keyword.as_str() {
288 "omitnan" => {
289 args.nan_mode = ReductionNaN::Omit;
290 idx += 1;
291 continue;
292 }
293 "includenan" => {
294 args.nan_mode = ReductionNaN::Include;
295 idx += 1;
296 continue;
297 }
298 "all" => {
299 if selection_set {
300 return Err(max_error(
301 "max: 'all' cannot be combined with an explicit dimension",
302 ));
303 }
304 args.selection = DimSelection::All;
305 selection_set = true;
306 idx += 1;
307 continue;
308 }
309 "linear" => {
310 if selection_set {
311 return Err(max_error(
312 "max: 'linear' cannot be combined with an explicit dimension",
313 ));
314 }
315 args.selection = DimSelection::All;
316 args.linear_index = true;
317 selection_set = true;
318 idx += 1;
319 continue;
320 }
321 "comparisonmethod" => {
322 let Some(value) = rest.get(idx + 1) else {
323 return Err(max_error("max: expected a value after 'ComparisonMethod'"));
324 };
325 args.comparison = parse_comparison_method(value)?;
326 comparison_set = true;
327 idx += 2;
328 continue;
329 }
330 _ => {}
331 }
332 }
333
334 if !selection_set {
335 if let Some(selection) = parse_dimension_value(&rest[idx]).await? {
336 args.selection = selection;
337 selection_set = true;
338 idx += 1;
339 continue;
340 }
341 }
342
343 return Err(max_error(format!(
344 "max: unrecognised argument {:?}",
345 rest[idx]
346 )));
347 }
348
349 if !comparison_set {
350 args.comparison = ComparisonMethod::Auto;
351 }
352
353 Ok(())
354}
355
356fn parse_elementwise_options(rest: &[Value]) -> BuiltinResult<ComparisonMethod> {
357 let mut comparison = ComparisonMethod::Auto;
358 let mut comparison_set = false;
359 let mut idx = 0usize;
360 while idx < rest.len() {
361 if let Some(keyword) = keyword_of(&rest[idx]) {
362 match keyword.as_str() {
363 "comparisonmethod" => {
364 let Some(value) = rest.get(idx + 1) else {
365 return Err(max_error("max: expected a value after 'ComparisonMethod'"));
366 };
367 comparison = parse_comparison_method(value)?;
368 comparison_set = true;
369 idx += 2;
370 continue;
371 }
372 "omitnan" | "includenan" | "all" | "linear" => {
373 return Err(max_error(format!(
374 "max: '{}' is only supported for reduction calls",
375 keyword
376 )));
377 }
378 _ => {}
379 }
380 }
381 return Err(max_error(format!(
382 "max: unrecognised argument {:?}",
383 rest[idx]
384 )));
385 }
386 if !comparison_set {
387 comparison = ComparisonMethod::Auto;
388 }
389 Ok(comparison)
390}
391
392fn parse_comparison_method(value: &Value) -> BuiltinResult<ComparisonMethod> {
393 let Some(keyword) = keyword_of(value) else {
394 return Err(max_error("max: 'ComparisonMethod' expects a string value"));
395 };
396 match keyword.as_str() {
397 "auto" => Ok(ComparisonMethod::Auto),
398 "abs" | "magnitude" => Ok(ComparisonMethod::Abs),
399 "real" => Ok(ComparisonMethod::Real),
400 other => Err(max_error(format!(
401 "max: unsupported ComparisonMethod '{other}'"
402 ))),
403 }
404}
405
406async fn parse_dimension_value(value: &Value) -> BuiltinResult<Option<DimSelection>> {
407 match value {
408 Value::Int(_) | Value::Num(_) => tensor::dimension_from_value_async(value, "max", false)
409 .await
410 .map_err(map_scalar_dim_error)
411 .map(|dim| dim.map(DimSelection::Dim)),
412 Value::Tensor(t) => parse_dimension_tensor(value, &t.shape).await,
413 Value::LogicalArray(logical) => parse_dimension_tensor(value, &logical.shape).await,
414 Value::GpuTensor(_) => Err(max_error(
415 "max: dimension arguments must reside on the host",
416 )),
417 _ => Ok(None),
418 }
419}
420
421async fn parse_dimension_tensor(
422 value: &Value,
423 shape: &[usize],
424) -> BuiltinResult<Option<DimSelection>> {
425 if tensor::element_count(shape) == 0 {
426 return Ok(Some(DimSelection::Auto));
427 }
428 let is_vector = shape.len() == 1
429 || shape.get(0).copied().unwrap_or(1) == 1
430 || shape.get(1).copied().unwrap_or(1) == 1;
431 if !is_vector {
432 return Err(max_error(
433 "max: dimension vector must be a row or column vector",
434 ));
435 }
436 let dims = tensor::dims_from_value_async(value)
437 .await
438 .map_err(map_vector_dim_error)?;
439 let Some(dims) = dims else {
440 return Ok(None);
441 };
442 if dims.is_empty() {
443 return Ok(Some(DimSelection::Auto));
444 }
445 let mut seen = BTreeSet::new();
446 let mut uniq = Vec::with_capacity(dims.len());
447 for dim in dims {
448 if dim < 1 {
449 return Err(max_error("max: dimension indices must be >= 1"));
450 }
451 if seen.insert(dim) {
452 uniq.push(dim);
453 }
454 }
455 Ok(Some(DimSelection::Vec(uniq)))
456}
457
458fn map_scalar_dim_error(message: String) -> RuntimeError {
459 if message.contains("integer") {
460 return max_error("max: dimension must be integral");
461 }
462 max_error(message)
463}
464
465fn map_vector_dim_error(message: String) -> RuntimeError {
466 if message.contains("non-negative") {
467 return max_error("max: dimension indices must be >= 1");
468 }
469 if message.contains("finite") {
470 return max_error("max: dimension entries must be finite");
471 }
472 if message.contains("integer") {
473 return max_error("max: dimension entries must be integers");
474 }
475 max_error(message)
476}
477
478async fn reduction_max(value: Value, args: ReductionArgs) -> BuiltinResult<MaxEvaluation> {
479 match value {
480 Value::GpuTensor(handle) => {
481 if let Some(eval) = reduction_max_gpu(handle.clone(), &args).await? {
482 return Ok(eval);
483 }
484 let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
486 reduction_max_host(Value::Tensor(tensor), &args)
487 }
488 other => reduction_max_host(other, &args),
489 }
490}
491
492async fn reduction_max_gpu(
493 handle: GpuTensorHandle,
494 args: &ReductionArgs,
495) -> BuiltinResult<Option<MaxEvaluation>> {
496 #[cfg(all(test, feature = "wgpu"))]
497 {
498 if handle.device_id != 0 {
499 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
500 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
501 );
502 }
503 }
504 if args.nan_mode == ReductionNaN::Omit {
505 log::trace!("max: gpu path disabled (nan_mode=omit)");
506 return Ok(None);
507 }
508 if args.comparison != ComparisonMethod::Auto {
509 log::trace!("max: gpu path disabled (comparison != auto)");
510 return Ok(None);
511 }
512 if args.linear_index {
513 log::trace!("max: gpu path disabled (linear_index=true)");
514 return Ok(None);
515 }
516 let provider = match runmat_accelerate_api::provider() {
517 Some(p) => p,
518 None => {
519 log::trace!(
520 "max: gpu path unavailable (provider() is None) handle_shape={:?} device_id={}",
521 handle.shape,
522 handle.device_id
523 );
524 return Ok(None);
525 }
526 };
527 let target_dim = match args.selection {
528 DimSelection::Auto => default_dimension_from_shape(&handle.shape),
529 DimSelection::Dim(dim) => dim,
530 DimSelection::Vec(ref dims) if dims.len() == 1 => dims[0],
531 DimSelection::All => {
532 if handle.shape.len() <= 1 {
533 1
534 } else {
535 return Ok(None);
536 }
537 }
538 _ => return Ok(None),
539 };
540 if target_dim == 0 {
541 return Ok(None);
542 }
543 let zero_based = target_dim.saturating_sub(1);
545 if zero_based >= handle.shape.len() {
546 return Ok(None);
547 }
548 log::trace!(
549 "max: attempting reduce_max_dim dim={} (zero_based={}) shape={:?} device_id={}",
550 target_dim,
551 zero_based,
552 handle.shape,
553 handle.device_id
554 );
555 match provider.reduce_max_dim(&handle, zero_based).await {
556 Ok(ReduceDimResult { values, indices }) => Ok(Some(MaxEvaluation {
557 values: Value::GpuTensor(values),
558 indices: Value::GpuTensor(indices),
559 })),
560 Err(err) => {
561 log::trace!("max: reduce_max_dim failed: {err}");
562 Ok(None)
563 }
564 }
565}
566
567fn reduction_max_host(value: Value, args: &ReductionArgs) -> BuiltinResult<MaxEvaluation> {
568 match materialize_for_max("max", value)? {
569 InputData::Real(tensor) => reduce_real_tensor(tensor, args),
570 InputData::Complex(tensor) => reduce_complex_tensor(tensor, args),
571 }
572}
573
574enum InputData {
575 Real(Tensor),
576 Complex(ComplexTensor),
577}
578
579fn materialize_for_max(name: &str, value: Value) -> BuiltinResult<InputData> {
580 match value {
581 Value::Tensor(t) => Ok(InputData::Real(t)),
582 Value::LogicalArray(logical) => {
583 let tensor = tensor::logical_to_tensor(&logical).map_err(|err| max_error(err))?;
584 Ok(InputData::Real(tensor))
585 }
586 Value::Num(n) => {
587 let tensor =
588 Tensor::new(vec![n], vec![1, 1]).map_err(|e| max_error(format!("{name}: {e}")))?;
589 Ok(InputData::Real(tensor))
590 }
591 Value::Int(i) => {
592 let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1])
593 .map_err(|e| max_error(format!("{name}: {e}")))?;
594 Ok(InputData::Real(tensor))
595 }
596 Value::Bool(b) => {
597 let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
598 .map_err(|e| max_error(format!("{name}: {e}")))?;
599 Ok(InputData::Real(tensor))
600 }
601 Value::Complex(re, im) => {
602 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
603 .map_err(|e| max_error(format!("{name}: {e}")))?;
604 Ok(InputData::Complex(tensor))
605 }
606 Value::ComplexTensor(ct) => Ok(InputData::Complex(ct)),
607 Value::String(_) | Value::StringArray(_) | Value::CharArray(_) | Value::Cell(_) => {
608 Err(max_error(format!(
609 "{name}: expected numeric or logical input, received non-numeric value"
610 )))
611 }
612 Value::GpuTensor(_) => Err(max_error(format!(
613 "{name}: internal error – GPU tensors must be gathered before host execution"
614 ))),
615 Value::Object(_) | Value::HandleObject(_) | Value::Struct(_) | Value::Listener(_) => {
616 Err(max_error(format!("{name}: unsupported input type")))
617 }
618 Value::FunctionHandle(_)
619 | Value::Closure(_)
620 | Value::ClassRef(_)
621 | Value::MException(_)
622 | Value::OutputList(_) => Err(max_error(format!("{name}: unsupported input type"))),
623 }
624}
625
626fn reduce_real_tensor(tensor: Tensor, args: &ReductionArgs) -> BuiltinResult<MaxEvaluation> {
627 let shape = tensor.shape.clone();
628 if tensor.data.is_empty() {
629 let output_shape = resolve_output_shape(&shape, &args.selection, &[])?;
630 let values = Tensor::new(Vec::new(), output_shape.clone())
631 .map_err(|e| max_error(format!("max: {e}")))?;
632 let indices =
633 Tensor::new(Vec::new(), output_shape).map_err(|e| max_error(format!("max: {e}")))?;
634 return Ok(MaxEvaluation {
635 values: tensor::tensor_into_value(values),
636 indices: tensor::tensor_into_value(indices),
637 });
638 }
639 let resolved = resolve_reduction_dims(&shape, &args.selection)?;
640 let output_shape = resolved.output_shape.clone();
641 let output_len = tensor::element_count(&output_shape);
642
643 if output_len == 0 {
644 let values = Tensor::new(Vec::new(), output_shape.clone())
645 .map_err(|e| max_error(format!("max: {e}")))?;
646 let indices =
647 Tensor::new(Vec::new(), output_shape).map_err(|e| max_error(format!("max: {e}")))?;
648 return Ok(MaxEvaluation {
649 values: tensor::tensor_into_value(values),
650 indices: tensor::tensor_into_value(indices),
651 });
652 }
653
654 let strides = compute_strides(&shape);
655 let output_strides = compute_strides(&output_shape);
656 let dims_mask = resolved.dims_mask.clone();
657 let reduce_strides = resolved.reduce_strides.clone();
658
659 let mut best = vec![BestReal::new(); output_len];
660 let mut coords = vec![0usize; shape.len()];
661 for &value in &tensor.data {
662 let out_idx = map_output_index(&coords, &output_strides, &dims_mask);
663 let reduce_idx = map_reduce_index(
664 &coords,
665 &resolved.reduced_dims,
666 &reduce_strides,
667 resolved.reduce_all,
668 );
669 let full_idx = map_linear_index(&coords, &strides);
670
671 update_best_real(
672 &mut best[out_idx],
673 value,
674 reduce_idx,
675 full_idx,
676 args.nan_mode,
677 args.comparison,
678 );
679 increment_coords(&mut coords, &shape);
680 }
681
682 let mut values = vec![0.0f64; output_len];
683 let mut indices = vec![0.0f64; output_len];
684
685 for (i, entry) in best.iter().enumerate() {
686 if entry.nan_fixed {
687 values[i] = f64::NAN;
688 indices[i] = if args.linear_index || resolved.reduce_all {
689 (entry.full_index + 1) as f64
690 } else if resolved.reduced_dims.is_empty() {
691 1.0
692 } else {
693 (entry.reduce_index + 1) as f64
694 };
695 continue;
696 }
697 if !entry.has_value {
698 values[i] = f64::NAN;
699 indices[i] = f64::NAN;
700 continue;
701 }
702 values[i] = entry.value;
703 indices[i] = if args.linear_index || resolved.reduce_all {
704 (entry.full_index + 1) as f64
705 } else if resolved.reduced_dims.is_empty() {
706 1.0
707 } else {
708 (entry.reduce_index + 1) as f64
709 };
710 }
711
712 let value_tensor =
713 Tensor::new(values, output_shape.clone()).map_err(|e| max_error(format!("max: {e}")))?;
714 let index_tensor =
715 Tensor::new(indices, output_shape).map_err(|e| max_error(format!("max: {e}")))?;
716
717 Ok(MaxEvaluation {
718 values: tensor::tensor_into_value(value_tensor),
719 indices: tensor::tensor_into_value(index_tensor),
720 })
721}
722
723fn reduce_complex_tensor(
724 tensor: ComplexTensor,
725 args: &ReductionArgs,
726) -> BuiltinResult<MaxEvaluation> {
727 let shape = tensor.shape.clone();
728 if tensor.data.is_empty() {
729 let output_shape = resolve_output_shape(&shape, &args.selection, &[])?;
730 let values = ComplexTensor::new(Vec::new(), output_shape.clone())
731 .map_err(|e| max_error(format!("max: {e}")))?;
732 let indices =
733 Tensor::new(Vec::new(), output_shape).map_err(|e| max_error(format!("max: {e}")))?;
734 return Ok(MaxEvaluation {
735 values: complex_tensor_into_value(values),
736 indices: tensor::tensor_into_value(indices),
737 });
738 }
739
740 let resolved = resolve_reduction_dims(&shape, &args.selection)?;
741 let output_shape = resolved.output_shape.clone();
742 let output_len = tensor::element_count(&output_shape);
743
744 if output_len == 0 {
745 let values = ComplexTensor::new(Vec::new(), output_shape.clone())
746 .map_err(|e| max_error(format!("max: {e}")))?;
747 let indices =
748 Tensor::new(Vec::new(), output_shape).map_err(|e| max_error(format!("max: {e}")))?;
749 return Ok(MaxEvaluation {
750 values: complex_tensor_into_value(values),
751 indices: tensor::tensor_into_value(indices),
752 });
753 }
754
755 let strides = compute_strides(&shape);
756 let output_strides = compute_strides(&output_shape);
757 let dims_mask = resolved.dims_mask.clone();
758 let reduce_strides = resolved.reduce_strides.clone();
759
760 let mut best = vec![BestComplex::new(); output_len];
761 let mut coords = vec![0usize; shape.len()];
762
763 for &(re, im) in &tensor.data {
764 let out_idx = map_output_index(&coords, &output_strides, &dims_mask);
765 let reduce_idx = map_reduce_index(
766 &coords,
767 &resolved.reduced_dims,
768 &reduce_strides,
769 resolved.reduce_all,
770 );
771 let full_idx = map_linear_index(&coords, &strides);
772 update_best_complex(
773 &mut best[out_idx],
774 (re, im),
775 reduce_idx,
776 full_idx,
777 args.nan_mode,
778 args.comparison,
779 );
780 increment_coords(&mut coords, &shape);
781 }
782
783 let mut values = vec![(0.0f64, 0.0f64); output_len];
784 let mut indices = vec![0.0f64; output_len];
785
786 for (i, entry) in best.iter().enumerate() {
787 if entry.nan_fixed {
788 values[i] = (f64::NAN, f64::NAN);
789 indices[i] = if args.linear_index || resolved.reduce_all {
790 (entry.full_index + 1) as f64
791 } else if resolved.reduced_dims.is_empty() {
792 1.0
793 } else {
794 (entry.reduce_index + 1) as f64
795 };
796 continue;
797 }
798 if !entry.has_value {
799 values[i] = (f64::NAN, f64::NAN);
800 indices[i] = f64::NAN;
801 continue;
802 }
803 values[i] = entry.value;
804 indices[i] = if args.linear_index || resolved.reduce_all {
805 (entry.full_index + 1) as f64
806 } else if resolved.reduced_dims.is_empty() {
807 1.0
808 } else {
809 (entry.reduce_index + 1) as f64
810 };
811 }
812
813 let value_tensor = ComplexTensor::new(values, output_shape.clone())
814 .map_err(|e| max_error(format!("max: {e}")))?;
815 let index_tensor =
816 Tensor::new(indices, output_shape).map_err(|e| max_error(format!("max: {e}")))?;
817 Ok(MaxEvaluation {
818 values: complex_tensor_into_value(value_tensor),
819 indices: tensor::tensor_into_value(index_tensor),
820 })
821}
822
823#[derive(Debug, Clone)]
824struct BestReal {
825 value: f64,
826 reduce_index: usize,
827 full_index: usize,
828 has_value: bool,
829 nan_fixed: bool,
830}
831
832impl BestReal {
833 fn new() -> Self {
834 Self {
835 value: 0.0,
836 reduce_index: 0,
837 full_index: 0,
838 has_value: false,
839 nan_fixed: false,
840 }
841 }
842}
843
844#[derive(Debug, Clone)]
845struct BestComplex {
846 value: (f64, f64),
847 reduce_index: usize,
848 full_index: usize,
849 has_value: bool,
850 nan_fixed: bool,
851}
852
853impl BestComplex {
854 fn new() -> Self {
855 Self {
856 value: (0.0, 0.0),
857 reduce_index: 0,
858 full_index: 0,
859 has_value: false,
860 nan_fixed: false,
861 }
862 }
863}
864
865fn resolve_output_shape(
866 shape: &[usize],
867 selection: &DimSelection,
868 reduced_dims: &[usize],
869) -> BuiltinResult<Vec<usize>> {
870 if is_scalar_shape(shape) {
871 return Ok(normalize_scalar_shape(shape));
872 }
873 let mut output = shape.to_vec();
874 match selection {
875 DimSelection::All => {
876 output.fill(1);
877 }
878 _ => {
879 for &dim in reduced_dims {
880 if dim < output.len() {
881 output[dim] = 1;
882 }
883 }
884 }
885 }
886 Ok(output)
887}
888
889struct ResolvedDims {
890 output_shape: Vec<usize>,
891 reduced_dims: Vec<usize>,
892 reduce_all: bool,
893 dims_mask: Vec<bool>,
894 reduce_strides: Vec<usize>,
895}
896
897fn resolve_reduction_dims(
898 shape: &[usize],
899 selection: &DimSelection,
900) -> BuiltinResult<ResolvedDims> {
901 if is_scalar_shape(shape) {
902 return Ok(ResolvedDims {
903 output_shape: normalize_scalar_shape(shape),
904 reduced_dims: Vec::new(),
905 reduce_all: true,
906 dims_mask: Vec::new(),
907 reduce_strides: Vec::new(),
908 });
909 }
910
911 let mut reduced_dims = match selection {
912 DimSelection::Auto => {
913 let mut dim = None;
914 for (index, &len) in shape.iter().enumerate() {
915 if len > 1 {
916 dim = Some(index);
917 break;
918 }
919 }
920 vec![dim.unwrap_or(0)]
921 }
922 DimSelection::Dim(dim) => {
923 if *dim == 0 {
924 return Err(max_error("max: dimension must be >= 1"));
925 }
926 let index = dim.saturating_sub(1);
927 if index >= shape.len() {
928 Vec::new()
929 } else {
930 vec![index]
931 }
932 }
933 DimSelection::Vec(dims) => {
934 if dims.is_empty() {
935 Vec::new()
936 } else {
937 dims.iter()
938 .filter_map(|dim| {
939 if *dim == 0 {
940 None
941 } else {
942 let idx = dim - 1;
943 if idx < shape.len() {
944 Some(idx)
945 } else {
946 None
947 }
948 }
949 })
950 .collect()
951 }
952 }
953 DimSelection::All => (0..shape.len()).collect(),
954 };
955
956 reduced_dims.sort_unstable();
957 reduced_dims.dedup();
958
959 let reduce_all = !reduced_dims.is_empty()
960 && reduced_dims.len() == shape.len()
961 && reduced_dims.iter().enumerate().all(|(i, &d)| i == d);
962
963 let output_shape = resolve_output_shape(shape, selection, &reduced_dims)?;
964 let mut dims_mask = vec![false; shape.len()];
965 for &dim in &reduced_dims {
966 if dim < dims_mask.len() {
967 dims_mask[dim] = true;
968 }
969 }
970 let reduce_strides = compute_subspace_strides(shape, &reduced_dims);
971
972 Ok(ResolvedDims {
973 output_shape,
974 reduced_dims,
975 reduce_all,
976 dims_mask,
977 reduce_strides,
978 })
979}
980
981fn compute_strides(shape: &[usize]) -> Vec<usize> {
982 let mut strides = Vec::with_capacity(shape.len());
983 let mut stride = 1usize;
984 for &len in shape {
985 strides.push(stride);
986 stride = stride.saturating_mul(len.max(1));
987 }
988 strides
989}
990
991fn compute_subspace_strides(shape: &[usize], dims: &[usize]) -> Vec<usize> {
992 if dims.is_empty() {
993 return Vec::new();
994 }
995 let mut strides = Vec::with_capacity(dims.len());
996 let mut accum = 1usize;
997 for &dim in dims {
998 let len = shape.get(dim).copied().unwrap_or(1).max(1);
999 strides.push(accum);
1000 accum = accum.saturating_mul(len);
1001 }
1002 strides
1003}
1004
1005fn map_output_index(coords: &[usize], output_strides: &[usize], dims_mask: &[bool]) -> usize {
1006 if coords.is_empty() {
1007 return 0;
1008 }
1009 let mut index = 0usize;
1010 for (dim, stride) in output_strides.iter().enumerate() {
1011 let coord = if *dims_mask.get(dim).unwrap_or(&false) {
1012 0
1013 } else {
1014 coords[dim]
1015 };
1016 index = index.saturating_add(coord.saturating_mul(*stride));
1017 }
1018 index
1019}
1020
1021fn map_reduce_index(
1022 coords: &[usize],
1023 reduced_dims: &[usize],
1024 reduce_strides: &[usize],
1025 reduce_all: bool,
1026) -> usize {
1027 if reduced_dims.is_empty() {
1028 return 0;
1029 }
1030 if reduce_all {
1031 return 0;
1033 }
1034 let mut index = 0usize;
1035 for (pos, &dim) in reduced_dims.iter().enumerate() {
1036 if let Some(coord) = coords.get(dim) {
1037 if let Some(stride) = reduce_strides.get(pos) {
1038 index = index.saturating_add(coord.saturating_mul(*stride));
1039 }
1040 }
1041 }
1042 index
1043}
1044
1045fn map_linear_index(coords: &[usize], strides: &[usize]) -> usize {
1046 coords
1047 .iter()
1048 .zip(strides.iter())
1049 .fold(0usize, |acc, (&coord, &stride)| {
1050 acc.saturating_add(coord.saturating_mul(stride))
1051 })
1052}
1053
1054fn increment_coords(coords: &mut [usize], shape: &[usize]) {
1055 for dim in 0..coords.len() {
1056 if shape[dim] == 0 {
1057 continue;
1058 }
1059 coords[dim] += 1;
1060 if coords[dim] < shape[dim] {
1061 break;
1062 }
1063 coords[dim] = 0;
1064 }
1065}
1066
1067fn update_best_real(
1068 best: &mut BestReal,
1069 value: f64,
1070 reduce_index: usize,
1071 full_index: usize,
1072 nan_mode: ReductionNaN,
1073 comparison: ComparisonMethod,
1074) {
1075 if value.is_nan() {
1076 match nan_mode {
1077 ReductionNaN::Include => {
1078 if !best.nan_fixed {
1079 best.value = f64::NAN;
1080 best.reduce_index = reduce_index;
1081 best.full_index = full_index;
1082 best.has_value = true;
1083 best.nan_fixed = true;
1084 }
1085 }
1086 ReductionNaN::Omit => {}
1087 }
1088 return;
1089 }
1090 if best.nan_fixed {
1091 return;
1092 }
1093
1094 if !best.has_value {
1095 best.value = value;
1096 best.reduce_index = reduce_index;
1097 best.full_index = full_index;
1098 best.has_value = true;
1099 return;
1100 }
1101
1102 if should_replace_real(best.value, value, comparison) {
1103 best.value = value;
1104 best.reduce_index = reduce_index;
1105 best.full_index = full_index;
1106 }
1107}
1108
1109fn update_best_complex(
1110 best: &mut BestComplex,
1111 value: (f64, f64),
1112 reduce_index: usize,
1113 full_index: usize,
1114 nan_mode: ReductionNaN,
1115 comparison: ComparisonMethod,
1116) {
1117 if value.0.is_nan() || value.1.is_nan() {
1118 match nan_mode {
1119 ReductionNaN::Include => {
1120 if !best.nan_fixed {
1121 best.value = (f64::NAN, f64::NAN);
1122 best.reduce_index = reduce_index;
1123 best.full_index = full_index;
1124 best.has_value = true;
1125 best.nan_fixed = true;
1126 }
1127 }
1128 ReductionNaN::Omit => {}
1129 }
1130 return;
1131 }
1132 if best.nan_fixed {
1133 return;
1134 }
1135
1136 if !best.has_value {
1137 best.value = value;
1138 best.reduce_index = reduce_index;
1139 best.full_index = full_index;
1140 best.has_value = true;
1141 return;
1142 }
1143
1144 if should_replace_complex(best.value, value, comparison) {
1145 best.value = value;
1146 best.reduce_index = reduce_index;
1147 best.full_index = full_index;
1148 }
1149}
1150
1151fn should_replace_real(current: f64, candidate: f64, comparison: ComparisonMethod) -> bool {
1152 match comparison {
1153 ComparisonMethod::Auto | ComparisonMethod::Real => {
1154 if candidate > current {
1155 return true;
1156 }
1157 if candidate < current {
1158 return false;
1159 }
1160 if candidate == 0.0 && current == 0.0 {
1161 return candidate.is_sign_positive() && !current.is_sign_positive();
1162 }
1163 false
1164 }
1165 ComparisonMethod::Abs => {
1166 let curr_abs = current.abs();
1167 let cand_abs = candidate.abs();
1168 if cand_abs > curr_abs {
1169 return true;
1170 }
1171 if cand_abs < curr_abs {
1172 return false;
1173 }
1174 if candidate > current {
1175 return true;
1176 }
1177 if candidate < current {
1178 return false;
1179 }
1180 if candidate == 0.0 && current == 0.0 {
1181 return candidate.is_sign_positive() && !current.is_sign_positive();
1182 }
1183 false
1184 }
1185 }
1186}
1187
1188fn should_replace_complex(
1189 current: (f64, f64),
1190 candidate: (f64, f64),
1191 comparison: ComparisonMethod,
1192) -> bool {
1193 match comparison {
1194 ComparisonMethod::Auto | ComparisonMethod::Abs => {
1195 compare_complex_auto(current, candidate) == Ordering::Less
1196 }
1197 ComparisonMethod::Real => compare_complex_real(current, candidate) == Ordering::Less,
1198 }
1199}
1200
1201fn compare_complex_auto(a: (f64, f64), b: (f64, f64)) -> Ordering {
1202 let a_mag = magnitude_squared(a);
1203 let b_mag = magnitude_squared(b);
1204 if a_mag < b_mag {
1205 return Ordering::Less;
1206 }
1207 if a_mag > b_mag {
1208 return Ordering::Greater;
1209 }
1210 let a_angle = a.1.atan2(a.0);
1212 let b_angle = b.1.atan2(b.0);
1213 if a_angle < b_angle {
1214 Ordering::Less
1215 } else if a_angle > b_angle {
1216 Ordering::Greater
1217 } else {
1218 Ordering::Equal
1219 }
1220}
1221
1222fn compare_complex_real(a: (f64, f64), b: (f64, f64)) -> Ordering {
1223 if a.0 < b.0 {
1224 return Ordering::Less;
1225 }
1226 if a.0 > b.0 {
1227 return Ordering::Greater;
1228 }
1229 compare_complex_auto(a, b)
1231}
1232
1233fn magnitude_squared(z: (f64, f64)) -> f64 {
1234 z.0.mul_add(z.0, z.1 * z.1)
1235}
1236
1237fn default_dimension_from_shape(shape: &[usize]) -> usize {
1238 if is_scalar_shape(shape) {
1239 return 1;
1240 }
1241 for (i, &len) in shape.iter().enumerate() {
1242 if len > 1 {
1243 return i + 1;
1244 }
1245 }
1246 1
1247}
1248
1249async fn elementwise_max(value: Value, args: ElementwiseArgs) -> BuiltinResult<MaxEvaluation> {
1250 let ElementwiseArgs { other, comparison } = args;
1251 match (value, other) {
1252 (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
1253 if gpu_tensor_is_scalar(&handle_b) {
1254 if let Some(num) = gpu_tensor_scalar_value(&handle_b).await {
1255 let scalar = Value::Num(num);
1256 if let Some(eval) =
1257 elementwise_max_gpu_scalar_left(&handle_a, &scalar, comparison).await
1258 {
1259 return Ok(eval);
1260 }
1261 if let Ok(ta) = gpu_helpers::gather_tensor_async(&handle_a).await {
1262 if let Ok(eval) = elementwise_real_or_complex(
1263 Value::Tensor(ta),
1264 scalar.clone(),
1265 comparison,
1266 ) {
1267 return Ok(eval);
1268 }
1269 }
1270 return Err(max_error("max: elementwise GPU scalar path failed"));
1271 }
1272 }
1273 if gpu_tensor_is_scalar(&handle_a) {
1274 if let Some(num) = gpu_tensor_scalar_value(&handle_a).await {
1275 let scalar = Value::Num(num);
1276 if let Some(eval) =
1277 elementwise_max_gpu_scalar_right(&scalar, &handle_b, comparison).await
1278 {
1279 return Ok(eval);
1280 }
1281 if let Ok(tb) = gpu_helpers::gather_tensor_async(&handle_b).await {
1282 if let Ok(eval) = elementwise_real_or_complex(
1283 scalar.clone(),
1284 Value::Tensor(tb),
1285 comparison,
1286 ) {
1287 return Ok(eval);
1288 }
1289 }
1290 return Err(max_error("max: elementwise GPU scalar path failed"));
1291 }
1292 }
1293 if let Some(eval) = elementwise_max_gpu_pair(&handle_a, &handle_b, comparison).await {
1294 return Ok(eval);
1295 }
1296 if let (Ok(ta), Ok(tb)) = (
1297 gpu_helpers::gather_tensor_async(&handle_a).await,
1298 gpu_helpers::gather_tensor_async(&handle_b).await,
1299 ) {
1300 if let Ok(eval) =
1301 elementwise_real_or_complex(Value::Tensor(ta), Value::Tensor(tb), comparison)
1302 {
1303 return Ok(eval);
1304 }
1305 }
1306 Err(max_error("max: elementwise GPU path failed"))
1307 }
1308 (Value::GpuTensor(handle), other) => {
1309 if let Some(eval) = elementwise_max_gpu_scalar_left(&handle, &other, comparison).await {
1310 return Ok(eval);
1311 }
1312 let t = gpu_helpers::gather_tensor_async(&handle)
1313 .await
1314 .map_err(|_| max_error("max: elementwise GPU scalar path failed"))?;
1315 elementwise_real_or_complex(Value::Tensor(t), other, comparison)
1316 }
1317 (other, Value::GpuTensor(handle)) => {
1318 if let Some(eval) = elementwise_max_gpu_scalar_right(&other, &handle, comparison).await
1319 {
1320 return Ok(eval);
1321 }
1322 let t = gpu_helpers::gather_tensor_async(&handle)
1323 .await
1324 .map_err(|_| max_error("max: elementwise GPU scalar path failed"))?;
1325 elementwise_real_or_complex(other, Value::Tensor(t), comparison)
1326 }
1327 (lhs, rhs) => elementwise_real_or_complex(lhs, rhs, comparison),
1328 }
1329}
1330
1331async fn elementwise_max_gpu_pair(
1332 a: &GpuTensorHandle,
1333 b: &GpuTensorHandle,
1334 comparison: ComparisonMethod,
1335) -> Option<MaxEvaluation> {
1336 if comparison != ComparisonMethod::Auto {
1337 return None;
1338 }
1339 let provider = runmat_accelerate_api::provider()?;
1340 if a.shape == b.shape {
1342 let values = provider.elem_max(a, b).await.ok()?;
1343 if let Ok(mask) = provider.elem_ge(a, b).await {
1345 let indices = gpu_mask_indices(provider, &mask)?;
1346 let _ = provider.free(&mask);
1347 return Some(MaxEvaluation {
1348 values: Value::GpuTensor(values),
1349 indices: Value::GpuTensor(indices),
1350 });
1351 } else {
1352 let ta = gpu_helpers::gather_tensor_async(a).await.ok()?;
1354 let tb = gpu_helpers::gather_tensor_async(b).await.ok()?;
1355 let mut indices = Vec::with_capacity(ta.data.len());
1356 for i in 0..ta.data.len() {
1357 indices.push(if ta.data[i] >= tb.data[i] { 1.0 } else { 2.0 });
1358 }
1359 let index_tensor = Tensor::new(indices, ta.shape.clone()).ok()?;
1360 return Some(MaxEvaluation {
1361 values: Value::GpuTensor(values),
1362 indices: tensor::tensor_into_value(index_tensor),
1363 });
1364 }
1365 }
1366 let (out_shape, reps_a, reps_b) = broadcast_reps(&a.shape, &b.shape)?;
1368 let a_exp = if reps_a.iter().any(|&r| r != 1) {
1369 provider.repmat(a, &reps_a).ok()?
1370 } else {
1371 a.clone()
1372 };
1373 let b_exp = if reps_b.iter().any(|&r| r != 1) {
1374 provider.repmat(b, &reps_b).ok()?
1375 } else {
1376 b.clone()
1377 };
1378 let values = provider.elem_max(&a_exp, &b_exp).await.ok();
1379 let mask = provider.elem_ge(&a_exp, &b_exp).await.ok();
1380 if !std::ptr::eq(&a_exp, a) {
1381 let _ = provider.free(&a_exp);
1382 }
1383 if !std::ptr::eq(&b_exp, b) {
1384 let _ = provider.free(&b_exp);
1385 }
1386 let values = values?;
1387 if values.shape != out_shape {
1388 let _ = provider.free(&values);
1389 return None;
1390 }
1391 let index_tensor = if let Some(mask) = mask {
1392 let mask_host = gpu_helpers::gather_tensor_async(&mask).await.ok()?;
1393 let _ = provider.free(&mask);
1394 let mut indices = Vec::with_capacity(mask_host.data.len());
1395 for &m in &mask_host.data {
1396 indices.push(if m != 0.0 { 1.0 } else { 2.0 });
1397 }
1398 Tensor::new(indices, out_shape).ok()?
1399 } else {
1400 let ta = gpu_helpers::gather_tensor_async(&a_exp).await.ok()?;
1402 let tb = gpu_helpers::gather_tensor_async(&b_exp).await.ok()?;
1403 let mut indices = Vec::with_capacity(ta.data.len());
1404 for i in 0..ta.data.len() {
1405 indices.push(if ta.data[i] >= tb.data[i] { 1.0 } else { 2.0 });
1406 }
1407 Tensor::new(indices, out_shape).ok()?
1408 };
1409 Some(MaxEvaluation {
1410 values: Value::GpuTensor(values),
1411 indices: tensor::tensor_into_value(index_tensor),
1412 })
1413}
1414
1415fn broadcast_reps(a: &[usize], b: &[usize]) -> Option<(Vec<usize>, Vec<usize>, Vec<usize>)> {
1416 let rank = a.len().max(b.len()).max(1);
1417 let mut out = vec![1usize; rank];
1418 let mut aa = vec![1usize; rank];
1419 let mut bb = vec![1usize; rank];
1420 for i in 0..rank {
1421 aa[i] = *a.get(i).unwrap_or(&1);
1422 bb[i] = *b.get(i).unwrap_or(&1);
1423 }
1424 for i in 0..rank {
1425 let (ad, bd) = (aa[i], bb[i]);
1426 if ad == bd {
1427 out[i] = ad;
1428 } else if ad == 1 {
1429 out[i] = bd;
1430 } else if bd == 1 {
1431 out[i] = ad;
1432 } else {
1433 return None;
1434 }
1435 }
1436 let reps_a: Vec<usize> = (0..rank)
1437 .map(|i| if aa[i] == out[i] { 1 } else { out[i] })
1438 .collect();
1439 let reps_b: Vec<usize> = (0..rank)
1440 .map(|i| if bb[i] == out[i] { 1 } else { out[i] })
1441 .collect();
1442 Some((out, reps_a, reps_b))
1443}
1444
1445async fn elementwise_max_gpu_scalar_left(
1446 a: &GpuTensorHandle,
1447 other: &Value,
1448 comparison: ComparisonMethod,
1449) -> Option<MaxEvaluation> {
1450 if comparison != ComparisonMethod::Auto {
1451 return None;
1452 }
1453 let provider = runmat_accelerate_api::provider()?;
1454 let scalar = extract_scalar(other)?;
1455 let values = if let Ok(fill) = provider.fill_like(a, scalar) {
1457 let vals = provider.elem_max(a, &fill).await.ok();
1458 let _ = provider.free(&fill);
1459 vals?
1460 } else {
1461 provider.scalar_max(a, scalar).ok()?
1462 };
1463 let index_tensor = if let Ok(fill) = provider.fill_like(a, scalar) {
1465 if let Ok(mask) = provider.elem_ge(a, &fill).await {
1466 let _ = provider.free(&fill);
1467 let indices = gpu_mask_indices(provider, &mask)?;
1468 let _ = provider.free(&mask);
1469 return Some(MaxEvaluation {
1470 values: Value::GpuTensor(values),
1471 indices: Value::GpuTensor(indices),
1472 });
1473 } else {
1474 let _ = provider.free(&fill);
1475 let ta = gpu_helpers::gather_tensor_async(a).await.ok()?;
1476 let mut indices = Vec::with_capacity(ta.data.len());
1477 for &v in &ta.data {
1478 indices.push(if v >= scalar { 1.0 } else { 2.0 });
1479 }
1480 Tensor::new(indices, ta.shape.clone()).ok()?
1481 }
1482 } else {
1483 let ta = gpu_helpers::gather_tensor_async(a).await.ok()?;
1484 let mut indices = Vec::with_capacity(ta.data.len());
1485 for &v in &ta.data {
1486 indices.push(if v >= scalar { 1.0 } else { 2.0 });
1487 }
1488 Tensor::new(indices, ta.shape.clone()).ok()?
1489 };
1490 Some(MaxEvaluation {
1491 values: Value::GpuTensor(values),
1492 indices: tensor::tensor_into_value(index_tensor),
1493 })
1494}
1495
1496async fn elementwise_max_gpu_scalar_right(
1497 other: &Value,
1498 b: &GpuTensorHandle,
1499 comparison: ComparisonMethod,
1500) -> Option<MaxEvaluation> {
1501 if comparison != ComparisonMethod::Auto {
1502 return None;
1503 }
1504 let provider = runmat_accelerate_api::provider()?;
1505 let scalar = extract_scalar(other)?;
1506 let values = if let Ok(fill) = provider.fill_like(b, scalar) {
1507 let vals = provider.elem_max(&fill, b).await.ok();
1508 let _ = provider.free(&fill);
1509 vals?
1510 } else {
1511 provider.scalar_max(b, scalar).ok()?
1512 };
1513 let index_tensor = if let Ok(fill) = provider.fill_like(b, scalar) {
1515 if let Ok(mask) = provider.elem_ge(&fill, b).await {
1516 let _ = provider.free(&fill);
1517 let indices = gpu_mask_indices(provider, &mask)?;
1518 let _ = provider.free(&mask);
1519 return Some(MaxEvaluation {
1520 values: Value::GpuTensor(values),
1521 indices: Value::GpuTensor(indices),
1522 });
1523 } else {
1524 let _ = provider.free(&fill);
1525 let tb = gpu_helpers::gather_tensor_async(b).await.ok()?;
1526 let mut indices = Vec::with_capacity(tb.data.len());
1527 for &v in &tb.data {
1528 indices.push(if scalar >= v { 1.0 } else { 2.0 });
1529 }
1530 Tensor::new(indices, tb.shape.clone()).ok()?
1531 }
1532 } else {
1533 let tb = gpu_helpers::gather_tensor_async(b).await.ok()?;
1534 let mut indices = Vec::with_capacity(tb.data.len());
1535 for &v in &tb.data {
1536 indices.push(if scalar >= v { 1.0 } else { 2.0 });
1537 }
1538 Tensor::new(indices, tb.shape.clone()).ok()?
1539 };
1540 Some(MaxEvaluation {
1541 values: Value::GpuTensor(values),
1542 indices: tensor::tensor_into_value(index_tensor),
1543 })
1544}
1545
1546fn extract_scalar(v: &Value) -> Option<f64> {
1547 match v {
1548 Value::Num(n) => Some(*n),
1549 Value::Int(i) => Some(i.to_f64()),
1550 Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
1551 Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
1552 Value::LogicalArray(l) if l.data.len() == 1 => Some(if l.data[0] != 0 { 1.0 } else { 0.0 }),
1553 _ => None,
1554 }
1555}
1556
1557fn gpu_tensor_is_scalar(handle: &GpuTensorHandle) -> bool {
1558 handle.shape.iter().copied().product::<usize>().max(1) == 1
1559}
1560
1561async fn gpu_tensor_scalar_value(handle: &GpuTensorHandle) -> Option<f64> {
1562 let tensor = gpu_helpers::gather_tensor_async(handle).await.ok()?;
1563 tensor.data.first().copied()
1564}
1565
1566fn gpu_mask_indices(
1567 provider: &dyn AccelProvider,
1568 mask: &GpuTensorHandle,
1569) -> Option<GpuTensorHandle> {
1570 let scaled = provider.scalar_mul(mask, -1.0).ok()?;
1571 let shifted = provider.scalar_add(&scaled, 2.0).ok()?;
1572 let _ = provider.free(&scaled);
1573 Some(shifted)
1574}
1575
1576fn elementwise_real_or_complex(
1577 lhs: Value,
1578 rhs: Value,
1579 comparison: ComparisonMethod,
1580) -> BuiltinResult<MaxEvaluation> {
1581 if let Some(eval) = scalar_elementwise_max(&lhs, &rhs, comparison) {
1582 return Ok(eval);
1583 }
1584 match (
1585 materialize_for_max("max", lhs)?,
1586 materialize_for_max("max", rhs)?,
1587 ) {
1588 (InputData::Complex(a), InputData::Complex(b)) => elementwise_complex_max(a, b, comparison),
1589 (InputData::Complex(a), InputData::Real(b)) => {
1590 let converted = promote_real_tensor_to_complex(b);
1591 elementwise_complex_max(a, converted, comparison)
1592 }
1593 (InputData::Real(a), InputData::Complex(b)) => {
1594 let converted = promote_real_tensor_to_complex(a);
1595 elementwise_complex_max(converted, b, comparison)
1596 }
1597 (InputData::Real(a), InputData::Real(b)) => elementwise_real_max(a, b, comparison),
1598 }
1599}
1600
1601fn scalar_real_value(value: &Value) -> Option<f64> {
1602 match value {
1603 Value::Num(n) => Some(*n),
1604 Value::Int(i) => Some(i.to_f64()),
1605 Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
1606 Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
1607 Value::LogicalArray(l) if l.data.len() == 1 => Some(if l.data[0] != 0 { 1.0 } else { 0.0 }),
1608 _ => None,
1609 }
1610}
1611
1612fn scalar_complex_value(value: &Value) -> Option<(f64, f64)> {
1613 match value {
1614 Value::Complex(re, im) => Some((*re, *im)),
1615 Value::ComplexTensor(ct) if ct.data.len() == 1 => ct.data.first().copied(),
1616 _ => None,
1617 }
1618}
1619
1620fn scalar_elementwise_max(
1621 lhs: &Value,
1622 rhs: &Value,
1623 comparison: ComparisonMethod,
1624) -> Option<MaxEvaluation> {
1625 let left = scalar_complex_value(lhs).or_else(|| scalar_real_value(lhs).map(|v| (v, 0.0)))?;
1626 let right = scalar_complex_value(rhs).or_else(|| scalar_real_value(rhs).map(|v| (v, 0.0)))?;
1627 let (ar, ai) = left;
1628 let (br, bi) = right;
1629 if ai != 0.0 || bi != 0.0 {
1630 let (value, origin) = choose_complex_elementwise((ar, ai), (br, bi), comparison);
1631 return Some(MaxEvaluation {
1632 values: Value::Complex(value.0, value.1),
1633 indices: Value::Num(origin),
1634 });
1635 }
1636 let (value, origin) = choose_real_elementwise(ar, br, comparison);
1637 Some(MaxEvaluation {
1638 values: Value::Num(value),
1639 indices: Value::Num(origin),
1640 })
1641}
1642
1643fn elementwise_real_max(
1644 lhs: Tensor,
1645 rhs: Tensor,
1646 comparison: ComparisonMethod,
1647) -> BuiltinResult<MaxEvaluation> {
1648 let plan = BroadcastPlan::new(&lhs.shape, &rhs.shape).map_err(|err| format!("max: {}", err))?;
1649 let mut values = vec![0.0f64; plan.len()];
1650 let mut indices = vec![0.0f64; plan.len()];
1651
1652 for (offset, index_a, index_b) in plan.iter() {
1653 let a = lhs.data.get(index_a).copied().unwrap_or(f64::NAN);
1654 let b = rhs.data.get(index_b).copied().unwrap_or(f64::NAN);
1655 let (value, origin) = choose_real_elementwise(a, b, comparison);
1656 values[offset] = value;
1657 indices[offset] = origin;
1658 }
1659
1660 let value_tensor = Tensor::new(values, plan.output_shape().to_vec())
1661 .map_err(|e| max_error(format!("max: {e}")))?;
1662 let index_tensor = Tensor::new(indices, plan.output_shape().to_vec())
1663 .map_err(|e| max_error(format!("max: {e}")))?;
1664
1665 Ok(MaxEvaluation {
1666 values: tensor::tensor_into_value(value_tensor),
1667 indices: tensor::tensor_into_value(index_tensor),
1668 })
1669}
1670
1671fn elementwise_complex_max(
1672 lhs: ComplexTensor,
1673 rhs: ComplexTensor,
1674 comparison: ComparisonMethod,
1675) -> BuiltinResult<MaxEvaluation> {
1676 let plan = BroadcastPlan::new(&lhs.shape, &rhs.shape).map_err(|err| format!("max: {}", err))?;
1677 let mut values = vec![(0.0f64, 0.0f64); plan.len()];
1678 let mut indices = vec![0.0f64; plan.len()];
1679
1680 for (offset, index_a, index_b) in plan.iter() {
1681 let a = lhs
1682 .data
1683 .get(index_a)
1684 .copied()
1685 .unwrap_or((f64::NAN, f64::NAN));
1686 let b = rhs
1687 .data
1688 .get(index_b)
1689 .copied()
1690 .unwrap_or((f64::NAN, f64::NAN));
1691 let (value, origin) = choose_complex_elementwise(a, b, comparison);
1692 values[offset] = value;
1693 indices[offset] = origin;
1694 }
1695
1696 let value_tensor = ComplexTensor::new(values, plan.output_shape().to_vec())
1697 .map_err(|e| max_error(format!("max: {e}")))?;
1698 let index_tensor = Tensor::new(indices, plan.output_shape().to_vec())
1699 .map_err(|e| max_error(format!("max: {e}")))?;
1700
1701 Ok(MaxEvaluation {
1702 values: complex_tensor_into_value(value_tensor),
1703 indices: tensor::tensor_into_value(index_tensor),
1704 })
1705}
1706
1707fn promote_real_tensor_to_complex(tensor: Tensor) -> ComplexTensor {
1708 let data = tensor
1709 .data
1710 .iter()
1711 .copied()
1712 .map(|re| (re, 0.0))
1713 .collect::<Vec<_>>();
1714 ComplexTensor {
1715 data,
1716 shape: tensor.shape.clone(),
1717 rows: tensor.rows,
1718 cols: tensor.cols,
1719 }
1720}
1721
1722fn choose_real_elementwise(a: f64, b: f64, comparison: ComparisonMethod) -> (f64, f64) {
1723 match (a.is_nan(), b.is_nan()) {
1724 (true, true) => (f64::NAN, 1.0),
1725 (true, false) => (f64::NAN, 1.0),
1726 (false, true) => (f64::NAN, 2.0),
1727 (false, false) => {
1728 if should_replace_real(a, b, comparison) {
1729 (b, 2.0)
1730 } else {
1731 (a, 1.0)
1732 }
1733 }
1734 }
1735}
1736
1737fn choose_complex_elementwise(
1738 a: (f64, f64),
1739 b: (f64, f64),
1740 comparison: ComparisonMethod,
1741) -> ((f64, f64), f64) {
1742 let a_nan = a.0.is_nan() || a.1.is_nan();
1743 let b_nan = b.0.is_nan() || b.1.is_nan();
1744 match (a_nan, b_nan) {
1745 (true, true) => ((f64::NAN, f64::NAN), 1.0),
1746 (true, false) => ((f64::NAN, f64::NAN), 1.0),
1747 (false, true) => ((f64::NAN, f64::NAN), 2.0),
1748 (false, false) => {
1749 if should_replace_complex(a, b, comparison) {
1750 (b, 2.0)
1751 } else {
1752 (a, 1.0)
1753 }
1754 }
1755 }
1756}
1757
1758#[cfg(test)]
1759pub(crate) mod tests {
1760 use super::*;
1761 #[cfg(feature = "wgpu")]
1762 use crate::builtins::common::test_support;
1763 use futures::executor::block_on;
1764 #[cfg(feature = "wgpu")]
1765 use runmat_accelerate_api::HostTensorView;
1766 use runmat_builtins::{IntValue, Tensor, Value};
1767
1768 fn max_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
1769 block_on(super::max_builtin(value, rest))
1770 }
1771
1772 #[test]
1773 fn max_type_with_two_args_returns_tensor() {
1774 let out = max_type(
1775 &[Type::Tensor { shape: None }, Type::Tensor { shape: None }],
1776 &ResolveContext::new(Vec::new()),
1777 );
1778 assert_eq!(out, Type::tensor());
1779 }
1780
1781 fn evaluate(value: Value, rest: &[Value]) -> BuiltinResult<MaxEvaluation> {
1782 block_on(super::evaluate(value, rest))
1783 }
1784
1785 fn placeholder() -> Value {
1786 let tensor = Tensor::new(Vec::<f64>::new(), vec![0, 0]).unwrap();
1787 Value::Tensor(tensor)
1788 }
1789
1790 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1791 #[test]
1792 fn max_scalar_returns_input() {
1793 let result = max_builtin(Value::Num(5.0), Vec::new()).expect("max");
1794 assert_eq!(result, Value::Num(5.0));
1795 }
1796
1797 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1798 #[test]
1799 fn max_vector_with_indices() {
1800 let tensor = Tensor::new(vec![3.0, 1.0, 5.0], vec![3, 1]).unwrap();
1801 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1802 let (values, indices) = eval.into_pair();
1803 assert_eq!(values, Value::Num(5.0));
1804 assert_eq!(indices, Value::Num(3.0));
1805 }
1806
1807 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1808 #[test]
1809 fn max_matrix_default_dimension() {
1810 let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0, 5.0, 6.0], vec![2, 3]).unwrap();
1811 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1812 let (values, indices) = eval.into_pair();
1813 match values {
1814 Value::Tensor(t) => {
1815 assert_eq!(t.shape, vec![1, 3]);
1816 assert_eq!(t.data, vec![4.0, 2.0, 6.0]);
1817 }
1818 other => panic!("expected tensor, got {other:?}"),
1819 }
1820 match indices {
1821 Value::Tensor(t) => {
1822 assert_eq!(t.data, vec![2.0, 2.0, 2.0]);
1823 }
1824 other => panic!("expected tensor, got {other:?}"),
1825 }
1826 }
1827
1828 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1829 #[test]
1830 fn max_all_linear_index() {
1831 let tensor =
1832 Tensor::new((1..=12).map(|v| v as f64).collect::<Vec<_>>(), vec![3, 4]).unwrap();
1833 let args = vec![placeholder(), Value::from("all")];
1834 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1835 let (values, indices) = eval.into_pair();
1836 assert_eq!(values, Value::Num(12.0));
1837 assert_eq!(indices, Value::Num(12.0));
1838
1839 let args_linear = vec![placeholder(), Value::from("linear")];
1840 let eval = evaluate(
1841 Value::Tensor(Tensor::new(vec![2.0, 3.0], vec![1, 2]).unwrap()),
1842 &args_linear,
1843 )
1844 .expect("evaluate");
1845 let (values, indices) = eval.into_pair();
1846 assert_eq!(values, Value::Num(3.0));
1847 assert_eq!(indices, Value::Num(2.0));
1848 }
1849
1850 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1851 #[test]
1852 fn max_with_omitnan() {
1853 let tensor = Tensor::new(vec![f64::NAN, 4.0, 2.0], vec![3, 1]).unwrap();
1854 let args = vec![placeholder(), Value::from("omitnan")];
1855 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1856 let (values, indices) = eval.into_pair();
1857 assert_eq!(values, Value::Num(4.0));
1858 assert_eq!(indices, Value::Num(2.0));
1859 }
1860
1861 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1862 #[test]
1863 fn max_omitnan_all_nan_slice() {
1864 let tensor = Tensor::new(vec![f64::NAN, f64::NAN], vec![2, 1]).unwrap();
1865 let args = vec![placeholder(), Value::from("omitnan")];
1866 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1867 let (values, indices) = eval.into_pair();
1868 match values {
1869 Value::Num(v) => assert!(v.is_nan()),
1870 other => panic!("expected scalar NaN, got {other:?}"),
1871 }
1872 match indices {
1873 Value::Num(v) => assert!(v.is_nan()),
1874 other => panic!("expected scalar NaN index, got {other:?}"),
1875 }
1876 }
1877
1878 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1879 #[test]
1880 fn max_reduction_abs_comparison() {
1881 let tensor = Tensor::new(vec![1.0, -3.0, -2.0, 4.0], vec![2, 2]).unwrap();
1882 let args = vec![
1883 placeholder(),
1884 Value::from("ComparisonMethod"),
1885 Value::from("abs"),
1886 ];
1887 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1888 let (values, indices) = eval.into_pair();
1889 match values {
1890 Value::Tensor(t) => {
1891 assert_eq!(t.shape, vec![1, 2]);
1892 assert_eq!(t.data, vec![-3.0, 4.0]);
1893 }
1894 other => panic!("expected tensor result, got {other:?}"),
1895 }
1896 match indices {
1897 Value::Tensor(t) => {
1898 assert_eq!(t.data, vec![2.0, 2.0]);
1899 }
1900 other => panic!("expected tensor indices, got {other:?}"),
1901 }
1902 }
1903
1904 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1905 #[test]
1906 fn max_reduction_complex_real_comparison() {
1907 let tensor = ComplexTensor::new(vec![(1.0, 2.0), (0.5, 5.0)], vec![2, 1]).expect("tensor");
1908 let args = vec![
1909 placeholder(),
1910 Value::from("ComparisonMethod"),
1911 Value::from("real"),
1912 ];
1913 let eval = evaluate(Value::ComplexTensor(tensor), &args).expect("evaluate");
1914 let (values, indices) = eval.into_pair();
1915 match values {
1916 Value::Complex(re, im) => {
1917 assert!((re - 1.0).abs() < 1e-12);
1918 assert!((im - 2.0).abs() < 1e-12);
1919 }
1920 other => panic!("expected complex scalar, got {other:?}"),
1921 }
1922 assert_eq!(indices, Value::Num(1.0));
1923 }
1924
1925 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1926 #[test]
1927 fn max_elementwise_broadcast() {
1928 let lhs = Tensor::new(vec![1.0, 4.0, 7.0], vec![1, 3]).unwrap();
1929 let rhs = Tensor::new(vec![2.0, 3.0, 5.0], vec![3, 1]).unwrap();
1930 let eval = evaluate(Value::Tensor(lhs), &[Value::Tensor(rhs)]).expect("evaluate");
1931 let (values, indices) = eval.into_pair();
1932 match values {
1933 Value::Tensor(t) => {
1934 assert_eq!(t.shape, vec![3, 3]);
1935 assert_eq!([t.data[0], t.data[3], t.data[6]], [2.0, 4.0, 7.0]);
1936 assert_eq!([t.data[1], t.data[4], t.data[7]], [3.0, 4.0, 7.0]);
1937 assert_eq!([t.data[2], t.data[5], t.data[8]], [5.0, 5.0, 7.0]);
1938 }
1939 other => panic!("expected tensor, got {other:?}"),
1940 }
1941 match indices {
1942 Value::Tensor(t) => {
1943 assert_eq!(t.shape, vec![3, 3]);
1944 assert_eq!([t.data[0], t.data[3], t.data[6]], [2.0, 1.0, 1.0]);
1945 assert_eq!([t.data[1], t.data[4], t.data[7]], [2.0, 1.0, 1.0]);
1946 assert_eq!([t.data[2], t.data[5], t.data[8]], [2.0, 2.0, 1.0]);
1947 }
1948 other => panic!("expected tensor, got {other:?}"),
1949 }
1950 }
1951
1952 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1953 #[test]
1954 fn max_elementwise_abs_comparison() {
1955 let lhs = Tensor::new(vec![-2.0, 1.0], vec![2, 1]).unwrap();
1956 let rhs = Tensor::new(vec![1.5, -3.0], vec![2, 1]).unwrap();
1957 let args = vec![
1958 Value::Tensor(rhs),
1959 Value::from("ComparisonMethod"),
1960 Value::from("abs"),
1961 ];
1962 let eval = evaluate(Value::Tensor(lhs), &args).expect("evaluate");
1963 let (values, indices) = eval.into_pair();
1964 match values {
1965 Value::Tensor(t) => {
1966 assert_eq!(t.data, vec![-2.0, -3.0]);
1967 }
1968 other => panic!("expected tensor, got {other:?}"),
1969 }
1970 match indices {
1971 Value::Tensor(t) => {
1972 assert_eq!(t.data, vec![1.0, 2.0]);
1973 }
1974 other => panic!("expected tensor, got {other:?}"),
1975 }
1976 }
1977
1978 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1979 #[test]
1980 fn max_elementwise_rejects_reduction_only_keywords() {
1981 let lhs = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1982 let rhs = Tensor::new(vec![3.0, 4.0], vec![2, 1]).unwrap();
1983 let err = evaluate(
1984 Value::Tensor(lhs),
1985 &[Value::Tensor(rhs), Value::from("omitnan")],
1986 )
1987 .expect_err("expected error");
1988 assert!(err.message().contains("only supported for reduction"));
1989 }
1990
1991 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1992 #[test]
1993 fn max_complex_real_comparison() {
1994 let lhs = ComplexTensor::new(vec![(1.0, 2.0)], vec![1, 1]).unwrap();
1995 let rhs = ComplexTensor::new(vec![(0.5, 5.0)], vec![1, 1]).unwrap();
1996 let args = vec![
1997 Value::ComplexTensor(rhs),
1998 Value::from("ComparisonMethod"),
1999 Value::from("real"),
2000 ];
2001 let eval = evaluate(Value::ComplexTensor(lhs), &args).expect("evaluate");
2002 let (values, indices) = eval.into_pair();
2003 assert_eq!(values, Value::Complex(1.0, 2.0));
2004 assert_eq!(indices, Value::Num(1.0));
2005 }
2006
2007 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2008 #[test]
2009 fn max_dimension_argument_parsing() {
2010 let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0], vec![2, 2]).unwrap();
2011 let dims = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
2012 let args = vec![placeholder(), Value::Tensor(dims)];
2013 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2014 let (values, indices) = eval.into_pair();
2015 assert_eq!(values, Value::Num(4.0));
2016 assert_eq!(indices, Value::Num(2.0));
2017 }
2018
2019 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2020 #[test]
2021 fn max_vecdim_duplicate_entries() {
2022 let tensor = Tensor::new(vec![5.0, 2.0, 7.0, 1.0], vec![2, 2]).unwrap();
2023 let dims = Tensor::new(vec![1.0, 1.0, 2.0], vec![3, 1]).unwrap();
2024 let args = vec![placeholder(), Value::Tensor(dims)];
2025 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2026 let (values, indices) = eval.into_pair();
2027 assert_eq!(values, Value::Num(7.0));
2028 assert_eq!(indices, Value::Num(3.0));
2029 }
2030
2031 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2032 #[test]
2033 fn max_dimension_gpu_argument_errors() {
2034 let tensor = Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap();
2035 let dim_handle = Value::GpuTensor(runmat_accelerate_api::GpuTensorHandle {
2036 shape: vec![1, 1],
2037 device_id: 0,
2038 buffer_id: 42,
2039 });
2040 let err = evaluate(Value::Tensor(tensor), &[placeholder(), dim_handle])
2041 .expect_err("expected error");
2042 assert!(err
2043 .message()
2044 .contains("dimension arguments must reside on the host"));
2045 }
2046
2047 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2048 #[test]
2049 fn max_invalid_comparison_method_errors() {
2050 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
2051 let args = vec![
2052 placeholder(),
2053 Value::from("ComparisonMethod"),
2054 Value::from("chebyshev"),
2055 ];
2056 let err = evaluate(Value::Tensor(tensor), &args).expect_err("expected error");
2057 assert!(err.message().contains("unsupported ComparisonMethod"));
2058 }
2059
2060 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2061 #[test]
2062 #[cfg(feature = "wgpu")]
2063 fn max_gpu_dim1_matches_cpu() {
2064 let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2065 let eval_cpu = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu");
2066 let (values_cpu, indices_cpu) = eval_cpu.into_pair();
2067
2068 test_support::with_test_provider(|provider| {
2069 let view = HostTensorView {
2070 data: &tensor.data,
2071 shape: &tensor.shape,
2072 };
2073 let handle = provider.upload(&view).expect("upload");
2074 let eval_gpu = evaluate(Value::GpuTensor(handle), &[]).expect("gpu");
2075 let (values_gpu, indices_gpu) = eval_gpu.into_pair();
2076 match (&values_gpu, &indices_gpu) {
2077 (Value::GpuTensor(_), Value::GpuTensor(_)) => {}
2078 other => panic!("expected GPU tensors, got {other:?}"),
2079 }
2080 let gathered_vals = test_support::gather(values_gpu).expect("gather values");
2081 let gathered_idx = test_support::gather(indices_gpu).expect("gather indices");
2082 let expected_vals = match values_cpu {
2083 Value::Tensor(t) => t,
2084 other => panic!("expected tensor values from cpu eval, got {other:?}"),
2085 };
2086 let expected_idx = match indices_cpu {
2087 Value::Tensor(t) => t,
2088 other => panic!("expected tensor indices from cpu eval, got {other:?}"),
2089 };
2090 assert_eq!(gathered_vals.shape, expected_vals.shape);
2091 assert_eq!(gathered_vals.data, expected_vals.data);
2092 assert_eq!(gathered_idx.shape, expected_idx.shape);
2093 assert_eq!(gathered_idx.data, expected_idx.data);
2094 });
2095 }
2096
2097 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2098 #[test]
2099 fn max_dimension_numeric_argument() {
2100 let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0], vec![2, 2]).unwrap();
2101 let args = vec![placeholder(), Value::Num(2.0)];
2102 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2103 let (values, indices) = eval.into_pair();
2104 match values {
2105 Value::Tensor(t) => {
2106 assert_eq!(t.shape, vec![2, 1]);
2107 assert_eq!(t.data, vec![3.0, 4.0]);
2108 }
2109 other => panic!("expected tensor, got {other:?}"),
2110 }
2111 match indices {
2112 Value::Tensor(t) => {
2113 assert_eq!(t.data, vec![1.0, 1.0]);
2114 }
2115 other => panic!("expected tensor, got {other:?}"),
2116 }
2117 }
2118
2119 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2120 #[test]
2121 fn max_complex_auto_comparison() {
2122 let lhs = ComplexTensor::new(vec![(1.0, 2.0)], vec![1, 1]).unwrap();
2123 let rhs = ComplexTensor::new(vec![(2.0, 1.0)], vec![1, 1]).unwrap();
2124 let eval =
2125 evaluate(Value::ComplexTensor(lhs), &[Value::ComplexTensor(rhs)]).expect("evaluate");
2126 let (values, indices) = eval.into_pair();
2127 assert_eq!(values, Value::Complex(1.0, 2.0));
2128 assert_eq!(indices, Value::Num(1.0));
2129 }
2130
2131 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2132 #[test]
2133 fn max_scalar_pair_arguments() {
2134 let args = vec![Value::Num(2.0)];
2135 let result = max_builtin(Value::Num(3.0), args).expect("max");
2136 assert_eq!(result, Value::Num(3.0));
2137 }
2138
2139 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2140 #[test]
2141 fn max_rejects_invalid_dimension() {
2142 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
2143 let args = vec![placeholder(), Value::Int(IntValue::I32(0))];
2144 let err = evaluate(Value::Tensor(tensor), &args).expect_err("expected error");
2145 assert!(err.message().contains("dimension must be >= 1"));
2146 }
2147}