Skip to main content

rlx_ir/
provenance.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Cross-stage node provenance — HIR block → MIR node → fusion pass.
17
18use std::fmt;
19
20use crate::hir::HirNodeId;
21use crate::{Graph, NodeId};
22
23/// Where a MIR node came from and how it was produced.
24#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
25#[derive(Debug, Clone, Default, PartialEq, Eq)]
26pub struct NodeOrigin {
27    /// Source HIR block node, when lowered from [`HirModule`].
28    pub hir: Option<HirNodeId>,
29    /// Human label (`layer0.ffn`, `swiglu_ffn`, param name, …).
30    pub label: Option<String>,
31    /// Optimizer pass that last created or fused this node.
32    pub pass: Option<String>,
33}
34
35impl NodeOrigin {
36    pub fn from_hir(hir: HirNodeId, label: Option<String>) -> Self {
37        Self {
38            hir: Some(hir),
39            label,
40            pass: None,
41        }
42    }
43
44    pub fn inherit_from_graph(graph: &Graph, inputs: &[NodeId], pass: &str) -> Self {
45        let mut out = Self::default();
46        for &id in inputs {
47            let node = graph.node(id);
48            if let Some(ref o) = node.origin {
49                if out.hir.is_none() {
50                    out.hir = o.hir;
51                }
52                if out.label.is_none() {
53                    out.label = o.label.clone();
54                }
55            }
56            if out.label.is_none() {
57                out.label = node.name.clone();
58            }
59            if out.hir.is_some() && out.label.is_some() {
60                break;
61            }
62        }
63        out.pass = Some(pass.to_string());
64        out
65    }
66}
67
68impl fmt::Display for NodeOrigin {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        let mut parts = Vec::new();
71        if let Some(h) = self.hir {
72            parts.push(format!("hir={h}"));
73        }
74        if let Some(l) = &self.label {
75            parts.push(format!("\"{l}\""));
76        }
77        if let Some(p) = &self.pass {
78            parts.push(format!("pass={p}"));
79        }
80        if parts.is_empty() {
81            write!(f, "—")
82        } else {
83            write!(f, "{}", parts.join(", "))
84        }
85    }
86}
87
88/// Best-effort label for diagnostics (origin label, node name, or id).
89pub fn node_label(graph: &Graph, id: NodeId) -> String {
90    let node = graph.node(id);
91    if let Some(ref o) = node.origin {
92        if let Some(ref l) = o.label {
93            return l.clone();
94        }
95        if let Some(h) = o.hir {
96            return format!("{h}");
97        }
98    }
99    node.name.clone().unwrap_or_else(|| format!("{id}"))
100}
101
102/// Stamp nodes created by a pass (no origin yet) by inheriting from inputs.
103pub fn stamp_pass_origins(graph: &mut Graph, pass: &str) {
104    let ids: Vec<NodeId> = graph.nodes().iter().map(|n| n.id).collect();
105    for id in ids {
106        if graph.node(id).origin.is_some() {
107            continue;
108        }
109        let inputs = graph.node(id).inputs.clone();
110        let origin = NodeOrigin::inherit_from_graph(graph, &inputs, pass);
111        graph.node_mut(id).origin = Some(origin);
112    }
113}