1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fmt;
4
5use runmat_builtins::{Type, Value as BuiltinValue};
6
7pub type NodeId = u32;
8pub type ValueId = u32;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct AccelGraph {
12 pub nodes: Vec<AccelNode>,
13 pub values: Vec<ValueInfo>,
14 pub var_bindings: HashMap<ValueId, VarBinding>,
15 pub node_bindings: HashMap<NodeId, VarBinding>,
16}
17
18impl AccelGraph {
19 pub fn is_empty(&self) -> bool {
20 self.nodes.is_empty()
21 }
22
23 pub fn node(&self, id: NodeId) -> Option<&AccelNode> {
24 self.nodes.get(id as usize)
25 }
26
27 pub fn value(&self, id: ValueId) -> Option<&ValueInfo> {
28 self.values.get(id as usize)
29 }
30
31 pub fn var_binding(&self, id: ValueId) -> Option<&VarBinding> {
32 self.var_bindings.get(&id)
33 }
34
35 pub fn node_binding(&self, id: NodeId) -> Option<&VarBinding> {
36 self.node_bindings.get(&id)
37 }
38
39 pub fn detect_fusion_groups(&self) -> Vec<crate::fusion::FusionGroup> {
40 crate::fusion::detect_fusion_groups(self)
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct AccelNode {
46 pub id: NodeId,
47 pub label: AccelNodeLabel,
48 pub category: AccelOpCategory,
49 pub inputs: Vec<ValueId>,
50 pub outputs: Vec<ValueId>,
51 pub span: InstrSpan,
52 pub tags: Vec<AccelGraphTag>,
53}
54
55impl AccelNode {
56 pub fn is_elementwise(&self) -> bool {
57 self.category == AccelOpCategory::Elementwise
58 }
59
60 pub fn is_reduction(&self) -> bool {
61 self.category == AccelOpCategory::Reduction
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
66pub struct InstrSpan {
67 pub start: usize,
68 pub end: usize,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
72pub enum AccelOpCategory {
73 Elementwise,
74 Reduction,
75 MatMul,
76 Transpose,
77 Other,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
81pub enum AccelNodeLabel {
82 Primitive(PrimitiveOp),
83 Builtin { name: String },
84 Unknown,
85}
86
87#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
88pub enum PrimitiveOp {
89 Add,
90 Sub,
91 Mul,
92 Pow,
93 Neg,
94 UPlus,
95 ElemMul,
96 ElemDiv,
97 ElemPow,
98 ElemLeftDiv,
99 LessEqual,
100 Less,
101 Greater,
102 GreaterEqual,
103 Equal,
104 NotEqual,
105 Transpose,
106}
107
108impl fmt::Display for PrimitiveOp {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 let name = match self {
111 PrimitiveOp::Add => "Add",
112 PrimitiveOp::Sub => "Sub",
113 PrimitiveOp::Mul => "Mul",
114 PrimitiveOp::Pow => "Pow",
115 PrimitiveOp::Neg => "Neg",
116 PrimitiveOp::UPlus => "UPlus",
117 PrimitiveOp::ElemMul => "ElemMul",
118 PrimitiveOp::ElemDiv => "ElemDiv",
119 PrimitiveOp::ElemPow => "ElemPow",
120 PrimitiveOp::ElemLeftDiv => "ElemLeftDiv",
121 PrimitiveOp::LessEqual => "LessEqual",
122 PrimitiveOp::Less => "Less",
123 PrimitiveOp::Greater => "Greater",
124 PrimitiveOp::GreaterEqual => "GreaterEqual",
125 PrimitiveOp::Equal => "Equal",
126 PrimitiveOp::NotEqual => "NotEqual",
127 PrimitiveOp::Transpose => "Transpose",
128 };
129 write!(f, "{}", name)
130 }
131}
132
133#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
134pub enum AccelGraphTag {
135 Unary,
136 Elementwise,
137 Reduction,
138 MatMul,
139 Transpose,
140 ArrayConstruct,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct ValueInfo {
145 pub id: ValueId,
146 pub origin: ValueOrigin,
147 pub ty: Type,
148 pub shape: ShapeInfo,
149 #[serde(skip)]
150 pub constant: Option<BuiltinValue>,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
154pub struct VarBinding {
155 pub kind: VarKind,
156 pub index: usize,
157}
158
159impl ValueInfo {
160 pub fn update_type(&mut self, ty: &Type) {
161 self.ty = match (&self.ty, ty) {
162 (Type::Unknown, other) => other.clone(),
163 (existing, other) => existing.unify(other),
164 };
165 self.shape = ShapeInfo::from_type(&self.ty);
166 }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub enum ValueOrigin {
171 Variable { kind: VarKind, index: usize },
172 NodeOutput { node: NodeId, output: usize },
173 Constant,
174 Unknown,
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
178pub enum VarKind {
179 Global,
180 Local,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
184pub enum ShapeInfo {
185 Unknown,
186 Scalar,
187 Tensor(Vec<Option<usize>>),
188}
189
190impl ShapeInfo {
191 pub fn from_type(ty: &Type) -> Self {
192 match ty {
193 Type::Int | Type::Num | Type::Bool => ShapeInfo::Scalar,
194 Type::Logical { shape } => match shape {
195 Some(dims) => ShapeInfo::Tensor(dims.clone()),
196 None => ShapeInfo::Tensor(Vec::new()),
197 },
198 Type::Tensor { shape } => match shape {
199 Some(dims) => ShapeInfo::Tensor(dims.clone()),
200 None => ShapeInfo::Tensor(Vec::new()),
201 },
202 _ => ShapeInfo::Unknown,
203 }
204 }
205
206 pub fn unify(&self, other: &ShapeInfo) -> ShapeInfo {
207 match (self, other) {
208 (ShapeInfo::Unknown, _) | (_, ShapeInfo::Unknown) => ShapeInfo::Unknown,
209 (ShapeInfo::Scalar, ShapeInfo::Scalar) => ShapeInfo::Scalar,
210 (ShapeInfo::Scalar, ShapeInfo::Tensor(dims))
211 | (ShapeInfo::Tensor(dims), ShapeInfo::Scalar) => ShapeInfo::Tensor(dims.clone()),
212 (ShapeInfo::Tensor(a), ShapeInfo::Tensor(b)) => {
213 ShapeInfo::Tensor(runmat_builtins::shape_rules::broadcast_shapes(a, b))
214 }
215 }
216 }
217
218 pub fn to_type(&self) -> Type {
219 match self {
220 ShapeInfo::Unknown => Type::Unknown,
221 ShapeInfo::Scalar => Type::Num,
222 ShapeInfo::Tensor(dims) => {
223 if dims.is_empty() {
224 Type::Tensor { shape: None }
225 } else {
226 Type::Tensor {
227 shape: Some(dims.clone()),
228 }
229 }
230 }
231 }
232 }
233
234 pub fn is_scalar(&self) -> bool {
235 matches!(self, ShapeInfo::Scalar)
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::ShapeInfo;
242
243 #[test]
244 fn test_unify_dims_basic() {
245 assert_eq!(
246 runmat_builtins::shape_rules::broadcast_shapes(
247 &[Some(4), Some(3)],
248 &[Some(4), Some(3)]
249 ),
250 vec![Some(4), Some(3)]
251 );
252 assert_eq!(
253 runmat_builtins::shape_rules::broadcast_shapes(
255 &[Some(4), Some(1)],
256 &[Some(1), Some(3)]
257 ),
258 vec![Some(4), Some(3)]
259 );
260 assert_eq!(
261 runmat_builtins::shape_rules::broadcast_shapes(&[None], &[Some(5)]),
262 vec![Some(5)]
263 );
264 assert_eq!(
265 runmat_builtins::shape_rules::broadcast_shapes(
267 &[Some(2), Some(3)],
268 &[Some(2), Some(1)]
269 ),
270 vec![Some(2), Some(3)]
271 );
272 }
273
274 #[test]
275 fn test_shape_unify() {
276 let a = ShapeInfo::Tensor(vec![Some(4), Some(3)]);
277 let b = ShapeInfo::Scalar;
278 assert!(matches!(a.unify(&b), ShapeInfo::Tensor(_)));
279 }
280
281 #[test]
282 fn test_shape_unify_broadcasts() {
283 let a = ShapeInfo::Tensor(vec![Some(1), Some(3)]);
284 let b = ShapeInfo::Tensor(vec![Some(2), Some(1)]);
285 assert_eq!(a.unify(&b), ShapeInfo::Tensor(vec![Some(2), Some(3)]));
286 }
287}