mlx_native/ops/
compute_g_beta.rs1use metal::MTLSize;
11
12use crate::buffer::MlxBuffer;
13use crate::device::MlxDevice;
14use crate::dtypes::DType;
15use crate::encoder::CommandEncoder;
16use crate::error::{MlxError, Result};
17use crate::kernel_registry::KernelRegistry;
18
19pub static COMPUTE_G_BETA_SHADER_SOURCE: &str = include_str!("../shaders/compute_g_beta.metal");
20
21pub fn register(registry: &mut KernelRegistry) {
23 registry.register_source("compute_g_beta_f32", COMPUTE_G_BETA_SHADER_SOURCE);
24}
25
26#[allow(clippy::too_many_arguments)]
32pub fn dispatch_compute_g_beta(
33 encoder: &mut CommandEncoder,
34 registry: &mut KernelRegistry,
35 device: &metal::DeviceRef,
36 alpha_logit: &MlxBuffer,
37 beta_logit: &MlxBuffer,
38 dt_bias: &MlxBuffer,
39 ssm_a: &MlxBuffer,
40 g_out: &MlxBuffer,
41 beta_out: &MlxBuffer,
42 params_buf: &MlxBuffer,
43 seq: u32,
44 nv: u32,
45) -> Result<()> {
46 let n = seq * nv;
47 if n == 0 {
48 return Err(MlxError::InvalidArgument("compute_g_beta: n must be > 0".into()));
49 }
50 let pipeline = registry.get_pipeline("compute_g_beta_f32", device)?;
51 let tg = MTLSize::new(std::cmp::min(n as u64, 256), 1, 1);
52 let grid = MTLSize::new(n as u64, 1, 1);
53 encoder.encode(
54 pipeline,
55 &[
56 (0, alpha_logit),
57 (1, beta_logit),
58 (2, dt_bias),
59 (3, ssm_a),
60 (4, g_out),
61 (5, beta_out),
62 (6, params_buf),
63 ],
64 grid,
65 tg,
66 );
67 Ok(())
68}
69
70#[allow(clippy::too_many_arguments)]
74pub fn compute_g_beta_gpu(
75 registry: &mut KernelRegistry,
76 device: &MlxDevice,
77 alpha_logit: &MlxBuffer,
78 beta_logit: &MlxBuffer,
79 dt_bias: &MlxBuffer,
80 ssm_a: &MlxBuffer,
81 seq: u32,
82 nv: u32,
83) -> Result<(MlxBuffer, MlxBuffer)> {
84 let n = (seq * nv) as usize;
85 let g_out = device
86 .alloc_buffer(n * 4, DType::F32, vec![n])
87 .map_err(|e| MlxError::InvalidArgument(format!("compute_g_beta_gpu: alloc g: {e}")))?;
88 let beta_out = device
89 .alloc_buffer(n * 4, DType::F32, vec![n])
90 .map_err(|e| MlxError::InvalidArgument(format!("compute_g_beta_gpu: alloc beta: {e}")))?;
91
92 let mut params_buf = device
93 .alloc_buffer(8, DType::U32, vec![2])
94 .map_err(|e| MlxError::InvalidArgument(format!("compute_g_beta_gpu: alloc params: {e}")))?;
95 params_buf
96 .as_mut_slice::<u32>()
97 .map_err(|e| MlxError::InvalidArgument(format!("compute_g_beta_gpu: write params: {e}")))?
98 [0] = nv;
99 params_buf
100 .as_mut_slice::<u32>()
101 .map_err(|e| MlxError::InvalidArgument(format!("compute_g_beta_gpu: write params2: {e}")))?
102 [1] = seq;
103
104 let mut enc = device
105 .command_encoder()
106 .map_err(|e| MlxError::InvalidArgument(format!("compute_g_beta_gpu: command_encoder: {e}")))?;
107 dispatch_compute_g_beta(
108 &mut enc,
109 registry,
110 device.metal_device(),
111 alpha_logit,
112 beta_logit,
113 dt_bias,
114 ssm_a,
115 &g_out,
116 &beta_out,
117 ¶ms_buf,
118 seq,
119 nv,
120 )?;
121 enc.commit_and_wait()
122 .map_err(|e| MlxError::InvalidArgument(format!("compute_g_beta_gpu: commit: {e}")))?;
123
124 Ok((g_out, beta_out))
125}