1use rlx_ir::{Graph, Node, Op, OpKind};
19use serde::{Deserialize, Serialize};
20use std::collections::HashSet;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
24pub struct GraphIoOptions {
25 pub fft_host_sync: bool,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
31pub struct GraphIoProfile {
32 pub kernel_launches: usize,
34 pub sync_points: usize,
36 pub host_output_bytes: u64,
38 pub device_traffic_bytes: u64,
40}
41
42impl GraphIoProfile {
43 pub fn host_readback_bytes(&self, unified_memory: bool) -> u64 {
44 if unified_memory {
45 self.host_output_bytes
46 } else {
47 self.host_output_bytes
48 .saturating_add(self.device_traffic_bytes / 4)
49 }
50 }
51}
52
53pub fn metal_host_sync_kinds() -> &'static [OpKind] {
55 &[
56 OpKind::LogMel,
57 OpKind::LogMelBackward,
58 OpKind::Custom,
59 OpKind::WelchPeaks,
60 ]
61}
62
63pub fn profile_graph_io(graph: &Graph) -> GraphIoProfile {
65 profile_graph_io_with_options(graph, GraphIoOptions::default())
66}
67
68pub fn profile_graph_io_with_options(graph: &Graph, opts: GraphIoOptions) -> GraphIoProfile {
69 let mut profile = GraphIoProfile::default();
70 let output_nodes: HashSet<_> = graph.outputs.iter().copied().collect();
71
72 for node in graph.nodes() {
73 if is_metadata_op(&node.op) {
74 continue;
75 }
76 profile.kernel_launches += 1;
77 profile.device_traffic_bytes += node_io_bytes(node, graph);
78
79 let kind = node.op.kind();
80 if metal_host_sync_kinds().contains(&kind) {
81 profile.sync_points += 1;
82 }
83 if opts.fft_host_sync && kind == OpKind::Fft {
84 profile.sync_points += 1;
85 }
86
87 if output_nodes.contains(&node.id) {
88 profile.host_output_bytes += tensor_bytes(&node.shape);
89 }
90 }
91
92 profile
93}
94
95pub fn profile_graph_io_outputs(graph: &Graph, output_indices: &[usize]) -> GraphIoProfile {
97 let mut profile = profile_graph_io(graph);
98 profile.host_output_bytes = graph
99 .outputs
100 .iter()
101 .enumerate()
102 .filter(|(i, _)| output_indices.contains(i))
103 .filter_map(|(_, id)| graph.node(*id).shape.num_elements())
104 .map(|n| (n * 4) as u64)
105 .sum();
106 profile
107}
108
109fn is_metadata_op(op: &Op) -> bool {
110 matches!(
111 op,
112 Op::Input { .. }
113 | Op::Param { .. }
114 | Op::Constant { .. }
115 | Op::Reshape { .. }
116 | Op::Transpose { .. }
117 | Op::Narrow { .. }
118 )
119}
120
121fn node_io_bytes(node: &Node, graph: &Graph) -> u64 {
122 let out = tensor_bytes(&node.shape);
123 let inputs: u64 = node
124 .inputs
125 .iter()
126 .map(|&id| tensor_bytes(&graph.node(id).shape))
127 .sum();
128 inputs.saturating_add(out)
129}
130
131fn tensor_bytes(shape: &rlx_ir::Shape) -> u64 {
132 shape
133 .num_elements()
134 .map(|n| (n * shape.dtype().size_bytes()) as u64)
135 .unwrap_or(0)
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use rlx_ir::infer::GraphExt;
142 use rlx_ir::{DType, Shape};
143
144 #[test]
145 fn fft_graph_io_profile() {
146 let mut g = Graph::new("fft");
147 let x = g.input("x", Shape::new(&[8, 512], DType::F32));
148 let zeros = g.sub(x, x);
149 let block = g.concat_(vec![x, zeros], 1);
150 let y = g.fft(block, false);
151 g.set_outputs(vec![y]);
152 let p = profile_graph_io(&g);
153 assert!(p.kernel_launches >= 3);
154 assert_eq!(p.host_output_bytes, (8 * 512 * 2 * 4) as u64);
155 }
156
157 #[test]
158 fn peaks_only_output_smaller_readback() {
159 let mut g = Graph::new("peaks");
160 let spec = g.input("spec", Shape::new(&[4, 512], DType::F32));
161 let peaks = g.welch_peaks(spec, 16, 2);
162 g.set_outputs(vec![peaks]);
163 let full = profile_graph_io(&g);
164 assert_eq!(full.host_output_bytes, (2 * 16 * 2 * 4) as u64);
165 assert!(full.device_traffic_bytes > full.host_output_bytes);
166 }
167}