mlx_native/ops/
softcap.rs1use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::dtypes::DType;
12use crate::encoder::CommandEncoder;
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16pub static SOFTCAP_SHADER_SOURCE: &str = include_str!("../shaders/softcap.metal");
18
19pub fn register(registry: &mut KernelRegistry) {
21 registry.register_source("softcap_f32", SOFTCAP_SHADER_SOURCE);
22 registry.register_source("softcap_f16", SOFTCAP_SHADER_SOURCE);
23 registry.register_source("softcap_bf16", SOFTCAP_SHADER_SOURCE);
24}
25
26pub fn dispatch_softcap(
45 encoder: &mut CommandEncoder,
46 registry: &mut KernelRegistry,
47 device: &metal::DeviceRef,
48 input: &MlxBuffer,
49 output: &MlxBuffer,
50 params_buf: &MlxBuffer,
51 cap: f32,
52) -> Result<()> {
53 if cap <= 0.0 {
54 return Err(MlxError::InvalidArgument(format!(
55 "Softcap cap must be positive, got {}",
56 cap
57 )));
58 }
59
60 let n = input.element_count();
61 if n == 0 {
62 return Err(MlxError::InvalidArgument(
63 "Softcap input must have at least one element".into(),
64 ));
65 }
66 if output.element_count() != n {
67 return Err(MlxError::InvalidArgument(format!(
68 "Softcap output element count {} != input element count {}",
69 output.element_count(),
70 n
71 )));
72 }
73
74 let _ = cap; let kernel_name = match input.dtype() {
77 DType::F32 => "softcap_f32",
78 DType::F16 => "softcap_f16",
79 DType::BF16 => "softcap_bf16",
80 _ => {
81 return Err(MlxError::InvalidArgument(format!(
82 "Softcap unsupported dtype: {}",
83 input.dtype()
84 )));
85 }
86 };
87
88 let pipeline = registry.get_pipeline(kernel_name, device)?;
89 let threadgroup_size: u64 = std::cmp::min(256, n as u64);
90 let threadgroup_count = (n as u64 + threadgroup_size - 1) / threadgroup_size;
91
92 encoder.encode_threadgroups(
93 pipeline,
94 &[(0, input), (1, output), (2, params_buf)],
95 MTLSize::new(threadgroup_count, 1, 1),
96 MTLSize::new(threadgroup_size, 1, 1),
97 );
98
99 Ok(())
100}