Skip to main content

hanzo_nn/
ops.rs

1//! Tensor ops.
2//!
3
4use hanzo_ml::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
5use rayon::prelude::*;
6
7/// Applies the softmax function to the input tensor, rescaling the element so that elements on
8/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
9///
10/// ```rust
11/// use hanzo_ml::{Tensor, Device, test_utils::to_vec2_round};
12/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
13/// let a = hanzo_nn::ops::softmax(&a, 1)?;
14/// assert_eq!(
15///     to_vec2_round(&a, 4)?,
16///     &[
17///         [0.1345, 0.3655, 0.1345, 0.3655],
18///         [0.0049, 0.2671, 0.7262, 0.0018]
19///     ]);
20/// # Ok::<(), hanzo_ml::Error>(())
21/// ```
22pub fn softmax<D: hanzo_ml::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
23    let dim = dim.to_index(xs.shape(), "softmax")?;
24    let max = xs.max_keepdim(dim)?;
25    let diff = xs.broadcast_sub(&max)?;
26    let num = diff.exp()?;
27    let den = num.sum_keepdim(dim)?;
28    num.broadcast_div(&den)
29}
30
31pub fn log_softmax<D: hanzo_ml::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
32    let d = d.to_index(xs.shape(), "log-softmax")?;
33    let max = xs.max_keepdim(d)?;
34    let diff = xs.broadcast_sub(&max)?;
35    let sum_exp = diff.exp()?.sum_keepdim(d)?;
36    let log_sm = diff.broadcast_sub(&sum_exp.log()?)?;
37    Ok(log_sm)
38}
39
40pub fn silu(xs: &Tensor) -> Result<Tensor> {
41    xs.silu()
42}
43
44pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
45    let xs = xs.chunk(2, D::Minus1)?;
46    &xs[0].silu()? * &xs[1]
47}
48
49struct Sigmoid;
50
51impl hanzo_ml::CustomOp1 for Sigmoid {
52    fn name(&self) -> &'static str {
53        "sigmoid"
54    }
55
56    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
57        use hanzo_ml::backend::BackendStorage;
58
59        fn fwd<T: num_traits::Float>(v: T) -> T {
60            (v.neg().exp() + T::one()).recip()
61        }
62
63        // FIXME: using `hanzo_ml::map_dtype` causes compilation errors.
64        let storage = match storage {
65            CpuStorage::BF16(slice) => {
66                CpuStorage::BF16(hanzo_ml::cpu_backend::unary_map(slice, layout, fwd))
67            }
68            CpuStorage::F16(slice) => {
69                CpuStorage::F16(hanzo_ml::cpu_backend::unary_map(slice, layout, fwd))
70            }
71            CpuStorage::F32(slice) => {
72                CpuStorage::F32(hanzo_ml::cpu_backend::unary_map(slice, layout, fwd))
73            }
74            CpuStorage::F64(slice) => {
75                CpuStorage::F64(hanzo_ml::cpu_backend::unary_map(slice, layout, fwd))
76            }
77            _ => Err(hanzo_ml::Error::UnsupportedDTypeForOp(
78                storage.dtype(),
79                self.name(),
80            ))?,
81        };
82        Ok((storage, layout.shape().clone()))
83    }
84
85    #[cfg(feature = "cuda")]
86    fn cuda_fwd(
87        &self,
88        storage: &hanzo_ml::CudaStorage,
89        layout: &Layout,
90    ) -> Result<(hanzo_ml::CudaStorage, Shape)> {
91        use hanzo_ml::backend::BackendStorage;
92        use hanzo_ml::cuda_backend::cudarc::driver::{
93            CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
94        };
95        use hanzo_ml::cuda_backend::SlicePtrOrNull;
96        use hanzo_ml::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
97        use hanzo_ml::{CudaDevice, WithDType};
98
99        struct S;
100        impl Map1 for S {
101            fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
102                &self,
103                src: &CudaSlice<T>,
104                dev: &CudaDevice,
105                layout: &Layout,
106            ) -> Result<CudaSlice<T>> {
107                let shape = layout.shape();
108                let dims = shape.dims();
109                let el_count = shape.elem_count();
110                let cfg = LaunchConfig::for_num_elems(el_count as u32);
111                let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
112                let src = &src.slice(layout.start_offset()..);
113                let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
114                // SAFETY: Set later by running the kernel.
115                let out = unsafe { dev.alloc::<T>(el_count)? };
116
117                let mut builder = func.builder();
118                hanzo_ml::builder_arg!(builder, el_count, dims.len());
119                ds.builder_arg(&mut builder);
120                builder.arg(src);
121                builder.arg(&out);
122                // SAFETY: ffi.
123                unsafe { builder.launch(cfg) }.w()?;
124                Ok(out)
125            }
126        }
127
128        let dev = storage.device();
129        let slice = S.map(&storage.slice, dev, layout)?;
130        let dst = hanzo_ml::CudaStorage {
131            slice,
132            device: dev.clone(),
133        };
134        Ok((dst, layout.shape().clone()))
135    }
136
137    #[cfg(feature = "metal")]
138    fn metal_fwd(
139        &self,
140        storage: &hanzo_ml::MetalStorage,
141        layout: &Layout,
142    ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
143        use hanzo_ml::backend::BackendStorage;
144        use hanzo_ml::MetalError;
145        let device = storage.device();
146        let dtype = storage.dtype();
147        let shape = layout.shape();
148        let el_count = shape.elem_count();
149        let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
150        let encoder = device.command_encoder()?;
151        encoder.set_label("sigmoid");
152        let src = hanzo_metal_kernels::BufferOffset {
153            buffer: storage.buffer(),
154            offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
155        };
156
157        if layout.is_contiguous() {
158            use hanzo_metal_kernels::unary::contiguous;
159            let kernel_name = match dtype {
160                DType::F16 => contiguous::sigmoid::HALF,
161                DType::F32 => contiguous::sigmoid::FLOAT,
162                DType::BF16 => contiguous::sigmoid::BFLOAT,
163                dtype => {
164                    hanzo_ml::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
165                }
166            };
167            hanzo_metal_kernels::call_unary_contiguous(
168                device.metal_device(),
169                &encoder,
170                device.kernels(),
171                kernel_name,
172                dtype.size_in_bytes(),
173                el_count,
174                src,
175                &buffer,
176            )
177            .map_err(MetalError::from)?;
178        } else {
179            use hanzo_metal_kernels::unary::strided;
180            let kernel_name = match dtype {
181                DType::F16 => strided::sigmoid::HALF,
182                DType::F32 => strided::sigmoid::FLOAT,
183                DType::BF16 => strided::sigmoid::BFLOAT,
184                dtype => {
185                    hanzo_ml::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
186                }
187            };
188            let dst = hanzo_metal_kernels::BufferOffset::zero_offset(&buffer);
189            hanzo_metal_kernels::call_unary_strided(
190                device.metal_device(),
191                &encoder,
192                device.kernels(),
193                kernel_name,
194                layout.dims(),
195                src,
196                layout.stride(),
197                dst,
198            )
199            .map_err(MetalError::from)?;
200        }
201
202        let new_storage = hanzo_ml::MetalStorage::new(buffer, device.clone(), el_count, dtype);
203        Ok((new_storage, layout.shape().clone()))
204    }
205
206    fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
207        // d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x)
208        let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?;
209        Ok(Some(grad_res.mul(&d_dx_sigmoid)?))
210    }
211}
212
213pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
214    xs.apply_op1(Sigmoid)
215}
216
217pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
218    // TODO: Should we have a specialized op for this?
219    ((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32)
220}
221
222pub fn mish(xs: &Tensor) -> Result<Tensor> {
223    xs * (1.0 + xs.exp()?)?.log()?.tanh()
224}
225
226pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
227    let zeros = xs.zeros_like()?;
228    xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
229}
230
231pub fn selu(xs: &Tensor, alpha: f32, gamma: f32) -> Result<Tensor> {
232    let is_pos = xs.gt(0f32)?;
233    let alpha_t = Tensor::full(alpha, xs.dims(), xs.device())?;
234    let neg = xs.exp()?.mul(&alpha_t)?.sub(&alpha_t)?;
235    let selu = is_pos.where_cond(xs, &neg)?;
236    let gamma_t = Tensor::full(gamma, xs.dims(), xs.device())?;
237    selu.broadcast_mul(&gamma_t)
238}
239
240pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
241    // This implementation is inefficient as it stores the full mask for the backward pass.
242    // Instead we could just store the seed and have a specialized kernel that would both
243    // generate the random mask and apply it.
244    // Another easier optimization would be to be able to generate boolean mask using just a bit of
245    // entropy per element rather than generating a full float per element.
246    if !(0. ..1.).contains(&drop_p) {
247        hanzo_ml::bail!("dropout probability has to be in [0, 1), got {drop_p}")
248    }
249    let rand = Tensor::rand(0f32, 1f32, xs.shape(), xs.device())?;
250    let scale = 1.0 / (1.0 - drop_p as f64);
251    let drop_p = Tensor::new(drop_p, xs.device())?.broadcast_as(xs.shape())?;
252    let mask = (rand.ge(&drop_p)?.to_dtype(xs.dtype())? * scale)?;
253    xs * mask
254}
255
256#[derive(Clone, Debug)]
257pub struct Dropout {
258    drop_p: f32,
259}
260
261impl Dropout {
262    pub fn new(drop_p: f32) -> Dropout {
263        Self { drop_p }
264    }
265
266    pub fn forward(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
267        if train {
268            dropout(xs, self.drop_p)
269        } else {
270            Ok(xs.clone())
271        }
272    }
273}
274
275impl hanzo_ml::ModuleT for Dropout {
276    fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
277        self.forward(xs, train)
278    }
279}
280
281struct SoftmaxLastDim;
282
283impl hanzo_ml::CustomOp1 for SoftmaxLastDim {
284    fn name(&self) -> &'static str {
285        "softmax-last-dim"
286    }
287
288    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
289        fn softmax<T: hanzo_ml::WithDType + num_traits::Float>(
290            src: &[T],
291            layout: &Layout,
292        ) -> Result<(CpuStorage, Shape)> {
293            let src = match layout.contiguous_offsets() {
294                None => hanzo_ml::bail!("input has to be contiguous"),
295                Some((o1, o2)) => &src[o1..o2],
296            };
297            let el_count = layout.shape().elem_count();
298            let dims = layout.shape().dims();
299            let dim_m1 = dims[dims.len() - 1];
300            let mut dst = vec![T::zero(); el_count];
301            src.par_chunks(dim_m1)
302                .zip(dst.par_chunks_mut(dim_m1))
303                .for_each(|(src, dst)| {
304                    let mut max = T::neg_infinity();
305                    unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) };
306                    for (s, d) in src.iter().zip(dst.iter_mut()) {
307                        *d = (*s - max).exp();
308                    }
309                    let mut sum_exp = T::zero();
310                    unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) };
311                    for d in dst.iter_mut() {
312                        *d /= sum_exp
313                    }
314                });
315            let storage = hanzo_ml::WithDType::to_cpu_storage_owned(dst);
316            Ok((storage, Shape::from_dims(dims)))
317        }
318
319        match storage {
320            CpuStorage::BF16(slice) => softmax::<half::bf16>(slice, layout),
321            CpuStorage::F16(slice) => softmax::<half::f16>(slice, layout),
322            CpuStorage::F32(slice) => softmax::<f32>(slice, layout),
323            CpuStorage::F64(slice) => softmax::<f64>(slice, layout),
324            _ => hanzo_ml::bail!("unsupported dtype for softmax {:?}", storage),
325        }
326    }
327
328    #[cfg(feature = "vulkan")]
329    fn vulkan_fwd(
330        &self,
331        storage: &hanzo_ml::VulkanStorage,
332        layout: &Layout,
333    ) -> Result<(hanzo_ml::VulkanStorage, Shape)> {
334        let out = storage.softmax_last_dim(layout)?;
335        Ok((out, layout.shape().clone()))
336    }
337
338    #[cfg(feature = "cuda")]
339    fn cuda_fwd(
340        &self,
341        storage: &hanzo_ml::CudaStorage,
342        layout: &Layout,
343    ) -> Result<(hanzo_ml::CudaStorage, Shape)> {
344        use hanzo_ml::cuda_backend::cudarc::driver::{
345            CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
346        };
347        use hanzo_ml::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
348        use hanzo_ml::{CudaDevice, WithDType};
349
350        struct S;
351        impl Map1 for S {
352            fn f<T: DeviceRepr + WithDType>(
353                &self,
354                src: &CudaSlice<T>,
355                dev: &CudaDevice,
356                layout: &Layout,
357            ) -> Result<CudaSlice<T>> {
358                let src = match layout.contiguous_offsets() {
359                    None => hanzo_ml::bail!("input has to be contiguous"),
360                    Some((o1, o2)) => src.slice(o1..o2),
361                };
362                let el = layout.shape().elem_count();
363                let dims = layout.shape().dims();
364                let dim_m1 = dims[dims.len() - 1];
365                let (n_rows, n_cols) = (el / dim_m1, dim_m1);
366
367                let cfg = LaunchConfig {
368                    grid_dim: (n_rows as u32, 1, 1),
369                    block_dim: (1, 32, 1),
370                    shared_mem_bytes: 0,
371                };
372                let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
373                // SAFETY: Set later by running the kernel.
374                let dst = unsafe { dev.alloc::<T>(el)? };
375                let mut builder = func.builder();
376                builder.arg(&src);
377                builder.arg(&dst);
378                hanzo_ml::builder_arg!(builder, n_cols as i32);
379                // SAFETY: ffi.
380                unsafe { builder.launch(cfg) }.w()?;
381                Ok(dst)
382            }
383        }
384
385        use hanzo_ml::backend::BackendStorage;
386        let dev = storage.device();
387        let slice = S.map(&storage.slice, dev, layout)?;
388        let dst = hanzo_ml::cuda_backend::CudaStorage {
389            slice,
390            device: dev.clone(),
391        };
392        Ok((dst, layout.shape().clone()))
393    }
394
395    #[cfg(feature = "metal")]
396    fn metal_fwd(
397        &self,
398        storage: &hanzo_ml::MetalStorage,
399        layout: &Layout,
400    ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
401        use hanzo_ml::backend::BackendStorage;
402        let device = storage.device();
403        let encoder = device.command_encoder()?;
404        encoder.set_label("softmax");
405        let kernels = device.kernels();
406        let name = match storage.dtype() {
407            DType::F32 => "softmax_f32",
408            DType::F16 => "softmax_f16",
409            DType::BF16 => "softmax_bf16",
410            dtype => hanzo_ml::bail!("softmax-last-dim is not implemented for {dtype:?}"),
411        };
412
413        let n = layout.stride().len();
414        if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
415            hanzo_ml::bail!("Non contiguous softmax-last-dim is not implemented");
416        }
417
418        let last_dim = layout.dims()[layout.shape().rank() - 1];
419        let elem_count = layout.shape().elem_count();
420        let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
421        hanzo_metal_kernels::call_last_softmax(
422            device.metal_device(),
423            &encoder,
424            kernels,
425            name,
426            elem_count,
427            last_dim,
428            storage.buffer(),
429            layout.start_offset() * storage.dtype().size_in_bytes(),
430            &output,
431        )
432        .map_err(hanzo_ml::Error::wrap)?;
433        let newstorage =
434            hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
435        Ok((newstorage, layout.shape().clone()))
436    }
437}
438
439pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
440    if xs.device().is_rocm() {
441        return softmax(xs, D::Minus1);
442    }
443    xs.apply_op1_no_bwd(&SoftmaxLastDim)
444}
445
446#[derive(Debug, Clone)]
447struct RmsNorm {
448    eps: f32,
449}
450
451impl hanzo_ml::CustomOp2 for RmsNorm {
452    fn name(&self) -> &'static str {
453        "rms-norm"
454    }
455
456    fn cpu_fwd(
457        &self,
458        s1: &CpuStorage,
459        l1: &Layout,
460        s2: &CpuStorage,
461        l2: &Layout,
462    ) -> Result<(CpuStorage, Shape)> {
463        use hanzo_ml::backend::BackendStorage;
464
465        let eps = self.eps;
466        fn inner<
467            T: hanzo_ml::WithDType
468                + num_traits::Float
469                + num_traits::AsPrimitive<f32>
470                + num_traits::FromPrimitive,
471        >(
472            src: &[T],
473            layout: &Layout,
474            alpha: &[T],
475            alpha_layout: &Layout,
476            eps: f32,
477        ) -> Result<(CpuStorage, Shape)> {
478            let src = match layout.contiguous_offsets() {
479                None => hanzo_ml::bail!("input has to be contiguous"),
480                Some((o1, o2)) => &src[o1..o2],
481            };
482            let alpha = match alpha_layout.contiguous_offsets() {
483                None => hanzo_ml::bail!("alpha has to be contiguous"),
484                Some((o1, o2)) => &alpha[o1..o2],
485            };
486            let el_count = layout.shape().elem_count();
487            let dims = layout.shape().dims();
488            let dim_m1 = dims[dims.len() - 1];
489            let mut dst = vec![T::zero(); el_count];
490            src.par_chunks(dim_m1)
491                .zip(dst.par_chunks_mut(dim_m1))
492                .for_each(|(src, dst)| {
493                    let sum2 = src
494                        .iter()
495                        .map(|&v| {
496                            let v = v.as_();
497                            v * v
498                        })
499                        .sum::<f32>();
500                    let m = (sum2 / dim_m1 as f32 + eps).sqrt();
501                    let m = T::from_f32(m).unwrap_or_else(T::nan);
502                    for ((d, s), alpha) in dst.iter_mut().zip(src.iter()).zip(alpha) {
503                        *d = *s / m * *alpha
504                    }
505                });
506            let storage = hanzo_ml::WithDType::to_cpu_storage_owned(dst);
507            Ok((storage, Shape::from_dims(dims)))
508        }
509
510        use CpuStorage as C;
511        match (s1, s2) {
512            (C::BF16(s1), C::BF16(s2)) => inner::<half::bf16>(s1, l1, s2, l2, eps),
513            (C::F16(s1), C::F16(s2)) => inner::<half::f16>(s1, l1, s2, l2, eps),
514            (C::F32(s1), C::F32(s2)) => inner::<f32>(s1, l1, s2, l2, eps),
515            _ => hanzo_ml::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
516        }
517    }
518
519    #[cfg(feature = "vulkan")]
520    fn vulkan_fwd(
521        &self,
522        s1: &hanzo_ml::VulkanStorage,
523        l1: &Layout,
524        s2: &hanzo_ml::VulkanStorage,
525        l2: &Layout,
526    ) -> Result<(hanzo_ml::VulkanStorage, Shape)> {
527        let out = s1.rms_norm(l1, s2, l2, self.eps)?;
528        Ok((out, l1.shape().clone()))
529    }
530
531    #[cfg(feature = "cuda")]
532    fn cuda_fwd(
533        &self,
534        s1: &hanzo_ml::CudaStorage,
535        l1: &Layout,
536        s2: &hanzo_ml::CudaStorage,
537        l2: &Layout,
538    ) -> Result<(hanzo_ml::CudaStorage, Shape)> {
539        use hanzo_ml::cuda_backend::cudarc::driver::{
540            CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
541        };
542        use hanzo_ml::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
543        use hanzo_ml::{CudaDevice, WithDType};
544
545        struct S {
546            eps: f32,
547        }
548        impl Map2 for S {
549            fn f<T: DeviceRepr + WithDType>(
550                &self,
551                src: &CudaSlice<T>,
552                layout: &Layout,
553                alpha: &CudaSlice<T>,
554                alpha_layout: &Layout,
555                dev: &CudaDevice,
556            ) -> Result<CudaSlice<T>> {
557                let src = match layout.contiguous_offsets() {
558                    None => hanzo_ml::bail!("input has to be contiguous"),
559                    Some((o1, o2)) => src.slice(o1..o2),
560                };
561                let alpha = match alpha_layout.contiguous_offsets() {
562                    None => hanzo_ml::bail!("alpha has to be contiguous"),
563                    Some((o1, o2)) => alpha.slice(o1..o2),
564                };
565                let el = layout.shape().elem_count();
566                let dims = layout.shape().dims();
567                let dim_m1 = dims[dims.len() - 1];
568                let (n_rows, n_cols) = (el / dim_m1, dim_m1);
569
570                let block_size = if n_cols < 1024 { 32 } else { 1024 };
571                let cfg = LaunchConfig {
572                    grid_dim: (n_rows as u32, 1, 1),
573                    block_dim: (block_size, 1, 1),
574                    shared_mem_bytes: 0,
575                };
576                let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
577                // SAFETY: Set later by running the kernel.
578                let dst = unsafe { dev.alloc::<T>(el)? };
579                let mut builder = func.builder();
580                builder.arg(&src);
581                builder.arg(&dst);
582                builder.arg(&alpha);
583                hanzo_ml::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
584                // SAFETY: ffi.
585                unsafe { builder.launch(cfg) }.w()?;
586                Ok(dst)
587            }
588        }
589
590        use hanzo_ml::backend::BackendStorage;
591        let dev = s1.device();
592        let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, dev)?;
593        let dst = hanzo_ml::cuda_backend::CudaStorage {
594            slice,
595            device: dev.clone(),
596        };
597        Ok((dst, l1.shape().clone()))
598    }
599
600    #[cfg(feature = "metal")]
601    fn metal_fwd(
602        &self,
603        s1: &hanzo_ml::MetalStorage,
604        l1: &Layout,
605        s2: &hanzo_ml::MetalStorage,
606        l2: &Layout,
607    ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
608        use hanzo_ml::backend::BackendStorage;
609        let device = s1.device();
610        let encoder = device.command_encoder()?;
611        encoder.set_label("rmsnorm");
612        let kernels = device.kernels();
613        let name = match (s1.dtype(), s2.dtype()) {
614            (DType::F32, DType::F32) => "rmsnorm_f32",
615            (DType::F16, DType::F16) => "rmsnorm_f16",
616            (DType::BF16, DType::BF16) => "rmsnorm_bf16",
617            (dt1, dt2) => hanzo_ml::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"),
618        };
619
620        if !(l1.is_contiguous() && l2.is_contiguous()) {
621            hanzo_ml::bail!("Non contiguous rmsnorm is not implemented");
622        }
623
624        let last_dim = l1.dims()[l1.shape().rank() - 1];
625        let elem_count = l1.shape().elem_count();
626        let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?;
627        hanzo_metal_kernels::call_rms_norm(
628            device.metal_device(),
629            &encoder,
630            kernels,
631            name,
632            elem_count,
633            last_dim,
634            self.eps,
635            s1.buffer(),
636            l1.start_offset() * s1.dtype().size_in_bytes(),
637            s2.buffer(),
638            l2.start_offset() * s2.dtype().size_in_bytes(),
639            &output,
640        )
641        .map_err(hanzo_ml::Error::wrap)?;
642        let newstorage =
643            hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
644        Ok((newstorage, l1.shape().clone()))
645    }
646}
647
648pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
649    let x_dtype = x.dtype();
650    let internal_dtype = match x_dtype {
651        DType::F16 | DType::BF16 => DType::F32,
652        d => d,
653    };
654    let hidden_size = x.dim(D::Minus1)?;
655    let x = x.to_dtype(internal_dtype)?;
656    let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
657    let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
658    x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
659}
660
661pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
662    let hidden_size_xs = xs.dim(D::Minus1)?;
663    let hidden_size_alpha = alpha.dims1()?;
664    if hidden_size_xs != hidden_size_alpha {
665        hanzo_ml::bail!(
666            "shape mismatch in rms-norm {:?} {:?}",
667            xs.shape(),
668            alpha.shape()
669        )
670    }
671    // ROCm has no fused rms-norm kernel; use the unfused tensor-op path
672    // (real HIP kernels for each sub-op).
673    if xs.device().is_rocm() {
674        return rms_norm_slow(xs, alpha, eps);
675    }
676    xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
677}
678
679// Fused SwiGLU: silu(a) * b, elementwise (both same shape). One op instead of silu + mul.
680struct SiluMul;
681
682impl hanzo_ml::CustomOp2 for SiluMul {
683    fn name(&self) -> &'static str {
684        "silu-mul"
685    }
686
687    fn cpu_fwd(
688        &self,
689        s1: &CpuStorage,
690        l1: &Layout,
691        s2: &CpuStorage,
692        l2: &Layout,
693    ) -> Result<(CpuStorage, Shape)> {
694        fn inner<
695            T: hanzo_ml::WithDType
696                + num_traits::Float
697                + num_traits::AsPrimitive<f32>
698                + num_traits::FromPrimitive,
699        >(
700            a: &[T],
701            la: &Layout,
702            b: &[T],
703            lb: &Layout,
704        ) -> Result<(CpuStorage, Shape)> {
705            let a = match la.contiguous_offsets() {
706                Some((o1, o2)) => &a[o1..o2],
707                None => hanzo_ml::bail!("silu-mul: a must be contiguous"),
708            };
709            let b = match lb.contiguous_offsets() {
710                Some((o1, o2)) => &b[o1..o2],
711                None => hanzo_ml::bail!("silu-mul: b must be contiguous"),
712            };
713            let dst: Vec<T> = a
714                .iter()
715                .zip(b.iter())
716                .map(|(&x, &y)| {
717                    let xf = x.as_();
718                    T::from_f32(xf / (1.0 + (-xf).exp()) * y.as_()).unwrap_or_else(T::nan)
719                })
720                .collect();
721            Ok((
722                hanzo_ml::WithDType::to_cpu_storage_owned(dst),
723                Shape::from_dims(la.shape().dims()),
724            ))
725        }
726        use hanzo_ml::backend::BackendStorage;
727        use CpuStorage as C;
728        match (s1, s2) {
729            (C::BF16(a), C::BF16(b)) => inner::<half::bf16>(a, l1, b, l2),
730            (C::F16(a), C::F16(b)) => inner::<half::f16>(a, l1, b, l2),
731            (C::F32(a), C::F32(b)) => inner::<f32>(a, l1, b, l2),
732            _ => hanzo_ml::bail!("silu-mul: unsupported dtype {:?}", s1.dtype()),
733        }
734    }
735
736    #[cfg(feature = "vulkan")]
737    fn vulkan_fwd(
738        &self,
739        s1: &hanzo_ml::VulkanStorage,
740        l1: &Layout,
741        s2: &hanzo_ml::VulkanStorage,
742        l2: &Layout,
743    ) -> Result<(hanzo_ml::VulkanStorage, Shape)> {
744        let out = s1.silu_mul(l1, s2, l2)?;
745        Ok((out, l1.shape().clone()))
746    }
747}
748
749/// Fused SwiGLU: `silu(gate) * up`. Falls back to the unfused tensor ops where there's no kernel.
750pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Result<Tensor> {
751    if gate.device().is_cuda() || gate.device().is_metal() {
752        // No fused kernel on these yet; compose (silu then mul).
753        return silu(gate)?.mul(up);
754    }
755    gate.apply_op2_no_bwd(up, &SiluMul)
756}
757
758#[derive(Debug, Clone)]
759struct LayerNorm {
760    eps: f32,
761}
762
763impl hanzo_ml::CustomOp3 for LayerNorm {
764    fn name(&self) -> &'static str {
765        "layer-norm"
766    }
767
768    fn cpu_fwd(
769        &self,
770        s1: &CpuStorage,
771        l1: &Layout,
772        s2: &CpuStorage,
773        l2: &Layout,
774        s3: &CpuStorage,
775        l3: &Layout,
776    ) -> Result<(CpuStorage, Shape)> {
777        use hanzo_ml::backend::BackendStorage;
778
779        let eps = self.eps;
780        fn inner<
781            T: hanzo_ml::WithDType
782                + num_traits::Float
783                + num_traits::AsPrimitive<f32>
784                + num_traits::FromPrimitive,
785        >(
786            src: &[T],
787            layout: &Layout,
788            alpha: &[T],
789            alpha_layout: &Layout,
790            beta: &[T],
791            beta_layout: &Layout,
792            eps: f32,
793        ) -> Result<(CpuStorage, Shape)> {
794            let src = match layout.contiguous_offsets() {
795                None => hanzo_ml::bail!("input has to be contiguous"),
796                Some((o1, o2)) => &src[o1..o2],
797            };
798            let alpha = match alpha_layout.contiguous_offsets() {
799                None => hanzo_ml::bail!("alpha has to be contiguous"),
800                Some((o1, o2)) => &alpha[o1..o2],
801            };
802            let beta = match beta_layout.contiguous_offsets() {
803                None => hanzo_ml::bail!("beta has to be contiguous"),
804                Some((o1, o2)) => &beta[o1..o2],
805            };
806            let el_count = layout.shape().elem_count();
807            let dims = layout.shape().dims();
808            let dim_m1 = dims[dims.len() - 1];
809            let mut dst = vec![T::zero(); el_count];
810            src.par_chunks(dim_m1)
811                .zip(dst.par_chunks_mut(dim_m1))
812                .for_each(|(src, dst)| {
813                    let mut sum = 0f32;
814                    let mut sum2 = 0f32;
815                    for v in src {
816                        let v = v.as_();
817                        sum += v;
818                        sum2 += v * v;
819                    }
820                    let mean = sum / dim_m1 as f32;
821                    let var = sum2 / dim_m1 as f32 - mean * mean;
822                    let inv_std = (var + eps).sqrt().recip();
823                    for ((d, s), (alpha, beta)) in
824                        dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta))
825                    {
826                        let alpha = alpha.as_();
827                        let beta = beta.as_();
828                        let d_ = (s.as_() - mean) * inv_std * alpha + beta;
829                        *d = T::from_f32(d_).unwrap_or_else(T::nan);
830                    }
831                });
832            let storage = hanzo_ml::WithDType::to_cpu_storage_owned(dst);
833            Ok((storage, Shape::from_dims(dims)))
834        }
835
836        use CpuStorage as C;
837        match (s1, s2, s3) {
838            (C::BF16(s1), C::BF16(s2), C::BF16(s3)) => {
839                inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps)
840            }
841            (C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps),
842            (C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps),
843            _ => hanzo_ml::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
844        }
845    }
846
847    #[cfg(feature = "cuda")]
848    fn cuda_fwd(
849        &self,
850        s1: &hanzo_ml::CudaStorage,
851        l1: &Layout,
852        s2: &hanzo_ml::CudaStorage,
853        l2: &Layout,
854        s3: &hanzo_ml::CudaStorage,
855        l3: &Layout,
856    ) -> Result<(hanzo_ml::CudaStorage, Shape)> {
857        use hanzo_ml::cuda_backend::cudarc::driver::{
858            CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
859        };
860        use hanzo_ml::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
861        use hanzo_ml::{CudaDevice, WithDType};
862
863        struct S {
864            eps: f32,
865        }
866        impl Map3 for S {
867            fn f<T: DeviceRepr + WithDType>(
868                &self,
869                src: &CudaSlice<T>,
870                layout: &Layout,
871                alpha: &CudaSlice<T>,
872                alpha_layout: &Layout,
873                beta: &CudaSlice<T>,
874                beta_layout: &Layout,
875                dev: &CudaDevice,
876            ) -> Result<CudaSlice<T>> {
877                let src = match layout.contiguous_offsets() {
878                    None => hanzo_ml::bail!("input has to be contiguous"),
879                    Some((o1, o2)) => src.slice(o1..o2),
880                };
881                let alpha = match alpha_layout.contiguous_offsets() {
882                    None => hanzo_ml::bail!("alpha has to be contiguous"),
883                    Some((o1, o2)) => alpha.slice(o1..o2),
884                };
885                let beta = match beta_layout.contiguous_offsets() {
886                    None => hanzo_ml::bail!("beta has to be contiguous"),
887                    Some((o1, o2)) => beta.slice(o1..o2),
888                };
889                let el = layout.shape().elem_count();
890                let dims = layout.shape().dims();
891                let dim_m1 = dims[dims.len() - 1];
892                let (n_rows, n_cols) = (el / dim_m1, dim_m1);
893
894                let block_size = if n_cols < 1024 { 32 } else { 1024 };
895                let cfg = LaunchConfig {
896                    grid_dim: (n_rows as u32, 1, 1),
897                    block_dim: (block_size, 1, 1),
898                    shared_mem_bytes: 0,
899                };
900                let func =
901                    dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
902                // SAFETY: Set later by running the kernel.
903                let dst = unsafe { dev.alloc::<T>(el)? };
904                let mut builder = func.builder();
905                builder.arg(&src);
906                builder.arg(&dst);
907                builder.arg(&alpha);
908                builder.arg(&beta);
909                hanzo_ml::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
910                // SAFETY: ffi.
911                unsafe { builder.launch(cfg) }.w()?;
912                Ok(dst)
913            }
914        }
915
916        use hanzo_ml::backend::BackendStorage;
917        let dev = s1.device();
918        let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?;
919        let dst = hanzo_ml::cuda_backend::CudaStorage {
920            slice,
921            device: dev.clone(),
922        };
923        Ok((dst, l1.shape().clone()))
924    }
925
926    #[cfg(feature = "metal")]
927    fn metal_fwd(
928        &self,
929        s1: &hanzo_ml::MetalStorage,
930        l1: &Layout,
931        s2: &hanzo_ml::MetalStorage,
932        l2: &Layout,
933        s3: &hanzo_ml::MetalStorage,
934        l3: &Layout,
935    ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
936        use hanzo_ml::backend::BackendStorage;
937        let device = s1.device();
938        let encoder = device.command_encoder()?;
939        encoder.set_label("layernorm");
940        let kernels = device.kernels();
941        let name = match (s1.dtype(), s2.dtype(), s3.dtype()) {
942            (DType::F32, DType::F32, DType::F32) => "layernorm_f32",
943            (DType::F16, DType::F16, DType::F16) => "layernorm_f16",
944            (DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16",
945            (dt1, dt2, dt3) => {
946                hanzo_ml::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}")
947            }
948        };
949
950        if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) {
951            hanzo_ml::bail!("Non contiguous layernorm is not implemented");
952        }
953
954        let last_dim = l1.dims()[l1.shape().rank() - 1];
955        let elem_count = l1.shape().elem_count();
956        let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?;
957        hanzo_metal_kernels::call_layer_norm(
958            device.metal_device(),
959            &encoder,
960            kernels,
961            name,
962            elem_count,
963            last_dim,
964            self.eps,
965            s1.buffer(),
966            l1.start_offset() * s1.dtype().size_in_bytes(),
967            s2.buffer(),
968            l2.start_offset() * s2.dtype().size_in_bytes(),
969            s3.buffer(),
970            l3.start_offset() * s3.dtype().size_in_bytes(),
971            &output,
972        )
973        .map_err(hanzo_ml::Error::wrap)?;
974        let newstorage =
975            hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
976        Ok((newstorage, l1.shape().clone()))
977    }
978}
979
980pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
981    let x_dtype = x.dtype();
982    let internal_dtype = match x_dtype {
983        DType::F16 | DType::BF16 => DType::F32,
984        d => d,
985    };
986    let hidden_size = x.dim(D::Minus1)?;
987    let x = x.to_dtype(internal_dtype)?;
988    let x = {
989        let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
990        x.broadcast_sub(&mean_x)?
991    };
992    let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
993    let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
994    x_normed
995        .to_dtype(x_dtype)?
996        .broadcast_mul(alpha)?
997        .broadcast_add(beta)
998}
999
1000pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
1001    let hidden_size_xs = xs.dim(D::Minus1)?;
1002    let hidden_size_alpha = alpha.dims1()?;
1003    let hidden_size_beta = beta.dims1()?;
1004    if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta {
1005        hanzo_ml::bail!(
1006            "shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}",
1007            xs.shape(),
1008            alpha.shape(),
1009            beta.shape()
1010        )
1011    }
1012    if xs.device().is_rocm() {
1013        return layer_norm_slow(xs, alpha, beta, eps);
1014    }
1015    xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps })
1016}
1017
1018// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
1019pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
1020    let (b_size, c, h, w) = xs.dims4()?;
1021    let out_c = c / upscale_factor / upscale_factor;
1022    xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))?
1023        .permute((0, 1, 4, 2, 5, 3))?
1024        .reshape((b_size, out_c, h * upscale_factor, w * upscale_factor))
1025}
1026
1027pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
1028    let (b_size, c, h, w) = xs.dims4()?;
1029    let out_c = c * downscale_factor * downscale_factor;
1030    xs.reshape((
1031        b_size,
1032        c,
1033        h / downscale_factor,
1034        downscale_factor,
1035        w / downscale_factor,
1036        downscale_factor,
1037    ))?
1038    .permute((0, 1, 3, 5, 2, 4))?
1039    .reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
1040}
1041
1042// https://pytorch.org/docs/stable/generated/torch.nn.ReplicationPad2d.html
1043pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
1044    match pad {
1045        0 => Ok(xs.clone()),
1046        1 => {
1047            let (_b_size, _c, h, w) = xs.dims4()?;
1048            let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?);
1049            let xs = Tensor::cat(&[&first, xs, &last], 3)?;
1050            let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?);
1051            Tensor::cat(&[&first, &xs, &last], 2)
1052        }
1053        n => hanzo_ml::bail!("replication-pad with a size of {n} is not supported"),
1054    }
1055}
1056
1057#[derive(Clone, Debug)]
1058pub struct Identity;
1059
1060impl Identity {
1061    pub fn new() -> Identity {
1062        Self
1063    }
1064}
1065
1066impl Default for Identity {
1067    fn default() -> Self {
1068        Self
1069    }
1070}
1071
1072impl Module for Identity {
1073    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1074        Ok(xs.clone())
1075    }
1076}
1077
1078#[allow(dead_code)]
1079struct Sdpa {
1080    scale: f32,
1081    softcapping: f32,
1082    mask: Option<Tensor>,
1083    do_causal: bool,
1084}
1085
1086impl hanzo_ml::CustomOp3 for Sdpa {
1087    fn name(&self) -> &'static str {
1088        "metal-sdpa"
1089    }
1090
1091    fn cpu_fwd(
1092        &self,
1093        _s1: &CpuStorage,
1094        _l1: &Layout,
1095        _s2: &CpuStorage,
1096        _l2: &Layout,
1097        _s3: &CpuStorage,
1098        _l3: &Layout,
1099    ) -> Result<(CpuStorage, Shape)> {
1100        hanzo_ml::bail!("SDPA has no cpu impl")
1101    }
1102
1103    #[cfg(feature = "metal")]
1104    fn metal_fwd(
1105        &self,
1106        q: &hanzo_ml::MetalStorage,
1107        q_l: &Layout,
1108        k: &hanzo_ml::MetalStorage,
1109        k_l: &Layout,
1110        v: &hanzo_ml::MetalStorage,
1111        v_l: &Layout,
1112    ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
1113        use hanzo_metal_kernels::SdpaDType;
1114        use hanzo_ml::backend::BackendStorage;
1115
1116        let device = q.device();
1117
1118        let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?];
1119        let elem_count: usize = out_dims.iter().product();
1120        let out_shape = Shape::from_dims(&out_dims);
1121        let out_layout = Layout::contiguous(out_shape.clone());
1122
1123        let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?;
1124
1125        // q,k must have matching emb dim
1126        if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? {
1127            hanzo_ml::bail!("`q` and `k` last dims must match");
1128        }
1129
1130        // k,v must have matching n kv heads
1131        if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? {
1132            hanzo_ml::bail!("`k` and `v` head dims must match");
1133        }
1134
1135        // n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.
1136        if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 {
1137            hanzo_ml::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
1138        }
1139
1140        let k_head = k_l.dim(D::Minus1)?;
1141        let q_head = q_l.dim(D::Minus1)?;
1142        let q_seq = q_l.dim(2)?;
1143        let k_seq = k_l.dim(2)?;
1144
1145        let mut implementation_supports_use_case = q_head == k_head;
1146        let supported_head_dim = q_head == 32
1147            || q_head == 64
1148            || q_head == 72
1149            || q_head == 80
1150            || q_head == 96
1151            || q_head == 128
1152            || q_head == 256
1153            || q_head == 512;
1154
1155        let supports_sdpa_full_mask = self.mask.is_none() || q_seq <= k_seq;
1156        // F32 full attention at head_dim=512 exceeds 32KB Metal threadgroup memory
1157        let supports_sdpa_full_dtype = !(q_head == 512 && q.dtype() == DType::F32);
1158        let supports_sdpa_full =
1159            q_seq > 8 && supported_head_dim && supports_sdpa_full_mask && supports_sdpa_full_dtype;
1160        let supports_sdpa_vector = q_seq <= 8 && supported_head_dim && q_seq <= k_seq;
1161
1162        implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector;
1163
1164        if !supported_head_dim {
1165            hanzo_ml::bail!(
1166                "Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.",
1167                q_l.dims(),
1168                k_l.dims(),
1169                v_l.dims()
1170            );
1171        }
1172        if !implementation_supports_use_case {
1173            hanzo_ml::bail!(
1174                "Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.",
1175                q_l.dims(),
1176                k_l.dims(),
1177                v_l.dims()
1178            );
1179        }
1180
1181        for t in [k.dtype(), v.dtype()] {
1182            if q.dtype() != t {
1183                hanzo_ml::bail!("all q, k, v dtypes must match.");
1184            }
1185        }
1186
1187        let itype = match q.dtype() {
1188            DType::BF16 => SdpaDType::BF16,
1189            DType::F16 => SdpaDType::F16,
1190            DType::F32 => SdpaDType::F32,
1191            other => hanzo_ml::bail!("unsupported sdpa type {other:?}"),
1192        };
1193
1194        let encoder = q.device().command_encoder()?;
1195        if supports_sdpa_vector {
1196            // Route to the 2 pass fused attention if the k seqlen is large.
1197            // https://github.com/ml-explore/mlx/pull/1597
1198            const TWO_PASS_K_THRESHOLD: usize = 1024;
1199            if k_seq >= TWO_PASS_K_THRESHOLD {
1200                let mut intermediate_shape = [
1201                    &out_dims[0..out_dims.len() - 2],
1202                    &[hanzo_metal_kernels::SDPA_2PASS_BLOCKS],
1203                    &[out_dims[out_dims.len() - 1]],
1204                ]
1205                .concat();
1206                let intermediate = device.new_buffer(
1207                    intermediate_shape.iter().product::<usize>(),
1208                    DType::F32,
1209                    "sdpa_2pass_intermediate",
1210                )?;
1211                let _ = intermediate_shape.pop().unwrap();
1212                let sums = device.new_buffer(
1213                    intermediate_shape.iter().product::<usize>(),
1214                    DType::F32,
1215                    "sdpa_2pass_sums",
1216                )?;
1217                let maxs = device.new_buffer(
1218                    intermediate_shape.iter().product::<usize>(),
1219                    DType::F32,
1220                    "sdpa_2pass_maxs",
1221                )?;
1222
1223                encoder.set_label("vector_attention");
1224                hanzo_metal_kernels::call_sdpa_vector_2pass(
1225                    q.device().device(),
1226                    &encoder,
1227                    q.device().kernels(),
1228                    q_l.start_offset(),
1229                    q_l.dims(),
1230                    q.buffer(),
1231                    k_l.start_offset(),
1232                    k_l.dims(),
1233                    k_l.stride(),
1234                    k.buffer(),
1235                    v_l.start_offset(),
1236                    v_l.stride(),
1237                    v.buffer(),
1238                    &output,
1239                    &intermediate,
1240                    &sums,
1241                    &maxs,
1242                    self.scale,
1243                    self.softcapping,
1244                    itype,
1245                )
1246                .map_err(hanzo_ml::Error::wrap)?;
1247            } else {
1248                encoder.set_label("vector_attention");
1249                hanzo_metal_kernels::call_sdpa_vector(
1250                    q.device().device(),
1251                    &encoder,
1252                    q.device().kernels(),
1253                    q_l.start_offset(),
1254                    q_l.dims(),
1255                    q.buffer(),
1256                    k_l.start_offset(),
1257                    k_l.dims(),
1258                    k_l.stride(),
1259                    k.buffer(),
1260                    v_l.start_offset(),
1261                    v_l.stride(),
1262                    v.buffer(),
1263                    &output,
1264                    self.scale,
1265                    self.softcapping,
1266                    itype,
1267                )
1268                .map_err(hanzo_ml::Error::wrap)?;
1269            }
1270        } else if supports_sdpa_full {
1271            encoder.set_label("full_attention");
1272            if self.softcapping != 1. {
1273                hanzo_ml::bail!("SDPA full requires softcapping to be disabled (1.0)");
1274            }
1275
1276            let mask_s_l = self.mask.as_ref().map(|m| m.storage_and_layout());
1277
1278            let (mask_type, mask_buffer, mask_strides) = if let Some(mask) = &self.mask {
1279                let (mask_s, mask_l) = mask_s_l.as_ref().unwrap();
1280
1281                let mask_buffer = match &**mask_s {
1282                    hanzo_ml::Storage::Metal(m) => m.buffer(),
1283                    _ => hanzo_ml::bail!("Expected metal device for mask"),
1284                };
1285
1286                let mask_type = match mask.dtype() {
1287                    DType::BF16 => SdpaDType::BF16,
1288                    DType::F16 => SdpaDType::F16,
1289                    DType::F32 => SdpaDType::F32,
1290                    other => hanzo_ml::bail!("unsupported sdpa type {other:?}"),
1291                };
1292                if mask_type != itype {
1293                    hanzo_ml::bail!("Mask type {mask_type:?} must match q type {itype:?}");
1294                }
1295
1296                if mask_l.dims() != [q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, k_seq] {
1297                    hanzo_ml::bail!(
1298                        "Mask shape must be {:?} (bs, qheads, qseq, kseq), got {:?}",
1299                        [q_l.dim(0)?, q_head, q_l.dim(2)?, k_seq],
1300                        mask_l.dims()
1301                    );
1302                }
1303
1304                (
1305                    Some(mask_type),
1306                    Some(mask_buffer),
1307                    Some(mask_l.stride().to_vec()),
1308                )
1309            } else {
1310                (None, None, None)
1311            };
1312
1313            hanzo_metal_kernels::call_sdpa_full(
1314                q.device().device(),
1315                &encoder,
1316                q.device().kernels(),
1317                q_l.start_offset(),
1318                q_l.dims(),
1319                q_l.stride(),
1320                q.buffer(),
1321                k_l.start_offset(),
1322                k_l.dims(),
1323                k_l.stride(),
1324                k.buffer(),
1325                v_l.start_offset(),
1326                v.buffer(),
1327                v_l.stride(),
1328                mask_type,
1329                mask_buffer,
1330                mask_strides.as_deref(),
1331                &output,
1332                out_layout.stride(),
1333                self.scale,
1334                self.do_causal,
1335                itype,
1336            )
1337            .map_err(hanzo_ml::Error::wrap)?;
1338        } else {
1339            hanzo_ml::bail!("must be vector or full sdpa kernel");
1340        }
1341
1342        let newstorage = hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, q.dtype());
1343        Ok((newstorage, out_shape))
1344    }
1345}
1346
1347/// Scaled dot product attention with a fused kernel.
1348///
1349/// Computes softmax(qk^T*scale)v.
1350///
1351/// **Inputs shapes:**
1352/// - `q`: (bs, qhead, seq, hidden)
1353/// - `k`: (bs, kv_head, kv_seq, hidden)
1354/// - `k`: (bs, kv_head, kv_seq, v_hidden)
1355/// - `mask`: (bs, qhead, seq, kv_seq)
1356/// - `do_causal`: Apply causal masking. If this is true, the mask does not need to be provided.
1357/// - `scale` is applied before softmax.
1358/// - If `softcapping` != 1.0:
1359///      - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v
1360///
1361/// **Output shape:** (bs, qhead, seq, v_hidden)
1362///
1363/// Note: For Grouped Query Attention and Multi-Query Attention, the k and v inputs should not be pre-tiled to match q.
1364///
1365/// ## On Metal:
1366/// - If `seq` == 1:
1367///     - Use a vectorized kernel
1368///     - Supports `seq` != `kv_seq` (cross attn. support)
1369///     - Supports GQA when `qhead` is a multiple of `kv_head`
1370/// - Otherwise:
1371///     - Masking is supported
1372///     - Supports `seq` != `kv_seq` (cross attn. support)
1373///     - Supports GQA when `qhead` is a multiple of `kv_head`
1374///     - Softcapping is not supported.
1375pub fn sdpa(
1376    q: &Tensor,
1377    k: &Tensor,
1378    v: &Tensor,
1379    mask: Option<&Tensor>,
1380    do_causal: bool,
1381    scale: f32,
1382    softcapping: f32,
1383) -> Result<Tensor> {
1384    q.apply_op3_no_bwd(
1385        k,
1386        v,
1387        &Sdpa {
1388            scale,
1389            softcapping,
1390            mask: mask.cloned(),
1391            do_causal,
1392        },
1393    )
1394}