Skip to main content

runmat_accelerate/
graph.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fmt;
4
5use runmat_builtins::{Type, Value as BuiltinValue};
6
7pub type NodeId = u32;
8pub type ValueId = u32;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct AccelGraph {
12    pub nodes: Vec<AccelNode>,
13    pub values: Vec<ValueInfo>,
14    pub var_bindings: HashMap<ValueId, VarBinding>,
15    pub node_bindings: HashMap<NodeId, VarBinding>,
16}
17
18impl AccelGraph {
19    pub fn is_empty(&self) -> bool {
20        self.nodes.is_empty()
21    }
22
23    pub fn node(&self, id: NodeId) -> Option<&AccelNode> {
24        self.nodes.get(id as usize)
25    }
26
27    pub fn value(&self, id: ValueId) -> Option<&ValueInfo> {
28        self.values.get(id as usize)
29    }
30
31    pub fn var_binding(&self, id: ValueId) -> Option<&VarBinding> {
32        self.var_bindings.get(&id)
33    }
34
35    pub fn node_binding(&self, id: NodeId) -> Option<&VarBinding> {
36        self.node_bindings.get(&id)
37    }
38
39    pub fn detect_fusion_groups(&self) -> Vec<crate::fusion::FusionGroup> {
40        crate::fusion::detect_fusion_groups(self)
41    }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct AccelNode {
46    pub id: NodeId,
47    pub label: AccelNodeLabel,
48    pub category: AccelOpCategory,
49    pub inputs: Vec<ValueId>,
50    pub outputs: Vec<ValueId>,
51    pub span: InstrSpan,
52    pub tags: Vec<AccelGraphTag>,
53}
54
55impl AccelNode {
56    pub fn is_elementwise(&self) -> bool {
57        self.category == AccelOpCategory::Elementwise
58    }
59
60    pub fn is_reduction(&self) -> bool {
61        self.category == AccelOpCategory::Reduction
62    }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
66pub struct InstrSpan {
67    pub start: usize,
68    pub end: usize,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
72pub enum AccelOpCategory {
73    Elementwise,
74    Reduction,
75    MatMul,
76    Transpose,
77    Other,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
81pub enum AccelNodeLabel {
82    Primitive(PrimitiveOp),
83    Builtin { name: String },
84    Unknown,
85}
86
87#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
88pub enum PrimitiveOp {
89    Add,
90    Sub,
91    Mul,
92    Pow,
93    Neg,
94    UPlus,
95    ElemMul,
96    ElemDiv,
97    ElemPow,
98    ElemLeftDiv,
99    LessEqual,
100    Less,
101    Greater,
102    GreaterEqual,
103    Equal,
104    NotEqual,
105    Transpose,
106}
107
108impl fmt::Display for PrimitiveOp {
109    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110        let name = match self {
111            PrimitiveOp::Add => "Add",
112            PrimitiveOp::Sub => "Sub",
113            PrimitiveOp::Mul => "Mul",
114            PrimitiveOp::Pow => "Pow",
115            PrimitiveOp::Neg => "Neg",
116            PrimitiveOp::UPlus => "UPlus",
117            PrimitiveOp::ElemMul => "ElemMul",
118            PrimitiveOp::ElemDiv => "ElemDiv",
119            PrimitiveOp::ElemPow => "ElemPow",
120            PrimitiveOp::ElemLeftDiv => "ElemLeftDiv",
121            PrimitiveOp::LessEqual => "LessEqual",
122            PrimitiveOp::Less => "Less",
123            PrimitiveOp::Greater => "Greater",
124            PrimitiveOp::GreaterEqual => "GreaterEqual",
125            PrimitiveOp::Equal => "Equal",
126            PrimitiveOp::NotEqual => "NotEqual",
127            PrimitiveOp::Transpose => "Transpose",
128        };
129        write!(f, "{}", name)
130    }
131}
132
133#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
134pub enum AccelGraphTag {
135    Unary,
136    Elementwise,
137    Reduction,
138    MatMul,
139    Transpose,
140    ArrayConstruct,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct ValueInfo {
145    pub id: ValueId,
146    pub origin: ValueOrigin,
147    pub ty: Type,
148    pub shape: ShapeInfo,
149    #[serde(skip)]
150    pub constant: Option<BuiltinValue>,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
154pub struct VarBinding {
155    pub kind: VarKind,
156    pub index: usize,
157}
158
159impl ValueInfo {
160    pub fn update_type(&mut self, ty: &Type) {
161        self.ty = match (&self.ty, ty) {
162            (Type::Unknown, other) => other.clone(),
163            (existing, other) => existing.unify(other),
164        };
165        self.shape = ShapeInfo::from_type(&self.ty);
166    }
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub enum ValueOrigin {
171    Variable { kind: VarKind, index: usize },
172    NodeOutput { node: NodeId, output: usize },
173    Constant,
174    Unknown,
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
178pub enum VarKind {
179    Global,
180    Local,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
184pub enum ShapeInfo {
185    Unknown,
186    Scalar,
187    Tensor(Vec<Option<usize>>),
188}
189
190impl ShapeInfo {
191    pub fn from_type(ty: &Type) -> Self {
192        match ty {
193            Type::Int | Type::Num | Type::Bool => ShapeInfo::Scalar,
194            Type::Logical { shape } => match shape {
195                Some(dims) => ShapeInfo::Tensor(dims.clone()),
196                None => ShapeInfo::Tensor(Vec::new()),
197            },
198            Type::Tensor { shape } => match shape {
199                Some(dims) => ShapeInfo::Tensor(dims.clone()),
200                None => ShapeInfo::Tensor(Vec::new()),
201            },
202            _ => ShapeInfo::Unknown,
203        }
204    }
205
206    pub fn unify(&self, other: &ShapeInfo) -> ShapeInfo {
207        match (self, other) {
208            (ShapeInfo::Unknown, _) | (_, ShapeInfo::Unknown) => ShapeInfo::Unknown,
209            (ShapeInfo::Scalar, ShapeInfo::Scalar) => ShapeInfo::Scalar,
210            (ShapeInfo::Scalar, ShapeInfo::Tensor(dims))
211            | (ShapeInfo::Tensor(dims), ShapeInfo::Scalar) => ShapeInfo::Tensor(dims.clone()),
212            (ShapeInfo::Tensor(a), ShapeInfo::Tensor(b)) => {
213                ShapeInfo::Tensor(runmat_builtins::shape_rules::broadcast_shapes(a, b))
214            }
215        }
216    }
217
218    pub fn to_type(&self) -> Type {
219        match self {
220            ShapeInfo::Unknown => Type::Unknown,
221            ShapeInfo::Scalar => Type::Num,
222            ShapeInfo::Tensor(dims) => {
223                if dims.is_empty() {
224                    Type::Tensor { shape: None }
225                } else {
226                    Type::Tensor {
227                        shape: Some(dims.clone()),
228                    }
229                }
230            }
231        }
232    }
233
234    pub fn is_scalar(&self) -> bool {
235        matches!(self, ShapeInfo::Scalar)
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::ShapeInfo;
242
243    #[test]
244    fn test_unify_dims_basic() {
245        assert_eq!(
246            runmat_builtins::shape_rules::broadcast_shapes(
247                &[Some(4), Some(3)],
248                &[Some(4), Some(3)]
249            ),
250            vec![Some(4), Some(3)]
251        );
252        assert_eq!(
253            // Treat vectors as 2D shapes (column-major): 4x1 broadcast with 1x3 -> 4x3.
254            runmat_builtins::shape_rules::broadcast_shapes(
255                &[Some(4), Some(1)],
256                &[Some(1), Some(3)]
257            ),
258            vec![Some(4), Some(3)]
259        );
260        assert_eq!(
261            runmat_builtins::shape_rules::broadcast_shapes(&[None], &[Some(5)]),
262            vec![Some(5)]
263        );
264        assert_eq!(
265            // Treat vectors as 2D shapes (column-major): 2x3 broadcast with 2x1 -> 2x3.
266            runmat_builtins::shape_rules::broadcast_shapes(
267                &[Some(2), Some(3)],
268                &[Some(2), Some(1)]
269            ),
270            vec![Some(2), Some(3)]
271        );
272    }
273
274    #[test]
275    fn test_shape_unify() {
276        let a = ShapeInfo::Tensor(vec![Some(4), Some(3)]);
277        let b = ShapeInfo::Scalar;
278        assert!(matches!(a.unify(&b), ShapeInfo::Tensor(_)));
279    }
280
281    #[test]
282    fn test_shape_unify_broadcasts() {
283        let a = ShapeInfo::Tensor(vec![Some(1), Some(3)]);
284        let b = ShapeInfo::Tensor(vec![Some(2), Some(1)]);
285        assert_eq!(a.unify(&b), ShapeInfo::Tensor(vec![Some(2), Some(3)]));
286    }
287}