Skip to main content

runmat_runtime/builtins/math/reduction/
max.rs

1//! MATLAB-compatible `max` builtin with GPU-aware semantics for RunMat.
2
3use std::cmp::Ordering;
4use std::collections::BTreeSet;
5
6use runmat_accelerate_api::{AccelProvider, GpuTensorHandle, ReduceDimResult};
7use runmat_builtins::{
8    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
9    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
10    ComplexTensor, ResolveContext, Tensor, Type, Value,
11};
12use runmat_macros::runtime_builtin;
13
14use crate::{build_runtime_error, BuiltinResult, RuntimeError};
15
16const NAME: &str = "max";
17
18fn max_type(args: &[Type], ctx: &ResolveContext) -> Type {
19    min_max_type(args, ctx)
20}
21
22const MAX_OUTPUT_M: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
23    name: "M",
24    ty: BuiltinParamType::NumericArray,
25    arity: BuiltinParamArity::Required,
26    default: None,
27    description: "Maximum values.",
28}];
29
30const MAX_OUTPUT_MI: [BuiltinParamDescriptor; 2] = [
31    BuiltinParamDescriptor {
32        name: "M",
33        ty: BuiltinParamType::NumericArray,
34        arity: BuiltinParamArity::Required,
35        default: None,
36        description: "Maximum values.",
37    },
38    BuiltinParamDescriptor {
39        name: "I",
40        ty: BuiltinParamType::NumericArray,
41        arity: BuiltinParamArity::Required,
42        default: None,
43        description: "One-based maximum indices/origins.",
44    },
45];
46
47const MAX_PARAM_A: BuiltinParamDescriptor = BuiltinParamDescriptor {
48    name: "A",
49    ty: BuiltinParamType::Any,
50    arity: BuiltinParamArity::Required,
51    default: None,
52    description: "Input scalar or array.",
53};
54
55const MAX_PARAM_B: BuiltinParamDescriptor = BuiltinParamDescriptor {
56    name: "B",
57    ty: BuiltinParamType::Any,
58    arity: BuiltinParamArity::Required,
59    default: None,
60    description: "Second operand for element-wise maximum.",
61};
62
63const MAX_PARAM_EMPTY: BuiltinParamDescriptor = BuiltinParamDescriptor {
64    name: "placeholder",
65    ty: BuiltinParamType::Any,
66    arity: BuiltinParamArity::Optional,
67    default: Some("[]"),
68    description: "Empty placeholder selecting reduction-argument grammar.",
69};
70
71const MAX_PARAM_DIM: BuiltinParamDescriptor = BuiltinParamDescriptor {
72    name: "dim",
73    ty: BuiltinParamType::Any,
74    arity: BuiltinParamArity::Optional,
75    default: Some("[]"),
76    description: "Reduction dimension selector (scalar or dimension vector).",
77};
78
79const MAX_PARAM_REDUCTION_FLAG: BuiltinParamDescriptor = BuiltinParamDescriptor {
80    name: "flag",
81    ty: BuiltinParamType::StringScalar,
82    arity: BuiltinParamArity::Optional,
83    default: Some("\"all\""),
84    description: "Reduction mode flag: \"all\" or \"linear\".",
85};
86
87const MAX_PARAM_NANFLAG: BuiltinParamDescriptor = BuiltinParamDescriptor {
88    name: "nanflag",
89    ty: BuiltinParamType::StringScalar,
90    arity: BuiltinParamArity::Optional,
91    default: Some("\"includenan\""),
92    description: "Missing-value mode: \"includenan\" or \"omitnan\".",
93};
94
95const MAX_PARAM_COMPARISON_NAME: BuiltinParamDescriptor = BuiltinParamDescriptor {
96    name: "optionName",
97    ty: BuiltinParamType::StringScalar,
98    arity: BuiltinParamArity::Optional,
99    default: Some("\"ComparisonMethod\""),
100    description: "Option name (currently \"ComparisonMethod\").",
101};
102
103const MAX_PARAM_COMPARISON_VALUE: BuiltinParamDescriptor = BuiltinParamDescriptor {
104    name: "method",
105    ty: BuiltinParamType::StringScalar,
106    arity: BuiltinParamArity::Optional,
107    default: Some("\"auto\""),
108    description: "Comparison method: \"auto\", \"abs\"/\"magnitude\", or \"real\".",
109};
110
111const MAX_PARAM_OPTION_NAME: BuiltinParamDescriptor = BuiltinParamDescriptor {
112    name: "optionName",
113    ty: BuiltinParamType::StringScalar,
114    arity: BuiltinParamArity::Variadic,
115    default: None,
116    description: "Name-value option name.",
117};
118
119const MAX_PARAM_OPTION_VALUE: BuiltinParamDescriptor = BuiltinParamDescriptor {
120    name: "optionValue",
121    ty: BuiltinParamType::Any,
122    arity: BuiltinParamArity::Variadic,
123    default: None,
124    description: "Name-value option value.",
125};
126
127const MAX_INPUTS_A: [BuiltinParamDescriptor; 1] = [MAX_PARAM_A];
128const MAX_INPUTS_A_B: [BuiltinParamDescriptor; 2] = [MAX_PARAM_A, MAX_PARAM_B];
129const MAX_INPUTS_A_EMPTY_DIM: [BuiltinParamDescriptor; 3] =
130    [MAX_PARAM_A, MAX_PARAM_EMPTY, MAX_PARAM_DIM];
131const MAX_INPUTS_A_EMPTY_FLAG: [BuiltinParamDescriptor; 3] =
132    [MAX_PARAM_A, MAX_PARAM_EMPTY, MAX_PARAM_REDUCTION_FLAG];
133const MAX_INPUTS_A_EMPTY_NANFLAG: [BuiltinParamDescriptor; 3] =
134    [MAX_PARAM_A, MAX_PARAM_EMPTY, MAX_PARAM_NANFLAG];
135const MAX_INPUTS_A_EMPTY_COMPARISON: [BuiltinParamDescriptor; 4] = [
136    MAX_PARAM_A,
137    MAX_PARAM_EMPTY,
138    MAX_PARAM_COMPARISON_NAME,
139    MAX_PARAM_COMPARISON_VALUE,
140];
141const MAX_INPUTS_A_B_COMPARISON: [BuiltinParamDescriptor; 4] = [
142    MAX_PARAM_A,
143    MAX_PARAM_B,
144    MAX_PARAM_COMPARISON_NAME,
145    MAX_PARAM_COMPARISON_VALUE,
146];
147const MAX_INPUTS_A_EMPTY_OPTIONS: [BuiltinParamDescriptor; 4] = [
148    MAX_PARAM_A,
149    MAX_PARAM_EMPTY,
150    MAX_PARAM_OPTION_NAME,
151    MAX_PARAM_OPTION_VALUE,
152];
153const MAX_INPUTS_A_B_OPTIONS: [BuiltinParamDescriptor; 4] = [
154    MAX_PARAM_A,
155    MAX_PARAM_B,
156    MAX_PARAM_OPTION_NAME,
157    MAX_PARAM_OPTION_VALUE,
158];
159
160const MAX_SIGNATURES: [BuiltinSignatureDescriptor; 22] = [
161    BuiltinSignatureDescriptor {
162        label: "M = max(A)",
163        inputs: &MAX_INPUTS_A,
164        outputs: &MAX_OUTPUT_M,
165    },
166    BuiltinSignatureDescriptor {
167        label: "[M, I] = max(A)",
168        inputs: &MAX_INPUTS_A,
169        outputs: &MAX_OUTPUT_MI,
170    },
171    BuiltinSignatureDescriptor {
172        label: "M = max(A, B)",
173        inputs: &MAX_INPUTS_A_B,
174        outputs: &MAX_OUTPUT_M,
175    },
176    BuiltinSignatureDescriptor {
177        label: "[M, I] = max(A, B)",
178        inputs: &MAX_INPUTS_A_B,
179        outputs: &MAX_OUTPUT_MI,
180    },
181    BuiltinSignatureDescriptor {
182        label: "M = max(A, [], dim)",
183        inputs: &MAX_INPUTS_A_EMPTY_DIM,
184        outputs: &MAX_OUTPUT_M,
185    },
186    BuiltinSignatureDescriptor {
187        label: "[M, I] = max(A, [], dim)",
188        inputs: &MAX_INPUTS_A_EMPTY_DIM,
189        outputs: &MAX_OUTPUT_MI,
190    },
191    BuiltinSignatureDescriptor {
192        label: "M = max(A, [], vecdim)",
193        inputs: &MAX_INPUTS_A_EMPTY_DIM,
194        outputs: &MAX_OUTPUT_M,
195    },
196    BuiltinSignatureDescriptor {
197        label: "[M, I] = max(A, [], vecdim)",
198        inputs: &MAX_INPUTS_A_EMPTY_DIM,
199        outputs: &MAX_OUTPUT_MI,
200    },
201    BuiltinSignatureDescriptor {
202        label: "M = max(A, [], \"all\")",
203        inputs: &MAX_INPUTS_A_EMPTY_FLAG,
204        outputs: &MAX_OUTPUT_M,
205    },
206    BuiltinSignatureDescriptor {
207        label: "[M, I] = max(A, [], \"all\")",
208        inputs: &MAX_INPUTS_A_EMPTY_FLAG,
209        outputs: &MAX_OUTPUT_MI,
210    },
211    BuiltinSignatureDescriptor {
212        label: "M = max(A, [], \"linear\")",
213        inputs: &MAX_INPUTS_A_EMPTY_FLAG,
214        outputs: &MAX_OUTPUT_M,
215    },
216    BuiltinSignatureDescriptor {
217        label: "[M, I] = max(A, [], \"linear\")",
218        inputs: &MAX_INPUTS_A_EMPTY_FLAG,
219        outputs: &MAX_OUTPUT_MI,
220    },
221    BuiltinSignatureDescriptor {
222        label: "M = max(A, [], nanflag)",
223        inputs: &MAX_INPUTS_A_EMPTY_NANFLAG,
224        outputs: &MAX_OUTPUT_M,
225    },
226    BuiltinSignatureDescriptor {
227        label: "[M, I] = max(A, [], nanflag)",
228        inputs: &MAX_INPUTS_A_EMPTY_NANFLAG,
229        outputs: &MAX_OUTPUT_MI,
230    },
231    BuiltinSignatureDescriptor {
232        label: "M = max(A, [], \"ComparisonMethod\", method)",
233        inputs: &MAX_INPUTS_A_EMPTY_COMPARISON,
234        outputs: &MAX_OUTPUT_M,
235    },
236    BuiltinSignatureDescriptor {
237        label: "[M, I] = max(A, [], \"ComparisonMethod\", method)",
238        inputs: &MAX_INPUTS_A_EMPTY_COMPARISON,
239        outputs: &MAX_OUTPUT_MI,
240    },
241    BuiltinSignatureDescriptor {
242        label: "M = max(A, B, \"ComparisonMethod\", method)",
243        inputs: &MAX_INPUTS_A_B_COMPARISON,
244        outputs: &MAX_OUTPUT_M,
245    },
246    BuiltinSignatureDescriptor {
247        label: "[M, I] = max(A, B, \"ComparisonMethod\", method)",
248        inputs: &MAX_INPUTS_A_B_COMPARISON,
249        outputs: &MAX_OUTPUT_MI,
250    },
251    BuiltinSignatureDescriptor {
252        label: "M = max(A, [], optionName, optionValue, ...)",
253        inputs: &MAX_INPUTS_A_EMPTY_OPTIONS,
254        outputs: &MAX_OUTPUT_M,
255    },
256    BuiltinSignatureDescriptor {
257        label: "[M, I] = max(A, [], optionName, optionValue, ...)",
258        inputs: &MAX_INPUTS_A_EMPTY_OPTIONS,
259        outputs: &MAX_OUTPUT_MI,
260    },
261    BuiltinSignatureDescriptor {
262        label: "M = max(A, B, optionName, optionValue, ...)",
263        inputs: &MAX_INPUTS_A_B_OPTIONS,
264        outputs: &MAX_OUTPUT_M,
265    },
266    BuiltinSignatureDescriptor {
267        label: "[M, I] = max(A, B, optionName, optionValue, ...)",
268        inputs: &MAX_INPUTS_A_B_OPTIONS,
269        outputs: &MAX_OUTPUT_MI,
270    },
271];
272
273const MAX_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
274    code: "RM.MAX.INVALID_ARGUMENT",
275    identifier: Some("RunMat:max:InvalidArgument"),
276    when: "Argument grammar, dimensions, or option names/values are invalid.",
277    message: "max: invalid argument",
278};
279
280const MAX_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
281    code: "RM.MAX.INVALID_INPUT",
282    identifier: Some("RunMat:max:InvalidInput"),
283    when: "Input values cannot be converted to supported max domains.",
284    message: "max: invalid input",
285};
286
287const MAX_ERROR_SIZE_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
288    code: "RM.MAX.SIZE_MISMATCH",
289    identifier: Some("RunMat:max:SizeMismatch"),
290    when: "Element-wise operands are not broadcast-compatible.",
291    message: "max: size mismatch",
292};
293
294const MAX_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
295    code: "RM.MAX.INTERNAL",
296    identifier: Some("RunMat:max:Internal"),
297    when: "Execution fails due to gather, provider, allocation, or conversion internals.",
298    message: "max: internal failure",
299};
300
301const MAX_ERRORS: [BuiltinErrorDescriptor; 4] = [
302    MAX_ERROR_INVALID_ARGUMENT,
303    MAX_ERROR_INVALID_INPUT,
304    MAX_ERROR_SIZE_MISMATCH,
305    MAX_ERROR_INTERNAL,
306];
307
308pub const MAX_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
309    signatures: &MAX_SIGNATURES,
310    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
311    completion_policy: BuiltinCompletionPolicy::Public,
312    errors: &MAX_ERRORS,
313};
314
315fn max_descriptor_error_with_message(
316    message: impl Into<String>,
317    error: &'static BuiltinErrorDescriptor,
318) -> RuntimeError {
319    let mut builder = build_runtime_error(message).with_builtin(NAME);
320    if let Some(identifier) = error.identifier {
321        builder = builder.with_identifier(identifier);
322    }
323    builder.build()
324}
325
326fn max_descriptor_error_with_detail(
327    error: &'static BuiltinErrorDescriptor,
328    detail: impl AsRef<str>,
329) -> RuntimeError {
330    max_descriptor_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
331}
332
333fn max_invalid_argument(detail: impl AsRef<str>) -> RuntimeError {
334    max_descriptor_error_with_detail(&MAX_ERROR_INVALID_ARGUMENT, detail)
335}
336
337fn max_invalid_input(detail: impl AsRef<str>) -> RuntimeError {
338    max_descriptor_error_with_detail(&MAX_ERROR_INVALID_INPUT, detail)
339}
340
341fn max_size_mismatch(detail: impl AsRef<str>) -> RuntimeError {
342    max_descriptor_error_with_detail(&MAX_ERROR_SIZE_MISMATCH, detail)
343}
344
345fn max_internal_error(detail: impl AsRef<str>) -> RuntimeError {
346    max_descriptor_error_with_detail(&MAX_ERROR_INTERNAL, detail)
347}
348
349use crate::builtins::common::arg_tokens::tokens_from_values;
350use crate::builtins::common::broadcast::BroadcastPlan;
351use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
352use crate::builtins::common::spec::{
353    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, FusionError,
354    FusionExprContext, FusionKernelTemplate, GpuOpKind, ProviderHook, ReductionNaN,
355    ResidencyPolicy, ScalarType, ShapeRequirements,
356};
357use crate::builtins::common::{
358    gpu_helpers,
359    shape::{is_scalar_shape, normalize_scalar_shape},
360    tensor,
361};
362use crate::builtins::math::reduction::type_resolvers::min_max_type;
363
364#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::reduction::max")]
365pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
366    name: "max",
367    op_kind: GpuOpKind::Reduction,
368    supported_precisions: &[ScalarType::F32, ScalarType::F64],
369    broadcast: BroadcastSemantics::Matlab,
370    provider_hooks: &[
371        ProviderHook::Reduction {
372            name: "reduce_max_dim",
373        },
374        ProviderHook::Reduction {
375            name: "reduce_max",
376        },
377    ],
378    constant_strategy: ConstantStrategy::InlineLiteral,
379    residency: ResidencyPolicy::NewHandle,
380    nan_mode: ReductionNaN::Include,
381    two_pass_threshold: Some(256),
382    workgroup_size: Some(256),
383    accepts_nan_mode: false,
384    notes:
385        "Providers should implement reduce_max_dim / reduce_max. Requests that require omitnan, comparisonmethod overrides, or complex inputs fall back to the host implementation.",
386};
387
388#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::reduction::max")]
389pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
390    name: "max",
391    shape: ShapeRequirements::BroadcastCompatible,
392    constant_strategy: ConstantStrategy::InlineLiteral,
393    elementwise: None,
394    reduction: Some(FusionKernelTemplate {
395        scalar_precisions: &[ScalarType::F32, ScalarType::F64],
396        wgsl_body: |ctx: &FusionExprContext| {
397            let input = ctx.inputs.first().ok_or(FusionError::MissingInput(0))?;
398            Ok(format!("accumulator = max(accumulator, {input});"))
399        },
400    }),
401    emits_nan: true,
402    notes: "Fusion planner emits canonical reduction kernels; providers may substitute custom WGSL via reduce_max_dim hooks.",
403};
404
405/// Evaluation artifact returned by `max` that carries both values and indices.
406#[derive(Debug, Clone)]
407pub struct MaxEvaluation {
408    values: Value,
409    indices: Value,
410}
411
412impl MaxEvaluation {
413    /// Consume the evaluation and return only the maximum values (single-output call).
414    pub fn into_value(self) -> Value {
415        self.values
416    }
417
418    /// Consume the evaluation and return both maxima and indices.
419    pub fn into_pair(self) -> (Value, Value) {
420        (self.values, self.indices)
421    }
422
423    /// Peek at the indices without consuming.
424    pub fn indices_value(&self) -> Value {
425        self.indices.clone()
426    }
427}
428
429#[runtime_builtin(
430    name = "max",
431    category = "math/reduction",
432    summary = "Return maximum elements along dimensions or pairwise comparisons.",
433    keywords = "max,maximum,reduction,gpu,comparisonmethod,omitnan",
434    accel = "reduction",
435    type_resolver(max_type),
436    descriptor(crate::builtins::math::reduction::max::MAX_DESCRIPTOR),
437    builtin_path = "crate::builtins::math::reduction::max"
438)]
439async fn max_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
440    let eval = evaluate(value, &rest).await?;
441    if let Some(out_count) = crate::output_count::current_output_count() {
442        if out_count == 0 {
443            return Ok(Value::OutputList(Vec::new()));
444        }
445        if out_count == 1 {
446            return Ok(Value::OutputList(vec![eval.into_value()]));
447        }
448        let (values, indices) = eval.into_pair();
449        return Ok(crate::output_count::output_list_with_padding(
450            out_count,
451            vec![values, indices],
452        ));
453    }
454    Ok(eval.into_value())
455}
456
457/// Evaluate the builtin once and expose both outputs (value + indices).
458pub async fn evaluate(value: Value, rest: &[Value]) -> BuiltinResult<MaxEvaluation> {
459    let parsed = parse_call(rest).await?;
460    if std::env::var("RUNMAT_DEBUG_MAX").is_ok() {
461        let call_label = match &parsed {
462            ParsedCall::Reduction(_) => "reduction",
463            ParsedCall::Elementwise(_) => "elementwise",
464        };
465        let first_arg = rest.first().map(debug_value_kind).unwrap_or("None");
466        tracing::debug!(
467            call_type = call_label,
468            rest_len = rest.len(),
469            first_arg = first_arg,
470            "[runmat-debug-max]"
471        );
472    }
473    match parsed {
474        ParsedCall::Elementwise(args) => elementwise_max(value, args).await,
475        ParsedCall::Reduction(args) => reduction_max(value, args).await,
476    }
477}
478
479#[derive(Debug, Clone)]
480enum ParsedCall {
481    Reduction(ReductionArgs),
482    Elementwise(ElementwiseArgs),
483}
484
485#[derive(Debug, Clone)]
486struct ReductionArgs {
487    selection: DimSelection,
488    nan_mode: ReductionNaN,
489    comparison: ComparisonMethod,
490    linear_index: bool,
491}
492
493impl Default for ReductionArgs {
494    fn default() -> Self {
495        Self {
496            selection: DimSelection::Auto,
497            nan_mode: ReductionNaN::Include,
498            comparison: ComparisonMethod::Auto,
499            linear_index: false,
500        }
501    }
502}
503
504#[derive(Debug, Clone)]
505enum DimSelection {
506    Auto,
507    Dim(usize),
508    Vec(Vec<usize>),
509    All,
510}
511
512#[derive(Debug, Clone, Copy, PartialEq, Eq)]
513enum ComparisonMethod {
514    Auto,
515    Real,
516    Abs,
517}
518
519#[derive(Debug, Clone)]
520struct ElementwiseArgs {
521    other: Value,
522    comparison: ComparisonMethod,
523}
524
525async fn parse_call(rest: &[Value]) -> BuiltinResult<ParsedCall> {
526    if rest.is_empty() {
527        return Ok(ParsedCall::Reduction(ReductionArgs::default()));
528    }
529
530    let first = &rest[0];
531    if !is_empty_placeholder(first) {
532        let comparison = parse_elementwise_options(&rest[1..])?;
533        return Ok(ParsedCall::Elementwise(ElementwiseArgs {
534            other: first.clone(),
535            comparison,
536        }));
537    }
538
539    let mut args = ReductionArgs::default();
540    parse_reduction_options(&mut args, &rest[1..]).await?;
541    Ok(ParsedCall::Reduction(args))
542}
543
544fn debug_value_kind(value: &Value) -> &'static str {
545    match value {
546        Value::Num(_) => "Num",
547        Value::Int(_) => "Int",
548        Value::Bool(_) => "Bool",
549        Value::Tensor(t) => {
550            if t.data.is_empty() {
551                "Tensor(empty)"
552            } else {
553                "Tensor"
554            }
555        }
556        Value::GpuTensor(_) => "GpuTensor",
557        Value::String(_) => "String",
558        Value::CharArray(_) => "CharArray",
559        Value::StringArray(sa) => {
560            if sa.data.is_empty() {
561                "StringArray(empty)"
562            } else {
563                "StringArray"
564            }
565        }
566        Value::LogicalArray(l) => {
567            if l.data.is_empty() {
568                "LogicalArray(empty)"
569            } else {
570                "LogicalArray"
571            }
572        }
573        Value::Cell(c) => {
574            if c.data.is_empty() {
575                "Cell(empty)"
576            } else {
577                "Cell"
578            }
579        }
580        _ => "Other",
581    }
582}
583
584fn is_empty_placeholder(value: &Value) -> bool {
585    match value {
586        Value::Tensor(t) => t.data.is_empty(),
587        Value::LogicalArray(l) => l.data.is_empty(),
588        Value::StringArray(sa) => sa.data.is_empty(),
589        Value::CharArray(ca) => ca.data.is_empty(),
590        Value::Cell(cell) => cell.data.is_empty(),
591        Value::String(s) => s.is_empty(),
592        _ => false,
593    }
594}
595
596async fn parse_reduction_options(args: &mut ReductionArgs, rest: &[Value]) -> BuiltinResult<()> {
597    let mut idx = 0usize;
598    let mut selection_set = !matches!(args.selection, DimSelection::Auto);
599    let mut comparison_set = matches!(args.comparison, ComparisonMethod::Auto);
600    let tokens = tokens_from_values(rest);
601    while idx < rest.len() {
602        if let Some(crate::builtins::common::arg_tokens::ArgToken::String(text)) = tokens.get(idx) {
603            match text.as_str() {
604                "omitnan" => {
605                    args.nan_mode = ReductionNaN::Omit;
606                    idx += 1;
607                    continue;
608                }
609                "includenan" => {
610                    args.nan_mode = ReductionNaN::Include;
611                    idx += 1;
612                    continue;
613                }
614                "all" => {
615                    if selection_set {
616                        return Err(max_invalid_argument(
617                            "max: 'all' cannot be combined with an explicit dimension",
618                        ));
619                    }
620                    args.selection = DimSelection::All;
621                    selection_set = true;
622                    idx += 1;
623                    continue;
624                }
625                _ => {}
626            }
627        }
628        if let Some(keyword) = keyword_of(&rest[idx]) {
629            match keyword.as_str() {
630                "omitnan" => {
631                    args.nan_mode = ReductionNaN::Omit;
632                    idx += 1;
633                    continue;
634                }
635                "includenan" => {
636                    args.nan_mode = ReductionNaN::Include;
637                    idx += 1;
638                    continue;
639                }
640                "all" => {
641                    if selection_set {
642                        return Err(max_invalid_argument(
643                            "max: 'all' cannot be combined with an explicit dimension",
644                        ));
645                    }
646                    args.selection = DimSelection::All;
647                    selection_set = true;
648                    idx += 1;
649                    continue;
650                }
651                "linear" => {
652                    if selection_set {
653                        return Err(max_invalid_argument(
654                            "max: 'linear' cannot be combined with an explicit dimension",
655                        ));
656                    }
657                    args.selection = DimSelection::All;
658                    args.linear_index = true;
659                    selection_set = true;
660                    idx += 1;
661                    continue;
662                }
663                "comparisonmethod" => {
664                    let Some(value) = rest.get(idx + 1) else {
665                        return Err(max_invalid_argument(
666                            "max: expected a value after 'ComparisonMethod'",
667                        ));
668                    };
669                    args.comparison = parse_comparison_method(value)?;
670                    comparison_set = true;
671                    idx += 2;
672                    continue;
673                }
674                _ => {}
675            }
676        }
677
678        if !selection_set {
679            if let Some(selection) = parse_dimension_value(&rest[idx]).await? {
680                args.selection = selection;
681                selection_set = true;
682                idx += 1;
683                continue;
684            }
685        }
686
687        return Err(max_invalid_argument(format!(
688            "max: unrecognised argument {:?}",
689            rest[idx]
690        )));
691    }
692
693    if !comparison_set {
694        args.comparison = ComparisonMethod::Auto;
695    }
696
697    Ok(())
698}
699
700fn parse_elementwise_options(rest: &[Value]) -> BuiltinResult<ComparisonMethod> {
701    let mut comparison = ComparisonMethod::Auto;
702    let mut comparison_set = false;
703    let mut idx = 0usize;
704    while idx < rest.len() {
705        if let Some(keyword) = keyword_of(&rest[idx]) {
706            match keyword.as_str() {
707                "comparisonmethod" => {
708                    let Some(value) = rest.get(idx + 1) else {
709                        return Err(max_invalid_argument(
710                            "max: expected a value after 'ComparisonMethod'",
711                        ));
712                    };
713                    comparison = parse_comparison_method(value)?;
714                    comparison_set = true;
715                    idx += 2;
716                    continue;
717                }
718                "omitnan" | "includenan" | "all" | "linear" => {
719                    return Err(max_invalid_argument(format!(
720                        "max: '{}' is only supported for reduction calls",
721                        keyword
722                    )));
723                }
724                _ => {}
725            }
726        }
727        return Err(max_invalid_argument(format!(
728            "max: unrecognised argument {:?}",
729            rest[idx]
730        )));
731    }
732    if !comparison_set {
733        comparison = ComparisonMethod::Auto;
734    }
735    Ok(comparison)
736}
737
738fn parse_comparison_method(value: &Value) -> BuiltinResult<ComparisonMethod> {
739    let Some(keyword) = keyword_of(value) else {
740        return Err(max_invalid_argument(
741            "max: 'ComparisonMethod' expects a string value",
742        ));
743    };
744    match keyword.as_str() {
745        "auto" => Ok(ComparisonMethod::Auto),
746        "abs" | "magnitude" => Ok(ComparisonMethod::Abs),
747        "real" => Ok(ComparisonMethod::Real),
748        other => Err(max_invalid_argument(format!(
749            "max: unsupported ComparisonMethod '{other}'"
750        ))),
751    }
752}
753
754async fn parse_dimension_value(value: &Value) -> BuiltinResult<Option<DimSelection>> {
755    match value {
756        Value::Int(_) | Value::Num(_) => tensor::dimension_from_value_async(value, "max", false)
757            .await
758            .map_err(map_scalar_dim_error)
759            .map(|dim| dim.map(DimSelection::Dim)),
760        Value::Tensor(t) => parse_dimension_tensor(value, &t.shape).await,
761        Value::LogicalArray(logical) => parse_dimension_tensor(value, &logical.shape).await,
762        Value::GpuTensor(_) => Err(max_invalid_argument(
763            "max: dimension arguments must reside on the host",
764        )),
765        _ => Ok(None),
766    }
767}
768
769async fn parse_dimension_tensor(
770    value: &Value,
771    shape: &[usize],
772) -> BuiltinResult<Option<DimSelection>> {
773    if tensor::element_count(shape) == 0 {
774        return Ok(Some(DimSelection::Auto));
775    }
776    let is_vector = shape.len() == 1
777        || shape.get(0).copied().unwrap_or(1) == 1
778        || shape.get(1).copied().unwrap_or(1) == 1;
779    if !is_vector {
780        return Err(max_invalid_argument(
781            "max: dimension vector must be a row or column vector",
782        ));
783    }
784    let dims = tensor::dims_from_value_async(value)
785        .await
786        .map_err(map_vector_dim_error)?;
787    let Some(dims) = dims else {
788        return Ok(None);
789    };
790    if dims.is_empty() {
791        return Ok(Some(DimSelection::Auto));
792    }
793    let mut seen = BTreeSet::new();
794    let mut uniq = Vec::with_capacity(dims.len());
795    for dim in dims {
796        if dim < 1 {
797            return Err(max_invalid_argument("max: dimension indices must be >= 1"));
798        }
799        if seen.insert(dim) {
800            uniq.push(dim);
801        }
802    }
803    Ok(Some(DimSelection::Vec(uniq)))
804}
805
806fn map_scalar_dim_error(message: String) -> RuntimeError {
807    if message.contains("integer") {
808        return max_invalid_argument("max: dimension must be integral");
809    }
810    max_invalid_argument(message)
811}
812
813fn map_vector_dim_error(message: String) -> RuntimeError {
814    if message.contains("non-negative") {
815        return max_invalid_argument("max: dimension indices must be >= 1");
816    }
817    if message.contains("finite") {
818        return max_invalid_argument("max: dimension entries must be finite");
819    }
820    if message.contains("integer") {
821        return max_invalid_argument("max: dimension entries must be integers");
822    }
823    max_invalid_argument(message)
824}
825
826async fn reduction_max(value: Value, args: ReductionArgs) -> BuiltinResult<MaxEvaluation> {
827    match value {
828        Value::GpuTensor(handle) => {
829            if let Some(eval) = reduction_max_gpu(handle.clone(), &args).await? {
830                return Ok(eval);
831            }
832            // Fall back to host if GPU path is unavailable.
833            let tensor = gpu_helpers::gather_tensor_async(&handle)
834                .await
835                .map_err(|e| max_internal_error(format!("max: {e}")))?;
836            reduction_max_host(Value::Tensor(tensor), &args)
837        }
838        other => reduction_max_host(other, &args),
839    }
840}
841
842async fn reduction_max_gpu(
843    handle: GpuTensorHandle,
844    args: &ReductionArgs,
845) -> BuiltinResult<Option<MaxEvaluation>> {
846    #[cfg(all(test, feature = "wgpu"))]
847    {
848        if handle.device_id != 0 {
849            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
850                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
851            );
852        }
853    }
854    if args.nan_mode == ReductionNaN::Omit {
855        log::trace!("max: gpu path disabled (nan_mode=omit)");
856        return Ok(None);
857    }
858    if args.comparison != ComparisonMethod::Auto {
859        log::trace!("max: gpu path disabled (comparison != auto)");
860        return Ok(None);
861    }
862    if args.linear_index {
863        log::trace!("max: gpu path disabled (linear_index=true)");
864        return Ok(None);
865    }
866    let provider = match runmat_accelerate_api::provider() {
867        Some(p) => p,
868        None => {
869            log::trace!(
870                "max: gpu path unavailable (provider() is None) handle_shape={:?} device_id={}",
871                handle.shape,
872                handle.device_id
873            );
874            return Ok(None);
875        }
876    };
877    let target_dim = match args.selection {
878        DimSelection::Auto => default_dimension_from_shape(&handle.shape),
879        DimSelection::Dim(dim) => dim,
880        DimSelection::Vec(ref dims) if dims.len() == 1 => dims[0],
881        DimSelection::All => {
882            if handle.shape.len() <= 1 {
883                1
884            } else {
885                return Ok(None);
886            }
887        }
888        _ => return Ok(None),
889    };
890    if target_dim == 0 {
891        return Ok(None);
892    }
893    // MATLAB dimensions are 1-based; `reduce_max_dim` expects zero-based.
894    let zero_based = target_dim.saturating_sub(1);
895    if zero_based >= handle.shape.len() {
896        return Ok(None);
897    }
898    log::trace!(
899        "max: attempting reduce_max_dim dim={} (zero_based={}) shape={:?} device_id={}",
900        target_dim,
901        zero_based,
902        handle.shape,
903        handle.device_id
904    );
905    match provider.reduce_max_dim(&handle, zero_based).await {
906        Ok(ReduceDimResult { values, indices }) => Ok(Some(MaxEvaluation {
907            values: Value::GpuTensor(values),
908            indices: Value::GpuTensor(indices),
909        })),
910        Err(err) => {
911            log::trace!("max: reduce_max_dim failed: {err}");
912            Ok(None)
913        }
914    }
915}
916
917fn reduction_max_host(value: Value, args: &ReductionArgs) -> BuiltinResult<MaxEvaluation> {
918    match materialize_for_max("max", value)? {
919        InputData::Real(tensor) => reduce_real_tensor(tensor, args),
920        InputData::Complex(tensor) => reduce_complex_tensor(tensor, args),
921    }
922}
923
924enum InputData {
925    Real(Tensor),
926    Complex(ComplexTensor),
927}
928
929fn materialize_for_max(name: &str, value: Value) -> BuiltinResult<InputData> {
930    match value {
931        Value::Tensor(t) => Ok(InputData::Real(t)),
932        Value::LogicalArray(logical) => {
933            let tensor = tensor::logical_to_tensor(&logical).map_err(max_invalid_input)?;
934            Ok(InputData::Real(tensor))
935        }
936        Value::Num(n) => {
937            let tensor = Tensor::new(vec![n], vec![1, 1])
938                .map_err(|e| max_internal_error(format!("{name}: {e}")))?;
939            Ok(InputData::Real(tensor))
940        }
941        Value::Int(i) => {
942            let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1])
943                .map_err(|e| max_internal_error(format!("{name}: {e}")))?;
944            Ok(InputData::Real(tensor))
945        }
946        Value::Bool(b) => {
947            let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
948                .map_err(|e| max_internal_error(format!("{name}: {e}")))?;
949            Ok(InputData::Real(tensor))
950        }
951        Value::Complex(re, im) => {
952            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
953                .map_err(|e| max_internal_error(format!("{name}: {e}")))?;
954            Ok(InputData::Complex(tensor))
955        }
956        Value::ComplexTensor(ct) => Ok(InputData::Complex(ct)),
957        Value::String(_)
958        | Value::StringArray(_)
959        | Value::CharArray(_)
960        | Value::SparseTensor(_)
961        | Value::Cell(_) => Err(max_invalid_input(format!(
962            "{name}: expected numeric or logical dense input"
963        ))),
964        Value::GpuTensor(_) => Err(max_internal_error(format!(
965            "{name}: internal error – GPU tensors must be gathered before host execution"
966        ))),
967        Value::Object(_) | Value::HandleObject(_) | Value::Struct(_) | Value::Listener(_) => {
968            Err(max_invalid_input(format!("{name}: unsupported input type")))
969        }
970        Value::FunctionHandle(_)
971        | Value::ExternalFunctionHandle(_)
972        | Value::MethodFunctionHandle(_)
973        | Value::BoundFunctionHandle { .. }
974        | Value::Closure(_)
975        | Value::ClassRef(_)
976        | Value::MException(_)
977        | Value::OutputList(_) => Err(max_invalid_input(format!("{name}: unsupported input type"))),
978    }
979}
980
981fn reduce_real_tensor(tensor: Tensor, args: &ReductionArgs) -> BuiltinResult<MaxEvaluation> {
982    let shape = tensor.shape.clone();
983    if tensor.data.is_empty() {
984        let output_shape = resolve_output_shape(&shape, &args.selection, &[])?;
985        let values = Tensor::new(Vec::new(), output_shape.clone())
986            .map_err(|e| max_internal_error(format!("max: {e}")))?;
987        let indices = Tensor::new(Vec::new(), output_shape)
988            .map_err(|e| max_internal_error(format!("max: {e}")))?;
989        return Ok(MaxEvaluation {
990            values: tensor::tensor_into_value(values),
991            indices: tensor::tensor_into_value(indices),
992        });
993    }
994    let resolved = resolve_reduction_dims(&shape, &args.selection)?;
995    let output_shape = resolved.output_shape.clone();
996    let output_len = tensor::element_count(&output_shape);
997
998    if output_len == 0 {
999        let values = Tensor::new(Vec::new(), output_shape.clone())
1000            .map_err(|e| max_internal_error(format!("max: {e}")))?;
1001        let indices = Tensor::new(Vec::new(), output_shape)
1002            .map_err(|e| max_internal_error(format!("max: {e}")))?;
1003        return Ok(MaxEvaluation {
1004            values: tensor::tensor_into_value(values),
1005            indices: tensor::tensor_into_value(indices),
1006        });
1007    }
1008
1009    let strides = compute_strides(&shape);
1010    let output_strides = compute_strides(&output_shape);
1011    let dims_mask = resolved.dims_mask.clone();
1012    let reduce_strides = resolved.reduce_strides.clone();
1013
1014    let mut best = vec![BestReal::new(); output_len];
1015    let mut coords = vec![0usize; shape.len()];
1016    for &value in &tensor.data {
1017        let out_idx = map_output_index(&coords, &output_strides, &dims_mask);
1018        let reduce_idx = map_reduce_index(
1019            &coords,
1020            &resolved.reduced_dims,
1021            &reduce_strides,
1022            resolved.reduce_all,
1023        );
1024        let full_idx = map_linear_index(&coords, &strides);
1025
1026        update_best_real(
1027            &mut best[out_idx],
1028            value,
1029            reduce_idx,
1030            full_idx,
1031            args.nan_mode,
1032            args.comparison,
1033        );
1034        increment_coords(&mut coords, &shape);
1035    }
1036
1037    let mut values = vec![0.0f64; output_len];
1038    let mut indices = vec![0.0f64; output_len];
1039
1040    for (i, entry) in best.iter().enumerate() {
1041        if entry.nan_fixed {
1042            values[i] = f64::NAN;
1043            indices[i] = if args.linear_index || resolved.reduce_all {
1044                (entry.full_index + 1) as f64
1045            } else if resolved.reduced_dims.is_empty() {
1046                1.0
1047            } else {
1048                (entry.reduce_index + 1) as f64
1049            };
1050            continue;
1051        }
1052        if !entry.has_value {
1053            values[i] = f64::NAN;
1054            indices[i] = f64::NAN;
1055            continue;
1056        }
1057        values[i] = entry.value;
1058        indices[i] = if args.linear_index || resolved.reduce_all {
1059            (entry.full_index + 1) as f64
1060        } else if resolved.reduced_dims.is_empty() {
1061            1.0
1062        } else {
1063            (entry.reduce_index + 1) as f64
1064        };
1065    }
1066
1067    let value_tensor = Tensor::new(values, output_shape.clone())
1068        .map_err(|e| max_internal_error(format!("max: {e}")))?;
1069    let index_tensor =
1070        Tensor::new(indices, output_shape).map_err(|e| max_internal_error(format!("max: {e}")))?;
1071
1072    Ok(MaxEvaluation {
1073        values: tensor::tensor_into_value(value_tensor),
1074        indices: tensor::tensor_into_value(index_tensor),
1075    })
1076}
1077
1078fn reduce_complex_tensor(
1079    tensor: ComplexTensor,
1080    args: &ReductionArgs,
1081) -> BuiltinResult<MaxEvaluation> {
1082    let shape = tensor.shape.clone();
1083    if tensor.data.is_empty() {
1084        let output_shape = resolve_output_shape(&shape, &args.selection, &[])?;
1085        let values = ComplexTensor::new(Vec::new(), output_shape.clone())
1086            .map_err(|e| max_internal_error(format!("max: {e}")))?;
1087        let indices = Tensor::new(Vec::new(), output_shape)
1088            .map_err(|e| max_internal_error(format!("max: {e}")))?;
1089        return Ok(MaxEvaluation {
1090            values: complex_tensor_into_value(values),
1091            indices: tensor::tensor_into_value(indices),
1092        });
1093    }
1094
1095    let resolved = resolve_reduction_dims(&shape, &args.selection)?;
1096    let output_shape = resolved.output_shape.clone();
1097    let output_len = tensor::element_count(&output_shape);
1098
1099    if output_len == 0 {
1100        let values = ComplexTensor::new(Vec::new(), output_shape.clone())
1101            .map_err(|e| max_internal_error(format!("max: {e}")))?;
1102        let indices = Tensor::new(Vec::new(), output_shape)
1103            .map_err(|e| max_internal_error(format!("max: {e}")))?;
1104        return Ok(MaxEvaluation {
1105            values: complex_tensor_into_value(values),
1106            indices: tensor::tensor_into_value(indices),
1107        });
1108    }
1109
1110    let strides = compute_strides(&shape);
1111    let output_strides = compute_strides(&output_shape);
1112    let dims_mask = resolved.dims_mask.clone();
1113    let reduce_strides = resolved.reduce_strides.clone();
1114
1115    let mut best = vec![BestComplex::new(); output_len];
1116    let mut coords = vec![0usize; shape.len()];
1117
1118    for &(re, im) in &tensor.data {
1119        let out_idx = map_output_index(&coords, &output_strides, &dims_mask);
1120        let reduce_idx = map_reduce_index(
1121            &coords,
1122            &resolved.reduced_dims,
1123            &reduce_strides,
1124            resolved.reduce_all,
1125        );
1126        let full_idx = map_linear_index(&coords, &strides);
1127        update_best_complex(
1128            &mut best[out_idx],
1129            (re, im),
1130            reduce_idx,
1131            full_idx,
1132            args.nan_mode,
1133            args.comparison,
1134        );
1135        increment_coords(&mut coords, &shape);
1136    }
1137
1138    let mut values = vec![(0.0f64, 0.0f64); output_len];
1139    let mut indices = vec![0.0f64; output_len];
1140
1141    for (i, entry) in best.iter().enumerate() {
1142        if entry.nan_fixed {
1143            values[i] = (f64::NAN, f64::NAN);
1144            indices[i] = if args.linear_index || resolved.reduce_all {
1145                (entry.full_index + 1) as f64
1146            } else if resolved.reduced_dims.is_empty() {
1147                1.0
1148            } else {
1149                (entry.reduce_index + 1) as f64
1150            };
1151            continue;
1152        }
1153        if !entry.has_value {
1154            values[i] = (f64::NAN, f64::NAN);
1155            indices[i] = f64::NAN;
1156            continue;
1157        }
1158        values[i] = entry.value;
1159        indices[i] = if args.linear_index || resolved.reduce_all {
1160            (entry.full_index + 1) as f64
1161        } else if resolved.reduced_dims.is_empty() {
1162            1.0
1163        } else {
1164            (entry.reduce_index + 1) as f64
1165        };
1166    }
1167
1168    let value_tensor = ComplexTensor::new(values, output_shape.clone())
1169        .map_err(|e| max_internal_error(format!("max: {e}")))?;
1170    let index_tensor =
1171        Tensor::new(indices, output_shape).map_err(|e| max_internal_error(format!("max: {e}")))?;
1172    Ok(MaxEvaluation {
1173        values: complex_tensor_into_value(value_tensor),
1174        indices: tensor::tensor_into_value(index_tensor),
1175    })
1176}
1177
1178#[derive(Debug, Clone)]
1179struct BestReal {
1180    value: f64,
1181    reduce_index: usize,
1182    full_index: usize,
1183    has_value: bool,
1184    nan_fixed: bool,
1185}
1186
1187impl BestReal {
1188    fn new() -> Self {
1189        Self {
1190            value: 0.0,
1191            reduce_index: 0,
1192            full_index: 0,
1193            has_value: false,
1194            nan_fixed: false,
1195        }
1196    }
1197}
1198
1199#[derive(Debug, Clone)]
1200struct BestComplex {
1201    value: (f64, f64),
1202    reduce_index: usize,
1203    full_index: usize,
1204    has_value: bool,
1205    nan_fixed: bool,
1206}
1207
1208impl BestComplex {
1209    fn new() -> Self {
1210        Self {
1211            value: (0.0, 0.0),
1212            reduce_index: 0,
1213            full_index: 0,
1214            has_value: false,
1215            nan_fixed: false,
1216        }
1217    }
1218}
1219
1220fn resolve_output_shape(
1221    shape: &[usize],
1222    selection: &DimSelection,
1223    reduced_dims: &[usize],
1224) -> BuiltinResult<Vec<usize>> {
1225    if is_scalar_shape(shape) {
1226        return Ok(normalize_scalar_shape(shape));
1227    }
1228    let mut output = shape.to_vec();
1229    match selection {
1230        DimSelection::All => {
1231            output.fill(1);
1232        }
1233        _ => {
1234            for &dim in reduced_dims {
1235                if dim < output.len() {
1236                    output[dim] = 1;
1237                }
1238            }
1239        }
1240    }
1241    Ok(output)
1242}
1243
1244struct ResolvedDims {
1245    output_shape: Vec<usize>,
1246    reduced_dims: Vec<usize>,
1247    reduce_all: bool,
1248    dims_mask: Vec<bool>,
1249    reduce_strides: Vec<usize>,
1250}
1251
1252fn resolve_reduction_dims(
1253    shape: &[usize],
1254    selection: &DimSelection,
1255) -> BuiltinResult<ResolvedDims> {
1256    if is_scalar_shape(shape) {
1257        return Ok(ResolvedDims {
1258            output_shape: normalize_scalar_shape(shape),
1259            reduced_dims: Vec::new(),
1260            reduce_all: true,
1261            dims_mask: Vec::new(),
1262            reduce_strides: Vec::new(),
1263        });
1264    }
1265
1266    let mut reduced_dims = match selection {
1267        DimSelection::Auto => {
1268            let mut dim = None;
1269            for (index, &len) in shape.iter().enumerate() {
1270                if len > 1 {
1271                    dim = Some(index);
1272                    break;
1273                }
1274            }
1275            vec![dim.unwrap_or(0)]
1276        }
1277        DimSelection::Dim(dim) => {
1278            if *dim == 0 {
1279                return Err(max_invalid_argument("max: dimension must be >= 1"));
1280            }
1281            let index = dim.saturating_sub(1);
1282            if index >= shape.len() {
1283                Vec::new()
1284            } else {
1285                vec![index]
1286            }
1287        }
1288        DimSelection::Vec(dims) => {
1289            if dims.is_empty() {
1290                Vec::new()
1291            } else {
1292                dims.iter()
1293                    .filter_map(|dim| {
1294                        if *dim == 0 {
1295                            None
1296                        } else {
1297                            let idx = dim - 1;
1298                            if idx < shape.len() {
1299                                Some(idx)
1300                            } else {
1301                                None
1302                            }
1303                        }
1304                    })
1305                    .collect()
1306            }
1307        }
1308        DimSelection::All => (0..shape.len()).collect(),
1309    };
1310
1311    reduced_dims.sort_unstable();
1312    reduced_dims.dedup();
1313
1314    let reduce_all = !reduced_dims.is_empty()
1315        && reduced_dims.len() == shape.len()
1316        && reduced_dims.iter().enumerate().all(|(i, &d)| i == d);
1317
1318    let output_shape = resolve_output_shape(shape, selection, &reduced_dims)?;
1319    let mut dims_mask = vec![false; shape.len()];
1320    for &dim in &reduced_dims {
1321        if dim < dims_mask.len() {
1322            dims_mask[dim] = true;
1323        }
1324    }
1325    let reduce_strides = compute_subspace_strides(shape, &reduced_dims);
1326
1327    Ok(ResolvedDims {
1328        output_shape,
1329        reduced_dims,
1330        reduce_all,
1331        dims_mask,
1332        reduce_strides,
1333    })
1334}
1335
1336fn compute_strides(shape: &[usize]) -> Vec<usize> {
1337    let mut strides = Vec::with_capacity(shape.len());
1338    let mut stride = 1usize;
1339    for &len in shape {
1340        strides.push(stride);
1341        stride = stride.saturating_mul(len.max(1));
1342    }
1343    strides
1344}
1345
1346fn compute_subspace_strides(shape: &[usize], dims: &[usize]) -> Vec<usize> {
1347    if dims.is_empty() {
1348        return Vec::new();
1349    }
1350    let mut strides = Vec::with_capacity(dims.len());
1351    let mut accum = 1usize;
1352    for &dim in dims {
1353        let len = shape.get(dim).copied().unwrap_or(1).max(1);
1354        strides.push(accum);
1355        accum = accum.saturating_mul(len);
1356    }
1357    strides
1358}
1359
1360fn map_output_index(coords: &[usize], output_strides: &[usize], dims_mask: &[bool]) -> usize {
1361    if coords.is_empty() {
1362        return 0;
1363    }
1364    let mut index = 0usize;
1365    for (dim, stride) in output_strides.iter().enumerate() {
1366        let coord = if *dims_mask.get(dim).unwrap_or(&false) {
1367            0
1368        } else {
1369            coords[dim]
1370        };
1371        index = index.saturating_add(coord.saturating_mul(*stride));
1372    }
1373    index
1374}
1375
1376fn map_reduce_index(
1377    coords: &[usize],
1378    reduced_dims: &[usize],
1379    reduce_strides: &[usize],
1380    reduce_all: bool,
1381) -> usize {
1382    if reduced_dims.is_empty() {
1383        return 0;
1384    }
1385    if reduce_all {
1386        // When all dimensions are reduced, the full index is used separately.
1387        return 0;
1388    }
1389    let mut index = 0usize;
1390    for (pos, &dim) in reduced_dims.iter().enumerate() {
1391        if let Some(coord) = coords.get(dim) {
1392            if let Some(stride) = reduce_strides.get(pos) {
1393                index = index.saturating_add(coord.saturating_mul(*stride));
1394            }
1395        }
1396    }
1397    index
1398}
1399
1400fn map_linear_index(coords: &[usize], strides: &[usize]) -> usize {
1401    coords
1402        .iter()
1403        .zip(strides.iter())
1404        .fold(0usize, |acc, (&coord, &stride)| {
1405            acc.saturating_add(coord.saturating_mul(stride))
1406        })
1407}
1408
1409fn increment_coords(coords: &mut [usize], shape: &[usize]) {
1410    for dim in 0..coords.len() {
1411        if shape[dim] == 0 {
1412            continue;
1413        }
1414        coords[dim] += 1;
1415        if coords[dim] < shape[dim] {
1416            break;
1417        }
1418        coords[dim] = 0;
1419    }
1420}
1421
1422fn update_best_real(
1423    best: &mut BestReal,
1424    value: f64,
1425    reduce_index: usize,
1426    full_index: usize,
1427    nan_mode: ReductionNaN,
1428    comparison: ComparisonMethod,
1429) {
1430    if value.is_nan() {
1431        match nan_mode {
1432            ReductionNaN::Include => {
1433                if !best.nan_fixed {
1434                    best.value = f64::NAN;
1435                    best.reduce_index = reduce_index;
1436                    best.full_index = full_index;
1437                    best.has_value = true;
1438                    best.nan_fixed = true;
1439                }
1440            }
1441            ReductionNaN::Omit => {}
1442        }
1443        return;
1444    }
1445    if best.nan_fixed {
1446        return;
1447    }
1448
1449    if !best.has_value {
1450        best.value = value;
1451        best.reduce_index = reduce_index;
1452        best.full_index = full_index;
1453        best.has_value = true;
1454        return;
1455    }
1456
1457    if should_replace_real(best.value, value, comparison) {
1458        best.value = value;
1459        best.reduce_index = reduce_index;
1460        best.full_index = full_index;
1461    }
1462}
1463
1464fn update_best_complex(
1465    best: &mut BestComplex,
1466    value: (f64, f64),
1467    reduce_index: usize,
1468    full_index: usize,
1469    nan_mode: ReductionNaN,
1470    comparison: ComparisonMethod,
1471) {
1472    if value.0.is_nan() || value.1.is_nan() {
1473        match nan_mode {
1474            ReductionNaN::Include => {
1475                if !best.nan_fixed {
1476                    best.value = (f64::NAN, f64::NAN);
1477                    best.reduce_index = reduce_index;
1478                    best.full_index = full_index;
1479                    best.has_value = true;
1480                    best.nan_fixed = true;
1481                }
1482            }
1483            ReductionNaN::Omit => {}
1484        }
1485        return;
1486    }
1487    if best.nan_fixed {
1488        return;
1489    }
1490
1491    if !best.has_value {
1492        best.value = value;
1493        best.reduce_index = reduce_index;
1494        best.full_index = full_index;
1495        best.has_value = true;
1496        return;
1497    }
1498
1499    if should_replace_complex(best.value, value, comparison) {
1500        best.value = value;
1501        best.reduce_index = reduce_index;
1502        best.full_index = full_index;
1503    }
1504}
1505
1506fn should_replace_real(current: f64, candidate: f64, comparison: ComparisonMethod) -> bool {
1507    match comparison {
1508        ComparisonMethod::Auto | ComparisonMethod::Real => {
1509            if candidate > current {
1510                return true;
1511            }
1512            if candidate < current {
1513                return false;
1514            }
1515            if candidate == 0.0 && current == 0.0 {
1516                return candidate.is_sign_positive() && !current.is_sign_positive();
1517            }
1518            false
1519        }
1520        ComparisonMethod::Abs => {
1521            let curr_abs = current.abs();
1522            let cand_abs = candidate.abs();
1523            if cand_abs > curr_abs {
1524                return true;
1525            }
1526            if cand_abs < curr_abs {
1527                return false;
1528            }
1529            if candidate > current {
1530                return true;
1531            }
1532            if candidate < current {
1533                return false;
1534            }
1535            if candidate == 0.0 && current == 0.0 {
1536                return candidate.is_sign_positive() && !current.is_sign_positive();
1537            }
1538            false
1539        }
1540    }
1541}
1542
1543fn should_replace_complex(
1544    current: (f64, f64),
1545    candidate: (f64, f64),
1546    comparison: ComparisonMethod,
1547) -> bool {
1548    match comparison {
1549        ComparisonMethod::Auto | ComparisonMethod::Abs => {
1550            compare_complex_auto(current, candidate) == Ordering::Less
1551        }
1552        ComparisonMethod::Real => compare_complex_real(current, candidate) == Ordering::Less,
1553    }
1554}
1555
1556fn compare_complex_auto(a: (f64, f64), b: (f64, f64)) -> Ordering {
1557    let a_mag = magnitude_squared(a);
1558    let b_mag = magnitude_squared(b);
1559    if a_mag < b_mag {
1560        return Ordering::Less;
1561    }
1562    if a_mag > b_mag {
1563        return Ordering::Greater;
1564    }
1565    // Equal magnitude: tie-break using phase angle.
1566    let a_angle = a.1.atan2(a.0);
1567    let b_angle = b.1.atan2(b.0);
1568    if a_angle < b_angle {
1569        Ordering::Less
1570    } else if a_angle > b_angle {
1571        Ordering::Greater
1572    } else {
1573        Ordering::Equal
1574    }
1575}
1576
1577fn compare_complex_real(a: (f64, f64), b: (f64, f64)) -> Ordering {
1578    if a.0 < b.0 {
1579        return Ordering::Less;
1580    }
1581    if a.0 > b.0 {
1582        return Ordering::Greater;
1583    }
1584    // Equal real parts: use magnitude and phase tie-breakers.
1585    compare_complex_auto(a, b)
1586}
1587
1588fn magnitude_squared(z: (f64, f64)) -> f64 {
1589    z.0.mul_add(z.0, z.1 * z.1)
1590}
1591
1592fn default_dimension_from_shape(shape: &[usize]) -> usize {
1593    if is_scalar_shape(shape) {
1594        return 1;
1595    }
1596    for (i, &len) in shape.iter().enumerate() {
1597        if len > 1 {
1598            return i + 1;
1599        }
1600    }
1601    1
1602}
1603
1604async fn elementwise_max(value: Value, args: ElementwiseArgs) -> BuiltinResult<MaxEvaluation> {
1605    let ElementwiseArgs { other, comparison } = args;
1606    match (value, other) {
1607        (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
1608            if gpu_tensor_is_scalar(&handle_b) {
1609                if let Some(num) = gpu_tensor_scalar_value(&handle_b).await {
1610                    let scalar = Value::Num(num);
1611                    if let Some(eval) =
1612                        elementwise_max_gpu_scalar_left(&handle_a, &scalar, comparison).await
1613                    {
1614                        return Ok(eval);
1615                    }
1616                    if let Ok(ta) = gpu_helpers::gather_tensor_async(&handle_a).await {
1617                        if let Ok(eval) = elementwise_real_or_complex(
1618                            Value::Tensor(ta),
1619                            scalar.clone(),
1620                            comparison,
1621                        ) {
1622                            return Ok(eval);
1623                        }
1624                    }
1625                    return Err(max_internal_error(
1626                        "max: elementwise GPU scalar path failed",
1627                    ));
1628                }
1629            }
1630            if gpu_tensor_is_scalar(&handle_a) {
1631                if let Some(num) = gpu_tensor_scalar_value(&handle_a).await {
1632                    let scalar = Value::Num(num);
1633                    if let Some(eval) =
1634                        elementwise_max_gpu_scalar_right(&scalar, &handle_b, comparison).await
1635                    {
1636                        return Ok(eval);
1637                    }
1638                    if let Ok(tb) = gpu_helpers::gather_tensor_async(&handle_b).await {
1639                        if let Ok(eval) = elementwise_real_or_complex(
1640                            scalar.clone(),
1641                            Value::Tensor(tb),
1642                            comparison,
1643                        ) {
1644                            return Ok(eval);
1645                        }
1646                    }
1647                    return Err(max_internal_error(
1648                        "max: elementwise GPU scalar path failed",
1649                    ));
1650                }
1651            }
1652            if let Some(eval) = elementwise_max_gpu_pair(&handle_a, &handle_b, comparison).await {
1653                return Ok(eval);
1654            }
1655            if let (Ok(ta), Ok(tb)) = (
1656                gpu_helpers::gather_tensor_async(&handle_a).await,
1657                gpu_helpers::gather_tensor_async(&handle_b).await,
1658            ) {
1659                if let Ok(eval) =
1660                    elementwise_real_or_complex(Value::Tensor(ta), Value::Tensor(tb), comparison)
1661                {
1662                    return Ok(eval);
1663                }
1664            }
1665            Err(max_internal_error("max: elementwise GPU path failed"))
1666        }
1667        (Value::GpuTensor(handle), other) => {
1668            if let Some(eval) = elementwise_max_gpu_scalar_left(&handle, &other, comparison).await {
1669                return Ok(eval);
1670            }
1671            let t = gpu_helpers::gather_tensor_async(&handle)
1672                .await
1673                .map_err(|_| max_internal_error("max: elementwise GPU scalar path failed"))?;
1674            elementwise_real_or_complex(Value::Tensor(t), other, comparison)
1675        }
1676        (other, Value::GpuTensor(handle)) => {
1677            if let Some(eval) = elementwise_max_gpu_scalar_right(&other, &handle, comparison).await
1678            {
1679                return Ok(eval);
1680            }
1681            let t = gpu_helpers::gather_tensor_async(&handle)
1682                .await
1683                .map_err(|_| max_internal_error("max: elementwise GPU scalar path failed"))?;
1684            elementwise_real_or_complex(other, Value::Tensor(t), comparison)
1685        }
1686        (lhs, rhs) => elementwise_real_or_complex(lhs, rhs, comparison),
1687    }
1688}
1689
1690async fn elementwise_max_gpu_pair(
1691    a: &GpuTensorHandle,
1692    b: &GpuTensorHandle,
1693    comparison: ComparisonMethod,
1694) -> Option<MaxEvaluation> {
1695    if comparison != ComparisonMethod::Auto {
1696        return None;
1697    }
1698    let provider = runmat_accelerate_api::provider()?;
1699    // Equal-shape fast path
1700    if a.shape == b.shape {
1701        let values = provider.elem_max(a, b).await.ok()?;
1702        // Try device mask first; if unavailable, compute indices on host while keeping values on device
1703        if let Ok(mask) = provider.elem_ge(a, b).await {
1704            let indices = gpu_mask_indices(provider, &mask)?;
1705            let _ = provider.free(&mask);
1706            return Some(MaxEvaluation {
1707                values: Value::GpuTensor(values),
1708                indices: Value::GpuTensor(indices),
1709            });
1710        } else {
1711            // Host path for indices only
1712            let ta = gpu_helpers::gather_tensor_async(a).await.ok()?;
1713            let tb = gpu_helpers::gather_tensor_async(b).await.ok()?;
1714            let mut indices = Vec::with_capacity(ta.data.len());
1715            for i in 0..ta.data.len() {
1716                indices.push(if ta.data[i] >= tb.data[i] { 1.0 } else { 2.0 });
1717            }
1718            let index_tensor = Tensor::new(indices, ta.shape.clone()).ok()?;
1719            return Some(MaxEvaluation {
1720                values: Value::GpuTensor(values),
1721                indices: tensor::tensor_into_value(index_tensor),
1722            });
1723        }
1724    }
1725    // Broadcast-compatible path via repmat, then device compare
1726    let (out_shape, reps_a, reps_b) = broadcast_reps(&a.shape, &b.shape)?;
1727    let a_exp = if reps_a.iter().any(|&r| r != 1) {
1728        provider.repmat(a, &reps_a).ok()?
1729    } else {
1730        a.clone()
1731    };
1732    let b_exp = if reps_b.iter().any(|&r| r != 1) {
1733        provider.repmat(b, &reps_b).ok()?
1734    } else {
1735        b.clone()
1736    };
1737    let values = provider.elem_max(&a_exp, &b_exp).await.ok();
1738    let mask = provider.elem_ge(&a_exp, &b_exp).await.ok();
1739    if !std::ptr::eq(&a_exp, a) {
1740        let _ = provider.free(&a_exp);
1741    }
1742    if !std::ptr::eq(&b_exp, b) {
1743        let _ = provider.free(&b_exp);
1744    }
1745    let values = values?;
1746    if values.shape != out_shape {
1747        let _ = provider.free(&values);
1748        return None;
1749    }
1750    let index_tensor = if let Some(mask) = mask {
1751        let mask_host = gpu_helpers::gather_tensor_async(&mask).await.ok()?;
1752        let _ = provider.free(&mask);
1753        let mut indices = Vec::with_capacity(mask_host.data.len());
1754        for &m in &mask_host.data {
1755            indices.push(if m != 0.0 { 1.0 } else { 2.0 });
1756        }
1757        Tensor::new(indices, out_shape).ok()?
1758    } else {
1759        // Host indices fallback
1760        let ta = gpu_helpers::gather_tensor_async(&a_exp).await.ok()?;
1761        let tb = gpu_helpers::gather_tensor_async(&b_exp).await.ok()?;
1762        let mut indices = Vec::with_capacity(ta.data.len());
1763        for i in 0..ta.data.len() {
1764            indices.push(if ta.data[i] >= tb.data[i] { 1.0 } else { 2.0 });
1765        }
1766        Tensor::new(indices, out_shape).ok()?
1767    };
1768    Some(MaxEvaluation {
1769        values: Value::GpuTensor(values),
1770        indices: tensor::tensor_into_value(index_tensor),
1771    })
1772}
1773
1774fn broadcast_reps(a: &[usize], b: &[usize]) -> Option<(Vec<usize>, Vec<usize>, Vec<usize>)> {
1775    let rank = a.len().max(b.len()).max(1);
1776    let mut out = vec![1usize; rank];
1777    let mut aa = vec![1usize; rank];
1778    let mut bb = vec![1usize; rank];
1779    for i in 0..rank {
1780        aa[i] = *a.get(i).unwrap_or(&1);
1781        bb[i] = *b.get(i).unwrap_or(&1);
1782    }
1783    for i in 0..rank {
1784        let (ad, bd) = (aa[i], bb[i]);
1785        if ad == bd {
1786            out[i] = ad;
1787        } else if ad == 1 {
1788            out[i] = bd;
1789        } else if bd == 1 {
1790            out[i] = ad;
1791        } else {
1792            return None;
1793        }
1794    }
1795    let reps_a: Vec<usize> = (0..rank)
1796        .map(|i| if aa[i] == out[i] { 1 } else { out[i] })
1797        .collect();
1798    let reps_b: Vec<usize> = (0..rank)
1799        .map(|i| if bb[i] == out[i] { 1 } else { out[i] })
1800        .collect();
1801    Some((out, reps_a, reps_b))
1802}
1803
1804async fn elementwise_max_gpu_scalar_left(
1805    a: &GpuTensorHandle,
1806    other: &Value,
1807    comparison: ComparisonMethod,
1808) -> Option<MaxEvaluation> {
1809    if comparison != ComparisonMethod::Auto {
1810        return None;
1811    }
1812    let provider = runmat_accelerate_api::provider()?;
1813    let scalar = extract_scalar(other)?;
1814    // Prefer tensorize + elem_max for broader provider compatibility
1815    let values = if let Ok(fill) = provider.fill_like(a, scalar) {
1816        let vals = provider.elem_max(a, &fill).await.ok();
1817        let _ = provider.free(&fill);
1818        vals?
1819    } else {
1820        provider.scalar_max(a, scalar).ok()?
1821    };
1822    // Try device mask; if unavailable, compute on host
1823    let index_tensor = if let Ok(fill) = provider.fill_like(a, scalar) {
1824        if let Ok(mask) = provider.elem_ge(a, &fill).await {
1825            let _ = provider.free(&fill);
1826            let indices = gpu_mask_indices(provider, &mask)?;
1827            let _ = provider.free(&mask);
1828            return Some(MaxEvaluation {
1829                values: Value::GpuTensor(values),
1830                indices: Value::GpuTensor(indices),
1831            });
1832        } else {
1833            let _ = provider.free(&fill);
1834            let ta = gpu_helpers::gather_tensor_async(a).await.ok()?;
1835            let mut indices = Vec::with_capacity(ta.data.len());
1836            for &v in &ta.data {
1837                indices.push(if v >= scalar { 1.0 } else { 2.0 });
1838            }
1839            Tensor::new(indices, ta.shape.clone()).ok()?
1840        }
1841    } else {
1842        let ta = gpu_helpers::gather_tensor_async(a).await.ok()?;
1843        let mut indices = Vec::with_capacity(ta.data.len());
1844        for &v in &ta.data {
1845            indices.push(if v >= scalar { 1.0 } else { 2.0 });
1846        }
1847        Tensor::new(indices, ta.shape.clone()).ok()?
1848    };
1849    Some(MaxEvaluation {
1850        values: Value::GpuTensor(values),
1851        indices: tensor::tensor_into_value(index_tensor),
1852    })
1853}
1854
1855async fn elementwise_max_gpu_scalar_right(
1856    other: &Value,
1857    b: &GpuTensorHandle,
1858    comparison: ComparisonMethod,
1859) -> Option<MaxEvaluation> {
1860    if comparison != ComparisonMethod::Auto {
1861        return None;
1862    }
1863    let provider = runmat_accelerate_api::provider()?;
1864    let scalar = extract_scalar(other)?;
1865    let values = if let Ok(fill) = provider.fill_like(b, scalar) {
1866        let vals = provider.elem_max(&fill, b).await.ok();
1867        let _ = provider.free(&fill);
1868        vals?
1869    } else {
1870        provider.scalar_max(b, scalar).ok()?
1871    };
1872    // Try device mask; if unavailable, compute on host (origin 1 if scalar >= b)
1873    let index_tensor = if let Ok(fill) = provider.fill_like(b, scalar) {
1874        if let Ok(mask) = provider.elem_ge(&fill, b).await {
1875            let _ = provider.free(&fill);
1876            let indices = gpu_mask_indices(provider, &mask)?;
1877            let _ = provider.free(&mask);
1878            return Some(MaxEvaluation {
1879                values: Value::GpuTensor(values),
1880                indices: Value::GpuTensor(indices),
1881            });
1882        } else {
1883            let _ = provider.free(&fill);
1884            let tb = gpu_helpers::gather_tensor_async(b).await.ok()?;
1885            let mut indices = Vec::with_capacity(tb.data.len());
1886            for &v in &tb.data {
1887                indices.push(if scalar >= v { 1.0 } else { 2.0 });
1888            }
1889            Tensor::new(indices, tb.shape.clone()).ok()?
1890        }
1891    } else {
1892        let tb = gpu_helpers::gather_tensor_async(b).await.ok()?;
1893        let mut indices = Vec::with_capacity(tb.data.len());
1894        for &v in &tb.data {
1895            indices.push(if scalar >= v { 1.0 } else { 2.0 });
1896        }
1897        Tensor::new(indices, tb.shape.clone()).ok()?
1898    };
1899    Some(MaxEvaluation {
1900        values: Value::GpuTensor(values),
1901        indices: tensor::tensor_into_value(index_tensor),
1902    })
1903}
1904
1905fn extract_scalar(v: &Value) -> Option<f64> {
1906    match v {
1907        Value::Num(n) => Some(*n),
1908        Value::Int(i) => Some(i.to_f64()),
1909        Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
1910        Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
1911        Value::LogicalArray(l) if l.data.len() == 1 => Some(if l.data[0] != 0 { 1.0 } else { 0.0 }),
1912        _ => None,
1913    }
1914}
1915
1916fn gpu_tensor_is_scalar(handle: &GpuTensorHandle) -> bool {
1917    handle.shape.iter().copied().product::<usize>().max(1) == 1
1918}
1919
1920async fn gpu_tensor_scalar_value(handle: &GpuTensorHandle) -> Option<f64> {
1921    let tensor = gpu_helpers::gather_tensor_async(handle).await.ok()?;
1922    tensor.data.first().copied()
1923}
1924
1925fn gpu_mask_indices(
1926    provider: &dyn AccelProvider,
1927    mask: &GpuTensorHandle,
1928) -> Option<GpuTensorHandle> {
1929    let scaled = provider.scalar_mul(mask, -1.0).ok()?;
1930    let shifted = provider.scalar_add(&scaled, 2.0).ok()?;
1931    let _ = provider.free(&scaled);
1932    Some(shifted)
1933}
1934
1935fn elementwise_real_or_complex(
1936    lhs: Value,
1937    rhs: Value,
1938    comparison: ComparisonMethod,
1939) -> BuiltinResult<MaxEvaluation> {
1940    if let Some(eval) = scalar_elementwise_max(&lhs, &rhs, comparison) {
1941        return Ok(eval);
1942    }
1943    match (
1944        materialize_for_max("max", lhs)?,
1945        materialize_for_max("max", rhs)?,
1946    ) {
1947        (InputData::Complex(a), InputData::Complex(b)) => elementwise_complex_max(a, b, comparison),
1948        (InputData::Complex(a), InputData::Real(b)) => {
1949            let converted = promote_real_tensor_to_complex(b);
1950            elementwise_complex_max(a, converted, comparison)
1951        }
1952        (InputData::Real(a), InputData::Complex(b)) => {
1953            let converted = promote_real_tensor_to_complex(a);
1954            elementwise_complex_max(converted, b, comparison)
1955        }
1956        (InputData::Real(a), InputData::Real(b)) => elementwise_real_max(a, b, comparison),
1957    }
1958}
1959
1960fn scalar_real_value(value: &Value) -> Option<f64> {
1961    match value {
1962        Value::Num(n) => Some(*n),
1963        Value::Int(i) => Some(i.to_f64()),
1964        Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
1965        Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
1966        Value::LogicalArray(l) if l.data.len() == 1 => Some(if l.data[0] != 0 { 1.0 } else { 0.0 }),
1967        _ => None,
1968    }
1969}
1970
1971fn scalar_complex_value(value: &Value) -> Option<(f64, f64)> {
1972    match value {
1973        Value::Complex(re, im) => Some((*re, *im)),
1974        Value::ComplexTensor(ct) if ct.data.len() == 1 => ct.data.first().copied(),
1975        _ => None,
1976    }
1977}
1978
1979fn scalar_elementwise_max(
1980    lhs: &Value,
1981    rhs: &Value,
1982    comparison: ComparisonMethod,
1983) -> Option<MaxEvaluation> {
1984    let left = scalar_complex_value(lhs).or_else(|| scalar_real_value(lhs).map(|v| (v, 0.0)))?;
1985    let right = scalar_complex_value(rhs).or_else(|| scalar_real_value(rhs).map(|v| (v, 0.0)))?;
1986    let (ar, ai) = left;
1987    let (br, bi) = right;
1988    if ai != 0.0 || bi != 0.0 {
1989        let (value, origin) = choose_complex_elementwise((ar, ai), (br, bi), comparison);
1990        return Some(MaxEvaluation {
1991            values: Value::Complex(value.0, value.1),
1992            indices: Value::Num(origin),
1993        });
1994    }
1995    let (value, origin) = choose_real_elementwise(ar, br, comparison);
1996    Some(MaxEvaluation {
1997        values: Value::Num(value),
1998        indices: Value::Num(origin),
1999    })
2000}
2001
2002fn elementwise_real_max(
2003    lhs: Tensor,
2004    rhs: Tensor,
2005    comparison: ComparisonMethod,
2006) -> BuiltinResult<MaxEvaluation> {
2007    let plan = BroadcastPlan::new(&lhs.shape, &rhs.shape)
2008        .map_err(|err| max_size_mismatch(format!("max: {err}")))?;
2009    let mut values = vec![0.0f64; plan.len()];
2010    let mut indices = vec![0.0f64; plan.len()];
2011
2012    for (offset, index_a, index_b) in plan.iter() {
2013        let a = lhs.data.get(index_a).copied().unwrap_or(f64::NAN);
2014        let b = rhs.data.get(index_b).copied().unwrap_or(f64::NAN);
2015        let (value, origin) = choose_real_elementwise(a, b, comparison);
2016        values[offset] = value;
2017        indices[offset] = origin;
2018    }
2019
2020    let value_tensor = Tensor::new(values, plan.output_shape().to_vec())
2021        .map_err(|e| max_internal_error(format!("max: {e}")))?;
2022    let index_tensor = Tensor::new(indices, plan.output_shape().to_vec())
2023        .map_err(|e| max_internal_error(format!("max: {e}")))?;
2024
2025    Ok(MaxEvaluation {
2026        values: tensor::tensor_into_value(value_tensor),
2027        indices: tensor::tensor_into_value(index_tensor),
2028    })
2029}
2030
2031fn elementwise_complex_max(
2032    lhs: ComplexTensor,
2033    rhs: ComplexTensor,
2034    comparison: ComparisonMethod,
2035) -> BuiltinResult<MaxEvaluation> {
2036    let plan = BroadcastPlan::new(&lhs.shape, &rhs.shape)
2037        .map_err(|err| max_size_mismatch(format!("max: {err}")))?;
2038    let mut values = vec![(0.0f64, 0.0f64); plan.len()];
2039    let mut indices = vec![0.0f64; plan.len()];
2040
2041    for (offset, index_a, index_b) in plan.iter() {
2042        let a = lhs
2043            .data
2044            .get(index_a)
2045            .copied()
2046            .unwrap_or((f64::NAN, f64::NAN));
2047        let b = rhs
2048            .data
2049            .get(index_b)
2050            .copied()
2051            .unwrap_or((f64::NAN, f64::NAN));
2052        let (value, origin) = choose_complex_elementwise(a, b, comparison);
2053        values[offset] = value;
2054        indices[offset] = origin;
2055    }
2056
2057    let value_tensor = ComplexTensor::new(values, plan.output_shape().to_vec())
2058        .map_err(|e| max_internal_error(format!("max: {e}")))?;
2059    let index_tensor = Tensor::new(indices, plan.output_shape().to_vec())
2060        .map_err(|e| max_internal_error(format!("max: {e}")))?;
2061
2062    Ok(MaxEvaluation {
2063        values: complex_tensor_into_value(value_tensor),
2064        indices: tensor::tensor_into_value(index_tensor),
2065    })
2066}
2067
2068fn promote_real_tensor_to_complex(tensor: Tensor) -> ComplexTensor {
2069    let data = tensor
2070        .data
2071        .iter()
2072        .copied()
2073        .map(|re| (re, 0.0))
2074        .collect::<Vec<_>>();
2075    ComplexTensor {
2076        data,
2077        shape: tensor.shape.clone(),
2078        rows: tensor.rows,
2079        cols: tensor.cols,
2080    }
2081}
2082
2083fn choose_real_elementwise(a: f64, b: f64, comparison: ComparisonMethod) -> (f64, f64) {
2084    match (a.is_nan(), b.is_nan()) {
2085        (true, true) => (f64::NAN, 1.0),
2086        (true, false) => (f64::NAN, 1.0),
2087        (false, true) => (f64::NAN, 2.0),
2088        (false, false) => {
2089            if should_replace_real(a, b, comparison) {
2090                (b, 2.0)
2091            } else {
2092                (a, 1.0)
2093            }
2094        }
2095    }
2096}
2097
2098fn choose_complex_elementwise(
2099    a: (f64, f64),
2100    b: (f64, f64),
2101    comparison: ComparisonMethod,
2102) -> ((f64, f64), f64) {
2103    let a_nan = a.0.is_nan() || a.1.is_nan();
2104    let b_nan = b.0.is_nan() || b.1.is_nan();
2105    match (a_nan, b_nan) {
2106        (true, true) => ((f64::NAN, f64::NAN), 1.0),
2107        (true, false) => ((f64::NAN, f64::NAN), 1.0),
2108        (false, true) => ((f64::NAN, f64::NAN), 2.0),
2109        (false, false) => {
2110            if should_replace_complex(a, b, comparison) {
2111                (b, 2.0)
2112            } else {
2113                (a, 1.0)
2114            }
2115        }
2116    }
2117}
2118
2119#[cfg(test)]
2120pub(crate) mod tests {
2121    use super::*;
2122    #[cfg(feature = "wgpu")]
2123    use crate::builtins::common::test_support;
2124    use futures::executor::block_on;
2125    #[cfg(feature = "wgpu")]
2126    use runmat_accelerate_api::HostTensorView;
2127    use runmat_builtins::{IntValue, Tensor, Value};
2128
2129    fn max_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
2130        block_on(super::max_builtin(value, rest))
2131    }
2132
2133    #[test]
2134    fn max_type_with_two_args_returns_tensor() {
2135        let out = max_type(
2136            &[Type::Tensor { shape: None }, Type::Tensor { shape: None }],
2137            &ResolveContext::new(Vec::new()),
2138        );
2139        assert_eq!(out, Type::tensor());
2140    }
2141
2142    #[test]
2143    fn max_descriptor_signatures_cover_core_forms() {
2144        let labels: Vec<&str> = MAX_DESCRIPTOR
2145            .signatures
2146            .iter()
2147            .map(|sig| sig.label)
2148            .collect();
2149        assert!(labels.contains(&"M = max(A)"));
2150        assert!(labels.contains(&"[M, I] = max(A)"));
2151        assert!(labels.contains(&"M = max(A, B)"));
2152        assert!(labels.contains(&"[M, I] = max(A, B)"));
2153        assert!(labels.contains(&"M = max(A, [], dim)"));
2154        assert!(labels.contains(&"M = max(A, [], \"all\")"));
2155        assert!(labels.contains(&"M = max(A, [], \"ComparisonMethod\", method)"));
2156        assert!(labels.contains(&"M = max(A, B, \"ComparisonMethod\", method)"));
2157    }
2158
2159    #[test]
2160    fn max_descriptor_errors_have_stable_codes() {
2161        assert!(MAX_DESCRIPTOR
2162            .errors
2163            .iter()
2164            .any(|error| error.code == MAX_ERROR_INVALID_ARGUMENT.code));
2165        assert!(MAX_DESCRIPTOR
2166            .errors
2167            .iter()
2168            .any(|error| error.code == MAX_ERROR_INVALID_INPUT.code));
2169        assert!(MAX_DESCRIPTOR
2170            .errors
2171            .iter()
2172            .any(|error| error.code == MAX_ERROR_SIZE_MISMATCH.code));
2173        assert!(MAX_DESCRIPTOR
2174            .errors
2175            .iter()
2176            .any(|error| error.code == MAX_ERROR_INTERNAL.code));
2177    }
2178
2179    fn evaluate(value: Value, rest: &[Value]) -> BuiltinResult<MaxEvaluation> {
2180        block_on(super::evaluate(value, rest))
2181    }
2182
2183    fn placeholder() -> Value {
2184        let tensor = Tensor::new(Vec::<f64>::new(), vec![0, 0]).unwrap();
2185        Value::Tensor(tensor)
2186    }
2187
2188    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2189    #[test]
2190    fn max_scalar_returns_input() {
2191        let result = max_builtin(Value::Num(5.0), Vec::new()).expect("max");
2192        assert_eq!(result, Value::Num(5.0));
2193    }
2194
2195    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2196    #[test]
2197    fn max_vector_with_indices() {
2198        let tensor = Tensor::new(vec![3.0, 1.0, 5.0], vec![3, 1]).unwrap();
2199        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
2200        let (values, indices) = eval.into_pair();
2201        assert_eq!(values, Value::Num(5.0));
2202        assert_eq!(indices, Value::Num(3.0));
2203    }
2204
2205    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2206    #[test]
2207    fn max_matrix_default_dimension() {
2208        let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0, 5.0, 6.0], vec![2, 3]).unwrap();
2209        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
2210        let (values, indices) = eval.into_pair();
2211        match values {
2212            Value::Tensor(t) => {
2213                assert_eq!(t.shape, vec![1, 3]);
2214                assert_eq!(t.data, vec![4.0, 2.0, 6.0]);
2215            }
2216            other => panic!("expected tensor, got {other:?}"),
2217        }
2218        match indices {
2219            Value::Tensor(t) => {
2220                assert_eq!(t.data, vec![2.0, 2.0, 2.0]);
2221            }
2222            other => panic!("expected tensor, got {other:?}"),
2223        }
2224    }
2225
2226    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2227    #[test]
2228    fn max_all_linear_index() {
2229        let tensor =
2230            Tensor::new((1..=12).map(|v| v as f64).collect::<Vec<_>>(), vec![3, 4]).unwrap();
2231        let args = vec![placeholder(), Value::from("all")];
2232        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2233        let (values, indices) = eval.into_pair();
2234        assert_eq!(values, Value::Num(12.0));
2235        assert_eq!(indices, Value::Num(12.0));
2236
2237        let args_linear = vec![placeholder(), Value::from("linear")];
2238        let eval = evaluate(
2239            Value::Tensor(Tensor::new(vec![2.0, 3.0], vec![1, 2]).unwrap()),
2240            &args_linear,
2241        )
2242        .expect("evaluate");
2243        let (values, indices) = eval.into_pair();
2244        assert_eq!(values, Value::Num(3.0));
2245        assert_eq!(indices, Value::Num(2.0));
2246    }
2247
2248    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2249    #[test]
2250    fn max_with_omitnan() {
2251        let tensor = Tensor::new(vec![f64::NAN, 4.0, 2.0], vec![3, 1]).unwrap();
2252        let args = vec![placeholder(), Value::from("omitnan")];
2253        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2254        let (values, indices) = eval.into_pair();
2255        assert_eq!(values, Value::Num(4.0));
2256        assert_eq!(indices, Value::Num(2.0));
2257    }
2258
2259    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2260    #[test]
2261    fn max_omitnan_all_nan_slice() {
2262        let tensor = Tensor::new(vec![f64::NAN, f64::NAN], vec![2, 1]).unwrap();
2263        let args = vec![placeholder(), Value::from("omitnan")];
2264        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2265        let (values, indices) = eval.into_pair();
2266        match values {
2267            Value::Num(v) => assert!(v.is_nan()),
2268            other => panic!("expected scalar NaN, got {other:?}"),
2269        }
2270        match indices {
2271            Value::Num(v) => assert!(v.is_nan()),
2272            other => panic!("expected scalar NaN index, got {other:?}"),
2273        }
2274    }
2275
2276    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2277    #[test]
2278    fn max_reduction_abs_comparison() {
2279        let tensor = Tensor::new(vec![1.0, -3.0, -2.0, 4.0], vec![2, 2]).unwrap();
2280        let args = vec![
2281            placeholder(),
2282            Value::from("ComparisonMethod"),
2283            Value::from("abs"),
2284        ];
2285        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2286        let (values, indices) = eval.into_pair();
2287        match values {
2288            Value::Tensor(t) => {
2289                assert_eq!(t.shape, vec![1, 2]);
2290                assert_eq!(t.data, vec![-3.0, 4.0]);
2291            }
2292            other => panic!("expected tensor result, got {other:?}"),
2293        }
2294        match indices {
2295            Value::Tensor(t) => {
2296                assert_eq!(t.data, vec![2.0, 2.0]);
2297            }
2298            other => panic!("expected tensor indices, got {other:?}"),
2299        }
2300    }
2301
2302    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2303    #[test]
2304    fn max_reduction_complex_real_comparison() {
2305        let tensor = ComplexTensor::new(vec![(1.0, 2.0), (0.5, 5.0)], vec![2, 1]).expect("tensor");
2306        let args = vec![
2307            placeholder(),
2308            Value::from("ComparisonMethod"),
2309            Value::from("real"),
2310        ];
2311        let eval = evaluate(Value::ComplexTensor(tensor), &args).expect("evaluate");
2312        let (values, indices) = eval.into_pair();
2313        match values {
2314            Value::Complex(re, im) => {
2315                assert!((re - 1.0).abs() < 1e-12);
2316                assert!((im - 2.0).abs() < 1e-12);
2317            }
2318            other => panic!("expected complex scalar, got {other:?}"),
2319        }
2320        assert_eq!(indices, Value::Num(1.0));
2321    }
2322
2323    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2324    #[test]
2325    fn max_elementwise_broadcast() {
2326        let lhs = Tensor::new(vec![1.0, 4.0, 7.0], vec![1, 3]).unwrap();
2327        let rhs = Tensor::new(vec![2.0, 3.0, 5.0], vec![3, 1]).unwrap();
2328        let eval = evaluate(Value::Tensor(lhs), &[Value::Tensor(rhs)]).expect("evaluate");
2329        let (values, indices) = eval.into_pair();
2330        match values {
2331            Value::Tensor(t) => {
2332                assert_eq!(t.shape, vec![3, 3]);
2333                assert_eq!([t.data[0], t.data[3], t.data[6]], [2.0, 4.0, 7.0]);
2334                assert_eq!([t.data[1], t.data[4], t.data[7]], [3.0, 4.0, 7.0]);
2335                assert_eq!([t.data[2], t.data[5], t.data[8]], [5.0, 5.0, 7.0]);
2336            }
2337            other => panic!("expected tensor, got {other:?}"),
2338        }
2339        match indices {
2340            Value::Tensor(t) => {
2341                assert_eq!(t.shape, vec![3, 3]);
2342                assert_eq!([t.data[0], t.data[3], t.data[6]], [2.0, 1.0, 1.0]);
2343                assert_eq!([t.data[1], t.data[4], t.data[7]], [2.0, 1.0, 1.0]);
2344                assert_eq!([t.data[2], t.data[5], t.data[8]], [2.0, 2.0, 1.0]);
2345            }
2346            other => panic!("expected tensor, got {other:?}"),
2347        }
2348    }
2349
2350    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2351    #[test]
2352    fn max_elementwise_abs_comparison() {
2353        let lhs = Tensor::new(vec![-2.0, 1.0], vec![2, 1]).unwrap();
2354        let rhs = Tensor::new(vec![1.5, -3.0], vec![2, 1]).unwrap();
2355        let args = vec![
2356            Value::Tensor(rhs),
2357            Value::from("ComparisonMethod"),
2358            Value::from("abs"),
2359        ];
2360        let eval = evaluate(Value::Tensor(lhs), &args).expect("evaluate");
2361        let (values, indices) = eval.into_pair();
2362        match values {
2363            Value::Tensor(t) => {
2364                assert_eq!(t.data, vec![-2.0, -3.0]);
2365            }
2366            other => panic!("expected tensor, got {other:?}"),
2367        }
2368        match indices {
2369            Value::Tensor(t) => {
2370                assert_eq!(t.data, vec![1.0, 2.0]);
2371            }
2372            other => panic!("expected tensor, got {other:?}"),
2373        }
2374    }
2375
2376    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2377    #[test]
2378    fn max_elementwise_rejects_reduction_only_keywords() {
2379        let lhs = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
2380        let rhs = Tensor::new(vec![3.0, 4.0], vec![2, 1]).unwrap();
2381        let err = evaluate(
2382            Value::Tensor(lhs),
2383            &[Value::Tensor(rhs), Value::from("omitnan")],
2384        )
2385        .expect_err("expected error");
2386        assert_eq!(err.identifier(), MAX_ERROR_INVALID_ARGUMENT.identifier);
2387        assert!(err.message().contains("only supported for reduction"));
2388    }
2389
2390    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2391    #[test]
2392    fn max_complex_real_comparison() {
2393        let lhs = ComplexTensor::new(vec![(1.0, 2.0)], vec![1, 1]).unwrap();
2394        let rhs = ComplexTensor::new(vec![(0.5, 5.0)], vec![1, 1]).unwrap();
2395        let args = vec![
2396            Value::ComplexTensor(rhs),
2397            Value::from("ComparisonMethod"),
2398            Value::from("real"),
2399        ];
2400        let eval = evaluate(Value::ComplexTensor(lhs), &args).expect("evaluate");
2401        let (values, indices) = eval.into_pair();
2402        assert_eq!(values, Value::Complex(1.0, 2.0));
2403        assert_eq!(indices, Value::Num(1.0));
2404    }
2405
2406    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2407    #[test]
2408    fn max_dimension_argument_parsing() {
2409        let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0], vec![2, 2]).unwrap();
2410        let dims = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
2411        let args = vec![placeholder(), Value::Tensor(dims)];
2412        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2413        let (values, indices) = eval.into_pair();
2414        assert_eq!(values, Value::Num(4.0));
2415        assert_eq!(indices, Value::Num(2.0));
2416    }
2417
2418    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2419    #[test]
2420    fn max_vecdim_duplicate_entries() {
2421        let tensor = Tensor::new(vec![5.0, 2.0, 7.0, 1.0], vec![2, 2]).unwrap();
2422        let dims = Tensor::new(vec![1.0, 1.0, 2.0], vec![3, 1]).unwrap();
2423        let args = vec![placeholder(), Value::Tensor(dims)];
2424        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2425        let (values, indices) = eval.into_pair();
2426        assert_eq!(values, Value::Num(7.0));
2427        assert_eq!(indices, Value::Num(3.0));
2428    }
2429
2430    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2431    #[test]
2432    fn max_dimension_gpu_argument_errors() {
2433        let tensor = Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap();
2434        let dim_handle = Value::GpuTensor(runmat_accelerate_api::GpuTensorHandle {
2435            shape: vec![1, 1],
2436            device_id: 0,
2437            buffer_id: 42,
2438        });
2439        let err = evaluate(Value::Tensor(tensor), &[placeholder(), dim_handle])
2440            .expect_err("expected error");
2441        assert_eq!(err.identifier(), MAX_ERROR_INVALID_ARGUMENT.identifier);
2442        assert!(err
2443            .message()
2444            .contains("dimension arguments must reside on the host"));
2445    }
2446
2447    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2448    #[test]
2449    fn max_invalid_comparison_method_errors() {
2450        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
2451        let args = vec![
2452            placeholder(),
2453            Value::from("ComparisonMethod"),
2454            Value::from("chebyshev"),
2455        ];
2456        let err = evaluate(Value::Tensor(tensor), &args).expect_err("expected error");
2457        assert_eq!(err.identifier(), MAX_ERROR_INVALID_ARGUMENT.identifier);
2458        assert!(err.message().contains("unsupported ComparisonMethod"));
2459    }
2460
2461    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2462    #[test]
2463    #[cfg(feature = "wgpu")]
2464    fn max_gpu_dim1_matches_cpu() {
2465        let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2466        let eval_cpu = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu");
2467        let (values_cpu, indices_cpu) = eval_cpu.into_pair();
2468
2469        test_support::with_test_provider(|provider| {
2470            let view = HostTensorView {
2471                data: &tensor.data,
2472                shape: &tensor.shape,
2473            };
2474            let handle = provider.upload(&view).expect("upload");
2475            let eval_gpu = evaluate(Value::GpuTensor(handle), &[]).expect("gpu");
2476            let (values_gpu, indices_gpu) = eval_gpu.into_pair();
2477            match (&values_gpu, &indices_gpu) {
2478                (Value::GpuTensor(_), Value::GpuTensor(_)) => {}
2479                other => panic!("expected GPU tensors, got {other:?}"),
2480            }
2481            let gathered_vals = test_support::gather(values_gpu).expect("gather values");
2482            let gathered_idx = test_support::gather(indices_gpu).expect("gather indices");
2483            let expected_vals = match values_cpu {
2484                Value::Tensor(t) => t,
2485                other => panic!("expected tensor values from cpu eval, got {other:?}"),
2486            };
2487            let expected_idx = match indices_cpu {
2488                Value::Tensor(t) => t,
2489                other => panic!("expected tensor indices from cpu eval, got {other:?}"),
2490            };
2491            assert_eq!(gathered_vals.shape, expected_vals.shape);
2492            assert_eq!(gathered_vals.data, expected_vals.data);
2493            assert_eq!(gathered_idx.shape, expected_idx.shape);
2494            assert_eq!(gathered_idx.data, expected_idx.data);
2495        });
2496    }
2497
2498    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2499    #[test]
2500    fn max_dimension_numeric_argument() {
2501        let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0], vec![2, 2]).unwrap();
2502        let args = vec![placeholder(), Value::Num(2.0)];
2503        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2504        let (values, indices) = eval.into_pair();
2505        match values {
2506            Value::Tensor(t) => {
2507                assert_eq!(t.shape, vec![2, 1]);
2508                assert_eq!(t.data, vec![3.0, 4.0]);
2509            }
2510            other => panic!("expected tensor, got {other:?}"),
2511        }
2512        match indices {
2513            Value::Tensor(t) => {
2514                assert_eq!(t.data, vec![1.0, 1.0]);
2515            }
2516            other => panic!("expected tensor, got {other:?}"),
2517        }
2518    }
2519
2520    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2521    #[test]
2522    fn max_complex_auto_comparison() {
2523        let lhs = ComplexTensor::new(vec![(1.0, 2.0)], vec![1, 1]).unwrap();
2524        let rhs = ComplexTensor::new(vec![(2.0, 1.0)], vec![1, 1]).unwrap();
2525        let eval =
2526            evaluate(Value::ComplexTensor(lhs), &[Value::ComplexTensor(rhs)]).expect("evaluate");
2527        let (values, indices) = eval.into_pair();
2528        assert_eq!(values, Value::Complex(1.0, 2.0));
2529        assert_eq!(indices, Value::Num(1.0));
2530    }
2531
2532    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2533    #[test]
2534    fn max_scalar_pair_arguments() {
2535        let args = vec![Value::Num(2.0)];
2536        let result = max_builtin(Value::Num(3.0), args).expect("max");
2537        assert_eq!(result, Value::Num(3.0));
2538    }
2539
2540    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2541    #[test]
2542    fn max_rejects_invalid_dimension() {
2543        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
2544        let args = vec![placeholder(), Value::Int(IntValue::I32(0))];
2545        let err = evaluate(Value::Tensor(tensor), &args).expect_err("expected error");
2546        assert_eq!(err.identifier(), MAX_ERROR_INVALID_ARGUMENT.identifier);
2547        assert!(err.message().contains("dimension must be >= 1"));
2548    }
2549}