1use std::cmp::Ordering;
4use std::collections::BTreeSet;
5
6use runmat_accelerate_api::{AccelProvider, GpuTensorHandle, ReduceDimResult};
7use runmat_builtins::{ComplexTensor, Tensor, Value};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::broadcast::BroadcastPlan;
11use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
12use crate::builtins::common::{gpu_helpers, tensor};
13#[cfg(feature = "doc_export")]
14use crate::register_builtin_doc_text;
15use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
16
17use crate::builtins::common::spec::{
18 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, FusionError,
19 FusionExprContext, FusionKernelTemplate, GpuOpKind, ProviderHook, ReductionNaN,
20 ResidencyPolicy, ScalarType, ShapeRequirements,
21};
22
23#[cfg(feature = "doc_export")]
24pub const DOC_MD: &str = r#"---
25title: "max"
26category: "math/reduction"
27keywords: ["max", "maximum", "reduction", "comparisonmethod", "omitnan", "gpu"]
28summary: "Return the maximum elements of scalars, vectors, matrices, or N-D tensors with MATLAB-compatible options."
29references: []
30gpu_support:
31 elementwise: false
32 reduction: true
33 precisions: ["f32", "f64"]
34 broadcasting: "matlab"
35 notes: "Uses provider reduce_max_dim / reduce_max when available. Fallback gathers data to the host for omitnan, custom comparison modes, or complex inputs."
36fusion:
37 elementwise: false
38 reduction: true
39 max_inputs: 1
40 constants: "inline"
41requires_feature: null
42tested:
43 unit: "builtins::math::reduction::max::tests"
44 integration: "builtins::math::reduction::max::tests::max_gpu_dim1_matches_cpu"
45---
46
47# What does the `max` function do in MATLAB / RunMat?
48`max` returns the largest values in its input while preserving MATLAB semantics for reductions, elementwise comparisons, NaN handling, complex magnitude comparisons, and linear indexing.
49
50## How does the `max` function behave in MATLAB / RunMat?
51- `max(X)` on an `m Ă— n` array reduces along the first non-singleton dimension, returning a row vector of column-wise maxima and the corresponding indices (when requested).
52- `max(X, [], dim)` reduces along the specified dimension; `max(X, [], vecdim)` reduces along each dimension listed in `vecdim`.
53- `max(X, [], 'all')` collapses every element into a scalar and returns the linear index when two outputs are requested.
54- `max(X, [], 'linear')` is equivalent to `'all'` but guarantees that the matching index is linear over `X(:)`.
55- `max(X, [], ..., 'omitnan')` ignores `NaN` values inside each slice. If every element in a slice is `NaN`, the result for that slice is `NaN` and the index is `NaN`.
56- `max(X, [], ..., 'includenan')` (default) propagates `NaN` whenever a slice contains any `NaN` element, returning the index of the first `NaN`.
57- `max(A, B)` performs elementwise comparison using MATLAB's implicit expansion rules. The second output indicates whether the maximum came from `A` (index `1`) or `B` (index `2`).
58- Complex inputs follow MATLAB ordering: `'ComparisonMethod','auto'` (default) compares magnitudes and breaks ties using phase angles, while `'real'` compares real components first. `'abs'` is an explicit alias for magnitude ordering on real and complex inputs.
59
60## `max` Function GPU Execution Behaviour
61When RunMat Accelerate is active, tensors that already reside on the GPU stay on the device whenever the provider exposes `reduce_max_dim` (for dimension reductions) or `reduce_max` (for whole-array reductions). Requests that require `omitnan`, custom comparison modes, `'linear'` indices, or complex arithmetic gather the data to the host, compute the MATLAB-compatible result, and return the output on the host. Elementwise `max(A, B)` currently executes on the host; the planner rematerializes tensors on the GPU when follow-on fused kernels make it profitable.
62
63## Examples of using the `max` function in MATLAB / RunMat
64
65### Finding column-wise maxima of a matrix
66```matlab
67A = [3 1 5; 4 2 6];
68[m, idx] = max(A);
69```
70Expected output:
71```matlab
72m = [4 2 6];
73idx = [2 2 2];
74```
75
76### Reducing along the second dimension
77```matlab
78A = [3 1 5; 4 2 6];
79[m, idx] = max(A, [], 2);
80```
81Expected output:
82```matlab
83m = [5; 6];
84idx = [3; 3];
85```
86
87### Collapsing all elements with linear indices
88```matlab
89A = reshape(1:12, [3 4]);
90[m, idx] = max(A, [], 'all');
91```
92Expected output:
93```matlab
94m = 12;
95idx = 12; % linear index into A(:)
96```
97
98### Ignoring NaN values during reduction
99```matlab
100values = [NaN 4 2; 3 NaN 1];
101[m, idx] = max(values, [], 1, 'omitnan');
102```
103Expected output:
104```matlab
105m = [3 4 2];
106idx = [2 1 1];
107```
108
109### Elementwise maximum with broadcasting
110```matlab
111A = [1 4 7];
112B = [2; 3; 5];
113[C, origin] = max(A, B);
114```
115Expected output:
116```matlab
117C =
118 2 4 7
119 3 4 7
120 5 5 7
121
122origin =
123 2 1 1
124 2 1 1
125 2 2 1
126```
127
128### Comparing complex values by magnitude
129```matlab
130Z = [1+2i, 2+1i, -2+2i];
131M = max(Z); % magnitude ordering
132R = max(Z, [], 'ComparisonMethod', 'real');
133```
134Expected output:
135```matlab
136M = -2.0000 + 2.0000i
137R = 2.0000 + 1.0000i
138```
139
140## GPU residency in RunMat (Do I need `gpuArray`?)
141You typically do **not** need to call `gpuArray` manually. The fusion planner keeps tensors on the GPU between compatible kernels. When a reduction is supported by the active provider, the maximum values and indices stay on device. If a provider lacks the necessary hook, RunMat gathers data to the host, computes the result, and returns host tensors—subsequent fused GPU kernels can re-upload data when profitable.
142
143## FAQ
144
145### Can I request the linear index of the global maximum?
146Yes. Use either `max(X, [], 'all')` or `max(X, [], 'linear')`. Both return a scalar maximum and the linear index into `X(:)` when you request two outputs.
147
148### Does `max` support `'ComparisonMethod'` for real and complex arrays?
149Absolutely. `'auto'` or `'abs'` compare magnitudes; `'real'` compares the real component first. The returned values always match MATLAB, including tie-breaking rules.
150
151### What happens when all elements are `NaN` and `'omitnan'` is requested?
152The value result is `NaN` and the index is `NaN`, matching MATLAB's behavior for empty slices.
153
154### Can I mix elementwise comparisons with dimension reductions?
155No. `max(A, B)` performs elementwise comparisons only. Use `max(A, [], dim)` when you want reductions along specific dimensions.
156
157### Do GPU reductions support `'omitnan'` or custom comparison methods?
158Not yet. Those requests fall back to the host implementation, which still honors MATLAB semantics. The output remains a host tensor in that case.
159
160### Are logical and integer inputs supported?
161Yes. Logical arrays are promoted to double precision, and integer inputs are converted to double before comparison, matching MATLAB's numeric tower.
162
163## See Also
164[min](./min), [sum](./sum), [mean](./mean), [gpuArray](../../acceleration/gpu/gpuArray), [gather](../../acceleration/gpu/gather)
165
166## Source & Feedback
167- The full source code for the implementation of the `max` function is available at: [`crates/runmat-runtime/src/builtins/math/reduction/max.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/math/reduction/max.rs)
168- Found a bug or behavioral difference? Please [open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with details and a minimal repro.
169"#;
170pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
171 name: "max",
172 op_kind: GpuOpKind::Reduction,
173 supported_precisions: &[ScalarType::F32, ScalarType::F64],
174 broadcast: BroadcastSemantics::Matlab,
175 provider_hooks: &[
176 ProviderHook::Reduction {
177 name: "reduce_max_dim",
178 },
179 ProviderHook::Reduction {
180 name: "reduce_max",
181 },
182 ],
183 constant_strategy: ConstantStrategy::InlineLiteral,
184 residency: ResidencyPolicy::NewHandle,
185 nan_mode: ReductionNaN::Include,
186 two_pass_threshold: Some(256),
187 workgroup_size: Some(256),
188 accepts_nan_mode: false,
189 notes:
190 "Providers should implement reduce_max_dim / reduce_max. Requests that require omitnan, comparisonmethod overrides, or complex inputs fall back to the host implementation.",
191};
192
193register_builtin_gpu_spec!(GPU_SPEC);
194
195pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
196 name: "max",
197 shape: ShapeRequirements::BroadcastCompatible,
198 constant_strategy: ConstantStrategy::InlineLiteral,
199 elementwise: None,
200 reduction: Some(FusionKernelTemplate {
201 scalar_precisions: &[ScalarType::F32, ScalarType::F64],
202 wgsl_body: |ctx: &FusionExprContext| {
203 let input = ctx.inputs.first().ok_or(FusionError::MissingInput(0))?;
204 Ok(format!("accumulator = max(accumulator, {input});"))
205 },
206 }),
207 emits_nan: true,
208 notes: "Fusion planner emits canonical reduction kernels; providers may substitute custom WGSL via reduce_max_dim hooks.",
209};
210
211register_builtin_fusion_spec!(FUSION_SPEC);
212
213#[cfg(feature = "doc_export")]
214register_builtin_doc_text!("max", DOC_MD);
215
216#[derive(Debug, Clone)]
218pub struct MaxEvaluation {
219 values: Value,
220 indices: Value,
221}
222
223impl MaxEvaluation {
224 pub fn into_value(self) -> Value {
226 self.values
227 }
228
229 pub fn into_pair(self) -> (Value, Value) {
231 (self.values, self.indices)
232 }
233
234 pub fn indices_value(&self) -> Value {
236 self.indices.clone()
237 }
238}
239
240#[runtime_builtin(
241 name = "max",
242 category = "math/reduction",
243 summary = "Return the maximum elements of scalars, vectors, matrices, or N-D tensors.",
244 keywords = "max,maximum,reduction,gpu,comparisonmethod,omitnan",
245 accel = "reduction"
246)]
247fn max_builtin(value: Value, rest: Vec<Value>) -> Result<Value, String> {
248 evaluate(value, &rest).map(|eval| eval.into_value())
249}
250
251pub fn evaluate(value: Value, rest: &[Value]) -> Result<MaxEvaluation, String> {
253 let parsed = parse_call(rest)?;
254 if std::env::var("RUNMAT_DEBUG_MAX").is_ok() {
255 let call_label = match &parsed {
256 ParsedCall::Reduction(_) => "reduction",
257 ParsedCall::Elementwise(_) => "elementwise",
258 };
259 let first_arg = rest.first().map(debug_value_kind).unwrap_or("None");
260 eprintln!(
261 "[runmat-debug-max] call_type={call_label} rest_len={} first_arg={first_arg}",
262 rest.len()
263 );
264 }
265 match parsed {
266 ParsedCall::Elementwise(args) => elementwise_max(value, args),
267 ParsedCall::Reduction(args) => reduction_max(value, args),
268 }
269}
270
271#[derive(Debug, Clone)]
272enum ParsedCall {
273 Reduction(ReductionArgs),
274 Elementwise(ElementwiseArgs),
275}
276
277#[derive(Debug, Clone)]
278struct ReductionArgs {
279 selection: DimSelection,
280 nan_mode: ReductionNaN,
281 comparison: ComparisonMethod,
282 linear_index: bool,
283}
284
285impl Default for ReductionArgs {
286 fn default() -> Self {
287 Self {
288 selection: DimSelection::Auto,
289 nan_mode: ReductionNaN::Include,
290 comparison: ComparisonMethod::Auto,
291 linear_index: false,
292 }
293 }
294}
295
296#[derive(Debug, Clone)]
297enum DimSelection {
298 Auto,
299 Dim(usize),
300 Vec(Vec<usize>),
301 All,
302}
303
304#[derive(Debug, Clone, Copy, PartialEq, Eq)]
305enum ComparisonMethod {
306 Auto,
307 Real,
308 Abs,
309}
310
311#[derive(Debug, Clone)]
312struct ElementwiseArgs {
313 other: Value,
314 comparison: ComparisonMethod,
315}
316
317fn parse_call(rest: &[Value]) -> Result<ParsedCall, String> {
318 if rest.is_empty() {
319 return Ok(ParsedCall::Reduction(ReductionArgs::default()));
320 }
321
322 let first = &rest[0];
323 if !is_empty_placeholder(first) {
324 let comparison = parse_elementwise_options(&rest[1..])?;
325 return Ok(ParsedCall::Elementwise(ElementwiseArgs {
326 other: first.clone(),
327 comparison,
328 }));
329 }
330
331 let mut args = ReductionArgs::default();
332 parse_reduction_options(&mut args, &rest[1..])?;
333 Ok(ParsedCall::Reduction(args))
334}
335
336fn debug_value_kind(value: &Value) -> &'static str {
337 match value {
338 Value::Num(_) => "Num",
339 Value::Int(_) => "Int",
340 Value::Bool(_) => "Bool",
341 Value::Tensor(t) => {
342 if t.data.is_empty() {
343 "Tensor(empty)"
344 } else {
345 "Tensor"
346 }
347 }
348 Value::GpuTensor(_) => "GpuTensor",
349 Value::String(_) => "String",
350 Value::CharArray(_) => "CharArray",
351 Value::StringArray(sa) => {
352 if sa.data.is_empty() {
353 "StringArray(empty)"
354 } else {
355 "StringArray"
356 }
357 }
358 Value::LogicalArray(l) => {
359 if l.data.is_empty() {
360 "LogicalArray(empty)"
361 } else {
362 "LogicalArray"
363 }
364 }
365 Value::Cell(c) => {
366 if c.data.is_empty() {
367 "Cell(empty)"
368 } else {
369 "Cell"
370 }
371 }
372 _ => "Other",
373 }
374}
375
376fn is_empty_placeholder(value: &Value) -> bool {
377 match value {
378 Value::Tensor(t) => t.data.is_empty(),
379 Value::LogicalArray(l) => l.data.is_empty(),
380 Value::StringArray(sa) => sa.data.is_empty(),
381 Value::CharArray(ca) => ca.data.is_empty(),
382 Value::Cell(cell) => cell.data.is_empty(),
383 Value::String(s) => s.is_empty(),
384 _ => false,
385 }
386}
387
388fn parse_reduction_options(args: &mut ReductionArgs, rest: &[Value]) -> Result<(), String> {
389 let mut idx = 0usize;
390 let mut selection_set = !matches!(args.selection, DimSelection::Auto);
391 let mut comparison_set = matches!(args.comparison, ComparisonMethod::Auto);
392 while idx < rest.len() {
393 if let Some(keyword) = keyword_of(&rest[idx]) {
394 match keyword.as_str() {
395 "omitnan" => {
396 args.nan_mode = ReductionNaN::Omit;
397 idx += 1;
398 continue;
399 }
400 "includenan" => {
401 args.nan_mode = ReductionNaN::Include;
402 idx += 1;
403 continue;
404 }
405 "all" => {
406 if selection_set {
407 return Err(
408 "max: 'all' cannot be combined with an explicit dimension".to_string()
409 );
410 }
411 args.selection = DimSelection::All;
412 selection_set = true;
413 idx += 1;
414 continue;
415 }
416 "linear" => {
417 if selection_set {
418 return Err(
419 "max: 'linear' cannot be combined with an explicit dimension"
420 .to_string(),
421 );
422 }
423 args.selection = DimSelection::All;
424 args.linear_index = true;
425 selection_set = true;
426 idx += 1;
427 continue;
428 }
429 "comparisonmethod" => {
430 let Some(value) = rest.get(idx + 1) else {
431 return Err("max: expected a value after 'ComparisonMethod'".to_string());
432 };
433 args.comparison = parse_comparison_method(value)?;
434 comparison_set = true;
435 idx += 2;
436 continue;
437 }
438 _ => {}
439 }
440 }
441
442 if !selection_set {
443 if let Some(selection) = parse_dimension_value(&rest[idx])? {
444 args.selection = selection;
445 selection_set = true;
446 idx += 1;
447 continue;
448 }
449 }
450
451 return Err(format!("max: unrecognised argument {:?}", rest[idx]));
452 }
453
454 if !comparison_set {
455 args.comparison = ComparisonMethod::Auto;
456 }
457
458 Ok(())
459}
460
461fn parse_elementwise_options(rest: &[Value]) -> Result<ComparisonMethod, String> {
462 let mut comparison = ComparisonMethod::Auto;
463 let mut comparison_set = false;
464 let mut idx = 0usize;
465 while idx < rest.len() {
466 if let Some(keyword) = keyword_of(&rest[idx]) {
467 match keyword.as_str() {
468 "comparisonmethod" => {
469 let Some(value) = rest.get(idx + 1) else {
470 return Err("max: expected a value after 'ComparisonMethod'".to_string());
471 };
472 comparison = parse_comparison_method(value)?;
473 comparison_set = true;
474 idx += 2;
475 continue;
476 }
477 "omitnan" | "includenan" | "all" | "linear" => {
478 return Err(format!(
479 "max: '{}' is only supported for reduction calls",
480 keyword
481 ));
482 }
483 _ => {}
484 }
485 }
486 return Err(format!("max: unrecognised argument {:?}", rest[idx]));
487 }
488 if !comparison_set {
489 comparison = ComparisonMethod::Auto;
490 }
491 Ok(comparison)
492}
493
494fn parse_comparison_method(value: &Value) -> Result<ComparisonMethod, String> {
495 let Some(keyword) = keyword_of(value) else {
496 return Err("max: 'ComparisonMethod' expects a string value".to_string());
497 };
498 match keyword.as_str() {
499 "auto" => Ok(ComparisonMethod::Auto),
500 "abs" | "magnitude" => Ok(ComparisonMethod::Abs),
501 "real" => Ok(ComparisonMethod::Real),
502 other => Err(format!("max: unsupported ComparisonMethod '{other}'")),
503 }
504}
505
506fn parse_dimension_value(value: &Value) -> Result<Option<DimSelection>, String> {
507 match value {
508 Value::Int(i) => {
509 let raw = i.to_i64();
510 if raw < 1 {
511 return Err("max: dimension must be >= 1".to_string());
512 }
513 Ok(Some(DimSelection::Dim(raw as usize)))
514 }
515 Value::Num(n) => {
516 if !n.is_finite() {
517 return Err("max: dimension must be finite".to_string());
518 }
519 let rounded = n.round();
520 if (rounded - n).abs() > f64::EPSILON {
521 return Err("max: dimension must be integral".to_string());
522 }
523 if rounded < 1.0 {
524 return Err("max: dimension must be >= 1".to_string());
525 }
526 Ok(Some(DimSelection::Dim(rounded as usize)))
527 }
528 Value::Tensor(t) => parse_dimension_tensor(t),
529 Value::LogicalArray(logical) => {
530 let tensor = tensor::logical_to_tensor(logical)?;
531 parse_dimension_tensor(&tensor)
532 }
533 Value::GpuTensor(_) => Err(
534 "max: dimension arguments must reside on the host (they cannot be gpuArray values)"
535 .to_string(),
536 ),
537 _ => Ok(None),
538 }
539}
540
541fn parse_dimension_tensor(tensor: &Tensor) -> Result<Option<DimSelection>, String> {
542 if tensor.data.is_empty() {
543 return Ok(Some(DimSelection::Auto));
544 }
545 if tensor.rows() != 1 && tensor.cols() != 1 && tensor.shape.len() != 1 {
546 return Err("max: dimension vector must be a row or column vector".to_string());
547 }
548 let mut dims = Vec::with_capacity(tensor.data.len());
549 for &value in &tensor.data {
550 if !value.is_finite() {
551 return Err("max: dimension entries must be finite".to_string());
552 }
553 let rounded = value.round();
554 if (rounded - value).abs() > f64::EPSILON {
555 return Err("max: dimension entries must be integers".to_string());
556 }
557 if rounded < 1.0 {
558 return Err("max: dimension indices must be >= 1".to_string());
559 }
560 dims.push(rounded as usize);
561 }
562 if dims.is_empty() {
563 Ok(Some(DimSelection::Auto))
564 } else {
565 let mut seen = BTreeSet::new();
567 let mut uniq = Vec::with_capacity(dims.len());
568 for dim in dims {
569 if seen.insert(dim) {
570 uniq.push(dim);
571 }
572 }
573 Ok(Some(DimSelection::Vec(uniq)))
574 }
575}
576
577fn reduction_max(value: Value, args: ReductionArgs) -> Result<MaxEvaluation, String> {
578 match value {
579 Value::GpuTensor(handle) => {
580 if let Some(eval) = reduction_max_gpu(handle.clone(), &args)? {
581 return Ok(eval);
582 }
583 let tensor = gpu_helpers::gather_tensor(&handle)?;
585 reduction_max_host(Value::Tensor(tensor), &args)
586 }
587 other => reduction_max_host(other, &args),
588 }
589}
590
591fn reduction_max_gpu(
592 handle: GpuTensorHandle,
593 args: &ReductionArgs,
594) -> Result<Option<MaxEvaluation>, String> {
595 #[cfg(all(test, feature = "wgpu"))]
596 {
597 if handle.device_id != 0 {
598 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
599 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
600 );
601 }
602 }
603 if args.nan_mode == ReductionNaN::Omit {
604 return Ok(None);
605 }
606 if args.comparison != ComparisonMethod::Auto {
607 return Ok(None);
608 }
609 if args.linear_index {
610 return Ok(None);
611 }
612 let provider = match runmat_accelerate_api::provider() {
613 Some(p) => p,
614 None => return Ok(None),
615 };
616 let target_dim = match args.selection {
617 DimSelection::Auto => default_dimension_from_shape(&handle.shape),
618 DimSelection::Dim(dim) => dim,
619 DimSelection::Vec(ref dims) if dims.len() == 1 => dims[0],
620 DimSelection::All => {
621 if handle.shape.len() <= 1 {
622 1
623 } else {
624 return Ok(None);
625 }
626 }
627 _ => return Ok(None),
628 };
629 if target_dim == 0 {
630 return Ok(None);
631 }
632 let zero_based = target_dim.saturating_sub(1);
634 if zero_based >= handle.shape.len() {
635 return Ok(None);
636 }
637 match provider.reduce_max_dim(&handle, zero_based) {
638 Ok(ReduceDimResult { values, indices }) => Ok(Some(MaxEvaluation {
639 values: Value::GpuTensor(values),
640 indices: Value::GpuTensor(indices),
641 })),
642 Err(_) => Ok(None),
643 }
644}
645
646fn reduction_max_host(value: Value, args: &ReductionArgs) -> Result<MaxEvaluation, String> {
647 match materialize_for_max("max", value)? {
648 InputData::Real(tensor) => reduce_real_tensor(tensor, args),
649 InputData::Complex(tensor) => reduce_complex_tensor(tensor, args),
650 }
651}
652
653enum InputData {
654 Real(Tensor),
655 Complex(ComplexTensor),
656}
657
658fn materialize_for_max(name: &str, value: Value) -> Result<InputData, String> {
659 match value {
660 Value::Tensor(t) => Ok(InputData::Real(t)),
661 Value::LogicalArray(logical) => {
662 let tensor = tensor::logical_to_tensor(&logical)?;
663 Ok(InputData::Real(tensor))
664 }
665 Value::Num(n) => {
666 let tensor = Tensor::new(vec![n], vec![1, 1]).map_err(|e| format!("{name}: {e}"))?;
667 Ok(InputData::Real(tensor))
668 }
669 Value::Int(i) => {
670 let tensor =
671 Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(|e| format!("{name}: {e}"))?;
672 Ok(InputData::Real(tensor))
673 }
674 Value::Bool(b) => {
675 let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
676 .map_err(|e| format!("{name}: {e}"))?;
677 Ok(InputData::Real(tensor))
678 }
679 Value::Complex(re, im) => {
680 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
681 .map_err(|e| format!("{name}: {e}"))?;
682 Ok(InputData::Complex(tensor))
683 }
684 Value::ComplexTensor(ct) => Ok(InputData::Complex(ct)),
685 Value::String(_) | Value::StringArray(_) | Value::CharArray(_) | Value::Cell(_) => Err(
686 format!("{name}: expected numeric or logical input, received non-numeric value"),
687 ),
688 Value::GpuTensor(_) => Err(format!(
689 "{name}: internal error – GPU tensors must be gathered before host execution"
690 )),
691 Value::Object(_) | Value::HandleObject(_) | Value::Struct(_) | Value::Listener(_) => {
692 Err(format!("{name}: unsupported input type"))
693 }
694 Value::FunctionHandle(_)
695 | Value::Closure(_)
696 | Value::ClassRef(_)
697 | Value::MException(_) => Err(format!("{name}: unsupported input type")),
698 }
699}
700
701fn reduce_real_tensor(tensor: Tensor, args: &ReductionArgs) -> Result<MaxEvaluation, String> {
702 let shape = tensor.shape.clone();
703 if tensor.data.is_empty() {
704 let output_shape = resolve_output_shape(&shape, &args.selection, &[])?;
705 let values =
706 Tensor::new(Vec::new(), output_shape.clone()).map_err(|e| format!("max: {e}"))?;
707 let indices = Tensor::new(Vec::new(), output_shape).map_err(|e| format!("max: {e}"))?;
708 return Ok(MaxEvaluation {
709 values: tensor::tensor_into_value(values),
710 indices: tensor::tensor_into_value(indices),
711 });
712 }
713 let resolved = resolve_reduction_dims(&shape, &args.selection)?;
714 let output_shape = resolved.output_shape.clone();
715 let output_len = tensor::element_count(&output_shape);
716
717 if output_len == 0 {
718 let values =
719 Tensor::new(Vec::new(), output_shape.clone()).map_err(|e| format!("max: {e}"))?;
720 let indices = Tensor::new(Vec::new(), output_shape).map_err(|e| format!("max: {e}"))?;
721 return Ok(MaxEvaluation {
722 values: tensor::tensor_into_value(values),
723 indices: tensor::tensor_into_value(indices),
724 });
725 }
726
727 let strides = compute_strides(&shape);
728 let output_strides = compute_strides(&output_shape);
729 let dims_mask = resolved.dims_mask.clone();
730 let reduce_strides = resolved.reduce_strides.clone();
731
732 let mut best = vec![BestReal::new(); output_len];
733 let mut coords = vec![0usize; shape.len()];
734 for &value in &tensor.data {
735 let out_idx = map_output_index(&coords, &output_strides, &dims_mask);
736 let reduce_idx = map_reduce_index(
737 &coords,
738 &resolved.reduced_dims,
739 &reduce_strides,
740 resolved.reduce_all,
741 );
742 let full_idx = map_linear_index(&coords, &strides);
743
744 update_best_real(
745 &mut best[out_idx],
746 value,
747 reduce_idx,
748 full_idx,
749 args.nan_mode,
750 args.comparison,
751 );
752 increment_coords(&mut coords, &shape);
753 }
754
755 let mut values = vec![0.0f64; output_len];
756 let mut indices = vec![0.0f64; output_len];
757
758 for (i, entry) in best.iter().enumerate() {
759 if entry.nan_fixed {
760 values[i] = f64::NAN;
761 indices[i] = if args.linear_index || resolved.reduce_all {
762 (entry.full_index + 1) as f64
763 } else if resolved.reduced_dims.is_empty() {
764 1.0
765 } else {
766 (entry.reduce_index + 1) as f64
767 };
768 continue;
769 }
770 if !entry.has_value {
771 values[i] = f64::NAN;
772 indices[i] = f64::NAN;
773 continue;
774 }
775 values[i] = entry.value;
776 indices[i] = if args.linear_index || resolved.reduce_all {
777 (entry.full_index + 1) as f64
778 } else if resolved.reduced_dims.is_empty() {
779 1.0
780 } else {
781 (entry.reduce_index + 1) as f64
782 };
783 }
784
785 let value_tensor =
786 Tensor::new(values, output_shape.clone()).map_err(|e| format!("max: {e}"))?;
787 let index_tensor = Tensor::new(indices, output_shape).map_err(|e| format!("max: {e}"))?;
788
789 Ok(MaxEvaluation {
790 values: tensor::tensor_into_value(value_tensor),
791 indices: tensor::tensor_into_value(index_tensor),
792 })
793}
794
795fn reduce_complex_tensor(
796 tensor: ComplexTensor,
797 args: &ReductionArgs,
798) -> Result<MaxEvaluation, String> {
799 let shape = tensor.shape.clone();
800 if tensor.data.is_empty() {
801 let output_shape = resolve_output_shape(&shape, &args.selection, &[])?;
802 let values = ComplexTensor::new(Vec::new(), output_shape.clone())
803 .map_err(|e| format!("max: {e}"))?;
804 let indices = Tensor::new(Vec::new(), output_shape).map_err(|e| format!("max: {e}"))?;
805 return Ok(MaxEvaluation {
806 values: complex_tensor_into_value(values),
807 indices: tensor::tensor_into_value(indices),
808 });
809 }
810
811 let resolved = resolve_reduction_dims(&shape, &args.selection)?;
812 let output_shape = resolved.output_shape.clone();
813 let output_len = tensor::element_count(&output_shape);
814
815 if output_len == 0 {
816 let values = ComplexTensor::new(Vec::new(), output_shape.clone())
817 .map_err(|e| format!("max: {e}"))?;
818 let indices = Tensor::new(Vec::new(), output_shape).map_err(|e| format!("max: {e}"))?;
819 return Ok(MaxEvaluation {
820 values: complex_tensor_into_value(values),
821 indices: tensor::tensor_into_value(indices),
822 });
823 }
824
825 let strides = compute_strides(&shape);
826 let output_strides = compute_strides(&output_shape);
827 let dims_mask = resolved.dims_mask.clone();
828 let reduce_strides = resolved.reduce_strides.clone();
829
830 let mut best = vec![BestComplex::new(); output_len];
831 let mut coords = vec![0usize; shape.len()];
832
833 for &(re, im) in &tensor.data {
834 let out_idx = map_output_index(&coords, &output_strides, &dims_mask);
835 let reduce_idx = map_reduce_index(
836 &coords,
837 &resolved.reduced_dims,
838 &reduce_strides,
839 resolved.reduce_all,
840 );
841 let full_idx = map_linear_index(&coords, &strides);
842 update_best_complex(
843 &mut best[out_idx],
844 (re, im),
845 reduce_idx,
846 full_idx,
847 args.nan_mode,
848 args.comparison,
849 );
850 increment_coords(&mut coords, &shape);
851 }
852
853 let mut values = vec![(0.0f64, 0.0f64); output_len];
854 let mut indices = vec![0.0f64; output_len];
855
856 for (i, entry) in best.iter().enumerate() {
857 if entry.nan_fixed {
858 values[i] = (f64::NAN, f64::NAN);
859 indices[i] = if args.linear_index || resolved.reduce_all {
860 (entry.full_index + 1) as f64
861 } else if resolved.reduced_dims.is_empty() {
862 1.0
863 } else {
864 (entry.reduce_index + 1) as f64
865 };
866 continue;
867 }
868 if !entry.has_value {
869 values[i] = (f64::NAN, f64::NAN);
870 indices[i] = f64::NAN;
871 continue;
872 }
873 values[i] = entry.value;
874 indices[i] = if args.linear_index || resolved.reduce_all {
875 (entry.full_index + 1) as f64
876 } else if resolved.reduced_dims.is_empty() {
877 1.0
878 } else {
879 (entry.reduce_index + 1) as f64
880 };
881 }
882
883 let value_tensor =
884 ComplexTensor::new(values, output_shape.clone()).map_err(|e| format!("max: {e}"))?;
885 let index_tensor = Tensor::new(indices, output_shape).map_err(|e| format!("max: {e}"))?;
886 Ok(MaxEvaluation {
887 values: complex_tensor_into_value(value_tensor),
888 indices: tensor::tensor_into_value(index_tensor),
889 })
890}
891
892#[derive(Debug, Clone)]
893struct BestReal {
894 value: f64,
895 reduce_index: usize,
896 full_index: usize,
897 has_value: bool,
898 nan_fixed: bool,
899}
900
901impl BestReal {
902 fn new() -> Self {
903 Self {
904 value: 0.0,
905 reduce_index: 0,
906 full_index: 0,
907 has_value: false,
908 nan_fixed: false,
909 }
910 }
911}
912
913#[derive(Debug, Clone)]
914struct BestComplex {
915 value: (f64, f64),
916 reduce_index: usize,
917 full_index: usize,
918 has_value: bool,
919 nan_fixed: bool,
920}
921
922impl BestComplex {
923 fn new() -> Self {
924 Self {
925 value: (0.0, 0.0),
926 reduce_index: 0,
927 full_index: 0,
928 has_value: false,
929 nan_fixed: false,
930 }
931 }
932}
933
934fn resolve_output_shape(
935 shape: &[usize],
936 selection: &DimSelection,
937 reduced_dims: &[usize],
938) -> Result<Vec<usize>, String> {
939 if shape.is_empty() {
940 return Ok(Vec::new());
941 }
942 let mut output = shape.to_vec();
943 match selection {
944 DimSelection::All => {
945 output.fill(1);
946 }
947 _ => {
948 for &dim in reduced_dims {
949 if dim < output.len() {
950 output[dim] = 1;
951 }
952 }
953 }
954 }
955 Ok(output)
956}
957
958struct ResolvedDims {
959 output_shape: Vec<usize>,
960 reduced_dims: Vec<usize>,
961 reduce_all: bool,
962 dims_mask: Vec<bool>,
963 reduce_strides: Vec<usize>,
964}
965
966fn resolve_reduction_dims(
967 shape: &[usize],
968 selection: &DimSelection,
969) -> Result<ResolvedDims, String> {
970 if shape.is_empty() {
971 return Ok(ResolvedDims {
972 output_shape: Vec::new(),
973 reduced_dims: Vec::new(),
974 reduce_all: true,
975 dims_mask: Vec::new(),
976 reduce_strides: Vec::new(),
977 });
978 }
979
980 let mut reduced_dims = match selection {
981 DimSelection::Auto => {
982 let mut dim = None;
983 for (index, &len) in shape.iter().enumerate() {
984 if len > 1 {
985 dim = Some(index);
986 break;
987 }
988 }
989 vec![dim.unwrap_or(0)]
990 }
991 DimSelection::Dim(dim) => {
992 if *dim == 0 {
993 return Err("max: dimension must be >= 1".to_string());
994 }
995 let index = dim.saturating_sub(1);
996 if index >= shape.len() {
997 Vec::new()
998 } else {
999 vec![index]
1000 }
1001 }
1002 DimSelection::Vec(dims) => {
1003 if dims.is_empty() {
1004 Vec::new()
1005 } else {
1006 dims.iter()
1007 .filter_map(|dim| {
1008 if *dim == 0 {
1009 None
1010 } else {
1011 let idx = dim - 1;
1012 if idx < shape.len() {
1013 Some(idx)
1014 } else {
1015 None
1016 }
1017 }
1018 })
1019 .collect()
1020 }
1021 }
1022 DimSelection::All => (0..shape.len()).collect(),
1023 };
1024
1025 reduced_dims.sort_unstable();
1026 reduced_dims.dedup();
1027
1028 let reduce_all = !reduced_dims.is_empty()
1029 && reduced_dims.len() == shape.len()
1030 && reduced_dims.iter().enumerate().all(|(i, &d)| i == d);
1031
1032 let output_shape = resolve_output_shape(shape, selection, &reduced_dims)?;
1033 let mut dims_mask = vec![false; shape.len()];
1034 for &dim in &reduced_dims {
1035 if dim < dims_mask.len() {
1036 dims_mask[dim] = true;
1037 }
1038 }
1039 let reduce_strides = compute_subspace_strides(shape, &reduced_dims);
1040
1041 Ok(ResolvedDims {
1042 output_shape,
1043 reduced_dims,
1044 reduce_all,
1045 dims_mask,
1046 reduce_strides,
1047 })
1048}
1049
1050fn compute_strides(shape: &[usize]) -> Vec<usize> {
1051 let mut strides = Vec::with_capacity(shape.len());
1052 let mut stride = 1usize;
1053 for &len in shape {
1054 strides.push(stride);
1055 stride = stride.saturating_mul(len.max(1));
1056 }
1057 strides
1058}
1059
1060fn compute_subspace_strides(shape: &[usize], dims: &[usize]) -> Vec<usize> {
1061 if dims.is_empty() {
1062 return Vec::new();
1063 }
1064 let mut strides = Vec::with_capacity(dims.len());
1065 let mut accum = 1usize;
1066 for &dim in dims {
1067 let len = shape.get(dim).copied().unwrap_or(1).max(1);
1068 strides.push(accum);
1069 accum = accum.saturating_mul(len);
1070 }
1071 strides
1072}
1073
1074fn map_output_index(coords: &[usize], output_strides: &[usize], dims_mask: &[bool]) -> usize {
1075 if coords.is_empty() {
1076 return 0;
1077 }
1078 let mut index = 0usize;
1079 for (dim, stride) in output_strides.iter().enumerate() {
1080 let coord = if *dims_mask.get(dim).unwrap_or(&false) {
1081 0
1082 } else {
1083 coords[dim]
1084 };
1085 index = index.saturating_add(coord.saturating_mul(*stride));
1086 }
1087 index
1088}
1089
1090fn map_reduce_index(
1091 coords: &[usize],
1092 reduced_dims: &[usize],
1093 reduce_strides: &[usize],
1094 reduce_all: bool,
1095) -> usize {
1096 if reduced_dims.is_empty() {
1097 return 0;
1098 }
1099 if reduce_all {
1100 return 0;
1102 }
1103 let mut index = 0usize;
1104 for (pos, &dim) in reduced_dims.iter().enumerate() {
1105 if let Some(coord) = coords.get(dim) {
1106 if let Some(stride) = reduce_strides.get(pos) {
1107 index = index.saturating_add(coord.saturating_mul(*stride));
1108 }
1109 }
1110 }
1111 index
1112}
1113
1114fn map_linear_index(coords: &[usize], strides: &[usize]) -> usize {
1115 coords
1116 .iter()
1117 .zip(strides.iter())
1118 .fold(0usize, |acc, (&coord, &stride)| {
1119 acc.saturating_add(coord.saturating_mul(stride))
1120 })
1121}
1122
1123fn increment_coords(coords: &mut [usize], shape: &[usize]) {
1124 for dim in 0..coords.len() {
1125 if shape[dim] == 0 {
1126 continue;
1127 }
1128 coords[dim] += 1;
1129 if coords[dim] < shape[dim] {
1130 break;
1131 }
1132 coords[dim] = 0;
1133 }
1134}
1135
1136fn update_best_real(
1137 best: &mut BestReal,
1138 value: f64,
1139 reduce_index: usize,
1140 full_index: usize,
1141 nan_mode: ReductionNaN,
1142 comparison: ComparisonMethod,
1143) {
1144 if value.is_nan() {
1145 match nan_mode {
1146 ReductionNaN::Include => {
1147 if !best.nan_fixed {
1148 best.value = f64::NAN;
1149 best.reduce_index = reduce_index;
1150 best.full_index = full_index;
1151 best.has_value = true;
1152 best.nan_fixed = true;
1153 }
1154 }
1155 ReductionNaN::Omit => {}
1156 }
1157 return;
1158 }
1159 if best.nan_fixed {
1160 return;
1161 }
1162
1163 if !best.has_value {
1164 best.value = value;
1165 best.reduce_index = reduce_index;
1166 best.full_index = full_index;
1167 best.has_value = true;
1168 return;
1169 }
1170
1171 if should_replace_real(best.value, value, comparison) {
1172 best.value = value;
1173 best.reduce_index = reduce_index;
1174 best.full_index = full_index;
1175 }
1176}
1177
1178fn update_best_complex(
1179 best: &mut BestComplex,
1180 value: (f64, f64),
1181 reduce_index: usize,
1182 full_index: usize,
1183 nan_mode: ReductionNaN,
1184 comparison: ComparisonMethod,
1185) {
1186 if value.0.is_nan() || value.1.is_nan() {
1187 match nan_mode {
1188 ReductionNaN::Include => {
1189 if !best.nan_fixed {
1190 best.value = (f64::NAN, f64::NAN);
1191 best.reduce_index = reduce_index;
1192 best.full_index = full_index;
1193 best.has_value = true;
1194 best.nan_fixed = true;
1195 }
1196 }
1197 ReductionNaN::Omit => {}
1198 }
1199 return;
1200 }
1201 if best.nan_fixed {
1202 return;
1203 }
1204
1205 if !best.has_value {
1206 best.value = value;
1207 best.reduce_index = reduce_index;
1208 best.full_index = full_index;
1209 best.has_value = true;
1210 return;
1211 }
1212
1213 if should_replace_complex(best.value, value, comparison) {
1214 best.value = value;
1215 best.reduce_index = reduce_index;
1216 best.full_index = full_index;
1217 }
1218}
1219
1220fn should_replace_real(current: f64, candidate: f64, comparison: ComparisonMethod) -> bool {
1221 match comparison {
1222 ComparisonMethod::Auto | ComparisonMethod::Real => {
1223 if candidate > current {
1224 return true;
1225 }
1226 if candidate < current {
1227 return false;
1228 }
1229 if candidate == 0.0 && current == 0.0 {
1230 return candidate.is_sign_positive() && !current.is_sign_positive();
1231 }
1232 false
1233 }
1234 ComparisonMethod::Abs => {
1235 let curr_abs = current.abs();
1236 let cand_abs = candidate.abs();
1237 if cand_abs > curr_abs {
1238 return true;
1239 }
1240 if cand_abs < curr_abs {
1241 return false;
1242 }
1243 if candidate > current {
1244 return true;
1245 }
1246 if candidate < current {
1247 return false;
1248 }
1249 if candidate == 0.0 && current == 0.0 {
1250 return candidate.is_sign_positive() && !current.is_sign_positive();
1251 }
1252 false
1253 }
1254 }
1255}
1256
1257fn should_replace_complex(
1258 current: (f64, f64),
1259 candidate: (f64, f64),
1260 comparison: ComparisonMethod,
1261) -> bool {
1262 match comparison {
1263 ComparisonMethod::Auto | ComparisonMethod::Abs => {
1264 compare_complex_auto(current, candidate) == Ordering::Less
1265 }
1266 ComparisonMethod::Real => compare_complex_real(current, candidate) == Ordering::Less,
1267 }
1268}
1269
1270fn compare_complex_auto(a: (f64, f64), b: (f64, f64)) -> Ordering {
1271 let a_mag = magnitude_squared(a);
1272 let b_mag = magnitude_squared(b);
1273 if a_mag < b_mag {
1274 return Ordering::Less;
1275 }
1276 if a_mag > b_mag {
1277 return Ordering::Greater;
1278 }
1279 let a_angle = a.1.atan2(a.0);
1281 let b_angle = b.1.atan2(b.0);
1282 if a_angle < b_angle {
1283 Ordering::Less
1284 } else if a_angle > b_angle {
1285 Ordering::Greater
1286 } else {
1287 Ordering::Equal
1288 }
1289}
1290
1291fn compare_complex_real(a: (f64, f64), b: (f64, f64)) -> Ordering {
1292 if a.0 < b.0 {
1293 return Ordering::Less;
1294 }
1295 if a.0 > b.0 {
1296 return Ordering::Greater;
1297 }
1298 compare_complex_auto(a, b)
1300}
1301
1302fn magnitude_squared(z: (f64, f64)) -> f64 {
1303 z.0.mul_add(z.0, z.1 * z.1)
1304}
1305
1306fn default_dimension_from_shape(shape: &[usize]) -> usize {
1307 if shape.is_empty() {
1308 return 1;
1309 }
1310 for (i, &len) in shape.iter().enumerate() {
1311 if len > 1 {
1312 return i + 1;
1313 }
1314 }
1315 1
1316}
1317
1318fn elementwise_max(value: Value, args: ElementwiseArgs) -> Result<MaxEvaluation, String> {
1319 let ElementwiseArgs { other, comparison } = args;
1320 match (value, other) {
1321 (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
1322 if gpu_tensor_is_scalar(&handle_b) {
1323 if let Some(num) = gpu_tensor_scalar_value(&handle_b) {
1324 let scalar = Value::Num(num);
1325 return elementwise_max_gpu_scalar_left(&handle_a, &scalar, comparison)
1326 .or_else(|| {
1327 let ta = gpu_helpers::gather_tensor(&handle_a).ok()?;
1328 elementwise_real_or_complex(
1329 Value::Tensor(ta),
1330 scalar.clone(),
1331 comparison,
1332 )
1333 .ok()
1334 })
1335 .ok_or_else(|| "max: elementwise GPU scalar path failed".to_string());
1336 }
1337 }
1338 if gpu_tensor_is_scalar(&handle_a) {
1339 if let Some(num) = gpu_tensor_scalar_value(&handle_a) {
1340 let scalar = Value::Num(num);
1341 return elementwise_max_gpu_scalar_right(&scalar, &handle_b, comparison)
1342 .or_else(|| {
1343 let tb = gpu_helpers::gather_tensor(&handle_b).ok()?;
1344 elementwise_real_or_complex(
1345 scalar.clone(),
1346 Value::Tensor(tb),
1347 comparison,
1348 )
1349 .ok()
1350 })
1351 .ok_or_else(|| "max: elementwise GPU scalar path failed".to_string());
1352 }
1353 }
1354 elementwise_max_gpu_pair(&handle_a, &handle_b, comparison)
1355 .or_else(|| {
1356 let ta = gpu_helpers::gather_tensor(&handle_a).ok()?;
1358 let tb = gpu_helpers::gather_tensor(&handle_b).ok()?;
1359 elementwise_real_or_complex(Value::Tensor(ta), Value::Tensor(tb), comparison)
1360 .ok()
1361 })
1362 .ok_or_else(|| "max: elementwise GPU path failed".to_string())
1363 }
1364 (Value::GpuTensor(handle), other) => {
1365 elementwise_max_gpu_scalar_left(&handle, &other, comparison)
1366 .or_else(|| {
1367 let t = gpu_helpers::gather_tensor(&handle).ok()?;
1368 elementwise_real_or_complex(Value::Tensor(t), other, comparison).ok()
1369 })
1370 .ok_or_else(|| "max: elementwise GPU scalar path failed".to_string())
1371 }
1372 (other, Value::GpuTensor(handle)) => {
1373 elementwise_max_gpu_scalar_right(&other, &handle, comparison)
1374 .or_else(|| {
1375 let t = gpu_helpers::gather_tensor(&handle).ok()?;
1376 elementwise_real_or_complex(other, Value::Tensor(t), comparison).ok()
1377 })
1378 .ok_or_else(|| "max: elementwise GPU scalar path failed".to_string())
1379 }
1380 (lhs, rhs) => elementwise_real_or_complex(lhs, rhs, comparison),
1381 }
1382}
1383
1384fn elementwise_max_gpu_pair(
1385 a: &GpuTensorHandle,
1386 b: &GpuTensorHandle,
1387 comparison: ComparisonMethod,
1388) -> Option<MaxEvaluation> {
1389 if comparison != ComparisonMethod::Auto {
1390 return None;
1391 }
1392 let provider = runmat_accelerate_api::provider()?;
1393 if a.shape == b.shape {
1395 let values = provider.elem_max(a, b).ok()?;
1396 if let Ok(mask) = provider.elem_ge(a, b) {
1398 let indices = gpu_mask_indices(provider, &mask)?;
1399 let _ = provider.free(&mask);
1400 return Some(MaxEvaluation {
1401 values: Value::GpuTensor(values),
1402 indices: Value::GpuTensor(indices),
1403 });
1404 } else {
1405 let ta = gpu_helpers::gather_tensor(a).ok()?;
1407 let tb = gpu_helpers::gather_tensor(b).ok()?;
1408 let mut indices = Vec::with_capacity(ta.data.len());
1409 for i in 0..ta.data.len() {
1410 indices.push(if ta.data[i] >= tb.data[i] { 1.0 } else { 2.0 });
1411 }
1412 let index_tensor = Tensor::new(indices, ta.shape.clone()).ok()?;
1413 return Some(MaxEvaluation {
1414 values: Value::GpuTensor(values),
1415 indices: tensor::tensor_into_value(index_tensor),
1416 });
1417 }
1418 }
1419 let (out_shape, reps_a, reps_b) = broadcast_reps(&a.shape, &b.shape)?;
1421 let a_exp = if reps_a.iter().any(|&r| r != 1) {
1422 provider.repmat(a, &reps_a).ok()?
1423 } else {
1424 a.clone()
1425 };
1426 let b_exp = if reps_b.iter().any(|&r| r != 1) {
1427 provider.repmat(b, &reps_b).ok()?
1428 } else {
1429 b.clone()
1430 };
1431 let values = provider.elem_max(&a_exp, &b_exp).ok();
1432 let mask = provider.elem_ge(&a_exp, &b_exp).ok();
1433 if !std::ptr::eq(&a_exp, a) {
1434 let _ = provider.free(&a_exp);
1435 }
1436 if !std::ptr::eq(&b_exp, b) {
1437 let _ = provider.free(&b_exp);
1438 }
1439 let values = values?;
1440 if values.shape != out_shape {
1441 let _ = provider.free(&values);
1442 return None;
1443 }
1444 let index_tensor = if let Some(mask) = mask {
1445 let mask_host = gpu_helpers::gather_tensor(&mask).ok()?;
1446 let _ = provider.free(&mask);
1447 let mut indices = Vec::with_capacity(mask_host.data.len());
1448 for &m in &mask_host.data {
1449 indices.push(if m != 0.0 { 1.0 } else { 2.0 });
1450 }
1451 Tensor::new(indices, out_shape).ok()?
1452 } else {
1453 let ta = gpu_helpers::gather_tensor(&a_exp).ok()?;
1455 let tb = gpu_helpers::gather_tensor(&b_exp).ok()?;
1456 let mut indices = Vec::with_capacity(ta.data.len());
1457 for i in 0..ta.data.len() {
1458 indices.push(if ta.data[i] >= tb.data[i] { 1.0 } else { 2.0 });
1459 }
1460 Tensor::new(indices, out_shape).ok()?
1461 };
1462 Some(MaxEvaluation {
1463 values: Value::GpuTensor(values),
1464 indices: tensor::tensor_into_value(index_tensor),
1465 })
1466}
1467
1468fn broadcast_reps(a: &[usize], b: &[usize]) -> Option<(Vec<usize>, Vec<usize>, Vec<usize>)> {
1469 let rank = a.len().max(b.len()).max(1);
1470 let mut out = vec![1usize; rank];
1471 let mut aa = vec![1usize; rank];
1472 let mut bb = vec![1usize; rank];
1473 for i in 0..rank {
1474 aa[i] = *a.get(i).unwrap_or(&1);
1475 bb[i] = *b.get(i).unwrap_or(&1);
1476 }
1477 for i in 0..rank {
1478 let (ad, bd) = (aa[i], bb[i]);
1479 if ad == bd {
1480 out[i] = ad;
1481 } else if ad == 1 {
1482 out[i] = bd;
1483 } else if bd == 1 {
1484 out[i] = ad;
1485 } else {
1486 return None;
1487 }
1488 }
1489 let reps_a: Vec<usize> = (0..rank)
1490 .map(|i| if aa[i] == out[i] { 1 } else { out[i] })
1491 .collect();
1492 let reps_b: Vec<usize> = (0..rank)
1493 .map(|i| if bb[i] == out[i] { 1 } else { out[i] })
1494 .collect();
1495 Some((out, reps_a, reps_b))
1496}
1497
1498fn elementwise_max_gpu_scalar_left(
1499 a: &GpuTensorHandle,
1500 other: &Value,
1501 comparison: ComparisonMethod,
1502) -> Option<MaxEvaluation> {
1503 if comparison != ComparisonMethod::Auto {
1504 return None;
1505 }
1506 let provider = runmat_accelerate_api::provider()?;
1507 let scalar = extract_scalar(other)?;
1508 let values = if let Ok(fill) = provider.fill_like(a, scalar) {
1510 let vals = provider.elem_max(a, &fill).ok();
1511 let _ = provider.free(&fill);
1512 vals?
1513 } else {
1514 provider.scalar_max(a, scalar).ok()?
1515 };
1516 let index_tensor = if let Ok(fill) = provider.fill_like(a, scalar) {
1518 if let Ok(mask) = provider.elem_ge(a, &fill) {
1519 let _ = provider.free(&fill);
1520 let indices = gpu_mask_indices(provider, &mask)?;
1521 let _ = provider.free(&mask);
1522 return Some(MaxEvaluation {
1523 values: Value::GpuTensor(values),
1524 indices: Value::GpuTensor(indices),
1525 });
1526 } else {
1527 let _ = provider.free(&fill);
1528 let ta = gpu_helpers::gather_tensor(a).ok()?;
1529 let mut indices = Vec::with_capacity(ta.data.len());
1530 for &v in &ta.data {
1531 indices.push(if v >= scalar { 1.0 } else { 2.0 });
1532 }
1533 Tensor::new(indices, ta.shape.clone()).ok()?
1534 }
1535 } else {
1536 let ta = gpu_helpers::gather_tensor(a).ok()?;
1537 let mut indices = Vec::with_capacity(ta.data.len());
1538 for &v in &ta.data {
1539 indices.push(if v >= scalar { 1.0 } else { 2.0 });
1540 }
1541 Tensor::new(indices, ta.shape.clone()).ok()?
1542 };
1543 Some(MaxEvaluation {
1544 values: Value::GpuTensor(values),
1545 indices: tensor::tensor_into_value(index_tensor),
1546 })
1547}
1548
1549fn elementwise_max_gpu_scalar_right(
1550 other: &Value,
1551 b: &GpuTensorHandle,
1552 comparison: ComparisonMethod,
1553) -> Option<MaxEvaluation> {
1554 if comparison != ComparisonMethod::Auto {
1555 return None;
1556 }
1557 let provider = runmat_accelerate_api::provider()?;
1558 let scalar = extract_scalar(other)?;
1559 let values = if let Ok(fill) = provider.fill_like(b, scalar) {
1560 let vals = provider.elem_max(&fill, b).ok();
1561 let _ = provider.free(&fill);
1562 vals?
1563 } else {
1564 provider.scalar_max(b, scalar).ok()?
1565 };
1566 let index_tensor = if let Ok(fill) = provider.fill_like(b, scalar) {
1568 if let Ok(mask) = provider.elem_ge(&fill, b) {
1569 let _ = provider.free(&fill);
1570 let indices = gpu_mask_indices(provider, &mask)?;
1571 let _ = provider.free(&mask);
1572 return Some(MaxEvaluation {
1573 values: Value::GpuTensor(values),
1574 indices: Value::GpuTensor(indices),
1575 });
1576 } else {
1577 let _ = provider.free(&fill);
1578 let tb = gpu_helpers::gather_tensor(b).ok()?;
1579 let mut indices = Vec::with_capacity(tb.data.len());
1580 for &v in &tb.data {
1581 indices.push(if scalar >= v { 1.0 } else { 2.0 });
1582 }
1583 Tensor::new(indices, tb.shape.clone()).ok()?
1584 }
1585 } else {
1586 let tb = gpu_helpers::gather_tensor(b).ok()?;
1587 let mut indices = Vec::with_capacity(tb.data.len());
1588 for &v in &tb.data {
1589 indices.push(if scalar >= v { 1.0 } else { 2.0 });
1590 }
1591 Tensor::new(indices, tb.shape.clone()).ok()?
1592 };
1593 Some(MaxEvaluation {
1594 values: Value::GpuTensor(values),
1595 indices: tensor::tensor_into_value(index_tensor),
1596 })
1597}
1598
1599fn extract_scalar(v: &Value) -> Option<f64> {
1600 match v {
1601 Value::Num(n) => Some(*n),
1602 Value::Int(i) => Some(i.to_f64()),
1603 Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
1604 Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
1605 Value::LogicalArray(l) if l.data.len() == 1 => Some(if l.data[0] != 0 { 1.0 } else { 0.0 }),
1606 _ => None,
1607 }
1608}
1609
1610fn gpu_tensor_is_scalar(handle: &GpuTensorHandle) -> bool {
1611 handle.shape.iter().copied().product::<usize>().max(1) == 1
1612}
1613
1614fn gpu_tensor_scalar_value(handle: &GpuTensorHandle) -> Option<f64> {
1615 let tensor = gpu_helpers::gather_tensor(handle).ok()?;
1616 tensor.data.first().copied()
1617}
1618
1619fn gpu_mask_indices(
1620 provider: &dyn AccelProvider,
1621 mask: &GpuTensorHandle,
1622) -> Option<GpuTensorHandle> {
1623 let scaled = provider.scalar_mul(mask, -1.0).ok()?;
1624 let shifted = provider.scalar_add(&scaled, 2.0).ok()?;
1625 let _ = provider.free(&scaled);
1626 Some(shifted)
1627}
1628
1629fn elementwise_real_or_complex(
1630 lhs: Value,
1631 rhs: Value,
1632 comparison: ComparisonMethod,
1633) -> Result<MaxEvaluation, String> {
1634 match (
1635 materialize_for_max("max", lhs)?,
1636 materialize_for_max("max", rhs)?,
1637 ) {
1638 (InputData::Complex(a), InputData::Complex(b)) => elementwise_complex_max(a, b, comparison),
1639 (InputData::Complex(a), InputData::Real(b)) => {
1640 let converted = promote_real_tensor_to_complex(b);
1641 elementwise_complex_max(a, converted, comparison)
1642 }
1643 (InputData::Real(a), InputData::Complex(b)) => {
1644 let converted = promote_real_tensor_to_complex(a);
1645 elementwise_complex_max(converted, b, comparison)
1646 }
1647 (InputData::Real(a), InputData::Real(b)) => elementwise_real_max(a, b, comparison),
1648 }
1649}
1650
1651fn elementwise_real_max(
1652 lhs: Tensor,
1653 rhs: Tensor,
1654 comparison: ComparisonMethod,
1655) -> Result<MaxEvaluation, String> {
1656 let plan = BroadcastPlan::new(&lhs.shape, &rhs.shape).map_err(|err| format!("max: {}", err))?;
1657 let mut values = vec![0.0f64; plan.len()];
1658 let mut indices = vec![0.0f64; plan.len()];
1659
1660 for (offset, index_a, index_b) in plan.iter() {
1661 let a = lhs.data.get(index_a).copied().unwrap_or(f64::NAN);
1662 let b = rhs.data.get(index_b).copied().unwrap_or(f64::NAN);
1663 let (value, origin) = choose_real_elementwise(a, b, comparison);
1664 values[offset] = value;
1665 indices[offset] = origin;
1666 }
1667
1668 let value_tensor =
1669 Tensor::new(values, plan.output_shape().to_vec()).map_err(|e| format!("max: {e}"))?;
1670 let index_tensor =
1671 Tensor::new(indices, plan.output_shape().to_vec()).map_err(|e| format!("max: {e}"))?;
1672
1673 Ok(MaxEvaluation {
1674 values: tensor::tensor_into_value(value_tensor),
1675 indices: tensor::tensor_into_value(index_tensor),
1676 })
1677}
1678
1679fn elementwise_complex_max(
1680 lhs: ComplexTensor,
1681 rhs: ComplexTensor,
1682 comparison: ComparisonMethod,
1683) -> Result<MaxEvaluation, String> {
1684 let plan = BroadcastPlan::new(&lhs.shape, &rhs.shape).map_err(|err| format!("max: {}", err))?;
1685 let mut values = vec![(0.0f64, 0.0f64); plan.len()];
1686 let mut indices = vec![0.0f64; plan.len()];
1687
1688 for (offset, index_a, index_b) in plan.iter() {
1689 let a = lhs
1690 .data
1691 .get(index_a)
1692 .copied()
1693 .unwrap_or((f64::NAN, f64::NAN));
1694 let b = rhs
1695 .data
1696 .get(index_b)
1697 .copied()
1698 .unwrap_or((f64::NAN, f64::NAN));
1699 let (value, origin) = choose_complex_elementwise(a, b, comparison);
1700 values[offset] = value;
1701 indices[offset] = origin;
1702 }
1703
1704 let value_tensor = ComplexTensor::new(values, plan.output_shape().to_vec())
1705 .map_err(|e| format!("max: {e}"))?;
1706 let index_tensor =
1707 Tensor::new(indices, plan.output_shape().to_vec()).map_err(|e| format!("max: {e}"))?;
1708
1709 Ok(MaxEvaluation {
1710 values: complex_tensor_into_value(value_tensor),
1711 indices: tensor::tensor_into_value(index_tensor),
1712 })
1713}
1714
1715fn promote_real_tensor_to_complex(tensor: Tensor) -> ComplexTensor {
1716 let data = tensor
1717 .data
1718 .iter()
1719 .copied()
1720 .map(|re| (re, 0.0))
1721 .collect::<Vec<_>>();
1722 ComplexTensor {
1723 data,
1724 shape: tensor.shape.clone(),
1725 rows: tensor.rows,
1726 cols: tensor.cols,
1727 }
1728}
1729
1730fn choose_real_elementwise(a: f64, b: f64, comparison: ComparisonMethod) -> (f64, f64) {
1731 match (a.is_nan(), b.is_nan()) {
1732 (true, true) => (f64::NAN, 1.0),
1733 (true, false) => (f64::NAN, 1.0),
1734 (false, true) => (f64::NAN, 2.0),
1735 (false, false) => {
1736 if should_replace_real(a, b, comparison) {
1737 (b, 2.0)
1738 } else {
1739 (a, 1.0)
1740 }
1741 }
1742 }
1743}
1744
1745fn choose_complex_elementwise(
1746 a: (f64, f64),
1747 b: (f64, f64),
1748 comparison: ComparisonMethod,
1749) -> ((f64, f64), f64) {
1750 let a_nan = a.0.is_nan() || a.1.is_nan();
1751 let b_nan = b.0.is_nan() || b.1.is_nan();
1752 match (a_nan, b_nan) {
1753 (true, true) => ((f64::NAN, f64::NAN), 1.0),
1754 (true, false) => ((f64::NAN, f64::NAN), 1.0),
1755 (false, true) => ((f64::NAN, f64::NAN), 2.0),
1756 (false, false) => {
1757 if should_replace_complex(a, b, comparison) {
1758 (b, 2.0)
1759 } else {
1760 (a, 1.0)
1761 }
1762 }
1763 }
1764}
1765
1766#[cfg(test)]
1767mod tests {
1768 use super::*;
1769 #[cfg(any(feature = "doc_export", feature = "wgpu"))]
1770 use crate::builtins::common::test_support;
1771 #[cfg(feature = "wgpu")]
1772 use runmat_accelerate_api::HostTensorView;
1773 use runmat_builtins::{IntValue, Tensor, Value};
1774
1775 fn placeholder() -> Value {
1776 let tensor = Tensor::new(Vec::<f64>::new(), vec![0, 0]).unwrap();
1777 Value::Tensor(tensor)
1778 }
1779
1780 #[test]
1781 fn max_scalar_returns_input() {
1782 let result = max_builtin(Value::Num(5.0), Vec::new()).expect("max");
1783 assert_eq!(result, Value::Num(5.0));
1784 }
1785
1786 #[test]
1787 fn max_vector_with_indices() {
1788 let tensor = Tensor::new(vec![3.0, 1.0, 5.0], vec![3, 1]).unwrap();
1789 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1790 let (values, indices) = eval.into_pair();
1791 assert_eq!(values, Value::Num(5.0));
1792 assert_eq!(indices, Value::Num(3.0));
1793 }
1794
1795 #[test]
1796 fn max_matrix_default_dimension() {
1797 let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0, 5.0, 6.0], vec![2, 3]).unwrap();
1798 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1799 let (values, indices) = eval.into_pair();
1800 match values {
1801 Value::Tensor(t) => {
1802 assert_eq!(t.shape, vec![1, 3]);
1803 assert_eq!(t.data, vec![4.0, 2.0, 6.0]);
1804 }
1805 other => panic!("expected tensor, got {other:?}"),
1806 }
1807 match indices {
1808 Value::Tensor(t) => {
1809 assert_eq!(t.data, vec![2.0, 2.0, 2.0]);
1810 }
1811 other => panic!("expected tensor, got {other:?}"),
1812 }
1813 }
1814
1815 #[test]
1816 fn max_all_linear_index() {
1817 let tensor =
1818 Tensor::new((1..=12).map(|v| v as f64).collect::<Vec<_>>(), vec![3, 4]).unwrap();
1819 let args = vec![placeholder(), Value::from("all")];
1820 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1821 let (values, indices) = eval.into_pair();
1822 assert_eq!(values, Value::Num(12.0));
1823 assert_eq!(indices, Value::Num(12.0));
1824
1825 let args_linear = vec![placeholder(), Value::from("linear")];
1826 let eval = evaluate(
1827 Value::Tensor(Tensor::new(vec![2.0, 3.0], vec![1, 2]).unwrap()),
1828 &args_linear,
1829 )
1830 .expect("evaluate");
1831 let (values, indices) = eval.into_pair();
1832 assert_eq!(values, Value::Num(3.0));
1833 assert_eq!(indices, Value::Num(2.0));
1834 }
1835
1836 #[test]
1837 fn max_with_omitnan() {
1838 let tensor = Tensor::new(vec![f64::NAN, 4.0, 2.0], vec![3, 1]).unwrap();
1839 let args = vec![placeholder(), Value::from("omitnan")];
1840 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1841 let (values, indices) = eval.into_pair();
1842 assert_eq!(values, Value::Num(4.0));
1843 assert_eq!(indices, Value::Num(2.0));
1844 }
1845
1846 #[test]
1847 fn max_omitnan_all_nan_slice() {
1848 let tensor = Tensor::new(vec![f64::NAN, f64::NAN], vec![2, 1]).unwrap();
1849 let args = vec![placeholder(), Value::from("omitnan")];
1850 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1851 let (values, indices) = eval.into_pair();
1852 match values {
1853 Value::Num(v) => assert!(v.is_nan()),
1854 other => panic!("expected scalar NaN, got {other:?}"),
1855 }
1856 match indices {
1857 Value::Num(v) => assert!(v.is_nan()),
1858 other => panic!("expected scalar NaN index, got {other:?}"),
1859 }
1860 }
1861
1862 #[test]
1863 fn max_reduction_abs_comparison() {
1864 let tensor = Tensor::new(vec![1.0, -3.0, -2.0, 4.0], vec![2, 2]).unwrap();
1865 let args = vec![
1866 placeholder(),
1867 Value::from("ComparisonMethod"),
1868 Value::from("abs"),
1869 ];
1870 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1871 let (values, indices) = eval.into_pair();
1872 match values {
1873 Value::Tensor(t) => {
1874 assert_eq!(t.shape, vec![1, 2]);
1875 assert_eq!(t.data, vec![-3.0, 4.0]);
1876 }
1877 other => panic!("expected tensor result, got {other:?}"),
1878 }
1879 match indices {
1880 Value::Tensor(t) => {
1881 assert_eq!(t.data, vec![2.0, 2.0]);
1882 }
1883 other => panic!("expected tensor indices, got {other:?}"),
1884 }
1885 }
1886
1887 #[test]
1888 fn max_reduction_complex_real_comparison() {
1889 let tensor = ComplexTensor::new(vec![(1.0, 2.0), (0.5, 5.0)], vec![2, 1]).expect("tensor");
1890 let args = vec![
1891 placeholder(),
1892 Value::from("ComparisonMethod"),
1893 Value::from("real"),
1894 ];
1895 let eval = evaluate(Value::ComplexTensor(tensor), &args).expect("evaluate");
1896 let (values, indices) = eval.into_pair();
1897 match values {
1898 Value::Complex(re, im) => {
1899 assert!((re - 1.0).abs() < 1e-12);
1900 assert!((im - 2.0).abs() < 1e-12);
1901 }
1902 other => panic!("expected complex scalar, got {other:?}"),
1903 }
1904 assert_eq!(indices, Value::Num(1.0));
1905 }
1906
1907 #[test]
1908 fn max_elementwise_broadcast() {
1909 let lhs = Tensor::new(vec![1.0, 4.0, 7.0], vec![1, 3]).unwrap();
1910 let rhs = Tensor::new(vec![2.0, 3.0, 5.0], vec![3, 1]).unwrap();
1911 let eval = evaluate(Value::Tensor(lhs), &[Value::Tensor(rhs)]).expect("evaluate");
1912 let (values, indices) = eval.into_pair();
1913 match values {
1914 Value::Tensor(t) => {
1915 assert_eq!(t.shape, vec![3, 3]);
1916 assert_eq!([t.data[0], t.data[3], t.data[6]], [2.0, 4.0, 7.0]);
1917 assert_eq!([t.data[1], t.data[4], t.data[7]], [3.0, 4.0, 7.0]);
1918 assert_eq!([t.data[2], t.data[5], t.data[8]], [5.0, 5.0, 7.0]);
1919 }
1920 other => panic!("expected tensor, got {other:?}"),
1921 }
1922 match indices {
1923 Value::Tensor(t) => {
1924 assert_eq!(t.shape, vec![3, 3]);
1925 assert_eq!([t.data[0], t.data[3], t.data[6]], [2.0, 1.0, 1.0]);
1926 assert_eq!([t.data[1], t.data[4], t.data[7]], [2.0, 1.0, 1.0]);
1927 assert_eq!([t.data[2], t.data[5], t.data[8]], [2.0, 2.0, 1.0]);
1928 }
1929 other => panic!("expected tensor, got {other:?}"),
1930 }
1931 }
1932
1933 #[test]
1934 fn max_elementwise_abs_comparison() {
1935 let lhs = Tensor::new(vec![-2.0, 1.0], vec![2, 1]).unwrap();
1936 let rhs = Tensor::new(vec![1.5, -3.0], vec![2, 1]).unwrap();
1937 let args = vec![
1938 Value::Tensor(rhs),
1939 Value::from("ComparisonMethod"),
1940 Value::from("abs"),
1941 ];
1942 let eval = evaluate(Value::Tensor(lhs), &args).expect("evaluate");
1943 let (values, indices) = eval.into_pair();
1944 match values {
1945 Value::Tensor(t) => {
1946 assert_eq!(t.data, vec![-2.0, -3.0]);
1947 }
1948 other => panic!("expected tensor, got {other:?}"),
1949 }
1950 match indices {
1951 Value::Tensor(t) => {
1952 assert_eq!(t.data, vec![1.0, 2.0]);
1953 }
1954 other => panic!("expected tensor, got {other:?}"),
1955 }
1956 }
1957
1958 #[test]
1959 fn max_elementwise_rejects_reduction_only_keywords() {
1960 let lhs = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1961 let rhs = Tensor::new(vec![3.0, 4.0], vec![2, 1]).unwrap();
1962 let err = evaluate(
1963 Value::Tensor(lhs),
1964 &[Value::Tensor(rhs), Value::from("omitnan")],
1965 )
1966 .expect_err("expected error");
1967 assert!(err.contains("only supported for reduction"));
1968 }
1969
1970 #[test]
1971 fn max_complex_real_comparison() {
1972 let lhs = ComplexTensor::new(vec![(1.0, 2.0)], vec![1, 1]).unwrap();
1973 let rhs = ComplexTensor::new(vec![(0.5, 5.0)], vec![1, 1]).unwrap();
1974 let args = vec![
1975 Value::ComplexTensor(rhs),
1976 Value::from("ComparisonMethod"),
1977 Value::from("real"),
1978 ];
1979 let eval = evaluate(Value::ComplexTensor(lhs), &args).expect("evaluate");
1980 let (values, indices) = eval.into_pair();
1981 assert_eq!(values, Value::Complex(1.0, 2.0));
1982 assert_eq!(indices, Value::Num(1.0));
1983 }
1984
1985 #[test]
1986 fn max_dimension_argument_parsing() {
1987 let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0], vec![2, 2]).unwrap();
1988 let dims = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1989 let args = vec![placeholder(), Value::Tensor(dims)];
1990 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1991 let (values, indices) = eval.into_pair();
1992 assert_eq!(values, Value::Num(4.0));
1993 assert_eq!(indices, Value::Num(2.0));
1994 }
1995
1996 #[test]
1997 fn max_vecdim_duplicate_entries() {
1998 let tensor = Tensor::new(vec![5.0, 2.0, 7.0, 1.0], vec![2, 2]).unwrap();
1999 let dims = Tensor::new(vec![1.0, 1.0, 2.0], vec![3, 1]).unwrap();
2000 let args = vec![placeholder(), Value::Tensor(dims)];
2001 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2002 let (values, indices) = eval.into_pair();
2003 assert_eq!(values, Value::Num(7.0));
2004 assert_eq!(indices, Value::Num(3.0));
2005 }
2006
2007 #[test]
2008 fn max_dimension_gpu_argument_errors() {
2009 let tensor = Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap();
2010 let dim_handle = Value::GpuTensor(runmat_accelerate_api::GpuTensorHandle {
2011 shape: vec![1, 1],
2012 device_id: 0,
2013 buffer_id: 42,
2014 });
2015 let err = evaluate(Value::Tensor(tensor), &[placeholder(), dim_handle])
2016 .expect_err("expected error");
2017 assert!(err.contains("dimension arguments must reside on the host"));
2018 }
2019
2020 #[test]
2021 fn max_invalid_comparison_method_errors() {
2022 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
2023 let args = vec![
2024 placeholder(),
2025 Value::from("ComparisonMethod"),
2026 Value::from("chebyshev"),
2027 ];
2028 let err = evaluate(Value::Tensor(tensor), &args).expect_err("expected error");
2029 assert!(err.contains("unsupported ComparisonMethod"));
2030 }
2031
2032 #[test]
2033 #[cfg(feature = "doc_export")]
2034 fn max_doc_examples_present() {
2035 let blocks = test_support::doc_examples(super::DOC_MD);
2036 assert!(!blocks.is_empty());
2037 }
2038
2039 #[test]
2040 #[cfg(feature = "wgpu")]
2041 fn max_gpu_dim1_matches_cpu() {
2042 let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2043 let eval_cpu = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu");
2044 let (values_cpu, indices_cpu) = eval_cpu.into_pair();
2045
2046 test_support::with_test_provider(|provider| {
2047 let view = HostTensorView {
2048 data: &tensor.data,
2049 shape: &tensor.shape,
2050 };
2051 let handle = provider.upload(&view).expect("upload");
2052 let eval_gpu = evaluate(Value::GpuTensor(handle), &[]).expect("gpu");
2053 let (values_gpu, indices_gpu) = eval_gpu.into_pair();
2054 match (&values_gpu, &indices_gpu) {
2055 (Value::GpuTensor(_), Value::GpuTensor(_)) => {}
2056 other => panic!("expected GPU tensors, got {other:?}"),
2057 }
2058 let gathered_vals = test_support::gather(values_gpu).expect("gather values");
2059 let gathered_idx = test_support::gather(indices_gpu).expect("gather indices");
2060 let expected_vals = match values_cpu {
2061 Value::Tensor(t) => t,
2062 other => panic!("expected tensor values from cpu eval, got {other:?}"),
2063 };
2064 let expected_idx = match indices_cpu {
2065 Value::Tensor(t) => t,
2066 other => panic!("expected tensor indices from cpu eval, got {other:?}"),
2067 };
2068 assert_eq!(gathered_vals.shape, expected_vals.shape);
2069 assert_eq!(gathered_vals.data, expected_vals.data);
2070 assert_eq!(gathered_idx.shape, expected_idx.shape);
2071 assert_eq!(gathered_idx.data, expected_idx.data);
2072 });
2073 }
2074
2075 #[test]
2076 fn max_dimension_numeric_argument() {
2077 let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0], vec![2, 2]).unwrap();
2078 let args = vec![placeholder(), Value::Num(2.0)];
2079 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2080 let (values, indices) = eval.into_pair();
2081 match values {
2082 Value::Tensor(t) => {
2083 assert_eq!(t.shape, vec![2, 1]);
2084 assert_eq!(t.data, vec![3.0, 4.0]);
2085 }
2086 other => panic!("expected tensor, got {other:?}"),
2087 }
2088 match indices {
2089 Value::Tensor(t) => {
2090 assert_eq!(t.data, vec![1.0, 1.0]);
2091 }
2092 other => panic!("expected tensor, got {other:?}"),
2093 }
2094 }
2095
2096 #[test]
2097 fn max_complex_auto_comparison() {
2098 let lhs = ComplexTensor::new(vec![(1.0, 2.0)], vec![1, 1]).unwrap();
2099 let rhs = ComplexTensor::new(vec![(2.0, 1.0)], vec![1, 1]).unwrap();
2100 let eval =
2101 evaluate(Value::ComplexTensor(lhs), &[Value::ComplexTensor(rhs)]).expect("evaluate");
2102 let (values, indices) = eval.into_pair();
2103 assert_eq!(values, Value::Complex(1.0, 2.0));
2104 assert_eq!(indices, Value::Num(1.0));
2105 }
2106
2107 #[test]
2108 fn max_scalar_pair_arguments() {
2109 let args = vec![Value::Num(2.0)];
2110 let result = max_builtin(Value::Num(3.0), args).expect("max");
2111 assert_eq!(result, Value::Num(3.0));
2112 }
2113
2114 #[test]
2115 fn max_rejects_invalid_dimension() {
2116 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
2117 let args = vec![placeholder(), Value::Int(IntValue::I32(0))];
2118 let err = evaluate(Value::Tensor(tensor), &args).expect_err("expected error");
2119 assert!(err.contains("dimension must be >= 1"));
2120 }
2121}