1use crate::types::{DType, Shape};
8use smallvec::SmallVec;
9use std::collections::HashMap;
10use std::hash::{Hash, Hasher};
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub struct NodeId(pub(crate) u64);
15
16#[derive(Clone, Debug)]
18pub struct TensorMeta {
19 pub shape: Shape,
20 pub dtype: DType,
21}
22
23#[derive(Clone, Debug)]
25pub struct Node {
26 pub id: NodeId,
27 pub op: OpKind,
28 pub inputs: SmallVec<[NodeId; 2]>,
29 pub meta: TensorMeta,
30}
31
32#[derive(Clone, Debug)]
34pub enum OpKind {
35 Constant,
38 Parameter,
40
41 Add,
43 Sub,
44 Mul,
45 Div,
46 Neg,
47 Exp,
48 Log,
49
50 Sum {
52 axis: Option<i32>,
53 },
54 Mean {
55 axis: Option<i32>,
56 },
57 Max {
58 axis: Option<i32>,
59 },
60
61 MatMul,
63
64 Reshape {
66 new_shape: Shape,
67 },
68 Transpose {
69 axes: Option<Vec<usize>>,
70 },
71
72 Softmax {
74 axis: i32,
75 },
76 Silu,
77 Gelu,
78
79 LayerNorm {
81 eps: f32,
82 },
83 RmsNorm {
84 eps: f32,
85 },
86
87 Rope {
91 rotary_dim: usize,
92 pos_offset: usize,
93 theta: f32,
94 },
95
96 Broadcast {
99 target_shape: Shape,
100 },
101
102 ScaledMaskedSoftmax {
106 scale: f32,
107 causal: bool,
108 },
109
110 Attention {
114 scale: f32,
115 causal: bool,
116 },
117
118 LayerNormVjp {
121 eps: f32,
122 },
123 RmsNormVjp {
125 eps: f32,
126 },
127 SoftmaxVjp {
129 axis: i32,
130 },
131 SiluVjp,
133 GeluVjp,
135
136 Sqrt,
139
140 #[cfg_attr(target_os = "macos", doc = "Apply rotary positional embeddings.")]
142 RoPE {
143 base: f32,
144 offset: usize,
145 traditional: bool,
146 },
147
148 Embedding,
153
154 Narrow {
158 axis: i32,
159 start: i64,
160 length: i64,
161 },
162
163 Concatenate {
167 axis: i32,
168 },
169}
170
171#[derive(Debug, Default)]
173pub struct Graph {
174 nodes: Vec<Node>,
175 next_id: u64,
176 cse: HashMap<CseKey, NodeId>,
177 const_payloads: HashMap<NodeId, Vec<f32>>,
178}
179
180impl Graph {
181 pub fn new() -> Self {
182 Self::default()
183 }
184
185 pub fn add_node(
187 &mut self,
188 op: OpKind,
189 inputs: SmallVec<[NodeId; 2]>,
190 meta: TensorMeta,
191 ) -> NodeId {
192 self.add_node_raw(op, inputs, meta)
193 }
194
195 pub fn add_node_raw(
197 &mut self,
198 op: OpKind,
199 inputs: SmallVec<[NodeId; 2]>,
200 meta: TensorMeta,
201 ) -> NodeId {
202 let id = NodeId(self.next_id);
203 self.next_id += 1;
204 self.nodes.push(Node {
205 id,
206 op,
207 inputs,
208 meta,
209 });
210 id
211 }
212
213 pub fn intern_node(
215 &mut self,
216 op: OpKind,
217 inputs: SmallVec<[NodeId; 2]>,
218 meta: TensorMeta,
219 const_payload: Option<&[f32]>,
220 ) -> NodeId {
221 if !is_cse_eligible(&op) {
222 return self.add_node_raw(op, inputs, meta);
223 }
224
225 let mut inputs = inputs;
226 normalize_inputs_for_cse(&op, &mut inputs);
227
228 let const_hash = const_payload.map(hash_f32_payload);
229 let key = CseKey {
230 op_key: OpKey::from_op(&op),
231 inputs: inputs.clone(),
232 meta_sig: MetaSig::new(&meta),
233 const_hash,
234 };
235
236 if let Some(&existing) = self.cse.get(&key) {
237 if matches!(op, OpKind::Constant) {
238 if let (Some(payload), Some(existing_payload)) =
239 (const_payload, self.const_payload(existing))
240 && existing_payload == payload
241 {
242 return existing;
243 }
244 } else {
245 return existing;
246 }
247 }
248
249 let id = self.add_node_raw(op, inputs, meta);
250 if matches!(key.op_key, OpKey::Constant)
251 && let Some(payload) = const_payload
252 {
253 self.const_payloads.insert(id, payload.to_vec());
254 }
255 self.cse.insert(key, id);
256 id
257 }
258
259 pub fn const_payload(&self, id: NodeId) -> Option<&[f32]> {
260 self.const_payloads.get(&id).map(|v| v.as_slice())
261 }
262
263 pub fn get(&self, id: NodeId) -> Option<&Node> {
265 self.nodes.iter().find(|n| n.id == id)
266 }
267
268 pub fn topo_sort(&self, outputs: &[NodeId]) -> Vec<NodeId> {
270 let mut visited = std::collections::HashSet::new();
271 let mut order = Vec::new();
272
273 for &out in outputs {
274 self.topo_visit(out, &mut visited, &mut order);
275 }
276
277 order
278 }
279
280 fn topo_visit(
281 &self,
282 id: NodeId,
283 visited: &mut std::collections::HashSet<NodeId>,
284 order: &mut Vec<NodeId>,
285 ) {
286 if !visited.insert(id) {
287 return;
288 }
289 if let Some(node) = self.get(id) {
290 for &input in &node.inputs {
291 self.topo_visit(input, visited, order);
292 }
293 }
294 order.push(id);
295 }
296
297 pub fn len(&self) -> usize {
299 self.nodes.len()
300 }
301
302 pub fn is_empty(&self) -> bool {
304 self.nodes.is_empty()
305 }
306}
307
308#[derive(Clone, Debug, PartialEq, Eq, Hash)]
309struct MetaSig {
310 dtype: DType,
311 shape: Vec<i64>,
312}
313
314impl MetaSig {
315 fn new(meta: &TensorMeta) -> Self {
316 Self {
317 dtype: meta.dtype,
318 shape: meta.shape.0.clone(),
319 }
320 }
321}
322
323#[derive(Clone, Debug, PartialEq, Eq, Hash)]
324enum OpKey {
325 Constant,
326 Parameter,
327 Add,
328 Sub,
329 Mul,
330 Div,
331 Neg,
332 Exp,
333 Log,
334 Sum {
335 axis: Option<i32>,
336 },
337 Mean {
338 axis: Option<i32>,
339 },
340 Max {
341 axis: Option<i32>,
342 },
343 MatMul,
344 Reshape {
345 new_shape: Vec<i64>,
346 },
347 Transpose {
348 axes: Option<Vec<usize>>,
349 },
350 Softmax {
351 axis: i32,
352 },
353 Silu,
354 Gelu,
355 LayerNorm {
356 eps_bits: u32,
357 },
358 RmsNorm {
359 eps_bits: u32,
360 },
361 Broadcast {
362 target_shape: Vec<i64>,
363 },
364 LayerNormVjp {
365 eps_bits: u32,
366 },
367 RmsNormVjp {
368 eps_bits: u32,
369 },
370 ScaledMaskedSoftmax {
371 scale_bits: u32,
372 causal: bool,
373 },
374 Attention {
375 scale_bits: u32,
376 causal: bool,
377 },
378 Rope {
379 rotary_dim: usize,
380 pos_offset: usize,
381 theta_bits: u32,
382 },
383 RoPE {
384 base_bits: u32,
385 offset: usize,
386 traditional: bool,
387 },
388 SoftmaxVjp {
389 axis: i32,
390 },
391 SiluVjp,
392 GeluVjp,
393 Sqrt,
394 Embedding,
395 Narrow {
396 axis: i32,
397 start: i64,
398 length: i64,
399 },
400 Concatenate {
401 axis: i32,
402 },
403}
404
405impl OpKey {
406 fn from_op(op: &OpKind) -> Self {
407 match op {
408 OpKind::Constant => OpKey::Constant,
409 OpKind::Parameter => OpKey::Parameter,
410 OpKind::Add => OpKey::Add,
411 OpKind::Sub => OpKey::Sub,
412 OpKind::Mul => OpKey::Mul,
413 OpKind::Div => OpKey::Div,
414 OpKind::Neg => OpKey::Neg,
415 OpKind::Exp => OpKey::Exp,
416 OpKind::Log => OpKey::Log,
417 OpKind::Sum { axis } => OpKey::Sum { axis: *axis },
418 OpKind::Mean { axis } => OpKey::Mean { axis: *axis },
419 OpKind::Max { axis } => OpKey::Max { axis: *axis },
420 OpKind::MatMul => OpKey::MatMul,
421 OpKind::Reshape { new_shape } => OpKey::Reshape {
422 new_shape: new_shape.0.clone(),
423 },
424 OpKind::Transpose { axes } => OpKey::Transpose { axes: axes.clone() },
425 OpKind::Softmax { axis } => OpKey::Softmax { axis: *axis },
426 OpKind::Silu => OpKey::Silu,
427 OpKind::Gelu => OpKey::Gelu,
428 OpKind::LayerNorm { eps } => OpKey::LayerNorm {
429 eps_bits: eps.to_bits(),
430 },
431 OpKind::RmsNorm { eps } => OpKey::RmsNorm {
432 eps_bits: eps.to_bits(),
433 },
434 OpKind::Broadcast { target_shape } => OpKey::Broadcast {
435 target_shape: target_shape.0.clone(),
436 },
437 OpKind::LayerNormVjp { eps } => OpKey::LayerNormVjp {
438 eps_bits: eps.to_bits(),
439 },
440 OpKind::RmsNormVjp { eps } => OpKey::RmsNormVjp {
441 eps_bits: eps.to_bits(),
442 },
443 OpKind::ScaledMaskedSoftmax { scale, causal } => OpKey::ScaledMaskedSoftmax {
444 scale_bits: scale.to_bits(),
445 causal: *causal,
446 },
447 OpKind::Attention { scale, causal } => OpKey::Attention {
448 scale_bits: scale.to_bits(),
449 causal: *causal,
450 },
451 OpKind::Rope {
452 rotary_dim,
453 pos_offset,
454 theta,
455 } => OpKey::Rope {
456 rotary_dim: *rotary_dim,
457 pos_offset: *pos_offset,
458 theta_bits: theta.to_bits(),
459 },
460 OpKind::RoPE {
461 base,
462 offset,
463 traditional,
464 } => OpKey::RoPE {
465 base_bits: base.to_bits(),
466 offset: *offset,
467 traditional: *traditional,
468 },
469 OpKind::SoftmaxVjp { axis } => OpKey::SoftmaxVjp { axis: *axis },
470 OpKind::SiluVjp => OpKey::SiluVjp,
471 OpKind::GeluVjp => OpKey::GeluVjp,
472 OpKind::Sqrt => OpKey::Sqrt,
473 OpKind::Embedding => OpKey::Embedding,
474 OpKind::Narrow {
475 axis,
476 start,
477 length,
478 } => OpKey::Narrow {
479 axis: *axis,
480 start: *start,
481 length: *length,
482 },
483 OpKind::Concatenate { axis } => OpKey::Concatenate { axis: *axis },
484 }
485 }
486}
487
488#[derive(Clone, Debug, PartialEq, Eq, Hash)]
489struct CseKey {
490 op_key: OpKey,
491 inputs: SmallVec<[NodeId; 2]>,
492 meta_sig: MetaSig,
493 const_hash: Option<u64>,
494}
495
496fn is_cse_eligible(op: &OpKind) -> bool {
497 !matches!(op, OpKind::Constant | OpKind::Parameter)
501}
502
503pub fn hash_f32_payload(data: &[f32]) -> u64 {
504 let mut h = std::collections::hash_map::DefaultHasher::new();
505 data.len().hash(&mut h);
506 for &x in data {
507 x.to_bits().hash(&mut h);
508 }
509 h.finish()
510}
511
512fn normalize_inputs_for_cse(op: &OpKind, inputs: &mut SmallVec<[NodeId; 2]>) {
513 if matches!(op, OpKind::Add | OpKind::Mul) && inputs.len() == 2 && inputs[0].0 > inputs[1].0 {
514 inputs.swap(0, 1);
515 }
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521
522 #[test]
523 fn test_graph_topo_sort() {
524 let mut g = Graph::new();
525 let a = g.add_node(
526 OpKind::Constant,
527 SmallVec::new(),
528 TensorMeta {
529 shape: Shape::new(vec![2, 3]),
530 dtype: DType::F32,
531 },
532 );
533 let b = g.add_node(
534 OpKind::Constant,
535 SmallVec::new(),
536 TensorMeta {
537 shape: Shape::new(vec![2, 3]),
538 dtype: DType::F32,
539 },
540 );
541 let c = g.add_node(
542 OpKind::Add,
543 SmallVec::from_slice(&[a, b]),
544 TensorMeta {
545 shape: Shape::new(vec![2, 3]),
546 dtype: DType::F32,
547 },
548 );
549
550 let order = g.topo_sort(&[c]);
551 assert_eq!(order.len(), 3);
552 let pos_a = order.iter().position(|&id| id == a).unwrap();
554 let pos_b = order.iter().position(|&id| id == b).unwrap();
555 let pos_c = order.iter().position(|&id| id == c).unwrap();
556 assert!(pos_a < pos_c);
557 assert!(pos_b < pos_c);
558 }
559
560 #[test]
561 fn test_cse_does_not_dedup_constants() {
562 let mut g = Graph::new();
563 let meta = TensorMeta {
564 shape: Shape::new(vec![2]),
565 dtype: DType::F32,
566 };
567 let a = g.intern_node(
570 OpKind::Constant,
571 SmallVec::new(),
572 meta.clone(),
573 Some(&[1.0, 2.0]),
574 );
575 let b = g.intern_node(
576 OpKind::Constant,
577 SmallVec::new(),
578 meta.clone(),
579 Some(&[1.0, 2.0]),
580 );
581 assert_ne!(a, b);
583 assert_eq!(g.len(), 2);
584 }
585
586 #[test]
587 fn test_cse_dedups_ops() {
588 let mut g = Graph::new();
589 let meta = TensorMeta {
590 shape: Shape::new(vec![2]),
591 dtype: DType::F32,
592 };
593 let a = g.intern_node(
594 OpKind::Constant,
595 SmallVec::new(),
596 meta.clone(),
597 Some(&[1.0, 2.0]),
598 );
599 let b = g.intern_node(
600 OpKind::Constant,
601 SmallVec::new(),
602 meta.clone(),
603 Some(&[3.0, 4.0]),
604 );
605
606 let add1 = g.intern_node(
607 OpKind::Add,
608 SmallVec::from_slice(&[a, b]),
609 meta.clone(),
610 None,
611 );
612 let add2 = g.intern_node(
613 OpKind::Add,
614 SmallVec::from_slice(&[a, b]),
615 meta.clone(),
616 None,
617 );
618 assert_eq!(add1, add2);
619 assert_eq!(g.len(), 3); }
621}