runmat_runtime/builtins/math/reduction/
max.rs

1//! MATLAB-compatible `max` builtin with GPU-aware semantics for RunMat.
2
3use std::cmp::Ordering;
4use std::collections::BTreeSet;
5
6use runmat_accelerate_api::{AccelProvider, GpuTensorHandle, ReduceDimResult};
7use runmat_builtins::{ComplexTensor, Tensor, Value};
8use runmat_macros::runtime_builtin;
9
10use crate::builtins::common::broadcast::BroadcastPlan;
11use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
12use crate::builtins::common::{gpu_helpers, tensor};
13#[cfg(feature = "doc_export")]
14use crate::register_builtin_doc_text;
15use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
16
17use crate::builtins::common::spec::{
18    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, FusionError,
19    FusionExprContext, FusionKernelTemplate, GpuOpKind, ProviderHook, ReductionNaN,
20    ResidencyPolicy, ScalarType, ShapeRequirements,
21};
22
23#[cfg(feature = "doc_export")]
24pub const DOC_MD: &str = r#"---
25title: "max"
26category: "math/reduction"
27keywords: ["max", "maximum", "reduction", "comparisonmethod", "omitnan", "gpu"]
28summary: "Return the maximum elements of scalars, vectors, matrices, or N-D tensors with MATLAB-compatible options."
29references: []
30gpu_support:
31  elementwise: false
32  reduction: true
33  precisions: ["f32", "f64"]
34  broadcasting: "matlab"
35  notes: "Uses provider reduce_max_dim / reduce_max when available. Fallback gathers data to the host for omitnan, custom comparison modes, or complex inputs."
36fusion:
37  elementwise: false
38  reduction: true
39  max_inputs: 1
40  constants: "inline"
41requires_feature: null
42tested:
43  unit: "builtins::math::reduction::max::tests"
44  integration: "builtins::math::reduction::max::tests::max_gpu_dim1_matches_cpu"
45---
46
47# What does the `max` function do in MATLAB / RunMat?
48`max` returns the largest values in its input while preserving MATLAB semantics for reductions, elementwise comparisons, NaN handling, complex magnitude comparisons, and linear indexing.
49
50## How does the `max` function behave in MATLAB / RunMat?
51- `max(X)` on an `m Ă— n` array reduces along the first non-singleton dimension, returning a row vector of column-wise maxima and the corresponding indices (when requested).
52- `max(X, [], dim)` reduces along the specified dimension; `max(X, [], vecdim)` reduces along each dimension listed in `vecdim`.
53- `max(X, [], 'all')` collapses every element into a scalar and returns the linear index when two outputs are requested.
54- `max(X, [], 'linear')` is equivalent to `'all'` but guarantees that the matching index is linear over `X(:)`.
55- `max(X, [], ..., 'omitnan')` ignores `NaN` values inside each slice. If every element in a slice is `NaN`, the result for that slice is `NaN` and the index is `NaN`.
56- `max(X, [], ..., 'includenan')` (default) propagates `NaN` whenever a slice contains any `NaN` element, returning the index of the first `NaN`.
57- `max(A, B)` performs elementwise comparison using MATLAB's implicit expansion rules. The second output indicates whether the maximum came from `A` (index `1`) or `B` (index `2`).
58- Complex inputs follow MATLAB ordering: `'ComparisonMethod','auto'` (default) compares magnitudes and breaks ties using phase angles, while `'real'` compares real components first. `'abs'` is an explicit alias for magnitude ordering on real and complex inputs.
59
60## `max` Function GPU Execution Behaviour
61When RunMat Accelerate is active, tensors that already reside on the GPU stay on the device whenever the provider exposes `reduce_max_dim` (for dimension reductions) or `reduce_max` (for whole-array reductions). Requests that require `omitnan`, custom comparison modes, `'linear'` indices, or complex arithmetic gather the data to the host, compute the MATLAB-compatible result, and return the output on the host. Elementwise `max(A, B)` currently executes on the host; the planner rematerializes tensors on the GPU when follow-on fused kernels make it profitable.
62
63## Examples of using the `max` function in MATLAB / RunMat
64
65### Finding column-wise maxima of a matrix
66```matlab
67A = [3 1 5; 4 2 6];
68[m, idx] = max(A);
69```
70Expected output:
71```matlab
72m   = [4 2 6];
73idx = [2 2 2];
74```
75
76### Reducing along the second dimension
77```matlab
78A = [3 1 5; 4 2 6];
79[m, idx] = max(A, [], 2);
80```
81Expected output:
82```matlab
83m   = [5; 6];
84idx = [3; 3];
85```
86
87### Collapsing all elements with linear indices
88```matlab
89A = reshape(1:12, [3 4]);
90[m, idx] = max(A, [], 'all');
91```
92Expected output:
93```matlab
94m   = 12;
95idx = 12;  % linear index into A(:)
96```
97
98### Ignoring NaN values during reduction
99```matlab
100values = [NaN 4 2; 3 NaN 1];
101[m, idx] = max(values, [], 1, 'omitnan');
102```
103Expected output:
104```matlab
105m   = [3 4 2];
106idx = [2 1 1];
107```
108
109### Elementwise maximum with broadcasting
110```matlab
111A = [1 4 7];
112B = [2; 3; 5];
113[C, origin] = max(A, B);
114```
115Expected output:
116```matlab
117C =
118     2     4     7
119     3     4     7
120     5     5     7
121
122origin =
123     2     1     1
124     2     1     1
125     2     2     1
126```
127
128### Comparing complex values by magnitude
129```matlab
130Z = [1+2i, 2+1i, -2+2i];
131M = max(Z);                         % magnitude ordering
132R = max(Z, [], 'ComparisonMethod', 'real');
133```
134Expected output:
135```matlab
136M = -2.0000 + 2.0000i
137R = 2.0000 + 1.0000i
138```
139
140## GPU residency in RunMat (Do I need `gpuArray`?)
141You typically do **not** need to call `gpuArray` manually. The fusion planner keeps tensors on the GPU between compatible kernels. When a reduction is supported by the active provider, the maximum values and indices stay on device. If a provider lacks the necessary hook, RunMat gathers data to the host, computes the result, and returns host tensors—subsequent fused GPU kernels can re-upload data when profitable.
142
143## FAQ
144
145### Can I request the linear index of the global maximum?
146Yes. Use either `max(X, [], 'all')` or `max(X, [], 'linear')`. Both return a scalar maximum and the linear index into `X(:)` when you request two outputs.
147
148### Does `max` support `'ComparisonMethod'` for real and complex arrays?
149Absolutely. `'auto'` or `'abs'` compare magnitudes; `'real'` compares the real component first. The returned values always match MATLAB, including tie-breaking rules.
150
151### What happens when all elements are `NaN` and `'omitnan'` is requested?
152The value result is `NaN` and the index is `NaN`, matching MATLAB's behavior for empty slices.
153
154### Can I mix elementwise comparisons with dimension reductions?
155No. `max(A, B)` performs elementwise comparisons only. Use `max(A, [], dim)` when you want reductions along specific dimensions.
156
157### Do GPU reductions support `'omitnan'` or custom comparison methods?
158Not yet. Those requests fall back to the host implementation, which still honors MATLAB semantics. The output remains a host tensor in that case.
159
160### Are logical and integer inputs supported?
161Yes. Logical arrays are promoted to double precision, and integer inputs are converted to double before comparison, matching MATLAB's numeric tower.
162
163## See Also
164[min](./min), [sum](./sum), [mean](./mean), [gpuArray](../../acceleration/gpu/gpuArray), [gather](../../acceleration/gpu/gather)
165
166## Source & Feedback
167- The full source code for the implementation of the `max` function is available at: [`crates/runmat-runtime/src/builtins/math/reduction/max.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/math/reduction/max.rs)
168- Found a bug or behavioral difference? Please [open an issue](https://github.com/runmat-org/runmat/issues/new/choose) with details and a minimal repro.
169"#;
170pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
171    name: "max",
172    op_kind: GpuOpKind::Reduction,
173    supported_precisions: &[ScalarType::F32, ScalarType::F64],
174    broadcast: BroadcastSemantics::Matlab,
175    provider_hooks: &[
176        ProviderHook::Reduction {
177            name: "reduce_max_dim",
178        },
179        ProviderHook::Reduction {
180            name: "reduce_max",
181        },
182    ],
183    constant_strategy: ConstantStrategy::InlineLiteral,
184    residency: ResidencyPolicy::NewHandle,
185    nan_mode: ReductionNaN::Include,
186    two_pass_threshold: Some(256),
187    workgroup_size: Some(256),
188    accepts_nan_mode: false,
189    notes:
190        "Providers should implement reduce_max_dim / reduce_max. Requests that require omitnan, comparisonmethod overrides, or complex inputs fall back to the host implementation.",
191};
192
193register_builtin_gpu_spec!(GPU_SPEC);
194
195pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
196    name: "max",
197    shape: ShapeRequirements::BroadcastCompatible,
198    constant_strategy: ConstantStrategy::InlineLiteral,
199    elementwise: None,
200    reduction: Some(FusionKernelTemplate {
201        scalar_precisions: &[ScalarType::F32, ScalarType::F64],
202        wgsl_body: |ctx: &FusionExprContext| {
203            let input = ctx.inputs.first().ok_or(FusionError::MissingInput(0))?;
204            Ok(format!("accumulator = max(accumulator, {input});"))
205        },
206    }),
207    emits_nan: true,
208    notes: "Fusion planner emits canonical reduction kernels; providers may substitute custom WGSL via reduce_max_dim hooks.",
209};
210
211register_builtin_fusion_spec!(FUSION_SPEC);
212
213#[cfg(feature = "doc_export")]
214register_builtin_doc_text!("max", DOC_MD);
215
216/// Evaluation artifact returned by `max` that carries both values and indices.
217#[derive(Debug, Clone)]
218pub struct MaxEvaluation {
219    values: Value,
220    indices: Value,
221}
222
223impl MaxEvaluation {
224    /// Consume the evaluation and return only the maximum values (single-output call).
225    pub fn into_value(self) -> Value {
226        self.values
227    }
228
229    /// Consume the evaluation and return both maxima and indices.
230    pub fn into_pair(self) -> (Value, Value) {
231        (self.values, self.indices)
232    }
233
234    /// Peek at the indices without consuming.
235    pub fn indices_value(&self) -> Value {
236        self.indices.clone()
237    }
238}
239
240#[runtime_builtin(
241    name = "max",
242    category = "math/reduction",
243    summary = "Return the maximum elements of scalars, vectors, matrices, or N-D tensors.",
244    keywords = "max,maximum,reduction,gpu,comparisonmethod,omitnan",
245    accel = "reduction"
246)]
247fn max_builtin(value: Value, rest: Vec<Value>) -> Result<Value, String> {
248    evaluate(value, &rest).map(|eval| eval.into_value())
249}
250
251/// Evaluate the builtin once and expose both outputs (value + indices).
252pub fn evaluate(value: Value, rest: &[Value]) -> Result<MaxEvaluation, String> {
253    let parsed = parse_call(rest)?;
254    if std::env::var("RUNMAT_DEBUG_MAX").is_ok() {
255        let call_label = match &parsed {
256            ParsedCall::Reduction(_) => "reduction",
257            ParsedCall::Elementwise(_) => "elementwise",
258        };
259        let first_arg = rest.first().map(debug_value_kind).unwrap_or("None");
260        eprintln!(
261            "[runmat-debug-max] call_type={call_label} rest_len={} first_arg={first_arg}",
262            rest.len()
263        );
264    }
265    match parsed {
266        ParsedCall::Elementwise(args) => elementwise_max(value, args),
267        ParsedCall::Reduction(args) => reduction_max(value, args),
268    }
269}
270
271#[derive(Debug, Clone)]
272enum ParsedCall {
273    Reduction(ReductionArgs),
274    Elementwise(ElementwiseArgs),
275}
276
277#[derive(Debug, Clone)]
278struct ReductionArgs {
279    selection: DimSelection,
280    nan_mode: ReductionNaN,
281    comparison: ComparisonMethod,
282    linear_index: bool,
283}
284
285impl Default for ReductionArgs {
286    fn default() -> Self {
287        Self {
288            selection: DimSelection::Auto,
289            nan_mode: ReductionNaN::Include,
290            comparison: ComparisonMethod::Auto,
291            linear_index: false,
292        }
293    }
294}
295
296#[derive(Debug, Clone)]
297enum DimSelection {
298    Auto,
299    Dim(usize),
300    Vec(Vec<usize>),
301    All,
302}
303
304#[derive(Debug, Clone, Copy, PartialEq, Eq)]
305enum ComparisonMethod {
306    Auto,
307    Real,
308    Abs,
309}
310
311#[derive(Debug, Clone)]
312struct ElementwiseArgs {
313    other: Value,
314    comparison: ComparisonMethod,
315}
316
317fn parse_call(rest: &[Value]) -> Result<ParsedCall, String> {
318    if rest.is_empty() {
319        return Ok(ParsedCall::Reduction(ReductionArgs::default()));
320    }
321
322    let first = &rest[0];
323    if !is_empty_placeholder(first) {
324        let comparison = parse_elementwise_options(&rest[1..])?;
325        return Ok(ParsedCall::Elementwise(ElementwiseArgs {
326            other: first.clone(),
327            comparison,
328        }));
329    }
330
331    let mut args = ReductionArgs::default();
332    parse_reduction_options(&mut args, &rest[1..])?;
333    Ok(ParsedCall::Reduction(args))
334}
335
336fn debug_value_kind(value: &Value) -> &'static str {
337    match value {
338        Value::Num(_) => "Num",
339        Value::Int(_) => "Int",
340        Value::Bool(_) => "Bool",
341        Value::Tensor(t) => {
342            if t.data.is_empty() {
343                "Tensor(empty)"
344            } else {
345                "Tensor"
346            }
347        }
348        Value::GpuTensor(_) => "GpuTensor",
349        Value::String(_) => "String",
350        Value::CharArray(_) => "CharArray",
351        Value::StringArray(sa) => {
352            if sa.data.is_empty() {
353                "StringArray(empty)"
354            } else {
355                "StringArray"
356            }
357        }
358        Value::LogicalArray(l) => {
359            if l.data.is_empty() {
360                "LogicalArray(empty)"
361            } else {
362                "LogicalArray"
363            }
364        }
365        Value::Cell(c) => {
366            if c.data.is_empty() {
367                "Cell(empty)"
368            } else {
369                "Cell"
370            }
371        }
372        _ => "Other",
373    }
374}
375
376fn is_empty_placeholder(value: &Value) -> bool {
377    match value {
378        Value::Tensor(t) => t.data.is_empty(),
379        Value::LogicalArray(l) => l.data.is_empty(),
380        Value::StringArray(sa) => sa.data.is_empty(),
381        Value::CharArray(ca) => ca.data.is_empty(),
382        Value::Cell(cell) => cell.data.is_empty(),
383        Value::String(s) => s.is_empty(),
384        _ => false,
385    }
386}
387
388fn parse_reduction_options(args: &mut ReductionArgs, rest: &[Value]) -> Result<(), String> {
389    let mut idx = 0usize;
390    let mut selection_set = !matches!(args.selection, DimSelection::Auto);
391    let mut comparison_set = matches!(args.comparison, ComparisonMethod::Auto);
392    while idx < rest.len() {
393        if let Some(keyword) = keyword_of(&rest[idx]) {
394            match keyword.as_str() {
395                "omitnan" => {
396                    args.nan_mode = ReductionNaN::Omit;
397                    idx += 1;
398                    continue;
399                }
400                "includenan" => {
401                    args.nan_mode = ReductionNaN::Include;
402                    idx += 1;
403                    continue;
404                }
405                "all" => {
406                    if selection_set {
407                        return Err(
408                            "max: 'all' cannot be combined with an explicit dimension".to_string()
409                        );
410                    }
411                    args.selection = DimSelection::All;
412                    selection_set = true;
413                    idx += 1;
414                    continue;
415                }
416                "linear" => {
417                    if selection_set {
418                        return Err(
419                            "max: 'linear' cannot be combined with an explicit dimension"
420                                .to_string(),
421                        );
422                    }
423                    args.selection = DimSelection::All;
424                    args.linear_index = true;
425                    selection_set = true;
426                    idx += 1;
427                    continue;
428                }
429                "comparisonmethod" => {
430                    let Some(value) = rest.get(idx + 1) else {
431                        return Err("max: expected a value after 'ComparisonMethod'".to_string());
432                    };
433                    args.comparison = parse_comparison_method(value)?;
434                    comparison_set = true;
435                    idx += 2;
436                    continue;
437                }
438                _ => {}
439            }
440        }
441
442        if !selection_set {
443            if let Some(selection) = parse_dimension_value(&rest[idx])? {
444                args.selection = selection;
445                selection_set = true;
446                idx += 1;
447                continue;
448            }
449        }
450
451        return Err(format!("max: unrecognised argument {:?}", rest[idx]));
452    }
453
454    if !comparison_set {
455        args.comparison = ComparisonMethod::Auto;
456    }
457
458    Ok(())
459}
460
461fn parse_elementwise_options(rest: &[Value]) -> Result<ComparisonMethod, String> {
462    let mut comparison = ComparisonMethod::Auto;
463    let mut comparison_set = false;
464    let mut idx = 0usize;
465    while idx < rest.len() {
466        if let Some(keyword) = keyword_of(&rest[idx]) {
467            match keyword.as_str() {
468                "comparisonmethod" => {
469                    let Some(value) = rest.get(idx + 1) else {
470                        return Err("max: expected a value after 'ComparisonMethod'".to_string());
471                    };
472                    comparison = parse_comparison_method(value)?;
473                    comparison_set = true;
474                    idx += 2;
475                    continue;
476                }
477                "omitnan" | "includenan" | "all" | "linear" => {
478                    return Err(format!(
479                        "max: '{}' is only supported for reduction calls",
480                        keyword
481                    ));
482                }
483                _ => {}
484            }
485        }
486        return Err(format!("max: unrecognised argument {:?}", rest[idx]));
487    }
488    if !comparison_set {
489        comparison = ComparisonMethod::Auto;
490    }
491    Ok(comparison)
492}
493
494fn parse_comparison_method(value: &Value) -> Result<ComparisonMethod, String> {
495    let Some(keyword) = keyword_of(value) else {
496        return Err("max: 'ComparisonMethod' expects a string value".to_string());
497    };
498    match keyword.as_str() {
499        "auto" => Ok(ComparisonMethod::Auto),
500        "abs" | "magnitude" => Ok(ComparisonMethod::Abs),
501        "real" => Ok(ComparisonMethod::Real),
502        other => Err(format!("max: unsupported ComparisonMethod '{other}'")),
503    }
504}
505
506fn parse_dimension_value(value: &Value) -> Result<Option<DimSelection>, String> {
507    match value {
508        Value::Int(i) => {
509            let raw = i.to_i64();
510            if raw < 1 {
511                return Err("max: dimension must be >= 1".to_string());
512            }
513            Ok(Some(DimSelection::Dim(raw as usize)))
514        }
515        Value::Num(n) => {
516            if !n.is_finite() {
517                return Err("max: dimension must be finite".to_string());
518            }
519            let rounded = n.round();
520            if (rounded - n).abs() > f64::EPSILON {
521                return Err("max: dimension must be integral".to_string());
522            }
523            if rounded < 1.0 {
524                return Err("max: dimension must be >= 1".to_string());
525            }
526            Ok(Some(DimSelection::Dim(rounded as usize)))
527        }
528        Value::Tensor(t) => parse_dimension_tensor(t),
529        Value::LogicalArray(logical) => {
530            let tensor = tensor::logical_to_tensor(logical)?;
531            parse_dimension_tensor(&tensor)
532        }
533        Value::GpuTensor(_) => Err(
534            "max: dimension arguments must reside on the host (they cannot be gpuArray values)"
535                .to_string(),
536        ),
537        _ => Ok(None),
538    }
539}
540
541fn parse_dimension_tensor(tensor: &Tensor) -> Result<Option<DimSelection>, String> {
542    if tensor.data.is_empty() {
543        return Ok(Some(DimSelection::Auto));
544    }
545    if tensor.rows() != 1 && tensor.cols() != 1 && tensor.shape.len() != 1 {
546        return Err("max: dimension vector must be a row or column vector".to_string());
547    }
548    let mut dims = Vec::with_capacity(tensor.data.len());
549    for &value in &tensor.data {
550        if !value.is_finite() {
551            return Err("max: dimension entries must be finite".to_string());
552        }
553        let rounded = value.round();
554        if (rounded - value).abs() > f64::EPSILON {
555            return Err("max: dimension entries must be integers".to_string());
556        }
557        if rounded < 1.0 {
558            return Err("max: dimension indices must be >= 1".to_string());
559        }
560        dims.push(rounded as usize);
561    }
562    if dims.is_empty() {
563        Ok(Some(DimSelection::Auto))
564    } else {
565        // MATLAB treats duplicate entries gracefully; remove duplicates while preserving order.
566        let mut seen = BTreeSet::new();
567        let mut uniq = Vec::with_capacity(dims.len());
568        for dim in dims {
569            if seen.insert(dim) {
570                uniq.push(dim);
571            }
572        }
573        Ok(Some(DimSelection::Vec(uniq)))
574    }
575}
576
577fn reduction_max(value: Value, args: ReductionArgs) -> Result<MaxEvaluation, String> {
578    match value {
579        Value::GpuTensor(handle) => {
580            if let Some(eval) = reduction_max_gpu(handle.clone(), &args)? {
581                return Ok(eval);
582            }
583            // Fall back to host if GPU path is unavailable.
584            let tensor = gpu_helpers::gather_tensor(&handle)?;
585            reduction_max_host(Value::Tensor(tensor), &args)
586        }
587        other => reduction_max_host(other, &args),
588    }
589}
590
591fn reduction_max_gpu(
592    handle: GpuTensorHandle,
593    args: &ReductionArgs,
594) -> Result<Option<MaxEvaluation>, String> {
595    #[cfg(all(test, feature = "wgpu"))]
596    {
597        if handle.device_id != 0 {
598            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
599                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
600            );
601        }
602    }
603    if args.nan_mode == ReductionNaN::Omit {
604        return Ok(None);
605    }
606    if args.comparison != ComparisonMethod::Auto {
607        return Ok(None);
608    }
609    if args.linear_index {
610        return Ok(None);
611    }
612    let provider = match runmat_accelerate_api::provider() {
613        Some(p) => p,
614        None => return Ok(None),
615    };
616    let target_dim = match args.selection {
617        DimSelection::Auto => default_dimension_from_shape(&handle.shape),
618        DimSelection::Dim(dim) => dim,
619        DimSelection::Vec(ref dims) if dims.len() == 1 => dims[0],
620        DimSelection::All => {
621            if handle.shape.len() <= 1 {
622                1
623            } else {
624                return Ok(None);
625            }
626        }
627        _ => return Ok(None),
628    };
629    if target_dim == 0 {
630        return Ok(None);
631    }
632    // MATLAB dimensions are 1-based; `reduce_max_dim` expects zero-based.
633    let zero_based = target_dim.saturating_sub(1);
634    if zero_based >= handle.shape.len() {
635        return Ok(None);
636    }
637    match provider.reduce_max_dim(&handle, zero_based) {
638        Ok(ReduceDimResult { values, indices }) => Ok(Some(MaxEvaluation {
639            values: Value::GpuTensor(values),
640            indices: Value::GpuTensor(indices),
641        })),
642        Err(_) => Ok(None),
643    }
644}
645
646fn reduction_max_host(value: Value, args: &ReductionArgs) -> Result<MaxEvaluation, String> {
647    match materialize_for_max("max", value)? {
648        InputData::Real(tensor) => reduce_real_tensor(tensor, args),
649        InputData::Complex(tensor) => reduce_complex_tensor(tensor, args),
650    }
651}
652
653enum InputData {
654    Real(Tensor),
655    Complex(ComplexTensor),
656}
657
658fn materialize_for_max(name: &str, value: Value) -> Result<InputData, String> {
659    match value {
660        Value::Tensor(t) => Ok(InputData::Real(t)),
661        Value::LogicalArray(logical) => {
662            let tensor = tensor::logical_to_tensor(&logical)?;
663            Ok(InputData::Real(tensor))
664        }
665        Value::Num(n) => {
666            let tensor = Tensor::new(vec![n], vec![1, 1]).map_err(|e| format!("{name}: {e}"))?;
667            Ok(InputData::Real(tensor))
668        }
669        Value::Int(i) => {
670            let tensor =
671                Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(|e| format!("{name}: {e}"))?;
672            Ok(InputData::Real(tensor))
673        }
674        Value::Bool(b) => {
675            let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
676                .map_err(|e| format!("{name}: {e}"))?;
677            Ok(InputData::Real(tensor))
678        }
679        Value::Complex(re, im) => {
680            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
681                .map_err(|e| format!("{name}: {e}"))?;
682            Ok(InputData::Complex(tensor))
683        }
684        Value::ComplexTensor(ct) => Ok(InputData::Complex(ct)),
685        Value::String(_) | Value::StringArray(_) | Value::CharArray(_) | Value::Cell(_) => Err(
686            format!("{name}: expected numeric or logical input, received non-numeric value"),
687        ),
688        Value::GpuTensor(_) => Err(format!(
689            "{name}: internal error – GPU tensors must be gathered before host execution"
690        )),
691        Value::Object(_) | Value::HandleObject(_) | Value::Struct(_) | Value::Listener(_) => {
692            Err(format!("{name}: unsupported input type"))
693        }
694        Value::FunctionHandle(_)
695        | Value::Closure(_)
696        | Value::ClassRef(_)
697        | Value::MException(_) => Err(format!("{name}: unsupported input type")),
698    }
699}
700
701fn reduce_real_tensor(tensor: Tensor, args: &ReductionArgs) -> Result<MaxEvaluation, String> {
702    let shape = tensor.shape.clone();
703    if tensor.data.is_empty() {
704        let output_shape = resolve_output_shape(&shape, &args.selection, &[])?;
705        let values =
706            Tensor::new(Vec::new(), output_shape.clone()).map_err(|e| format!("max: {e}"))?;
707        let indices = Tensor::new(Vec::new(), output_shape).map_err(|e| format!("max: {e}"))?;
708        return Ok(MaxEvaluation {
709            values: tensor::tensor_into_value(values),
710            indices: tensor::tensor_into_value(indices),
711        });
712    }
713    let resolved = resolve_reduction_dims(&shape, &args.selection)?;
714    let output_shape = resolved.output_shape.clone();
715    let output_len = tensor::element_count(&output_shape);
716
717    if output_len == 0 {
718        let values =
719            Tensor::new(Vec::new(), output_shape.clone()).map_err(|e| format!("max: {e}"))?;
720        let indices = Tensor::new(Vec::new(), output_shape).map_err(|e| format!("max: {e}"))?;
721        return Ok(MaxEvaluation {
722            values: tensor::tensor_into_value(values),
723            indices: tensor::tensor_into_value(indices),
724        });
725    }
726
727    let strides = compute_strides(&shape);
728    let output_strides = compute_strides(&output_shape);
729    let dims_mask = resolved.dims_mask.clone();
730    let reduce_strides = resolved.reduce_strides.clone();
731
732    let mut best = vec![BestReal::new(); output_len];
733    let mut coords = vec![0usize; shape.len()];
734    for &value in &tensor.data {
735        let out_idx = map_output_index(&coords, &output_strides, &dims_mask);
736        let reduce_idx = map_reduce_index(
737            &coords,
738            &resolved.reduced_dims,
739            &reduce_strides,
740            resolved.reduce_all,
741        );
742        let full_idx = map_linear_index(&coords, &strides);
743
744        update_best_real(
745            &mut best[out_idx],
746            value,
747            reduce_idx,
748            full_idx,
749            args.nan_mode,
750            args.comparison,
751        );
752        increment_coords(&mut coords, &shape);
753    }
754
755    let mut values = vec![0.0f64; output_len];
756    let mut indices = vec![0.0f64; output_len];
757
758    for (i, entry) in best.iter().enumerate() {
759        if entry.nan_fixed {
760            values[i] = f64::NAN;
761            indices[i] = if args.linear_index || resolved.reduce_all {
762                (entry.full_index + 1) as f64
763            } else if resolved.reduced_dims.is_empty() {
764                1.0
765            } else {
766                (entry.reduce_index + 1) as f64
767            };
768            continue;
769        }
770        if !entry.has_value {
771            values[i] = f64::NAN;
772            indices[i] = f64::NAN;
773            continue;
774        }
775        values[i] = entry.value;
776        indices[i] = if args.linear_index || resolved.reduce_all {
777            (entry.full_index + 1) as f64
778        } else if resolved.reduced_dims.is_empty() {
779            1.0
780        } else {
781            (entry.reduce_index + 1) as f64
782        };
783    }
784
785    let value_tensor =
786        Tensor::new(values, output_shape.clone()).map_err(|e| format!("max: {e}"))?;
787    let index_tensor = Tensor::new(indices, output_shape).map_err(|e| format!("max: {e}"))?;
788
789    Ok(MaxEvaluation {
790        values: tensor::tensor_into_value(value_tensor),
791        indices: tensor::tensor_into_value(index_tensor),
792    })
793}
794
795fn reduce_complex_tensor(
796    tensor: ComplexTensor,
797    args: &ReductionArgs,
798) -> Result<MaxEvaluation, String> {
799    let shape = tensor.shape.clone();
800    if tensor.data.is_empty() {
801        let output_shape = resolve_output_shape(&shape, &args.selection, &[])?;
802        let values = ComplexTensor::new(Vec::new(), output_shape.clone())
803            .map_err(|e| format!("max: {e}"))?;
804        let indices = Tensor::new(Vec::new(), output_shape).map_err(|e| format!("max: {e}"))?;
805        return Ok(MaxEvaluation {
806            values: complex_tensor_into_value(values),
807            indices: tensor::tensor_into_value(indices),
808        });
809    }
810
811    let resolved = resolve_reduction_dims(&shape, &args.selection)?;
812    let output_shape = resolved.output_shape.clone();
813    let output_len = tensor::element_count(&output_shape);
814
815    if output_len == 0 {
816        let values = ComplexTensor::new(Vec::new(), output_shape.clone())
817            .map_err(|e| format!("max: {e}"))?;
818        let indices = Tensor::new(Vec::new(), output_shape).map_err(|e| format!("max: {e}"))?;
819        return Ok(MaxEvaluation {
820            values: complex_tensor_into_value(values),
821            indices: tensor::tensor_into_value(indices),
822        });
823    }
824
825    let strides = compute_strides(&shape);
826    let output_strides = compute_strides(&output_shape);
827    let dims_mask = resolved.dims_mask.clone();
828    let reduce_strides = resolved.reduce_strides.clone();
829
830    let mut best = vec![BestComplex::new(); output_len];
831    let mut coords = vec![0usize; shape.len()];
832
833    for &(re, im) in &tensor.data {
834        let out_idx = map_output_index(&coords, &output_strides, &dims_mask);
835        let reduce_idx = map_reduce_index(
836            &coords,
837            &resolved.reduced_dims,
838            &reduce_strides,
839            resolved.reduce_all,
840        );
841        let full_idx = map_linear_index(&coords, &strides);
842        update_best_complex(
843            &mut best[out_idx],
844            (re, im),
845            reduce_idx,
846            full_idx,
847            args.nan_mode,
848            args.comparison,
849        );
850        increment_coords(&mut coords, &shape);
851    }
852
853    let mut values = vec![(0.0f64, 0.0f64); output_len];
854    let mut indices = vec![0.0f64; output_len];
855
856    for (i, entry) in best.iter().enumerate() {
857        if entry.nan_fixed {
858            values[i] = (f64::NAN, f64::NAN);
859            indices[i] = if args.linear_index || resolved.reduce_all {
860                (entry.full_index + 1) as f64
861            } else if resolved.reduced_dims.is_empty() {
862                1.0
863            } else {
864                (entry.reduce_index + 1) as f64
865            };
866            continue;
867        }
868        if !entry.has_value {
869            values[i] = (f64::NAN, f64::NAN);
870            indices[i] = f64::NAN;
871            continue;
872        }
873        values[i] = entry.value;
874        indices[i] = if args.linear_index || resolved.reduce_all {
875            (entry.full_index + 1) as f64
876        } else if resolved.reduced_dims.is_empty() {
877            1.0
878        } else {
879            (entry.reduce_index + 1) as f64
880        };
881    }
882
883    let value_tensor =
884        ComplexTensor::new(values, output_shape.clone()).map_err(|e| format!("max: {e}"))?;
885    let index_tensor = Tensor::new(indices, output_shape).map_err(|e| format!("max: {e}"))?;
886    Ok(MaxEvaluation {
887        values: complex_tensor_into_value(value_tensor),
888        indices: tensor::tensor_into_value(index_tensor),
889    })
890}
891
892#[derive(Debug, Clone)]
893struct BestReal {
894    value: f64,
895    reduce_index: usize,
896    full_index: usize,
897    has_value: bool,
898    nan_fixed: bool,
899}
900
901impl BestReal {
902    fn new() -> Self {
903        Self {
904            value: 0.0,
905            reduce_index: 0,
906            full_index: 0,
907            has_value: false,
908            nan_fixed: false,
909        }
910    }
911}
912
913#[derive(Debug, Clone)]
914struct BestComplex {
915    value: (f64, f64),
916    reduce_index: usize,
917    full_index: usize,
918    has_value: bool,
919    nan_fixed: bool,
920}
921
922impl BestComplex {
923    fn new() -> Self {
924        Self {
925            value: (0.0, 0.0),
926            reduce_index: 0,
927            full_index: 0,
928            has_value: false,
929            nan_fixed: false,
930        }
931    }
932}
933
934fn resolve_output_shape(
935    shape: &[usize],
936    selection: &DimSelection,
937    reduced_dims: &[usize],
938) -> Result<Vec<usize>, String> {
939    if shape.is_empty() {
940        return Ok(Vec::new());
941    }
942    let mut output = shape.to_vec();
943    match selection {
944        DimSelection::All => {
945            output.fill(1);
946        }
947        _ => {
948            for &dim in reduced_dims {
949                if dim < output.len() {
950                    output[dim] = 1;
951                }
952            }
953        }
954    }
955    Ok(output)
956}
957
958struct ResolvedDims {
959    output_shape: Vec<usize>,
960    reduced_dims: Vec<usize>,
961    reduce_all: bool,
962    dims_mask: Vec<bool>,
963    reduce_strides: Vec<usize>,
964}
965
966fn resolve_reduction_dims(
967    shape: &[usize],
968    selection: &DimSelection,
969) -> Result<ResolvedDims, String> {
970    if shape.is_empty() {
971        return Ok(ResolvedDims {
972            output_shape: Vec::new(),
973            reduced_dims: Vec::new(),
974            reduce_all: true,
975            dims_mask: Vec::new(),
976            reduce_strides: Vec::new(),
977        });
978    }
979
980    let mut reduced_dims = match selection {
981        DimSelection::Auto => {
982            let mut dim = None;
983            for (index, &len) in shape.iter().enumerate() {
984                if len > 1 {
985                    dim = Some(index);
986                    break;
987                }
988            }
989            vec![dim.unwrap_or(0)]
990        }
991        DimSelection::Dim(dim) => {
992            if *dim == 0 {
993                return Err("max: dimension must be >= 1".to_string());
994            }
995            let index = dim.saturating_sub(1);
996            if index >= shape.len() {
997                Vec::new()
998            } else {
999                vec![index]
1000            }
1001        }
1002        DimSelection::Vec(dims) => {
1003            if dims.is_empty() {
1004                Vec::new()
1005            } else {
1006                dims.iter()
1007                    .filter_map(|dim| {
1008                        if *dim == 0 {
1009                            None
1010                        } else {
1011                            let idx = dim - 1;
1012                            if idx < shape.len() {
1013                                Some(idx)
1014                            } else {
1015                                None
1016                            }
1017                        }
1018                    })
1019                    .collect()
1020            }
1021        }
1022        DimSelection::All => (0..shape.len()).collect(),
1023    };
1024
1025    reduced_dims.sort_unstable();
1026    reduced_dims.dedup();
1027
1028    let reduce_all = !reduced_dims.is_empty()
1029        && reduced_dims.len() == shape.len()
1030        && reduced_dims.iter().enumerate().all(|(i, &d)| i == d);
1031
1032    let output_shape = resolve_output_shape(shape, selection, &reduced_dims)?;
1033    let mut dims_mask = vec![false; shape.len()];
1034    for &dim in &reduced_dims {
1035        if dim < dims_mask.len() {
1036            dims_mask[dim] = true;
1037        }
1038    }
1039    let reduce_strides = compute_subspace_strides(shape, &reduced_dims);
1040
1041    Ok(ResolvedDims {
1042        output_shape,
1043        reduced_dims,
1044        reduce_all,
1045        dims_mask,
1046        reduce_strides,
1047    })
1048}
1049
1050fn compute_strides(shape: &[usize]) -> Vec<usize> {
1051    let mut strides = Vec::with_capacity(shape.len());
1052    let mut stride = 1usize;
1053    for &len in shape {
1054        strides.push(stride);
1055        stride = stride.saturating_mul(len.max(1));
1056    }
1057    strides
1058}
1059
1060fn compute_subspace_strides(shape: &[usize], dims: &[usize]) -> Vec<usize> {
1061    if dims.is_empty() {
1062        return Vec::new();
1063    }
1064    let mut strides = Vec::with_capacity(dims.len());
1065    let mut accum = 1usize;
1066    for &dim in dims {
1067        let len = shape.get(dim).copied().unwrap_or(1).max(1);
1068        strides.push(accum);
1069        accum = accum.saturating_mul(len);
1070    }
1071    strides
1072}
1073
1074fn map_output_index(coords: &[usize], output_strides: &[usize], dims_mask: &[bool]) -> usize {
1075    if coords.is_empty() {
1076        return 0;
1077    }
1078    let mut index = 0usize;
1079    for (dim, stride) in output_strides.iter().enumerate() {
1080        let coord = if *dims_mask.get(dim).unwrap_or(&false) {
1081            0
1082        } else {
1083            coords[dim]
1084        };
1085        index = index.saturating_add(coord.saturating_mul(*stride));
1086    }
1087    index
1088}
1089
1090fn map_reduce_index(
1091    coords: &[usize],
1092    reduced_dims: &[usize],
1093    reduce_strides: &[usize],
1094    reduce_all: bool,
1095) -> usize {
1096    if reduced_dims.is_empty() {
1097        return 0;
1098    }
1099    if reduce_all {
1100        // When all dimensions are reduced, the full index is used separately.
1101        return 0;
1102    }
1103    let mut index = 0usize;
1104    for (pos, &dim) in reduced_dims.iter().enumerate() {
1105        if let Some(coord) = coords.get(dim) {
1106            if let Some(stride) = reduce_strides.get(pos) {
1107                index = index.saturating_add(coord.saturating_mul(*stride));
1108            }
1109        }
1110    }
1111    index
1112}
1113
1114fn map_linear_index(coords: &[usize], strides: &[usize]) -> usize {
1115    coords
1116        .iter()
1117        .zip(strides.iter())
1118        .fold(0usize, |acc, (&coord, &stride)| {
1119            acc.saturating_add(coord.saturating_mul(stride))
1120        })
1121}
1122
1123fn increment_coords(coords: &mut [usize], shape: &[usize]) {
1124    for dim in 0..coords.len() {
1125        if shape[dim] == 0 {
1126            continue;
1127        }
1128        coords[dim] += 1;
1129        if coords[dim] < shape[dim] {
1130            break;
1131        }
1132        coords[dim] = 0;
1133    }
1134}
1135
1136fn update_best_real(
1137    best: &mut BestReal,
1138    value: f64,
1139    reduce_index: usize,
1140    full_index: usize,
1141    nan_mode: ReductionNaN,
1142    comparison: ComparisonMethod,
1143) {
1144    if value.is_nan() {
1145        match nan_mode {
1146            ReductionNaN::Include => {
1147                if !best.nan_fixed {
1148                    best.value = f64::NAN;
1149                    best.reduce_index = reduce_index;
1150                    best.full_index = full_index;
1151                    best.has_value = true;
1152                    best.nan_fixed = true;
1153                }
1154            }
1155            ReductionNaN::Omit => {}
1156        }
1157        return;
1158    }
1159    if best.nan_fixed {
1160        return;
1161    }
1162
1163    if !best.has_value {
1164        best.value = value;
1165        best.reduce_index = reduce_index;
1166        best.full_index = full_index;
1167        best.has_value = true;
1168        return;
1169    }
1170
1171    if should_replace_real(best.value, value, comparison) {
1172        best.value = value;
1173        best.reduce_index = reduce_index;
1174        best.full_index = full_index;
1175    }
1176}
1177
1178fn update_best_complex(
1179    best: &mut BestComplex,
1180    value: (f64, f64),
1181    reduce_index: usize,
1182    full_index: usize,
1183    nan_mode: ReductionNaN,
1184    comparison: ComparisonMethod,
1185) {
1186    if value.0.is_nan() || value.1.is_nan() {
1187        match nan_mode {
1188            ReductionNaN::Include => {
1189                if !best.nan_fixed {
1190                    best.value = (f64::NAN, f64::NAN);
1191                    best.reduce_index = reduce_index;
1192                    best.full_index = full_index;
1193                    best.has_value = true;
1194                    best.nan_fixed = true;
1195                }
1196            }
1197            ReductionNaN::Omit => {}
1198        }
1199        return;
1200    }
1201    if best.nan_fixed {
1202        return;
1203    }
1204
1205    if !best.has_value {
1206        best.value = value;
1207        best.reduce_index = reduce_index;
1208        best.full_index = full_index;
1209        best.has_value = true;
1210        return;
1211    }
1212
1213    if should_replace_complex(best.value, value, comparison) {
1214        best.value = value;
1215        best.reduce_index = reduce_index;
1216        best.full_index = full_index;
1217    }
1218}
1219
1220fn should_replace_real(current: f64, candidate: f64, comparison: ComparisonMethod) -> bool {
1221    match comparison {
1222        ComparisonMethod::Auto | ComparisonMethod::Real => {
1223            if candidate > current {
1224                return true;
1225            }
1226            if candidate < current {
1227                return false;
1228            }
1229            if candidate == 0.0 && current == 0.0 {
1230                return candidate.is_sign_positive() && !current.is_sign_positive();
1231            }
1232            false
1233        }
1234        ComparisonMethod::Abs => {
1235            let curr_abs = current.abs();
1236            let cand_abs = candidate.abs();
1237            if cand_abs > curr_abs {
1238                return true;
1239            }
1240            if cand_abs < curr_abs {
1241                return false;
1242            }
1243            if candidate > current {
1244                return true;
1245            }
1246            if candidate < current {
1247                return false;
1248            }
1249            if candidate == 0.0 && current == 0.0 {
1250                return candidate.is_sign_positive() && !current.is_sign_positive();
1251            }
1252            false
1253        }
1254    }
1255}
1256
1257fn should_replace_complex(
1258    current: (f64, f64),
1259    candidate: (f64, f64),
1260    comparison: ComparisonMethod,
1261) -> bool {
1262    match comparison {
1263        ComparisonMethod::Auto | ComparisonMethod::Abs => {
1264            compare_complex_auto(current, candidate) == Ordering::Less
1265        }
1266        ComparisonMethod::Real => compare_complex_real(current, candidate) == Ordering::Less,
1267    }
1268}
1269
1270fn compare_complex_auto(a: (f64, f64), b: (f64, f64)) -> Ordering {
1271    let a_mag = magnitude_squared(a);
1272    let b_mag = magnitude_squared(b);
1273    if a_mag < b_mag {
1274        return Ordering::Less;
1275    }
1276    if a_mag > b_mag {
1277        return Ordering::Greater;
1278    }
1279    // Equal magnitude: tie-break using phase angle.
1280    let a_angle = a.1.atan2(a.0);
1281    let b_angle = b.1.atan2(b.0);
1282    if a_angle < b_angle {
1283        Ordering::Less
1284    } else if a_angle > b_angle {
1285        Ordering::Greater
1286    } else {
1287        Ordering::Equal
1288    }
1289}
1290
1291fn compare_complex_real(a: (f64, f64), b: (f64, f64)) -> Ordering {
1292    if a.0 < b.0 {
1293        return Ordering::Less;
1294    }
1295    if a.0 > b.0 {
1296        return Ordering::Greater;
1297    }
1298    // Equal real parts: use magnitude and phase tie-breakers.
1299    compare_complex_auto(a, b)
1300}
1301
1302fn magnitude_squared(z: (f64, f64)) -> f64 {
1303    z.0.mul_add(z.0, z.1 * z.1)
1304}
1305
1306fn default_dimension_from_shape(shape: &[usize]) -> usize {
1307    if shape.is_empty() {
1308        return 1;
1309    }
1310    for (i, &len) in shape.iter().enumerate() {
1311        if len > 1 {
1312            return i + 1;
1313        }
1314    }
1315    1
1316}
1317
1318fn elementwise_max(value: Value, args: ElementwiseArgs) -> Result<MaxEvaluation, String> {
1319    let ElementwiseArgs { other, comparison } = args;
1320    match (value, other) {
1321        (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
1322            if gpu_tensor_is_scalar(&handle_b) {
1323                if let Some(num) = gpu_tensor_scalar_value(&handle_b) {
1324                    let scalar = Value::Num(num);
1325                    return elementwise_max_gpu_scalar_left(&handle_a, &scalar, comparison)
1326                        .or_else(|| {
1327                            let ta = gpu_helpers::gather_tensor(&handle_a).ok()?;
1328                            elementwise_real_or_complex(
1329                                Value::Tensor(ta),
1330                                scalar.clone(),
1331                                comparison,
1332                            )
1333                            .ok()
1334                        })
1335                        .ok_or_else(|| "max: elementwise GPU scalar path failed".to_string());
1336                }
1337            }
1338            if gpu_tensor_is_scalar(&handle_a) {
1339                if let Some(num) = gpu_tensor_scalar_value(&handle_a) {
1340                    let scalar = Value::Num(num);
1341                    return elementwise_max_gpu_scalar_right(&scalar, &handle_b, comparison)
1342                        .or_else(|| {
1343                            let tb = gpu_helpers::gather_tensor(&handle_b).ok()?;
1344                            elementwise_real_or_complex(
1345                                scalar.clone(),
1346                                Value::Tensor(tb),
1347                                comparison,
1348                            )
1349                            .ok()
1350                        })
1351                        .ok_or_else(|| "max: elementwise GPU scalar path failed".to_string());
1352                }
1353            }
1354            elementwise_max_gpu_pair(&handle_a, &handle_b, comparison)
1355                .or_else(|| {
1356                    // Fallback to host path if provider path unavailable or unsupported
1357                    let ta = gpu_helpers::gather_tensor(&handle_a).ok()?;
1358                    let tb = gpu_helpers::gather_tensor(&handle_b).ok()?;
1359                    elementwise_real_or_complex(Value::Tensor(ta), Value::Tensor(tb), comparison)
1360                        .ok()
1361                })
1362                .ok_or_else(|| "max: elementwise GPU path failed".to_string())
1363        }
1364        (Value::GpuTensor(handle), other) => {
1365            elementwise_max_gpu_scalar_left(&handle, &other, comparison)
1366                .or_else(|| {
1367                    let t = gpu_helpers::gather_tensor(&handle).ok()?;
1368                    elementwise_real_or_complex(Value::Tensor(t), other, comparison).ok()
1369                })
1370                .ok_or_else(|| "max: elementwise GPU scalar path failed".to_string())
1371        }
1372        (other, Value::GpuTensor(handle)) => {
1373            elementwise_max_gpu_scalar_right(&other, &handle, comparison)
1374                .or_else(|| {
1375                    let t = gpu_helpers::gather_tensor(&handle).ok()?;
1376                    elementwise_real_or_complex(other, Value::Tensor(t), comparison).ok()
1377                })
1378                .ok_or_else(|| "max: elementwise GPU scalar path failed".to_string())
1379        }
1380        (lhs, rhs) => elementwise_real_or_complex(lhs, rhs, comparison),
1381    }
1382}
1383
1384fn elementwise_max_gpu_pair(
1385    a: &GpuTensorHandle,
1386    b: &GpuTensorHandle,
1387    comparison: ComparisonMethod,
1388) -> Option<MaxEvaluation> {
1389    if comparison != ComparisonMethod::Auto {
1390        return None;
1391    }
1392    let provider = runmat_accelerate_api::provider()?;
1393    // Equal-shape fast path
1394    if a.shape == b.shape {
1395        let values = provider.elem_max(a, b).ok()?;
1396        // Try device mask first; if unavailable, compute indices on host while keeping values on device
1397        if let Ok(mask) = provider.elem_ge(a, b) {
1398            let indices = gpu_mask_indices(provider, &mask)?;
1399            let _ = provider.free(&mask);
1400            return Some(MaxEvaluation {
1401                values: Value::GpuTensor(values),
1402                indices: Value::GpuTensor(indices),
1403            });
1404        } else {
1405            // Host path for indices only
1406            let ta = gpu_helpers::gather_tensor(a).ok()?;
1407            let tb = gpu_helpers::gather_tensor(b).ok()?;
1408            let mut indices = Vec::with_capacity(ta.data.len());
1409            for i in 0..ta.data.len() {
1410                indices.push(if ta.data[i] >= tb.data[i] { 1.0 } else { 2.0 });
1411            }
1412            let index_tensor = Tensor::new(indices, ta.shape.clone()).ok()?;
1413            return Some(MaxEvaluation {
1414                values: Value::GpuTensor(values),
1415                indices: tensor::tensor_into_value(index_tensor),
1416            });
1417        }
1418    }
1419    // Broadcast-compatible path via repmat, then device compare
1420    let (out_shape, reps_a, reps_b) = broadcast_reps(&a.shape, &b.shape)?;
1421    let a_exp = if reps_a.iter().any(|&r| r != 1) {
1422        provider.repmat(a, &reps_a).ok()?
1423    } else {
1424        a.clone()
1425    };
1426    let b_exp = if reps_b.iter().any(|&r| r != 1) {
1427        provider.repmat(b, &reps_b).ok()?
1428    } else {
1429        b.clone()
1430    };
1431    let values = provider.elem_max(&a_exp, &b_exp).ok();
1432    let mask = provider.elem_ge(&a_exp, &b_exp).ok();
1433    if !std::ptr::eq(&a_exp, a) {
1434        let _ = provider.free(&a_exp);
1435    }
1436    if !std::ptr::eq(&b_exp, b) {
1437        let _ = provider.free(&b_exp);
1438    }
1439    let values = values?;
1440    if values.shape != out_shape {
1441        let _ = provider.free(&values);
1442        return None;
1443    }
1444    let index_tensor = if let Some(mask) = mask {
1445        let mask_host = gpu_helpers::gather_tensor(&mask).ok()?;
1446        let _ = provider.free(&mask);
1447        let mut indices = Vec::with_capacity(mask_host.data.len());
1448        for &m in &mask_host.data {
1449            indices.push(if m != 0.0 { 1.0 } else { 2.0 });
1450        }
1451        Tensor::new(indices, out_shape).ok()?
1452    } else {
1453        // Host indices fallback
1454        let ta = gpu_helpers::gather_tensor(&a_exp).ok()?;
1455        let tb = gpu_helpers::gather_tensor(&b_exp).ok()?;
1456        let mut indices = Vec::with_capacity(ta.data.len());
1457        for i in 0..ta.data.len() {
1458            indices.push(if ta.data[i] >= tb.data[i] { 1.0 } else { 2.0 });
1459        }
1460        Tensor::new(indices, out_shape).ok()?
1461    };
1462    Some(MaxEvaluation {
1463        values: Value::GpuTensor(values),
1464        indices: tensor::tensor_into_value(index_tensor),
1465    })
1466}
1467
1468fn broadcast_reps(a: &[usize], b: &[usize]) -> Option<(Vec<usize>, Vec<usize>, Vec<usize>)> {
1469    let rank = a.len().max(b.len()).max(1);
1470    let mut out = vec![1usize; rank];
1471    let mut aa = vec![1usize; rank];
1472    let mut bb = vec![1usize; rank];
1473    for i in 0..rank {
1474        aa[i] = *a.get(i).unwrap_or(&1);
1475        bb[i] = *b.get(i).unwrap_or(&1);
1476    }
1477    for i in 0..rank {
1478        let (ad, bd) = (aa[i], bb[i]);
1479        if ad == bd {
1480            out[i] = ad;
1481        } else if ad == 1 {
1482            out[i] = bd;
1483        } else if bd == 1 {
1484            out[i] = ad;
1485        } else {
1486            return None;
1487        }
1488    }
1489    let reps_a: Vec<usize> = (0..rank)
1490        .map(|i| if aa[i] == out[i] { 1 } else { out[i] })
1491        .collect();
1492    let reps_b: Vec<usize> = (0..rank)
1493        .map(|i| if bb[i] == out[i] { 1 } else { out[i] })
1494        .collect();
1495    Some((out, reps_a, reps_b))
1496}
1497
1498fn elementwise_max_gpu_scalar_left(
1499    a: &GpuTensorHandle,
1500    other: &Value,
1501    comparison: ComparisonMethod,
1502) -> Option<MaxEvaluation> {
1503    if comparison != ComparisonMethod::Auto {
1504        return None;
1505    }
1506    let provider = runmat_accelerate_api::provider()?;
1507    let scalar = extract_scalar(other)?;
1508    // Prefer tensorize + elem_max for broader provider compatibility
1509    let values = if let Ok(fill) = provider.fill_like(a, scalar) {
1510        let vals = provider.elem_max(a, &fill).ok();
1511        let _ = provider.free(&fill);
1512        vals?
1513    } else {
1514        provider.scalar_max(a, scalar).ok()?
1515    };
1516    // Try device mask; if unavailable, compute on host
1517    let index_tensor = if let Ok(fill) = provider.fill_like(a, scalar) {
1518        if let Ok(mask) = provider.elem_ge(a, &fill) {
1519            let _ = provider.free(&fill);
1520            let indices = gpu_mask_indices(provider, &mask)?;
1521            let _ = provider.free(&mask);
1522            return Some(MaxEvaluation {
1523                values: Value::GpuTensor(values),
1524                indices: Value::GpuTensor(indices),
1525            });
1526        } else {
1527            let _ = provider.free(&fill);
1528            let ta = gpu_helpers::gather_tensor(a).ok()?;
1529            let mut indices = Vec::with_capacity(ta.data.len());
1530            for &v in &ta.data {
1531                indices.push(if v >= scalar { 1.0 } else { 2.0 });
1532            }
1533            Tensor::new(indices, ta.shape.clone()).ok()?
1534        }
1535    } else {
1536        let ta = gpu_helpers::gather_tensor(a).ok()?;
1537        let mut indices = Vec::with_capacity(ta.data.len());
1538        for &v in &ta.data {
1539            indices.push(if v >= scalar { 1.0 } else { 2.0 });
1540        }
1541        Tensor::new(indices, ta.shape.clone()).ok()?
1542    };
1543    Some(MaxEvaluation {
1544        values: Value::GpuTensor(values),
1545        indices: tensor::tensor_into_value(index_tensor),
1546    })
1547}
1548
1549fn elementwise_max_gpu_scalar_right(
1550    other: &Value,
1551    b: &GpuTensorHandle,
1552    comparison: ComparisonMethod,
1553) -> Option<MaxEvaluation> {
1554    if comparison != ComparisonMethod::Auto {
1555        return None;
1556    }
1557    let provider = runmat_accelerate_api::provider()?;
1558    let scalar = extract_scalar(other)?;
1559    let values = if let Ok(fill) = provider.fill_like(b, scalar) {
1560        let vals = provider.elem_max(&fill, b).ok();
1561        let _ = provider.free(&fill);
1562        vals?
1563    } else {
1564        provider.scalar_max(b, scalar).ok()?
1565    };
1566    // Try device mask; if unavailable, compute on host (origin 1 if scalar >= b)
1567    let index_tensor = if let Ok(fill) = provider.fill_like(b, scalar) {
1568        if let Ok(mask) = provider.elem_ge(&fill, b) {
1569            let _ = provider.free(&fill);
1570            let indices = gpu_mask_indices(provider, &mask)?;
1571            let _ = provider.free(&mask);
1572            return Some(MaxEvaluation {
1573                values: Value::GpuTensor(values),
1574                indices: Value::GpuTensor(indices),
1575            });
1576        } else {
1577            let _ = provider.free(&fill);
1578            let tb = gpu_helpers::gather_tensor(b).ok()?;
1579            let mut indices = Vec::with_capacity(tb.data.len());
1580            for &v in &tb.data {
1581                indices.push(if scalar >= v { 1.0 } else { 2.0 });
1582            }
1583            Tensor::new(indices, tb.shape.clone()).ok()?
1584        }
1585    } else {
1586        let tb = gpu_helpers::gather_tensor(b).ok()?;
1587        let mut indices = Vec::with_capacity(tb.data.len());
1588        for &v in &tb.data {
1589            indices.push(if scalar >= v { 1.0 } else { 2.0 });
1590        }
1591        Tensor::new(indices, tb.shape.clone()).ok()?
1592    };
1593    Some(MaxEvaluation {
1594        values: Value::GpuTensor(values),
1595        indices: tensor::tensor_into_value(index_tensor),
1596    })
1597}
1598
1599fn extract_scalar(v: &Value) -> Option<f64> {
1600    match v {
1601        Value::Num(n) => Some(*n),
1602        Value::Int(i) => Some(i.to_f64()),
1603        Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
1604        Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
1605        Value::LogicalArray(l) if l.data.len() == 1 => Some(if l.data[0] != 0 { 1.0 } else { 0.0 }),
1606        _ => None,
1607    }
1608}
1609
1610fn gpu_tensor_is_scalar(handle: &GpuTensorHandle) -> bool {
1611    handle.shape.iter().copied().product::<usize>().max(1) == 1
1612}
1613
1614fn gpu_tensor_scalar_value(handle: &GpuTensorHandle) -> Option<f64> {
1615    let tensor = gpu_helpers::gather_tensor(handle).ok()?;
1616    tensor.data.first().copied()
1617}
1618
1619fn gpu_mask_indices(
1620    provider: &dyn AccelProvider,
1621    mask: &GpuTensorHandle,
1622) -> Option<GpuTensorHandle> {
1623    let scaled = provider.scalar_mul(mask, -1.0).ok()?;
1624    let shifted = provider.scalar_add(&scaled, 2.0).ok()?;
1625    let _ = provider.free(&scaled);
1626    Some(shifted)
1627}
1628
1629fn elementwise_real_or_complex(
1630    lhs: Value,
1631    rhs: Value,
1632    comparison: ComparisonMethod,
1633) -> Result<MaxEvaluation, String> {
1634    match (
1635        materialize_for_max("max", lhs)?,
1636        materialize_for_max("max", rhs)?,
1637    ) {
1638        (InputData::Complex(a), InputData::Complex(b)) => elementwise_complex_max(a, b, comparison),
1639        (InputData::Complex(a), InputData::Real(b)) => {
1640            let converted = promote_real_tensor_to_complex(b);
1641            elementwise_complex_max(a, converted, comparison)
1642        }
1643        (InputData::Real(a), InputData::Complex(b)) => {
1644            let converted = promote_real_tensor_to_complex(a);
1645            elementwise_complex_max(converted, b, comparison)
1646        }
1647        (InputData::Real(a), InputData::Real(b)) => elementwise_real_max(a, b, comparison),
1648    }
1649}
1650
1651fn elementwise_real_max(
1652    lhs: Tensor,
1653    rhs: Tensor,
1654    comparison: ComparisonMethod,
1655) -> Result<MaxEvaluation, String> {
1656    let plan = BroadcastPlan::new(&lhs.shape, &rhs.shape).map_err(|err| format!("max: {}", err))?;
1657    let mut values = vec![0.0f64; plan.len()];
1658    let mut indices = vec![0.0f64; plan.len()];
1659
1660    for (offset, index_a, index_b) in plan.iter() {
1661        let a = lhs.data.get(index_a).copied().unwrap_or(f64::NAN);
1662        let b = rhs.data.get(index_b).copied().unwrap_or(f64::NAN);
1663        let (value, origin) = choose_real_elementwise(a, b, comparison);
1664        values[offset] = value;
1665        indices[offset] = origin;
1666    }
1667
1668    let value_tensor =
1669        Tensor::new(values, plan.output_shape().to_vec()).map_err(|e| format!("max: {e}"))?;
1670    let index_tensor =
1671        Tensor::new(indices, plan.output_shape().to_vec()).map_err(|e| format!("max: {e}"))?;
1672
1673    Ok(MaxEvaluation {
1674        values: tensor::tensor_into_value(value_tensor),
1675        indices: tensor::tensor_into_value(index_tensor),
1676    })
1677}
1678
1679fn elementwise_complex_max(
1680    lhs: ComplexTensor,
1681    rhs: ComplexTensor,
1682    comparison: ComparisonMethod,
1683) -> Result<MaxEvaluation, String> {
1684    let plan = BroadcastPlan::new(&lhs.shape, &rhs.shape).map_err(|err| format!("max: {}", err))?;
1685    let mut values = vec![(0.0f64, 0.0f64); plan.len()];
1686    let mut indices = vec![0.0f64; plan.len()];
1687
1688    for (offset, index_a, index_b) in plan.iter() {
1689        let a = lhs
1690            .data
1691            .get(index_a)
1692            .copied()
1693            .unwrap_or((f64::NAN, f64::NAN));
1694        let b = rhs
1695            .data
1696            .get(index_b)
1697            .copied()
1698            .unwrap_or((f64::NAN, f64::NAN));
1699        let (value, origin) = choose_complex_elementwise(a, b, comparison);
1700        values[offset] = value;
1701        indices[offset] = origin;
1702    }
1703
1704    let value_tensor = ComplexTensor::new(values, plan.output_shape().to_vec())
1705        .map_err(|e| format!("max: {e}"))?;
1706    let index_tensor =
1707        Tensor::new(indices, plan.output_shape().to_vec()).map_err(|e| format!("max: {e}"))?;
1708
1709    Ok(MaxEvaluation {
1710        values: complex_tensor_into_value(value_tensor),
1711        indices: tensor::tensor_into_value(index_tensor),
1712    })
1713}
1714
1715fn promote_real_tensor_to_complex(tensor: Tensor) -> ComplexTensor {
1716    let data = tensor
1717        .data
1718        .iter()
1719        .copied()
1720        .map(|re| (re, 0.0))
1721        .collect::<Vec<_>>();
1722    ComplexTensor {
1723        data,
1724        shape: tensor.shape.clone(),
1725        rows: tensor.rows,
1726        cols: tensor.cols,
1727    }
1728}
1729
1730fn choose_real_elementwise(a: f64, b: f64, comparison: ComparisonMethod) -> (f64, f64) {
1731    match (a.is_nan(), b.is_nan()) {
1732        (true, true) => (f64::NAN, 1.0),
1733        (true, false) => (f64::NAN, 1.0),
1734        (false, true) => (f64::NAN, 2.0),
1735        (false, false) => {
1736            if should_replace_real(a, b, comparison) {
1737                (b, 2.0)
1738            } else {
1739                (a, 1.0)
1740            }
1741        }
1742    }
1743}
1744
1745fn choose_complex_elementwise(
1746    a: (f64, f64),
1747    b: (f64, f64),
1748    comparison: ComparisonMethod,
1749) -> ((f64, f64), f64) {
1750    let a_nan = a.0.is_nan() || a.1.is_nan();
1751    let b_nan = b.0.is_nan() || b.1.is_nan();
1752    match (a_nan, b_nan) {
1753        (true, true) => ((f64::NAN, f64::NAN), 1.0),
1754        (true, false) => ((f64::NAN, f64::NAN), 1.0),
1755        (false, true) => ((f64::NAN, f64::NAN), 2.0),
1756        (false, false) => {
1757            if should_replace_complex(a, b, comparison) {
1758                (b, 2.0)
1759            } else {
1760                (a, 1.0)
1761            }
1762        }
1763    }
1764}
1765
1766#[cfg(test)]
1767mod tests {
1768    use super::*;
1769    #[cfg(any(feature = "doc_export", feature = "wgpu"))]
1770    use crate::builtins::common::test_support;
1771    #[cfg(feature = "wgpu")]
1772    use runmat_accelerate_api::HostTensorView;
1773    use runmat_builtins::{IntValue, Tensor, Value};
1774
1775    fn placeholder() -> Value {
1776        let tensor = Tensor::new(Vec::<f64>::new(), vec![0, 0]).unwrap();
1777        Value::Tensor(tensor)
1778    }
1779
1780    #[test]
1781    fn max_scalar_returns_input() {
1782        let result = max_builtin(Value::Num(5.0), Vec::new()).expect("max");
1783        assert_eq!(result, Value::Num(5.0));
1784    }
1785
1786    #[test]
1787    fn max_vector_with_indices() {
1788        let tensor = Tensor::new(vec![3.0, 1.0, 5.0], vec![3, 1]).unwrap();
1789        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1790        let (values, indices) = eval.into_pair();
1791        assert_eq!(values, Value::Num(5.0));
1792        assert_eq!(indices, Value::Num(3.0));
1793    }
1794
1795    #[test]
1796    fn max_matrix_default_dimension() {
1797        let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0, 5.0, 6.0], vec![2, 3]).unwrap();
1798        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1799        let (values, indices) = eval.into_pair();
1800        match values {
1801            Value::Tensor(t) => {
1802                assert_eq!(t.shape, vec![1, 3]);
1803                assert_eq!(t.data, vec![4.0, 2.0, 6.0]);
1804            }
1805            other => panic!("expected tensor, got {other:?}"),
1806        }
1807        match indices {
1808            Value::Tensor(t) => {
1809                assert_eq!(t.data, vec![2.0, 2.0, 2.0]);
1810            }
1811            other => panic!("expected tensor, got {other:?}"),
1812        }
1813    }
1814
1815    #[test]
1816    fn max_all_linear_index() {
1817        let tensor =
1818            Tensor::new((1..=12).map(|v| v as f64).collect::<Vec<_>>(), vec![3, 4]).unwrap();
1819        let args = vec![placeholder(), Value::from("all")];
1820        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1821        let (values, indices) = eval.into_pair();
1822        assert_eq!(values, Value::Num(12.0));
1823        assert_eq!(indices, Value::Num(12.0));
1824
1825        let args_linear = vec![placeholder(), Value::from("linear")];
1826        let eval = evaluate(
1827            Value::Tensor(Tensor::new(vec![2.0, 3.0], vec![1, 2]).unwrap()),
1828            &args_linear,
1829        )
1830        .expect("evaluate");
1831        let (values, indices) = eval.into_pair();
1832        assert_eq!(values, Value::Num(3.0));
1833        assert_eq!(indices, Value::Num(2.0));
1834    }
1835
1836    #[test]
1837    fn max_with_omitnan() {
1838        let tensor = Tensor::new(vec![f64::NAN, 4.0, 2.0], vec![3, 1]).unwrap();
1839        let args = vec![placeholder(), Value::from("omitnan")];
1840        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1841        let (values, indices) = eval.into_pair();
1842        assert_eq!(values, Value::Num(4.0));
1843        assert_eq!(indices, Value::Num(2.0));
1844    }
1845
1846    #[test]
1847    fn max_omitnan_all_nan_slice() {
1848        let tensor = Tensor::new(vec![f64::NAN, f64::NAN], vec![2, 1]).unwrap();
1849        let args = vec![placeholder(), Value::from("omitnan")];
1850        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1851        let (values, indices) = eval.into_pair();
1852        match values {
1853            Value::Num(v) => assert!(v.is_nan()),
1854            other => panic!("expected scalar NaN, got {other:?}"),
1855        }
1856        match indices {
1857            Value::Num(v) => assert!(v.is_nan()),
1858            other => panic!("expected scalar NaN index, got {other:?}"),
1859        }
1860    }
1861
1862    #[test]
1863    fn max_reduction_abs_comparison() {
1864        let tensor = Tensor::new(vec![1.0, -3.0, -2.0, 4.0], vec![2, 2]).unwrap();
1865        let args = vec![
1866            placeholder(),
1867            Value::from("ComparisonMethod"),
1868            Value::from("abs"),
1869        ];
1870        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1871        let (values, indices) = eval.into_pair();
1872        match values {
1873            Value::Tensor(t) => {
1874                assert_eq!(t.shape, vec![1, 2]);
1875                assert_eq!(t.data, vec![-3.0, 4.0]);
1876            }
1877            other => panic!("expected tensor result, got {other:?}"),
1878        }
1879        match indices {
1880            Value::Tensor(t) => {
1881                assert_eq!(t.data, vec![2.0, 2.0]);
1882            }
1883            other => panic!("expected tensor indices, got {other:?}"),
1884        }
1885    }
1886
1887    #[test]
1888    fn max_reduction_complex_real_comparison() {
1889        let tensor = ComplexTensor::new(vec![(1.0, 2.0), (0.5, 5.0)], vec![2, 1]).expect("tensor");
1890        let args = vec![
1891            placeholder(),
1892            Value::from("ComparisonMethod"),
1893            Value::from("real"),
1894        ];
1895        let eval = evaluate(Value::ComplexTensor(tensor), &args).expect("evaluate");
1896        let (values, indices) = eval.into_pair();
1897        match values {
1898            Value::Complex(re, im) => {
1899                assert!((re - 1.0).abs() < 1e-12);
1900                assert!((im - 2.0).abs() < 1e-12);
1901            }
1902            other => panic!("expected complex scalar, got {other:?}"),
1903        }
1904        assert_eq!(indices, Value::Num(1.0));
1905    }
1906
1907    #[test]
1908    fn max_elementwise_broadcast() {
1909        let lhs = Tensor::new(vec![1.0, 4.0, 7.0], vec![1, 3]).unwrap();
1910        let rhs = Tensor::new(vec![2.0, 3.0, 5.0], vec![3, 1]).unwrap();
1911        let eval = evaluate(Value::Tensor(lhs), &[Value::Tensor(rhs)]).expect("evaluate");
1912        let (values, indices) = eval.into_pair();
1913        match values {
1914            Value::Tensor(t) => {
1915                assert_eq!(t.shape, vec![3, 3]);
1916                assert_eq!([t.data[0], t.data[3], t.data[6]], [2.0, 4.0, 7.0]);
1917                assert_eq!([t.data[1], t.data[4], t.data[7]], [3.0, 4.0, 7.0]);
1918                assert_eq!([t.data[2], t.data[5], t.data[8]], [5.0, 5.0, 7.0]);
1919            }
1920            other => panic!("expected tensor, got {other:?}"),
1921        }
1922        match indices {
1923            Value::Tensor(t) => {
1924                assert_eq!(t.shape, vec![3, 3]);
1925                assert_eq!([t.data[0], t.data[3], t.data[6]], [2.0, 1.0, 1.0]);
1926                assert_eq!([t.data[1], t.data[4], t.data[7]], [2.0, 1.0, 1.0]);
1927                assert_eq!([t.data[2], t.data[5], t.data[8]], [2.0, 2.0, 1.0]);
1928            }
1929            other => panic!("expected tensor, got {other:?}"),
1930        }
1931    }
1932
1933    #[test]
1934    fn max_elementwise_abs_comparison() {
1935        let lhs = Tensor::new(vec![-2.0, 1.0], vec![2, 1]).unwrap();
1936        let rhs = Tensor::new(vec![1.5, -3.0], vec![2, 1]).unwrap();
1937        let args = vec![
1938            Value::Tensor(rhs),
1939            Value::from("ComparisonMethod"),
1940            Value::from("abs"),
1941        ];
1942        let eval = evaluate(Value::Tensor(lhs), &args).expect("evaluate");
1943        let (values, indices) = eval.into_pair();
1944        match values {
1945            Value::Tensor(t) => {
1946                assert_eq!(t.data, vec![-2.0, -3.0]);
1947            }
1948            other => panic!("expected tensor, got {other:?}"),
1949        }
1950        match indices {
1951            Value::Tensor(t) => {
1952                assert_eq!(t.data, vec![1.0, 2.0]);
1953            }
1954            other => panic!("expected tensor, got {other:?}"),
1955        }
1956    }
1957
1958    #[test]
1959    fn max_elementwise_rejects_reduction_only_keywords() {
1960        let lhs = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1961        let rhs = Tensor::new(vec![3.0, 4.0], vec![2, 1]).unwrap();
1962        let err = evaluate(
1963            Value::Tensor(lhs),
1964            &[Value::Tensor(rhs), Value::from("omitnan")],
1965        )
1966        .expect_err("expected error");
1967        assert!(err.contains("only supported for reduction"));
1968    }
1969
1970    #[test]
1971    fn max_complex_real_comparison() {
1972        let lhs = ComplexTensor::new(vec![(1.0, 2.0)], vec![1, 1]).unwrap();
1973        let rhs = ComplexTensor::new(vec![(0.5, 5.0)], vec![1, 1]).unwrap();
1974        let args = vec![
1975            Value::ComplexTensor(rhs),
1976            Value::from("ComparisonMethod"),
1977            Value::from("real"),
1978        ];
1979        let eval = evaluate(Value::ComplexTensor(lhs), &args).expect("evaluate");
1980        let (values, indices) = eval.into_pair();
1981        assert_eq!(values, Value::Complex(1.0, 2.0));
1982        assert_eq!(indices, Value::Num(1.0));
1983    }
1984
1985    #[test]
1986    fn max_dimension_argument_parsing() {
1987        let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0], vec![2, 2]).unwrap();
1988        let dims = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1989        let args = vec![placeholder(), Value::Tensor(dims)];
1990        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1991        let (values, indices) = eval.into_pair();
1992        assert_eq!(values, Value::Num(4.0));
1993        assert_eq!(indices, Value::Num(2.0));
1994    }
1995
1996    #[test]
1997    fn max_vecdim_duplicate_entries() {
1998        let tensor = Tensor::new(vec![5.0, 2.0, 7.0, 1.0], vec![2, 2]).unwrap();
1999        let dims = Tensor::new(vec![1.0, 1.0, 2.0], vec![3, 1]).unwrap();
2000        let args = vec![placeholder(), Value::Tensor(dims)];
2001        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2002        let (values, indices) = eval.into_pair();
2003        assert_eq!(values, Value::Num(7.0));
2004        assert_eq!(indices, Value::Num(3.0));
2005    }
2006
2007    #[test]
2008    fn max_dimension_gpu_argument_errors() {
2009        let tensor = Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap();
2010        let dim_handle = Value::GpuTensor(runmat_accelerate_api::GpuTensorHandle {
2011            shape: vec![1, 1],
2012            device_id: 0,
2013            buffer_id: 42,
2014        });
2015        let err = evaluate(Value::Tensor(tensor), &[placeholder(), dim_handle])
2016            .expect_err("expected error");
2017        assert!(err.contains("dimension arguments must reside on the host"));
2018    }
2019
2020    #[test]
2021    fn max_invalid_comparison_method_errors() {
2022        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
2023        let args = vec![
2024            placeholder(),
2025            Value::from("ComparisonMethod"),
2026            Value::from("chebyshev"),
2027        ];
2028        let err = evaluate(Value::Tensor(tensor), &args).expect_err("expected error");
2029        assert!(err.contains("unsupported ComparisonMethod"));
2030    }
2031
2032    #[test]
2033    #[cfg(feature = "doc_export")]
2034    fn max_doc_examples_present() {
2035        let blocks = test_support::doc_examples(super::DOC_MD);
2036        assert!(!blocks.is_empty());
2037    }
2038
2039    #[test]
2040    #[cfg(feature = "wgpu")]
2041    fn max_gpu_dim1_matches_cpu() {
2042        let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2043        let eval_cpu = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu");
2044        let (values_cpu, indices_cpu) = eval_cpu.into_pair();
2045
2046        test_support::with_test_provider(|provider| {
2047            let view = HostTensorView {
2048                data: &tensor.data,
2049                shape: &tensor.shape,
2050            };
2051            let handle = provider.upload(&view).expect("upload");
2052            let eval_gpu = evaluate(Value::GpuTensor(handle), &[]).expect("gpu");
2053            let (values_gpu, indices_gpu) = eval_gpu.into_pair();
2054            match (&values_gpu, &indices_gpu) {
2055                (Value::GpuTensor(_), Value::GpuTensor(_)) => {}
2056                other => panic!("expected GPU tensors, got {other:?}"),
2057            }
2058            let gathered_vals = test_support::gather(values_gpu).expect("gather values");
2059            let gathered_idx = test_support::gather(indices_gpu).expect("gather indices");
2060            let expected_vals = match values_cpu {
2061                Value::Tensor(t) => t,
2062                other => panic!("expected tensor values from cpu eval, got {other:?}"),
2063            };
2064            let expected_idx = match indices_cpu {
2065                Value::Tensor(t) => t,
2066                other => panic!("expected tensor indices from cpu eval, got {other:?}"),
2067            };
2068            assert_eq!(gathered_vals.shape, expected_vals.shape);
2069            assert_eq!(gathered_vals.data, expected_vals.data);
2070            assert_eq!(gathered_idx.shape, expected_idx.shape);
2071            assert_eq!(gathered_idx.data, expected_idx.data);
2072        });
2073    }
2074
2075    #[test]
2076    fn max_dimension_numeric_argument() {
2077        let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0], vec![2, 2]).unwrap();
2078        let args = vec![placeholder(), Value::Num(2.0)];
2079        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2080        let (values, indices) = eval.into_pair();
2081        match values {
2082            Value::Tensor(t) => {
2083                assert_eq!(t.shape, vec![2, 1]);
2084                assert_eq!(t.data, vec![3.0, 4.0]);
2085            }
2086            other => panic!("expected tensor, got {other:?}"),
2087        }
2088        match indices {
2089            Value::Tensor(t) => {
2090                assert_eq!(t.data, vec![1.0, 1.0]);
2091            }
2092            other => panic!("expected tensor, got {other:?}"),
2093        }
2094    }
2095
2096    #[test]
2097    fn max_complex_auto_comparison() {
2098        let lhs = ComplexTensor::new(vec![(1.0, 2.0)], vec![1, 1]).unwrap();
2099        let rhs = ComplexTensor::new(vec![(2.0, 1.0)], vec![1, 1]).unwrap();
2100        let eval =
2101            evaluate(Value::ComplexTensor(lhs), &[Value::ComplexTensor(rhs)]).expect("evaluate");
2102        let (values, indices) = eval.into_pair();
2103        assert_eq!(values, Value::Complex(1.0, 2.0));
2104        assert_eq!(indices, Value::Num(1.0));
2105    }
2106
2107    #[test]
2108    fn max_scalar_pair_arguments() {
2109        let args = vec![Value::Num(2.0)];
2110        let result = max_builtin(Value::Num(3.0), args).expect("max");
2111        assert_eq!(result, Value::Num(3.0));
2112    }
2113
2114    #[test]
2115    fn max_rejects_invalid_dimension() {
2116        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
2117        let args = vec![placeholder(), Value::Int(IntValue::I32(0))];
2118        let err = evaluate(Value::Tensor(tensor), &args).expect_err("expected error");
2119        assert!(err.contains("dimension must be >= 1"));
2120    }
2121}