pub struct GraphSession<'a> { /* private fields */ }Implementations§
Source§impl<'a> GraphSession<'a>
impl<'a> GraphSession<'a>
Sourcepub fn rms_norm(
&mut self,
registry: &mut KernelRegistry,
device: &DeviceRef,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()>
pub fn rms_norm( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, input: &MlxBuffer, weight: &MlxBuffer, output: &MlxBuffer, params_buf: &MlxBuffer, rows: u32, dim: u32, ) -> Result<()>
Encode an RMS normalization into this session’s encoder.
Delegates to ops::rms_norm::dispatch_rms_norm.
Sourcepub fn quantized_matmul(
&mut self,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
params: &QuantizedMatmulParams,
) -> Result<MlxBuffer>
pub fn quantized_matmul( &mut self, registry: &mut KernelRegistry, device: &MlxDevice, input: &MlxBuffer, weight: &MlxBuffer, scales: &MlxBuffer, biases: &MlxBuffer, params: &QuantizedMatmulParams, ) -> Result<MlxBuffer>
Encode a quantized matrix multiplication into this session’s encoder.
Delegates to ops::quantized_matmul::quantized_matmul.
Returns the freshly allocated output buffer.
Sourcepub fn quantized_matmul_simd(
&mut self,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
params: &QuantizedMatmulParams,
) -> Result<MlxBuffer>
pub fn quantized_matmul_simd( &mut self, registry: &mut KernelRegistry, device: &MlxDevice, input: &MlxBuffer, weight: &MlxBuffer, scales: &MlxBuffer, biases: &MlxBuffer, params: &QuantizedMatmulParams, ) -> Result<MlxBuffer>
Encode a SIMD-optimized quantized matmul into this session’s encoder.
Delegates to ops::quantized_matmul::quantized_matmul_simd.
Returns the freshly allocated output buffer.
Sourcepub fn quantized_matmul_ggml(
&mut self,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &mut MlxBuffer,
params: &GgmlQuantizedMatmulParams,
) -> Result<()>
pub fn quantized_matmul_ggml( &mut self, registry: &mut KernelRegistry, device: &MlxDevice, input: &MlxBuffer, weight: &MlxBuffer, output: &mut MlxBuffer, params: &GgmlQuantizedMatmulParams, ) -> Result<()>
Encode a GGML block-format quantized mat-vec into this session’s encoder.
Delegates to ops::quantized_matmul_ggml::quantized_matmul_ggml.
Sourcepub fn quantized_matmul_id_ggml(
&mut self,
registry: &mut KernelRegistry,
device: &MlxDevice,
input: &MlxBuffer,
weight: &MlxBuffer,
ids: &MlxBuffer,
output: &mut MlxBuffer,
params: &GgmlQuantizedMatmulIdParams,
) -> Result<()>
pub fn quantized_matmul_id_ggml( &mut self, registry: &mut KernelRegistry, device: &MlxDevice, input: &MlxBuffer, weight: &MlxBuffer, ids: &MlxBuffer, output: &mut MlxBuffer, params: &GgmlQuantizedMatmulIdParams, ) -> Result<()>
Encode an expert-routed GGML block-format quantized mat-vec into this session’s encoder.
Delegates to ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml.
Sourcepub fn sdpa(
&mut self,
registry: &mut KernelRegistry,
device: &MlxDevice,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
output: &MlxBuffer,
params: &SdpaParams,
batch_size: u32,
) -> Result<()>
pub fn sdpa( &mut self, registry: &mut KernelRegistry, device: &MlxDevice, q: &MlxBuffer, k: &MlxBuffer, v: &MlxBuffer, output: &MlxBuffer, params: &SdpaParams, batch_size: u32, ) -> Result<()>
Encode scaled dot-product attention into this session’s encoder.
Delegates to ops::sdpa::sdpa.
Sourcepub fn flash_attn_vec(
&mut self,
registry: &mut KernelRegistry,
device: &MlxDevice,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
output: &MlxBuffer,
tmp: &MlxBuffer,
params: &FlashAttnVecParams,
) -> Result<()>
pub fn flash_attn_vec( &mut self, registry: &mut KernelRegistry, device: &MlxDevice, q: &MlxBuffer, k: &MlxBuffer, v: &MlxBuffer, output: &MlxBuffer, tmp: &MlxBuffer, params: &FlashAttnVecParams, ) -> Result<()>
Encode flash attention vector (SIMD-vectorized decode-path SDPA).
Delegates to ops::flash_attn_vec::flash_attn_vec.
Sourcepub fn elementwise_add(
&mut self,
registry: &mut KernelRegistry,
device: &DeviceRef,
a: &MlxBuffer,
b: &MlxBuffer,
output: &MlxBuffer,
n_elements: usize,
dtype: DType,
) -> Result<()>
pub fn elementwise_add( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, a: &MlxBuffer, b: &MlxBuffer, output: &MlxBuffer, n_elements: usize, dtype: DType, ) -> Result<()>
Encode an elementwise add into this session’s encoder.
Delegates to ops::elementwise::elementwise_add.
Sourcepub fn elementwise_mul(
&mut self,
registry: &mut KernelRegistry,
device: &DeviceRef,
a: &MlxBuffer,
b: &MlxBuffer,
output: &MlxBuffer,
n_elements: usize,
dtype: DType,
) -> Result<()>
pub fn elementwise_mul( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, a: &MlxBuffer, b: &MlxBuffer, output: &MlxBuffer, n_elements: usize, dtype: DType, ) -> Result<()>
Encode an elementwise multiply into this session’s encoder.
Delegates to ops::elementwise::elementwise_mul.
Sourcepub fn rope(
&mut self,
registry: &mut KernelRegistry,
device: &DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
positions_buf: &MlxBuffer,
seq_len: u32,
head_dim: u32,
) -> Result<()>
pub fn rope( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, input: &MlxBuffer, output: &MlxBuffer, params_buf: &MlxBuffer, positions_buf: &MlxBuffer, seq_len: u32, head_dim: u32, ) -> Result<()>
Encode a RoPE transform into this session’s encoder.
Delegates to ops::rope::dispatch_rope.
Sourcepub fn gelu(
&mut self,
registry: &mut KernelRegistry,
device: &DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
) -> Result<()>
pub fn gelu( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, input: &MlxBuffer, output: &MlxBuffer, ) -> Result<()>
Encode a GELU activation into this session’s encoder.
Delegates to ops::gelu::dispatch_gelu.
Sourcepub fn softmax(
&mut self,
registry: &mut KernelRegistry,
device: &DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
cols: u32,
) -> Result<()>
pub fn softmax( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, input: &MlxBuffer, output: &MlxBuffer, params_buf: &MlxBuffer, rows: u32, cols: u32, ) -> Result<()>
Encode a softmax into this session’s encoder.
Delegates to ops::softmax::dispatch_softmax.
Sourcepub fn softcap(
&mut self,
registry: &mut KernelRegistry,
device: &DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
cap: f32,
) -> Result<()>
pub fn softcap( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, input: &MlxBuffer, output: &MlxBuffer, params_buf: &MlxBuffer, cap: f32, ) -> Result<()>
Encode a softcap into this session’s encoder.
Delegates to ops::softcap::dispatch_softcap.
Sourcepub fn rms_norm_no_scale_f32(
&mut self,
registry: &mut KernelRegistry,
device: &DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()>
pub fn rms_norm_no_scale_f32( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, input: &MlxBuffer, output: &MlxBuffer, params_buf: &MlxBuffer, rows: u32, dim: u32, ) -> Result<()>
Encode an RMS norm without learned scale (f32) into this session’s encoder.
Delegates to ops::rms_norm::dispatch_rms_norm_no_scale_f32.
Sourcepub fn rope_neox_f32(
&mut self,
registry: &mut KernelRegistry,
device: &DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
positions_buf: &MlxBuffer,
freq_factors: Option<&MlxBuffer>,
seq_len: u32,
n_heads: u32,
head_dim: u32,
rope_dim: u32,
) -> Result<()>
pub fn rope_neox_f32( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, input: &MlxBuffer, output: &MlxBuffer, params_buf: &MlxBuffer, positions_buf: &MlxBuffer, freq_factors: Option<&MlxBuffer>, seq_len: u32, n_heads: u32, head_dim: u32, rope_dim: u32, ) -> Result<()>
Encode a NeoX RoPE (f32) with optional freq_factors into this session’s encoder.
Delegates to ops::rope::dispatch_rope_neox_f32.
Sourcepub fn barrier(&mut self)
pub fn barrier(&mut self)
Insert a GPU memory barrier (MTLBarrierScopeBuffers).
Unconditional barrier — always emits. Use barrier_between for
automatic conflict detection that can elide unnecessary barriers.
Sourcepub fn barrier_between(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer])
pub fn barrier_between(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer])
Smart barrier with conflict detection.
Checks if the next dispatch (with the given read and write buffers) actually conflicts with any dispatch in the current concurrent group. If yes, emits a Metal barrier and resets the tracker. If no, the barrier is elided and the dispatch can run concurrently.
This mirrors llama.cpp’s ggml_metal_op_concurrency_check +
ggml_metal_op_concurrency_reset pattern.
Sourcepub fn dump_group_stats(&self)
pub fn dump_group_stats(&self)
Print group size histogram to stderr (for HF2Q_MLX_TIMING debug).
Sourcepub fn track_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer])
pub fn track_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer])
Register a dispatch’s buffer ranges without checking for conflicts.
Use after dispatching an op that doesn’t need a barrier check (e.g., the first dispatch in a session, or dispatches known to be concurrent).
Sourcepub fn barrier_count(&self) -> u32
pub fn barrier_count(&self) -> u32
Return the number of barriers inserted so far in this session.
Sourcepub fn tracker_overhead_ns(&self) -> u64
pub fn tracker_overhead_ns(&self) -> u64
Cumulative nanoseconds spent in ConflictTracker checks (diagnostic). Returns 0 when timing is not compiled in.
Sourcepub fn encoder_mut(&mut self) -> &mut CommandEncoder
pub fn encoder_mut(&mut self) -> &mut CommandEncoder
Borrow the underlying command encoder for direct op dispatch.
Use this when you need to call an op function that is not wrapped by
a GraphSession method. The returned encoder is the same shared
encoder — all dispatches still go into the same command buffer.
Sourcepub fn is_recording(&self) -> bool
pub fn is_recording(&self) -> bool
Whether this session is in capture/record mode.
Sourcepub fn finish(self) -> Result<()>
pub fn finish(self) -> Result<()>
Commit the command buffer and wait for GPU completion.
This is the ONLY sync point per forward pass. After this call, all output buffers are readable by the CPU.
In recording mode: extracts the captured graph, replays it into
the encoder via ComputeGraph::encode_sequential(), then commits
and waits. The result is identical to the direct-dispatch path.
Consumes the session — no further ops can be encoded.
Sourcepub fn commit(self) -> CommandEncoder
pub fn commit(self) -> CommandEncoder
Commit the command buffer WITHOUT waiting.
The GPU begins executing immediately. Use this for fire-and-forget dispatch when you do not need results until later.
In recording mode: replays the captured graph before committing.
Consumes the session.
Sourcepub fn finish_with_timing(self, session_begin: Instant) -> Result<(u64, u64)>
pub fn finish_with_timing(self, session_begin: Instant) -> Result<(u64, u64)>
Commit the command buffer and wait, returning split timing.
Returns (encoding_ns, gpu_wait_ns) where:
encoding_nsis the time from session begin to commit (CPU encoding)gpu_wait_nsis the time from commit to GPU completion
The session_begin instant should be captured right after exec.begin().
In recording mode: replays the captured graph before committing.
Consumes the session.
Sourcepub fn finish_with_fusion(
self,
registry: &mut KernelRegistry,
device: &DeviceRef,
) -> Result<u32>
pub fn finish_with_fusion( self, registry: &mut KernelRegistry, device: &DeviceRef, ) -> Result<u32>
Finish with fusion: run the RMS norm + MUL fusion pass before replaying the graph.
Only meaningful in recording mode. In direct-dispatch mode, this
behaves identically to finish().
Returns (fusions_applied,) on success.
Sourcepub fn finish_with_fusion_and_timing(
self,
registry: &mut KernelRegistry,
device: &DeviceRef,
session_begin: Instant,
) -> Result<(u64, u64, u32)>
pub fn finish_with_fusion_and_timing( self, registry: &mut KernelRegistry, device: &DeviceRef, session_begin: Instant, ) -> Result<(u64, u64, u32)>
Finish with fusion and split timing.
Like finish_with_timing but runs the fusion pass first.
Returns (encoding_ns, gpu_wait_ns, fusions_applied).
Sourcepub fn finish_with_fusion_and_reorder(
self,
registry: &mut KernelRegistry,
device: &DeviceRef,
) -> Result<(u32, u32)>
pub fn finish_with_fusion_and_reorder( self, registry: &mut KernelRegistry, device: &DeviceRef, ) -> Result<(u32, u32)>
Finish with fusion AND reorder: run both graph optimization passes before replaying the graph.
Only meaningful in recording mode. In direct-dispatch mode, this
behaves identically to finish().
Returns (fusions_applied, nodes_reordered) on success.
Sourcepub fn finish_with_fusion_reorder_and_timing(
self,
registry: &mut KernelRegistry,
device: &DeviceRef,
session_begin: Instant,
) -> Result<(u64, u64, u32, u32)>
pub fn finish_with_fusion_reorder_and_timing( self, registry: &mut KernelRegistry, device: &DeviceRef, session_begin: Instant, ) -> Result<(u64, u64, u32, u32)>
Finish with fusion, reorder, and split timing.
Like finish_with_fusion_and_timing but also runs the reorder pass.
Returns (encoding_ns, gpu_wait_ns, fusions_applied, nodes_reordered).
Sourcepub fn finish_optimized(
self,
registry: &mut KernelRegistry,
device: &DeviceRef,
) -> Result<(u32, u32, u32, u32)>
pub fn finish_optimized( self, registry: &mut KernelRegistry, device: &DeviceRef, ) -> Result<(u32, u32, u32, u32)>
Finish with the full optimization pipeline: fuse, reorder, dual-buffer encode.
Runs the fusion pass, reorder pass, then encodes the graph into two Metal command buffers for CPU/GPU overlap. The first ~10% of dispatches are committed immediately so the GPU can start executing while the CPU encodes the remaining ~90%.
Only meaningful in recording mode. In direct-dispatch mode, this
behaves identically to finish().
Returns (fusions_applied, nodes_reordered, barriers_buf0, barriers_buf1).
Sourcepub fn finish_optimized_with_timing(
self,
registry: &mut KernelRegistry,
device: &DeviceRef,
session_begin: Instant,
) -> Result<(u64, u64, u32, u32, u32, u32)>
pub fn finish_optimized_with_timing( self, registry: &mut KernelRegistry, device: &DeviceRef, session_begin: Instant, ) -> Result<(u64, u64, u32, u32, u32, u32)>
Finish with the full optimization pipeline and split timing.
Like finish_optimized but returns timing information.
Returns (encoding_ns, gpu_wait_ns, fusions, reordered, barriers_buf0, barriers_buf1).
Timing breakdown:
encoding_ns: CPU time from session begin to first buffer commit (fusion + reorder + encode chunk 0)gpu_wait_ns: wall time from second buffer commit to GPU completion (includes GPU execution of both buffers, overlapped with chunk 1 encoding)