Skip to main content

mlx_native/ops/
gelu.rs

1//! GELU activation (pytorch_tanh variant) GPU dispatch.
2//!
3//! Computes: `0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
4//!
5//! This is the exact variant used by Gemma 4. It is **not** the erf-based
6//! GELU approximation.
7
8use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::dtypes::DType;
12use crate::encoder::CommandEncoder;
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16/// MSL source for the GELU kernels (embedded at compile time).
17pub static GELU_SHADER_SOURCE: &str = include_str!("../shaders/gelu.metal");
18
19/// Register GELU shader sources with the given kernel registry.
20///
21/// This must be called before dispatching any GELU operations.
22pub fn register(registry: &mut KernelRegistry) {
23    registry.register_source("gelu_f32", GELU_SHADER_SOURCE);
24    registry.register_source("gelu_f16", GELU_SHADER_SOURCE);
25    registry.register_source("gelu_bf16", GELU_SHADER_SOURCE);
26}
27
28/// Dispatch a GELU activation on the GPU.
29///
30/// # Arguments
31///
32/// * `encoder`  - Command encoder to record the dispatch into.
33/// * `registry` - Kernel registry (must have GELU sources registered).
34/// * `device`   - Metal device for pipeline compilation.
35/// * `input`    - Input buffer (f32, f16, or bf16).
36/// * `output`   - Output buffer (same dtype and shape as input).
37///
38/// # Errors
39///
40/// Returns `MlxError::InvalidArgument` if:
41/// - Input dtype is not f32, f16, or bf16.
42/// - Input and output element counts do not match.
43pub fn dispatch_gelu(
44    encoder: &mut CommandEncoder,
45    registry: &mut KernelRegistry,
46    device: &metal::DeviceRef,
47    input: &MlxBuffer,
48    output: &MlxBuffer,
49) -> Result<()> {
50    let n = input.element_count();
51    if n == 0 {
52        return Err(MlxError::InvalidArgument(
53            "GELU input must have at least one element".into(),
54        ));
55    }
56    if output.element_count() != n {
57        return Err(MlxError::InvalidArgument(format!(
58            "GELU output element count {} != input element count {}",
59            output.element_count(),
60            n
61        )));
62    }
63
64    let kernel_name = match input.dtype() {
65        DType::F32 => "gelu_f32",
66        DType::F16 => "gelu_f16",
67        DType::BF16 => "gelu_bf16",
68        _ => {
69            return Err(MlxError::InvalidArgument(format!(
70                "GELU unsupported dtype: {}",
71                input.dtype()
72            )));
73        }
74    };
75
76    let pipeline = registry.get_pipeline(kernel_name, device)?;
77    let thread_count = n as u64;
78    let threadgroup_size = std::cmp::min(256, thread_count);
79
80    encoder.encode(
81        pipeline,
82        &[(0, input), (1, output)],
83        MTLSize::new(thread_count, 1, 1),
84        MTLSize::new(threadgroup_size, 1, 1),
85    );
86
87    Ok(())
88}