Skip to main content

GraphSession

Struct GraphSession 

Source
pub struct GraphSession<'a> { /* private fields */ }

Implementations§

Source§

impl<'a> GraphSession<'a>

Source

pub fn rms_norm( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, input: &MlxBuffer, weight: &MlxBuffer, output: &MlxBuffer, params_buf: &MlxBuffer, rows: u32, dim: u32, ) -> Result<()>

Encode an RMS normalization into this session’s encoder.

Delegates to ops::rms_norm::dispatch_rms_norm.

Source

pub fn quantized_matmul( &mut self, registry: &mut KernelRegistry, device: &MlxDevice, input: &MlxBuffer, weight: &MlxBuffer, scales: &MlxBuffer, biases: &MlxBuffer, params: &QuantizedMatmulParams, ) -> Result<MlxBuffer>

Encode a quantized matrix multiplication into this session’s encoder.

Delegates to ops::quantized_matmul::quantized_matmul. Returns the freshly allocated output buffer.

Source

pub fn quantized_matmul_simd( &mut self, registry: &mut KernelRegistry, device: &MlxDevice, input: &MlxBuffer, weight: &MlxBuffer, scales: &MlxBuffer, biases: &MlxBuffer, params: &QuantizedMatmulParams, ) -> Result<MlxBuffer>

Encode a SIMD-optimized quantized matmul into this session’s encoder.

Delegates to ops::quantized_matmul::quantized_matmul_simd. Returns the freshly allocated output buffer.

Source

pub fn quantized_matmul_ggml( &mut self, registry: &mut KernelRegistry, device: &MlxDevice, input: &MlxBuffer, weight: &MlxBuffer, output: &mut MlxBuffer, params: &GgmlQuantizedMatmulParams, ) -> Result<()>

Encode a GGML block-format quantized mat-vec into this session’s encoder.

Delegates to ops::quantized_matmul_ggml::quantized_matmul_ggml.

Source

pub fn quantized_matmul_id_ggml( &mut self, registry: &mut KernelRegistry, device: &MlxDevice, input: &MlxBuffer, weight: &MlxBuffer, ids: &MlxBuffer, output: &mut MlxBuffer, params: &GgmlQuantizedMatmulIdParams, ) -> Result<()>

Encode an expert-routed GGML block-format quantized mat-vec into this session’s encoder.

Delegates to ops::quantized_matmul_id_ggml::quantized_matmul_id_ggml.

Source

pub fn sdpa( &mut self, registry: &mut KernelRegistry, device: &MlxDevice, q: &MlxBuffer, k: &MlxBuffer, v: &MlxBuffer, output: &MlxBuffer, params: &SdpaParams, batch_size: u32, ) -> Result<()>

Encode scaled dot-product attention into this session’s encoder.

Delegates to ops::sdpa::sdpa.

Source

pub fn flash_attn_vec( &mut self, registry: &mut KernelRegistry, device: &MlxDevice, q: &MlxBuffer, k: &MlxBuffer, v: &MlxBuffer, output: &MlxBuffer, tmp: &MlxBuffer, params: &FlashAttnVecParams, ) -> Result<()>

Encode flash attention vector (SIMD-vectorized decode-path SDPA).

Delegates to ops::flash_attn_vec::flash_attn_vec.

Source

pub fn elementwise_add( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, a: &MlxBuffer, b: &MlxBuffer, output: &MlxBuffer, n_elements: usize, dtype: DType, ) -> Result<()>

Encode an elementwise add into this session’s encoder.

Delegates to ops::elementwise::elementwise_add.

Source

pub fn elementwise_mul( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, a: &MlxBuffer, b: &MlxBuffer, output: &MlxBuffer, n_elements: usize, dtype: DType, ) -> Result<()>

Encode an elementwise multiply into this session’s encoder.

Delegates to ops::elementwise::elementwise_mul.

Source

pub fn rope( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, input: &MlxBuffer, output: &MlxBuffer, params_buf: &MlxBuffer, positions_buf: &MlxBuffer, seq_len: u32, head_dim: u32, ) -> Result<()>

Encode a RoPE transform into this session’s encoder.

Delegates to ops::rope::dispatch_rope.

Source

pub fn gelu( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, input: &MlxBuffer, output: &MlxBuffer, ) -> Result<()>

Encode a GELU activation into this session’s encoder.

Delegates to ops::gelu::dispatch_gelu.

Source

pub fn softmax( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, input: &MlxBuffer, output: &MlxBuffer, params_buf: &MlxBuffer, rows: u32, cols: u32, ) -> Result<()>

Encode a softmax into this session’s encoder.

Delegates to ops::softmax::dispatch_softmax.

Source

pub fn softcap( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, input: &MlxBuffer, output: &MlxBuffer, params_buf: &MlxBuffer, cap: f32, ) -> Result<()>

Encode a softcap into this session’s encoder.

Delegates to ops::softcap::dispatch_softcap.

Source

pub fn rms_norm_no_scale_f32( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, input: &MlxBuffer, output: &MlxBuffer, params_buf: &MlxBuffer, rows: u32, dim: u32, ) -> Result<()>

Encode an RMS norm without learned scale (f32) into this session’s encoder.

Delegates to ops::rms_norm::dispatch_rms_norm_no_scale_f32.

Source

pub fn rope_neox_f32( &mut self, registry: &mut KernelRegistry, device: &DeviceRef, input: &MlxBuffer, output: &MlxBuffer, params_buf: &MlxBuffer, positions_buf: &MlxBuffer, freq_factors: Option<&MlxBuffer>, seq_len: u32, n_heads: u32, head_dim: u32, rope_dim: u32, ) -> Result<()>

Encode a NeoX RoPE (f32) with optional freq_factors into this session’s encoder.

Delegates to ops::rope::dispatch_rope_neox_f32.

Source

pub fn barrier(&mut self)

Insert a GPU memory barrier (MTLBarrierScopeBuffers).

Unconditional barrier — always emits. Use barrier_between for automatic conflict detection that can elide unnecessary barriers.

Source

pub fn barrier_between(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer])

Smart barrier with conflict detection.

Checks if the next dispatch (with the given read and write buffers) actually conflicts with any dispatch in the current concurrent group. If yes, emits a Metal barrier and resets the tracker. If no, the barrier is elided and the dispatch can run concurrently.

This mirrors llama.cpp’s ggml_metal_op_concurrency_check + ggml_metal_op_concurrency_reset pattern.

Source

pub fn dump_group_stats(&self)

Print group size histogram to stderr (for HF2Q_MLX_TIMING debug).

Source

pub fn track_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer])

Register a dispatch’s buffer ranges without checking for conflicts.

Use after dispatching an op that doesn’t need a barrier check (e.g., the first dispatch in a session, or dispatches known to be concurrent).

In recording mode, also retroactively annotates the most recently captured dispatch node with these ranges if it was missing them. That keeps the reorder pass able to reason about dispatches that were preceded by track_dispatch rather than barrier_between.

Source

pub fn barrier_count(&self) -> u32

Return the number of barriers inserted so far in this session.

Source

pub fn tracker_overhead_ns(&self) -> u64

Cumulative nanoseconds spent in ConflictTracker checks (diagnostic). Returns 0 when timing is not compiled in.

Source

pub fn encoder_mut(&mut self) -> &mut CommandEncoder

Borrow the underlying command encoder for direct op dispatch.

Use this when you need to call an op function that is not wrapped by a GraphSession method. The returned encoder is the same shared encoder — all dispatches still go into the same command buffer.

Source

pub fn device(&self) -> &MlxDevice

Borrow the device reference.

Source

pub fn is_recording(&self) -> bool

Whether this session is in capture/record mode.

Source

pub fn finish(self) -> Result<()>

Commit the command buffer and wait for GPU completion.

This is the ONLY sync point per forward pass. After this call, all output buffers are readable by the CPU.

In recording mode: extracts the captured graph, replays it into the encoder via ComputeGraph::encode_sequential(), then commits and waits. The result is identical to the direct-dispatch path.

Consumes the session — no further ops can be encoded.

Source

pub fn commit(self) -> CommandEncoder

Commit the command buffer WITHOUT waiting.

The GPU begins executing immediately. Use this for fire-and-forget dispatch when you do not need results until later.

In recording mode: replays the captured graph before committing.

Consumes the session.

Source

pub fn finish_with_timing(self, session_begin: Instant) -> Result<(u64, u64)>

Commit the command buffer and wait, returning split timing.

Returns (encoding_ns, gpu_wait_ns) where:

  • encoding_ns is the time from session begin to commit (CPU encoding)
  • gpu_wait_ns is the time from commit to GPU completion

The session_begin instant should be captured right after exec.begin().

In recording mode: replays the captured graph before committing.

Consumes the session.

Source

pub fn finish_with_fusion( self, registry: &mut KernelRegistry, device: &DeviceRef, ) -> Result<u32>

Finish with fusion: run the RMS norm + MUL fusion pass before replaying the graph.

Only meaningful in recording mode. In direct-dispatch mode, this behaves identically to finish().

Returns (fusions_applied,) on success.

Source

pub fn finish_with_fusion_and_timing( self, registry: &mut KernelRegistry, device: &DeviceRef, session_begin: Instant, ) -> Result<(u64, u64, u32)>

Finish with fusion and split timing.

Like finish_with_timing but runs the fusion pass first. Returns (encoding_ns, gpu_wait_ns, fusions_applied).

Source

pub fn finish_with_fusion_and_reorder( self, registry: &mut KernelRegistry, device: &DeviceRef, ) -> Result<(u32, u32)>

Finish with fusion AND reorder: run both graph optimization passes before replaying the graph.

Only meaningful in recording mode. In direct-dispatch mode, this behaves identically to finish().

Returns (fusions_applied, nodes_reordered) on success.

Source

pub fn finish_with_fusion_reorder_and_timing( self, registry: &mut KernelRegistry, device: &DeviceRef, session_begin: Instant, ) -> Result<(u64, u64, u32, u32)>

Finish with fusion, reorder, and split timing.

Like finish_with_fusion_and_timing but also runs the reorder pass. Returns (encoding_ns, gpu_wait_ns, fusions_applied, nodes_reordered).

Source

pub fn finish_optimized( self, registry: &mut KernelRegistry, device: &DeviceRef, ) -> Result<(u32, u32, u32, u32)>

Finish with the full optimization pipeline: fuse, reorder, dual-buffer encode.

Runs the fusion pass, reorder pass, then encodes the graph into two Metal command buffers for CPU/GPU overlap. The first ~10% of dispatches are committed immediately so the GPU can start executing while the CPU encodes the remaining ~90%.

Only meaningful in recording mode. In direct-dispatch mode, this behaves identically to finish().

Returns (fusions_applied, nodes_reordered, barriers_buf0, barriers_buf1).

Source

pub fn finish_optimized_with_timing( self, registry: &mut KernelRegistry, device: &DeviceRef, session_begin: Instant, ) -> Result<(u64, u64, u32, u32, u32, u32)>

Finish with the full optimization pipeline and split timing.

Like finish_optimized but returns timing information. Returns (encoding_ns, gpu_wait_ns, fusions, reordered, barriers_buf0, barriers_buf1).

Timing breakdown:

  • encoding_ns: CPU time from session begin to first buffer commit (fusion + reorder + encode chunk 0)
  • gpu_wait_ns: wall time from second buffer commit to GPU completion (includes GPU execution of both buffers, overlapped with chunk 1 encoding)

Auto Trait Implementations§

§

impl<'a> Freeze for GraphSession<'a>

§

impl<'a> !RefUnwindSafe for GraphSession<'a>

§

impl<'a> Send for GraphSession<'a>

§

impl<'a> !Sync for GraphSession<'a>

§

impl<'a> Unpin for GraphSession<'a>

§

impl<'a> UnsafeUnpin for GraphSession<'a>

§

impl<'a> !UnwindSafe for GraphSession<'a>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.