1use std::{
2 cell::Cell,
3 collections::HashMap,
4 env,
5 fmt::Display,
6 fs,
7 hash::Hash,
8 marker::PhantomData,
9 path::Path,
10 process::Command,
11 rc::Rc,
12 sync::{Arc, RwLock, RwLockReadGuard},
13};
14
15use crate::{device::Dev, tensor::concretetensor::from_storage, DType, Result, Shape, Tensor};
16
17use petgraph::Graph as PetGraph;
18use petgraph::{dot::Dot, graph::NodeIndex};
19
20#[derive(Clone, Debug)]
21pub struct GraphNode<T: DType> {
22 pub op: Op<T>,
23 pub shape: Vec<usize>,
24 pub strides: Vec<usize>,
25 pub id: GraphTensorId,
26}
27
28#[derive(Clone)]
29pub struct Graph<T: DType> {
30 data: Arc<RwLock<Vec<GraphNode<T>>>>,
31 id: Arc<RwLock<usize>>,
32}
33
34impl<T: DType> Graph<T> {
35 pub fn empty() -> Self {
37 Self {
38 data: Arc::new(RwLock::new(Vec::new())),
39 id: Arc::new(RwLock::new(0)),
40 }
41 }
42
43 pub fn get_ops(&self) -> RwLockReadGuard<Vec<GraphNode<T>>> {
45 self.data.read().unwrap()
46 }
47
48 pub(crate) fn add_op<S: Shape>(&self, op: Op<T>, strides: &[usize], id: &GraphTensorId) {
50 self.data.write().unwrap().push(GraphNode {
51 op,
52 shape: S::shape(),
53 strides: strides.to_vec(),
54 id: id.clone(),
55 });
56 }
57
58 #[must_use]
60 pub(crate) fn next_id(&mut self) -> GraphTensorId {
61 let next = GraphTensorId::out_of_place(*self.id.read().unwrap());
62 *self.id.write().unwrap() += 1;
63 next
64 }
65
66 pub fn to_petgraph(&self) -> PetGraph<String, String> {
67 let ops = self.data.read().unwrap();
68 let mut g = PetGraph::<String, String>::new();
69 let mut idx_map: Vec<Option<NodeIndex>> = Vec::with_capacity(ops.len());
71
72 for op in ops.iter() {
74 match op.op {
75 Op::NoOp => {
76 idx_map.push(None);
77 }
78 _ => {
79 let label = match &op.op {
80 Op::Fill { v, .. } => format!("Fill({v:?})"),
81 Op::Arange {
82 start, step, stop, ..
83 } => {
84 format!("Arange(start={start:?}, step={step:?}, stop={stop:?})")
85 }
86 Op::Rand => "Rand".to_string(),
87 Op::Randn { mean, std } => {
88 format!("Randn(mean={mean:?}, std={std:?})")
89 }
90 Op::BinaryOp { operator, .. } => format!("BinOp({})", operator.as_c_op()),
91 Op::UnaryOp { operator, .. } => format!("UnOp({operator:?})"),
92 Op::FusedMulAdd { .. } => "FMA".to_string(),
93 Op::MatMul { .. } => "MatMul".to_string(),
95 Op::Permute { v_id: _ } => "Permute".to_string(),
96 Op::NoOp => unreachable!(),
98 };
99 let node = g.add_node(label);
100 idx_map.push(Some(node));
101 }
102 }
103 }
104
105 for (i, op) in ops.iter().enumerate() {
107 let dst = match idx_map[i] {
109 Some(dst) => dst,
110 None => continue,
111 };
112 match &op.op {
113 Op::BinaryOp { l_id, r_id, .. } => {
114 if let Some(src) = idx_map[l_id.get()] {
115 let mut label = "l".to_string();
116 if l_id.is_inplace() {
117 label.push('*');
118 }
119 g.add_edge(src, dst, label.clone());
120 }
121 if let Some(src) = idx_map[r_id.get()] {
122 let mut label = "r".to_string();
123 if r_id.is_inplace() {
124 label.push('*');
125 }
126 g.add_edge(src, dst, label.clone());
127 }
128 }
129 Op::UnaryOp { v_id, .. } => {
130 if let Some(src) = idx_map[v_id.get()] {
131 let mut label = "v".to_string();
132 if v_id.is_inplace() {
133 label.push('*');
134 }
135 g.add_edge(src, dst, label.clone());
136 }
137 }
138 Op::FusedMulAdd {
139 a_id, b_id, c_id, ..
140 } => {
141 for (prefix, src_id) in [("a", a_id), ("b", b_id), ("c", c_id)].iter() {
142 if let Some(src) = idx_map[src_id.get()] {
143 let mut label = prefix.to_string();
144 if src_id.is_inplace() {
145 label.push('*');
146 }
147 g.add_edge(src, dst, label.clone());
148 }
149 }
150 }
151 Op::MatMul {
152 l_id, r_id, o_id, ..
153 } => {
154 if let Some(src) = idx_map[l_id.get()] {
155 let mut label = "l".to_string();
156 if l_id.is_inplace() {
157 label.push('*');
158 }
159 g.add_edge(src, dst, label.clone());
160 }
161 if let Some(src) = idx_map[r_id.get()] {
162 let mut label = "r".to_string();
163 if r_id.is_inplace() {
164 label.push('*');
165 }
166 g.add_edge(src, dst, label.clone());
167 }
168 if let Some(o_id) = o_id {
169 if let Some(src) = idx_map[o_id.get()] {
170 let mut label = "o".to_string();
171 if o_id.is_inplace() {
172 label.push('*');
173 }
174 g.add_edge(src, dst, label.clone());
175 }
176 }
177 }
178 Op::Permute { v_id, .. } => {
179 if let Some(src) = idx_map[v_id.get()] {
180 let mut label = "v".to_string();
181 if v_id.is_inplace() {
182 label.push('*');
183 }
184 g.add_edge(src, dst, label.clone());
185 }
186 }
187 Op::NoOp | Op::Fill { .. } | Op::Arange { .. } | Op::Rand | Op::Randn { .. } => {}
189 }
190 }
191
192 g
193 }
194
195 pub fn to_dot(&self) -> String {
197 let g = self.to_petgraph();
198 format!("{:?}", Dot::with_config(&g, &[]))
199 }
200
201 pub fn visualize<P: AsRef<Path>>(&self, filename: P) -> Result<()> {
207 let path = filename.as_ref();
208 let tmp_dir = env::temp_dir();
209 let dot_path = tmp_dir.join("graph.dot");
210 let png_path = path.to_path_buf();
211
212 fs::write(&dot_path, self.to_dot())?;
213 let status = Command::new("dot")
214 .args([
215 "-Tpng",
216 &dot_path.display().to_string(),
217 "-o",
218 &png_path.display().to_string(),
219 ])
220 .status()?;
221 if !status.success() {
222 panic!("Graphviz failed");
223 }
224
225 Ok(())
226 }
227
228 fn optimize_const(&mut self) {
231 let ops = self.data.read().unwrap().clone();
233 let mut new_ops = ops.clone();
234 for (i, node) in ops.iter().enumerate() {
235 match &node.op {
236 Op::BinaryOp {
237 l_id,
238 r_id,
239 operator,
240 } => {
241 let l_idx = l_id.get();
242 let r_idx = r_id.get();
243 if let Op::Fill { v: v1 } = &new_ops[l_idx].op {
245 if let Op::Fill { v: v2 } = &new_ops[r_idx].op {
246 let v = operator.as_closure()(*v1, *v2);
247 new_ops[i] = GraphNode {
248 op: Op::Fill { v },
249 ..node.clone()
250 };
251 }
252 }
253 }
254 Op::UnaryOp { v_id, operator } => {
255 let idx = v_id.get();
256 if let Op::Fill { v: v0 } = &new_ops[idx].op {
258 let v = operator.to_closure()(*v0);
259 new_ops[i] = GraphNode {
260 op: Op::Fill { v },
261 ..node.clone()
262 };
263 }
264 }
265 _ => {}
266 }
267 }
268 *self.data.write().unwrap() = new_ops;
270 }
271
272 fn optimize_fma(&mut self) {
274 let ops = self.data.write().unwrap().clone();
275 let mut new_ops = ops.clone();
276
277 for (x_id, x) in ops.iter().enumerate() {
279 if let Op::BinaryOp {
280 l_id: a_id,
281 r_id: b_id,
282 operator: BinaryOpType::Mul,
283 } = &x.op
284 {
285 if let Op::BinaryOp {
287 l_id: l_y,
288 r_id: r_y,
289 operator: BinaryOpType::Add,
290 } = &ops[x_id + 1].op
291 {
292 let y_id = x_id + 1;
293 if l_y.get() == x_id || r_y.get() == x_id && x.shape == ops[x_id + 1].shape {
294 let rhs_add = if l_y.get() == x_id { r_y } else { l_y };
296 new_ops[y_id] = GraphNode {
297 op: Op::FusedMulAdd {
298 a_id: a_id.clone(),
299 b_id: b_id.clone(),
300 c_id: rhs_add.clone(),
301 },
302 ..x.clone()
303 };
304 new_ops[x_id] = GraphNode {
305 op: Op::NoOp,
306 ..x.clone()
307 };
308
309 for user in new_ops.iter() {
311 let ids = match &user.op {
312 Op::Arange {
313 start: _,
314 step: _,
315 stop: _,
316 ..
317 } => vec![],
318 Op::Rand => vec![],
319 Op::Randn { mean: _, std: _ } => vec![],
320 Op::BinaryOp { l_id, r_id, .. } => vec![l_id, r_id],
321 Op::Fill { v: _, .. } => vec![],
322 Op::UnaryOp {
323 v_id, operator: _, ..
324 } => vec![v_id],
325 Op::FusedMulAdd {
326 a_id, b_id, c_id, ..
327 } => {
328 vec![a_id, b_id, c_id]
329 }
330 Op::MatMul {
331 l_id, r_id, o_id, ..
332 } => o_id
333 .as_ref()
334 .map(|o| vec![l_id, r_id, o])
335 .unwrap_or(vec![l_id, r_id]),
336 Op::Permute { v_id } => vec![v_id],
337 Op::NoOp => vec![],
338 };
339
340 let used_ids = ids
342 .into_iter()
343 .filter(|id| id.get() == y_id)
344 .collect::<Vec<_>>();
345 if !used_ids.is_empty() {
346 for id in used_ids {
347 id.set(x_id);
349 }
350 }
351 }
352 }
353 }
354 }
355 }
356
357 let filtered_ops = new_ops
359 .into_iter()
360 .filter(|op| !matches!(op.op, Op::NoOp))
361 .collect::<Vec<_>>();
362 *self.data.write().unwrap() = filtered_ops;
363 }
364
365 #[allow(clippy::mutable_key_type)]
367 fn count_input_usage(ops: &[GraphNode<T>]) -> HashMap<GraphTensorId, usize> {
368 #[allow(clippy::mutable_key_type)]
369 let mut usage: HashMap<GraphTensorId, usize> = HashMap::new();
370 for op in ops {
371 match &op.op {
372 Op::BinaryOp { l_id, r_id, .. } => {
373 *usage.entry(l_id.clone()).or_default() += 1;
374 *usage.entry(r_id.clone()).or_default() += 1;
375 }
376 Op::UnaryOp { v_id, .. } => {
377 *usage.entry(v_id.clone()).or_default() += 1;
378 }
379 Op::FusedMulAdd {
380 a_id, b_id, c_id, ..
381 } => {
382 *usage.entry(a_id.clone()).or_default() += 1;
383 *usage.entry(b_id.clone()).or_default() += 1;
384 *usage.entry(c_id.clone()).or_default() += 1;
385 }
386 Op::MatMul {
387 l_id, r_id, o_id, ..
388 } => {
389 *usage.entry(l_id.clone()).or_default() += 1;
390 *usage.entry(r_id.clone()).or_default() += 1;
391 if let Some(o_id) = o_id {
392 *usage.entry(o_id.clone()).or_default() += 1;
393 }
394 }
395 Op::Permute { v_id } => {
396 *usage.entry(v_id.clone()).or_default() += 1;
397 }
398 Op::NoOp | Op::Fill { .. } | Op::Arange { .. } | Op::Rand | Op::Randn { .. } => {}
400 }
401 }
402 usage
403 }
404
405 fn optimize_inplace_bin(&mut self) {
407 let ops = self.data.write().unwrap().clone();
408 let mut new_ops = ops.clone();
409 #[allow(clippy::mutable_key_type)]
410 let usage = Self::count_input_usage(&ops);
411 for (i, op) in ops.iter().enumerate() {
413 if let Op::BinaryOp {
414 l_id,
415 r_id,
416 operator,
417 } = &op.op
418 {
419 let l_use = usage.get(l_id).copied().unwrap_or(0);
420 let r_use = usage.get(r_id).copied().unwrap_or(0);
421 if l_use <= 1 || r_use <= 1 {
422 let target = if r_use > l_use {
424 r_id.clone()
425 } else {
426 l_id.clone()
427 };
428 new_ops[i] = GraphNode {
430 op: Op::BinaryOp {
431 l_id: l_id.clone().to_inplace_if(&target == l_id),
432 r_id: r_id.clone().to_inplace_if(&target == r_id),
433 operator: *operator,
434 },
435 ..op.clone()
436 };
437 }
438 }
439 }
440 *self.data.write().unwrap() = new_ops;
442 }
443
444 fn optimize_inplace_fma(&mut self) {
446 let ops = self.data.write().unwrap().clone();
447 let mut new_ops = ops.clone();
448 #[allow(clippy::mutable_key_type)]
449 let usage = Self::count_input_usage(&ops);
450 for (i, op) in ops.iter().enumerate() {
451 if let Op::FusedMulAdd { a_id, b_id, c_id } = &op.op {
452 let mut target = None;
453 if *usage.get(a_id).unwrap_or(&0) <= 1 {
455 target = Some(a_id.clone());
456 } else if *usage.get(b_id).unwrap_or(&0) <= 1 {
457 target = Some(b_id.clone());
458 } else if *usage.get(c_id).unwrap_or(&0) <= 1 {
459 target = Some(c_id.clone());
460 }
461 if let Some(out) = target {
462 new_ops[i] = GraphNode {
463 op: Op::FusedMulAdd {
464 a_id: a_id.clone().to_inplace_if(&out == a_id),
465 b_id: b_id.clone().to_inplace_if(&out == b_id),
466 c_id: c_id.clone().to_inplace_if(&out == c_id),
467 },
468 ..op.clone()
469 };
470 }
471 }
472 }
473 *self.data.write().unwrap() = new_ops;
474 }
475
476 fn optimize_inplace_matmul(&mut self) {
478 let ops = self.data.write().unwrap().clone();
479 let mut new_ops = ops.clone();
480 #[allow(clippy::mutable_key_type)]
481 let usage = Self::count_input_usage(&ops);
482 for (i, op) in ops.iter().enumerate() {
484 if let Op::MatMul {
485 o_id: Some(o_id),
486 l_id,
487 r_id,
488 k,
489 alpha,
490 beta,
491 } = &op.op
492 {
493 let o_use = usage.get(o_id).copied().unwrap_or(0);
494 if o_use <= 1 {
495 new_ops[i] = GraphNode {
497 op: Op::MatMul {
498 o_id: Some(o_id.to_inplace()),
499 l_id: l_id.clone(),
500 r_id: r_id.clone(),
501 k: *k,
502 alpha: *alpha,
503 beta: *beta,
504 },
505 ..op.clone()
506 };
507 }
508 }
509 }
510 *self.data.write().unwrap() = new_ops;
512 }
513
514 fn optimize_dead_code(&mut self) {
516 let old_ops = self.data.read().unwrap().clone();
518 let n = old_ops.len();
519 let mut keep = vec![false; n];
521 if n > 0 {
522 keep[n - 1] = true;
523 }
524 for i in (0..n).rev() {
526 if keep[i] {
527 match &old_ops[i].op {
528 Op::BinaryOp { l_id, r_id, .. } => {
529 keep[l_id.get()] = true;
530 keep[r_id.get()] = true;
531 }
532 Op::UnaryOp { v_id, .. } => {
533 keep[v_id.get()] = true;
534 }
535 Op::FusedMulAdd {
536 a_id, b_id, c_id, ..
537 } => {
538 keep[a_id.get()] = true;
539 keep[b_id.get()] = true;
540 keep[c_id.get()] = true;
541 }
542 Op::MatMul {
543 l_id, r_id, o_id, ..
544 } => {
545 keep[l_id.get()] = true;
546 keep[r_id.get()] = true;
547 if let Some(o_id) = o_id {
548 keep[o_id.get()] = true;
549 }
550 }
551 Op::Permute { v_id, .. } => {
552 keep[v_id.get()] = true;
553 }
554 Op::NoOp
555 | Op::Fill { .. }
556 | Op::Arange { .. }
557 | Op::Rand
558 | Op::Randn { .. } => (),
559 }
560 }
561 }
562 let mut index_map = std::collections::HashMap::new();
564 let mut new_ops = Vec::new();
565 for (old_idx, node) in old_ops.into_iter().enumerate() {
566 if keep[old_idx] {
567 let new_idx = new_ops.len();
568 index_map.insert(old_idx, new_idx);
569 new_ops.push(node);
570 }
571 }
572 for node in new_ops.iter_mut() {
574 match &mut node.op {
575 Op::BinaryOp { l_id, r_id, .. } => {
576 let old_l = l_id.get();
577 let old_r = r_id.get();
578 l_id.set(*index_map.get(&old_l).unwrap());
579 r_id.set(*index_map.get(&old_r).unwrap());
580 }
581 Op::UnaryOp { v_id, .. } => {
582 let old_v = v_id.get();
583 v_id.set(*index_map.get(&old_v).unwrap());
584 }
585 Op::FusedMulAdd {
586 a_id, b_id, c_id, ..
587 } => {
588 let old_a = a_id.get();
589 let old_b = b_id.get();
590 let old_c = c_id.get();
591 a_id.set(*index_map.get(&old_a).unwrap());
592 b_id.set(*index_map.get(&old_b).unwrap());
593 c_id.set(*index_map.get(&old_c).unwrap());
594 }
595 Op::MatMul {
596 l_id, r_id, o_id, ..
597 } => {
598 let old_l = l_id.get();
599 let old_r = r_id.get();
600 l_id.set(*index_map.get(&old_l).unwrap());
601 r_id.set(*index_map.get(&old_r).unwrap());
602 if let Some(o_id) = o_id {
603 let old_o = o_id.get();
604 o_id.set(*index_map.get(&old_o).unwrap());
605 }
606 }
607 _ => {}
608 }
609 }
610 *self.data.write().unwrap() = new_ops;
612 }
613
614 pub fn optimize(&mut self) {
624 self.optimize_const();
626 self.optimize_fma();
628 self.optimize_inplace_bin();
629 self.optimize_inplace_fma();
630 self.optimize_inplace_matmul();
631 self.optimize_dead_code();
633 }
634
635 pub fn compile<S: Shape, D: Dev>(self) -> Result<CompiledGraph<S, T, D>> {
637 if self
638 .data
639 .read()
640 .unwrap()
641 .last()
642 .is_some_and(|last| last.shape != S::shape())
643 {
644 let read = self.data.read();
645 let last = read.as_ref().unwrap().last().unwrap();
646
647 crate::bail!(
648 "Graph compiled shape is {:?} does not match the last node shape {:?}!",
649 &last.shape,
650 S::shape()
651 );
652 }
653
654 let device = D::resolve()?;
655
656 device.compile(self.data.read().unwrap().clone())
657 }
658}
659
660pub enum CompiledGraph<S: Shape, T: DType, D: Dev> {
662 Cpu {
663 order: Vec<usize>,
664 graph: Vec<GraphNode<T>>,
665 ghost: PhantomData<(S, T, D)>,
666 },
667 #[cfg(feature = "cuda")]
668 Cuda {
669 kernels: Vec<crate::cuda_backend::CudaCompiledKernel<T>>,
670 ghost: PhantomData<(S, T, D)>,
671 },
672}
673
674impl<S: Shape, T: DType, D: Dev> CompiledGraph<S, T, D> {
675 pub fn run(&self) -> Result<Tensor<S, T, D>> {
677 let device = D::resolve()?;
678 let storage = device.run_graph(self)?;
679 Ok(from_storage(Arc::new(storage)))
680 }
681}
682
683#[derive(PartialEq, Debug, Clone, Copy)]
684pub enum BinaryOpType {
685 Add,
686 Div,
687 Sub,
688 Mul,
689}
690
691impl BinaryOpType {
692 pub fn as_c_op(&self) -> &'static str {
693 match self {
694 Self::Add => "+",
695 Self::Div => "/",
696 Self::Sub => "-",
697 Self::Mul => "*",
698 }
699 }
700
701 pub fn as_closure<T: DType>(&self) -> impl Fn(T, T) -> T {
702 match self {
703 Self::Add => |x, y| x + y,
704 Self::Div => |x, y| x / y,
705 Self::Sub => |x, y| x - y,
706 Self::Mul => |x, y| x * y,
707 }
708 }
709}
710
711#[derive(PartialEq, Debug, Clone)]
712pub enum UnaryOpType {
713 Neg,
714 Sqrt,
715}
716
717impl UnaryOpType {
718 pub fn fill_in_c_op(&self, val: impl Display) -> String {
719 match self {
720 Self::Neg => format!("-{val}"),
721 Self::Sqrt => format!("static_cast<T>( sqrt( static_cast<double>({val}) ) )"),
722 }
723 }
724
725 pub fn to_closure<T: DType>(&self) -> impl Fn(T) -> T {
726 match self {
727 Self::Neg => T::maybe_neg,
728 Self::Sqrt => |x: T| x.sqrt(),
729 }
730 }
731}
732
733#[derive(PartialEq, Debug, Clone)]
734pub enum Op<T: DType> {
735 Fill {
736 v: T,
737 },
738 Arange {
739 start: T,
740 step: T,
741 stop: T,
742 },
743 BinaryOp {
744 l_id: GraphTensorId,
745 r_id: GraphTensorId,
746 operator: BinaryOpType,
747 },
748 UnaryOp {
749 v_id: GraphTensorId,
750 operator: UnaryOpType,
751 },
752 FusedMulAdd {
754 a_id: GraphTensorId,
755 b_id: GraphTensorId,
756 c_id: GraphTensorId,
757 },
758 MatMul {
761 l_id: GraphTensorId,
762 r_id: GraphTensorId,
763 o_id: Option<GraphTensorId>,
764 k: usize,
765 alpha: T,
766 beta: T,
767 },
768 Rand,
770 Randn {
772 mean: T,
773 std: T,
774 },
775 Permute {
777 v_id: GraphTensorId,
778 },
779 NoOp,
780}
781
782#[derive(Clone, PartialEq, Debug, Eq)]
783pub enum GraphTensorId {
785 OutOfPlace(Rc<Cell<usize>>),
786 InPlace(Rc<Cell<usize>>),
787}
788
789impl Hash for GraphTensorId {
790 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
791 state.write_usize(self.get());
792 }
793}
794
795impl GraphTensorId {
796 pub fn out_of_place(value: usize) -> Self {
797 Self::OutOfPlace(Rc::new(Cell::new(value)))
798 }
799
800 pub fn inplace(value: usize) -> Self {
801 Self::InPlace(Rc::new(Cell::new(value)))
802 }
803
804 pub fn to_inplace(&self) -> Self {
805 match self {
806 Self::OutOfPlace(x) | Self::InPlace(x) => Self::inplace(x.get()),
807 }
808 }
809
810 pub fn to_inplace_if(&self, predicate: bool) -> Self {
811 match self {
812 Self::OutOfPlace(x) | Self::InPlace(x) if predicate => Self::inplace(x.get()),
813 _ => self.clone(),
814 }
815 }
816
817 pub fn get(&self) -> usize {
818 match self {
819 GraphTensorId::InPlace(x) | GraphTensorId::OutOfPlace(x) => x.get(),
820 }
821 }
822
823 pub fn set(&self, value: usize) {
824 match self {
825 GraphTensorId::InPlace(x) | GraphTensorId::OutOfPlace(x) => x.set(value),
826 }
827 }
828
829 pub fn is_inplace(&self) -> bool {
830 matches!(self, Self::InPlace(_))
831 }
832}