1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fmt;
4
5use runmat_builtins::{Type, Value as BuiltinValue};
6
7pub type NodeId = u32;
8pub type ValueId = u32;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct AccelGraph {
12 pub nodes: Vec<AccelNode>,
13 pub values: Vec<ValueInfo>,
14 pub var_bindings: HashMap<ValueId, VarBinding>,
15 pub node_bindings: HashMap<NodeId, VarBinding>,
16}
17
18impl AccelGraph {
19 pub fn is_empty(&self) -> bool {
20 self.nodes.is_empty()
21 }
22
23 pub fn node(&self, id: NodeId) -> Option<&AccelNode> {
24 self.nodes.get(id as usize)
25 }
26
27 pub fn value(&self, id: ValueId) -> Option<&ValueInfo> {
28 self.values.get(id as usize)
29 }
30
31 pub fn var_binding(&self, id: ValueId) -> Option<&VarBinding> {
32 self.var_bindings.get(&id)
33 }
34
35 pub fn node_binding(&self, id: NodeId) -> Option<&VarBinding> {
36 self.node_bindings.get(&id)
37 }
38
39 pub fn detect_fusion_groups(&self) -> Vec<crate::fusion::FusionGroup> {
40 crate::fusion::detect_fusion_groups(self)
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct AccelNode {
46 pub id: NodeId,
47 pub label: AccelNodeLabel,
48 pub category: AccelOpCategory,
49 pub inputs: Vec<ValueId>,
50 pub outputs: Vec<ValueId>,
51 pub span: InstrSpan,
52 pub tags: Vec<AccelGraphTag>,
53}
54
55impl AccelNode {
56 pub fn is_elementwise(&self) -> bool {
57 self.category == AccelOpCategory::Elementwise
58 }
59
60 pub fn is_reduction(&self) -> bool {
61 self.category == AccelOpCategory::Reduction
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
66pub struct InstrSpan {
67 pub start: usize,
68 pub end: usize,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
72pub enum AccelOpCategory {
73 Elementwise,
74 Reduction,
75 MatMul,
76 Transpose,
77 Other,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
81pub enum AccelNodeLabel {
82 Primitive(PrimitiveOp),
83 Builtin { name: String },
84 Unknown,
85}
86
87#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
88pub enum PrimitiveOp {
89 Add,
90 Sub,
91 Mul,
92 Div,
93 Pow,
94 Neg,
95 UPlus,
96 ElemMul,
97 ElemDiv,
98 ElemPow,
99 ElemLeftDiv,
100 LessEqual,
101 Less,
102 Greater,
103 GreaterEqual,
104 Equal,
105 NotEqual,
106 Transpose,
107}
108
109impl fmt::Display for PrimitiveOp {
110 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111 let name = match self {
112 PrimitiveOp::Add => "Add",
113 PrimitiveOp::Sub => "Sub",
114 PrimitiveOp::Mul => "Mul",
115 PrimitiveOp::Div => "Div",
116 PrimitiveOp::Pow => "Pow",
117 PrimitiveOp::Neg => "Neg",
118 PrimitiveOp::UPlus => "UPlus",
119 PrimitiveOp::ElemMul => "ElemMul",
120 PrimitiveOp::ElemDiv => "ElemDiv",
121 PrimitiveOp::ElemPow => "ElemPow",
122 PrimitiveOp::ElemLeftDiv => "ElemLeftDiv",
123 PrimitiveOp::LessEqual => "LessEqual",
124 PrimitiveOp::Less => "Less",
125 PrimitiveOp::Greater => "Greater",
126 PrimitiveOp::GreaterEqual => "GreaterEqual",
127 PrimitiveOp::Equal => "Equal",
128 PrimitiveOp::NotEqual => "NotEqual",
129 PrimitiveOp::Transpose => "Transpose",
130 };
131 write!(f, "{}", name)
132 }
133}
134
135#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
136pub enum AccelGraphTag {
137 Unary,
138 Elementwise,
139 Reduction,
140 MatMul,
141 Transpose,
142 ArrayConstruct,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct ValueInfo {
147 pub id: ValueId,
148 pub origin: ValueOrigin,
149 pub ty: Type,
150 pub shape: ShapeInfo,
151 #[serde(skip)]
152 pub constant: Option<BuiltinValue>,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
156pub struct VarBinding {
157 pub kind: VarKind,
158 pub index: usize,
159}
160
161impl ValueInfo {
162 pub fn update_type(&mut self, ty: &Type) {
163 self.ty = match (&self.ty, ty) {
164 (Type::Unknown, other) => other.clone(),
165 (existing, other) => existing.unify(other),
166 };
167 self.shape = ShapeInfo::from_type(&self.ty);
168 }
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub enum ValueOrigin {
173 Variable { kind: VarKind, index: usize },
174 NodeOutput { node: NodeId, output: usize },
175 Constant,
176 Unknown,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
180pub enum VarKind {
181 Global,
182 Local,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
186pub enum ShapeInfo {
187 Unknown,
188 Scalar,
189 Tensor(Vec<Option<usize>>),
190}
191
192impl ShapeInfo {
193 pub fn from_type(ty: &Type) -> Self {
194 match ty {
195 Type::Int | Type::Num | Type::Bool | Type::Logical => ShapeInfo::Scalar,
196 Type::Tensor { shape } => match shape {
197 Some(dims) => ShapeInfo::Tensor(dims.clone()),
198 None => ShapeInfo::Tensor(Vec::new()),
199 },
200 _ => ShapeInfo::Unknown,
201 }
202 }
203
204 pub fn unify(&self, other: &ShapeInfo) -> ShapeInfo {
205 match (self, other) {
206 (ShapeInfo::Unknown, _) | (_, ShapeInfo::Unknown) => ShapeInfo::Unknown,
207 (ShapeInfo::Scalar, ShapeInfo::Scalar) => ShapeInfo::Scalar,
208 (ShapeInfo::Scalar, ShapeInfo::Tensor(dims))
209 | (ShapeInfo::Tensor(dims), ShapeInfo::Scalar) => ShapeInfo::Tensor(dims.clone()),
210 (ShapeInfo::Tensor(a), ShapeInfo::Tensor(b)) => ShapeInfo::Tensor(unify_dims(a, b)),
211 }
212 }
213
214 pub fn to_type(&self) -> Type {
215 match self {
216 ShapeInfo::Unknown => Type::Unknown,
217 ShapeInfo::Scalar => Type::Num,
218 ShapeInfo::Tensor(dims) => {
219 if dims.is_empty() {
220 Type::Tensor { shape: None }
221 } else {
222 Type::Tensor {
223 shape: Some(dims.clone()),
224 }
225 }
226 }
227 }
228 }
229
230 pub fn is_scalar(&self) -> bool {
231 matches!(self, ShapeInfo::Scalar)
232 }
233}
234
235fn unify_dims(a: &[Option<usize>], b: &[Option<usize>]) -> Vec<Option<usize>> {
236 let len = a.len().max(b.len());
237 let mut result = Vec::with_capacity(len);
238 for i in 0..len {
239 let da = a.get(i).cloned().unwrap_or(None);
240 let db = b.get(i).cloned().unwrap_or(None);
241 let dim = match (da, db) {
242 (Some(x), Some(y)) if x == y => Some(x),
243 (Some(1), Some(y)) => Some(y),
244 (Some(x), Some(1)) => Some(x),
245 (Some(x), Some(y)) if x != y => None,
246 (Some(x), None) => Some(x),
247 (None, Some(y)) => Some(y),
248 (None, None) => None,
249 _ => None,
250 };
251 result.push(dim);
252 }
253 result
254}
255
256#[cfg(test)]
257mod tests {
258 use super::{unify_dims, ShapeInfo};
259
260 #[test]
261 fn test_unify_dims_basic() {
262 assert_eq!(
263 unify_dims(&[Some(4), Some(3)], &[Some(4), Some(3)]),
264 vec![Some(4), Some(3)]
265 );
266 assert_eq!(
267 unify_dims(&[Some(4)], &[Some(1), Some(3)]),
268 vec![Some(4), Some(3)]
269 );
270 assert_eq!(unify_dims(&[None], &[Some(5)]), vec![Some(5)]);
271 assert_eq!(
272 unify_dims(&[Some(2), Some(3)], &[Some(2)]),
273 vec![Some(2), Some(3)]
274 );
275 }
276
277 #[test]
278 fn test_shape_unify() {
279 let a = ShapeInfo::Tensor(vec![Some(4), Some(3)]);
280 let b = ShapeInfo::Scalar;
281 assert!(matches!(a.unify(&b), ShapeInfo::Tensor(_)));
282 }
283}