1use std::cmp::Ordering;
4use std::collections::BTreeSet;
5
6use runmat_accelerate_api::{AccelProvider, GpuTensorHandle, ReduceDimResult};
7use runmat_builtins::{
8 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
9 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
10 ComplexTensor, ResolveContext, Tensor, Type, Value,
11};
12use runmat_macros::runtime_builtin;
13
14use crate::{build_runtime_error, BuiltinResult, RuntimeError};
15
16const NAME: &str = "max";
17
18fn max_type(args: &[Type], ctx: &ResolveContext) -> Type {
19 min_max_type(args, ctx)
20}
21
22const MAX_OUTPUT_M: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
23 name: "M",
24 ty: BuiltinParamType::NumericArray,
25 arity: BuiltinParamArity::Required,
26 default: None,
27 description: "Maximum values.",
28}];
29
30const MAX_OUTPUT_MI: [BuiltinParamDescriptor; 2] = [
31 BuiltinParamDescriptor {
32 name: "M",
33 ty: BuiltinParamType::NumericArray,
34 arity: BuiltinParamArity::Required,
35 default: None,
36 description: "Maximum values.",
37 },
38 BuiltinParamDescriptor {
39 name: "I",
40 ty: BuiltinParamType::NumericArray,
41 arity: BuiltinParamArity::Required,
42 default: None,
43 description: "One-based maximum indices/origins.",
44 },
45];
46
47const MAX_PARAM_A: BuiltinParamDescriptor = BuiltinParamDescriptor {
48 name: "A",
49 ty: BuiltinParamType::Any,
50 arity: BuiltinParamArity::Required,
51 default: None,
52 description: "Input scalar or array.",
53};
54
55const MAX_PARAM_B: BuiltinParamDescriptor = BuiltinParamDescriptor {
56 name: "B",
57 ty: BuiltinParamType::Any,
58 arity: BuiltinParamArity::Required,
59 default: None,
60 description: "Second operand for element-wise maximum.",
61};
62
63const MAX_PARAM_EMPTY: BuiltinParamDescriptor = BuiltinParamDescriptor {
64 name: "placeholder",
65 ty: BuiltinParamType::Any,
66 arity: BuiltinParamArity::Optional,
67 default: Some("[]"),
68 description: "Empty placeholder selecting reduction-argument grammar.",
69};
70
71const MAX_PARAM_DIM: BuiltinParamDescriptor = BuiltinParamDescriptor {
72 name: "dim",
73 ty: BuiltinParamType::Any,
74 arity: BuiltinParamArity::Optional,
75 default: Some("[]"),
76 description: "Reduction dimension selector (scalar or dimension vector).",
77};
78
79const MAX_PARAM_REDUCTION_FLAG: BuiltinParamDescriptor = BuiltinParamDescriptor {
80 name: "flag",
81 ty: BuiltinParamType::StringScalar,
82 arity: BuiltinParamArity::Optional,
83 default: Some("\"all\""),
84 description: "Reduction mode flag: \"all\" or \"linear\".",
85};
86
87const MAX_PARAM_NANFLAG: BuiltinParamDescriptor = BuiltinParamDescriptor {
88 name: "nanflag",
89 ty: BuiltinParamType::StringScalar,
90 arity: BuiltinParamArity::Optional,
91 default: Some("\"includenan\""),
92 description: "Missing-value mode: \"includenan\" or \"omitnan\".",
93};
94
95const MAX_PARAM_COMPARISON_NAME: BuiltinParamDescriptor = BuiltinParamDescriptor {
96 name: "optionName",
97 ty: BuiltinParamType::StringScalar,
98 arity: BuiltinParamArity::Optional,
99 default: Some("\"ComparisonMethod\""),
100 description: "Option name (currently \"ComparisonMethod\").",
101};
102
103const MAX_PARAM_COMPARISON_VALUE: BuiltinParamDescriptor = BuiltinParamDescriptor {
104 name: "method",
105 ty: BuiltinParamType::StringScalar,
106 arity: BuiltinParamArity::Optional,
107 default: Some("\"auto\""),
108 description: "Comparison method: \"auto\", \"abs\"/\"magnitude\", or \"real\".",
109};
110
111const MAX_PARAM_OPTION_NAME: BuiltinParamDescriptor = BuiltinParamDescriptor {
112 name: "optionName",
113 ty: BuiltinParamType::StringScalar,
114 arity: BuiltinParamArity::Variadic,
115 default: None,
116 description: "Name-value option name.",
117};
118
119const MAX_PARAM_OPTION_VALUE: BuiltinParamDescriptor = BuiltinParamDescriptor {
120 name: "optionValue",
121 ty: BuiltinParamType::Any,
122 arity: BuiltinParamArity::Variadic,
123 default: None,
124 description: "Name-value option value.",
125};
126
127const MAX_INPUTS_A: [BuiltinParamDescriptor; 1] = [MAX_PARAM_A];
128const MAX_INPUTS_A_B: [BuiltinParamDescriptor; 2] = [MAX_PARAM_A, MAX_PARAM_B];
129const MAX_INPUTS_A_EMPTY_DIM: [BuiltinParamDescriptor; 3] =
130 [MAX_PARAM_A, MAX_PARAM_EMPTY, MAX_PARAM_DIM];
131const MAX_INPUTS_A_EMPTY_FLAG: [BuiltinParamDescriptor; 3] =
132 [MAX_PARAM_A, MAX_PARAM_EMPTY, MAX_PARAM_REDUCTION_FLAG];
133const MAX_INPUTS_A_EMPTY_NANFLAG: [BuiltinParamDescriptor; 3] =
134 [MAX_PARAM_A, MAX_PARAM_EMPTY, MAX_PARAM_NANFLAG];
135const MAX_INPUTS_A_EMPTY_COMPARISON: [BuiltinParamDescriptor; 4] = [
136 MAX_PARAM_A,
137 MAX_PARAM_EMPTY,
138 MAX_PARAM_COMPARISON_NAME,
139 MAX_PARAM_COMPARISON_VALUE,
140];
141const MAX_INPUTS_A_B_COMPARISON: [BuiltinParamDescriptor; 4] = [
142 MAX_PARAM_A,
143 MAX_PARAM_B,
144 MAX_PARAM_COMPARISON_NAME,
145 MAX_PARAM_COMPARISON_VALUE,
146];
147const MAX_INPUTS_A_EMPTY_OPTIONS: [BuiltinParamDescriptor; 4] = [
148 MAX_PARAM_A,
149 MAX_PARAM_EMPTY,
150 MAX_PARAM_OPTION_NAME,
151 MAX_PARAM_OPTION_VALUE,
152];
153const MAX_INPUTS_A_B_OPTIONS: [BuiltinParamDescriptor; 4] = [
154 MAX_PARAM_A,
155 MAX_PARAM_B,
156 MAX_PARAM_OPTION_NAME,
157 MAX_PARAM_OPTION_VALUE,
158];
159
160const MAX_SIGNATURES: [BuiltinSignatureDescriptor; 22] = [
161 BuiltinSignatureDescriptor {
162 label: "M = max(A)",
163 inputs: &MAX_INPUTS_A,
164 outputs: &MAX_OUTPUT_M,
165 },
166 BuiltinSignatureDescriptor {
167 label: "[M, I] = max(A)",
168 inputs: &MAX_INPUTS_A,
169 outputs: &MAX_OUTPUT_MI,
170 },
171 BuiltinSignatureDescriptor {
172 label: "M = max(A, B)",
173 inputs: &MAX_INPUTS_A_B,
174 outputs: &MAX_OUTPUT_M,
175 },
176 BuiltinSignatureDescriptor {
177 label: "[M, I] = max(A, B)",
178 inputs: &MAX_INPUTS_A_B,
179 outputs: &MAX_OUTPUT_MI,
180 },
181 BuiltinSignatureDescriptor {
182 label: "M = max(A, [], dim)",
183 inputs: &MAX_INPUTS_A_EMPTY_DIM,
184 outputs: &MAX_OUTPUT_M,
185 },
186 BuiltinSignatureDescriptor {
187 label: "[M, I] = max(A, [], dim)",
188 inputs: &MAX_INPUTS_A_EMPTY_DIM,
189 outputs: &MAX_OUTPUT_MI,
190 },
191 BuiltinSignatureDescriptor {
192 label: "M = max(A, [], vecdim)",
193 inputs: &MAX_INPUTS_A_EMPTY_DIM,
194 outputs: &MAX_OUTPUT_M,
195 },
196 BuiltinSignatureDescriptor {
197 label: "[M, I] = max(A, [], vecdim)",
198 inputs: &MAX_INPUTS_A_EMPTY_DIM,
199 outputs: &MAX_OUTPUT_MI,
200 },
201 BuiltinSignatureDescriptor {
202 label: "M = max(A, [], \"all\")",
203 inputs: &MAX_INPUTS_A_EMPTY_FLAG,
204 outputs: &MAX_OUTPUT_M,
205 },
206 BuiltinSignatureDescriptor {
207 label: "[M, I] = max(A, [], \"all\")",
208 inputs: &MAX_INPUTS_A_EMPTY_FLAG,
209 outputs: &MAX_OUTPUT_MI,
210 },
211 BuiltinSignatureDescriptor {
212 label: "M = max(A, [], \"linear\")",
213 inputs: &MAX_INPUTS_A_EMPTY_FLAG,
214 outputs: &MAX_OUTPUT_M,
215 },
216 BuiltinSignatureDescriptor {
217 label: "[M, I] = max(A, [], \"linear\")",
218 inputs: &MAX_INPUTS_A_EMPTY_FLAG,
219 outputs: &MAX_OUTPUT_MI,
220 },
221 BuiltinSignatureDescriptor {
222 label: "M = max(A, [], nanflag)",
223 inputs: &MAX_INPUTS_A_EMPTY_NANFLAG,
224 outputs: &MAX_OUTPUT_M,
225 },
226 BuiltinSignatureDescriptor {
227 label: "[M, I] = max(A, [], nanflag)",
228 inputs: &MAX_INPUTS_A_EMPTY_NANFLAG,
229 outputs: &MAX_OUTPUT_MI,
230 },
231 BuiltinSignatureDescriptor {
232 label: "M = max(A, [], \"ComparisonMethod\", method)",
233 inputs: &MAX_INPUTS_A_EMPTY_COMPARISON,
234 outputs: &MAX_OUTPUT_M,
235 },
236 BuiltinSignatureDescriptor {
237 label: "[M, I] = max(A, [], \"ComparisonMethod\", method)",
238 inputs: &MAX_INPUTS_A_EMPTY_COMPARISON,
239 outputs: &MAX_OUTPUT_MI,
240 },
241 BuiltinSignatureDescriptor {
242 label: "M = max(A, B, \"ComparisonMethod\", method)",
243 inputs: &MAX_INPUTS_A_B_COMPARISON,
244 outputs: &MAX_OUTPUT_M,
245 },
246 BuiltinSignatureDescriptor {
247 label: "[M, I] = max(A, B, \"ComparisonMethod\", method)",
248 inputs: &MAX_INPUTS_A_B_COMPARISON,
249 outputs: &MAX_OUTPUT_MI,
250 },
251 BuiltinSignatureDescriptor {
252 label: "M = max(A, [], optionName, optionValue, ...)",
253 inputs: &MAX_INPUTS_A_EMPTY_OPTIONS,
254 outputs: &MAX_OUTPUT_M,
255 },
256 BuiltinSignatureDescriptor {
257 label: "[M, I] = max(A, [], optionName, optionValue, ...)",
258 inputs: &MAX_INPUTS_A_EMPTY_OPTIONS,
259 outputs: &MAX_OUTPUT_MI,
260 },
261 BuiltinSignatureDescriptor {
262 label: "M = max(A, B, optionName, optionValue, ...)",
263 inputs: &MAX_INPUTS_A_B_OPTIONS,
264 outputs: &MAX_OUTPUT_M,
265 },
266 BuiltinSignatureDescriptor {
267 label: "[M, I] = max(A, B, optionName, optionValue, ...)",
268 inputs: &MAX_INPUTS_A_B_OPTIONS,
269 outputs: &MAX_OUTPUT_MI,
270 },
271];
272
273const MAX_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
274 code: "RM.MAX.INVALID_ARGUMENT",
275 identifier: Some("RunMat:max:InvalidArgument"),
276 when: "Argument grammar, dimensions, or option names/values are invalid.",
277 message: "max: invalid argument",
278};
279
280const MAX_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
281 code: "RM.MAX.INVALID_INPUT",
282 identifier: Some("RunMat:max:InvalidInput"),
283 when: "Input values cannot be converted to supported max domains.",
284 message: "max: invalid input",
285};
286
287const MAX_ERROR_SIZE_MISMATCH: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
288 code: "RM.MAX.SIZE_MISMATCH",
289 identifier: Some("RunMat:max:SizeMismatch"),
290 when: "Element-wise operands are not broadcast-compatible.",
291 message: "max: size mismatch",
292};
293
294const MAX_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
295 code: "RM.MAX.INTERNAL",
296 identifier: Some("RunMat:max:Internal"),
297 when: "Execution fails due to gather, provider, allocation, or conversion internals.",
298 message: "max: internal failure",
299};
300
301const MAX_ERRORS: [BuiltinErrorDescriptor; 4] = [
302 MAX_ERROR_INVALID_ARGUMENT,
303 MAX_ERROR_INVALID_INPUT,
304 MAX_ERROR_SIZE_MISMATCH,
305 MAX_ERROR_INTERNAL,
306];
307
308pub const MAX_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
309 signatures: &MAX_SIGNATURES,
310 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
311 completion_policy: BuiltinCompletionPolicy::Public,
312 errors: &MAX_ERRORS,
313};
314
315fn max_descriptor_error_with_message(
316 message: impl Into<String>,
317 error: &'static BuiltinErrorDescriptor,
318) -> RuntimeError {
319 let mut builder = build_runtime_error(message).with_builtin(NAME);
320 if let Some(identifier) = error.identifier {
321 builder = builder.with_identifier(identifier);
322 }
323 builder.build()
324}
325
326fn max_descriptor_error_with_detail(
327 error: &'static BuiltinErrorDescriptor,
328 detail: impl AsRef<str>,
329) -> RuntimeError {
330 max_descriptor_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
331}
332
333fn max_invalid_argument(detail: impl AsRef<str>) -> RuntimeError {
334 max_descriptor_error_with_detail(&MAX_ERROR_INVALID_ARGUMENT, detail)
335}
336
337fn max_invalid_input(detail: impl AsRef<str>) -> RuntimeError {
338 max_descriptor_error_with_detail(&MAX_ERROR_INVALID_INPUT, detail)
339}
340
341fn max_size_mismatch(detail: impl AsRef<str>) -> RuntimeError {
342 max_descriptor_error_with_detail(&MAX_ERROR_SIZE_MISMATCH, detail)
343}
344
345fn max_internal_error(detail: impl AsRef<str>) -> RuntimeError {
346 max_descriptor_error_with_detail(&MAX_ERROR_INTERNAL, detail)
347}
348
349use crate::builtins::common::arg_tokens::tokens_from_values;
350use crate::builtins::common::broadcast::BroadcastPlan;
351use crate::builtins::common::random_args::{complex_tensor_into_value, keyword_of};
352use crate::builtins::common::spec::{
353 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, FusionError,
354 FusionExprContext, FusionKernelTemplate, GpuOpKind, ProviderHook, ReductionNaN,
355 ResidencyPolicy, ScalarType, ShapeRequirements,
356};
357use crate::builtins::common::{
358 gpu_helpers,
359 shape::{is_scalar_shape, normalize_scalar_shape},
360 tensor,
361};
362use crate::builtins::math::reduction::type_resolvers::min_max_type;
363
364#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::reduction::max")]
365pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
366 name: "max",
367 op_kind: GpuOpKind::Reduction,
368 supported_precisions: &[ScalarType::F32, ScalarType::F64],
369 broadcast: BroadcastSemantics::Matlab,
370 provider_hooks: &[
371 ProviderHook::Reduction {
372 name: "reduce_max_dim",
373 },
374 ProviderHook::Reduction {
375 name: "reduce_max",
376 },
377 ],
378 constant_strategy: ConstantStrategy::InlineLiteral,
379 residency: ResidencyPolicy::NewHandle,
380 nan_mode: ReductionNaN::Include,
381 two_pass_threshold: Some(256),
382 workgroup_size: Some(256),
383 accepts_nan_mode: false,
384 notes:
385 "Providers should implement reduce_max_dim / reduce_max. Requests that require omitnan, comparisonmethod overrides, or complex inputs fall back to the host implementation.",
386};
387
388#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::reduction::max")]
389pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
390 name: "max",
391 shape: ShapeRequirements::BroadcastCompatible,
392 constant_strategy: ConstantStrategy::InlineLiteral,
393 elementwise: None,
394 reduction: Some(FusionKernelTemplate {
395 scalar_precisions: &[ScalarType::F32, ScalarType::F64],
396 wgsl_body: |ctx: &FusionExprContext| {
397 let input = ctx.inputs.first().ok_or(FusionError::MissingInput(0))?;
398 Ok(format!("accumulator = max(accumulator, {input});"))
399 },
400 }),
401 emits_nan: true,
402 notes: "Fusion planner emits canonical reduction kernels; providers may substitute custom WGSL via reduce_max_dim hooks.",
403};
404
405#[derive(Debug, Clone)]
407pub struct MaxEvaluation {
408 values: Value,
409 indices: Value,
410}
411
412impl MaxEvaluation {
413 pub fn into_value(self) -> Value {
415 self.values
416 }
417
418 pub fn into_pair(self) -> (Value, Value) {
420 (self.values, self.indices)
421 }
422
423 pub fn indices_value(&self) -> Value {
425 self.indices.clone()
426 }
427}
428
429#[runtime_builtin(
430 name = "max",
431 category = "math/reduction",
432 summary = "Return maximum elements along dimensions or pairwise comparisons.",
433 keywords = "max,maximum,reduction,gpu,comparisonmethod,omitnan",
434 accel = "reduction",
435 type_resolver(max_type),
436 descriptor(crate::builtins::math::reduction::max::MAX_DESCRIPTOR),
437 builtin_path = "crate::builtins::math::reduction::max"
438)]
439async fn max_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
440 let eval = evaluate(value, &rest).await?;
441 if let Some(out_count) = crate::output_count::current_output_count() {
442 if out_count == 0 {
443 return Ok(Value::OutputList(Vec::new()));
444 }
445 if out_count == 1 {
446 return Ok(Value::OutputList(vec![eval.into_value()]));
447 }
448 let (values, indices) = eval.into_pair();
449 return Ok(crate::output_count::output_list_with_padding(
450 out_count,
451 vec![values, indices],
452 ));
453 }
454 Ok(eval.into_value())
455}
456
457pub async fn evaluate(value: Value, rest: &[Value]) -> BuiltinResult<MaxEvaluation> {
459 let parsed = parse_call(rest).await?;
460 if std::env::var("RUNMAT_DEBUG_MAX").is_ok() {
461 let call_label = match &parsed {
462 ParsedCall::Reduction(_) => "reduction",
463 ParsedCall::Elementwise(_) => "elementwise",
464 };
465 let first_arg = rest.first().map(debug_value_kind).unwrap_or("None");
466 tracing::debug!(
467 call_type = call_label,
468 rest_len = rest.len(),
469 first_arg = first_arg,
470 "[runmat-debug-max]"
471 );
472 }
473 match parsed {
474 ParsedCall::Elementwise(args) => elementwise_max(value, args).await,
475 ParsedCall::Reduction(args) => reduction_max(value, args).await,
476 }
477}
478
479#[derive(Debug, Clone)]
480enum ParsedCall {
481 Reduction(ReductionArgs),
482 Elementwise(ElementwiseArgs),
483}
484
485#[derive(Debug, Clone)]
486struct ReductionArgs {
487 selection: DimSelection,
488 nan_mode: ReductionNaN,
489 comparison: ComparisonMethod,
490 linear_index: bool,
491}
492
493impl Default for ReductionArgs {
494 fn default() -> Self {
495 Self {
496 selection: DimSelection::Auto,
497 nan_mode: ReductionNaN::Include,
498 comparison: ComparisonMethod::Auto,
499 linear_index: false,
500 }
501 }
502}
503
504#[derive(Debug, Clone)]
505enum DimSelection {
506 Auto,
507 Dim(usize),
508 Vec(Vec<usize>),
509 All,
510}
511
512#[derive(Debug, Clone, Copy, PartialEq, Eq)]
513enum ComparisonMethod {
514 Auto,
515 Real,
516 Abs,
517}
518
519#[derive(Debug, Clone)]
520struct ElementwiseArgs {
521 other: Value,
522 comparison: ComparisonMethod,
523}
524
525async fn parse_call(rest: &[Value]) -> BuiltinResult<ParsedCall> {
526 if rest.is_empty() {
527 return Ok(ParsedCall::Reduction(ReductionArgs::default()));
528 }
529
530 let first = &rest[0];
531 if !is_empty_placeholder(first) {
532 let comparison = parse_elementwise_options(&rest[1..])?;
533 return Ok(ParsedCall::Elementwise(ElementwiseArgs {
534 other: first.clone(),
535 comparison,
536 }));
537 }
538
539 let mut args = ReductionArgs::default();
540 parse_reduction_options(&mut args, &rest[1..]).await?;
541 Ok(ParsedCall::Reduction(args))
542}
543
544fn debug_value_kind(value: &Value) -> &'static str {
545 match value {
546 Value::Num(_) => "Num",
547 Value::Int(_) => "Int",
548 Value::Bool(_) => "Bool",
549 Value::Tensor(t) => {
550 if t.data.is_empty() {
551 "Tensor(empty)"
552 } else {
553 "Tensor"
554 }
555 }
556 Value::GpuTensor(_) => "GpuTensor",
557 Value::String(_) => "String",
558 Value::CharArray(_) => "CharArray",
559 Value::StringArray(sa) => {
560 if sa.data.is_empty() {
561 "StringArray(empty)"
562 } else {
563 "StringArray"
564 }
565 }
566 Value::LogicalArray(l) => {
567 if l.data.is_empty() {
568 "LogicalArray(empty)"
569 } else {
570 "LogicalArray"
571 }
572 }
573 Value::Cell(c) => {
574 if c.data.is_empty() {
575 "Cell(empty)"
576 } else {
577 "Cell"
578 }
579 }
580 _ => "Other",
581 }
582}
583
584fn is_empty_placeholder(value: &Value) -> bool {
585 match value {
586 Value::Tensor(t) => t.data.is_empty(),
587 Value::LogicalArray(l) => l.data.is_empty(),
588 Value::StringArray(sa) => sa.data.is_empty(),
589 Value::CharArray(ca) => ca.data.is_empty(),
590 Value::Cell(cell) => cell.data.is_empty(),
591 Value::String(s) => s.is_empty(),
592 _ => false,
593 }
594}
595
596async fn parse_reduction_options(args: &mut ReductionArgs, rest: &[Value]) -> BuiltinResult<()> {
597 let mut idx = 0usize;
598 let mut selection_set = !matches!(args.selection, DimSelection::Auto);
599 let mut comparison_set = matches!(args.comparison, ComparisonMethod::Auto);
600 let tokens = tokens_from_values(rest);
601 while idx < rest.len() {
602 if let Some(crate::builtins::common::arg_tokens::ArgToken::String(text)) = tokens.get(idx) {
603 match text.as_str() {
604 "omitnan" => {
605 args.nan_mode = ReductionNaN::Omit;
606 idx += 1;
607 continue;
608 }
609 "includenan" => {
610 args.nan_mode = ReductionNaN::Include;
611 idx += 1;
612 continue;
613 }
614 "all" => {
615 if selection_set {
616 return Err(max_invalid_argument(
617 "max: 'all' cannot be combined with an explicit dimension",
618 ));
619 }
620 args.selection = DimSelection::All;
621 selection_set = true;
622 idx += 1;
623 continue;
624 }
625 _ => {}
626 }
627 }
628 if let Some(keyword) = keyword_of(&rest[idx]) {
629 match keyword.as_str() {
630 "omitnan" => {
631 args.nan_mode = ReductionNaN::Omit;
632 idx += 1;
633 continue;
634 }
635 "includenan" => {
636 args.nan_mode = ReductionNaN::Include;
637 idx += 1;
638 continue;
639 }
640 "all" => {
641 if selection_set {
642 return Err(max_invalid_argument(
643 "max: 'all' cannot be combined with an explicit dimension",
644 ));
645 }
646 args.selection = DimSelection::All;
647 selection_set = true;
648 idx += 1;
649 continue;
650 }
651 "linear" => {
652 if selection_set {
653 return Err(max_invalid_argument(
654 "max: 'linear' cannot be combined with an explicit dimension",
655 ));
656 }
657 args.selection = DimSelection::All;
658 args.linear_index = true;
659 selection_set = true;
660 idx += 1;
661 continue;
662 }
663 "comparisonmethod" => {
664 let Some(value) = rest.get(idx + 1) else {
665 return Err(max_invalid_argument(
666 "max: expected a value after 'ComparisonMethod'",
667 ));
668 };
669 args.comparison = parse_comparison_method(value)?;
670 comparison_set = true;
671 idx += 2;
672 continue;
673 }
674 _ => {}
675 }
676 }
677
678 if !selection_set {
679 if let Some(selection) = parse_dimension_value(&rest[idx]).await? {
680 args.selection = selection;
681 selection_set = true;
682 idx += 1;
683 continue;
684 }
685 }
686
687 return Err(max_invalid_argument(format!(
688 "max: unrecognised argument {:?}",
689 rest[idx]
690 )));
691 }
692
693 if !comparison_set {
694 args.comparison = ComparisonMethod::Auto;
695 }
696
697 Ok(())
698}
699
700fn parse_elementwise_options(rest: &[Value]) -> BuiltinResult<ComparisonMethod> {
701 let mut comparison = ComparisonMethod::Auto;
702 let mut comparison_set = false;
703 let mut idx = 0usize;
704 while idx < rest.len() {
705 if let Some(keyword) = keyword_of(&rest[idx]) {
706 match keyword.as_str() {
707 "comparisonmethod" => {
708 let Some(value) = rest.get(idx + 1) else {
709 return Err(max_invalid_argument(
710 "max: expected a value after 'ComparisonMethod'",
711 ));
712 };
713 comparison = parse_comparison_method(value)?;
714 comparison_set = true;
715 idx += 2;
716 continue;
717 }
718 "omitnan" | "includenan" | "all" | "linear" => {
719 return Err(max_invalid_argument(format!(
720 "max: '{}' is only supported for reduction calls",
721 keyword
722 )));
723 }
724 _ => {}
725 }
726 }
727 return Err(max_invalid_argument(format!(
728 "max: unrecognised argument {:?}",
729 rest[idx]
730 )));
731 }
732 if !comparison_set {
733 comparison = ComparisonMethod::Auto;
734 }
735 Ok(comparison)
736}
737
738fn parse_comparison_method(value: &Value) -> BuiltinResult<ComparisonMethod> {
739 let Some(keyword) = keyword_of(value) else {
740 return Err(max_invalid_argument(
741 "max: 'ComparisonMethod' expects a string value",
742 ));
743 };
744 match keyword.as_str() {
745 "auto" => Ok(ComparisonMethod::Auto),
746 "abs" | "magnitude" => Ok(ComparisonMethod::Abs),
747 "real" => Ok(ComparisonMethod::Real),
748 other => Err(max_invalid_argument(format!(
749 "max: unsupported ComparisonMethod '{other}'"
750 ))),
751 }
752}
753
754async fn parse_dimension_value(value: &Value) -> BuiltinResult<Option<DimSelection>> {
755 match value {
756 Value::Int(_) | Value::Num(_) => tensor::dimension_from_value_async(value, "max", false)
757 .await
758 .map_err(map_scalar_dim_error)
759 .map(|dim| dim.map(DimSelection::Dim)),
760 Value::Tensor(t) => parse_dimension_tensor(value, &t.shape).await,
761 Value::LogicalArray(logical) => parse_dimension_tensor(value, &logical.shape).await,
762 Value::GpuTensor(_) => Err(max_invalid_argument(
763 "max: dimension arguments must reside on the host",
764 )),
765 _ => Ok(None),
766 }
767}
768
769async fn parse_dimension_tensor(
770 value: &Value,
771 shape: &[usize],
772) -> BuiltinResult<Option<DimSelection>> {
773 if tensor::element_count(shape) == 0 {
774 return Ok(Some(DimSelection::Auto));
775 }
776 let is_vector = shape.len() == 1
777 || shape.get(0).copied().unwrap_or(1) == 1
778 || shape.get(1).copied().unwrap_or(1) == 1;
779 if !is_vector {
780 return Err(max_invalid_argument(
781 "max: dimension vector must be a row or column vector",
782 ));
783 }
784 let dims = tensor::dims_from_value_async(value)
785 .await
786 .map_err(map_vector_dim_error)?;
787 let Some(dims) = dims else {
788 return Ok(None);
789 };
790 if dims.is_empty() {
791 return Ok(Some(DimSelection::Auto));
792 }
793 let mut seen = BTreeSet::new();
794 let mut uniq = Vec::with_capacity(dims.len());
795 for dim in dims {
796 if dim < 1 {
797 return Err(max_invalid_argument("max: dimension indices must be >= 1"));
798 }
799 if seen.insert(dim) {
800 uniq.push(dim);
801 }
802 }
803 Ok(Some(DimSelection::Vec(uniq)))
804}
805
806fn map_scalar_dim_error(message: String) -> RuntimeError {
807 if message.contains("integer") {
808 return max_invalid_argument("max: dimension must be integral");
809 }
810 max_invalid_argument(message)
811}
812
813fn map_vector_dim_error(message: String) -> RuntimeError {
814 if message.contains("non-negative") {
815 return max_invalid_argument("max: dimension indices must be >= 1");
816 }
817 if message.contains("finite") {
818 return max_invalid_argument("max: dimension entries must be finite");
819 }
820 if message.contains("integer") {
821 return max_invalid_argument("max: dimension entries must be integers");
822 }
823 max_invalid_argument(message)
824}
825
826async fn reduction_max(value: Value, args: ReductionArgs) -> BuiltinResult<MaxEvaluation> {
827 match value {
828 Value::GpuTensor(handle) => {
829 if let Some(eval) = reduction_max_gpu(handle.clone(), &args).await? {
830 return Ok(eval);
831 }
832 let tensor = gpu_helpers::gather_tensor_async(&handle)
834 .await
835 .map_err(|e| max_internal_error(format!("max: {e}")))?;
836 reduction_max_host(Value::Tensor(tensor), &args)
837 }
838 other => reduction_max_host(other, &args),
839 }
840}
841
842async fn reduction_max_gpu(
843 handle: GpuTensorHandle,
844 args: &ReductionArgs,
845) -> BuiltinResult<Option<MaxEvaluation>> {
846 #[cfg(all(test, feature = "wgpu"))]
847 {
848 if handle.device_id != 0 {
849 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
850 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
851 );
852 }
853 }
854 if args.nan_mode == ReductionNaN::Omit {
855 log::trace!("max: gpu path disabled (nan_mode=omit)");
856 return Ok(None);
857 }
858 if args.comparison != ComparisonMethod::Auto {
859 log::trace!("max: gpu path disabled (comparison != auto)");
860 return Ok(None);
861 }
862 if args.linear_index {
863 log::trace!("max: gpu path disabled (linear_index=true)");
864 return Ok(None);
865 }
866 let provider = match runmat_accelerate_api::provider() {
867 Some(p) => p,
868 None => {
869 log::trace!(
870 "max: gpu path unavailable (provider() is None) handle_shape={:?} device_id={}",
871 handle.shape,
872 handle.device_id
873 );
874 return Ok(None);
875 }
876 };
877 let target_dim = match args.selection {
878 DimSelection::Auto => default_dimension_from_shape(&handle.shape),
879 DimSelection::Dim(dim) => dim,
880 DimSelection::Vec(ref dims) if dims.len() == 1 => dims[0],
881 DimSelection::All => {
882 if handle.shape.len() <= 1 {
883 1
884 } else {
885 return Ok(None);
886 }
887 }
888 _ => return Ok(None),
889 };
890 if target_dim == 0 {
891 return Ok(None);
892 }
893 let zero_based = target_dim.saturating_sub(1);
895 if zero_based >= handle.shape.len() {
896 return Ok(None);
897 }
898 log::trace!(
899 "max: attempting reduce_max_dim dim={} (zero_based={}) shape={:?} device_id={}",
900 target_dim,
901 zero_based,
902 handle.shape,
903 handle.device_id
904 );
905 match provider.reduce_max_dim(&handle, zero_based).await {
906 Ok(ReduceDimResult { values, indices }) => Ok(Some(MaxEvaluation {
907 values: Value::GpuTensor(values),
908 indices: Value::GpuTensor(indices),
909 })),
910 Err(err) => {
911 log::trace!("max: reduce_max_dim failed: {err}");
912 Ok(None)
913 }
914 }
915}
916
917fn reduction_max_host(value: Value, args: &ReductionArgs) -> BuiltinResult<MaxEvaluation> {
918 match materialize_for_max("max", value)? {
919 InputData::Real(tensor) => reduce_real_tensor(tensor, args),
920 InputData::Complex(tensor) => reduce_complex_tensor(tensor, args),
921 }
922}
923
924enum InputData {
925 Real(Tensor),
926 Complex(ComplexTensor),
927}
928
929fn materialize_for_max(name: &str, value: Value) -> BuiltinResult<InputData> {
930 match value {
931 Value::Tensor(t) => Ok(InputData::Real(t)),
932 Value::LogicalArray(logical) => {
933 let tensor = tensor::logical_to_tensor(&logical).map_err(max_invalid_input)?;
934 Ok(InputData::Real(tensor))
935 }
936 Value::Num(n) => {
937 let tensor = Tensor::new(vec![n], vec![1, 1])
938 .map_err(|e| max_internal_error(format!("{name}: {e}")))?;
939 Ok(InputData::Real(tensor))
940 }
941 Value::Int(i) => {
942 let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1])
943 .map_err(|e| max_internal_error(format!("{name}: {e}")))?;
944 Ok(InputData::Real(tensor))
945 }
946 Value::Bool(b) => {
947 let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
948 .map_err(|e| max_internal_error(format!("{name}: {e}")))?;
949 Ok(InputData::Real(tensor))
950 }
951 Value::Complex(re, im) => {
952 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
953 .map_err(|e| max_internal_error(format!("{name}: {e}")))?;
954 Ok(InputData::Complex(tensor))
955 }
956 Value::ComplexTensor(ct) => Ok(InputData::Complex(ct)),
957 Value::String(_)
958 | Value::StringArray(_)
959 | Value::CharArray(_)
960 | Value::SparseTensor(_)
961 | Value::Cell(_) => Err(max_invalid_input(format!(
962 "{name}: expected numeric or logical dense input"
963 ))),
964 Value::GpuTensor(_) => Err(max_internal_error(format!(
965 "{name}: internal error – GPU tensors must be gathered before host execution"
966 ))),
967 Value::Object(_) | Value::HandleObject(_) | Value::Struct(_) | Value::Listener(_) => {
968 Err(max_invalid_input(format!("{name}: unsupported input type")))
969 }
970 Value::FunctionHandle(_)
971 | Value::ExternalFunctionHandle(_)
972 | Value::MethodFunctionHandle(_)
973 | Value::BoundFunctionHandle { .. }
974 | Value::Closure(_)
975 | Value::ClassRef(_)
976 | Value::MException(_)
977 | Value::OutputList(_) => Err(max_invalid_input(format!("{name}: unsupported input type"))),
978 }
979}
980
981fn reduce_real_tensor(tensor: Tensor, args: &ReductionArgs) -> BuiltinResult<MaxEvaluation> {
982 let shape = tensor.shape.clone();
983 if tensor.data.is_empty() {
984 let output_shape = resolve_output_shape(&shape, &args.selection, &[])?;
985 let values = Tensor::new(Vec::new(), output_shape.clone())
986 .map_err(|e| max_internal_error(format!("max: {e}")))?;
987 let indices = Tensor::new(Vec::new(), output_shape)
988 .map_err(|e| max_internal_error(format!("max: {e}")))?;
989 return Ok(MaxEvaluation {
990 values: tensor::tensor_into_value(values),
991 indices: tensor::tensor_into_value(indices),
992 });
993 }
994 let resolved = resolve_reduction_dims(&shape, &args.selection)?;
995 let output_shape = resolved.output_shape.clone();
996 let output_len = tensor::element_count(&output_shape);
997
998 if output_len == 0 {
999 let values = Tensor::new(Vec::new(), output_shape.clone())
1000 .map_err(|e| max_internal_error(format!("max: {e}")))?;
1001 let indices = Tensor::new(Vec::new(), output_shape)
1002 .map_err(|e| max_internal_error(format!("max: {e}")))?;
1003 return Ok(MaxEvaluation {
1004 values: tensor::tensor_into_value(values),
1005 indices: tensor::tensor_into_value(indices),
1006 });
1007 }
1008
1009 let strides = compute_strides(&shape);
1010 let output_strides = compute_strides(&output_shape);
1011 let dims_mask = resolved.dims_mask.clone();
1012 let reduce_strides = resolved.reduce_strides.clone();
1013
1014 let mut best = vec![BestReal::new(); output_len];
1015 let mut coords = vec![0usize; shape.len()];
1016 for &value in &tensor.data {
1017 let out_idx = map_output_index(&coords, &output_strides, &dims_mask);
1018 let reduce_idx = map_reduce_index(
1019 &coords,
1020 &resolved.reduced_dims,
1021 &reduce_strides,
1022 resolved.reduce_all,
1023 );
1024 let full_idx = map_linear_index(&coords, &strides);
1025
1026 update_best_real(
1027 &mut best[out_idx],
1028 value,
1029 reduce_idx,
1030 full_idx,
1031 args.nan_mode,
1032 args.comparison,
1033 );
1034 increment_coords(&mut coords, &shape);
1035 }
1036
1037 let mut values = vec![0.0f64; output_len];
1038 let mut indices = vec![0.0f64; output_len];
1039
1040 for (i, entry) in best.iter().enumerate() {
1041 if entry.nan_fixed {
1042 values[i] = f64::NAN;
1043 indices[i] = if args.linear_index || resolved.reduce_all {
1044 (entry.full_index + 1) as f64
1045 } else if resolved.reduced_dims.is_empty() {
1046 1.0
1047 } else {
1048 (entry.reduce_index + 1) as f64
1049 };
1050 continue;
1051 }
1052 if !entry.has_value {
1053 values[i] = f64::NAN;
1054 indices[i] = f64::NAN;
1055 continue;
1056 }
1057 values[i] = entry.value;
1058 indices[i] = if args.linear_index || resolved.reduce_all {
1059 (entry.full_index + 1) as f64
1060 } else if resolved.reduced_dims.is_empty() {
1061 1.0
1062 } else {
1063 (entry.reduce_index + 1) as f64
1064 };
1065 }
1066
1067 let value_tensor = Tensor::new(values, output_shape.clone())
1068 .map_err(|e| max_internal_error(format!("max: {e}")))?;
1069 let index_tensor =
1070 Tensor::new(indices, output_shape).map_err(|e| max_internal_error(format!("max: {e}")))?;
1071
1072 Ok(MaxEvaluation {
1073 values: tensor::tensor_into_value(value_tensor),
1074 indices: tensor::tensor_into_value(index_tensor),
1075 })
1076}
1077
1078fn reduce_complex_tensor(
1079 tensor: ComplexTensor,
1080 args: &ReductionArgs,
1081) -> BuiltinResult<MaxEvaluation> {
1082 let shape = tensor.shape.clone();
1083 if tensor.data.is_empty() {
1084 let output_shape = resolve_output_shape(&shape, &args.selection, &[])?;
1085 let values = ComplexTensor::new(Vec::new(), output_shape.clone())
1086 .map_err(|e| max_internal_error(format!("max: {e}")))?;
1087 let indices = Tensor::new(Vec::new(), output_shape)
1088 .map_err(|e| max_internal_error(format!("max: {e}")))?;
1089 return Ok(MaxEvaluation {
1090 values: complex_tensor_into_value(values),
1091 indices: tensor::tensor_into_value(indices),
1092 });
1093 }
1094
1095 let resolved = resolve_reduction_dims(&shape, &args.selection)?;
1096 let output_shape = resolved.output_shape.clone();
1097 let output_len = tensor::element_count(&output_shape);
1098
1099 if output_len == 0 {
1100 let values = ComplexTensor::new(Vec::new(), output_shape.clone())
1101 .map_err(|e| max_internal_error(format!("max: {e}")))?;
1102 let indices = Tensor::new(Vec::new(), output_shape)
1103 .map_err(|e| max_internal_error(format!("max: {e}")))?;
1104 return Ok(MaxEvaluation {
1105 values: complex_tensor_into_value(values),
1106 indices: tensor::tensor_into_value(indices),
1107 });
1108 }
1109
1110 let strides = compute_strides(&shape);
1111 let output_strides = compute_strides(&output_shape);
1112 let dims_mask = resolved.dims_mask.clone();
1113 let reduce_strides = resolved.reduce_strides.clone();
1114
1115 let mut best = vec![BestComplex::new(); output_len];
1116 let mut coords = vec![0usize; shape.len()];
1117
1118 for &(re, im) in &tensor.data {
1119 let out_idx = map_output_index(&coords, &output_strides, &dims_mask);
1120 let reduce_idx = map_reduce_index(
1121 &coords,
1122 &resolved.reduced_dims,
1123 &reduce_strides,
1124 resolved.reduce_all,
1125 );
1126 let full_idx = map_linear_index(&coords, &strides);
1127 update_best_complex(
1128 &mut best[out_idx],
1129 (re, im),
1130 reduce_idx,
1131 full_idx,
1132 args.nan_mode,
1133 args.comparison,
1134 );
1135 increment_coords(&mut coords, &shape);
1136 }
1137
1138 let mut values = vec![(0.0f64, 0.0f64); output_len];
1139 let mut indices = vec![0.0f64; output_len];
1140
1141 for (i, entry) in best.iter().enumerate() {
1142 if entry.nan_fixed {
1143 values[i] = (f64::NAN, f64::NAN);
1144 indices[i] = if args.linear_index || resolved.reduce_all {
1145 (entry.full_index + 1) as f64
1146 } else if resolved.reduced_dims.is_empty() {
1147 1.0
1148 } else {
1149 (entry.reduce_index + 1) as f64
1150 };
1151 continue;
1152 }
1153 if !entry.has_value {
1154 values[i] = (f64::NAN, f64::NAN);
1155 indices[i] = f64::NAN;
1156 continue;
1157 }
1158 values[i] = entry.value;
1159 indices[i] = if args.linear_index || resolved.reduce_all {
1160 (entry.full_index + 1) as f64
1161 } else if resolved.reduced_dims.is_empty() {
1162 1.0
1163 } else {
1164 (entry.reduce_index + 1) as f64
1165 };
1166 }
1167
1168 let value_tensor = ComplexTensor::new(values, output_shape.clone())
1169 .map_err(|e| max_internal_error(format!("max: {e}")))?;
1170 let index_tensor =
1171 Tensor::new(indices, output_shape).map_err(|e| max_internal_error(format!("max: {e}")))?;
1172 Ok(MaxEvaluation {
1173 values: complex_tensor_into_value(value_tensor),
1174 indices: tensor::tensor_into_value(index_tensor),
1175 })
1176}
1177
1178#[derive(Debug, Clone)]
1179struct BestReal {
1180 value: f64,
1181 reduce_index: usize,
1182 full_index: usize,
1183 has_value: bool,
1184 nan_fixed: bool,
1185}
1186
1187impl BestReal {
1188 fn new() -> Self {
1189 Self {
1190 value: 0.0,
1191 reduce_index: 0,
1192 full_index: 0,
1193 has_value: false,
1194 nan_fixed: false,
1195 }
1196 }
1197}
1198
1199#[derive(Debug, Clone)]
1200struct BestComplex {
1201 value: (f64, f64),
1202 reduce_index: usize,
1203 full_index: usize,
1204 has_value: bool,
1205 nan_fixed: bool,
1206}
1207
1208impl BestComplex {
1209 fn new() -> Self {
1210 Self {
1211 value: (0.0, 0.0),
1212 reduce_index: 0,
1213 full_index: 0,
1214 has_value: false,
1215 nan_fixed: false,
1216 }
1217 }
1218}
1219
1220fn resolve_output_shape(
1221 shape: &[usize],
1222 selection: &DimSelection,
1223 reduced_dims: &[usize],
1224) -> BuiltinResult<Vec<usize>> {
1225 if is_scalar_shape(shape) {
1226 return Ok(normalize_scalar_shape(shape));
1227 }
1228 let mut output = shape.to_vec();
1229 match selection {
1230 DimSelection::All => {
1231 output.fill(1);
1232 }
1233 _ => {
1234 for &dim in reduced_dims {
1235 if dim < output.len() {
1236 output[dim] = 1;
1237 }
1238 }
1239 }
1240 }
1241 Ok(output)
1242}
1243
1244struct ResolvedDims {
1245 output_shape: Vec<usize>,
1246 reduced_dims: Vec<usize>,
1247 reduce_all: bool,
1248 dims_mask: Vec<bool>,
1249 reduce_strides: Vec<usize>,
1250}
1251
1252fn resolve_reduction_dims(
1253 shape: &[usize],
1254 selection: &DimSelection,
1255) -> BuiltinResult<ResolvedDims> {
1256 if is_scalar_shape(shape) {
1257 return Ok(ResolvedDims {
1258 output_shape: normalize_scalar_shape(shape),
1259 reduced_dims: Vec::new(),
1260 reduce_all: true,
1261 dims_mask: Vec::new(),
1262 reduce_strides: Vec::new(),
1263 });
1264 }
1265
1266 let mut reduced_dims = match selection {
1267 DimSelection::Auto => {
1268 let mut dim = None;
1269 for (index, &len) in shape.iter().enumerate() {
1270 if len > 1 {
1271 dim = Some(index);
1272 break;
1273 }
1274 }
1275 vec![dim.unwrap_or(0)]
1276 }
1277 DimSelection::Dim(dim) => {
1278 if *dim == 0 {
1279 return Err(max_invalid_argument("max: dimension must be >= 1"));
1280 }
1281 let index = dim.saturating_sub(1);
1282 if index >= shape.len() {
1283 Vec::new()
1284 } else {
1285 vec![index]
1286 }
1287 }
1288 DimSelection::Vec(dims) => {
1289 if dims.is_empty() {
1290 Vec::new()
1291 } else {
1292 dims.iter()
1293 .filter_map(|dim| {
1294 if *dim == 0 {
1295 None
1296 } else {
1297 let idx = dim - 1;
1298 if idx < shape.len() {
1299 Some(idx)
1300 } else {
1301 None
1302 }
1303 }
1304 })
1305 .collect()
1306 }
1307 }
1308 DimSelection::All => (0..shape.len()).collect(),
1309 };
1310
1311 reduced_dims.sort_unstable();
1312 reduced_dims.dedup();
1313
1314 let reduce_all = !reduced_dims.is_empty()
1315 && reduced_dims.len() == shape.len()
1316 && reduced_dims.iter().enumerate().all(|(i, &d)| i == d);
1317
1318 let output_shape = resolve_output_shape(shape, selection, &reduced_dims)?;
1319 let mut dims_mask = vec![false; shape.len()];
1320 for &dim in &reduced_dims {
1321 if dim < dims_mask.len() {
1322 dims_mask[dim] = true;
1323 }
1324 }
1325 let reduce_strides = compute_subspace_strides(shape, &reduced_dims);
1326
1327 Ok(ResolvedDims {
1328 output_shape,
1329 reduced_dims,
1330 reduce_all,
1331 dims_mask,
1332 reduce_strides,
1333 })
1334}
1335
1336fn compute_strides(shape: &[usize]) -> Vec<usize> {
1337 let mut strides = Vec::with_capacity(shape.len());
1338 let mut stride = 1usize;
1339 for &len in shape {
1340 strides.push(stride);
1341 stride = stride.saturating_mul(len.max(1));
1342 }
1343 strides
1344}
1345
1346fn compute_subspace_strides(shape: &[usize], dims: &[usize]) -> Vec<usize> {
1347 if dims.is_empty() {
1348 return Vec::new();
1349 }
1350 let mut strides = Vec::with_capacity(dims.len());
1351 let mut accum = 1usize;
1352 for &dim in dims {
1353 let len = shape.get(dim).copied().unwrap_or(1).max(1);
1354 strides.push(accum);
1355 accum = accum.saturating_mul(len);
1356 }
1357 strides
1358}
1359
1360fn map_output_index(coords: &[usize], output_strides: &[usize], dims_mask: &[bool]) -> usize {
1361 if coords.is_empty() {
1362 return 0;
1363 }
1364 let mut index = 0usize;
1365 for (dim, stride) in output_strides.iter().enumerate() {
1366 let coord = if *dims_mask.get(dim).unwrap_or(&false) {
1367 0
1368 } else {
1369 coords[dim]
1370 };
1371 index = index.saturating_add(coord.saturating_mul(*stride));
1372 }
1373 index
1374}
1375
1376fn map_reduce_index(
1377 coords: &[usize],
1378 reduced_dims: &[usize],
1379 reduce_strides: &[usize],
1380 reduce_all: bool,
1381) -> usize {
1382 if reduced_dims.is_empty() {
1383 return 0;
1384 }
1385 if reduce_all {
1386 return 0;
1388 }
1389 let mut index = 0usize;
1390 for (pos, &dim) in reduced_dims.iter().enumerate() {
1391 if let Some(coord) = coords.get(dim) {
1392 if let Some(stride) = reduce_strides.get(pos) {
1393 index = index.saturating_add(coord.saturating_mul(*stride));
1394 }
1395 }
1396 }
1397 index
1398}
1399
1400fn map_linear_index(coords: &[usize], strides: &[usize]) -> usize {
1401 coords
1402 .iter()
1403 .zip(strides.iter())
1404 .fold(0usize, |acc, (&coord, &stride)| {
1405 acc.saturating_add(coord.saturating_mul(stride))
1406 })
1407}
1408
1409fn increment_coords(coords: &mut [usize], shape: &[usize]) {
1410 for dim in 0..coords.len() {
1411 if shape[dim] == 0 {
1412 continue;
1413 }
1414 coords[dim] += 1;
1415 if coords[dim] < shape[dim] {
1416 break;
1417 }
1418 coords[dim] = 0;
1419 }
1420}
1421
1422fn update_best_real(
1423 best: &mut BestReal,
1424 value: f64,
1425 reduce_index: usize,
1426 full_index: usize,
1427 nan_mode: ReductionNaN,
1428 comparison: ComparisonMethod,
1429) {
1430 if value.is_nan() {
1431 match nan_mode {
1432 ReductionNaN::Include => {
1433 if !best.nan_fixed {
1434 best.value = f64::NAN;
1435 best.reduce_index = reduce_index;
1436 best.full_index = full_index;
1437 best.has_value = true;
1438 best.nan_fixed = true;
1439 }
1440 }
1441 ReductionNaN::Omit => {}
1442 }
1443 return;
1444 }
1445 if best.nan_fixed {
1446 return;
1447 }
1448
1449 if !best.has_value {
1450 best.value = value;
1451 best.reduce_index = reduce_index;
1452 best.full_index = full_index;
1453 best.has_value = true;
1454 return;
1455 }
1456
1457 if should_replace_real(best.value, value, comparison) {
1458 best.value = value;
1459 best.reduce_index = reduce_index;
1460 best.full_index = full_index;
1461 }
1462}
1463
1464fn update_best_complex(
1465 best: &mut BestComplex,
1466 value: (f64, f64),
1467 reduce_index: usize,
1468 full_index: usize,
1469 nan_mode: ReductionNaN,
1470 comparison: ComparisonMethod,
1471) {
1472 if value.0.is_nan() || value.1.is_nan() {
1473 match nan_mode {
1474 ReductionNaN::Include => {
1475 if !best.nan_fixed {
1476 best.value = (f64::NAN, f64::NAN);
1477 best.reduce_index = reduce_index;
1478 best.full_index = full_index;
1479 best.has_value = true;
1480 best.nan_fixed = true;
1481 }
1482 }
1483 ReductionNaN::Omit => {}
1484 }
1485 return;
1486 }
1487 if best.nan_fixed {
1488 return;
1489 }
1490
1491 if !best.has_value {
1492 best.value = value;
1493 best.reduce_index = reduce_index;
1494 best.full_index = full_index;
1495 best.has_value = true;
1496 return;
1497 }
1498
1499 if should_replace_complex(best.value, value, comparison) {
1500 best.value = value;
1501 best.reduce_index = reduce_index;
1502 best.full_index = full_index;
1503 }
1504}
1505
1506fn should_replace_real(current: f64, candidate: f64, comparison: ComparisonMethod) -> bool {
1507 match comparison {
1508 ComparisonMethod::Auto | ComparisonMethod::Real => {
1509 if candidate > current {
1510 return true;
1511 }
1512 if candidate < current {
1513 return false;
1514 }
1515 if candidate == 0.0 && current == 0.0 {
1516 return candidate.is_sign_positive() && !current.is_sign_positive();
1517 }
1518 false
1519 }
1520 ComparisonMethod::Abs => {
1521 let curr_abs = current.abs();
1522 let cand_abs = candidate.abs();
1523 if cand_abs > curr_abs {
1524 return true;
1525 }
1526 if cand_abs < curr_abs {
1527 return false;
1528 }
1529 if candidate > current {
1530 return true;
1531 }
1532 if candidate < current {
1533 return false;
1534 }
1535 if candidate == 0.0 && current == 0.0 {
1536 return candidate.is_sign_positive() && !current.is_sign_positive();
1537 }
1538 false
1539 }
1540 }
1541}
1542
1543fn should_replace_complex(
1544 current: (f64, f64),
1545 candidate: (f64, f64),
1546 comparison: ComparisonMethod,
1547) -> bool {
1548 match comparison {
1549 ComparisonMethod::Auto | ComparisonMethod::Abs => {
1550 compare_complex_auto(current, candidate) == Ordering::Less
1551 }
1552 ComparisonMethod::Real => compare_complex_real(current, candidate) == Ordering::Less,
1553 }
1554}
1555
1556fn compare_complex_auto(a: (f64, f64), b: (f64, f64)) -> Ordering {
1557 let a_mag = magnitude_squared(a);
1558 let b_mag = magnitude_squared(b);
1559 if a_mag < b_mag {
1560 return Ordering::Less;
1561 }
1562 if a_mag > b_mag {
1563 return Ordering::Greater;
1564 }
1565 let a_angle = a.1.atan2(a.0);
1567 let b_angle = b.1.atan2(b.0);
1568 if a_angle < b_angle {
1569 Ordering::Less
1570 } else if a_angle > b_angle {
1571 Ordering::Greater
1572 } else {
1573 Ordering::Equal
1574 }
1575}
1576
1577fn compare_complex_real(a: (f64, f64), b: (f64, f64)) -> Ordering {
1578 if a.0 < b.0 {
1579 return Ordering::Less;
1580 }
1581 if a.0 > b.0 {
1582 return Ordering::Greater;
1583 }
1584 compare_complex_auto(a, b)
1586}
1587
1588fn magnitude_squared(z: (f64, f64)) -> f64 {
1589 z.0.mul_add(z.0, z.1 * z.1)
1590}
1591
1592fn default_dimension_from_shape(shape: &[usize]) -> usize {
1593 if is_scalar_shape(shape) {
1594 return 1;
1595 }
1596 for (i, &len) in shape.iter().enumerate() {
1597 if len > 1 {
1598 return i + 1;
1599 }
1600 }
1601 1
1602}
1603
1604async fn elementwise_max(value: Value, args: ElementwiseArgs) -> BuiltinResult<MaxEvaluation> {
1605 let ElementwiseArgs { other, comparison } = args;
1606 match (value, other) {
1607 (Value::GpuTensor(handle_a), Value::GpuTensor(handle_b)) => {
1608 if gpu_tensor_is_scalar(&handle_b) {
1609 if let Some(num) = gpu_tensor_scalar_value(&handle_b).await {
1610 let scalar = Value::Num(num);
1611 if let Some(eval) =
1612 elementwise_max_gpu_scalar_left(&handle_a, &scalar, comparison).await
1613 {
1614 return Ok(eval);
1615 }
1616 if let Ok(ta) = gpu_helpers::gather_tensor_async(&handle_a).await {
1617 if let Ok(eval) = elementwise_real_or_complex(
1618 Value::Tensor(ta),
1619 scalar.clone(),
1620 comparison,
1621 ) {
1622 return Ok(eval);
1623 }
1624 }
1625 return Err(max_internal_error(
1626 "max: elementwise GPU scalar path failed",
1627 ));
1628 }
1629 }
1630 if gpu_tensor_is_scalar(&handle_a) {
1631 if let Some(num) = gpu_tensor_scalar_value(&handle_a).await {
1632 let scalar = Value::Num(num);
1633 if let Some(eval) =
1634 elementwise_max_gpu_scalar_right(&scalar, &handle_b, comparison).await
1635 {
1636 return Ok(eval);
1637 }
1638 if let Ok(tb) = gpu_helpers::gather_tensor_async(&handle_b).await {
1639 if let Ok(eval) = elementwise_real_or_complex(
1640 scalar.clone(),
1641 Value::Tensor(tb),
1642 comparison,
1643 ) {
1644 return Ok(eval);
1645 }
1646 }
1647 return Err(max_internal_error(
1648 "max: elementwise GPU scalar path failed",
1649 ));
1650 }
1651 }
1652 if let Some(eval) = elementwise_max_gpu_pair(&handle_a, &handle_b, comparison).await {
1653 return Ok(eval);
1654 }
1655 if let (Ok(ta), Ok(tb)) = (
1656 gpu_helpers::gather_tensor_async(&handle_a).await,
1657 gpu_helpers::gather_tensor_async(&handle_b).await,
1658 ) {
1659 if let Ok(eval) =
1660 elementwise_real_or_complex(Value::Tensor(ta), Value::Tensor(tb), comparison)
1661 {
1662 return Ok(eval);
1663 }
1664 }
1665 Err(max_internal_error("max: elementwise GPU path failed"))
1666 }
1667 (Value::GpuTensor(handle), other) => {
1668 if let Some(eval) = elementwise_max_gpu_scalar_left(&handle, &other, comparison).await {
1669 return Ok(eval);
1670 }
1671 let t = gpu_helpers::gather_tensor_async(&handle)
1672 .await
1673 .map_err(|_| max_internal_error("max: elementwise GPU scalar path failed"))?;
1674 elementwise_real_or_complex(Value::Tensor(t), other, comparison)
1675 }
1676 (other, Value::GpuTensor(handle)) => {
1677 if let Some(eval) = elementwise_max_gpu_scalar_right(&other, &handle, comparison).await
1678 {
1679 return Ok(eval);
1680 }
1681 let t = gpu_helpers::gather_tensor_async(&handle)
1682 .await
1683 .map_err(|_| max_internal_error("max: elementwise GPU scalar path failed"))?;
1684 elementwise_real_or_complex(other, Value::Tensor(t), comparison)
1685 }
1686 (lhs, rhs) => elementwise_real_or_complex(lhs, rhs, comparison),
1687 }
1688}
1689
1690async fn elementwise_max_gpu_pair(
1691 a: &GpuTensorHandle,
1692 b: &GpuTensorHandle,
1693 comparison: ComparisonMethod,
1694) -> Option<MaxEvaluation> {
1695 if comparison != ComparisonMethod::Auto {
1696 return None;
1697 }
1698 let provider = runmat_accelerate_api::provider()?;
1699 if a.shape == b.shape {
1701 let values = provider.elem_max(a, b).await.ok()?;
1702 if let Ok(mask) = provider.elem_ge(a, b).await {
1704 let indices = gpu_mask_indices(provider, &mask)?;
1705 let _ = provider.free(&mask);
1706 return Some(MaxEvaluation {
1707 values: Value::GpuTensor(values),
1708 indices: Value::GpuTensor(indices),
1709 });
1710 } else {
1711 let ta = gpu_helpers::gather_tensor_async(a).await.ok()?;
1713 let tb = gpu_helpers::gather_tensor_async(b).await.ok()?;
1714 let mut indices = Vec::with_capacity(ta.data.len());
1715 for i in 0..ta.data.len() {
1716 indices.push(if ta.data[i] >= tb.data[i] { 1.0 } else { 2.0 });
1717 }
1718 let index_tensor = Tensor::new(indices, ta.shape.clone()).ok()?;
1719 return Some(MaxEvaluation {
1720 values: Value::GpuTensor(values),
1721 indices: tensor::tensor_into_value(index_tensor),
1722 });
1723 }
1724 }
1725 let (out_shape, reps_a, reps_b) = broadcast_reps(&a.shape, &b.shape)?;
1727 let a_exp = if reps_a.iter().any(|&r| r != 1) {
1728 provider.repmat(a, &reps_a).ok()?
1729 } else {
1730 a.clone()
1731 };
1732 let b_exp = if reps_b.iter().any(|&r| r != 1) {
1733 provider.repmat(b, &reps_b).ok()?
1734 } else {
1735 b.clone()
1736 };
1737 let values = provider.elem_max(&a_exp, &b_exp).await.ok();
1738 let mask = provider.elem_ge(&a_exp, &b_exp).await.ok();
1739 if !std::ptr::eq(&a_exp, a) {
1740 let _ = provider.free(&a_exp);
1741 }
1742 if !std::ptr::eq(&b_exp, b) {
1743 let _ = provider.free(&b_exp);
1744 }
1745 let values = values?;
1746 if values.shape != out_shape {
1747 let _ = provider.free(&values);
1748 return None;
1749 }
1750 let index_tensor = if let Some(mask) = mask {
1751 let mask_host = gpu_helpers::gather_tensor_async(&mask).await.ok()?;
1752 let _ = provider.free(&mask);
1753 let mut indices = Vec::with_capacity(mask_host.data.len());
1754 for &m in &mask_host.data {
1755 indices.push(if m != 0.0 { 1.0 } else { 2.0 });
1756 }
1757 Tensor::new(indices, out_shape).ok()?
1758 } else {
1759 let ta = gpu_helpers::gather_tensor_async(&a_exp).await.ok()?;
1761 let tb = gpu_helpers::gather_tensor_async(&b_exp).await.ok()?;
1762 let mut indices = Vec::with_capacity(ta.data.len());
1763 for i in 0..ta.data.len() {
1764 indices.push(if ta.data[i] >= tb.data[i] { 1.0 } else { 2.0 });
1765 }
1766 Tensor::new(indices, out_shape).ok()?
1767 };
1768 Some(MaxEvaluation {
1769 values: Value::GpuTensor(values),
1770 indices: tensor::tensor_into_value(index_tensor),
1771 })
1772}
1773
1774fn broadcast_reps(a: &[usize], b: &[usize]) -> Option<(Vec<usize>, Vec<usize>, Vec<usize>)> {
1775 let rank = a.len().max(b.len()).max(1);
1776 let mut out = vec![1usize; rank];
1777 let mut aa = vec![1usize; rank];
1778 let mut bb = vec![1usize; rank];
1779 for i in 0..rank {
1780 aa[i] = *a.get(i).unwrap_or(&1);
1781 bb[i] = *b.get(i).unwrap_or(&1);
1782 }
1783 for i in 0..rank {
1784 let (ad, bd) = (aa[i], bb[i]);
1785 if ad == bd {
1786 out[i] = ad;
1787 } else if ad == 1 {
1788 out[i] = bd;
1789 } else if bd == 1 {
1790 out[i] = ad;
1791 } else {
1792 return None;
1793 }
1794 }
1795 let reps_a: Vec<usize> = (0..rank)
1796 .map(|i| if aa[i] == out[i] { 1 } else { out[i] })
1797 .collect();
1798 let reps_b: Vec<usize> = (0..rank)
1799 .map(|i| if bb[i] == out[i] { 1 } else { out[i] })
1800 .collect();
1801 Some((out, reps_a, reps_b))
1802}
1803
1804async fn elementwise_max_gpu_scalar_left(
1805 a: &GpuTensorHandle,
1806 other: &Value,
1807 comparison: ComparisonMethod,
1808) -> Option<MaxEvaluation> {
1809 if comparison != ComparisonMethod::Auto {
1810 return None;
1811 }
1812 let provider = runmat_accelerate_api::provider()?;
1813 let scalar = extract_scalar(other)?;
1814 let values = if let Ok(fill) = provider.fill_like(a, scalar) {
1816 let vals = provider.elem_max(a, &fill).await.ok();
1817 let _ = provider.free(&fill);
1818 vals?
1819 } else {
1820 provider.scalar_max(a, scalar).ok()?
1821 };
1822 let index_tensor = if let Ok(fill) = provider.fill_like(a, scalar) {
1824 if let Ok(mask) = provider.elem_ge(a, &fill).await {
1825 let _ = provider.free(&fill);
1826 let indices = gpu_mask_indices(provider, &mask)?;
1827 let _ = provider.free(&mask);
1828 return Some(MaxEvaluation {
1829 values: Value::GpuTensor(values),
1830 indices: Value::GpuTensor(indices),
1831 });
1832 } else {
1833 let _ = provider.free(&fill);
1834 let ta = gpu_helpers::gather_tensor_async(a).await.ok()?;
1835 let mut indices = Vec::with_capacity(ta.data.len());
1836 for &v in &ta.data {
1837 indices.push(if v >= scalar { 1.0 } else { 2.0 });
1838 }
1839 Tensor::new(indices, ta.shape.clone()).ok()?
1840 }
1841 } else {
1842 let ta = gpu_helpers::gather_tensor_async(a).await.ok()?;
1843 let mut indices = Vec::with_capacity(ta.data.len());
1844 for &v in &ta.data {
1845 indices.push(if v >= scalar { 1.0 } else { 2.0 });
1846 }
1847 Tensor::new(indices, ta.shape.clone()).ok()?
1848 };
1849 Some(MaxEvaluation {
1850 values: Value::GpuTensor(values),
1851 indices: tensor::tensor_into_value(index_tensor),
1852 })
1853}
1854
1855async fn elementwise_max_gpu_scalar_right(
1856 other: &Value,
1857 b: &GpuTensorHandle,
1858 comparison: ComparisonMethod,
1859) -> Option<MaxEvaluation> {
1860 if comparison != ComparisonMethod::Auto {
1861 return None;
1862 }
1863 let provider = runmat_accelerate_api::provider()?;
1864 let scalar = extract_scalar(other)?;
1865 let values = if let Ok(fill) = provider.fill_like(b, scalar) {
1866 let vals = provider.elem_max(&fill, b).await.ok();
1867 let _ = provider.free(&fill);
1868 vals?
1869 } else {
1870 provider.scalar_max(b, scalar).ok()?
1871 };
1872 let index_tensor = if let Ok(fill) = provider.fill_like(b, scalar) {
1874 if let Ok(mask) = provider.elem_ge(&fill, b).await {
1875 let _ = provider.free(&fill);
1876 let indices = gpu_mask_indices(provider, &mask)?;
1877 let _ = provider.free(&mask);
1878 return Some(MaxEvaluation {
1879 values: Value::GpuTensor(values),
1880 indices: Value::GpuTensor(indices),
1881 });
1882 } else {
1883 let _ = provider.free(&fill);
1884 let tb = gpu_helpers::gather_tensor_async(b).await.ok()?;
1885 let mut indices = Vec::with_capacity(tb.data.len());
1886 for &v in &tb.data {
1887 indices.push(if scalar >= v { 1.0 } else { 2.0 });
1888 }
1889 Tensor::new(indices, tb.shape.clone()).ok()?
1890 }
1891 } else {
1892 let tb = gpu_helpers::gather_tensor_async(b).await.ok()?;
1893 let mut indices = Vec::with_capacity(tb.data.len());
1894 for &v in &tb.data {
1895 indices.push(if scalar >= v { 1.0 } else { 2.0 });
1896 }
1897 Tensor::new(indices, tb.shape.clone()).ok()?
1898 };
1899 Some(MaxEvaluation {
1900 values: Value::GpuTensor(values),
1901 indices: tensor::tensor_into_value(index_tensor),
1902 })
1903}
1904
1905fn extract_scalar(v: &Value) -> Option<f64> {
1906 match v {
1907 Value::Num(n) => Some(*n),
1908 Value::Int(i) => Some(i.to_f64()),
1909 Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
1910 Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
1911 Value::LogicalArray(l) if l.data.len() == 1 => Some(if l.data[0] != 0 { 1.0 } else { 0.0 }),
1912 _ => None,
1913 }
1914}
1915
1916fn gpu_tensor_is_scalar(handle: &GpuTensorHandle) -> bool {
1917 handle.shape.iter().copied().product::<usize>().max(1) == 1
1918}
1919
1920async fn gpu_tensor_scalar_value(handle: &GpuTensorHandle) -> Option<f64> {
1921 let tensor = gpu_helpers::gather_tensor_async(handle).await.ok()?;
1922 tensor.data.first().copied()
1923}
1924
1925fn gpu_mask_indices(
1926 provider: &dyn AccelProvider,
1927 mask: &GpuTensorHandle,
1928) -> Option<GpuTensorHandle> {
1929 let scaled = provider.scalar_mul(mask, -1.0).ok()?;
1930 let shifted = provider.scalar_add(&scaled, 2.0).ok()?;
1931 let _ = provider.free(&scaled);
1932 Some(shifted)
1933}
1934
1935fn elementwise_real_or_complex(
1936 lhs: Value,
1937 rhs: Value,
1938 comparison: ComparisonMethod,
1939) -> BuiltinResult<MaxEvaluation> {
1940 if let Some(eval) = scalar_elementwise_max(&lhs, &rhs, comparison) {
1941 return Ok(eval);
1942 }
1943 match (
1944 materialize_for_max("max", lhs)?,
1945 materialize_for_max("max", rhs)?,
1946 ) {
1947 (InputData::Complex(a), InputData::Complex(b)) => elementwise_complex_max(a, b, comparison),
1948 (InputData::Complex(a), InputData::Real(b)) => {
1949 let converted = promote_real_tensor_to_complex(b);
1950 elementwise_complex_max(a, converted, comparison)
1951 }
1952 (InputData::Real(a), InputData::Complex(b)) => {
1953 let converted = promote_real_tensor_to_complex(a);
1954 elementwise_complex_max(converted, b, comparison)
1955 }
1956 (InputData::Real(a), InputData::Real(b)) => elementwise_real_max(a, b, comparison),
1957 }
1958}
1959
1960fn scalar_real_value(value: &Value) -> Option<f64> {
1961 match value {
1962 Value::Num(n) => Some(*n),
1963 Value::Int(i) => Some(i.to_f64()),
1964 Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
1965 Value::Tensor(t) if t.data.len() == 1 => t.data.first().copied(),
1966 Value::LogicalArray(l) if l.data.len() == 1 => Some(if l.data[0] != 0 { 1.0 } else { 0.0 }),
1967 _ => None,
1968 }
1969}
1970
1971fn scalar_complex_value(value: &Value) -> Option<(f64, f64)> {
1972 match value {
1973 Value::Complex(re, im) => Some((*re, *im)),
1974 Value::ComplexTensor(ct) if ct.data.len() == 1 => ct.data.first().copied(),
1975 _ => None,
1976 }
1977}
1978
1979fn scalar_elementwise_max(
1980 lhs: &Value,
1981 rhs: &Value,
1982 comparison: ComparisonMethod,
1983) -> Option<MaxEvaluation> {
1984 let left = scalar_complex_value(lhs).or_else(|| scalar_real_value(lhs).map(|v| (v, 0.0)))?;
1985 let right = scalar_complex_value(rhs).or_else(|| scalar_real_value(rhs).map(|v| (v, 0.0)))?;
1986 let (ar, ai) = left;
1987 let (br, bi) = right;
1988 if ai != 0.0 || bi != 0.0 {
1989 let (value, origin) = choose_complex_elementwise((ar, ai), (br, bi), comparison);
1990 return Some(MaxEvaluation {
1991 values: Value::Complex(value.0, value.1),
1992 indices: Value::Num(origin),
1993 });
1994 }
1995 let (value, origin) = choose_real_elementwise(ar, br, comparison);
1996 Some(MaxEvaluation {
1997 values: Value::Num(value),
1998 indices: Value::Num(origin),
1999 })
2000}
2001
2002fn elementwise_real_max(
2003 lhs: Tensor,
2004 rhs: Tensor,
2005 comparison: ComparisonMethod,
2006) -> BuiltinResult<MaxEvaluation> {
2007 let plan = BroadcastPlan::new(&lhs.shape, &rhs.shape)
2008 .map_err(|err| max_size_mismatch(format!("max: {err}")))?;
2009 let mut values = vec![0.0f64; plan.len()];
2010 let mut indices = vec![0.0f64; plan.len()];
2011
2012 for (offset, index_a, index_b) in plan.iter() {
2013 let a = lhs.data.get(index_a).copied().unwrap_or(f64::NAN);
2014 let b = rhs.data.get(index_b).copied().unwrap_or(f64::NAN);
2015 let (value, origin) = choose_real_elementwise(a, b, comparison);
2016 values[offset] = value;
2017 indices[offset] = origin;
2018 }
2019
2020 let value_tensor = Tensor::new(values, plan.output_shape().to_vec())
2021 .map_err(|e| max_internal_error(format!("max: {e}")))?;
2022 let index_tensor = Tensor::new(indices, plan.output_shape().to_vec())
2023 .map_err(|e| max_internal_error(format!("max: {e}")))?;
2024
2025 Ok(MaxEvaluation {
2026 values: tensor::tensor_into_value(value_tensor),
2027 indices: tensor::tensor_into_value(index_tensor),
2028 })
2029}
2030
2031fn elementwise_complex_max(
2032 lhs: ComplexTensor,
2033 rhs: ComplexTensor,
2034 comparison: ComparisonMethod,
2035) -> BuiltinResult<MaxEvaluation> {
2036 let plan = BroadcastPlan::new(&lhs.shape, &rhs.shape)
2037 .map_err(|err| max_size_mismatch(format!("max: {err}")))?;
2038 let mut values = vec![(0.0f64, 0.0f64); plan.len()];
2039 let mut indices = vec![0.0f64; plan.len()];
2040
2041 for (offset, index_a, index_b) in plan.iter() {
2042 let a = lhs
2043 .data
2044 .get(index_a)
2045 .copied()
2046 .unwrap_or((f64::NAN, f64::NAN));
2047 let b = rhs
2048 .data
2049 .get(index_b)
2050 .copied()
2051 .unwrap_or((f64::NAN, f64::NAN));
2052 let (value, origin) = choose_complex_elementwise(a, b, comparison);
2053 values[offset] = value;
2054 indices[offset] = origin;
2055 }
2056
2057 let value_tensor = ComplexTensor::new(values, plan.output_shape().to_vec())
2058 .map_err(|e| max_internal_error(format!("max: {e}")))?;
2059 let index_tensor = Tensor::new(indices, plan.output_shape().to_vec())
2060 .map_err(|e| max_internal_error(format!("max: {e}")))?;
2061
2062 Ok(MaxEvaluation {
2063 values: complex_tensor_into_value(value_tensor),
2064 indices: tensor::tensor_into_value(index_tensor),
2065 })
2066}
2067
2068fn promote_real_tensor_to_complex(tensor: Tensor) -> ComplexTensor {
2069 let data = tensor
2070 .data
2071 .iter()
2072 .copied()
2073 .map(|re| (re, 0.0))
2074 .collect::<Vec<_>>();
2075 ComplexTensor {
2076 data,
2077 shape: tensor.shape.clone(),
2078 rows: tensor.rows,
2079 cols: tensor.cols,
2080 }
2081}
2082
2083fn choose_real_elementwise(a: f64, b: f64, comparison: ComparisonMethod) -> (f64, f64) {
2084 match (a.is_nan(), b.is_nan()) {
2085 (true, true) => (f64::NAN, 1.0),
2086 (true, false) => (f64::NAN, 1.0),
2087 (false, true) => (f64::NAN, 2.0),
2088 (false, false) => {
2089 if should_replace_real(a, b, comparison) {
2090 (b, 2.0)
2091 } else {
2092 (a, 1.0)
2093 }
2094 }
2095 }
2096}
2097
2098fn choose_complex_elementwise(
2099 a: (f64, f64),
2100 b: (f64, f64),
2101 comparison: ComparisonMethod,
2102) -> ((f64, f64), f64) {
2103 let a_nan = a.0.is_nan() || a.1.is_nan();
2104 let b_nan = b.0.is_nan() || b.1.is_nan();
2105 match (a_nan, b_nan) {
2106 (true, true) => ((f64::NAN, f64::NAN), 1.0),
2107 (true, false) => ((f64::NAN, f64::NAN), 1.0),
2108 (false, true) => ((f64::NAN, f64::NAN), 2.0),
2109 (false, false) => {
2110 if should_replace_complex(a, b, comparison) {
2111 (b, 2.0)
2112 } else {
2113 (a, 1.0)
2114 }
2115 }
2116 }
2117}
2118
2119#[cfg(test)]
2120pub(crate) mod tests {
2121 use super::*;
2122 #[cfg(feature = "wgpu")]
2123 use crate::builtins::common::test_support;
2124 use futures::executor::block_on;
2125 #[cfg(feature = "wgpu")]
2126 use runmat_accelerate_api::HostTensorView;
2127 use runmat_builtins::{IntValue, Tensor, Value};
2128
2129 fn max_builtin(value: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
2130 block_on(super::max_builtin(value, rest))
2131 }
2132
2133 #[test]
2134 fn max_type_with_two_args_returns_tensor() {
2135 let out = max_type(
2136 &[Type::Tensor { shape: None }, Type::Tensor { shape: None }],
2137 &ResolveContext::new(Vec::new()),
2138 );
2139 assert_eq!(out, Type::tensor());
2140 }
2141
2142 #[test]
2143 fn max_descriptor_signatures_cover_core_forms() {
2144 let labels: Vec<&str> = MAX_DESCRIPTOR
2145 .signatures
2146 .iter()
2147 .map(|sig| sig.label)
2148 .collect();
2149 assert!(labels.contains(&"M = max(A)"));
2150 assert!(labels.contains(&"[M, I] = max(A)"));
2151 assert!(labels.contains(&"M = max(A, B)"));
2152 assert!(labels.contains(&"[M, I] = max(A, B)"));
2153 assert!(labels.contains(&"M = max(A, [], dim)"));
2154 assert!(labels.contains(&"M = max(A, [], \"all\")"));
2155 assert!(labels.contains(&"M = max(A, [], \"ComparisonMethod\", method)"));
2156 assert!(labels.contains(&"M = max(A, B, \"ComparisonMethod\", method)"));
2157 }
2158
2159 #[test]
2160 fn max_descriptor_errors_have_stable_codes() {
2161 assert!(MAX_DESCRIPTOR
2162 .errors
2163 .iter()
2164 .any(|error| error.code == MAX_ERROR_INVALID_ARGUMENT.code));
2165 assert!(MAX_DESCRIPTOR
2166 .errors
2167 .iter()
2168 .any(|error| error.code == MAX_ERROR_INVALID_INPUT.code));
2169 assert!(MAX_DESCRIPTOR
2170 .errors
2171 .iter()
2172 .any(|error| error.code == MAX_ERROR_SIZE_MISMATCH.code));
2173 assert!(MAX_DESCRIPTOR
2174 .errors
2175 .iter()
2176 .any(|error| error.code == MAX_ERROR_INTERNAL.code));
2177 }
2178
2179 fn evaluate(value: Value, rest: &[Value]) -> BuiltinResult<MaxEvaluation> {
2180 block_on(super::evaluate(value, rest))
2181 }
2182
2183 fn placeholder() -> Value {
2184 let tensor = Tensor::new(Vec::<f64>::new(), vec![0, 0]).unwrap();
2185 Value::Tensor(tensor)
2186 }
2187
2188 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2189 #[test]
2190 fn max_scalar_returns_input() {
2191 let result = max_builtin(Value::Num(5.0), Vec::new()).expect("max");
2192 assert_eq!(result, Value::Num(5.0));
2193 }
2194
2195 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2196 #[test]
2197 fn max_vector_with_indices() {
2198 let tensor = Tensor::new(vec![3.0, 1.0, 5.0], vec![3, 1]).unwrap();
2199 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
2200 let (values, indices) = eval.into_pair();
2201 assert_eq!(values, Value::Num(5.0));
2202 assert_eq!(indices, Value::Num(3.0));
2203 }
2204
2205 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2206 #[test]
2207 fn max_matrix_default_dimension() {
2208 let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0, 5.0, 6.0], vec![2, 3]).unwrap();
2209 let eval = evaluate(Value::Tensor(tensor), &[]).expect("evaluate");
2210 let (values, indices) = eval.into_pair();
2211 match values {
2212 Value::Tensor(t) => {
2213 assert_eq!(t.shape, vec![1, 3]);
2214 assert_eq!(t.data, vec![4.0, 2.0, 6.0]);
2215 }
2216 other => panic!("expected tensor, got {other:?}"),
2217 }
2218 match indices {
2219 Value::Tensor(t) => {
2220 assert_eq!(t.data, vec![2.0, 2.0, 2.0]);
2221 }
2222 other => panic!("expected tensor, got {other:?}"),
2223 }
2224 }
2225
2226 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2227 #[test]
2228 fn max_all_linear_index() {
2229 let tensor =
2230 Tensor::new((1..=12).map(|v| v as f64).collect::<Vec<_>>(), vec![3, 4]).unwrap();
2231 let args = vec![placeholder(), Value::from("all")];
2232 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2233 let (values, indices) = eval.into_pair();
2234 assert_eq!(values, Value::Num(12.0));
2235 assert_eq!(indices, Value::Num(12.0));
2236
2237 let args_linear = vec![placeholder(), Value::from("linear")];
2238 let eval = evaluate(
2239 Value::Tensor(Tensor::new(vec![2.0, 3.0], vec![1, 2]).unwrap()),
2240 &args_linear,
2241 )
2242 .expect("evaluate");
2243 let (values, indices) = eval.into_pair();
2244 assert_eq!(values, Value::Num(3.0));
2245 assert_eq!(indices, Value::Num(2.0));
2246 }
2247
2248 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2249 #[test]
2250 fn max_with_omitnan() {
2251 let tensor = Tensor::new(vec![f64::NAN, 4.0, 2.0], vec![3, 1]).unwrap();
2252 let args = vec![placeholder(), Value::from("omitnan")];
2253 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2254 let (values, indices) = eval.into_pair();
2255 assert_eq!(values, Value::Num(4.0));
2256 assert_eq!(indices, Value::Num(2.0));
2257 }
2258
2259 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2260 #[test]
2261 fn max_omitnan_all_nan_slice() {
2262 let tensor = Tensor::new(vec![f64::NAN, f64::NAN], vec![2, 1]).unwrap();
2263 let args = vec![placeholder(), Value::from("omitnan")];
2264 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2265 let (values, indices) = eval.into_pair();
2266 match values {
2267 Value::Num(v) => assert!(v.is_nan()),
2268 other => panic!("expected scalar NaN, got {other:?}"),
2269 }
2270 match indices {
2271 Value::Num(v) => assert!(v.is_nan()),
2272 other => panic!("expected scalar NaN index, got {other:?}"),
2273 }
2274 }
2275
2276 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2277 #[test]
2278 fn max_reduction_abs_comparison() {
2279 let tensor = Tensor::new(vec![1.0, -3.0, -2.0, 4.0], vec![2, 2]).unwrap();
2280 let args = vec![
2281 placeholder(),
2282 Value::from("ComparisonMethod"),
2283 Value::from("abs"),
2284 ];
2285 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2286 let (values, indices) = eval.into_pair();
2287 match values {
2288 Value::Tensor(t) => {
2289 assert_eq!(t.shape, vec![1, 2]);
2290 assert_eq!(t.data, vec![-3.0, 4.0]);
2291 }
2292 other => panic!("expected tensor result, got {other:?}"),
2293 }
2294 match indices {
2295 Value::Tensor(t) => {
2296 assert_eq!(t.data, vec![2.0, 2.0]);
2297 }
2298 other => panic!("expected tensor indices, got {other:?}"),
2299 }
2300 }
2301
2302 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2303 #[test]
2304 fn max_reduction_complex_real_comparison() {
2305 let tensor = ComplexTensor::new(vec![(1.0, 2.0), (0.5, 5.0)], vec![2, 1]).expect("tensor");
2306 let args = vec![
2307 placeholder(),
2308 Value::from("ComparisonMethod"),
2309 Value::from("real"),
2310 ];
2311 let eval = evaluate(Value::ComplexTensor(tensor), &args).expect("evaluate");
2312 let (values, indices) = eval.into_pair();
2313 match values {
2314 Value::Complex(re, im) => {
2315 assert!((re - 1.0).abs() < 1e-12);
2316 assert!((im - 2.0).abs() < 1e-12);
2317 }
2318 other => panic!("expected complex scalar, got {other:?}"),
2319 }
2320 assert_eq!(indices, Value::Num(1.0));
2321 }
2322
2323 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2324 #[test]
2325 fn max_elementwise_broadcast() {
2326 let lhs = Tensor::new(vec![1.0, 4.0, 7.0], vec![1, 3]).unwrap();
2327 let rhs = Tensor::new(vec![2.0, 3.0, 5.0], vec![3, 1]).unwrap();
2328 let eval = evaluate(Value::Tensor(lhs), &[Value::Tensor(rhs)]).expect("evaluate");
2329 let (values, indices) = eval.into_pair();
2330 match values {
2331 Value::Tensor(t) => {
2332 assert_eq!(t.shape, vec![3, 3]);
2333 assert_eq!([t.data[0], t.data[3], t.data[6]], [2.0, 4.0, 7.0]);
2334 assert_eq!([t.data[1], t.data[4], t.data[7]], [3.0, 4.0, 7.0]);
2335 assert_eq!([t.data[2], t.data[5], t.data[8]], [5.0, 5.0, 7.0]);
2336 }
2337 other => panic!("expected tensor, got {other:?}"),
2338 }
2339 match indices {
2340 Value::Tensor(t) => {
2341 assert_eq!(t.shape, vec![3, 3]);
2342 assert_eq!([t.data[0], t.data[3], t.data[6]], [2.0, 1.0, 1.0]);
2343 assert_eq!([t.data[1], t.data[4], t.data[7]], [2.0, 1.0, 1.0]);
2344 assert_eq!([t.data[2], t.data[5], t.data[8]], [2.0, 2.0, 1.0]);
2345 }
2346 other => panic!("expected tensor, got {other:?}"),
2347 }
2348 }
2349
2350 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2351 #[test]
2352 fn max_elementwise_abs_comparison() {
2353 let lhs = Tensor::new(vec![-2.0, 1.0], vec![2, 1]).unwrap();
2354 let rhs = Tensor::new(vec![1.5, -3.0], vec![2, 1]).unwrap();
2355 let args = vec![
2356 Value::Tensor(rhs),
2357 Value::from("ComparisonMethod"),
2358 Value::from("abs"),
2359 ];
2360 let eval = evaluate(Value::Tensor(lhs), &args).expect("evaluate");
2361 let (values, indices) = eval.into_pair();
2362 match values {
2363 Value::Tensor(t) => {
2364 assert_eq!(t.data, vec![-2.0, -3.0]);
2365 }
2366 other => panic!("expected tensor, got {other:?}"),
2367 }
2368 match indices {
2369 Value::Tensor(t) => {
2370 assert_eq!(t.data, vec![1.0, 2.0]);
2371 }
2372 other => panic!("expected tensor, got {other:?}"),
2373 }
2374 }
2375
2376 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2377 #[test]
2378 fn max_elementwise_rejects_reduction_only_keywords() {
2379 let lhs = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
2380 let rhs = Tensor::new(vec![3.0, 4.0], vec![2, 1]).unwrap();
2381 let err = evaluate(
2382 Value::Tensor(lhs),
2383 &[Value::Tensor(rhs), Value::from("omitnan")],
2384 )
2385 .expect_err("expected error");
2386 assert_eq!(err.identifier(), MAX_ERROR_INVALID_ARGUMENT.identifier);
2387 assert!(err.message().contains("only supported for reduction"));
2388 }
2389
2390 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2391 #[test]
2392 fn max_complex_real_comparison() {
2393 let lhs = ComplexTensor::new(vec![(1.0, 2.0)], vec![1, 1]).unwrap();
2394 let rhs = ComplexTensor::new(vec![(0.5, 5.0)], vec![1, 1]).unwrap();
2395 let args = vec![
2396 Value::ComplexTensor(rhs),
2397 Value::from("ComparisonMethod"),
2398 Value::from("real"),
2399 ];
2400 let eval = evaluate(Value::ComplexTensor(lhs), &args).expect("evaluate");
2401 let (values, indices) = eval.into_pair();
2402 assert_eq!(values, Value::Complex(1.0, 2.0));
2403 assert_eq!(indices, Value::Num(1.0));
2404 }
2405
2406 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2407 #[test]
2408 fn max_dimension_argument_parsing() {
2409 let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0], vec![2, 2]).unwrap();
2410 let dims = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
2411 let args = vec![placeholder(), Value::Tensor(dims)];
2412 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2413 let (values, indices) = eval.into_pair();
2414 assert_eq!(values, Value::Num(4.0));
2415 assert_eq!(indices, Value::Num(2.0));
2416 }
2417
2418 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2419 #[test]
2420 fn max_vecdim_duplicate_entries() {
2421 let tensor = Tensor::new(vec![5.0, 2.0, 7.0, 1.0], vec![2, 2]).unwrap();
2422 let dims = Tensor::new(vec![1.0, 1.0, 2.0], vec![3, 1]).unwrap();
2423 let args = vec![placeholder(), Value::Tensor(dims)];
2424 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2425 let (values, indices) = eval.into_pair();
2426 assert_eq!(values, Value::Num(7.0));
2427 assert_eq!(indices, Value::Num(3.0));
2428 }
2429
2430 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2431 #[test]
2432 fn max_dimension_gpu_argument_errors() {
2433 let tensor = Tensor::new(vec![3.0, 1.0], vec![2, 1]).unwrap();
2434 let dim_handle = Value::GpuTensor(runmat_accelerate_api::GpuTensorHandle {
2435 shape: vec![1, 1],
2436 device_id: 0,
2437 buffer_id: 42,
2438 });
2439 let err = evaluate(Value::Tensor(tensor), &[placeholder(), dim_handle])
2440 .expect_err("expected error");
2441 assert_eq!(err.identifier(), MAX_ERROR_INVALID_ARGUMENT.identifier);
2442 assert!(err
2443 .message()
2444 .contains("dimension arguments must reside on the host"));
2445 }
2446
2447 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2448 #[test]
2449 fn max_invalid_comparison_method_errors() {
2450 let tensor = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
2451 let args = vec![
2452 placeholder(),
2453 Value::from("ComparisonMethod"),
2454 Value::from("chebyshev"),
2455 ];
2456 let err = evaluate(Value::Tensor(tensor), &args).expect_err("expected error");
2457 assert_eq!(err.identifier(), MAX_ERROR_INVALID_ARGUMENT.identifier);
2458 assert!(err.message().contains("unsupported ComparisonMethod"));
2459 }
2460
2461 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2462 #[test]
2463 #[cfg(feature = "wgpu")]
2464 fn max_gpu_dim1_matches_cpu() {
2465 let tensor = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2466 let eval_cpu = evaluate(Value::Tensor(tensor.clone()), &[]).expect("cpu");
2467 let (values_cpu, indices_cpu) = eval_cpu.into_pair();
2468
2469 test_support::with_test_provider(|provider| {
2470 let view = HostTensorView {
2471 data: &tensor.data,
2472 shape: &tensor.shape,
2473 };
2474 let handle = provider.upload(&view).expect("upload");
2475 let eval_gpu = evaluate(Value::GpuTensor(handle), &[]).expect("gpu");
2476 let (values_gpu, indices_gpu) = eval_gpu.into_pair();
2477 match (&values_gpu, &indices_gpu) {
2478 (Value::GpuTensor(_), Value::GpuTensor(_)) => {}
2479 other => panic!("expected GPU tensors, got {other:?}"),
2480 }
2481 let gathered_vals = test_support::gather(values_gpu).expect("gather values");
2482 let gathered_idx = test_support::gather(indices_gpu).expect("gather indices");
2483 let expected_vals = match values_cpu {
2484 Value::Tensor(t) => t,
2485 other => panic!("expected tensor values from cpu eval, got {other:?}"),
2486 };
2487 let expected_idx = match indices_cpu {
2488 Value::Tensor(t) => t,
2489 other => panic!("expected tensor indices from cpu eval, got {other:?}"),
2490 };
2491 assert_eq!(gathered_vals.shape, expected_vals.shape);
2492 assert_eq!(gathered_vals.data, expected_vals.data);
2493 assert_eq!(gathered_idx.shape, expected_idx.shape);
2494 assert_eq!(gathered_idx.data, expected_idx.data);
2495 });
2496 }
2497
2498 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2499 #[test]
2500 fn max_dimension_numeric_argument() {
2501 let tensor = Tensor::new(vec![3.0, 4.0, 1.0, 2.0], vec![2, 2]).unwrap();
2502 let args = vec![placeholder(), Value::Num(2.0)];
2503 let eval = evaluate(Value::Tensor(tensor), &args).expect("evaluate");
2504 let (values, indices) = eval.into_pair();
2505 match values {
2506 Value::Tensor(t) => {
2507 assert_eq!(t.shape, vec![2, 1]);
2508 assert_eq!(t.data, vec![3.0, 4.0]);
2509 }
2510 other => panic!("expected tensor, got {other:?}"),
2511 }
2512 match indices {
2513 Value::Tensor(t) => {
2514 assert_eq!(t.data, vec![1.0, 1.0]);
2515 }
2516 other => panic!("expected tensor, got {other:?}"),
2517 }
2518 }
2519
2520 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2521 #[test]
2522 fn max_complex_auto_comparison() {
2523 let lhs = ComplexTensor::new(vec![(1.0, 2.0)], vec![1, 1]).unwrap();
2524 let rhs = ComplexTensor::new(vec![(2.0, 1.0)], vec![1, 1]).unwrap();
2525 let eval =
2526 evaluate(Value::ComplexTensor(lhs), &[Value::ComplexTensor(rhs)]).expect("evaluate");
2527 let (values, indices) = eval.into_pair();
2528 assert_eq!(values, Value::Complex(1.0, 2.0));
2529 assert_eq!(indices, Value::Num(1.0));
2530 }
2531
2532 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2533 #[test]
2534 fn max_scalar_pair_arguments() {
2535 let args = vec![Value::Num(2.0)];
2536 let result = max_builtin(Value::Num(3.0), args).expect("max");
2537 assert_eq!(result, Value::Num(3.0));
2538 }
2539
2540 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2541 #[test]
2542 fn max_rejects_invalid_dimension() {
2543 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![3, 1]).unwrap();
2544 let args = vec![placeholder(), Value::Int(IntValue::I32(0))];
2545 let err = evaluate(Value::Tensor(tensor), &args).expect_err("expected error");
2546 assert_eq!(err.identifier(), MAX_ERROR_INVALID_ARGUMENT.identifier);
2547 assert!(err.message().contains("dimension must be >= 1"));
2548 }
2549}