use rlx_ir::{Graph, Node, Op, OpKind};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct GraphIoOptions {
pub fft_host_sync: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct GraphIoProfile {
pub kernel_launches: usize,
pub sync_points: usize,
pub host_output_bytes: u64,
pub device_traffic_bytes: u64,
}
impl GraphIoProfile {
pub fn host_readback_bytes(&self, unified_memory: bool) -> u64 {
if unified_memory {
self.host_output_bytes
} else {
self.host_output_bytes
.saturating_add(self.device_traffic_bytes / 4)
}
}
}
pub fn metal_host_sync_kinds() -> &'static [OpKind] {
&[
OpKind::LogMel,
OpKind::LogMelBackward,
OpKind::Custom,
OpKind::WelchPeaks,
]
}
pub fn profile_graph_io(graph: &Graph) -> GraphIoProfile {
profile_graph_io_with_options(graph, GraphIoOptions::default())
}
pub fn profile_graph_io_with_options(graph: &Graph, opts: GraphIoOptions) -> GraphIoProfile {
let mut profile = GraphIoProfile::default();
let output_nodes: HashSet<_> = graph.outputs.iter().copied().collect();
for node in graph.nodes() {
if is_metadata_op(&node.op) {
continue;
}
profile.kernel_launches += 1;
profile.device_traffic_bytes += node_io_bytes(node, graph);
let kind = node.op.kind();
if metal_host_sync_kinds().contains(&kind) {
profile.sync_points += 1;
}
if opts.fft_host_sync && kind == OpKind::Fft {
profile.sync_points += 1;
}
if output_nodes.contains(&node.id) {
profile.host_output_bytes += tensor_bytes(&node.shape);
}
}
profile
}
pub fn profile_graph_io_outputs(graph: &Graph, output_indices: &[usize]) -> GraphIoProfile {
let mut profile = profile_graph_io(graph);
profile.host_output_bytes = graph
.outputs
.iter()
.enumerate()
.filter(|(i, _)| output_indices.contains(i))
.filter_map(|(_, id)| graph.node(*id).shape.num_elements())
.map(|n| (n * 4) as u64)
.sum();
profile
}
fn is_metadata_op(op: &Op) -> bool {
matches!(
op,
Op::Input { .. }
| Op::Param { .. }
| Op::Constant { .. }
| Op::Reshape { .. }
| Op::Transpose { .. }
| Op::Narrow { .. }
)
}
fn node_io_bytes(node: &Node, graph: &Graph) -> u64 {
let out = tensor_bytes(&node.shape);
let inputs: u64 = node
.inputs
.iter()
.map(|&id| tensor_bytes(&graph.node(id).shape))
.sum();
inputs.saturating_add(out)
}
fn tensor_bytes(shape: &rlx_ir::Shape) -> u64 {
shape
.num_elements()
.map(|n| (n * shape.dtype().size_bytes()) as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use rlx_ir::infer::GraphExt;
use rlx_ir::{DType, Shape};
#[test]
fn fft_graph_io_profile() {
let mut g = Graph::new("fft");
let x = g.input("x", Shape::new(&[8, 512], DType::F32));
let zeros = g.sub(x, x);
let block = g.concat_(vec![x, zeros], 1);
let y = g.fft(block, false);
g.set_outputs(vec![y]);
let p = profile_graph_io(&g);
assert!(p.kernel_launches >= 3);
assert_eq!(p.host_output_bytes, (8 * 512 * 2 * 4) as u64);
}
#[test]
fn peaks_only_output_smaller_readback() {
let mut g = Graph::new("peaks");
let spec = g.input("spec", Shape::new(&[4, 512], DType::F32));
let peaks = g.welch_peaks(spec, 16, 2);
g.set_outputs(vec![peaks]);
let full = profile_graph_io(&g);
assert_eq!(full.host_output_bytes, (2 * 16 * 2 * 4) as u64);
assert!(full.device_traffic_bytes > full.host_output_bytes);
}
}