1use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::dtypes::DType;
12use crate::encoder::CommandEncoder;
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16pub static GELU_SHADER_SOURCE: &str = include_str!("../shaders/gelu.metal");
18
19pub fn register(registry: &mut KernelRegistry) {
23 registry.register_source("gelu_f32", GELU_SHADER_SOURCE);
24 registry.register_source("gelu_f16", GELU_SHADER_SOURCE);
25 registry.register_source("gelu_bf16", GELU_SHADER_SOURCE);
26}
27
28pub fn dispatch_gelu(
44 encoder: &mut CommandEncoder,
45 registry: &mut KernelRegistry,
46 device: &metal::DeviceRef,
47 input: &MlxBuffer,
48 output: &MlxBuffer,
49) -> Result<()> {
50 let n = input.element_count();
51 if n == 0 {
52 return Err(MlxError::InvalidArgument(
53 "GELU input must have at least one element".into(),
54 ));
55 }
56 if output.element_count() != n {
57 return Err(MlxError::InvalidArgument(format!(
58 "GELU output element count {} != input element count {}",
59 output.element_count(),
60 n
61 )));
62 }
63
64 let kernel_name = match input.dtype() {
65 DType::F32 => "gelu_f32",
66 DType::F16 => "gelu_f16",
67 DType::BF16 => "gelu_bf16",
68 _ => {
69 return Err(MlxError::InvalidArgument(format!(
70 "GELU unsupported dtype: {}",
71 input.dtype()
72 )));
73 }
74 };
75
76 let pipeline = registry.get_pipeline(kernel_name, device)?;
77 let thread_count = n as u64;
78 let threadgroup_size = std::cmp::min(256, thread_count);
79
80 encoder.encode(
81 pipeline,
82 &[(0, input), (1, output)],
83 MTLSize::new(thread_count, 1, 1),
84 MTLSize::new(threadgroup_size, 1, 1),
85 );
86
87 Ok(())
88}