1use metal::MTLSize;
44
45use crate::buffer::MlxBuffer;
46use crate::dtypes::DType;
47use crate::encoder::CommandEncoder;
48use crate::error::{MlxError, Result};
49use crate::kernel_registry::KernelRegistry;
50
51pub static SSM_CONV_SHADER_SOURCE: &str = include_str!("../shaders/ssm_conv.metal");
52
53pub fn register(registry: &mut KernelRegistry) {
55 registry.register_source("ssm_conv_forward_f32", SSM_CONV_SHADER_SOURCE);
56 registry.register_source("ssm_conv_forward_bf16", SSM_CONV_SHADER_SOURCE);
57 registry.register_source("ssm_conv_state_update_f32", SSM_CONV_SHADER_SOURCE);
58 registry.register_source("ssm_conv_state_update_bf16", SSM_CONV_SHADER_SOURCE);
59}
60
61#[derive(Debug, Clone, Copy)]
63pub struct SsmConvParams {
64 pub channels: u32,
65 pub n_tokens: u32,
66 pub n_seqs: u32,
67 pub k_width: u32, }
69
70fn validate(
71 params: &SsmConvParams,
72 x: &MlxBuffer,
73 kernel_w: &MlxBuffer,
74 old_state: &MlxBuffer,
75 new_state: &MlxBuffer,
76 y: &MlxBuffer,
77) -> Result<()> {
78 if params.channels == 0 || params.n_tokens == 0 || params.n_seqs == 0 {
79 return Err(MlxError::InvalidArgument(
80 "ssm_conv: channels, n_tokens, n_seqs must all be > 0".into(),
81 ));
82 }
83 if params.k_width < 2 {
84 return Err(MlxError::InvalidArgument(
85 "ssm_conv: k_width must be >= 2 (K=1 has empty state)".into(),
86 ));
87 }
88 let x_elems = (params.channels as usize)
89 .checked_mul(params.n_tokens as usize)
90 .and_then(|v| v.checked_mul(params.n_seqs as usize))
91 .ok_or_else(|| MlxError::InvalidArgument("ssm_conv: shape overflow".into()))?;
92 let w_elems = (params.k_width as usize) * (params.channels as usize);
93 let s_elems = ((params.k_width - 1) as usize)
94 * (params.channels as usize)
95 * (params.n_seqs as usize);
96
97 if x.element_count() != x_elems {
98 return Err(MlxError::InvalidArgument(format!(
99 "ssm_conv: x element count {} != channels({}) * n_tokens({}) * n_seqs({})",
100 x.element_count(),
101 params.channels,
102 params.n_tokens,
103 params.n_seqs
104 )));
105 }
106 if y.element_count() != x_elems {
107 return Err(MlxError::InvalidArgument(format!(
108 "ssm_conv: y element count {} != expected {}",
109 y.element_count(),
110 x_elems
111 )));
112 }
113 if kernel_w.element_count() != w_elems {
114 return Err(MlxError::InvalidArgument(format!(
115 "ssm_conv: kernel_w element count {} != K({}) * channels({})",
116 kernel_w.element_count(),
117 params.k_width,
118 params.channels
119 )));
120 }
121 if old_state.element_count() != s_elems || new_state.element_count() != s_elems {
122 return Err(MlxError::InvalidArgument(format!(
123 "ssm_conv: state element count mismatch; old={} new={} expected {}",
124 old_state.element_count(),
125 new_state.element_count(),
126 s_elems
127 )));
128 }
129
130 let dt = x.dtype();
131 for (name, buf) in [
132 ("kernel_w", kernel_w),
133 ("old_state", old_state),
134 ("new_state", new_state),
135 ("y", y),
136 ] {
137 if buf.dtype() != dt {
138 return Err(MlxError::InvalidArgument(format!(
139 "ssm_conv: dtype mismatch — x is {}, {} is {}",
140 dt,
141 name,
142 buf.dtype()
143 )));
144 }
145 }
146 Ok(())
147}
148
149pub fn dispatch_ssm_conv(
166 encoder: &mut CommandEncoder,
167 registry: &mut KernelRegistry,
168 device: &metal::DeviceRef,
169 x: &MlxBuffer,
170 kernel_w: &MlxBuffer,
171 old_state: &MlxBuffer,
172 new_state: &MlxBuffer,
173 y: &MlxBuffer,
174 params_buf: &MlxBuffer,
175 params: SsmConvParams,
176) -> Result<()> {
177 validate(¶ms, x, kernel_w, old_state, new_state, y)?;
178
179 let (fwd_name, state_name) = match x.dtype() {
180 DType::F32 => ("ssm_conv_forward_f32", "ssm_conv_state_update_f32"),
181 DType::BF16 => ("ssm_conv_forward_bf16", "ssm_conv_state_update_bf16"),
182 other => {
183 return Err(MlxError::InvalidArgument(format!(
184 "ssm_conv: unsupported dtype {}",
185 other
186 )))
187 }
188 };
189
190 let fwd_pipeline = registry.get_pipeline(fwd_name, device)?;
192 let fwd_grid = MTLSize::new(
193 params.channels as u64,
194 params.n_tokens as u64,
195 params.n_seqs as u64,
196 );
197 let tg_c = std::cmp::min(params.channels, 256).max(1);
199 let remain = 256u32 / tg_c;
200 let tg_t = std::cmp::min(params.n_tokens, remain).max(1);
201 let remain2 = (256u32 / (tg_c * tg_t)).max(1);
202 let tg_s = std::cmp::min(params.n_seqs, remain2).max(1);
203 let fwd_tg = MTLSize::new(tg_c as u64, tg_t as u64, tg_s as u64);
204
205 encoder.encode(
206 fwd_pipeline,
207 &[
208 (0, x),
209 (1, kernel_w),
210 (2, old_state),
211 (3, y),
212 (4, params_buf),
213 ],
214 fwd_grid,
215 fwd_tg,
216 );
217
218 let state_pipeline = registry.get_pipeline(state_name, device)?;
220 let state_grid = MTLSize::new(
221 (params.k_width - 1) as u64,
222 params.channels as u64,
223 params.n_seqs as u64,
224 );
225 let su_tg_i = (params.k_width - 1).max(1);
226 let su_remain = (256u32 / su_tg_i).max(1);
227 let su_tg_c_raw = std::cmp::min(params.channels, su_remain).max(1);
228 fn gcd_u32(mut a: u32, mut b: u32) -> u32 {
243 while b != 0 { let t = b; b = a % b; a = t; }
244 a
245 }
246 let step = 32u32 / gcd_u32(su_tg_i, 32);
247 let su_tg_c = if step <= 1 {
248 su_tg_c_raw
249 } else if su_tg_c_raw >= step {
250 (su_tg_c_raw / step) * step } else {
252 su_tg_c_raw
257 };
258 let su_remain2 = (256u32 / (su_tg_i * su_tg_c)).max(1);
259 let su_tg_s = std::cmp::min(params.n_seqs, su_remain2).max(1);
260 let state_tg = MTLSize::new(su_tg_i as u64, su_tg_c as u64, su_tg_s as u64);
261
262 encoder.encode(
263 state_pipeline,
264 &[
265 (0, x),
266 (1, old_state),
267 (2, new_state),
268 (3, params_buf),
269 ],
270 state_grid,
271 state_tg,
272 );
273
274 Ok(())
275}