Skip to main content

rlx_runtime/
graph_io.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Static IO / sync profile for compiled graphs (Phase 0 — fusion planning).
17
18use rlx_ir::{Graph, Node, Op, OpKind};
19use serde::{Deserialize, Serialize};
20use std::collections::HashSet;
21
22/// Tuning for static IO analysis.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
24pub struct GraphIoOptions {
25    /// Count each `Op::Fft` as a host-sync boundary (non-native fallback).
26    pub fft_host_sync: bool,
27}
28
29/// Host-visible traffic and sync points for one forward pass.
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
31pub struct GraphIoProfile {
32    /// Kernel / thunk dispatches (one per non-view executable node).
33    pub kernel_launches: usize,
34    /// GPU flush + host-side thunk boundaries (Metal LogMel, host FFT fallback, …).
35    pub sync_points: usize,
36    /// Bytes returned to the caller via graph outputs (`CompiledGraph::run`).
37    pub host_output_bytes: u64,
38    /// Bytes moved inside the device arena (read inputs + write outputs per node).
39    pub device_traffic_bytes: u64,
40}
41
42impl GraphIoProfile {
43    pub fn host_readback_bytes(&self, unified_memory: bool) -> u64 {
44        if unified_memory {
45            self.host_output_bytes
46        } else {
47            self.host_output_bytes
48                .saturating_add(self.device_traffic_bytes / 4)
49        }
50    }
51}
52
53/// Ops that force a GPU sync + host thunk on Metal today.
54pub fn metal_host_sync_kinds() -> &'static [OpKind] {
55    &[
56        OpKind::LogMel,
57        OpKind::LogMelBackward,
58        OpKind::Custom,
59        OpKind::WelchPeaks,
60    ]
61}
62
63/// Profile a graph before compile (conservative static estimate).
64pub fn profile_graph_io(graph: &Graph) -> GraphIoProfile {
65    profile_graph_io_with_options(graph, GraphIoOptions::default())
66}
67
68pub fn profile_graph_io_with_options(graph: &Graph, opts: GraphIoOptions) -> GraphIoProfile {
69    let mut profile = GraphIoProfile::default();
70    let output_nodes: HashSet<_> = graph.outputs.iter().copied().collect();
71
72    for node in graph.nodes() {
73        if is_metadata_op(&node.op) {
74            continue;
75        }
76        profile.kernel_launches += 1;
77        profile.device_traffic_bytes += node_io_bytes(node, graph);
78
79        let kind = node.op.kind();
80        if metal_host_sync_kinds().contains(&kind) {
81            profile.sync_points += 1;
82        }
83        if opts.fft_host_sync && kind == OpKind::Fft {
84            profile.sync_points += 1;
85        }
86
87        if output_nodes.contains(&node.id) {
88            profile.host_output_bytes += tensor_bytes(&node.shape);
89        }
90    }
91
92    profile
93}
94
95/// Profile with only selected outputs materialized on the host (peaks-only, logits-only, …).
96pub fn profile_graph_io_outputs(graph: &Graph, output_indices: &[usize]) -> GraphIoProfile {
97    let mut profile = profile_graph_io(graph);
98    profile.host_output_bytes = graph
99        .outputs
100        .iter()
101        .enumerate()
102        .filter(|(i, _)| output_indices.contains(i))
103        .filter_map(|(_, id)| graph.node(*id).shape.num_elements())
104        .map(|n| (n * 4) as u64)
105        .sum();
106    profile
107}
108
109fn is_metadata_op(op: &Op) -> bool {
110    matches!(
111        op,
112        Op::Input { .. }
113            | Op::Param { .. }
114            | Op::Constant { .. }
115            | Op::Reshape { .. }
116            | Op::Transpose { .. }
117            | Op::Narrow { .. }
118    )
119}
120
121fn node_io_bytes(node: &Node, graph: &Graph) -> u64 {
122    let out = tensor_bytes(&node.shape);
123    let inputs: u64 = node
124        .inputs
125        .iter()
126        .map(|&id| tensor_bytes(&graph.node(id).shape))
127        .sum();
128    inputs.saturating_add(out)
129}
130
131fn tensor_bytes(shape: &rlx_ir::Shape) -> u64 {
132    shape
133        .num_elements()
134        .map(|n| (n * shape.dtype().size_bytes()) as u64)
135        .unwrap_or(0)
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use rlx_ir::infer::GraphExt;
142    use rlx_ir::{DType, Shape};
143
144    #[test]
145    fn fft_graph_io_profile() {
146        let mut g = Graph::new("fft");
147        let x = g.input("x", Shape::new(&[8, 512], DType::F32));
148        let zeros = g.sub(x, x);
149        let block = g.concat_(vec![x, zeros], 1);
150        let y = g.fft(block, false);
151        g.set_outputs(vec![y]);
152        let p = profile_graph_io(&g);
153        assert!(p.kernel_launches >= 3);
154        assert_eq!(p.host_output_bytes, (8 * 512 * 2 * 4) as u64);
155    }
156
157    #[test]
158    fn peaks_only_output_smaller_readback() {
159        let mut g = Graph::new("peaks");
160        let spec = g.input("spec", Shape::new(&[4, 512], DType::F32));
161        let peaks = g.welch_peaks(spec, 16, 2);
162        g.set_outputs(vec![peaks]);
163        let full = profile_graph_io(&g);
164        assert_eq!(full.host_output_bytes, (2 * 16 * 2 * 4) as u64);
165        assert!(full.device_traffic_bytes > full.host_output_bytes);
166    }
167}