runmat_accelerate/
reduction_meta.rs

1use crate::graph::{AccelGraph, AccelNode, AccelOpCategory, ValueId};
2use runmat_builtins::{IntValue, Tensor, Type, Value};
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum ReductionBehavior {
6    SumLike,
7    MeanLike, // sum-like with 1/reduce_len post-scale
8}
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum ReductionAxes {
12    Unspecified,
13    All,
14    Explicit(Vec<usize>),
15}
16
17#[derive(Debug, Clone)]
18pub struct ReductionSignature {
19    pub data_input: ValueId,
20    pub dim_arg: Option<ValueId>,
21    pub behavior: ReductionBehavior,
22    pub axes: ReductionAxes,
23}
24
25/// Attempt to derive a generic reduction signature from a builtin reduction node without name checks.
26/// Heuristics:
27/// - Data input: the first input whose type is Tensor; otherwise fall back to the first input.
28/// - Dim arg: the first input (after data) that is a scalar numeric or int constant (by ValueId; resolution is done by callers).
29/// - Behavior: inferred via a minimal registry keyed by builtin name for post-scale choices (mean -> MeanLike). This is centralized here.
30pub fn detect_reduction_signature(
31    graph: &AccelGraph,
32    node: &AccelNode,
33) -> Option<ReductionSignature> {
34    if node.category != AccelOpCategory::Reduction {
35        return None;
36    }
37    let (name_opt, inputs) = match &node.label {
38        crate::graph::AccelNodeLabel::Builtin { name } => {
39            (Some(name.as_str()), node.inputs.as_slice())
40        }
41        _ => (None, node.inputs.as_slice()),
42    };
43    if inputs.is_empty() {
44        return None;
45    }
46
47    // 1) Pick data input: first tensor-typed input, else inputs[0]
48    let mut data_input = inputs[0];
49    for &vid in inputs {
50        if let Some(info) = graph.value(vid) {
51            if matches!(info.ty, Type::Tensor { .. }) {
52                data_input = vid;
53                break;
54            }
55        }
56    }
57
58    // 2) Pick dim argument if present: first scalar numeric/int constant input after data input
59    let mut dim_arg: Option<ValueId> = None;
60    for &vid in inputs {
61        if vid == data_input {
62            continue;
63        }
64        if let Some(info) = graph.value(vid) {
65            // constants resolved by callers; here we only pass the ValueId through
66            if matches!(info.origin, crate::graph::ValueOrigin::Constant) {
67                // allow numeric or integer constants (type system may already have Num/Int)
68                if matches!(info.ty, Type::Num | Type::Int) {
69                    dim_arg = Some(vid);
70                    break;
71                }
72            }
73        }
74    }
75
76    // 3) Behavior via centralized minimal registry
77    let behavior = name_opt
78        .map(|n| match n.to_ascii_lowercase().as_str() {
79            "mean" => ReductionBehavior::MeanLike,
80            // Add more here as behavior needs expand; default to SumLike when unsure
81            "sum" => ReductionBehavior::SumLike,
82            _ => ReductionBehavior::SumLike,
83        })
84        .unwrap_or(ReductionBehavior::SumLike);
85
86    let mut axes = ReductionAxes::Unspecified;
87    // Inspect the dimension argument (if constant) first
88    if let Some(dim_vid) = dim_arg {
89        if let Some(value) = graph.value(dim_vid).and_then(|info| info.constant.clone()) {
90            if value_is_all_keyword(&value) {
91                axes = ReductionAxes::All;
92            } else if let Some(dims) = parse_dims_from_value(&value) {
93                axes = ReductionAxes::Explicit(dims);
94            }
95        }
96    }
97    // Fallback: look for any constant input resembling a dimension/all keyword
98    if matches!(axes, ReductionAxes::Unspecified) {
99        for &vid in inputs {
100            if vid == data_input {
101                continue;
102            }
103            if let Some(value) = graph.value(vid).and_then(|info| info.constant.clone()) {
104                if value_is_all_keyword(&value) {
105                    axes = ReductionAxes::All;
106                    break;
107                } else if let Some(dims) = parse_dims_from_value(&value) {
108                    axes = ReductionAxes::Explicit(dims);
109                    break;
110                }
111            }
112        }
113    }
114
115    Some(ReductionSignature {
116        data_input,
117        dim_arg,
118        behavior,
119        axes,
120    })
121}
122
123pub fn value_is_all_keyword(value: &Value) -> bool {
124    match value {
125        Value::String(s) => s.eq_ignore_ascii_case("all"),
126        Value::CharArray(ca) => {
127            if ca.rows == 1 {
128                let candidate: String = ca.data.iter().collect();
129                candidate.trim().eq_ignore_ascii_case("all")
130            } else {
131                false
132            }
133        }
134        Value::StringArray(sa) => sa.data.len() == 1 && sa.data[0].eq_ignore_ascii_case("all"),
135        _ => false,
136    }
137}
138
139fn parse_dims_from_value(value: &Value) -> Option<Vec<usize>> {
140    match value {
141        Value::Int(int_val) => parse_single_int(int_val),
142        Value::Num(n) => parse_single_float(*n),
143        Value::Tensor(t) => parse_tensor_dims(t),
144        _ => None,
145    }
146}
147
148fn parse_single_int(int_val: &IntValue) -> Option<Vec<usize>> {
149    let raw = int_val.to_i64();
150    if raw >= 1 {
151        Some(vec![raw as usize])
152    } else {
153        None
154    }
155}
156
157fn parse_single_float(value: f64) -> Option<Vec<usize>> {
158    if !value.is_finite() {
159        return None;
160    }
161    let rounded = value.round();
162    if (rounded - value).abs() > f64::EPSILON || rounded < 1.0 {
163        return None;
164    }
165    Some(vec![rounded as usize])
166}
167
168fn parse_tensor_dims(tensor: &Tensor) -> Option<Vec<usize>> {
169    if tensor.data.is_empty() {
170        return None;
171    }
172    let mut dims = Vec::with_capacity(tensor.data.len());
173    for value in &tensor.data {
174        if let Some(parsed) = parse_single_float(*value) {
175            dims.extend(parsed);
176        } else {
177            return None;
178        }
179    }
180    if dims.is_empty() {
181        None
182    } else {
183        Some(dims)
184    }
185}