runmat_accelerate/
reduction_meta.rs1use crate::graph::{AccelGraph, AccelNode, AccelOpCategory, ValueId};
2use runmat_builtins::{IntValue, Tensor, Type, Value};
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum ReductionBehavior {
6 SumLike,
7 MeanLike, }
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum ReductionAxes {
12 Unspecified,
13 All,
14 Explicit(Vec<usize>),
15}
16
17#[derive(Debug, Clone)]
18pub struct ReductionSignature {
19 pub data_input: ValueId,
20 pub dim_arg: Option<ValueId>,
21 pub behavior: ReductionBehavior,
22 pub axes: ReductionAxes,
23}
24
25pub fn detect_reduction_signature(
31 graph: &AccelGraph,
32 node: &AccelNode,
33) -> Option<ReductionSignature> {
34 if node.category != AccelOpCategory::Reduction {
35 return None;
36 }
37 let (name_opt, inputs) = match &node.label {
38 crate::graph::AccelNodeLabel::Builtin { name } => {
39 (Some(name.as_str()), node.inputs.as_slice())
40 }
41 _ => (None, node.inputs.as_slice()),
42 };
43 if inputs.is_empty() {
44 return None;
45 }
46
47 let mut data_input = inputs[0];
49 for &vid in inputs {
50 if let Some(info) = graph.value(vid) {
51 if matches!(info.ty, Type::Tensor { .. }) {
52 data_input = vid;
53 break;
54 }
55 }
56 }
57
58 let mut dim_arg: Option<ValueId> = None;
60 for &vid in inputs {
61 if vid == data_input {
62 continue;
63 }
64 if let Some(info) = graph.value(vid) {
65 if matches!(info.origin, crate::graph::ValueOrigin::Constant) {
67 if matches!(info.ty, Type::Num | Type::Int) {
69 dim_arg = Some(vid);
70 break;
71 }
72 }
73 }
74 }
75
76 let behavior = name_opt
78 .map(|n| match n.to_ascii_lowercase().as_str() {
79 "mean" => ReductionBehavior::MeanLike,
80 "sum" => ReductionBehavior::SumLike,
82 _ => ReductionBehavior::SumLike,
83 })
84 .unwrap_or(ReductionBehavior::SumLike);
85
86 let mut axes = ReductionAxes::Unspecified;
87 if let Some(dim_vid) = dim_arg {
89 if let Some(value) = graph.value(dim_vid).and_then(|info| info.constant.clone()) {
90 if value_is_all_keyword(&value) {
91 axes = ReductionAxes::All;
92 } else if let Some(dims) = parse_dims_from_value(&value) {
93 axes = ReductionAxes::Explicit(dims);
94 }
95 }
96 }
97 if matches!(axes, ReductionAxes::Unspecified) {
99 for &vid in inputs {
100 if vid == data_input {
101 continue;
102 }
103 if let Some(value) = graph.value(vid).and_then(|info| info.constant.clone()) {
104 if value_is_all_keyword(&value) {
105 axes = ReductionAxes::All;
106 break;
107 } else if let Some(dims) = parse_dims_from_value(&value) {
108 axes = ReductionAxes::Explicit(dims);
109 break;
110 }
111 }
112 }
113 }
114
115 Some(ReductionSignature {
116 data_input,
117 dim_arg,
118 behavior,
119 axes,
120 })
121}
122
123pub fn value_is_all_keyword(value: &Value) -> bool {
124 match value {
125 Value::String(s) => s.eq_ignore_ascii_case("all"),
126 Value::CharArray(ca) => {
127 if ca.rows == 1 {
128 let candidate: String = ca.data.iter().collect();
129 candidate.trim().eq_ignore_ascii_case("all")
130 } else {
131 false
132 }
133 }
134 Value::StringArray(sa) => sa.data.len() == 1 && sa.data[0].eq_ignore_ascii_case("all"),
135 _ => false,
136 }
137}
138
139fn parse_dims_from_value(value: &Value) -> Option<Vec<usize>> {
140 match value {
141 Value::Int(int_val) => parse_single_int(int_val),
142 Value::Num(n) => parse_single_float(*n),
143 Value::Tensor(t) => parse_tensor_dims(t),
144 _ => None,
145 }
146}
147
148fn parse_single_int(int_val: &IntValue) -> Option<Vec<usize>> {
149 let raw = int_val.to_i64();
150 if raw >= 1 {
151 Some(vec![raw as usize])
152 } else {
153 None
154 }
155}
156
157fn parse_single_float(value: f64) -> Option<Vec<usize>> {
158 if !value.is_finite() {
159 return None;
160 }
161 let rounded = value.round();
162 if (rounded - value).abs() > f64::EPSILON || rounded < 1.0 {
163 return None;
164 }
165 Some(vec![rounded as usize])
166}
167
168fn parse_tensor_dims(tensor: &Tensor) -> Option<Vec<usize>> {
169 if tensor.data.is_empty() {
170 return None;
171 }
172 let mut dims = Vec::with_capacity(tensor.data.len());
173 for value in &tensor.data {
174 if let Some(parsed) = parse_single_float(*value) {
175 dims.extend(parsed);
176 } else {
177 return None;
178 }
179 }
180 if dims.is_empty() {
181 None
182 } else {
183 Some(dims)
184 }
185}