#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct GraphIoProfile {
pub kernel_launches: usize,
pub sync_points: usize,
pub host_output_bytes: u64,
pub device_traffic_bytes: u64,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct FusionBenefit {
pub launches_saved: isize,
pub sync_points_saved: isize,
pub host_readback_bytes_saved: i64,
pub device_traffic_bytes_saved: i64,
}
impl FusionBenefit {
pub fn score_ns(
&self,
dispatch_ns: f64,
roundtrip_ns: f64,
memory_bw: f64,
host_readback_bw: f64,
unified_memory: bool,
) -> f64 {
let rb_before = if unified_memory {
0.0
} else {
host_readback_bw
};
(self.launches_saved as f64) * dispatch_ns
+ (self.sync_points_saved as f64) * roundtrip_ns
+ (self.device_traffic_bytes_saved as f64) / memory_bw.max(1.0)
+ (self.host_readback_bytes_saved as f64) / rb_before.max(1.0)
}
pub fn should_fuse(
&self,
dispatch_ns: f64,
roundtrip_ns: f64,
memory_bw: f64,
host_readback_bw: f64,
unified_memory: bool,
min_gain_ns: f64,
) -> bool {
self.score_ns(
dispatch_ns,
roundtrip_ns,
memory_bw,
host_readback_bw,
unified_memory,
) >= min_gain_ns
}
}
pub fn fusion_benefit(before: &GraphIoProfile, after: &GraphIoProfile) -> FusionBenefit {
FusionBenefit {
launches_saved: before.kernel_launches as isize - after.kernel_launches as isize,
sync_points_saved: before.sync_points as isize - after.sync_points as isize,
host_readback_bytes_saved: before.host_output_bytes as i64 - after.host_output_bytes as i64,
device_traffic_bytes_saved: before.device_traffic_bytes as i64
- after.device_traffic_bytes as i64,
}
}
#[derive(Debug, Clone, Copy)]
pub struct IoFusionGate {
pub dispatch_ns: f64,
pub roundtrip_ns: f64,
pub memory_bw: f64,
pub host_readback_bw: f64,
pub unified_memory: bool,
pub host_thunk_penalty_ns: f64,
pub min_gain_ns: f64,
}
use rlx_ir::{Graph, Node, Op, OpKind};
use std::collections::HashSet;
fn is_metadata_op(op: &Op) -> bool {
matches!(
op,
Op::Input { .. }
| Op::Param { .. }
| Op::Constant { .. }
| Op::Reshape { .. }
| Op::Transpose { .. }
| Op::Narrow { .. }
)
}
fn tensor_bytes(shape: &rlx_ir::Shape) -> u64 {
shape
.num_elements()
.map(|n| (n * shape.dtype().size_bytes()) as u64)
.unwrap_or(0)
}
fn node_io_bytes(node: &Node, graph: &Graph) -> u64 {
let out = tensor_bytes(&node.shape);
let inputs: u64 = node
.inputs
.iter()
.map(|&id| tensor_bytes(&graph.node(id).shape))
.sum();
inputs.saturating_add(out)
}
pub fn host_sync_op_kinds() -> &'static [OpKind] {
&[
OpKind::LogMel,
OpKind::LogMelBackward,
OpKind::Custom,
OpKind::WelchPeaks,
]
}
pub fn profile_graph_io(graph: &Graph) -> GraphIoProfile {
profile_graph_io_outputs(graph, &graph.outputs)
}
pub fn profile_graph_io_outputs(graph: &Graph, output_ids: &[rlx_ir::NodeId]) -> GraphIoProfile {
let mut profile = GraphIoProfile::default();
let output_nodes: HashSet<_> = output_ids.iter().copied().collect();
for node in graph.nodes() {
if is_metadata_op(&node.op) {
continue;
}
profile.kernel_launches += 1;
profile.device_traffic_bytes += node_io_bytes(node, graph);
let kind = node.op.kind();
if host_sync_op_kinds().contains(&kind) {
profile.sync_points += 1;
}
if output_nodes.contains(&node.id) {
profile.host_output_bytes += tensor_bytes(&node.shape);
}
}
profile
}
impl IoFusionGate {
pub fn score_ns(&self, benefit: &FusionBenefit) -> f64 {
let mut score = benefit.score_ns(
self.dispatch_ns,
self.roundtrip_ns,
self.memory_bw,
self.host_readback_bw,
self.unified_memory,
);
if benefit.sync_points_saved < 0 {
score -= (-benefit.sync_points_saved as f64) * self.host_thunk_penalty_ns;
}
score
}
pub fn should_fuse(&self, before: &GraphIoProfile, after: &GraphIoProfile) -> bool {
let benefit = fusion_benefit(before, after);
self.score_ns(&benefit) >= self.min_gain_ns
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn peaks_fusion_saves_readback() {
let dense = GraphIoProfile {
kernel_launches: 2,
sync_points: 1,
host_output_bytes: 8192,
device_traffic_bytes: 16384,
};
let fused = GraphIoProfile {
kernel_launches: 1,
sync_points: 1,
host_output_bytes: 512,
device_traffic_bytes: 16384,
};
let b = fusion_benefit(&dense, &fused);
assert!(b.host_readback_bytes_saved > 0);
assert!(b.launches_saved >= 0);
}
#[test]
fn io_gate_favors_welch_peaks_fusion_on_metal() {
let dense = GraphIoProfile {
kernel_launches: 3,
sync_points: 0,
host_output_bytes: 33_554_432,
device_traffic_bytes: 184_549_376,
};
let fused = GraphIoProfile {
kernel_launches: 4,
sync_points: 1,
host_output_bytes: 1_048_576,
device_traffic_bytes: 219_152_384,
};
let gate = IoFusionGate {
dispatch_ns: 500.0,
roundtrip_ns: 5_000.0,
memory_bw: 200.0,
host_readback_bw: 200.0,
unified_memory: true,
host_thunk_penalty_ns: 2_000_000.0,
min_gain_ns: 1_000.0,
};
assert!(gate.should_fuse(&dense, &fused));
}
#[test]
fn io_gate_rejects_welch_peaks_fusion_on_wgpu() {
let dense = GraphIoProfile {
kernel_launches: 3,
sync_points: 0,
host_output_bytes: 33_554_432,
device_traffic_bytes: 184_549_376,
};
let fused = GraphIoProfile {
kernel_launches: 4,
sync_points: 1,
host_output_bytes: 1_048_576,
device_traffic_bytes: 219_152_384,
};
let gate = IoFusionGate {
dispatch_ns: 3_000.0,
roundtrip_ns: 30_000.0,
memory_bw: 100.0,
host_readback_bw: 40.0,
unified_memory: false,
host_thunk_penalty_ns: 25_000_000.0,
min_gain_ns: 10_000.0,
};
assert!(!gate.should_fuse(&dense, &fused));
}
}