Skip to main content

mlx_native/ops/
softcap.rs

1//! Softcap (tanh-based logit capping) GPU dispatch.
2//!
3//! Computes: `tanh(logits / cap) * cap`
4//!
5//! This bounds output logits to the range `(-cap, +cap)`.  Gemma 4 uses
6//! `cap = 30.0` for final logit capping.
7
8use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::dtypes::DType;
12use crate::encoder::CommandEncoder;
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16/// MSL source for the softcap kernels (embedded at compile time).
17pub static SOFTCAP_SHADER_SOURCE: &str = include_str!("../shaders/softcap.metal");
18
19/// Register softcap shader sources with the given kernel registry.
20pub fn register(registry: &mut KernelRegistry) {
21    registry.register_source("softcap_f32", SOFTCAP_SHADER_SOURCE);
22    registry.register_source("softcap_f16", SOFTCAP_SHADER_SOURCE);
23    registry.register_source("softcap_bf16", SOFTCAP_SHADER_SOURCE);
24}
25
26/// Dispatch a softcap operation on the GPU.
27///
28/// # Arguments
29///
30/// * `encoder`    - Command encoder to record the dispatch into.
31/// * `registry`   - Kernel registry (must have softcap sources registered).
32/// * `device`     - Metal device for pipeline compilation.
33/// * `input`      - Input buffer (f32, f16, or bf16).
34/// * `output`     - Output buffer (same dtype and shape as input).
35/// * `params_buf` - Params buffer containing `[cap, n_elements_as_f32_bits]` as two f32 values.
36/// * `cap`        - The capping value (e.g. 30.0).
37///
38/// # Errors
39///
40/// Returns `MlxError::InvalidArgument` if:
41/// - Input dtype is not f32, f16, or bf16.
42/// - Input and output element counts do not match.
43/// - Cap is not positive.
44pub fn dispatch_softcap(
45    encoder: &mut CommandEncoder,
46    registry: &mut KernelRegistry,
47    device: &metal::DeviceRef,
48    input: &MlxBuffer,
49    output: &MlxBuffer,
50    params_buf: &MlxBuffer,
51    cap: f32,
52) -> Result<()> {
53    if cap <= 0.0 {
54        return Err(MlxError::InvalidArgument(format!(
55            "Softcap cap must be positive, got {}",
56            cap
57        )));
58    }
59
60    let n = input.element_count();
61    if n == 0 {
62        return Err(MlxError::InvalidArgument(
63            "Softcap input must have at least one element".into(),
64        ));
65    }
66    if output.element_count() != n {
67        return Err(MlxError::InvalidArgument(format!(
68            "Softcap output element count {} != input element count {}",
69            output.element_count(),
70            n
71        )));
72    }
73
74    let _ = cap; // cap value is passed via params_buf
75
76    let kernel_name = match input.dtype() {
77        DType::F32 => "softcap_f32",
78        DType::F16 => "softcap_f16",
79        DType::BF16 => "softcap_bf16",
80        _ => {
81            return Err(MlxError::InvalidArgument(format!(
82                "Softcap unsupported dtype: {}",
83                input.dtype()
84            )));
85        }
86    };
87
88    let pipeline = registry.get_pipeline(kernel_name, device)?;
89    let threadgroup_size: u64 = std::cmp::min(256, n as u64);
90    let threadgroup_count = (n as u64 + threadgroup_size - 1) / threadgroup_size;
91
92    encoder.encode_threadgroups(
93        pipeline,
94        &[(0, input), (1, output), (2, params_buf)],
95        MTLSize::new(threadgroup_count, 1, 1),
96        MTLSize::new(threadgroup_size, 1, 1),
97    );
98
99    Ok(())
100}