Skip to main content

mlx_native/ops/
compute_g_beta.rs

1//! Fused GPU kernel for DeltaNet g and beta computation.
2//!
3//! `g[t, vh]    = softplus(alpha_logit[t, vh] + dt_bias[vh]) * (-ssm_a[vh])`
4//! `beta[t, vh] = sigmoid(beta_logit[t, vh])`
5//!
6//! Replaces the CPU bridge in `compute_g_and_beta_cpu`.
7//! For seq=1 this dispatches 1×nv threads — trivial GPU work but eliminates
8//! 2 CPU-GPU buffer downloads (alpha_logit, beta_logit) per delta-net layer.
9
10use metal::MTLSize;
11
12use crate::buffer::MlxBuffer;
13use crate::device::MlxDevice;
14use crate::dtypes::DType;
15use crate::encoder::CommandEncoder;
16use crate::error::{MlxError, Result};
17use crate::kernel_registry::KernelRegistry;
18
19pub static COMPUTE_G_BETA_SHADER_SOURCE: &str = include_str!("../shaders/compute_g_beta.metal");
20
21/// Register `compute_g_beta_f32` shader with the kernel registry.
22pub fn register(registry: &mut KernelRegistry) {
23    registry.register_source("compute_g_beta_f32", COMPUTE_G_BETA_SHADER_SOURCE);
24}
25
26/// Dispatch `compute_g_beta` kernel into `encoder`.
27///
28/// All f32 buffers. `alpha_logit` and `beta_logit` are `[seq, nv]`.
29/// `dt_bias` and `ssm_a` are `[nv]`. `g_out` and `beta_out` are `[seq, nv]`.
30/// `params_buf` must hold `[nv: u32, seq: u32]` (8 bytes).
31#[allow(clippy::too_many_arguments)]
32pub fn dispatch_compute_g_beta(
33    encoder: &mut CommandEncoder,
34    registry: &mut KernelRegistry,
35    device: &metal::DeviceRef,
36    alpha_logit: &MlxBuffer,
37    beta_logit: &MlxBuffer,
38    dt_bias: &MlxBuffer,
39    ssm_a: &MlxBuffer,
40    g_out: &MlxBuffer,
41    beta_out: &MlxBuffer,
42    params_buf: &MlxBuffer,
43    seq: u32,
44    nv: u32,
45) -> Result<()> {
46    let n = seq * nv;
47    if n == 0 {
48        return Err(MlxError::InvalidArgument("compute_g_beta: n must be > 0".into()));
49    }
50    let pipeline = registry.get_pipeline("compute_g_beta_f32", device)?;
51    let tg = MTLSize::new(std::cmp::min(n as u64, 256), 1, 1);
52    let grid = MTLSize::new(n as u64, 1, 1);
53    encoder.encode(
54        pipeline,
55        &[
56            (0, alpha_logit),
57            (1, beta_logit),
58            (2, dt_bias),
59            (3, ssm_a),
60            (4, g_out),
61            (5, beta_out),
62            (6, params_buf),
63        ],
64        grid,
65        tg,
66    );
67    Ok(())
68}
69
70/// Allocate output buffers, dispatch the kernel, commit, and return `(g, beta)`.
71///
72/// Both returned buffers are `[seq * nv]` f32 in flat token-major layout.
73#[allow(clippy::too_many_arguments)]
74pub fn compute_g_beta_gpu(
75    registry: &mut KernelRegistry,
76    device: &MlxDevice,
77    alpha_logit: &MlxBuffer,
78    beta_logit: &MlxBuffer,
79    dt_bias: &MlxBuffer,
80    ssm_a: &MlxBuffer,
81    seq: u32,
82    nv: u32,
83) -> Result<(MlxBuffer, MlxBuffer)> {
84    let n = (seq * nv) as usize;
85    let g_out = device
86        .alloc_buffer(n * 4, DType::F32, vec![n])
87        .map_err(|e| MlxError::InvalidArgument(format!("compute_g_beta_gpu: alloc g: {e}")))?;
88    let beta_out = device
89        .alloc_buffer(n * 4, DType::F32, vec![n])
90        .map_err(|e| MlxError::InvalidArgument(format!("compute_g_beta_gpu: alloc beta: {e}")))?;
91
92    let mut params_buf = device
93        .alloc_buffer(8, DType::U32, vec![2])
94        .map_err(|e| MlxError::InvalidArgument(format!("compute_g_beta_gpu: alloc params: {e}")))?;
95    params_buf
96        .as_mut_slice::<u32>()
97        .map_err(|e| MlxError::InvalidArgument(format!("compute_g_beta_gpu: write params: {e}")))?
98        [0] = nv;
99    params_buf
100        .as_mut_slice::<u32>()
101        .map_err(|e| MlxError::InvalidArgument(format!("compute_g_beta_gpu: write params2: {e}")))?
102        [1] = seq;
103
104    let mut enc = device
105        .command_encoder()
106        .map_err(|e| MlxError::InvalidArgument(format!("compute_g_beta_gpu: command_encoder: {e}")))?;
107    dispatch_compute_g_beta(
108        &mut enc,
109        registry,
110        device.metal_device(),
111        alpha_logit,
112        beta_logit,
113        dt_bias,
114        ssm_a,
115        &g_out,
116        &beta_out,
117        &params_buf,
118        seq,
119        nv,
120    )?;
121    enc.commit_and_wait()
122        .map_err(|e| MlxError::InvalidArgument(format!("compute_g_beta_gpu: commit: {e}")))?;
123
124    Ok((g_out, beta_out))
125}