1use metal::MTLSize;
7
8use crate::buffer::MlxBuffer;
9use crate::dtypes::DType;
10use crate::encoder::{CapturedOpKind, CommandEncoder};
11use crate::error::{MlxError, Result};
12use crate::kernel_registry::KernelRegistry;
13
14use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
15
16#[repr(C)]
20#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
21struct GpuElementwiseParams {
22 n_elements: u32,
23}
24
25const ELEMENTWISE_TG_SIZE: u64 = 256;
27
28fn elementwise_kernel_name(op: &str, dtype: DType) -> Result<&'static str> {
30 match (op, dtype) {
31 ("add", DType::F32) => Ok("elementwise_add_f32"),
32 ("add", DType::F16) => Ok("elementwise_add_f16"),
33 ("add", DType::BF16) => Ok("elementwise_add_bf16"),
34 ("mul", DType::F32) => Ok("elementwise_mul_f32"),
35 ("mul", DType::F16) => Ok("elementwise_mul_f16"),
36 ("mul", DType::BF16) => Ok("elementwise_mul_bf16"),
37 _ => Err(MlxError::InvalidArgument(format!(
38 "elementwise_{op}: unsupported dtype {dtype}"
39 ))),
40 }
41}
42
43#[allow(clippy::too_many_arguments)]
55fn elementwise_binary(
56 encoder: &mut CommandEncoder,
57 registry: &mut KernelRegistry,
58 device: &metal::DeviceRef,
59 a: &MlxBuffer,
60 b: &MlxBuffer,
61 output: &MlxBuffer,
62 n_elements: usize,
63 op: &str,
64 dtype: DType,
65) -> Result<()> {
66 if n_elements == 0 {
67 return Err(MlxError::InvalidArgument(format!(
68 "elementwise_{op}: n_elements must be > 0"
69 )));
70 }
71
72 let elem_bytes = n_elements * dtype.size_of();
73 if a.byte_len() < elem_bytes {
74 return Err(MlxError::InvalidArgument(format!(
75 "elementwise_{op}: input 'a' buffer too small: need {} bytes, have {}",
76 elem_bytes,
77 a.byte_len()
78 )));
79 }
80 if b.byte_len() < elem_bytes {
81 return Err(MlxError::InvalidArgument(format!(
82 "elementwise_{op}: input 'b' buffer too small: need {} bytes, have {}",
83 elem_bytes,
84 b.byte_len()
85 )));
86 }
87 if output.byte_len() < elem_bytes {
88 return Err(MlxError::InvalidArgument(format!(
89 "elementwise_{op}: output buffer too small: need {} bytes, have {}",
90 elem_bytes,
91 output.byte_len()
92 )));
93 }
94
95 let kernel_name = elementwise_kernel_name(op, dtype)?;
96 let pipeline = registry.get_pipeline(kernel_name, device)?;
97
98 let gpu_params = GpuElementwiseParams {
99 n_elements: n_elements as u32,
100 };
101
102 let grid = MTLSize::new(n_elements as u64, 1, 1);
103 let tg = MTLSize::new(std::cmp::min(ELEMENTWISE_TG_SIZE, n_elements as u64), 1, 1);
104
105 let op_tag = match op {
108 "mul" => CapturedOpKind::ElemMul,
109 "add" => CapturedOpKind::ElemAdd,
110 _ => CapturedOpKind::Other,
111 };
112 encoder.set_op_kind(op_tag);
113
114 encode_with_args(
115 encoder,
116 pipeline,
117 &[
118 (0, KernelArg::Buffer(a)),
119 (1, KernelArg::Buffer(b)),
120 (2, KernelArg::Buffer(output)),
121 (3, KernelArg::Bytes(as_bytes(&gpu_params))),
122 ],
123 grid,
124 tg,
125 );
126
127 Ok(())
128}
129
130#[allow(clippy::too_many_arguments)]
134pub fn elementwise_add(
135 encoder: &mut CommandEncoder,
136 registry: &mut KernelRegistry,
137 device: &metal::DeviceRef,
138 a: &MlxBuffer,
139 b: &MlxBuffer,
140 output: &MlxBuffer,
141 n_elements: usize,
142 dtype: DType,
143) -> Result<()> {
144 elementwise_binary(encoder, registry, device, a, b, output, n_elements, "add", dtype)
145}
146
147#[allow(clippy::too_many_arguments)]
151pub fn elementwise_mul(
152 encoder: &mut CommandEncoder,
153 registry: &mut KernelRegistry,
154 device: &metal::DeviceRef,
155 a: &MlxBuffer,
156 b: &MlxBuffer,
157 output: &MlxBuffer,
158 n_elements: usize,
159 dtype: DType,
160) -> Result<()> {
161 elementwise_binary(encoder, registry, device, a, b, output, n_elements, "mul", dtype)
162}
163
164#[repr(C)]
168#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
169struct GpuScalarMulParams {
170 scalar: f32,
171 count: u32,
172}
173
174pub fn scalar_mul_bf16(
190 encoder: &mut CommandEncoder,
191 registry: &mut KernelRegistry,
192 device: &metal::DeviceRef,
193 input: &MlxBuffer,
194 output: &MlxBuffer,
195 n_elements: usize,
196 scalar: f32,
197) -> Result<()> {
198 if n_elements == 0 {
199 return Err(MlxError::InvalidArgument(
200 "scalar_mul_bf16: n_elements must be > 0".into(),
201 ));
202 }
203
204 let elem_bytes = n_elements * DType::BF16.size_of();
205 if input.byte_len() < elem_bytes {
206 return Err(MlxError::InvalidArgument(format!(
207 "scalar_mul_bf16: input buffer too small: need {} bytes, have {}",
208 elem_bytes,
209 input.byte_len()
210 )));
211 }
212 if output.byte_len() < elem_bytes {
213 return Err(MlxError::InvalidArgument(format!(
214 "scalar_mul_bf16: output buffer too small: need {} bytes, have {}",
215 elem_bytes,
216 output.byte_len()
217 )));
218 }
219
220 let pipeline = registry.get_pipeline("scalar_mul_bf16", device)?;
221
222 let gpu_params = GpuScalarMulParams {
223 scalar,
224 count: n_elements as u32,
225 };
226
227 let grid = MTLSize::new(n_elements as u64, 1, 1);
228 let tg = MTLSize::new(std::cmp::min(ELEMENTWISE_TG_SIZE, n_elements as u64), 1, 1);
229
230 encode_with_args(
231 encoder,
232 pipeline,
233 &[
234 (0, KernelArg::Buffer(input)),
235 (1, KernelArg::Buffer(output)),
236 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
237 ],
238 grid,
239 tg,
240 );
241
242 Ok(())
243}
244
245pub fn scalar_mul_f32(
261 encoder: &mut CommandEncoder,
262 registry: &mut KernelRegistry,
263 device: &metal::DeviceRef,
264 input: &MlxBuffer,
265 output: &MlxBuffer,
266 n_elements: usize,
267 scalar: f32,
268) -> Result<()> {
269 if n_elements == 0 {
270 return Err(MlxError::InvalidArgument(
271 "scalar_mul_f32: n_elements must be > 0".into(),
272 ));
273 }
274
275 let elem_bytes = n_elements * DType::F32.size_of();
276 if input.byte_len() < elem_bytes {
277 return Err(MlxError::InvalidArgument(format!(
278 "scalar_mul_f32: input buffer too small: need {} bytes, have {}",
279 elem_bytes,
280 input.byte_len()
281 )));
282 }
283 if output.byte_len() < elem_bytes {
284 return Err(MlxError::InvalidArgument(format!(
285 "scalar_mul_f32: output buffer too small: need {} bytes, have {}",
286 elem_bytes,
287 output.byte_len()
288 )));
289 }
290
291 let pipeline = registry.get_pipeline("scalar_mul_f32", device)?;
292
293 let gpu_params = GpuScalarMulParams {
294 scalar,
295 count: n_elements as u32,
296 };
297
298 let grid = MTLSize::new(n_elements as u64, 1, 1);
299 let tg = MTLSize::new(std::cmp::min(ELEMENTWISE_TG_SIZE, n_elements as u64), 1, 1);
300
301 encode_with_args(
302 encoder,
303 pipeline,
304 &[
305 (0, KernelArg::Buffer(input)),
306 (1, KernelArg::Buffer(output)),
307 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
308 ],
309 grid,
310 tg,
311 );
312
313 Ok(())
314}
315
316#[repr(C)]
318#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
319struct GpuEmbedGatherScaleParams {
320 scale: f32,
321 hidden_size: u32,
322 token_id: u32,
323}
324
325pub fn embedding_gather_scale_f32(
338 encoder: &mut CommandEncoder,
339 registry: &mut KernelRegistry,
340 device: &metal::DeviceRef,
341 embed_table: &MlxBuffer,
342 output: &MlxBuffer,
343 token_id: u32,
344 hidden_size: usize,
345 scale: f32,
346) -> Result<()> {
347 if hidden_size == 0 {
348 return Err(MlxError::InvalidArgument(
349 "embedding_gather_scale_f32: hidden_size must be > 0".into(),
350 ));
351 }
352 let out_bytes = hidden_size * std::mem::size_of::<f32>();
353 if output.byte_len() < out_bytes {
354 return Err(MlxError::InvalidArgument(format!(
355 "embedding_gather_scale_f32: output too small: need {} bytes, have {}",
356 out_bytes, output.byte_len()
357 )));
358 }
359
360 let pipeline = registry.get_pipeline("embedding_gather_scale_f32", device)?;
361
362 let gpu_params = GpuEmbedGatherScaleParams {
363 scale,
364 hidden_size: hidden_size as u32,
365 token_id,
366 };
367
368 let grid = MTLSize::new(hidden_size as u64, 1, 1);
369 let tg = MTLSize::new(std::cmp::min(ELEMENTWISE_TG_SIZE, hidden_size as u64), 1, 1);
370
371 encode_with_args(
372 encoder,
373 pipeline,
374 &[
375 (0, KernelArg::Buffer(embed_table)),
376 (1, KernelArg::Buffer(output)),
377 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
378 ],
379 grid,
380 tg,
381 );
382
383 Ok(())
384}
385
386pub fn dispatch_cast_f32_to_bf16_with_encoder(
406 encoder: &mut CommandEncoder,
407 registry: &mut KernelRegistry,
408 device: &metal::DeviceRef,
409 input: &MlxBuffer,
410 output: &MlxBuffer,
411 n_elements: u32,
412) -> Result<()> {
413 cast(
414 encoder,
415 registry,
416 device,
417 input,
418 output,
419 n_elements as usize,
420 CastDirection::F32ToBF16,
421 )
422}
423
424pub fn dispatch_cast_bf16_to_f32_with_encoder(
444 encoder: &mut CommandEncoder,
445 registry: &mut KernelRegistry,
446 device: &metal::DeviceRef,
447 input: &MlxBuffer,
448 output: &MlxBuffer,
449 n_elements: u32,
450) -> Result<()> {
451 cast(
452 encoder,
453 registry,
454 device,
455 input,
456 output,
457 n_elements as usize,
458 CastDirection::BF16ToF32,
459 )
460}
461
462pub fn dispatch_scalar_mul_bf16_with_encoder(
483 encoder: &mut CommandEncoder,
484 registry: &mut KernelRegistry,
485 device: &metal::DeviceRef,
486 input: &MlxBuffer,
487 output: &MlxBuffer,
488 n_elements: u32,
489 scalar: f32,
490) -> Result<()> {
491 scalar_mul_bf16(
492 encoder,
493 registry,
494 device,
495 input,
496 output,
497 n_elements as usize,
498 scalar,
499 )
500}
501
502pub enum CastDirection {
504 F16ToF32,
506 F32ToF16,
508 BF16ToF32,
510 F32ToBF16,
512}
513
514impl CastDirection {
515 fn kernel_name(&self) -> &'static str {
516 match self {
517 CastDirection::F16ToF32 => "cast_f16_to_f32",
518 CastDirection::F32ToF16 => "cast_f32_to_f16",
519 CastDirection::BF16ToF32 => "cast_bf16_to_f32",
520 CastDirection::F32ToBF16 => "cast_f32_to_bf16",
521 }
522 }
523
524 fn input_elem_size(&self) -> usize {
525 match self {
526 CastDirection::F16ToF32 | CastDirection::BF16ToF32 => 2,
527 CastDirection::F32ToF16 | CastDirection::F32ToBF16 => 4,
528 }
529 }
530
531 fn output_elem_size(&self) -> usize {
532 match self {
533 CastDirection::F16ToF32 | CastDirection::BF16ToF32 => 4,
534 CastDirection::F32ToF16 | CastDirection::F32ToBF16 => 2,
535 }
536 }
537}
538
539pub fn cast(
546 encoder: &mut CommandEncoder,
547 registry: &mut KernelRegistry,
548 device: &metal::DeviceRef,
549 input: &MlxBuffer,
550 output: &MlxBuffer,
551 n_elements: usize,
552 direction: CastDirection,
553) -> Result<()> {
554 if n_elements == 0 {
555 return Err(MlxError::InvalidArgument(
556 "cast: n_elements must be > 0".into(),
557 ));
558 }
559
560 let input_bytes = n_elements * direction.input_elem_size();
561 if input.byte_len() < input_bytes {
562 return Err(MlxError::InvalidArgument(format!(
563 "cast: input buffer too small: need {} bytes, have {}",
564 input_bytes,
565 input.byte_len()
566 )));
567 }
568
569 let output_bytes = n_elements * direction.output_elem_size();
570 if output.byte_len() < output_bytes {
571 return Err(MlxError::InvalidArgument(format!(
572 "cast: output buffer too small: need {} bytes, have {}",
573 output_bytes,
574 output.byte_len()
575 )));
576 }
577
578 let pipeline = registry.get_pipeline(direction.kernel_name(), device)?;
579
580 let gpu_params = GpuElementwiseParams {
581 n_elements: n_elements as u32,
582 };
583
584 let grid = MTLSize::new(n_elements as u64, 1, 1);
585 let tg = MTLSize::new(std::cmp::min(ELEMENTWISE_TG_SIZE, n_elements as u64), 1, 1);
586
587 encode_with_args(
588 encoder,
589 pipeline,
590 &[
591 (0, KernelArg::Buffer(input)),
592 (1, KernelArg::Buffer(output)),
593 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
594 ],
595 grid,
596 tg,
597 );
598
599 Ok(())
600}