runmat_accelerate/
graph.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fmt;
4
5use runmat_builtins::{Type, Value as BuiltinValue};
6
7pub type NodeId = u32;
8pub type ValueId = u32;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct AccelGraph {
12    pub nodes: Vec<AccelNode>,
13    pub values: Vec<ValueInfo>,
14    pub var_bindings: HashMap<ValueId, VarBinding>,
15    pub node_bindings: HashMap<NodeId, VarBinding>,
16}
17
18impl AccelGraph {
19    pub fn is_empty(&self) -> bool {
20        self.nodes.is_empty()
21    }
22
23    pub fn node(&self, id: NodeId) -> Option<&AccelNode> {
24        self.nodes.get(id as usize)
25    }
26
27    pub fn value(&self, id: ValueId) -> Option<&ValueInfo> {
28        self.values.get(id as usize)
29    }
30
31    pub fn var_binding(&self, id: ValueId) -> Option<&VarBinding> {
32        self.var_bindings.get(&id)
33    }
34
35    pub fn node_binding(&self, id: NodeId) -> Option<&VarBinding> {
36        self.node_bindings.get(&id)
37    }
38
39    pub fn detect_fusion_groups(&self) -> Vec<crate::fusion::FusionGroup> {
40        crate::fusion::detect_fusion_groups(self)
41    }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct AccelNode {
46    pub id: NodeId,
47    pub label: AccelNodeLabel,
48    pub category: AccelOpCategory,
49    pub inputs: Vec<ValueId>,
50    pub outputs: Vec<ValueId>,
51    pub span: InstrSpan,
52    pub tags: Vec<AccelGraphTag>,
53}
54
55impl AccelNode {
56    pub fn is_elementwise(&self) -> bool {
57        self.category == AccelOpCategory::Elementwise
58    }
59
60    pub fn is_reduction(&self) -> bool {
61        self.category == AccelOpCategory::Reduction
62    }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
66pub struct InstrSpan {
67    pub start: usize,
68    pub end: usize,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
72pub enum AccelOpCategory {
73    Elementwise,
74    Reduction,
75    MatMul,
76    Transpose,
77    Other,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
81pub enum AccelNodeLabel {
82    Primitive(PrimitiveOp),
83    Builtin { name: String },
84    Unknown,
85}
86
87#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
88pub enum PrimitiveOp {
89    Add,
90    Sub,
91    Mul,
92    Div,
93    Pow,
94    Neg,
95    UPlus,
96    ElemMul,
97    ElemDiv,
98    ElemPow,
99    ElemLeftDiv,
100    LessEqual,
101    Less,
102    Greater,
103    GreaterEqual,
104    Equal,
105    NotEqual,
106    Transpose,
107}
108
109impl fmt::Display for PrimitiveOp {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        let name = match self {
112            PrimitiveOp::Add => "Add",
113            PrimitiveOp::Sub => "Sub",
114            PrimitiveOp::Mul => "Mul",
115            PrimitiveOp::Div => "Div",
116            PrimitiveOp::Pow => "Pow",
117            PrimitiveOp::Neg => "Neg",
118            PrimitiveOp::UPlus => "UPlus",
119            PrimitiveOp::ElemMul => "ElemMul",
120            PrimitiveOp::ElemDiv => "ElemDiv",
121            PrimitiveOp::ElemPow => "ElemPow",
122            PrimitiveOp::ElemLeftDiv => "ElemLeftDiv",
123            PrimitiveOp::LessEqual => "LessEqual",
124            PrimitiveOp::Less => "Less",
125            PrimitiveOp::Greater => "Greater",
126            PrimitiveOp::GreaterEqual => "GreaterEqual",
127            PrimitiveOp::Equal => "Equal",
128            PrimitiveOp::NotEqual => "NotEqual",
129            PrimitiveOp::Transpose => "Transpose",
130        };
131        write!(f, "{}", name)
132    }
133}
134
135#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
136pub enum AccelGraphTag {
137    Unary,
138    Elementwise,
139    Reduction,
140    MatMul,
141    Transpose,
142    ArrayConstruct,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct ValueInfo {
147    pub id: ValueId,
148    pub origin: ValueOrigin,
149    pub ty: Type,
150    pub shape: ShapeInfo,
151    #[serde(skip)]
152    pub constant: Option<BuiltinValue>,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
156pub struct VarBinding {
157    pub kind: VarKind,
158    pub index: usize,
159}
160
161impl ValueInfo {
162    pub fn update_type(&mut self, ty: &Type) {
163        self.ty = match (&self.ty, ty) {
164            (Type::Unknown, other) => other.clone(),
165            (existing, other) => existing.unify(other),
166        };
167        self.shape = ShapeInfo::from_type(&self.ty);
168    }
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub enum ValueOrigin {
173    Variable { kind: VarKind, index: usize },
174    NodeOutput { node: NodeId, output: usize },
175    Constant,
176    Unknown,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
180pub enum VarKind {
181    Global,
182    Local,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
186pub enum ShapeInfo {
187    Unknown,
188    Scalar,
189    Tensor(Vec<Option<usize>>),
190}
191
192impl ShapeInfo {
193    pub fn from_type(ty: &Type) -> Self {
194        match ty {
195            Type::Int | Type::Num | Type::Bool | Type::Logical => ShapeInfo::Scalar,
196            Type::Tensor { shape } => match shape {
197                Some(dims) => ShapeInfo::Tensor(dims.clone()),
198                None => ShapeInfo::Tensor(Vec::new()),
199            },
200            _ => ShapeInfo::Unknown,
201        }
202    }
203
204    pub fn unify(&self, other: &ShapeInfo) -> ShapeInfo {
205        match (self, other) {
206            (ShapeInfo::Unknown, _) | (_, ShapeInfo::Unknown) => ShapeInfo::Unknown,
207            (ShapeInfo::Scalar, ShapeInfo::Scalar) => ShapeInfo::Scalar,
208            (ShapeInfo::Scalar, ShapeInfo::Tensor(dims))
209            | (ShapeInfo::Tensor(dims), ShapeInfo::Scalar) => ShapeInfo::Tensor(dims.clone()),
210            (ShapeInfo::Tensor(a), ShapeInfo::Tensor(b)) => ShapeInfo::Tensor(unify_dims(a, b)),
211        }
212    }
213
214    pub fn to_type(&self) -> Type {
215        match self {
216            ShapeInfo::Unknown => Type::Unknown,
217            ShapeInfo::Scalar => Type::Num,
218            ShapeInfo::Tensor(dims) => {
219                if dims.is_empty() {
220                    Type::Tensor { shape: None }
221                } else {
222                    Type::Tensor {
223                        shape: Some(dims.clone()),
224                    }
225                }
226            }
227        }
228    }
229
230    pub fn is_scalar(&self) -> bool {
231        matches!(self, ShapeInfo::Scalar)
232    }
233}
234
235fn unify_dims(a: &[Option<usize>], b: &[Option<usize>]) -> Vec<Option<usize>> {
236    let len = a.len().max(b.len());
237    let mut result = Vec::with_capacity(len);
238    for i in 0..len {
239        let da = a.get(i).cloned().unwrap_or(None);
240        let db = b.get(i).cloned().unwrap_or(None);
241        let dim = match (da, db) {
242            (Some(x), Some(y)) if x == y => Some(x),
243            (Some(1), Some(y)) => Some(y),
244            (Some(x), Some(1)) => Some(x),
245            (Some(x), Some(y)) if x != y => None,
246            (Some(x), None) => Some(x),
247            (None, Some(y)) => Some(y),
248            (None, None) => None,
249            _ => None,
250        };
251        result.push(dim);
252    }
253    result
254}
255
256#[cfg(test)]
257mod tests {
258    use super::{unify_dims, ShapeInfo};
259
260    #[test]
261    fn test_unify_dims_basic() {
262        assert_eq!(
263            unify_dims(&[Some(4), Some(3)], &[Some(4), Some(3)]),
264            vec![Some(4), Some(3)]
265        );
266        assert_eq!(
267            unify_dims(&[Some(4)], &[Some(1), Some(3)]),
268            vec![Some(4), Some(3)]
269        );
270        assert_eq!(unify_dims(&[None], &[Some(5)]), vec![Some(5)]);
271        assert_eq!(
272            unify_dims(&[Some(2), Some(3)], &[Some(2)]),
273            vec![Some(2), Some(3)]
274        );
275    }
276
277    #[test]
278    fn test_shape_unify() {
279        let a = ShapeInfo::Tensor(vec![Some(4), Some(3)]);
280        let b = ShapeInfo::Scalar;
281        assert!(matches!(a.unify(&b), ShapeInfo::Tensor(_)));
282    }
283}