Skip to main content

runmat_runtime/builtins/math/reduction/
max.rs

1//! MATLAB-compatible `max` builtin with GPU-aware semantics for RunMat.
2
3use std::cmp::Ordering;
4use std::collections::BTreeSet;
5
6use runmat_accelerate_api::{AccelProvider, GpuTensorHandle, ReduceDimResult};
7use runmat_builtins::{ComplexTensor, ResolveContext, Tensor, Type, Value};
8use runmat_macros::runtime_builtin;
9
10use crate::{build_runtime_error, BuiltinResult, RuntimeError};
11
12const NAME: &str = "max";
13
14fn max_type(args: &[Type], ctx: &ResolveContext) -> Type {
15    min_max_type(args, ctx)
16}
17
18fn max_error(message: impl Into<String>) -> RuntimeError {
19    build_runtime_error(message).with_builtin(NAME).build()
20}
21
22use crate::builtins::common::arg_tokens::tokens_from_values;
23use crate::builtins::common::broadcast::BroadcastPlan;
24use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
25use crate::builtins::common::spec::{
26    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, FusionError,
27    FusionExprContext, FusionKernelTemplate, GpuOpKind, ProviderHook, ReductionNaN,
28    ResidencyPolicy, ScalarType, ShapeRequirements,
29};
30use crate::builtins::common::{
31    gpu_helpers,
32    shape::{is_scalar_shape, normalize_scalar_shape},
33    tensor,
34};
35use crate::builtins::math::reduction::type_resolvers::min_max_type;
36
37#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::reduction::max")]
38pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
39    name: "max",
40    op_kind: GpuOpKind::Reduction,
41    supported_precisions: &[ScalarType::F32, ScalarType::F64],
42    broadcast: BroadcastSemantics::Matlab,
43    provider_hooks: &[
44        ProviderHook::Reduction {
45            name: "reduce_max_dim",
46        },
47        ProviderHook::Reduction {
48            name: "reduce_max",
49        },
50    ],
51    constant_strategy: ConstantStrategy::InlineLiteral,
52    residency: ResidencyPolicy::NewHandle,
53    nan_mode: ReductionNaN::Include,
54    two_pass_threshold: Some(256),
55    workgroup_size: Some(256),
56    accepts_nan_mode: false,
57    notes:
58        "Providers should implement reduce_max_dim / reduce_max. Requests that require omitnan, comparisonmethod overrides, or complex inputs fall back to the host implementation.",
59};
60
61#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::reduction::max")]
62pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
63    name: "max",
64    shape: ShapeRequirements::BroadcastCompatible,
65    constant_strategy: ConstantStrategy::InlineLiteral,
66    elementwise: None,
67    reduction: Some(FusionKernelTemplate {
68        scalar_precisions: &[ScalarType::F32, ScalarType::F64],
69        wgsl_body: |ctx: &FusionExprContext| {
70            let input = ctx.inputs.first().ok_or(FusionError::MissingInput(0))?;
71            Ok(format!("accumulator = max(accumulator, {input});"))
72        },
73    }),
74    emits_nan: true,
75    notes: "Fusion planner emits canonical reduction kernels; providers may substitute custom WGSL via reduce_max_dim hooks.",
76};
77
78/// Evaluation artifact returned by `max` that carries both values and indices.
79#[derive(Debug, Clone)]
80pub struct MaxEvaluation {
81    values: Value,
82    indices: Value,
83}
84
85impl MaxEvaluation {
86    /// Consume the evaluation and return only the maximum values (single-output call).
87    pub fn into_value(self) -> Value {
88        self.values
89    }
90
91    /// Consume the evaluation and return both maxima and indices.
92    pub fn into_pair(self) -> (Value, Value) {
93        (self.values, self.indices)
94    }
95
96    /// Peek at the indices without consuming.
97    pub fn indices_value(&self) -> Value {
98        self.indices.clone()
99    }
100}
101
102#[runtime_builtin(
103    name = "max",
104    category = "math/reduction",
105    summary = "Return the maximum elements of scalars, vectors, matrices, or N-D tensors.",
106    keywords = "max,maximum,reduction,gpu,comparisonmethod,omitnan",
107    accel = "reduction",
108    type_resolver(max_type),
109    builtin_path = "crate::builtins::math::reduction::max"
110)]
111async fn max_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
112    evaluate(value, &rest).await.map(|eval| eval.into_value())
113}
114
115/// Evaluate the builtin once and expose both outputs (value + indices).
116pub async fn evaluate(value: Value, rest: &[Value]) -> BuiltinResult<MaxEvaluation> {
117    let parsed = parse_call(rest).await?;
118    if std::env::var("RUNMAT_DEBUG_MAX").is_ok() {
119        let call_label = match &parsed {
120            ParsedCall::Reduction(_) => "reduction",
121            ParsedCall::Elementwise(_) => "elementwise",
122        };
123        let first_arg = rest.first().map(debug_value_kind).unwrap_or("None");
124        tracing::debug!(
125            call_type = call_label,
126            rest_len = rest.len(),
127            first_arg = first_arg,
128            "[runmat-debug-max]"
129        );
130    }
131    match parsed {
132        ParsedCall::Elementwise(args) => elementwise_max(value, args).await,
133        ParsedCall::Reduction(args) => reduction_max(value, args).await,
134    }
135}
136
137#[derive(Debug, Clone)]
138enum ParsedCall {
139    Reduction(ReductionArgs),
140    Elementwise(ElementwiseArgs),
141}
142
143#[derive(Debug, Clone)]
144struct ReductionArgs {
145    selection: DimSelection,
146    nan_mode: ReductionNaN,
147    comparison: ComparisonMethod,
148    linear_index: bool,
149}
150
151impl Default for ReductionArgs {
152    fn default() -> Self {
153        Self {
154            selection: DimSelection::Auto,
155            nan_mode: ReductionNaN::Include,
156            comparison: ComparisonMethod::Auto,
157            linear_index: false,
158        }
159    }
160}
161
162#[derive(Debug, Clone)]
163enum DimSelection {
164    Auto,
165    Dim(usize),
166    Vec(Vec<usize>),
167    All,
168}
169
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171enum ComparisonMethod {
172    Auto,
173    Real,
174    Abs,
175}
176
177#[derive(Debug, Clone)]
178struct ElementwiseArgs {
179    other: Value,
180    comparison: ComparisonMethod,
181}
182
183async fn parse_call(rest: &[Value]) -> BuiltinResult<ParsedCall> {
184    if rest.is_empty() {
185        return Ok(ParsedCall::Reduction(ReductionArgs::default()));
186    }
187
188    let first = &rest[0];
189    if !is_empty_placeholder(first) {
190        let comparison = parse_elementwise_options(&rest[1..])?;
191        return Ok(ParsedCall::Elementwise(ElementwiseArgs {
192            other: first.clone(),
193            comparison,
194        }));
195    }
196
197    let mut args = ReductionArgs::default();
198    parse_reduction_options(&mut args, &rest[1..]).await?;
199    Ok(ParsedCall::Reduction(args))
200}
201
202fn debug_value_kind(value: &Value) -> &'static str {
203    match value {
204        Value::Num(_) => "Num",
205        Value::Int(_) => "Int",
206        Value::Bool(_) => "Bool",
207        Value::Tensor(t) => {
208            if t.data.is_empty() {
209                "Tensor(empty)"
210            } else {
211                "Tensor"
212            }
213        }
214        Value::GpuTensor(_) => "GpuTensor",
215        Value::String(_) => "String",
216        Value::CharArray(_) => "CharArray",
217        Value::StringArray(sa) => {
218            if sa.data.is_empty() {
219                "StringArray(empty)"
220            } else {
221                "StringArray"
222            }
223        }
224        Value::LogicalArray(l) => {
225            if l.data.is_empty() {
226                "LogicalArray(empty)"
227            } else {
228                "LogicalArray"
229            }
230        }
231        Value::Cell(c) => {
232            if c.data.is_empty() {
233                "Cell(empty)"
234            } else {
235                "Cell"
236            }
237        }
238        _ => "Other",
239    }
240}
241
242fn is_empty_placeholder(value: &Value) -> bool {
243    match value {
244        Value::Tensor(t) => t.data.is_empty(),
245        Value::LogicalArray(l) => l.data.is_empty(),
246        Value::StringArray(sa) => sa.data.is_empty(),
247        Value::CharArray(ca) => ca.data.is_empty(),
248        Value::Cell(cell) => cell.data.is_empty(),
249        Value::String(s) => s.is_empty(),
250        _ => false,
251    }
252}
253
254async fn parse_reduction_options(args: &mut ReductionArgs, rest: &[Value]) -> BuiltinResult<()> {
255    let mut idx = 0usize;
256    let mut selection_set = !matches!(args.selection, DimSelection::Auto);
257    let mut comparison_set = matches!(args.comparison, ComparisonMethod::Auto);
258    let tokens = tokens_from_values(rest);
259    while idx < rest.len() {
260        if let Some(crate::builtins::common::arg_tokens::ArgToken::String(text)) = tokens.get(idx) {
261            match text.as_str() {
262                "omitnan" => {
263                    args.nan_mode = ReductionNaN::Omit;
264                    idx += 1;
265                    continue;
266                }
267                "includenan" => {
268                    args.nan_mode = ReductionNaN::Include;
269                    idx += 1;
270                    continue;
271                }
272                "all" => {
273                    if selection_set {
274                        return Err(max_error(
275                            "max: 'all' cannot be combined with an explicit dimension",
276                        ));
277                    }
278                    args.selection = DimSelection::All;
279                    selection_set = true;
280                    idx += 1;
281                    continue;
282                }
283                _ => {}
284            }
285        }
286        if let Some(keyword) = keyword_of(&rest[idx]) {
287            match keyword.as_str() {
288                "omitnan" => {
289                    args.nan_mode = ReductionNaN::Omit;
290                    idx += 1;
291                    continue;
292                }
293                "includenan" => {
294                    args.nan_mode = ReductionNaN::Include;
295                    idx += 1;
296                    continue;
297                }
298                "all" => {
299                    if selection_set {
300                        return Err(max_error(
301                            "max: 'all' cannot be combined with an explicit dimension",
302                        ));
303                    }
304                    args.selection = DimSelection::All;
305                    selection_set = true;
306                    idx += 1;
307                    continue;
308                }
309                "linear" => {
310                    if selection_set {
311                        return Err(max_error(
312                            "max: 'linear' cannot be combined with an explicit dimension",
313                        ));
314                    }
315                    args.selection = DimSelection::All;
316                    args.linear_index = true;
317                    selection_set = true;
318                    idx += 1;
319                    continue;
320                }
321                "comparisonmethod" => {
322                    let Some(value) = rest.get(idx + 1) else {
323                        return Err(max_error("max: expected a value after 'ComparisonMethod'"));
324                    };
325                    args.comparison = parse_comparison_method(value)?;
326                    comparison_set = true;
327                    idx += 2;
328                    continue;
329                }
330                _ => {}
331            }
332        }
333
334        if !selection_set {
335            if let Some(selection) = parse_dimension_value(&rest[idx]).await? {
336                args.selection = selection;
337                selection_set = true;
338                idx += 1;
339                continue;
340            }
341        }
342
343        return Err(max_error(format!(
344            "max: unrecognised argument {:?}",
345            rest[idx]
346        )));
347    }
348
349    if !comparison_set {
350        args.comparison = ComparisonMethod::Auto;
351    }
352
353    Ok(())
354}
355
356fn parse_elementwise_options(rest: &[Value]) -> BuiltinResult<ComparisonMethod> {
357    let mut comparison = ComparisonMethod::Auto;
358    let mut comparison_set = false;
359    let mut idx = 0usize;
360    while idx < rest.len() {
361        if let Some(keyword) = keyword_of(&rest[idx]) {
362            match keyword.as_str() {
363                "comparisonmethod" => {
364                    let Some(value) = rest.get(idx + 1) else {
365                        return Err(max_error("max: expected a value after 'ComparisonMethod'"));
366                    };
367                    comparison = parse_comparison_method(value)?;
368                    comparison_set = true;
369                    idx += 2;
370                    continue;
371                }
372                "omitnan" | "includenan" | "all" | "linear" => {
373                    return Err(max_error(format!(
374                        "max: '{}' is only supported for reduction calls",
375                        keyword
376                    )));
377                }
378                _ => {}
379            }
380        }
381        return Err(max_error(format!(
382            "max: unrecognised argument {:?}",
383            rest[idx]
384        )));
385    }
386    if !comparison_set {
387        comparison = ComparisonMethod::Auto;
388    }
389    Ok(comparison)
390}
391
392fn parse_comparison_method(value: &Value) -> BuiltinResult<ComparisonMethod> {
393    let Some(keyword) = keyword_of(value) else {
394        return Err(max_error("max: 'ComparisonMethod' expects a string value"));
395    };
396    match keyword.as_str() {
397        "auto" => Ok(ComparisonMethod::Auto),
398        "abs" | "magnitude" => Ok(ComparisonMethod::Abs),
399        "real" => Ok(ComparisonMethod::Real),
400        other => Err(max_error(format!(
401            "max: unsupported ComparisonMethod '{other}'"
402        ))),
403    }
404}
405
406async fn parse_dimension_value(value: &Value) -> BuiltinResult<Option<DimSelection>> {
407    match value {
408        Value::Int(_) | Value::Num(_) => tensor::dimension_from_value_async(value, "max", false)
409            .await
410            .map_err(map_scalar_dim_error)
411            .map(|dim| dim.map(DimSelection::Dim)),
412        Value::Tensor(t) => parse_dimension_tensor(value, &t.shape).await,
413        Value::LogicalArray(logical) => parse_dimension_tensor(value, &logical.shape).await,
414        Value::GpuTensor(_) => Err(max_error(
415            "max: dimension arguments must reside on the host",
416        )),
417        _ => Ok(None),
418    }
419}
420
421async fn parse_dimension_tensor(
422    value: &Value,
423    shape: &[usize],
424) -> BuiltinResult<Option<DimSelection>> {
425    if tensor::element_count(shape) == 0 {
426        return Ok(Some(DimSelection::Auto));
427    }
428    let is_vector = shape.len() == 1
429        || shape.get(0).copied().unwrap_or(1) == 1
430        || shape.get(1).copied().unwrap_or(1) == 1;
431    if !is_vector {
432        return Err(max_error(
433            "max: dimension vector must be a row or column vector",
434        ));
435    }
436    let dims = tensor::dims_from_value_async(value)
437        .await
438        .map_err(map_vector_dim_error)?;
439    let Some(dims) = dims else {
440        return Ok(None);
441    };
442    if dims.is_empty() {
443        return Ok(Some(DimSelection::Auto));
444    }
445    let mut seen = BTreeSet::new();
446    let mut uniq = Vec::with_capacity(dims.len());
447    for dim in dims {
448        if dim < 1 {
449            return Err(max_error("max: dimension indices must be >= 1"));
450        }
451        if seen.insert(dim) {
452            uniq.push(dim);
453        }
454    }
455    Ok(Some(DimSelection::Vec(uniq)))
456}
457
458fn map_scalar_dim_error(message: String) -> RuntimeError {
459    if message.contains("integer") {
460        return max_error("max: dimension must be integral");
461    }
462    max_error(message)
463}
464
465fn map_vector_dim_error(message: String) -> RuntimeError {
466    if message.contains("non-negative") {
467        return max_error("max: dimension indices must be >= 1");
468    }
469    if message.contains("finite") {
470        return max_error("max: dimension entries must be finite");
471    }
472    if message.contains("integer") {
473        return max_error("max: dimension entries must be integers");
474    }
475    max_error(message)
476}
477
478async fn reduction_max(value: Value, args: ReductionArgs) -> BuiltinResult<MaxEvaluation> {
479    match value {
480        Value::GpuTensor(handle) => {
481            if let Some(eval) = reduction_max_gpu(handle.clone(), &args).await? {
482                return Ok(eval);
483            }
484            // Fall back to host if GPU path is unavailable.
485            let tensor = gpu_helpers::gather_tensor_async(&handle).await?;
486            reduction_max_host(Value::Tensor(tensor), &args)
487        }
488        other => reduction_max_host(other, &args),
489    }
490}
491
492async fn reduction_max_gpu(
493    handle: GpuTensorHandle,
494    args: &ReductionArgs,
495) -> BuiltinResult<Option<MaxEvaluation>> {
496    #[cfg(all(test, feature = "wgpu"))]
497    {
498        if handle.device_id != 0 {
499            let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
500                runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
501            );
502        }
503    }
504    if args.nan_mode == ReductionNaN::Omit {
505        log::trace!("max: gpu path disabled (nan_mode=omit)");
506        return Ok(None);
507    }
508    if args.comparison != ComparisonMethod::Auto {
509        log::trace!("max: gpu path disabled (comparison != auto)");
510        return Ok(None);
511    }
512    if args.linear_index {
513        log::trace!("max: gpu path disabled (linear_index=true)");
514        return Ok(None);
515    }
516    let provider = match runmat_accelerate_api::provider() {
517        Some(p) => p,
518        None => {
519            log::trace!(
520                "max: gpu path unavailable (provider() is None) handle_shape={:?} device_id={}",
521                handle.shape,
522                handle.device_id
523            );
524            return Ok(None);
525        }
526    };
527    let target_dim = match args.selection {
528        DimSelection::Auto => default_dimension_from_shape(&handle.shape),
529        DimSelection::Dim(dim) => dim,
530        DimSelection::Vec(ref dims) if dims.len() == 1 => dims[0],
531        DimSelection::All => {
532            if handle.shape.len() <= 1 {
533                1
534            } else {
535                return Ok(None);
536            }
537        }
538        _ => return Ok(None),
539    };
540    if target_dim == 0 {
541        return Ok(None);
542    }
543    // MATLAB dimensions are 1-based; `reduce_max_dim` expects zero-based.
544    let zero_based = target_dim.saturating_sub(1);
545    if zero_based >= handle.shape.len() {
546        return Ok(None);
547    }
548    log::trace!(
549        "max: attempting reduce_max_dim dim={} (zero_based={}) shape={:?} device_id={}",
550        target_dim,
551        zero_based,
552        handle.shape,
553        handle.device_id
554    );
555    match provider.reduce_max_dim(&handle, zero_based).await {
556        Ok(ReduceDimResult { values, indices }) => Ok(Some(MaxEvaluation {
557            values: Value::GpuTensor(values),
558            indices: Value::GpuTensor(indices),
559        })),
560        Err(err) => {
561            log::trace!("max: reduce_max_dim failed: {err}");
562            Ok(None)
563        }
564    }
565}
566
567fn reduction_max_host(value: Value, args: &ReductionArgs) -> BuiltinResult<MaxEvaluation> {
568    match materialize_for_max("max", value)? {
569        InputData::Real(tensor) => reduce_real_tensor(tensor, args),
570        InputData::Complex(tensor) => reduce_complex_tensor(tensor, args),
571    }
572}
573
574enum InputData {
575    Real(Tensor),
576    Complex(ComplexTensor),
577}
578
579fn materialize_for_max(name: &str, value: Value) -> BuiltinResult<InputData> {
580    match value {
581        Value::Tensor(t) => Ok(InputData::Real(t)),
582        Value::LogicalArray(logical) => {
583            let tensor = tensor::logical_to_tensor(&logical).map_err(|err| max_error(err))?;
584            Ok(InputData::Real(tensor))
585        }
586        Value::Num(n) => {
587            let tensor =
588                Tensor::new(vec![n], vec![1, 1]).map_err(|e| max_error(format!("{name}: {e}")))?;
589            Ok(InputData::Real(tensor))
590        }
591        Value::Int(i) => {
592            let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1])
593                .map_err(|e| max_error(format!("{name}: {e}")))?;
594            Ok(InputData::Real(tensor))
595        }
596        Value::Bool(b) => {
597            let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
598                .map_err(|e| max_error(format!("{name}: {e}")))?;
599            Ok(InputData::Real(tensor))
600        }
601        Value::Complex(re, im) => {
602            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
603                .map_err(|e| max_error(format!("{name}: {e}")))?;
604            Ok(InputData::Complex(tensor))
605        }
606        Value::ComplexTensor(ct) => Ok(InputData::Complex(ct)),
607        Value::String(_) | Value::StringArray(_) | Value::CharArray(_) | Value::Cell(_) => {
608            Err(max_error(format!(
609                "{name}: expected numeric or logical input, received non-numeric value"
610            )))
611        }
612        Value::GpuTensor(_) => Err(max_error(format!(
613            "{name}: internal error – GPU tensors must be gathered before host execution"
614        ))),
615        Value::Object(_) | Value::HandleObject(_) | Value::Struct(_) | Value::Listener(_) => {
616            Err(max_error(format!("{name}: unsupported input type")))
617        }
618        Value::FunctionHandle(_)
619        | Value::Closure(_)
620        | Value::ClassRef(_)
621        | Value::MException(_)
622        | Value::OutputList(_) => Err(max_error(format!("{name}: unsupported input type"))),
623    }
624}
625
626fn reduce_real_tensor(tensor: Tensor, args: &ReductionArgs) -> BuiltinResult<MaxEvaluation> {
627    let shape = tensor.shape.clone();
628    if tensor.data.is_empty() {
629        let output_shape = resolve_output_shape(&shape, &args.selection, &[])?;
630        let values = Tensor::new(Vec::new(), output_shape.clone())
631            .map_err(|e| max_error(format!("max: {e}")))?;
632        let indices =
633            Tensor::new(Vec::new(), output_shape).map_err(|e| max_error(format!("max: {e}")))?;
634        return Ok(MaxEvaluation {
635            values: tensor::tensor_into_value(values),
636            indices: tensor::tensor_into_value(indices),
637        });
638    }
639    let resolved = resolve_reduction_dims(&shape, &args.selection)?;
640    let output_shape = resolved.output_shape.clone();
641    let output_len = tensor::element_count(&output_shape);
642
643    if output_len == 0 {
644        let values = Tensor::new(Vec::new(), output_shape.clone())
645            .map_err(|e| max_error(format!("max: {e}")))?;
646        let indices =
647            Tensor::new(Vec::new(), output_shape).map_err(|e| max_error(format!("max: {e}")))?;
648        return Ok(MaxEvaluation {
649            values: tensor::tensor_into_value(values),
650            indices: tensor::tensor_into_value(indices),
651        });
652    }
653
654    let strides = compute_strides(&shape);
655    let output_strides = compute_strides(&output_shape);
656    let dims_mask = resolved.dims_mask.clone();
657    let reduce_strides = resolved.reduce_strides.clone();
658
659    let mut best = vec![BestReal::new(); output_len];
660    let mut coords = vec![0usize; shape.len()];
661    for &value in &tensor.data {
662        let out_idx = map_output_index(&coords, &output_strides, &dims_mask);
663        let reduce_idx = map_reduce_index(
664            &coords,
665            &resolved.reduced_dims,
666            &reduce_strides,
667            resolved.reduce_all,
668        );
669        let full_idx = map_linear_index(&coords, &strides);
670
671        update_best_real(
672            &mut best[out_idx],
673            value,
674            reduce_idx,
675            full_idx,
676            args.nan_mode,
677            args.comparison,
678        );
679        increment_coords(&mut coords, &shape);
680    }
681
682    let mut values = vec![0.0f64; output_len];
683    let mut indices = vec![0.0f64; output_len];
684
685    for (i, entry) in best.iter().enumerate() {
686        if entry.nan_fixed {
687            values[i] = f64::NAN;
688            indices[i] = if args.linear_index || resolved.reduce_all {
689                (entry.full_index + 1) as f64
690            } else if resolved.reduced_dims.is_empty() {
691                1.0
692            } else {
693                (entry.reduce_index + 1) as f64
694            };
695            continue;
696        }
697        if !entry.has_value {
698            values[i] = f64::NAN;
699            indices[i] = f64::NAN;
700            continue;
701        }
702        values[i] = entry.value;
703        indices[i] = if args.linear_index || resolved.reduce_all {
704            (entry.full_index + 1) as f64
705        } else if resolved.reduced_dims.is_empty() {
706            1.0
707        } else {
708            (entry.reduce_index + 1) as f64
709        };
710    }
711
712    let value_tensor =
713        Tensor::new(values, output_shape.clone()).map_err(|e| max_error(format!("max: {e}")))?;
714    let index_tensor =
715        Tensor::new(indices, output_shape).map_err(|e| max_error(format!("max: {e}")))?;
716
717    Ok(MaxEvaluation {
718        values: tensor::tensor_into_value(value_tensor),
719        indices: tensor::tensor_into_value(index_tensor),
720    })
721}
722
723fn reduce_complex_tensor(
724    tensor: ComplexTensor,
725    args: &ReductionArgs,
726) -> BuiltinResult<MaxEvaluation> {
727    let shape = tensor.shape.clone();
728    if tensor.data.is_empty() {
729        let output_shape = resolve_output_shape(&shape, &args.selection, &[])?;
730        let values = ComplexTensor::new(Vec::new(), output_shape.clone())
731            .map_err(|e| max_error(format!("max: {e}")))?;
732        let indices =
733            Tensor::new(Vec::new(), output_shape).map_err(|e| max_error(format!("max: {e}")))?;
734        return Ok(MaxEvaluation {
735            values: complex_tensor_into_value(values),
736            indices: tensor::tensor_into_value(indices),
737        });
738    }
739
740    let resolved = resolve_reduction_dims(&shape, &args.selection)?;
741    let output_shape = resolved.output_shape.clone();
742    let output_len = tensor::element_count(&output_shape);
743
744    if output_len == 0 {
745        let values = ComplexTensor::new(Vec::new(), output_shape.clone())
746            .map_err(|e| max_error(format!("max: {e}")))?;
747        let indices =
748            Tensor::new(Vec::new(), output_shape).map_err(|e| max_error(format!("max: {e}")))?;
749        return Ok(MaxEvaluation {
750            values: complex_tensor_into_value(values),
751            indices: tensor::tensor_into_value(indices),
752        });
753    }
754
755    let strides = compute_strides(&shape);
756    let output_strides = compute_strides(&output_shape);
757    let dims_mask = resolved.dims_mask.clone();
758    let reduce_strides = resolved.reduce_strides.clone();
759
760    let mut best = vec![BestComplex::new(); output_len];
761    let mut coords = vec![0usize; shape.len()];
762
763    for &(re, im) in &tensor.data {
764        let out_idx = map_output_index(&coords, &output_strides, &dims_mask);
765        let reduce_idx = map_reduce_index(
766            &coords,
767            &resolved.reduced_dims,
768            &reduce_strides,
769            resolved.reduce_all,
770        );
771        let full_idx = map_linear_index(&coords, &strides);
772        update_best_complex(
773            &mut best[out_idx],
774            (re, im),
775            reduce_idx,
776            full_idx,
777            args.nan_mode,
778            args.comparison,
779        );
780        increment_coords(&mut coords, &shape);
781    }
782
783    let mut values = vec![(0.0f64, 0.0f64); output_len];
784    let mut indices = vec![0.0f64; output_len];
785
786    for (i, entry) in best.iter().enumerate() {
787        if entry.nan_fixed {
788            values[i] = (f64::NAN, f64::NAN);
789            indices[i] = if args.linear_index || resolved.reduce_all {
790                (entry.full_index + 1) as f64
791            } else if resolved.reduced_dims.is_empty() {
792                1.0
793            } else {
794                (entry.reduce_index + 1) as f64
795            };
796            continue;
797        }
798        if !entry.has_value {
799            values[i] = (f64::NAN, f64::NAN);
800            indices[i] = f64::NAN;
801            continue;
802        }
803        values[i] = entry.value;
804        indices[i] = if args.linear_index || resolved.reduce_all {
805            (entry.full_index + 1) as f64
806        } else if resolved.reduced_dims.is_empty() {
807            1.0
808        } else {
809            (entry.reduce_index + 1) as f64
810        };
811    }
812
813    let value_tensor = ComplexTensor::new(values, output_shape.clone())
814        .map_err(|e| max_error(format!("max: {e}")))?;
815    let index_tensor =
816        Tensor::new(indices, output_shape).map_err(|e| max_error(format!("max: {e}")))?;
817    Ok(MaxEvaluation {
818        values: complex_tensor_into_value(value_tensor),
819        indices: tensor::tensor_into_value(index_tensor),
820    })
821}
822
823#[derive(Debug, Clone)]
824struct BestReal {
825    value: f64,
826    reduce_index: usize,
827    full_index: usize,
828    has_value: bool,
829    nan_fixed: bool,
830}
831
832impl BestReal {
833    fn new() -> Self {
834        Self {
835            value: 0.0,
836            reduce_index: 0,
837            full_index: 0,
838            has_value: false,
839            nan_fixed: false,
840        }
841    }
842}
843
844#[derive(Debug, Clone)]
845struct BestComplex {
846    value: (f64, f64),
847    reduce_index: usize,
848    full_index: usize,
849    has_value: bool,
850    nan_fixed: bool,
851}
852
853impl BestComplex {
854    fn new() -> Self {
855        Self {
856            value: (0.0, 0.0),
857            reduce_index: 0,
858            full_index: 0,
859            has_value: false,
860            nan_fixed: false,
861        }
862    }
863}
864
865fn resolve_output_shape(
866    shape: &[usize],
867    selection: &DimSelection,
868    reduced_dims: &[usize],
869) -> BuiltinResult<Vec<usize>> {
870    if is_scalar_shape(shape) {
871        return Ok(normalize_scalar_shape(shape));
872    }
873    let mut output = shape.to_vec();
874    match selection {
875        DimSelection::All => {
876            output.fill(1);
877        }
878        _ => {
879            for &dim in reduced_dims {
880                if dim < output.len() {
881                    output[dim] = 1;
882                }
883            }
884        }
885    }
886    Ok(output)
887}
888
889struct ResolvedDims {
890    output_shape: Vec<usize>,
891    reduced_dims: Vec<usize>,
892    reduce_all: bool,
893    dims_mask: Vec<bool>,
894    reduce_strides: Vec<usize>,
895}
896
897fn resolve_reduction_dims(
898    shape: &[usize],
899    selection: &DimSelection,
900) -> BuiltinResult<ResolvedDims> {
901    if is_scalar_shape(shape) {
902        return Ok(ResolvedDims {
903            output_shape: normalize_scalar_shape(shape),
904            reduced_dims: Vec::new(),
905            reduce_all: true,
906            dims_mask: Vec::new(),
907            reduce_strides: Vec::new(),
908        });
909    }
910
911    let mut reduced_dims = match selection {
912        DimSelection::Auto => {
913            let mut dim = None;
914            for (index, &len) in shape.iter().enumerate() {
915                if len > 1 {
916                    dim = Some(index);
917                    break;
918                }
919            }
920            vec![dim.unwrap_or(0)]
921        }
922        DimSelection::Dim(dim) => {
923            if *dim == 0 {
924                return Err(max_error("max: dimension must be >= 1"));
925            }
926            let index = dim.saturating_sub(1);
927            if index >= shape.len() {
928                Vec::new()
929            } else {
930                vec![index]
931            }
932        }
933        DimSelection::Vec(dims) => {
934            if dims.is_empty() {
935                Vec::new()
936            } else {
937                dims.iter()
938                    .filter_map(|dim| {
939                        if *dim == 0 {
940                            None
941                        } else {
942                            let idx = dim - 1;
943                            if idx < shape.len() {
944                                Some(idx)
945                            } else {
946                                None
947                            }
948                        }
949                    })
950                    .collect()
951            }
952        }
953        DimSelection::All => (0..shape.len()).collect(),
954    };
955
956    reduced_dims.sort_unstable();
957    reduced_dims.dedup();
958
959    let reduce_all = !reduced_dims.is_empty()
960        && reduced_dims.len() == shape.len()
961        && reduced_dims.iter().enumerate().all(|(i, &d)| i == d);
962
963    let output_shape = resolve_output_shape(shape, selection, &reduced_dims)?;
964    let mut dims_mask = vec![false; shape.len()];
965    for &dim in &reduced_dims {
966        if dim < dims_mask.len() {
967            dims_mask[dim] = true;
968        }
969    }
970    let reduce_strides = compute_subspace_strides(shape, &reduced_dims);
971
972    Ok(ResolvedDims {
973        output_shape,
974        reduced_dims,
975        reduce_all,
976        dims_mask,
977        reduce_strides,
978    })
979}
980
981fn compute_strides(shape: &[usize]) -> Vec<usize> {
982    let mut strides = Vec::with_capacity(shape.len());
983    let mut stride = 1usize;
984    for &len in shape {
985        strides.push(stride);
986        stride = stride.saturating_mul(len.max(1));
987    }
988    strides
989}
990
991fn compute_subspace_strides(shape: &[usize], dims: &[usize]) -> Vec<usize> {
992    if dims.is_empty() {
993        return Vec::new();
994    }
995    let mut strides = Vec::with_capacity(dims.len());
996    let mut accum = 1usize;
997    for &dim in dims {
998        let len = shape.get(dim).copied().unwrap_or(1).max(1);
999        strides.push(accum);
1000        accum = accum.saturating_mul(len);
1001    }
1002    strides
1003}
1004
1005fn map_output_index(coords: &[usize], output_strides: &[usize], dims_mask: &[bool]) -> usize {
1006    if coords.is_empty() {
1007        return 0;
1008    }
1009    let mut index = 0usize;
1010    for (dim, stride) in output_strides.iter().enumerate() {
1011        let coord = if *dims_mask.get(dim).unwrap_or(&false) {
1012            0
1013        } else {
1014            coords[dim]
1015        };
1016        index = index.saturating_add(coord.saturating_mul(*stride));
1017    }
1018    index
1019}
1020
1021fn map_reduce_index(
1022    coords: &[usize],
1023    reduced_dims: &[usize],
1024    reduce_strides: &[usize],
1025    reduce_all: bool,
1026) -> usize {
1027    if reduced_dims.is_empty() {
1028        return 0;
1029    }
1030    if reduce_all {
1031        // When all dimensions are reduced, the full index is used separately.
1032        return 0;
1033    }
1034    let mut index = 0usize;
1035    for (pos, &dim) in reduced_dims.iter().enumerate() {
1036        if let Some(coord) = coords.get(dim) {
1037            if let Some(stride) = reduce_strides.get(pos) {
1038                index = index.saturating_add(coord.saturating_mul(*stride));
1039            }
1040        }
1041    }
1042    index
1043}
1044
1045fn map_linear_index(coords: &[usize], strides: &[usize]) -> usize {
1046    coords
1047        .iter()
1048        .zip(strides.iter())
1049        .fold(0usize, |acc, (&coord, &stride)| {
1050            acc.saturating_add(coord.saturating_mul(stride))
1051        })
1052}
1053
1054fn increment_coords(coords: &mut [usize], shape: &[usize]) {
1055    for dim in 0..coords.len() {
1056        if shape[dim] == 0 {
1057            continue;
1058        }
1059        coords[dim] += 1;
1060        if coords[dim] < shape[dim] {
1061            break;
1062        }
1063        coords[dim] = 0;
1064    }
1065}
1066
1067fn update_best_real(
1068    best: &mut BestReal,
1069    value: f64,
1070    reduce_index: usize,
1071    full_index: usize,
1072    nan_mode: ReductionNaN,
1073    comparison: ComparisonMethod,
1074) {
1075    if value.is_nan() {
1076        match nan_mode {
1077            ReductionNaN::Include => {
1078                if !best.nan_fixed {
1079                    best.value = f64::NAN;
1080                    best.reduce_index = reduce_index;
1081                    best.full_index = full_index;
1082                    best.has_value = true;
1083                    best.nan_fixed = true;
1084                }
1085            }
1086            ReductionNaN::Omit => {}
1087        }
1088        return;
1089    }
1090    if best.nan_fixed {
1091        return;
1092    }
1093
1094    if !best.has_value {
1095        best.value = value;
1096        best.reduce_index = reduce_index;
1097        best.full_index = full_index;
1098        best.has_value = true;
1099        return;
1100    }
1101
1102    if should_replace_real(best.value, value, comparison) {
1103        best.value = value;
1104        best.reduce_index = reduce_index;
1105        best.full_index = full_index;
1106    }
1107}
1108
1109fn update_best_complex(
1110    best: &mut BestComplex,
1111    value: (f64, f64),
1112    reduce_index: usize,
1113    full_index: usize,
1114    nan_mode: ReductionNaN,
1115    comparison: ComparisonMethod,
1116) {
1117    if value.0.is_nan() || value.1.is_nan() {
1118        match nan_mode {
1119            ReductionNaN::Include => {
1120                if !best.nan_fixed {
1121                    best.value = (f64::NAN, f64::NAN);
1122                    best.reduce_index = reduce_index;
1123                    best.full_index = full_index;
1124                    best.has_value = true;
1125                    best.nan_fixed = true;
1126                }
1127            }
1128            ReductionNaN::Omit => {}
1129        }
1130        return;
1131    }
1132    if best.nan_fixed {
1133        return;
1134    }
1135
1136    if !best.has_value {
1137        best.value = value;
1138        best.reduce_index = reduce_index;
1139        best.full_index = full_index;
1140        best.has_value = true;
1141        return;
1142    }
1143
1144    if should_replace_complex(best.value, value, comparison) {
1145        best.value = value;
1146        best.reduce_index = reduce_index;
1147        best.full_index = full_index;
1148    }
1149}
1150
1151fn should_replace_real(current: f64, candidate: f64, comparison: ComparisonMethod) -> bool {
1152    match comparison {
1153        ComparisonMethod::Auto | ComparisonMethod::Real => {
1154            if candidate > current {
1155                return true;
1156            }
1157            if candidate < current {
1158                return false;
1159            }
1160            if candidate == 0.0 && current == 0.0 {
1161                return candidate.is_sign_positive() && !current.is_sign_positive();
1162            }
1163            false
1164        }
1165        ComparisonMethod::Abs => {
1166            let curr_abs = current.abs();
1167            let cand_abs = candidate.abs();
1168            if cand_abs > curr_abs {
1169                return true;
1170            }
1171            if cand_abs < curr_abs {
1172                return false;
1173            }
1174            if candidate > current {
1175                return true;
1176            }
1177            if candidate < current {
1178                return false;
1179            }
1180            if candidate == 0.0 && current == 0.0 {
1181                return candidate.is_sign_positive() && !current.is_sign_positive();
1182            }
1183            false
1184        }
1185    }
1186}
1187
1188fn should_replace_complex(
1189    current: (f64, f64),
1190    candidate: (f64, f64),
1191    comparison: ComparisonMethod,
1192) -> bool {
1193    match comparison {
1194        ComparisonMethod::Auto | ComparisonMethod::Abs => {
1195            compare_complex_auto(current, candidate) == Ordering::Less
1196        }
1197        ComparisonMethod::Real => compare_complex_real(current, candidate) == Ordering::Less,
1198    }
1199}
1200
1201fn compare_complex_auto(a: (f64, f64), b: (f64, f64)) -> Ordering {
1202    let a_mag = magnitude_squared(a);
1203    let b_mag = magnitude_squared(b);
1204    if a_mag < b_mag {
1205        return Ordering::Less;
1206    }
1207    if a_mag > b_mag {
1208        return Ordering::Greater;
1209    }
1210    // Equal magnitude: tie-break using phase angle.
1211    let a_angle = a.1.atan2(a.0);
1212    let b_angle = b.1.atan2(b.0);
1213    if a_angle < b_angle {
1214        Ordering::Less
1215    } else if a_angle > b_angle {
1216        Ordering::Greater
1217    } else {
1218        Ordering::Equal
1219    }
1220}
1221
1222fn compare_complex_real(a: (f64, f64), b: (f64, f64)) -> Ordering {
1223    if a.0 < b.0 {
1224        return Ordering::Less;
1225    }
1226    if a.0 > b.0 {
1227        return Ordering::Greater;
1228    }
1229    // Equal real parts: use magnitude and phase tie-breakers.
1230    compare_complex_auto(a, b)
1231}
1232
1233fn magnitude_squared(z: (f64, f64)) -> f64 {
1234    z.0.mul_add(z.0, z.1 * z.1)
1235}
1236
1237fn default_dimension_from_shape(shape: &[usize]) -> usize {
1238    if is_scalar_shape(shape) {
1239        return 1;
1240    }
1241    for (i, &len) in shape.iter().enumerate() {
1242        if len > 1 {
1243            return i + 1;
1244        }
1245    }
1246    1
1247}
1248
1249async fn elementwise_max(value: Value, args: ElementwiseArgs) -> BuiltinResult<MaxEvaluation> {
1250    let ElementwiseArgs { other, comparison } = args;
1251    match (value, other) {
1252        (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
1253            if gpu_tensor_is_scalar(&handle_b) {
1254                if let Some(num) = gpu_tensor_scalar_value(&handle_b).await {
1255                    let scalar = Value::Num(num);
1256                    if let Some(eval) =
1257                        elementwise_max_gpu_scalar_left(&handle_a, &scalar, comparison).await
1258                    {
1259                        return Ok(eval);
1260                    }
1261                    if let Ok(ta) = gpu_helpers::gather_tensor_async(&handle_a).await {
1262                        if let Ok(eval) = elementwise_real_or_complex(
1263                            Value::Tensor(ta),
1264                            scalar.clone(),
1265                            comparison,
1266                        ) {
1267                            return Ok(eval);
1268                        }
1269                    }
1270                    return Err(max_error("max: elementwise GPU scalar path failed"));
1271                }
1272            }
1273            if gpu_tensor_is_scalar(&handle_a) {
1274                if let Some(num) = gpu_tensor_scalar_value(&handle_a).await {
1275                    let scalar = Value::Num(num);
1276                    if let Some(eval) =
1277                        elementwise_max_gpu_scalar_right(&scalar, &handle_b, comparison).await
1278                    {
1279                        return Ok(eval);
1280                    }
1281                    if let Ok(tb) = gpu_helpers::gather_tensor_async(&handle_b).await {
1282                        if let Ok(eval) = elementwise_real_or_complex(
1283                            scalar.clone(),
1284                            Value::Tensor(tb),
1285                            comparison,
1286                        ) {
1287                            return Ok(eval);
1288                        }
1289                    }
1290                    return Err(max_error("max: elementwise GPU scalar path failed"));
1291                }
1292            }
1293            if let Some(eval) = elementwise_max_gpu_pair(&handle_a, &handle_b, comparison).await {
1294                return Ok(eval);
1295            }
1296            if let (Ok(ta), Ok(tb)) = (
1297                gpu_helpers::gather_tensor_async(&handle_a).await,
1298                gpu_helpers::gather_tensor_async(&handle_b).await,
1299            ) {
1300                if let Ok(eval) =
1301                    elementwise_real_or_complex(Value::Tensor(ta), Value::Tensor(tb), comparison)
1302                {
1303                    return Ok(eval);
1304                }
1305            }
1306            Err(max_error("max: elementwise GPU path failed"))
1307        }
1308        (Value::GpuTensor(handle), other) => {
1309            if let Some(eval) = elementwise_max_gpu_scalar_left(&handle, &other, comparison).await {
1310                return Ok(eval);
1311            }
1312            let t = gpu_helpers::gather_tensor_async(&handle)
1313                .await
1314                .map_err(|_| max_error("max: elementwise GPU scalar path failed"))?;
1315            elementwise_real_or_complex(Value::Tensor(t), other, comparison)
1316        }
1317        (other, Value::GpuTensor(handle)) => {
1318            if let Some(eval) = elementwise_max_gpu_scalar_right(&other, &handle, comparison).await
1319            {
1320                return Ok(eval);
1321            }
1322            let t = gpu_helpers::gather_tensor_async(&handle)
1323                .await
1324                .map_err(|_| max_error("max: elementwise GPU scalar path failed"))?;
1325            elementwise_real_or_complex(other, Value::Tensor(t), comparison)
1326        }
1327        (lhs, rhs) => elementwise_real_or_complex(lhs, rhs, comparison),
1328    }
1329}
1330
1331async fn elementwise_max_gpu_pair(
1332    a: &GpuTensorHandle,
1333    b: &GpuTensorHandle,
1334    comparison: ComparisonMethod,
1335) -> Option<MaxEvaluation> {
1336    if comparison != ComparisonMethod::Auto {
1337        return None;
1338    }
1339    let provider = runmat_accelerate_api::provider()?;
1340    // Equal-shape fast path
1341    if a.shape == b.shape {
1342        let values = provider.elem_max(a, b).await.ok()?;
1343        // Try device mask first; if unavailable, compute indices on host while keeping values on device
1344        if let Ok(mask) = provider.elem_ge(a, b).await {
1345            let indices = gpu_mask_indices(provider, &mask)?;
1346            let _ = provider.free(&mask);
1347            return Some(MaxEvaluation {
1348                values: Value::GpuTensor(values),
1349                indices: Value::GpuTensor(indices),
1350            });
1351        } else {
1352            // Host path for indices only
1353            let ta = gpu_helpers::gather_tensor_async(a).await.ok()?;
1354            let tb = gpu_helpers::gather_tensor_async(b).await.ok()?;
1355            let mut indices = Vec::with_capacity(ta.data.len());
1356            for i in 0..ta.data.len() {
1357                indices.push(if ta.data[i] >= tb.data[i] { 1.0 } else { 2.0 });
1358            }
1359            let index_tensor = Tensor::new(indices, ta.shape.clone()).ok()?;
1360            return Some(MaxEvaluation {
1361                values: Value::GpuTensor(values),
1362                indices: tensor::tensor_into_value(index_tensor),
1363            });
1364        }
1365    }
1366    // Broadcast-compatible path via repmat, then device compare
1367    let (out_shape, reps_a, reps_b) = broadcast_reps(&a.shape, &b.shape)?;
1368    let a_exp = if reps_a.iter().any(|&r| r != 1) {
1369        provider.repmat(a, &reps_a).ok()?
1370    } else {
1371        a.clone()
1372    };
1373    let b_exp = if reps_b.iter().any(|&r| r != 1) {
1374        provider.repmat(b, &reps_b).ok()?
1375    } else {
1376        b.clone()
1377    };
1378    let values = provider.elem_max(&a_exp, &b_exp).await.ok();
1379    let mask = provider.elem_ge(&a_exp, &b_exp).await.ok();
1380    if !std::ptr::eq(&a_exp, a) {
1381        let _ = provider.free(&a_exp);
1382    }
1383    if !std::ptr::eq(&b_exp, b) {
1384        let _ = provider.free(&b_exp);
1385    }
1386    let values = values?;
1387    if values.shape != out_shape {
1388        let _ = provider.free(&values);
1389        return None;
1390    }
1391    let index_tensor = if let Some(mask) = mask {
1392        let mask_host = gpu_helpers::gather_tensor_async(&mask).await.ok()?;
1393        let _ = provider.free(&mask);
1394        let mut indices = Vec::with_capacity(mask_host.data.len());
1395        for &m in &mask_host.data {
1396            indices.push(if m != 0.0 { 1.0 } else { 2.0 });
1397        }
1398        Tensor::new(indices, out_shape).ok()?
1399    } else {
1400        // Host indices fallback
1401        let ta = gpu_helpers::gather_tensor_async(&a_exp).await.ok()?;
1402        let tb = gpu_helpers::gather_tensor_async(&b_exp).await.ok()?;
1403        let mut indices = Vec::with_capacity(ta.data.len());
1404        for i in 0..ta.data.len() {
1405            indices.push(if ta.data[i] >= tb.data[i] { 1.0 } else { 2.0 });
1406        }
1407        Tensor::new(indices, out_shape).ok()?
1408    };
1409    Some(MaxEvaluation {
1410        values: Value::GpuTensor(values),
1411        indices: tensor::tensor_into_value(index_tensor),
1412    })
1413}
1414
1415fn broadcast_reps(a: &[usize], b: &[usize]) -> Option<(Vec<usize>, Vec<usize>, Vec<usize>)> {
1416    let rank = a.len().max(b.len()).max(1);
1417    let mut out = vec![1usize; rank];
1418    let mut aa = vec![1usize; rank];
1419    let mut bb = vec![1usize; rank];
1420    for i in 0..rank {
1421        aa[i] = *a.get(i).unwrap_or(&1);
1422        bb[i] = *b.get(i).unwrap_or(&1);
1423    }
1424    for i in 0..rank {
1425        let (ad, bd) = (aa[i], bb[i]);
1426        if ad == bd {
1427            out[i] = ad;
1428        } else if ad == 1 {
1429            out[i] = bd;
1430        } else if bd == 1 {
1431            out[i] = ad;
1432        } else {
1433            return None;
1434        }
1435    }
1436    let reps_a: Vec<usize> = (0..rank)
1437        .map(|i| if aa[i] == out[i] { 1 } else { out[i] })
1438        .collect();
1439    let reps_b: Vec<usize> = (0..rank)
1440        .map(|i| if bb[i] == out[i] { 1 } else { out[i] })
1441        .collect();
1442    Some((out, reps_a, reps_b))
1443}
1444
1445async fn elementwise_max_gpu_scalar_left(
1446    a: &GpuTensorHandle,
1447    other: &Value,
1448    comparison: ComparisonMethod,
1449) -> Option<MaxEvaluation> {
1450    if comparison != ComparisonMethod::Auto {
1451        return None;
1452    }
1453    let provider = runmat_accelerate_api::provider()?;
1454    let scalar = extract_scalar(other)?;
1455    // Prefer tensorize + elem_max for broader provider compatibility
1456    let values = if let Ok(fill) = provider.fill_like(a, scalar) {
1457        let vals = provider.elem_max(a, &fill).await.ok();
1458        let _ = provider.free(&fill);
1459        vals?
1460    } else {
1461        provider.scalar_max(a, scalar).ok()?
1462    };
1463    // Try device mask; if unavailable, compute on host
1464    let index_tensor = if let Ok(fill) = provider.fill_like(a, scalar) {
1465        if let Ok(mask) = provider.elem_ge(a, &fill).await {
1466            let _ = provider.free(&fill);
1467            let indices = gpu_mask_indices(provider, &mask)?;
1468            let _ = provider.free(&mask);
1469            return Some(MaxEvaluation {
1470                values: Value::GpuTensor(values),
1471                indices: Value::GpuTensor(indices),
1472            });
1473        } else {
1474            let _ = provider.free(&fill);
1475            let ta = gpu_helpers::gather_tensor_async(a).await.ok()?;
1476            let mut indices = Vec::with_capacity(ta.data.len());
1477            for &v in &ta.data {
1478                indices.push(if v >= scalar { 1.0 } else { 2.0 });
1479            }
1480            Tensor::new(indices, ta.shape.clone()).ok()?
1481        }
1482    } else {
1483        let ta = gpu_helpers::gather_tensor_async(a).await.ok()?;
1484        let mut indices = Vec::with_capacity(ta.data.len());
1485        for &v in &ta.data {
1486            indices.push(if v >= scalar { 1.0 } else { 2.0 });
1487        }
1488        Tensor::new(indices, ta.shape.clone()).ok()?
1489    };
1490    Some(MaxEvaluation {
1491        values: Value::GpuTensor(values),
1492        indices: tensor::tensor_into_value(index_tensor),
1493    })
1494}
1495
1496async fn elementwise_max_gpu_scalar_right(
1497    other: &Value,
1498    b: &GpuTensorHandle,
1499    comparison: ComparisonMethod,
1500) -> Option<MaxEvaluation> {
1501    if comparison != ComparisonMethod::Auto {
1502        return None;
1503    }
1504    let provider = runmat_accelerate_api::provider()?;
1505    let scalar = extract_scalar(other)?;
1506    let values = if let Ok(fill) = provider.fill_like(b, scalar) {
1507        let vals = provider.elem_max(&fill, b).await.ok();
1508        let _ = provider.free(&fill);
1509        vals?
1510    } else {
1511        provider.scalar_max(b, scalar).ok()?
1512    };
1513    // Try device mask; if unavailable, compute on host (origin 1 if scalar >= b)
1514    let index_tensor = if let Ok(fill) = provider.fill_like(b, scalar) {
1515        if let Ok(mask) = provider.elem_ge(&fill, b).await {
1516            let _ = provider.free(&fill);
1517            let indices = gpu_mask_indices(provider, &mask)?;
1518            let _ = provider.free(&mask);
1519            return Some(MaxEvaluation {
1520                values: Value::GpuTensor(values),
1521                indices: Value::GpuTensor(indices),
1522            });
1523        } else {
1524            let _ = provider.free(&fill);
1525            let tb = gpu_helpers::gather_tensor_async(b).await.ok()?;
1526            let mut indices = Vec::with_capacity(tb.data.len());
1527            for &v in &tb.data {
1528                indices.push(if scalar >= v { 1.0 } else { 2.0 });
1529            }
1530            Tensor::new(indices, tb.shape.clone()).ok()?
1531        }
1532    } else {
1533        let tb = gpu_helpers::gather_tensor_async(b).await.ok()?;
1534        let mut indices = Vec::with_capacity(tb.data.len());
1535        for &v in &tb.data {
1536            indices.push(if scalar >= v { 1.0 } else { 2.0 });
1537        }
1538        Tensor::new(indices, tb.shape.clone()).ok()?
1539    };
1540    Some(MaxEvaluation {
1541        values: Value::GpuTensor(values),
1542        indices: tensor::tensor_into_value(index_tensor),
1543    })
1544}
1545
1546fn extract_scalar(v: &Value) -> Option<f64> {
1547    match v {
1548        Value::Num(n) => Some(*n),
1549        Value::Int(i) => Some(i.to_f64()),
1550        Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
1551        Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
1552        Value::LogicalArray(l) if l.data.len() == 1 => Some(if l.data[0] != 0 { 1.0 } else { 0.0 }),
1553        _ => None,
1554    }
1555}
1556
1557fn gpu_tensor_is_scalar(handle: &GpuTensorHandle) -> bool {
1558    handle.shape.iter().copied().product::<usize>().max(1) == 1
1559}
1560
1561async fn gpu_tensor_scalar_value(handle: &GpuTensorHandle) -> Option<f64> {
1562    let tensor = gpu_helpers::gather_tensor_async(handle).await.ok()?;
1563    tensor.data.first().copied()
1564}
1565
1566fn gpu_mask_indices(
1567    provider: &dyn AccelProvider,
1568    mask: &GpuTensorHandle,
1569) -> Option<GpuTensorHandle> {
1570    let scaled = provider.scalar_mul(mask, -1.0).ok()?;
1571    let shifted = provider.scalar_add(&scaled, 2.0).ok()?;
1572    let _ = provider.free(&scaled);
1573    Some(shifted)
1574}
1575
1576fn elementwise_real_or_complex(
1577    lhs: Value,
1578    rhs: Value,
1579    comparison: ComparisonMethod,
1580) -> BuiltinResult<MaxEvaluation> {
1581    if let Some(eval) = scalar_elementwise_max(&lhs, &rhs, comparison) {
1582        return Ok(eval);
1583    }
1584    match (
1585        materialize_for_max("max", lhs)?,
1586        materialize_for_max("max", rhs)?,
1587    ) {
1588        (InputData::Complex(a), InputData::Complex(b)) => elementwise_complex_max(a, b, comparison),
1589        (InputData::Complex(a), InputData::Real(b)) => {
1590            let converted = promote_real_tensor_to_complex(b);
1591            elementwise_complex_max(a, converted, comparison)
1592        }
1593        (InputData::Real(a), InputData::Complex(b)) => {
1594            let converted = promote_real_tensor_to_complex(a);
1595            elementwise_complex_max(converted, b, comparison)
1596        }
1597        (InputData::Real(a), InputData::Real(b)) => elementwise_real_max(a, b, comparison),
1598    }
1599}
1600
1601fn scalar_real_value(value: &Value) -> Option<f64> {
1602    match value {
1603        Value::Num(n) => Some(*n),
1604        Value::Int(i) => Some(i.to_f64()),
1605        Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
1606        Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
1607        Value::LogicalArray(l) if l.data.len() == 1 => Some(if l.data[0] != 0 { 1.0 } else { 0.0 }),
1608        _ => None,
1609    }
1610}
1611
1612fn scalar_complex_value(value: &Value) -> Option<(f64, f64)> {
1613    match value {
1614        Value::Complex(re, im) => Some((*re, *im)),
1615        Value::ComplexTensor(ct) if ct.data.len() == 1 => ct.data.first().copied(),
1616        _ => None,
1617    }
1618}
1619
1620fn scalar_elementwise_max(
1621    lhs: &Value,
1622    rhs: &Value,
1623    comparison: ComparisonMethod,
1624) -> Option<MaxEvaluation> {
1625    let left = scalar_complex_value(lhs).or_else(|| scalar_real_value(lhs).map(|v| (v, 0.0)))?;
1626    let right = scalar_complex_value(rhs).or_else(|| scalar_real_value(rhs).map(|v| (v, 0.0)))?;
1627    let (ar, ai) = left;
1628    let (br, bi) = right;
1629    if ai != 0.0 || bi != 0.0 {
1630        let (value, origin) = choose_complex_elementwise((ar, ai), (br, bi), comparison);
1631        return Some(MaxEvaluation {
1632            values: Value::Complex(value.0, value.1),
1633            indices: Value::Num(origin),
1634        });
1635    }
1636    let (value, origin) = choose_real_elementwise(ar, br, comparison);
1637    Some(MaxEvaluation {
1638        values: Value::Num(value),
1639        indices: Value::Num(origin),
1640    })
1641}
1642
1643fn elementwise_real_max(
1644    lhs: Tensor,
1645    rhs: Tensor,
1646    comparison: ComparisonMethod,
1647) -> BuiltinResult<MaxEvaluation> {
1648    let plan = BroadcastPlan::new(&lhs.shape, &rhs.shape).map_err(|err| format!("max: {}", err))?;
1649    let mut values = vec![0.0f64; plan.len()];
1650    let mut indices = vec![0.0f64; plan.len()];
1651
1652    for (offset, index_a, index_b) in plan.iter() {
1653        let a = lhs.data.get(index_a).copied().unwrap_or(f64::NAN);
1654        let b = rhs.data.get(index_b).copied().unwrap_or(f64::NAN);
1655        let (value, origin) = choose_real_elementwise(a, b, comparison);
1656        values[offset] = value;
1657        indices[offset] = origin;
1658    }
1659
1660    let value_tensor = Tensor::new(values, plan.output_shape().to_vec())
1661        .map_err(|e| max_error(format!("max: {e}")))?;
1662    let index_tensor = Tensor::new(indices, plan.output_shape().to_vec())
1663        .map_err(|e| max_error(format!("max: {e}")))?;
1664
1665    Ok(MaxEvaluation {
1666        values: tensor::tensor_into_value(value_tensor),
1667        indices: tensor::tensor_into_value(index_tensor),
1668    })
1669}
1670
1671fn elementwise_complex_max(
1672    lhs: ComplexTensor,
1673    rhs: ComplexTensor,
1674    comparison: ComparisonMethod,
1675) -> BuiltinResult<MaxEvaluation> {
1676    let plan = BroadcastPlan::new(&lhs.shape, &rhs.shape).map_err(|err| format!("max: {}", err))?;
1677    let mut values = vec![(0.0f64, 0.0f64); plan.len()];
1678    let mut indices = vec![0.0f64; plan.len()];
1679
1680    for (offset, index_a, index_b) in plan.iter() {
1681        let a = lhs
1682            .data
1683            .get(index_a)
1684            .copied()
1685            .unwrap_or((f64::NAN, f64::NAN));
1686        let b = rhs
1687            .data
1688            .get(index_b)
1689            .copied()
1690            .unwrap_or((f64::NAN, f64::NAN));
1691        let (value, origin) = choose_complex_elementwise(a, b, comparison);
1692        values[offset] = value;
1693        indices[offset] = origin;
1694    }
1695
1696    let value_tensor = ComplexTensor::new(values, plan.output_shape().to_vec())
1697        .map_err(|e| max_error(format!("max: {e}")))?;
1698    let index_tensor = Tensor::new(indices, plan.output_shape().to_vec())
1699        .map_err(|e| max_error(format!("max: {e}")))?;
1700
1701    Ok(MaxEvaluation {
1702        values: complex_tensor_into_value(value_tensor),
1703        indices: tensor::tensor_into_value(index_tensor),
1704    })
1705}
1706
1707fn promote_real_tensor_to_complex(tensor: Tensor) -> ComplexTensor {
1708    let data = tensor
1709        .data
1710        .iter()
1711        .copied()
1712        .map(|re| (re, 0.0))
1713        .collect::<Vec<_>>();
1714    ComplexTensor {
1715        data,
1716        shape: tensor.shape.clone(),
1717        rows: tensor.rows,
1718        cols: tensor.cols,
1719    }
1720}
1721
1722fn choose_real_elementwise(a: f64, b: f64, comparison: ComparisonMethod) -> (f64, f64) {
1723    match (a.is_nan(), b.is_nan()) {
1724        (true, true) => (f64::NAN, 1.0),
1725        (true, false) => (f64::NAN, 1.0),
1726        (false, true) => (f64::NAN, 2.0),
1727        (false, false) => {
1728            if should_replace_real(a, b, comparison) {
1729                (b, 2.0)
1730            } else {
1731                (a, 1.0)
1732            }
1733        }
1734    }
1735}
1736
1737fn choose_complex_elementwise(
1738    a: (f64, f64),
1739    b: (f64, f64),
1740    comparison: ComparisonMethod,
1741) -> ((f64, f64), f64) {
1742    let a_nan = a.0.is_nan() || a.1.is_nan();
1743    let b_nan = b.0.is_nan() || b.1.is_nan();
1744    match (a_nan, b_nan) {
1745        (true, true) => ((f64::NAN, f64::NAN), 1.0),
1746        (true, false) => ((f64::NAN, f64::NAN), 1.0),
1747        (false, true) => ((f64::NAN, f64::NAN), 2.0),
1748        (false, false) => {
1749            if should_replace_complex(a, b, comparison) {
1750                (b, 2.0)
1751            } else {
1752                (a, 1.0)
1753            }
1754        }
1755    }
1756}
1757
1758#[cfg(test)]
1759pub(crate) mod tests {
1760    use super::*;
1761    #[cfg(feature = "wgpu")]
1762    use crate::builtins::common::test_support;
1763    use futures::executor::block_on;
1764    #[cfg(feature = "wgpu")]
1765    use runmat_accelerate_api::HostTensorView;
1766    use runmat_builtins::{IntValue, Tensor, Value};
1767
1768    fn max_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
1769        block_on(super::max_builtin(value, rest))
1770    }
1771
1772    #[test]
1773    fn max_type_with_two_args_returns_tensor() {
1774        let out = max_type(
1775            &[Type::Tensor { shape: None }, Type::Tensor { shape: None }],
1776            &ResolveContext::new(Vec::new()),
1777        );
1778        assert_eq!(out, Type::tensor());
1779    }
1780
1781    fn evaluate(value: Value, rest: &[Value]) -> BuiltinResult<MaxEvaluation> {
1782        block_on(super::evaluate(value, rest))
1783    }
1784
1785    fn placeholder() -> Value {
1786        let tensor = Tensor::new(Vec::<f64>::new(), vec![0, 0]).unwrap();
1787        Value::Tensor(tensor)
1788    }
1789
1790    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1791    #[test]
1792    fn max_scalar_returns_input() {
1793        let result = max_builtin(Value::Num(5.0), Vec::new()).expect("max");
1794        assert_eq!(result, Value::Num(5.0));
1795    }
1796
1797    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1798    #[test]
1799    fn max_vector_with_indices() {
1800        let tensor = Tensor::new(vec![3.0, 1.0, 5.0], vec![3, 1]).unwrap();
1801        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1802        let (values, indices) = eval.into_pair();
1803        assert_eq!(values, Value::Num(5.0));
1804        assert_eq!(indices, Value::Num(3.0));
1805    }
1806
1807    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1808    #[test]
1809    fn max_matrix_default_dimension() {
1810        let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0, 5.0, 6.0], vec![2, 3]).unwrap();
1811        let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
1812        let (values, indices) = eval.into_pair();
1813        match values {
1814            Value::Tensor(t) => {
1815                assert_eq!(t.shape, vec![1, 3]);
1816                assert_eq!(t.data, vec![4.0, 2.0, 6.0]);
1817            }
1818            other => panic!("expected tensor, got {other:?}"),
1819        }
1820        match indices {
1821            Value::Tensor(t) => {
1822                assert_eq!(t.data, vec![2.0, 2.0, 2.0]);
1823            }
1824            other => panic!("expected tensor, got {other:?}"),
1825        }
1826    }
1827
1828    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1829    #[test]
1830    fn max_all_linear_index() {
1831        let tensor =
1832            Tensor::new((1..=12).map(|v| v as f64).collect::<Vec<_>>(), vec![3, 4]).unwrap();
1833        let args = vec![placeholder(), Value::from("all")];
1834        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1835        let (values, indices) = eval.into_pair();
1836        assert_eq!(values, Value::Num(12.0));
1837        assert_eq!(indices, Value::Num(12.0));
1838
1839        let args_linear = vec![placeholder(), Value::from("linear")];
1840        let eval = evaluate(
1841            Value::Tensor(Tensor::new(vec![2.0, 3.0], vec![1, 2]).unwrap()),
1842            &args_linear,
1843        )
1844        .expect("evaluate");
1845        let (values, indices) = eval.into_pair();
1846        assert_eq!(values, Value::Num(3.0));
1847        assert_eq!(indices, Value::Num(2.0));
1848    }
1849
1850    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1851    #[test]
1852    fn max_with_omitnan() {
1853        let tensor = Tensor::new(vec![f64::NAN, 4.0, 2.0], vec![3, 1]).unwrap();
1854        let args = vec![placeholder(), Value::from("omitnan")];
1855        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1856        let (values, indices) = eval.into_pair();
1857        assert_eq!(values, Value::Num(4.0));
1858        assert_eq!(indices, Value::Num(2.0));
1859    }
1860
1861    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1862    #[test]
1863    fn max_omitnan_all_nan_slice() {
1864        let tensor = Tensor::new(vec![f64::NAN, f64::NAN], vec![2, 1]).unwrap();
1865        let args = vec![placeholder(), Value::from("omitnan")];
1866        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1867        let (values, indices) = eval.into_pair();
1868        match values {
1869            Value::Num(v) => assert!(v.is_nan()),
1870            other => panic!("expected scalar NaN, got {other:?}"),
1871        }
1872        match indices {
1873            Value::Num(v) => assert!(v.is_nan()),
1874            other => panic!("expected scalar NaN index, got {other:?}"),
1875        }
1876    }
1877
1878    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1879    #[test]
1880    fn max_reduction_abs_comparison() {
1881        let tensor = Tensor::new(vec![1.0, -3.0, -2.0, 4.0], vec![2, 2]).unwrap();
1882        let args = vec![
1883            placeholder(),
1884            Value::from("ComparisonMethod"),
1885            Value::from("abs"),
1886        ];
1887        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
1888        let (values, indices) = eval.into_pair();
1889        match values {
1890            Value::Tensor(t) => {
1891                assert_eq!(t.shape, vec![1, 2]);
1892                assert_eq!(t.data, vec![-3.0, 4.0]);
1893            }
1894            other => panic!("expected tensor result, got {other:?}"),
1895        }
1896        match indices {
1897            Value::Tensor(t) => {
1898                assert_eq!(t.data, vec![2.0, 2.0]);
1899            }
1900            other => panic!("expected tensor indices, got {other:?}"),
1901        }
1902    }
1903
1904    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1905    #[test]
1906    fn max_reduction_complex_real_comparison() {
1907        let tensor = ComplexTensor::new(vec![(1.0, 2.0), (0.5, 5.0)], vec![2, 1]).expect("tensor");
1908        let args = vec![
1909            placeholder(),
1910            Value::from("ComparisonMethod"),
1911            Value::from("real"),
1912        ];
1913        let eval = evaluate(Value::ComplexTensor(tensor), &args).expect("evaluate");
1914        let (values, indices) = eval.into_pair();
1915        match values {
1916            Value::Complex(re, im) => {
1917                assert!((re - 1.0).abs() < 1e-12);
1918                assert!((im - 2.0).abs() < 1e-12);
1919            }
1920            other => panic!("expected complex scalar, got {other:?}"),
1921        }
1922        assert_eq!(indices, Value::Num(1.0));
1923    }
1924
1925    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1926    #[test]
1927    fn max_elementwise_broadcast() {
1928        let lhs = Tensor::new(vec![1.0, 4.0, 7.0], vec![1, 3]).unwrap();
1929        let rhs = Tensor::new(vec![2.0, 3.0, 5.0], vec![3, 1]).unwrap();
1930        let eval = evaluate(Value::Tensor(lhs), &[Value::Tensor(rhs)]).expect("evaluate");
1931        let (values, indices) = eval.into_pair();
1932        match values {
1933            Value::Tensor(t) => {
1934                assert_eq!(t.shape, vec![3, 3]);
1935                assert_eq!([t.data[0], t.data[3], t.data[6]], [2.0, 4.0, 7.0]);
1936                assert_eq!([t.data[1], t.data[4], t.data[7]], [3.0, 4.0, 7.0]);
1937                assert_eq!([t.data[2], t.data[5], t.data[8]], [5.0, 5.0, 7.0]);
1938            }
1939            other => panic!("expected tensor, got {other:?}"),
1940        }
1941        match indices {
1942            Value::Tensor(t) => {
1943                assert_eq!(t.shape, vec![3, 3]);
1944                assert_eq!([t.data[0], t.data[3], t.data[6]], [2.0, 1.0, 1.0]);
1945                assert_eq!([t.data[1], t.data[4], t.data[7]], [2.0, 1.0, 1.0]);
1946                assert_eq!([t.data[2], t.data[5], t.data[8]], [2.0, 2.0, 1.0]);
1947            }
1948            other => panic!("expected tensor, got {other:?}"),
1949        }
1950    }
1951
1952    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1953    #[test]
1954    fn max_elementwise_abs_comparison() {
1955        let lhs = Tensor::new(vec![-2.0, 1.0], vec![2, 1]).unwrap();
1956        let rhs = Tensor::new(vec![1.5, -3.0], vec![2, 1]).unwrap();
1957        let args = vec![
1958            Value::Tensor(rhs),
1959            Value::from("ComparisonMethod"),
1960            Value::from("abs"),
1961        ];
1962        let eval = evaluate(Value::Tensor(lhs), &args).expect("evaluate");
1963        let (values, indices) = eval.into_pair();
1964        match values {
1965            Value::Tensor(t) => {
1966                assert_eq!(t.data, vec![-2.0, -3.0]);
1967            }
1968            other => panic!("expected tensor, got {other:?}"),
1969        }
1970        match indices {
1971            Value::Tensor(t) => {
1972                assert_eq!(t.data, vec![1.0, 2.0]);
1973            }
1974            other => panic!("expected tensor, got {other:?}"),
1975        }
1976    }
1977
1978    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1979    #[test]
1980    fn max_elementwise_rejects_reduction_only_keywords() {
1981        let lhs = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1982        let rhs = Tensor::new(vec![3.0, 4.0], vec![2, 1]).unwrap();
1983        let err = evaluate(
1984            Value::Tensor(lhs),
1985            &[Value::Tensor(rhs), Value::from("omitnan")],
1986        )
1987        .expect_err("expected error");
1988        assert!(err.message().contains("only supported for reduction"));
1989    }
1990
1991    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1992    #[test]
1993    fn max_complex_real_comparison() {
1994        let lhs = ComplexTensor::new(vec![(1.0, 2.0)], vec![1, 1]).unwrap();
1995        let rhs = ComplexTensor::new(vec![(0.5, 5.0)], vec![1, 1]).unwrap();
1996        let args = vec![
1997            Value::ComplexTensor(rhs),
1998            Value::from("ComparisonMethod"),
1999            Value::from("real"),
2000        ];
2001        let eval = evaluate(Value::ComplexTensor(lhs), &args).expect("evaluate");
2002        let (values, indices) = eval.into_pair();
2003        assert_eq!(values, Value::Complex(1.0, 2.0));
2004        assert_eq!(indices, Value::Num(1.0));
2005    }
2006
2007    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2008    #[test]
2009    fn max_dimension_argument_parsing() {
2010        let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0], vec![2, 2]).unwrap();
2011        let dims = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
2012        let args = vec![placeholder(), Value::Tensor(dims)];
2013        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2014        let (values, indices) = eval.into_pair();
2015        assert_eq!(values, Value::Num(4.0));
2016        assert_eq!(indices, Value::Num(2.0));
2017    }
2018
2019    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2020    #[test]
2021    fn max_vecdim_duplicate_entries() {
2022        let tensor = Tensor::new(vec![5.0, 2.0, 7.0, 1.0], vec![2, 2]).unwrap();
2023        let dims = Tensor::new(vec![1.0, 1.0, 2.0], vec![3, 1]).unwrap();
2024        let args = vec![placeholder(), Value::Tensor(dims)];
2025        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2026        let (values, indices) = eval.into_pair();
2027        assert_eq!(values, Value::Num(7.0));
2028        assert_eq!(indices, Value::Num(3.0));
2029    }
2030
2031    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2032    #[test]
2033    fn max_dimension_gpu_argument_errors() {
2034        let tensor = Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap();
2035        let dim_handle = Value::GpuTensor(runmat_accelerate_api::GpuTensorHandle {
2036            shape: vec![1, 1],
2037            device_id: 0,
2038            buffer_id: 42,
2039        });
2040        let err = evaluate(Value::Tensor(tensor), &[placeholder(), dim_handle])
2041            .expect_err("expected error");
2042        assert!(err
2043            .message()
2044            .contains("dimension arguments must reside on the host"));
2045    }
2046
2047    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2048    #[test]
2049    fn max_invalid_comparison_method_errors() {
2050        let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
2051        let args = vec![
2052            placeholder(),
2053            Value::from("ComparisonMethod"),
2054            Value::from("chebyshev"),
2055        ];
2056        let err = evaluate(Value::Tensor(tensor), &args).expect_err("expected error");
2057        assert!(err.message().contains("unsupported ComparisonMethod"));
2058    }
2059
2060    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2061    #[test]
2062    #[cfg(feature = "wgpu")]
2063    fn max_gpu_dim1_matches_cpu() {
2064        let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2065        let eval_cpu = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu");
2066        let (values_cpu, indices_cpu) = eval_cpu.into_pair();
2067
2068        test_support::with_test_provider(|provider| {
2069            let view = HostTensorView {
2070                data: &tensor.data,
2071                shape: &tensor.shape,
2072            };
2073            let handle = provider.upload(&view).expect("upload");
2074            let eval_gpu = evaluate(Value::GpuTensor(handle), &[]).expect("gpu");
2075            let (values_gpu, indices_gpu) = eval_gpu.into_pair();
2076            match (&values_gpu, &indices_gpu) {
2077                (Value::GpuTensor(_), Value::GpuTensor(_)) => {}
2078                other => panic!("expected GPU tensors, got {other:?}"),
2079            }
2080            let gathered_vals = test_support::gather(values_gpu).expect("gather values");
2081            let gathered_idx = test_support::gather(indices_gpu).expect("gather indices");
2082            let expected_vals = match values_cpu {
2083                Value::Tensor(t) => t,
2084                other => panic!("expected tensor values from cpu eval, got {other:?}"),
2085            };
2086            let expected_idx = match indices_cpu {
2087                Value::Tensor(t) => t,
2088                other => panic!("expected tensor indices from cpu eval, got {other:?}"),
2089            };
2090            assert_eq!(gathered_vals.shape, expected_vals.shape);
2091            assert_eq!(gathered_vals.data, expected_vals.data);
2092            assert_eq!(gathered_idx.shape, expected_idx.shape);
2093            assert_eq!(gathered_idx.data, expected_idx.data);
2094        });
2095    }
2096
2097    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2098    #[test]
2099    fn max_dimension_numeric_argument() {
2100        let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0], vec![2, 2]).unwrap();
2101        let args = vec![placeholder(), Value::Num(2.0)];
2102        let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2103        let (values, indices) = eval.into_pair();
2104        match values {
2105            Value::Tensor(t) => {
2106                assert_eq!(t.shape, vec![2, 1]);
2107                assert_eq!(t.data, vec![3.0, 4.0]);
2108            }
2109            other => panic!("expected tensor, got {other:?}"),
2110        }
2111        match indices {
2112            Value::Tensor(t) => {
2113                assert_eq!(t.data, vec![1.0, 1.0]);
2114            }
2115            other => panic!("expected tensor, got {other:?}"),
2116        }
2117    }
2118
2119    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2120    #[test]
2121    fn max_complex_auto_comparison() {
2122        let lhs = ComplexTensor::new(vec![(1.0, 2.0)], vec![1, 1]).unwrap();
2123        let rhs = ComplexTensor::new(vec![(2.0, 1.0)], vec![1, 1]).unwrap();
2124        let eval =
2125            evaluate(Value::ComplexTensor(lhs), &[Value::ComplexTensor(rhs)]).expect("evaluate");
2126        let (values, indices) = eval.into_pair();
2127        assert_eq!(values, Value::Complex(1.0, 2.0));
2128        assert_eq!(indices, Value::Num(1.0));
2129    }
2130
2131    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2132    #[test]
2133    fn max_scalar_pair_arguments() {
2134        let args = vec![Value::Num(2.0)];
2135        let result = max_builtin(Value::Num(3.0), args).expect("max");
2136        assert_eq!(result, Value::Num(3.0));
2137    }
2138
2139    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2140    #[test]
2141    fn max_rejects_invalid_dimension() {
2142        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
2143        let args = vec![placeholder(), Value::Int(IntValue::I32(0))];
2144        let err = evaluate(Value::Tensor(tensor), &args).expect_err("expected error");
2145        assert!(err.message().contains("dimension must be >= 1"));
2146    }
2147}