constensor_core/
graph.rs

1use std::{
2    cell::Cell,
3    collections::HashMap,
4    env,
5    fmt::Display,
6    fs,
7    hash::Hash,
8    marker::PhantomData,
9    path::Path,
10    process::Command,
11    rc::Rc,
12    sync::{Arc, RwLock, RwLockReadGuard},
13};
14
15use crate::{device::Dev, tensor::concretetensor::from_storage, DType, Result, Shape, Tensor};
16
17use petgraph::Graph as PetGraph;
18use petgraph::{dot::Dot, graph::NodeIndex};
19
20#[derive(Clone, Debug)]
21pub struct GraphNode<T: DType> {
22    pub op: Op<T>,
23    pub shape: Vec<usize>,
24    pub strides: Vec<usize>,
25    pub id: GraphTensorId,
26}
27
28#[derive(Clone)]
29pub struct Graph<T: DType> {
30    data: Arc<RwLock<Vec<GraphNode<T>>>>,
31    id: Arc<RwLock<usize>>,
32}
33
34impl<T: DType> Graph<T> {
35    /// Create an empty Graph
36    pub fn empty() -> Self {
37        Self {
38            data: Arc::new(RwLock::new(Vec::new())),
39            id: Arc::new(RwLock::new(0)),
40        }
41    }
42
43    /// Read-only access to the list of operations
44    pub fn get_ops(&self) -> RwLockReadGuard<Vec<GraphNode<T>>> {
45        self.data.read().unwrap()
46    }
47
48    /// Append an operation to the graph
49    pub(crate) fn add_op<S: Shape>(&self, op: Op<T>, strides: &[usize], id: &GraphTensorId) {
50        self.data.write().unwrap().push(GraphNode {
51            op,
52            shape: S::shape(),
53            strides: strides.to_vec(),
54            id: id.clone(),
55        });
56    }
57
58    /// Generate the next unique tensor ID
59    #[must_use]
60    pub(crate) fn next_id(&mut self) -> GraphTensorId {
61        let next = GraphTensorId::out_of_place(*self.id.read().unwrap());
62        *self.id.write().unwrap() += 1;
63        next
64    }
65
66    pub fn to_petgraph(&self) -> PetGraph<String, String> {
67        let ops = self.data.read().unwrap();
68        let mut g = PetGraph::<String, String>::new();
69        // map from op‐index → Some(node) if we created a node, or None if it was a NoOp
70        let mut idx_map: Vec<Option<NodeIndex>> = Vec::with_capacity(ops.len());
71
72        // 1) Add only non‐NoOp nodes
73        for op in ops.iter() {
74            match op.op {
75                Op::NoOp => {
76                    idx_map.push(None);
77                }
78                _ => {
79                    let label = match &op.op {
80                        Op::Fill { v, .. } => format!("Fill({v:?})"),
81                        Op::Arange {
82                            start, step, stop, ..
83                        } => {
84                            format!("Arange(start={start:?}, step={step:?}, stop={stop:?})")
85                        }
86                        Op::Rand => "Rand".to_string(),
87                        Op::Randn { mean, std } => {
88                            format!("Randn(mean={mean:?}, std={std:?})")
89                        }
90                        Op::BinaryOp { operator, .. } => format!("BinOp({})", operator.as_c_op()),
91                        Op::UnaryOp { operator, .. } => format!("UnOp({operator:?})"),
92                        Op::FusedMulAdd { .. } => "FMA".to_string(),
93                        // Matrix multiplication
94                        Op::MatMul { .. } => "MatMul".to_string(),
95                        Op::Permute { v_id: _ } => "Permute".to_string(),
96                        // we already matched NoOp above
97                        Op::NoOp => unreachable!(),
98                    };
99                    let node = g.add_node(label);
100                    idx_map.push(Some(node));
101                }
102            }
103        }
104
105        // 2) Walk ops again and only connect edges for those dst nodes that exist
106        for (i, op) in ops.iter().enumerate() {
107            // if this op was NoOp, skip entirely
108            let dst = match idx_map[i] {
109                Some(dst) => dst,
110                None => continue,
111            };
112            match &op.op {
113                Op::BinaryOp { l_id, r_id, .. } => {
114                    if let Some(src) = idx_map[l_id.get()] {
115                        let mut label = "l".to_string();
116                        if l_id.is_inplace() {
117                            label.push('*');
118                        }
119                        g.add_edge(src, dst, label.clone());
120                    }
121                    if let Some(src) = idx_map[r_id.get()] {
122                        let mut label = "r".to_string();
123                        if r_id.is_inplace() {
124                            label.push('*');
125                        }
126                        g.add_edge(src, dst, label.clone());
127                    }
128                }
129                Op::UnaryOp { v_id, .. } => {
130                    if let Some(src) = idx_map[v_id.get()] {
131                        let mut label = "v".to_string();
132                        if v_id.is_inplace() {
133                            label.push('*');
134                        }
135                        g.add_edge(src, dst, label.clone());
136                    }
137                }
138                Op::FusedMulAdd {
139                    a_id, b_id, c_id, ..
140                } => {
141                    for (prefix, src_id) in [("a", a_id), ("b", b_id), ("c", c_id)].iter() {
142                        if let Some(src) = idx_map[src_id.get()] {
143                            let mut label = prefix.to_string();
144                            if src_id.is_inplace() {
145                                label.push('*');
146                            }
147                            g.add_edge(src, dst, label.clone());
148                        }
149                    }
150                }
151                Op::MatMul {
152                    l_id, r_id, o_id, ..
153                } => {
154                    if let Some(src) = idx_map[l_id.get()] {
155                        let mut label = "l".to_string();
156                        if l_id.is_inplace() {
157                            label.push('*');
158                        }
159                        g.add_edge(src, dst, label.clone());
160                    }
161                    if let Some(src) = idx_map[r_id.get()] {
162                        let mut label = "r".to_string();
163                        if r_id.is_inplace() {
164                            label.push('*');
165                        }
166                        g.add_edge(src, dst, label.clone());
167                    }
168                    if let Some(o_id) = o_id {
169                        if let Some(src) = idx_map[o_id.get()] {
170                            let mut label = "o".to_string();
171                            if o_id.is_inplace() {
172                                label.push('*');
173                            }
174                            g.add_edge(src, dst, label.clone());
175                        }
176                    }
177                }
178                Op::Permute { v_id, .. } => {
179                    if let Some(src) = idx_map[v_id.get()] {
180                        let mut label = "v".to_string();
181                        if v_id.is_inplace() {
182                            label.push('*');
183                        }
184                        g.add_edge(src, dst, label.clone());
185                    }
186                }
187                // NoOp, Fill/Arange, Rand/Randn don’t create incoming edges
188                Op::NoOp | Op::Fill { .. } | Op::Arange { .. } | Op::Rand | Op::Randn { .. } => {}
189            }
190        }
191
192        g
193    }
194
195    /// Produce a DOT format string of this graph.
196    pub fn to_dot(&self) -> String {
197        let g = self.to_petgraph();
198        format!("{:?}", Dot::with_config(&g, &[]))
199    }
200
201    /// Visualize the graph by saving it to this file.
202    ///
203    /// Install graphvis:
204    /// - brew install graphviz
205    /// - apt install graphviz
206    pub fn visualize<P: AsRef<Path>>(&self, filename: P) -> Result<()> {
207        let path = filename.as_ref();
208        let tmp_dir = env::temp_dir();
209        let dot_path = tmp_dir.join("graph.dot");
210        let png_path = path.to_path_buf();
211
212        fs::write(&dot_path, self.to_dot())?;
213        let status = Command::new("dot")
214            .args([
215                "-Tpng",
216                &dot_path.display().to_string(),
217                "-o",
218                &png_path.display().to_string(),
219            ])
220            .status()?;
221        if !status.success() {
222            panic!("Graphviz failed");
223        }
224
225        Ok(())
226    }
227
228    /// Optimize by performing constant folding:
229    ///   - Fold BinaryOp and UnaryOp when all operands are constant Fill ops.
230    fn optimize_const(&mut self) {
231        // Clone current ops for inspection
232        let ops = self.data.read().unwrap().clone();
233        let mut new_ops = ops.clone();
234        for (i, node) in ops.iter().enumerate() {
235            match &node.op {
236                Op::BinaryOp {
237                    l_id,
238                    r_id,
239                    operator,
240                } => {
241                    let l_idx = l_id.get();
242                    let r_idx = r_id.get();
243                    // both operands are constant fills
244                    if let Op::Fill { v: v1 } = &new_ops[l_idx].op {
245                        if let Op::Fill { v: v2 } = &new_ops[r_idx].op {
246                            let v = operator.as_closure()(*v1, *v2);
247                            new_ops[i] = GraphNode {
248                                op: Op::Fill { v },
249                                ..node.clone()
250                            };
251                        }
252                    }
253                }
254                Op::UnaryOp { v_id, operator } => {
255                    let idx = v_id.get();
256                    // operand is a constant fill
257                    if let Op::Fill { v: v0 } = &new_ops[idx].op {
258                        let v = operator.to_closure()(*v0);
259                        new_ops[i] = GraphNode {
260                            op: Op::Fill { v },
261                            ..node.clone()
262                        };
263                    }
264                }
265                _ => {}
266            }
267        }
268        // Commit folded constants
269        *self.data.write().unwrap() = new_ops;
270    }
271
272    /// Optimize by looking for mul-add pairs, convert to FMA
273    fn optimize_fma(&mut self) {
274        let ops = self.data.write().unwrap().clone();
275        let mut new_ops = ops.clone();
276
277        // This contains the indices of the first of the pair.
278        for (x_id, x) in ops.iter().enumerate() {
279            if let Op::BinaryOp {
280                l_id: a_id,
281                r_id: b_id,
282                operator: BinaryOpType::Mul,
283            } = &x.op
284            {
285                // Check if next op uses this
286                if let Op::BinaryOp {
287                    l_id: l_y,
288                    r_id: r_y,
289                    operator: BinaryOpType::Add,
290                } = &ops[x_id + 1].op
291                {
292                    let y_id = x_id + 1;
293                    if l_y.get() == x_id || r_y.get() == x_id && x.shape == ops[x_id + 1].shape {
294                        // Want to see what is being added to the result of the mul
295                        let rhs_add = if l_y.get() == x_id { r_y } else { l_y };
296                        new_ops[y_id] = GraphNode {
297                            op: Op::FusedMulAdd {
298                                a_id: a_id.clone(),
299                                b_id: b_id.clone(),
300                                c_id: rhs_add.clone(),
301                            },
302                            ..x.clone()
303                        };
304                        new_ops[x_id] = GraphNode {
305                            op: Op::NoOp,
306                            ..x.clone()
307                        };
308
309                        // Look for ops which actually use this one
310                        for user in new_ops.iter() {
311                            let ids = match &user.op {
312                                Op::Arange {
313                                    start: _,
314                                    step: _,
315                                    stop: _,
316                                    ..
317                                } => vec![],
318                                Op::Rand => vec![],
319                                Op::Randn { mean: _, std: _ } => vec![],
320                                Op::BinaryOp { l_id, r_id, .. } => vec![l_id, r_id],
321                                Op::Fill { v: _, .. } => vec![],
322                                Op::UnaryOp {
323                                    v_id, operator: _, ..
324                                } => vec![v_id],
325                                Op::FusedMulAdd {
326                                    a_id, b_id, c_id, ..
327                                } => {
328                                    vec![a_id, b_id, c_id]
329                                }
330                                Op::MatMul {
331                                    l_id, r_id, o_id, ..
332                                } => o_id
333                                    .as_ref()
334                                    .map(|o| vec![l_id, r_id, o])
335                                    .unwrap_or(vec![l_id, r_id]),
336                                Op::Permute { v_id } => vec![v_id],
337                                Op::NoOp => vec![],
338                            };
339
340                            // We are going to remove the noop so this is necessary to fix the indices.
341                            let used_ids = ids
342                                .into_iter()
343                                .filter(|id| id.get() == y_id)
344                                .collect::<Vec<_>>();
345                            if !used_ids.is_empty() {
346                                for id in used_ids {
347                                    // Tell the ops which use the result of the fma to source from there
348                                    id.set(x_id);
349                                }
350                            }
351                        }
352                    }
353                }
354            }
355        }
356
357        // Remove any NoOp entries before storing back to the graph
358        let filtered_ops = new_ops
359            .into_iter()
360            .filter(|op| !matches!(op.op, Op::NoOp))
361            .collect::<Vec<_>>();
362        *self.data.write().unwrap() = filtered_ops;
363    }
364
365    /// Count how often each tensor id is used as an input.
366    #[allow(clippy::mutable_key_type)]
367    fn count_input_usage(ops: &[GraphNode<T>]) -> HashMap<GraphTensorId, usize> {
368        #[allow(clippy::mutable_key_type)]
369        let mut usage: HashMap<GraphTensorId, usize> = HashMap::new();
370        for op in ops {
371            match &op.op {
372                Op::BinaryOp { l_id, r_id, .. } => {
373                    *usage.entry(l_id.clone()).or_default() += 1;
374                    *usage.entry(r_id.clone()).or_default() += 1;
375                }
376                Op::UnaryOp { v_id, .. } => {
377                    *usage.entry(v_id.clone()).or_default() += 1;
378                }
379                Op::FusedMulAdd {
380                    a_id, b_id, c_id, ..
381                } => {
382                    *usage.entry(a_id.clone()).or_default() += 1;
383                    *usage.entry(b_id.clone()).or_default() += 1;
384                    *usage.entry(c_id.clone()).or_default() += 1;
385                }
386                Op::MatMul {
387                    l_id, r_id, o_id, ..
388                } => {
389                    *usage.entry(l_id.clone()).or_default() += 1;
390                    *usage.entry(r_id.clone()).or_default() += 1;
391                    if let Some(o_id) = o_id {
392                        *usage.entry(o_id.clone()).or_default() += 1;
393                    }
394                }
395                Op::Permute { v_id } => {
396                    *usage.entry(v_id.clone()).or_default() += 1;
397                }
398                // No input usage for these ops
399                Op::NoOp | Op::Fill { .. } | Op::Arange { .. } | Op::Rand | Op::Randn { .. } => {}
400            }
401        }
402        usage
403    }
404
405    /// Optimize by inplacing binary operations when inputs are not reused.
406    fn optimize_inplace_bin(&mut self) {
407        let ops = self.data.write().unwrap().clone();
408        let mut new_ops = ops.clone();
409        #[allow(clippy::mutable_key_type)]
410        let usage = Self::count_input_usage(&ops);
411        // Transform eligible BinaryOps into InplaceBinaryOps.
412        for (i, op) in ops.iter().enumerate() {
413            if let Op::BinaryOp {
414                l_id,
415                r_id,
416                operator,
417            } = &op.op
418            {
419                let l_use = usage.get(l_id).copied().unwrap_or(0);
420                let r_use = usage.get(r_id).copied().unwrap_or(0);
421                if l_use <= 1 || r_use <= 1 {
422                    // Choose target for in-place: if both, default to lhs.
423                    let target = if r_use > l_use {
424                        r_id.clone()
425                    } else {
426                        l_id.clone()
427                    };
428                    // Replace with InplaceBinaryOp
429                    new_ops[i] = GraphNode {
430                        op: Op::BinaryOp {
431                            l_id: l_id.clone().to_inplace_if(&target == l_id),
432                            r_id: r_id.clone().to_inplace_if(&target == r_id),
433                            operator: *operator,
434                        },
435                        ..op.clone()
436                    };
437                }
438            }
439        }
440        // Commit the transformed op list.
441        *self.data.write().unwrap() = new_ops;
442    }
443
444    /// Optimize by inplacing fused multiply-add (FMA) operations when inputs are not reused.
445    fn optimize_inplace_fma(&mut self) {
446        let ops = self.data.write().unwrap().clone();
447        let mut new_ops = ops.clone();
448        #[allow(clippy::mutable_key_type)]
449        let usage = Self::count_input_usage(&ops);
450        for (i, op) in ops.iter().enumerate() {
451            if let Op::FusedMulAdd { a_id, b_id, c_id } = &op.op {
452                let mut target = None;
453                // If an input is used only once, we can reuse its buffer; default order: a_id, then b_id, then c_id
454                if *usage.get(a_id).unwrap_or(&0) <= 1 {
455                    target = Some(a_id.clone());
456                } else if *usage.get(b_id).unwrap_or(&0) <= 1 {
457                    target = Some(b_id.clone());
458                } else if *usage.get(c_id).unwrap_or(&0) <= 1 {
459                    target = Some(c_id.clone());
460                }
461                if let Some(out) = target {
462                    new_ops[i] = GraphNode {
463                        op: Op::FusedMulAdd {
464                            a_id: a_id.clone().to_inplace_if(&out == a_id),
465                            b_id: b_id.clone().to_inplace_if(&out == b_id),
466                            c_id: c_id.clone().to_inplace_if(&out == c_id),
467                        },
468                        ..op.clone()
469                    };
470                }
471            }
472        }
473        *self.data.write().unwrap() = new_ops;
474    }
475
476    /// Optimize by inplacing the output of a matmul when inputs are not reused.
477    fn optimize_inplace_matmul(&mut self) {
478        let ops = self.data.write().unwrap().clone();
479        let mut new_ops = ops.clone();
480        #[allow(clippy::mutable_key_type)]
481        let usage = Self::count_input_usage(&ops);
482        // Transform eligible BinaryOps into InplaceBinaryOps.
483        for (i, op) in ops.iter().enumerate() {
484            if let Op::MatMul {
485                o_id: Some(o_id),
486                l_id,
487                r_id,
488                k,
489                alpha,
490                beta,
491            } = &op.op
492            {
493                let o_use = usage.get(o_id).copied().unwrap_or(0);
494                if o_use <= 1 {
495                    // Replace with InplaceBinaryOp
496                    new_ops[i] = GraphNode {
497                        op: Op::MatMul {
498                            o_id: Some(o_id.to_inplace()),
499                            l_id: l_id.clone(),
500                            r_id: r_id.clone(),
501                            k: *k,
502                            alpha: *alpha,
503                            beta: *beta,
504                        },
505                        ..op.clone()
506                    };
507                }
508            }
509        }
510        // Commit the transformed op list.
511        *self.data.write().unwrap() = new_ops;
512    }
513
514    /// Remove nodes whose outputs are never used, except the final output node.
515    fn optimize_dead_code(&mut self) {
516        // Clone current ops
517        let old_ops = self.data.read().unwrap().clone();
518        let n = old_ops.len();
519        // Mark reachable nodes: start from final output
520        let mut keep = vec![false; n];
521        if n > 0 {
522            keep[n - 1] = true;
523        }
524        // Propagate reachability backwards
525        for i in (0..n).rev() {
526            if keep[i] {
527                match &old_ops[i].op {
528                    Op::BinaryOp { l_id, r_id, .. } => {
529                        keep[l_id.get()] = true;
530                        keep[r_id.get()] = true;
531                    }
532                    Op::UnaryOp { v_id, .. } => {
533                        keep[v_id.get()] = true;
534                    }
535                    Op::FusedMulAdd {
536                        a_id, b_id, c_id, ..
537                    } => {
538                        keep[a_id.get()] = true;
539                        keep[b_id.get()] = true;
540                        keep[c_id.get()] = true;
541                    }
542                    Op::MatMul {
543                        l_id, r_id, o_id, ..
544                    } => {
545                        keep[l_id.get()] = true;
546                        keep[r_id.get()] = true;
547                        if let Some(o_id) = o_id {
548                            keep[o_id.get()] = true;
549                        }
550                    }
551                    Op::Permute { v_id, .. } => {
552                        keep[v_id.get()] = true;
553                    }
554                    Op::NoOp
555                    | Op::Fill { .. }
556                    | Op::Arange { .. }
557                    | Op::Rand
558                    | Op::Randn { .. } => (),
559                }
560            }
561        }
562        // Build new ops and map old indices to new indices
563        let mut index_map = std::collections::HashMap::new();
564        let mut new_ops = Vec::new();
565        for (old_idx, node) in old_ops.into_iter().enumerate() {
566            if keep[old_idx] {
567                let new_idx = new_ops.len();
568                index_map.insert(old_idx, new_idx);
569                new_ops.push(node);
570            }
571        }
572        // Update tensor IDs in remaining ops
573        for node in new_ops.iter_mut() {
574            match &mut node.op {
575                Op::BinaryOp { l_id, r_id, .. } => {
576                    let old_l = l_id.get();
577                    let old_r = r_id.get();
578                    l_id.set(*index_map.get(&old_l).unwrap());
579                    r_id.set(*index_map.get(&old_r).unwrap());
580                }
581                Op::UnaryOp { v_id, .. } => {
582                    let old_v = v_id.get();
583                    v_id.set(*index_map.get(&old_v).unwrap());
584                }
585                Op::FusedMulAdd {
586                    a_id, b_id, c_id, ..
587                } => {
588                    let old_a = a_id.get();
589                    let old_b = b_id.get();
590                    let old_c = c_id.get();
591                    a_id.set(*index_map.get(&old_a).unwrap());
592                    b_id.set(*index_map.get(&old_b).unwrap());
593                    c_id.set(*index_map.get(&old_c).unwrap());
594                }
595                Op::MatMul {
596                    l_id, r_id, o_id, ..
597                } => {
598                    let old_l = l_id.get();
599                    let old_r = r_id.get();
600                    l_id.set(*index_map.get(&old_l).unwrap());
601                    r_id.set(*index_map.get(&old_r).unwrap());
602                    if let Some(o_id) = o_id {
603                        let old_o = o_id.get();
604                        o_id.set(*index_map.get(&old_o).unwrap());
605                    }
606                }
607                _ => {}
608            }
609        }
610        // Commit pruned graph
611        *self.data.write().unwrap() = new_ops;
612    }
613
614    /// Optimize this graph.
615    ///
616    /// Apply the following optimizations:
617    /// - Constant folding of elementwise fills
618    /// - Fuse mul-add into FMA
619    /// - Inplace binary operations when safe
620    /// - Inplace fused multiply-add when safe
621    /// - Inplace matrix-multiplication when safe
622    /// - Dead code removal
623    pub fn optimize(&mut self) {
624        // Constant folding first
625        self.optimize_const();
626        // Fuse mul-add into FMA
627        self.optimize_fma();
628        self.optimize_inplace_bin();
629        self.optimize_inplace_fma();
630        self.optimize_inplace_matmul();
631        // Remove dead code
632        self.optimize_dead_code();
633    }
634
635    /// Compile this graph and insert device-specific optimizations such as CUDA streams.
636    pub fn compile<S: Shape, D: Dev>(self) -> Result<CompiledGraph<S, T, D>> {
637        if self
638            .data
639            .read()
640            .unwrap()
641            .last()
642            .is_some_and(|last| last.shape != S::shape())
643        {
644            let read = self.data.read();
645            let last = read.as_ref().unwrap().last().unwrap();
646
647            crate::bail!(
648                "Graph compiled shape is {:?} does not match the last node shape {:?}!",
649                &last.shape,
650                S::shape()
651            );
652        }
653
654        let device = D::resolve()?;
655
656        device.compile(self.data.read().unwrap().clone())
657    }
658}
659
660/// A representation of the compiled graph. The shape is the output shape.
661pub enum CompiledGraph<S: Shape, T: DType, D: Dev> {
662    Cpu {
663        order: Vec<usize>,
664        graph: Vec<GraphNode<T>>,
665        ghost: PhantomData<(S, T, D)>,
666    },
667    #[cfg(feature = "cuda")]
668    Cuda {
669        kernels: Vec<crate::cuda_backend::CudaCompiledKernel<T>>,
670        ghost: PhantomData<(S, T, D)>,
671    },
672}
673
674impl<S: Shape, T: DType, D: Dev> CompiledGraph<S, T, D> {
675    /// Run the precompiled graph. This executes all nodes on the specified backend device and returns a concrete tensor.
676    pub fn run(&self) -> Result<Tensor<S, T, D>> {
677        let device = D::resolve()?;
678        let storage = device.run_graph(self)?;
679        Ok(from_storage(Arc::new(storage)))
680    }
681}
682
683#[derive(PartialEq, Debug, Clone, Copy)]
684pub enum BinaryOpType {
685    Add,
686    Div,
687    Sub,
688    Mul,
689}
690
691impl BinaryOpType {
692    pub fn as_c_op(&self) -> &'static str {
693        match self {
694            Self::Add => "+",
695            Self::Div => "/",
696            Self::Sub => "-",
697            Self::Mul => "*",
698        }
699    }
700
701    pub fn as_closure<T: DType>(&self) -> impl Fn(T, T) -> T {
702        match self {
703            Self::Add => |x, y| x + y,
704            Self::Div => |x, y| x / y,
705            Self::Sub => |x, y| x - y,
706            Self::Mul => |x, y| x * y,
707        }
708    }
709}
710
711#[derive(PartialEq, Debug, Clone)]
712pub enum UnaryOpType {
713    Neg,
714    Sqrt,
715}
716
717impl UnaryOpType {
718    pub fn fill_in_c_op(&self, val: impl Display) -> String {
719        match self {
720            Self::Neg => format!("-{val}"),
721            Self::Sqrt => format!("static_cast<T>( sqrt( static_cast<double>({val}) ) )"),
722        }
723    }
724
725    pub fn to_closure<T: DType>(&self) -> impl Fn(T) -> T {
726        match self {
727            Self::Neg => T::maybe_neg,
728            Self::Sqrt => |x: T| x.sqrt(),
729        }
730    }
731}
732
733#[derive(PartialEq, Debug, Clone)]
734pub enum Op<T: DType> {
735    Fill {
736        v: T,
737    },
738    Arange {
739        start: T,
740        step: T,
741        stop: T,
742    },
743    BinaryOp {
744        l_id: GraphTensorId,
745        r_id: GraphTensorId,
746        operator: BinaryOpType,
747    },
748    UnaryOp {
749        v_id: GraphTensorId,
750        operator: UnaryOpType,
751    },
752    /// a * b + c
753    FusedMulAdd {
754        a_id: GraphTensorId,
755        b_id: GraphTensorId,
756        c_id: GraphTensorId,
757    },
758    /// (B x M x K) * (B x K x N) = (B x M x N)
759    /// out = out * alpha + beta * lhs * rhs
760    MatMul {
761        l_id: GraphTensorId,
762        r_id: GraphTensorId,
763        o_id: Option<GraphTensorId>,
764        k: usize,
765        alpha: T,
766        beta: T,
767    },
768    /// Fill with uniform random values in [0, 1).
769    Rand,
770    /// Fill with normally distributed random values (mean, std).
771    Randn {
772        mean: T,
773        std: T,
774    },
775    // Permutation operator.
776    Permute {
777        v_id: GraphTensorId,
778    },
779    NoOp,
780}
781
782#[derive(Clone, PartialEq, Debug, Eq)]
783/// Graph tensor IDs can be cloned.
784pub enum GraphTensorId {
785    OutOfPlace(Rc<Cell<usize>>),
786    InPlace(Rc<Cell<usize>>),
787}
788
789impl Hash for GraphTensorId {
790    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
791        state.write_usize(self.get());
792    }
793}
794
795impl GraphTensorId {
796    pub fn out_of_place(value: usize) -> Self {
797        Self::OutOfPlace(Rc::new(Cell::new(value)))
798    }
799
800    pub fn inplace(value: usize) -> Self {
801        Self::InPlace(Rc::new(Cell::new(value)))
802    }
803
804    pub fn to_inplace(&self) -> Self {
805        match self {
806            Self::OutOfPlace(x) | Self::InPlace(x) => Self::inplace(x.get()),
807        }
808    }
809
810    pub fn to_inplace_if(&self, predicate: bool) -> Self {
811        match self {
812            Self::OutOfPlace(x) | Self::InPlace(x) if predicate => Self::inplace(x.get()),
813            _ => self.clone(),
814        }
815    }
816
817    pub fn get(&self) -> usize {
818        match self {
819            GraphTensorId::InPlace(x) | GraphTensorId::OutOfPlace(x) => x.get(),
820        }
821    }
822
823    pub fn set(&self, value: usize) {
824        match self {
825            GraphTensorId::InPlace(x) | GraphTensorId::OutOfPlace(x) => x.set(value),
826        }
827    }
828
829    pub fn is_inplace(&self) -> bool {
830        matches!(self, Self::InPlace(_))
831    }
832}