use metal::foreign_types::ForeignType;
use crate::device::MlxDevice;
use crate::encoder::{CapturedNode, CapturedOpKind, CommandEncoder, MemRange, RecordedBinding};
use crate::error::Result;
use crate::kernel_registry::KernelRegistry;
use crate::ops;
pub use crate::buffer::MlxBuffer;
pub use crate::dtypes::DType;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum OpKind {
MatMul,
MatMulId,
Norm,
Rope,
Elementwise,
Copy,
Gather,
Sdpa,
Softmax,
MoeGate,
Other,
}
impl OpKind {
pub fn is_reorderable(&self) -> bool {
matches!(
self,
Self::MatMul
| Self::MatMulId
| Self::Norm
| Self::Rope
| Self::Elementwise
| Self::Copy
| Self::Gather
)
}
}
pub struct ComputeGraph {
nodes: Vec<CapturedNode>,
}
impl ComputeGraph {
pub fn new() -> Self {
Self {
nodes: Vec::with_capacity(128),
}
}
pub fn from_nodes(nodes: Vec<CapturedNode>) -> Self {
Self { nodes }
}
pub fn record(&mut self, node: CapturedNode) {
self.nodes.push(node);
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn dispatch_count(&self) -> usize {
self.nodes
.iter()
.filter(|n| matches!(n, CapturedNode::Dispatch { .. }))
.count()
}
pub fn barrier_count(&self) -> usize {
self.nodes
.iter()
.filter(|n| matches!(n, CapturedNode::Barrier))
.count()
}
pub fn nodes(&self) -> &[CapturedNode] {
&self.nodes
}
pub fn unannotated_dispatch_count(&self) -> usize {
self.nodes
.iter()
.filter(|n| matches!(n, CapturedNode::Dispatch { reads, writes, .. }
if reads.is_empty() || writes.is_empty()))
.count()
}
pub fn into_nodes(self) -> Vec<CapturedNode> {
self.nodes
}
pub fn encode_sequential(&self, encoder: &mut CommandEncoder) -> u32 {
let mut barrier_count = 0u32;
for node in &self.nodes {
match node {
CapturedNode::Barrier => {
encoder.memory_barrier();
barrier_count += 1;
}
CapturedNode::Dispatch {
pipeline,
bindings,
threads_per_grid,
threads_per_threadgroup,
threadgroup_memory,
dispatch_kind,
..
} => {
encoder.replay_dispatch(
pipeline,
bindings,
threadgroup_memory,
*threads_per_grid,
*threads_per_threadgroup,
*dispatch_kind,
);
}
}
}
barrier_count
}
pub fn encode_with_barriers(&self, encoder: &mut CommandEncoder) -> u32 {
let mut tracker = ReorderConflictTracker::new();
let mut barrier_count = 0u32;
for node in &self.nodes {
match node {
CapturedNode::Dispatch {
pipeline,
bindings,
threads_per_grid,
threads_per_threadgroup,
threadgroup_memory,
dispatch_kind,
reads,
writes,
..
} => {
let has_ranges = !reads.is_empty() || !writes.is_empty();
if has_ranges && tracker.conflicts(reads, writes) {
encoder.memory_barrier();
tracker.reset();
barrier_count += 1;
}
if has_ranges {
tracker.add(reads, writes);
}
encoder.replay_dispatch(
pipeline,
bindings,
threadgroup_memory,
*threads_per_grid,
*threads_per_threadgroup,
*dispatch_kind,
);
}
CapturedNode::Barrier => {
encoder.memory_barrier();
tracker.reset();
barrier_count += 1;
}
}
}
barrier_count
}
pub fn encode_dual_buffer(
&self,
encoder0: &mut CommandEncoder,
encoder1: &mut CommandEncoder,
) -> (u32, u32) {
let dispatch_total = self.dispatch_count();
let n0 = std::cmp::max(64, dispatch_total / 10);
let split_idx = find_dispatch_split_index(&self.nodes, n0);
let barriers0 = encode_chunk_with_barriers(&self.nodes[..split_idx], encoder0);
encoder0.commit();
let barriers1 = encode_chunk_with_barriers(&self.nodes[split_idx..], encoder1);
(barriers0, barriers1)
}
pub fn fuse(
&mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
) -> Result<u32> {
let mut result: Vec<CapturedNode> = Vec::with_capacity(self.nodes.len());
let mut fusions = 0u32;
let mut i = 0;
while i < self.nodes.len() {
let is_rms_norm = matches!(
&self.nodes[i],
CapturedNode::Dispatch { op_kind: CapturedOpKind::RmsNorm, .. }
);
if !is_rms_norm {
result.push(self.nodes[i].clone());
i += 1;
continue;
}
let mut j = i + 1;
let mut barrier_count = 0usize;
while j < self.nodes.len() && matches!(&self.nodes[j], CapturedNode::Barrier) {
barrier_count += 1;
j += 1;
}
if barrier_count == 0 || j >= self.nodes.len() {
result.push(self.nodes[i].clone());
i += 1;
continue;
}
let is_elem_mul = matches!(
&self.nodes[j],
CapturedNode::Dispatch { op_kind: CapturedOpKind::ElemMul, .. }
);
if !is_elem_mul {
result.push(self.nodes[i].clone());
i += 1;
continue;
}
let (norm_pipeline, norm_bindings, norm_tpg, norm_tptg, norm_tgmem, norm_dk) =
match &self.nodes[i] {
CapturedNode::Dispatch {
pipeline,
bindings,
threads_per_grid,
threads_per_threadgroup,
threadgroup_memory,
dispatch_kind,
..
} => (pipeline, bindings, threads_per_grid, threads_per_threadgroup, threadgroup_memory, dispatch_kind),
_ => unreachable!(),
};
let (mul_bindings, _mul_tpg, _mul_tptg) = match &self.nodes[j] {
CapturedNode::Dispatch {
bindings,
threads_per_grid,
threads_per_threadgroup,
..
} => (bindings, threads_per_grid, threads_per_threadgroup),
_ => unreachable!(),
};
let norm_output_ptr = Self::buffer_ptr_for_slot(norm_bindings, 2);
let mul_a_ptr = Self::buffer_ptr_for_slot(mul_bindings, 0);
let mul_b_ptr = Self::buffer_ptr_for_slot(mul_bindings, 1);
if norm_output_ptr.is_none() || (norm_output_ptr != mul_a_ptr && norm_output_ptr != mul_b_ptr) {
result.push(self.nodes[i].clone());
i += 1;
continue;
}
let scale_slot = if norm_output_ptr == mul_a_ptr { 1 } else { 0 };
let (norm_input, norm_weight, scale, mul_output, norm_params) = match (
Self::get_binding(norm_bindings, 0),
Self::get_binding(norm_bindings, 1),
Self::get_binding(mul_bindings, scale_slot),
Self::get_binding(mul_bindings, 2),
Self::get_binding(norm_bindings, 3),
) {
(Some(a), Some(b), Some(c), Some(d), Some(e)) => (a, b, c, d, e),
_ => {
result.push(self.nodes[i].clone());
i += 1;
continue;
}
};
let fused_name = match Self::fused_pipeline_name(norm_pipeline) {
Some(name) => name,
None => {
result.push(self.nodes[i].clone());
i += 1;
continue;
}
};
let fused_pipeline = registry.get_pipeline(fused_name, device)?;
let fused_bindings = vec![
(0, norm_input),
(1, norm_weight),
(2, scale),
(3, mul_output),
(4, norm_params),
];
let (fused_reads, fused_writes) = match (&self.nodes[i], &self.nodes[j]) {
(
CapturedNode::Dispatch { reads: nr, writes: _nw, .. },
CapturedNode::Dispatch { reads: mr, writes: mw, .. },
) => {
let mut reads = nr.clone();
reads.extend_from_slice(mr);
(reads, mw.clone())
}
_ => (Vec::new(), Vec::new()),
};
result.push(CapturedNode::Dispatch {
pipeline: fused_pipeline.to_owned(),
bindings: fused_bindings,
threads_per_grid: *norm_tpg,
threads_per_threadgroup: *norm_tptg,
threadgroup_memory: norm_tgmem.clone(),
dispatch_kind: *norm_dk,
op_kind: CapturedOpKind::Other, reads: fused_reads,
writes: fused_writes,
});
fusions += 1;
i = j + 1;
}
self.nodes = result;
Ok(fusions)
}
pub fn reorder(&mut self) -> u32 {
self.nodes.retain(|n| !matches!(n, CapturedNode::Barrier));
let n = self.nodes.len();
if n == 0 {
return 0;
}
let mut result: Vec<usize> = Vec::with_capacity(n);
let mut used = vec![false; n];
let mut mrs0 = ReorderConflictTracker::new();
let mut mrs1 = ReorderConflictTracker::new();
const N_FORWARD: usize = 64;
for i0 in 0..n {
if used[i0] {
continue;
}
let node0 = &self.nodes[i0];
let (reads0, writes0, op_kind0) = match node0 {
CapturedNode::Dispatch { reads, writes, op_kind, .. } => {
(reads.as_slice(), writes.as_slice(), *op_kind)
}
CapturedNode::Barrier => continue, };
let has_ranges = !reads0.is_empty() || !writes0.is_empty();
if has_ranges && mrs0.conflicts(reads0, writes0) {
mrs1.reset();
mrs1.add(reads0, writes0);
let end = (i0 + N_FORWARD).min(n);
for i1 in (i0 + 1)..end {
if used[i1] {
continue;
}
let node1 = &self.nodes[i1];
let (reads1, writes1, op_kind1) = match node1 {
CapturedNode::Dispatch { reads, writes, op_kind, .. } => {
(reads.as_slice(), writes.as_slice(), *op_kind)
}
CapturedNode::Barrier => continue,
};
if !op_kind1.is_reorderable() {
break;
}
let is_empty1 = reads1.is_empty() && writes1.is_empty();
if (is_empty1 || !mrs0.conflicts(reads1, writes1))
&& !mrs1.conflicts(reads1, writes1)
{
mrs0.add(reads1, writes1);
result.push(i1);
used[i1] = true;
} else {
mrs1.add(reads1, writes1);
}
}
mrs0.reset();
}
let _ = op_kind0; mrs0.add(reads0, writes0);
result.push(i0);
}
let mut reordered_count = 0u32;
for (pos, &orig_idx) in result.iter().enumerate() {
if orig_idx != pos {
reordered_count += 1;
}
}
let old_nodes = std::mem::take(&mut self.nodes);
self.nodes = result.iter().map(|&idx| old_nodes[idx].clone()).collect();
if std::env::var("HF2Q_REORDER_DUMP").is_ok() {
eprintln!(
" [REORDER] nodes={} reordered={} ({:.1}%)",
n,
reordered_count,
100.0 * reordered_count as f64 / n as f64,
);
}
reordered_count
}
fn buffer_ptr_for_slot(bindings: &[(u64, RecordedBinding)], slot: u64) -> Option<*const std::ffi::c_void> {
for (idx, binding) in bindings {
if *idx == slot {
if let RecordedBinding::Buffer { metal_buffer, offset: _ } = binding {
let ptr: *const std::ffi::c_void = metal_buffer.as_ptr() as *const _;
return Some(ptr);
}
}
}
None
}
fn get_binding(bindings: &[(u64, RecordedBinding)], slot: u64) -> Option<RecordedBinding> {
for (idx, binding) in bindings {
if *idx == slot {
return Some(binding.clone());
}
}
None
}
fn fused_pipeline_name(pipeline: &metal::ComputePipelineState) -> Option<&'static str> {
match pipeline.label() {
"rms_norm_f32" => Some("rms_norm_mul_f32"),
"rms_norm_f16" => Some("rms_norm_mul_f16"),
"rms_norm_bf16" => Some("rms_norm_mul_bf16"),
_ => None,
}
}
}
impl Default for ComputeGraph {
fn default() -> Self {
Self::new()
}
}
fn find_dispatch_split_index(nodes: &[CapturedNode], n0: usize) -> usize {
let mut dispatches_seen = 0usize;
for (i, node) in nodes.iter().enumerate() {
if matches!(node, CapturedNode::Dispatch { .. }) {
dispatches_seen += 1;
if dispatches_seen == n0 {
return i + 1; }
}
}
nodes.len()
}
fn encode_chunk_with_barriers(nodes: &[CapturedNode], encoder: &mut CommandEncoder) -> u32 {
let mut tracker = ReorderConflictTracker::new();
let mut barrier_count = 0u32;
for node in nodes {
match node {
CapturedNode::Dispatch {
pipeline,
bindings,
threads_per_grid,
threads_per_threadgroup,
threadgroup_memory,
dispatch_kind,
reads,
writes,
..
} => {
let has_ranges = !reads.is_empty() || !writes.is_empty();
if has_ranges && tracker.conflicts(reads, writes) {
encoder.memory_barrier();
tracker.reset();
barrier_count += 1;
}
if has_ranges {
tracker.add(reads, writes);
}
encoder.replay_dispatch(
pipeline,
bindings,
threadgroup_memory,
*threads_per_grid,
*threads_per_threadgroup,
*dispatch_kind,
);
}
CapturedNode::Barrier => {
encoder.memory_barrier();
tracker.reset();
barrier_count += 1;
}
}
}
barrier_count
}
struct ReorderConflictTracker {
ranges: Vec<(usize, usize, bool)>,
}
impl ReorderConflictTracker {
fn new() -> Self {
Self {
ranges: Vec::with_capacity(64),
}
}
fn reset(&mut self) {
self.ranges.clear();
}
fn conflicts(&self, reads: &[MemRange], writes: &[MemRange]) -> bool {
for &(r_start, r_end) in reads {
for &(s, e, is_write) in &self.ranges {
if is_write && r_start < e && r_end > s {
return true;
}
}
}
for &(w_start, w_end) in writes {
for &(s, e, _) in &self.ranges {
if w_start < e && w_end > s {
return true;
}
}
}
false
}
fn add(&mut self, reads: &[MemRange], writes: &[MemRange]) {
for &(start, end) in reads {
self.ranges.push((start, end, false));
}
for &(start, end) in writes {
self.ranges.push((start, end, true));
}
}
}
pub struct GraphExecutor {
device: MlxDevice,
}
impl GraphExecutor {
pub fn new(device: MlxDevice) -> Self {
Self { device }
}
pub fn begin(&self) -> Result<GraphSession<'_>> {
let encoder = self.device.command_encoder()?;
Ok(GraphSession {
encoder,
device: &self.device,
barrier_count: 0,
tracker: ConflictTracker::new(),
dispatch_in_group: 0,
total_dispatches: 0,
group_sizes: [0; 8],
recording: false,
})
}
pub fn begin_recorded(&self) -> Result<GraphSession<'_>> {
let mut encoder = self.device.command_encoder()?;
encoder.start_capture();
Ok(GraphSession {
encoder,
device: &self.device,
barrier_count: 0,
tracker: ConflictTracker::new(),
dispatch_in_group: 0,
total_dispatches: 0,
group_sizes: [0; 8],
recording: true,
})
}
pub fn device(&self) -> &MlxDevice {
&self.device
}
}
pub struct ConflictTracker {
ranges: Vec<(usize, usize, bool)>,
}
impl ConflictTracker {
fn new() -> Self {
Self {
ranges: Vec::with_capacity(32),
}
}
fn reset(&mut self) {
self.ranges.clear();
}
fn conflicts_reason(&self, reads: &[&MlxBuffer], writes: &[&MlxBuffer])
-> Option<(&'static str, usize, usize)>
{
for r in reads {
let r_start = r.contents_ptr() as usize;
let r_end = r_start + r.byte_len();
for &(s, e, is_write) in &self.ranges {
if is_write && r_start < e && r_end > s {
return Some(("RAW", r_start, s));
}
}
}
for w in writes {
let w_start = w.contents_ptr() as usize;
let w_end = w_start + w.byte_len();
for &(s, e, is_write) in &self.ranges {
if w_start < e && w_end > s {
let kind = if is_write { "WAW" } else { "WAR" };
return Some((kind, w_start, s));
}
}
}
None
}
fn add(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
for r in reads {
let start = r.contents_ptr() as usize;
let end = start + r.byte_len();
self.ranges.push((start, end, false));
}
for w in writes {
let start = w.contents_ptr() as usize;
let end = start + w.byte_len();
self.ranges.push((start, end, true));
}
}
}
pub struct GraphSession<'a> {
encoder: CommandEncoder,
device: &'a MlxDevice,
barrier_count: u32,
tracker: ConflictTracker,
dispatch_in_group: u32,
total_dispatches: u32,
group_sizes: [u32; 8],
recording: bool,
}
impl<'a> GraphSession<'a> {
pub fn rms_norm(
&mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()> {
ops::rms_norm::dispatch_rms_norm(
&mut self.encoder,
registry,
device,
input,
weight,
output,
params_buf,
rows,
dim,
)
}
pub fn quantized_matmul(
&mut self,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
params: &ops::quantized_matmul::QuantizedMatmulParams,
) -> Result<MlxBuffer> {
ops::quantized_matmul::quantized_matmul(
&mut self.encoder,
registry,
device,
input,
weight,
scales,
biases,
params,
)
}
pub fn quantized_matmul_simd(
&mut self,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
params: &ops::quantized_matmul::QuantizedMatmulParams,
) -> Result<MlxBuffer> {
ops::quantized_matmul::quantized_matmul_simd(
&mut self.encoder,
registry,
device,
input,
weight,
scales,
biases,
params,
)
}
pub fn quantized_matmul_ggml(
&mut self,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &mut MlxBuffer,
params: &ops::quantized_matmul_ggml::GgmlQuantizedMatmulParams,
) -> Result<()> {
ops::quantized_matmul_ggml::quantized_matmul_ggml(
&mut self.encoder,
registry,
device,
input,
weight,
output,
params,
)
}
#[allow(clippy::too_many_arguments)]
pub fn quantized_matmul_id_ggml(
&mut self,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
ids: &MlxBuffer,
output: &mut MlxBuffer,
params: &ops::quantized_matmul_id_ggml::GgmlQuantizedMatmulIdParams,
) -> Result<()> {
ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml(
&mut self.encoder,
registry,
device,
input,
weight,
ids,
output,
params,
)
}
pub fn sdpa(
&mut self,
registry: &mut KernelRegistry,
device: &MlxDevice,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
output: &MlxBuffer,
params: &ops::sdpa::SdpaParams,
batch_size: u32,
) -> Result<()> {
ops::sdpa::sdpa(
&mut self.encoder,
registry,
device,
q,
k,
v,
output,
params,
batch_size,
)
}
pub fn flash_attn_vec(
&mut self,
registry: &mut KernelRegistry,
device: &MlxDevice,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
output: &MlxBuffer,
tmp: &MlxBuffer,
params: &ops::flash_attn_vec::FlashAttnVecParams,
) -> Result<()> {
ops::flash_attn_vec::flash_attn_vec(
&mut self.encoder,
registry,
device,
q,
k,
v,
output,
tmp,
params,
)
}
pub fn elementwise_add(
&mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
a: &MlxBuffer,
b: &MlxBuffer,
output: &MlxBuffer,
n_elements: usize,
dtype: DType,
) -> Result<()> {
ops::elementwise::elementwise_add(
&mut self.encoder,
registry,
device,
a,
b,
output,
n_elements,
dtype,
)
}
pub fn elementwise_mul(
&mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
a: &MlxBuffer,
b: &MlxBuffer,
output: &MlxBuffer,
n_elements: usize,
dtype: DType,
) -> Result<()> {
ops::elementwise::elementwise_mul(
&mut self.encoder,
registry,
device,
a,
b,
output,
n_elements,
dtype,
)
}
pub fn rope(
&mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
positions_buf: &MlxBuffer,
seq_len: u32,
head_dim: u32,
) -> Result<()> {
ops::rope::dispatch_rope(
&mut self.encoder,
registry,
device,
input,
output,
params_buf,
positions_buf,
seq_len,
head_dim,
)
}
pub fn gelu(
&mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
) -> Result<()> {
ops::gelu::dispatch_gelu(
&mut self.encoder,
registry,
device,
input,
output,
)
}
pub fn softmax(
&mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
cols: u32,
) -> Result<()> {
ops::softmax::dispatch_softmax(
&mut self.encoder,
registry,
device,
input,
output,
params_buf,
rows,
cols,
)
}
pub fn softcap(
&mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
cap: f32,
) -> Result<()> {
ops::softcap::dispatch_softcap(
&mut self.encoder,
registry,
device,
input,
output,
params_buf,
cap,
)
}
pub fn rms_norm_no_scale_f32(
&mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()> {
ops::rms_norm::dispatch_rms_norm_no_scale_f32(
&mut self.encoder,
registry,
device,
input,
output,
params_buf,
rows,
dim,
)
}
#[allow(clippy::too_many_arguments)]
pub fn rope_neox_f32(
&mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
positions_buf: &MlxBuffer,
freq_factors: Option<&MlxBuffer>,
seq_len: u32,
n_heads: u32,
head_dim: u32,
rope_dim: u32,
) -> Result<()> {
ops::rope::dispatch_rope_neox_f32(
&mut self.encoder,
registry,
device,
input,
output,
params_buf,
positions_buf,
freq_factors,
seq_len,
n_heads,
head_dim,
rope_dim,
)
}
#[inline]
pub fn barrier(&mut self) {
if self.dispatch_in_group > 0 {
let idx = (self.dispatch_in_group as usize).min(self.group_sizes.len()) - 1;
self.group_sizes[idx] += 1;
}
self.encoder.memory_barrier();
self.tracker.reset();
self.barrier_count += 1;
self.dispatch_in_group = 0;
}
#[inline]
pub fn barrier_between(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
if self.recording {
let read_ranges: Vec<MemRange> = reads
.iter()
.map(|b| {
let start = b.contents_ptr() as usize;
(start, start + b.byte_len())
})
.collect();
let write_ranges: Vec<MemRange> = writes
.iter()
.map(|b| {
let start = b.contents_ptr() as usize;
(start, start + b.byte_len())
})
.collect();
self.encoder.set_pending_buffer_ranges(read_ranges, write_ranges);
}
let reason = self.tracker.conflicts_reason(reads, writes);
if let Some((_kind, _new_ptr, _existing_ptr)) = reason {
if self.dispatch_in_group > 0 {
let idx = (self.dispatch_in_group as usize).min(self.group_sizes.len()) - 1;
self.group_sizes[idx] += 1;
}
self.encoder.memory_barrier();
self.tracker.reset();
self.barrier_count += 1;
self.dispatch_in_group = 0;
}
self.dispatch_in_group += 1;
self.total_dispatches += 1;
self.tracker.add(reads, writes);
}
pub fn dump_group_stats(&self) {
let mut gs = self.group_sizes;
if self.dispatch_in_group > 0 {
let idx = (self.dispatch_in_group as usize).min(gs.len()) - 1;
gs[idx] += 1;
}
let total_groups: u32 = gs.iter().sum();
eprintln!(" [GROUP_STATS] dispatches={} barriers={} groups={} ratio={:.2}",
self.total_dispatches, self.barrier_count, total_groups,
if total_groups > 0 { self.total_dispatches as f64 / total_groups as f64 } else { 0.0 });
for (i, &count) in gs.iter().enumerate() {
if count > 0 {
eprintln!(" size {}: {} groups", i + 1, count);
}
}
}
#[inline]
pub fn track_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
self.tracker.add(reads, writes);
}
#[inline]
pub fn barrier_count(&self) -> u32 {
self.barrier_count
}
pub fn tracker_overhead_ns(&self) -> u64 {
0
}
pub fn encoder_mut(&mut self) -> &mut CommandEncoder {
&mut self.encoder
}
pub fn device(&self) -> &MlxDevice {
self.device
}
pub fn is_recording(&self) -> bool {
self.recording
}
pub fn finish(mut self) -> Result<()> {
if self.recording {
if let Some(nodes) = self.encoder.take_capture() {
let graph = ComputeGraph::from_nodes(nodes);
graph.encode_sequential(&mut self.encoder);
}
}
self.encoder.commit_and_wait()
}
pub fn commit(mut self) -> CommandEncoder {
if self.recording {
if let Some(nodes) = self.encoder.take_capture() {
let graph = ComputeGraph::from_nodes(nodes);
graph.encode_sequential(&mut self.encoder);
}
}
self.encoder.commit();
self.encoder
}
pub fn finish_with_timing(mut self, session_begin: std::time::Instant) -> Result<(u64, u64)> {
if self.recording {
if let Some(nodes) = self.encoder.take_capture() {
let graph = ComputeGraph::from_nodes(nodes);
graph.encode_sequential(&mut self.encoder);
}
}
let commit_start = std::time::Instant::now();
let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
self.encoder.commit();
self.encoder.wait_until_completed()?;
let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
Ok((encoding_ns, gpu_wait_ns))
}
pub fn finish_with_fusion(
mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
) -> Result<u32> {
let mut fusions = 0;
if self.recording {
if let Some(nodes) = self.encoder.take_capture() {
let mut graph = ComputeGraph::from_nodes(nodes);
fusions = graph.fuse(registry, device)?;
graph.encode_sequential(&mut self.encoder);
}
}
self.encoder.commit_and_wait()?;
Ok(fusions)
}
pub fn finish_with_fusion_and_timing(
mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
session_begin: std::time::Instant,
) -> Result<(u64, u64, u32)> {
let mut fusions = 0;
if self.recording {
if let Some(nodes) = self.encoder.take_capture() {
let mut graph = ComputeGraph::from_nodes(nodes);
fusions = graph.fuse(registry, device)?;
graph.encode_sequential(&mut self.encoder);
}
}
let commit_start = std::time::Instant::now();
let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
self.encoder.commit();
self.encoder.wait_until_completed()?;
let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
Ok((encoding_ns, gpu_wait_ns, fusions))
}
pub fn finish_with_fusion_and_reorder(
mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
) -> Result<(u32, u32)> {
let mut fusions = 0;
let mut reordered = 0;
if self.recording {
if let Some(nodes) = self.encoder.take_capture() {
let mut graph = ComputeGraph::from_nodes(nodes);
fusions = graph.fuse(registry, device)?;
reordered = graph.reorder();
graph.encode_with_barriers(&mut self.encoder);
}
}
self.encoder.commit_and_wait()?;
Ok((fusions, reordered))
}
pub fn finish_with_fusion_reorder_and_timing(
mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
session_begin: std::time::Instant,
) -> Result<(u64, u64, u32, u32)> {
let mut fusions = 0;
let mut reordered = 0;
if self.recording {
if let Some(nodes) = self.encoder.take_capture() {
let mut graph = ComputeGraph::from_nodes(nodes);
fusions = graph.fuse(registry, device)?;
reordered = graph.reorder();
graph.encode_with_barriers(&mut self.encoder);
}
}
let commit_start = std::time::Instant::now();
let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
self.encoder.commit();
self.encoder.wait_until_completed()?;
let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
Ok((encoding_ns, gpu_wait_ns, fusions, reordered))
}
pub fn finish_optimized(
mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
) -> Result<(u32, u32, u32, u32)> {
let mut fusions = 0;
let mut reordered = 0;
let mut barriers0 = 0u32;
let mut barriers1 = 0u32;
if self.recording {
if let Some(nodes) = self.encoder.take_capture() {
self.encoder.commit();
let mut graph = ComputeGraph::from_nodes(nodes);
fusions = graph.fuse(registry, device)?;
reordered = graph.reorder();
let mut enc0 = self.device.command_encoder()?;
let mut enc1 = self.device.command_encoder()?;
let (b0, b1) = graph.encode_dual_buffer(&mut enc0, &mut enc1);
barriers0 = b0;
barriers1 = b1;
enc1.commit_and_wait()?;
return Ok((fusions, reordered, barriers0, barriers1));
}
}
self.encoder.commit_and_wait()?;
Ok((fusions, reordered, barriers0, barriers1))
}
pub fn finish_optimized_with_timing(
mut self,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
session_begin: std::time::Instant,
) -> Result<(u64, u64, u32, u32, u32, u32)> {
let mut fusions = 0;
let mut reordered = 0;
let mut barriers0 = 0u32;
let mut barriers1 = 0u32;
if self.recording {
if let Some(nodes) = self.encoder.take_capture() {
self.encoder.commit();
let opt_t0 = std::time::Instant::now();
let mut graph = ComputeGraph::from_nodes(nodes);
let fuse_t0 = std::time::Instant::now();
fusions = graph.fuse(registry, device)?;
let fuse_us = fuse_t0.elapsed().as_micros();
let reorder_t0 = std::time::Instant::now();
let unannotated = graph.unannotated_dispatch_count();
if unannotated == 0 {
reordered = graph.reorder();
} else if std::env::var("HF2Q_MLX_TIMING").is_ok() {
eprintln!(" [GRAPH_OPT] WARN: skipping reorder — {} of {} dispatches lack range annotations",
unannotated, graph.dispatch_count());
}
let reorder_us = reorder_t0.elapsed().as_micros();
let opt_us = opt_t0.elapsed().as_micros();
let diag = std::env::var("HF2Q_GRAPH_DIAG").is_ok();
let t0 = std::time::Instant::now();
let mut enc0 = self.device.command_encoder()?;
let mut enc1 = self.device.command_encoder()?;
let enc_create_us = t0.elapsed().as_micros();
let t1 = std::time::Instant::now();
let (b0, b1) = graph.encode_dual_buffer(&mut enc0, &mut enc1);
barriers0 = b0;
barriers1 = b1;
let encode_us = t1.elapsed().as_micros();
let encoding_ns = session_begin.elapsed().as_nanos() as u64;
let wait_start = std::time::Instant::now();
enc1.commit_and_wait()?;
let gpu_wait_ns = wait_start.elapsed().as_nanos() as u64;
if diag {
eprintln!(" [DIAG] fuse={:.1}ms reorder={:.1}ms opt_total={:.1}ms enc_create={:.1}ms encode={:.1}ms gpu_wait={:.1}ms barriers={}+{}",
fuse_us as f64 / 1e3, reorder_us as f64 / 1e3, opt_us as f64 / 1e3,
enc_create_us as f64 / 1e3, encode_us as f64 / 1e3,
gpu_wait_ns as f64 / 1e6, b0, b1);
}
return Ok((encoding_ns, gpu_wait_ns, fusions, reordered, barriers0, barriers1));
}
}
let commit_start = std::time::Instant::now();
let encoding_ns = commit_start.duration_since(session_begin).as_nanos() as u64;
self.encoder.commit();
self.encoder.wait_until_completed()?;
let gpu_wait_ns = commit_start.elapsed().as_nanos() as u64;
Ok((encoding_ns, gpu_wait_ns, fusions, reordered, barriers0, barriers1))
}
}