1use metal::MTLSize;
7
8use crate::buffer::MlxBuffer;
9use crate::dtypes::DType;
10use crate::encoder::{CapturedOpKind, CommandEncoder};
11use crate::error::{MlxError, Result};
12use crate::kernel_registry::KernelRegistry;
13
14use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
15
16#[repr(C)]
20#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
21struct GpuElementwiseParams {
22 n_elements: u32,
23}
24
25const ELEMENTWISE_TG_SIZE: u64 = 256;
27
28fn elementwise_kernel_name(op: &str, dtype: DType) -> Result<&'static str> {
30 match (op, dtype) {
31 ("add", DType::F32) => Ok("elementwise_add_f32"),
32 ("add", DType::F16) => Ok("elementwise_add_f16"),
33 ("add", DType::BF16) => Ok("elementwise_add_bf16"),
34 ("mul", DType::F32) => Ok("elementwise_mul_f32"),
35 ("mul", DType::F16) => Ok("elementwise_mul_f16"),
36 ("mul", DType::BF16) => Ok("elementwise_mul_bf16"),
37 _ => Err(MlxError::InvalidArgument(format!(
38 "elementwise_{op}: unsupported dtype {dtype}"
39 ))),
40 }
41}
42
43#[allow(clippy::too_many_arguments)]
55fn elementwise_binary(
56 encoder: &mut CommandEncoder,
57 registry: &mut KernelRegistry,
58 device: &metal::DeviceRef,
59 a: &MlxBuffer,
60 b: &MlxBuffer,
61 output: &MlxBuffer,
62 n_elements: usize,
63 op: &str,
64 dtype: DType,
65) -> Result<()> {
66 if n_elements == 0 {
67 return Err(MlxError::InvalidArgument(format!(
68 "elementwise_{op}: n_elements must be > 0"
69 )));
70 }
71
72 let elem_bytes = n_elements * dtype.size_of();
73 if a.byte_len() < elem_bytes {
74 return Err(MlxError::InvalidArgument(format!(
75 "elementwise_{op}: input 'a' buffer too small: need {} bytes, have {}",
76 elem_bytes,
77 a.byte_len()
78 )));
79 }
80 if b.byte_len() < elem_bytes {
81 return Err(MlxError::InvalidArgument(format!(
82 "elementwise_{op}: input 'b' buffer too small: need {} bytes, have {}",
83 elem_bytes,
84 b.byte_len()
85 )));
86 }
87 if output.byte_len() < elem_bytes {
88 return Err(MlxError::InvalidArgument(format!(
89 "elementwise_{op}: output buffer too small: need {} bytes, have {}",
90 elem_bytes,
91 output.byte_len()
92 )));
93 }
94
95 let kernel_name = elementwise_kernel_name(op, dtype)?;
96 let pipeline = registry.get_pipeline(kernel_name, device)?;
97
98 let gpu_params = GpuElementwiseParams {
99 n_elements: n_elements as u32,
100 };
101
102 let grid = MTLSize::new(n_elements as u64, 1, 1);
103 let tg = MTLSize::new(std::cmp::min(ELEMENTWISE_TG_SIZE, n_elements as u64), 1, 1);
104
105 let op_tag = match op {
108 "mul" => CapturedOpKind::ElemMul,
109 "add" => CapturedOpKind::ElemAdd,
110 _ => CapturedOpKind::Other,
111 };
112 encoder.set_op_kind(op_tag);
113
114 encode_with_args(
115 encoder,
116 pipeline,
117 &[
118 (0, KernelArg::Buffer(a)),
119 (1, KernelArg::Buffer(b)),
120 (2, KernelArg::Buffer(output)),
121 (3, KernelArg::Bytes(as_bytes(&gpu_params))),
122 ],
123 grid,
124 tg,
125 );
126
127 Ok(())
128}
129
130#[allow(clippy::too_many_arguments)]
134pub fn elementwise_add(
135 encoder: &mut CommandEncoder,
136 registry: &mut KernelRegistry,
137 device: &metal::DeviceRef,
138 a: &MlxBuffer,
139 b: &MlxBuffer,
140 output: &MlxBuffer,
141 n_elements: usize,
142 dtype: DType,
143) -> Result<()> {
144 elementwise_binary(encoder, registry, device, a, b, output, n_elements, "add", dtype)
145}
146
147#[allow(clippy::too_many_arguments)]
151pub fn elementwise_mul(
152 encoder: &mut CommandEncoder,
153 registry: &mut KernelRegistry,
154 device: &metal::DeviceRef,
155 a: &MlxBuffer,
156 b: &MlxBuffer,
157 output: &MlxBuffer,
158 n_elements: usize,
159 dtype: DType,
160) -> Result<()> {
161 elementwise_binary(encoder, registry, device, a, b, output, n_elements, "mul", dtype)
162}
163
164#[repr(C)]
168#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
169struct GpuScalarMulParams {
170 scalar: f32,
171 count: u32,
172}
173
174pub fn scalar_mul_bf16(
190 encoder: &mut CommandEncoder,
191 registry: &mut KernelRegistry,
192 device: &metal::DeviceRef,
193 input: &MlxBuffer,
194 output: &MlxBuffer,
195 n_elements: usize,
196 scalar: f32,
197) -> Result<()> {
198 if n_elements == 0 {
199 return Err(MlxError::InvalidArgument(
200 "scalar_mul_bf16: n_elements must be > 0".into(),
201 ));
202 }
203
204 let elem_bytes = n_elements * DType::BF16.size_of();
205 if input.byte_len() < elem_bytes {
206 return Err(MlxError::InvalidArgument(format!(
207 "scalar_mul_bf16: input buffer too small: need {} bytes, have {}",
208 elem_bytes,
209 input.byte_len()
210 )));
211 }
212 if output.byte_len() < elem_bytes {
213 return Err(MlxError::InvalidArgument(format!(
214 "scalar_mul_bf16: output buffer too small: need {} bytes, have {}",
215 elem_bytes,
216 output.byte_len()
217 )));
218 }
219
220 let pipeline = registry.get_pipeline("scalar_mul_bf16", device)?;
221
222 let gpu_params = GpuScalarMulParams {
223 scalar,
224 count: n_elements as u32,
225 };
226
227 let grid = MTLSize::new(n_elements as u64, 1, 1);
228 let tg = MTLSize::new(std::cmp::min(ELEMENTWISE_TG_SIZE, n_elements as u64), 1, 1);
229
230 encode_with_args(
231 encoder,
232 pipeline,
233 &[
234 (0, KernelArg::Buffer(input)),
235 (1, KernelArg::Buffer(output)),
236 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
237 ],
238 grid,
239 tg,
240 );
241
242 Ok(())
243}
244
245pub fn scalar_mul_f32(
261 encoder: &mut CommandEncoder,
262 registry: &mut KernelRegistry,
263 device: &metal::DeviceRef,
264 input: &MlxBuffer,
265 output: &MlxBuffer,
266 n_elements: usize,
267 scalar: f32,
268) -> Result<()> {
269 if n_elements == 0 {
270 return Err(MlxError::InvalidArgument(
271 "scalar_mul_f32: n_elements must be > 0".into(),
272 ));
273 }
274
275 let elem_bytes = n_elements * DType::F32.size_of();
276 if input.byte_len() < elem_bytes {
277 return Err(MlxError::InvalidArgument(format!(
278 "scalar_mul_f32: input buffer too small: need {} bytes, have {}",
279 elem_bytes,
280 input.byte_len()
281 )));
282 }
283 if output.byte_len() < elem_bytes {
284 return Err(MlxError::InvalidArgument(format!(
285 "scalar_mul_f32: output buffer too small: need {} bytes, have {}",
286 elem_bytes,
287 output.byte_len()
288 )));
289 }
290
291 let pipeline = registry.get_pipeline("scalar_mul_f32", device)?;
292
293 let gpu_params = GpuScalarMulParams {
294 scalar,
295 count: n_elements as u32,
296 };
297
298 let grid = MTLSize::new(n_elements as u64, 1, 1);
299 let tg = MTLSize::new(std::cmp::min(ELEMENTWISE_TG_SIZE, n_elements as u64), 1, 1);
300
301 encode_with_args(
302 encoder,
303 pipeline,
304 &[
305 (0, KernelArg::Buffer(input)),
306 (1, KernelArg::Buffer(output)),
307 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
308 ],
309 grid,
310 tg,
311 );
312
313 Ok(())
314}
315
316#[repr(C)]
318#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
319struct GpuEmbedGatherScaleParams {
320 scale: f32,
321 hidden_size: u32,
322 token_id: u32,
323}
324
325pub fn embedding_gather_scale_f32(
338 encoder: &mut CommandEncoder,
339 registry: &mut KernelRegistry,
340 device: &metal::DeviceRef,
341 embed_table: &MlxBuffer,
342 output: &MlxBuffer,
343 token_id: u32,
344 hidden_size: usize,
345 scale: f32,
346) -> Result<()> {
347 if hidden_size == 0 {
348 return Err(MlxError::InvalidArgument(
349 "embedding_gather_scale_f32: hidden_size must be > 0".into(),
350 ));
351 }
352 let out_bytes = hidden_size * std::mem::size_of::<f32>();
353 if output.byte_len() < out_bytes {
354 return Err(MlxError::InvalidArgument(format!(
355 "embedding_gather_scale_f32: output too small: need {} bytes, have {}",
356 out_bytes, output.byte_len()
357 )));
358 }
359
360 let pipeline = registry.get_pipeline("embedding_gather_scale_f32", device)?;
361
362 let gpu_params = GpuEmbedGatherScaleParams {
363 scale,
364 hidden_size: hidden_size as u32,
365 token_id,
366 };
367
368 let grid = MTLSize::new(hidden_size as u64, 1, 1);
369 let tg = MTLSize::new(std::cmp::min(ELEMENTWISE_TG_SIZE, hidden_size as u64), 1, 1);
370
371 encode_with_args(
372 encoder,
373 pipeline,
374 &[
375 (0, KernelArg::Buffer(embed_table)),
376 (1, KernelArg::Buffer(output)),
377 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
378 ],
379 grid,
380 tg,
381 );
382
383 Ok(())
384}
385
386#[repr(C)]
389#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
390struct GpuEmbedGatherScaleBatchParams {
391 scale: f32,
392 hidden_size: u32,
393 n_tokens: u32,
394}
395
396#[allow(clippy::too_many_arguments)]
406pub fn embedding_gather_scale_batch_f32(
407 encoder: &mut CommandEncoder,
408 registry: &mut KernelRegistry,
409 device: &metal::DeviceRef,
410 embed_table: &MlxBuffer,
411 token_ids: &MlxBuffer,
412 output: &MlxBuffer,
413 hidden_size: usize,
414 n_tokens: usize,
415 scale: f32,
416) -> Result<()> {
417 if hidden_size == 0 || n_tokens == 0 {
418 return Err(MlxError::InvalidArgument(
419 "embedding_gather_scale_batch_f32: hidden_size and n_tokens must be > 0".into(),
420 ));
421 }
422 let out_bytes = n_tokens * hidden_size * std::mem::size_of::<f32>();
423 if output.byte_len() < out_bytes {
424 return Err(MlxError::InvalidArgument(format!(
425 "embedding_gather_scale_batch_f32: output too small: need {} bytes, have {}",
426 out_bytes, output.byte_len()
427 )));
428 }
429 let ids_bytes = n_tokens * std::mem::size_of::<u32>();
430 if token_ids.byte_len() < ids_bytes {
431 return Err(MlxError::InvalidArgument(format!(
432 "embedding_gather_scale_batch_f32: token_ids too small: need {} bytes, have {}",
433 ids_bytes, token_ids.byte_len()
434 )));
435 }
436
437 let pipeline = registry.get_pipeline("embedding_gather_scale_batch_f32", device)?;
438
439 let gpu_params = GpuEmbedGatherScaleBatchParams {
440 scale,
441 hidden_size: hidden_size as u32,
442 n_tokens: n_tokens as u32,
443 };
444
445 let grid = MTLSize::new(hidden_size as u64, n_tokens as u64, 1);
446 let tg = MTLSize::new(
447 std::cmp::min(ELEMENTWISE_TG_SIZE, hidden_size as u64),
448 1, 1,
449 );
450
451 encode_with_args(
452 encoder,
453 pipeline,
454 &[
455 (0, KernelArg::Buffer(embed_table)),
456 (1, KernelArg::Buffer(token_ids)),
457 (2, KernelArg::Buffer(output)),
458 (3, KernelArg::Bytes(as_bytes(&gpu_params))),
459 ],
460 grid,
461 tg,
462 );
463
464 Ok(())
465}
466
467pub fn dispatch_cast_f32_to_bf16_with_encoder(
487 encoder: &mut CommandEncoder,
488 registry: &mut KernelRegistry,
489 device: &metal::DeviceRef,
490 input: &MlxBuffer,
491 output: &MlxBuffer,
492 n_elements: u32,
493) -> Result<()> {
494 cast(
495 encoder,
496 registry,
497 device,
498 input,
499 output,
500 n_elements as usize,
501 CastDirection::F32ToBF16,
502 )
503}
504
505pub fn dispatch_cast_bf16_to_f32_with_encoder(
525 encoder: &mut CommandEncoder,
526 registry: &mut KernelRegistry,
527 device: &metal::DeviceRef,
528 input: &MlxBuffer,
529 output: &MlxBuffer,
530 n_elements: u32,
531) -> Result<()> {
532 cast(
533 encoder,
534 registry,
535 device,
536 input,
537 output,
538 n_elements as usize,
539 CastDirection::BF16ToF32,
540 )
541}
542
543pub fn dispatch_scalar_mul_bf16_with_encoder(
564 encoder: &mut CommandEncoder,
565 registry: &mut KernelRegistry,
566 device: &metal::DeviceRef,
567 input: &MlxBuffer,
568 output: &MlxBuffer,
569 n_elements: u32,
570 scalar: f32,
571) -> Result<()> {
572 scalar_mul_bf16(
573 encoder,
574 registry,
575 device,
576 input,
577 output,
578 n_elements as usize,
579 scalar,
580 )
581}
582
583pub enum CastDirection {
585 F16ToF32,
587 F32ToF16,
589 BF16ToF32,
591 F32ToBF16,
593}
594
595impl CastDirection {
596 fn kernel_name(&self) -> &'static str {
597 match self {
598 CastDirection::F16ToF32 => "cast_f16_to_f32",
599 CastDirection::F32ToF16 => "cast_f32_to_f16",
600 CastDirection::BF16ToF32 => "cast_bf16_to_f32",
601 CastDirection::F32ToBF16 => "cast_f32_to_bf16",
602 }
603 }
604
605 fn input_elem_size(&self) -> usize {
606 match self {
607 CastDirection::F16ToF32 | CastDirection::BF16ToF32 => 2,
608 CastDirection::F32ToF16 | CastDirection::F32ToBF16 => 4,
609 }
610 }
611
612 fn output_elem_size(&self) -> usize {
613 match self {
614 CastDirection::F16ToF32 | CastDirection::BF16ToF32 => 4,
615 CastDirection::F32ToF16 | CastDirection::F32ToBF16 => 2,
616 }
617 }
618}
619
620pub fn cast(
627 encoder: &mut CommandEncoder,
628 registry: &mut KernelRegistry,
629 device: &metal::DeviceRef,
630 input: &MlxBuffer,
631 output: &MlxBuffer,
632 n_elements: usize,
633 direction: CastDirection,
634) -> Result<()> {
635 if n_elements == 0 {
636 return Err(MlxError::InvalidArgument(
637 "cast: n_elements must be > 0".into(),
638 ));
639 }
640
641 let input_bytes = n_elements * direction.input_elem_size();
642 if input.byte_len() < input_bytes {
643 return Err(MlxError::InvalidArgument(format!(
644 "cast: input buffer too small: need {} bytes, have {}",
645 input_bytes,
646 input.byte_len()
647 )));
648 }
649
650 let output_bytes = n_elements * direction.output_elem_size();
651 if output.byte_len() < output_bytes {
652 return Err(MlxError::InvalidArgument(format!(
653 "cast: output buffer too small: need {} bytes, have {}",
654 output_bytes,
655 output.byte_len()
656 )));
657 }
658
659 let pipeline = registry.get_pipeline(direction.kernel_name(), device)?;
660
661 let gpu_params = GpuElementwiseParams {
662 n_elements: n_elements as u32,
663 };
664
665 let grid = MTLSize::new(n_elements as u64, 1, 1);
666 let tg = MTLSize::new(std::cmp::min(ELEMENTWISE_TG_SIZE, n_elements as u64), 1, 1);
667
668 encode_with_args(
669 encoder,
670 pipeline,
671 &[
672 (0, KernelArg::Buffer(input)),
673 (1, KernelArg::Buffer(output)),
674 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
675 ],
676 grid,
677 tg,
678 );
679
680 Ok(())
681}