candle_nn/
ops.rs

1//! Tensor ops.
2//!
3
4use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
5use rayon::prelude::*;
6
7/// Applies the softmax function to the input tensor, rescaling the element so that elements on
8/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
9///
10/// ```rust
11/// use candle::{Tensor, Device, test_utils::to_vec2_round};
12/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
13/// let a = candle_nn::ops::softmax(&a, 1)?;
14/// assert_eq!(
15///     to_vec2_round(&a, 4)?,
16///     &[
17///         [0.1345, 0.3655, 0.1345, 0.3655],
18///         [0.0049, 0.2671, 0.7262, 0.0018]
19///     ]);
20/// # Ok::<(), candle::Error>(())
21/// ```
22pub fn softmax<D: candle::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
23    let dim = dim.to_index(xs.shape(), "softmax")?;
24    let max = xs.max_keepdim(dim)?;
25    let diff = xs.broadcast_sub(&max)?;
26    let num = diff.exp()?;
27    let den = num.sum_keepdim(dim)?;
28    num.broadcast_div(&den)
29}
30
31pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
32    let d = d.to_index(xs.shape(), "log-softmax")?;
33    let max = xs.max_keepdim(d)?;
34    let diff = xs.broadcast_sub(&max)?;
35    let sum_exp = diff.exp()?.sum_keepdim(d)?;
36    let log_sm = diff.broadcast_sub(&sum_exp.log()?)?;
37    Ok(log_sm)
38}
39
40pub fn silu(xs: &Tensor) -> Result<Tensor> {
41    xs.silu()
42}
43
44pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
45    let xs = xs.chunk(2, D::Minus1)?;
46    &xs[0].silu()? * &xs[1]
47}
48
49struct Sigmoid;
50
51impl candle::CustomOp1 for Sigmoid {
52    fn name(&self) -> &'static str {
53        "sigmoid"
54    }
55
56    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
57        use candle::backend::BackendStorage;
58
59        fn fwd<T: num_traits::Float>(v: T) -> T {
60            (v.neg().exp() + T::one()).recip()
61        }
62
63        // FIXME: using `candle::map_dtype` causes compilation errors.
64        let storage = match storage {
65            CpuStorage::BF16(slice) => {
66                CpuStorage::BF16(candle::cpu_backend::unary_map(slice, layout, fwd))
67            }
68            CpuStorage::F16(slice) => {
69                CpuStorage::F16(candle::cpu_backend::unary_map(slice, layout, fwd))
70            }
71            CpuStorage::F32(slice) => {
72                CpuStorage::F32(candle::cpu_backend::unary_map(slice, layout, fwd))
73            }
74            CpuStorage::F64(slice) => {
75                CpuStorage::F64(candle::cpu_backend::unary_map(slice, layout, fwd))
76            }
77            _ => Err(candle::Error::UnsupportedDTypeForOp(
78                storage.dtype(),
79                self.name(),
80            ))?,
81        };
82        Ok((storage, layout.shape().clone()))
83    }
84
85    #[cfg(feature = "cuda")]
86    fn cuda_fwd(
87        &self,
88        storage: &candle::CudaStorage,
89        layout: &Layout,
90    ) -> Result<(candle::CudaStorage, Shape)> {
91        use candle::backend::BackendStorage;
92        use candle::cuda_backend::cudarc::driver::{
93            CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
94        };
95        use candle::cuda_backend::SlicePtrOrNull;
96        use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
97        use candle::{CudaDevice, WithDType};
98
99        struct S;
100        impl Map1 for S {
101            fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
102                &self,
103                src: &CudaSlice<T>,
104                dev: &CudaDevice,
105                layout: &Layout,
106            ) -> Result<CudaSlice<T>> {
107                let shape = layout.shape();
108                let dims = shape.dims();
109                let el_count = shape.elem_count();
110                let cfg = LaunchConfig::for_num_elems(el_count as u32);
111                let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
112                let src = &src.slice(layout.start_offset()..);
113                let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
114                // SAFETY: Set later by running the kernel.
115                let out = unsafe { dev.alloc::<T>(el_count)? };
116
117                let mut builder = func.builder();
118                candle::builder_arg!(builder, el_count, dims.len());
119                ds.builder_arg(&mut builder);
120                builder.arg(src);
121                builder.arg(&out);
122                // SAFETY: ffi.
123                unsafe { builder.launch(cfg) }.w()?;
124                Ok(out)
125            }
126        }
127
128        let dev = storage.device();
129        let slice = S.map(&storage.slice, dev, layout)?;
130        let dst = candle::CudaStorage {
131            slice,
132            device: dev.clone(),
133        };
134        Ok((dst, layout.shape().clone()))
135    }
136
137    #[cfg(feature = "metal")]
138    fn metal_fwd(
139        &self,
140        storage: &candle::MetalStorage,
141        layout: &Layout,
142    ) -> Result<(candle::MetalStorage, Shape)> {
143        use candle::backend::BackendStorage;
144        use candle::MetalError;
145        let device = storage.device();
146        let dtype = storage.dtype();
147        let shape = layout.shape();
148        let el_count = shape.elem_count();
149        let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
150        let encoder = device.command_encoder()?;
151        encoder.set_label("sigmoid");
152        let src = candle_metal_kernels::BufferOffset {
153            buffer: storage.buffer(),
154            offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
155        };
156
157        if layout.is_contiguous() {
158            use candle_metal_kernels::unary::contiguous;
159            let kernel_name = match dtype {
160                DType::F16 => contiguous::sigmoid::HALF,
161                DType::F32 => contiguous::sigmoid::FLOAT,
162                DType::BF16 => contiguous::sigmoid::BFLOAT,
163                dtype => {
164                    candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
165                }
166            };
167            candle_metal_kernels::call_unary_contiguous(
168                device.metal_device(),
169                &encoder,
170                device.kernels(),
171                kernel_name,
172                dtype.size_in_bytes(),
173                el_count,
174                src,
175                &buffer,
176            )
177            .map_err(MetalError::from)?;
178        } else {
179            use candle_metal_kernels::unary::strided;
180            let kernel_name = match dtype {
181                DType::F16 => strided::sigmoid::HALF,
182                DType::F32 => strided::sigmoid::FLOAT,
183                DType::BF16 => strided::sigmoid::BFLOAT,
184                dtype => {
185                    candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
186                }
187            };
188            let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer);
189            candle_metal_kernels::call_unary_strided(
190                device.metal_device(),
191                &encoder,
192                device.kernels(),
193                kernel_name,
194                layout.dims(),
195                src,
196                layout.stride(),
197                dst,
198            )
199            .map_err(MetalError::from)?;
200        }
201
202        let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype);
203        Ok((new_storage, layout.shape().clone()))
204    }
205
206    fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
207        // d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x)
208        let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?;
209        Ok(Some(grad_res.mul(&d_dx_sigmoid)?))
210    }
211}
212
213pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
214    xs.apply_op1(Sigmoid)
215}
216
217pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
218    // TODO: Should we have a specialized op for this?
219    ((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32)
220}
221
222pub fn mish(xs: &Tensor) -> Result<Tensor> {
223    xs * (1.0 + xs.exp()?)?.log()?.tanh()
224}
225
226pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
227    let zeros = xs.zeros_like()?;
228    xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
229}
230
231pub fn selu(xs: &Tensor, alpha: f32, gamma: f32) -> Result<Tensor> {
232    let is_pos = xs.gt(0f32)?;
233    let alpha_t = Tensor::full(alpha, xs.dims(), xs.device())?;
234    let neg = xs.exp()?.mul(&alpha_t)?.sub(&alpha_t)?;
235    let selu = is_pos.where_cond(xs, &neg)?;
236    let gamma_t = Tensor::full(gamma, xs.dims(), xs.device())?;
237    selu.broadcast_mul(&gamma_t)
238}
239
240pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
241    // This implementation is inefficient as it stores the full mask for the backward pass.
242    // Instead we could just store the seed and have a specialized kernel that would both
243    // generate the random mask and apply it.
244    // Another easier optimization would be to be able to generate boolean mask using just a bit of
245    // entropy per element rather than generating a full float per element.
246    if !(0. ..1.).contains(&drop_p) {
247        candle::bail!("dropout probability has to be in [0, 1), got {drop_p}")
248    }
249    let rand = Tensor::rand(0f32, 1f32, xs.shape(), xs.device())?;
250    let scale = 1.0 / (1.0 - drop_p as f64);
251    let drop_p = Tensor::new(drop_p, xs.device())?.broadcast_as(xs.shape())?;
252    let mask = (rand.ge(&drop_p)?.to_dtype(xs.dtype())? * scale)?;
253    xs * mask
254}
255
256#[derive(Clone, Debug)]
257pub struct Dropout {
258    drop_p: f32,
259}
260
261impl Dropout {
262    pub fn new(drop_p: f32) -> Dropout {
263        Self { drop_p }
264    }
265
266    pub fn forward(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
267        if train {
268            dropout(xs, self.drop_p)
269        } else {
270            Ok(xs.clone())
271        }
272    }
273}
274
275impl candle::ModuleT for Dropout {
276    fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
277        self.forward(xs, train)
278    }
279}
280
281struct SoftmaxLastDim;
282
283impl candle::CustomOp1 for SoftmaxLastDim {
284    fn name(&self) -> &'static str {
285        "softmax-last-dim"
286    }
287
288    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
289        fn softmax<T: candle::WithDType + num_traits::Float>(
290            src: &[T],
291            layout: &Layout,
292        ) -> Result<(CpuStorage, Shape)> {
293            let src = match layout.contiguous_offsets() {
294                None => candle::bail!("input has to be contiguous"),
295                Some((o1, o2)) => &src[o1..o2],
296            };
297            let el_count = layout.shape().elem_count();
298            let dims = layout.shape().dims();
299            let dim_m1 = dims[dims.len() - 1];
300            let mut dst = vec![T::zero(); el_count];
301            src.par_chunks(dim_m1)
302                .zip(dst.par_chunks_mut(dim_m1))
303                .for_each(|(src, dst)| {
304                    let mut max = T::neg_infinity();
305                    unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) };
306                    for (s, d) in src.iter().zip(dst.iter_mut()) {
307                        *d = (*s - max).exp();
308                    }
309                    let mut sum_exp = T::zero();
310                    unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) };
311                    for d in dst.iter_mut() {
312                        *d /= sum_exp
313                    }
314                });
315            let storage = candle::WithDType::to_cpu_storage_owned(dst);
316            Ok((storage, Shape::from_dims(dims)))
317        }
318
319        match storage {
320            CpuStorage::BF16(slice) => softmax::<half::bf16>(slice, layout),
321            CpuStorage::F16(slice) => softmax::<half::f16>(slice, layout),
322            CpuStorage::F32(slice) => softmax::<f32>(slice, layout),
323            CpuStorage::F64(slice) => softmax::<f64>(slice, layout),
324            _ => candle::bail!("unsupported dtype for softmax {:?}", storage),
325        }
326    }
327
328    #[cfg(feature = "cuda")]
329    fn cuda_fwd(
330        &self,
331        storage: &candle::CudaStorage,
332        layout: &Layout,
333    ) -> Result<(candle::CudaStorage, Shape)> {
334        use candle::cuda_backend::cudarc::driver::{
335            CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
336        };
337        use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
338        use candle::{CudaDevice, WithDType};
339
340        struct S;
341        impl Map1 for S {
342            fn f<T: DeviceRepr + WithDType>(
343                &self,
344                src: &CudaSlice<T>,
345                dev: &CudaDevice,
346                layout: &Layout,
347            ) -> Result<CudaSlice<T>> {
348                let src = match layout.contiguous_offsets() {
349                    None => candle::bail!("input has to be contiguous"),
350                    Some((o1, o2)) => src.slice(o1..o2),
351                };
352                let el = layout.shape().elem_count();
353                let dims = layout.shape().dims();
354                let dim_m1 = dims[dims.len() - 1];
355                let (n_rows, n_cols) = (el / dim_m1, dim_m1);
356
357                let cfg = LaunchConfig {
358                    grid_dim: (n_rows as u32, 1, 1),
359                    block_dim: (1, 32, 1),
360                    shared_mem_bytes: 0,
361                };
362                let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
363                // SAFETY: Set later by running the kernel.
364                let dst = unsafe { dev.alloc::<T>(el)? };
365                let mut builder = func.builder();
366                builder.arg(&src);
367                builder.arg(&dst);
368                candle::builder_arg!(builder, n_cols as i32);
369                // SAFETY: ffi.
370                unsafe { builder.launch(cfg) }.w()?;
371                Ok(dst)
372            }
373        }
374
375        use candle::backend::BackendStorage;
376        let dev = storage.device();
377        let slice = S.map(&storage.slice, dev, layout)?;
378        let dst = candle::cuda_backend::CudaStorage {
379            slice,
380            device: dev.clone(),
381        };
382        Ok((dst, layout.shape().clone()))
383    }
384
385    #[cfg(feature = "metal")]
386    fn metal_fwd(
387        &self,
388        storage: &candle::MetalStorage,
389        layout: &Layout,
390    ) -> Result<(candle::MetalStorage, Shape)> {
391        use candle::backend::BackendStorage;
392        let device = storage.device();
393        let encoder = device.command_encoder()?;
394        encoder.set_label("softmax");
395        let kernels = device.kernels();
396        let name = match storage.dtype() {
397            DType::F32 => "softmax_f32",
398            DType::F16 => "softmax_f16",
399            DType::BF16 => "softmax_bf16",
400            dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
401        };
402
403        let n = layout.stride().len();
404        if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
405            candle::bail!("Non contiguous softmax-last-dim is not implemented");
406        }
407
408        let last_dim = layout.dims()[layout.shape().rank() - 1];
409        let elem_count = layout.shape().elem_count();
410        let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
411        candle_metal_kernels::call_last_softmax(
412            device.metal_device(),
413            &encoder,
414            kernels,
415            name,
416            elem_count,
417            last_dim,
418            storage.buffer(),
419            layout.start_offset() * storage.dtype().size_in_bytes(),
420            &output,
421        )
422        .map_err(candle::Error::wrap)?;
423        let newstorage =
424            candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
425        Ok((newstorage, layout.shape().clone()))
426    }
427}
428
429pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
430    xs.apply_op1_no_bwd(&SoftmaxLastDim)
431}
432
433#[derive(Debug, Clone)]
434struct RmsNorm {
435    eps: f32,
436}
437
438impl candle::CustomOp2 for RmsNorm {
439    fn name(&self) -> &'static str {
440        "rms-norm"
441    }
442
443    fn cpu_fwd(
444        &self,
445        s1: &CpuStorage,
446        l1: &Layout,
447        s2: &CpuStorage,
448        l2: &Layout,
449    ) -> Result<(CpuStorage, Shape)> {
450        use candle::backend::BackendStorage;
451
452        let eps = self.eps;
453        fn inner<
454            T: candle::WithDType
455                + num_traits::Float
456                + num_traits::AsPrimitive<f32>
457                + num_traits::FromPrimitive,
458        >(
459            src: &[T],
460            layout: &Layout,
461            alpha: &[T],
462            alpha_layout: &Layout,
463            eps: f32,
464        ) -> Result<(CpuStorage, Shape)> {
465            let src = match layout.contiguous_offsets() {
466                None => candle::bail!("input has to be contiguous"),
467                Some((o1, o2)) => &src[o1..o2],
468            };
469            let alpha = match alpha_layout.contiguous_offsets() {
470                None => candle::bail!("alpha has to be contiguous"),
471                Some((o1, o2)) => &alpha[o1..o2],
472            };
473            let el_count = layout.shape().elem_count();
474            let dims = layout.shape().dims();
475            let dim_m1 = dims[dims.len() - 1];
476            let mut dst = vec![T::zero(); el_count];
477            src.par_chunks(dim_m1)
478                .zip(dst.par_chunks_mut(dim_m1))
479                .for_each(|(src, dst)| {
480                    let sum2 = src
481                        .iter()
482                        .map(|&v| {
483                            let v = v.as_();
484                            v * v
485                        })
486                        .sum::<f32>();
487                    let m = (sum2 / dim_m1 as f32 + eps).sqrt();
488                    let m = T::from_f32(m).unwrap_or_else(T::nan);
489                    for ((d, s), alpha) in dst.iter_mut().zip(src.iter()).zip(alpha) {
490                        *d = *s / m * *alpha
491                    }
492                });
493            let storage = candle::WithDType::to_cpu_storage_owned(dst);
494            Ok((storage, Shape::from_dims(dims)))
495        }
496
497        use CpuStorage as C;
498        match (s1, s2) {
499            (C::BF16(s1), C::BF16(s2)) => inner::<half::bf16>(s1, l1, s2, l2, eps),
500            (C::F16(s1), C::F16(s2)) => inner::<half::f16>(s1, l1, s2, l2, eps),
501            (C::F32(s1), C::F32(s2)) => inner::<f32>(s1, l1, s2, l2, eps),
502            _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
503        }
504    }
505
506    #[cfg(feature = "cuda")]
507    fn cuda_fwd(
508        &self,
509        s1: &candle::CudaStorage,
510        l1: &Layout,
511        s2: &candle::CudaStorage,
512        l2: &Layout,
513    ) -> Result<(candle::CudaStorage, Shape)> {
514        use candle::cuda_backend::cudarc::driver::{
515            CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
516        };
517        use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
518        use candle::{CudaDevice, WithDType};
519
520        struct S {
521            eps: f32,
522        }
523        impl Map2 for S {
524            fn f<T: DeviceRepr + WithDType>(
525                &self,
526                src: &CudaSlice<T>,
527                layout: &Layout,
528                alpha: &CudaSlice<T>,
529                alpha_layout: &Layout,
530                dev: &CudaDevice,
531            ) -> Result<CudaSlice<T>> {
532                let src = match layout.contiguous_offsets() {
533                    None => candle::bail!("input has to be contiguous"),
534                    Some((o1, o2)) => src.slice(o1..o2),
535                };
536                let alpha = match alpha_layout.contiguous_offsets() {
537                    None => candle::bail!("alpha has to be contiguous"),
538                    Some((o1, o2)) => alpha.slice(o1..o2),
539                };
540                let el = layout.shape().elem_count();
541                let dims = layout.shape().dims();
542                let dim_m1 = dims[dims.len() - 1];
543                let (n_rows, n_cols) = (el / dim_m1, dim_m1);
544
545                let block_size = if n_cols < 1024 { 32 } else { 1024 };
546                let cfg = LaunchConfig {
547                    grid_dim: (n_rows as u32, 1, 1),
548                    block_dim: (block_size, 1, 1),
549                    shared_mem_bytes: 0,
550                };
551                let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
552                // SAFETY: Set later by running the kernel.
553                let dst = unsafe { dev.alloc::<T>(el)? };
554                let mut builder = func.builder();
555                builder.arg(&src);
556                builder.arg(&dst);
557                builder.arg(&alpha);
558                candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
559                // SAFETY: ffi.
560                unsafe { builder.launch(cfg) }.w()?;
561                Ok(dst)
562            }
563        }
564
565        use candle::backend::BackendStorage;
566        let dev = s1.device();
567        let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, dev)?;
568        let dst = candle::cuda_backend::CudaStorage {
569            slice,
570            device: dev.clone(),
571        };
572        Ok((dst, l1.shape().clone()))
573    }
574
575    #[cfg(feature = "metal")]
576    fn metal_fwd(
577        &self,
578        s1: &candle::MetalStorage,
579        l1: &Layout,
580        s2: &candle::MetalStorage,
581        l2: &Layout,
582    ) -> Result<(candle::MetalStorage, Shape)> {
583        use candle::backend::BackendStorage;
584        let device = s1.device();
585        let encoder = device.command_encoder()?;
586        encoder.set_label("rmsnorm");
587        let kernels = device.kernels();
588        let name = match (s1.dtype(), s2.dtype()) {
589            (DType::F32, DType::F32) => "rmsnorm_f32",
590            (DType::F16, DType::F16) => "rmsnorm_f16",
591            (DType::BF16, DType::BF16) => "rmsnorm_bf16",
592            (dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"),
593        };
594
595        if !(l1.is_contiguous() && l2.is_contiguous()) {
596            candle::bail!("Non contiguous rmsnorm is not implemented");
597        }
598
599        let last_dim = l1.dims()[l1.shape().rank() - 1];
600        let elem_count = l1.shape().elem_count();
601        let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?;
602        candle_metal_kernels::call_rms_norm(
603            device.metal_device(),
604            &encoder,
605            kernels,
606            name,
607            elem_count,
608            last_dim,
609            self.eps,
610            s1.buffer(),
611            l1.start_offset() * s1.dtype().size_in_bytes(),
612            s2.buffer(),
613            l2.start_offset() * s2.dtype().size_in_bytes(),
614            &output,
615        )
616        .map_err(candle::Error::wrap)?;
617        let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
618        Ok((newstorage, l1.shape().clone()))
619    }
620}
621
622pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
623    let x_dtype = x.dtype();
624    let internal_dtype = match x_dtype {
625        DType::F16 | DType::BF16 => DType::F32,
626        d => d,
627    };
628    let hidden_size = x.dim(D::Minus1)?;
629    let x = x.to_dtype(internal_dtype)?;
630    let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
631    let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
632    x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
633}
634
635pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
636    let hidden_size_xs = xs.dim(D::Minus1)?;
637    let hidden_size_alpha = alpha.dims1()?;
638    if hidden_size_xs != hidden_size_alpha {
639        candle::bail!(
640            "shape mismatch in rms-norm {:?} {:?}",
641            xs.shape(),
642            alpha.shape()
643        )
644    }
645    xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
646}
647
648#[derive(Debug, Clone)]
649struct LayerNorm {
650    eps: f32,
651}
652
653impl candle::CustomOp3 for LayerNorm {
654    fn name(&self) -> &'static str {
655        "layer-norm"
656    }
657
658    fn cpu_fwd(
659        &self,
660        s1: &CpuStorage,
661        l1: &Layout,
662        s2: &CpuStorage,
663        l2: &Layout,
664        s3: &CpuStorage,
665        l3: &Layout,
666    ) -> Result<(CpuStorage, Shape)> {
667        use candle::backend::BackendStorage;
668
669        let eps = self.eps;
670        fn inner<
671            T: candle::WithDType
672                + num_traits::Float
673                + num_traits::AsPrimitive<f32>
674                + num_traits::FromPrimitive,
675        >(
676            src: &[T],
677            layout: &Layout,
678            alpha: &[T],
679            alpha_layout: &Layout,
680            beta: &[T],
681            beta_layout: &Layout,
682            eps: f32,
683        ) -> Result<(CpuStorage, Shape)> {
684            let src = match layout.contiguous_offsets() {
685                None => candle::bail!("input has to be contiguous"),
686                Some((o1, o2)) => &src[o1..o2],
687            };
688            let alpha = match alpha_layout.contiguous_offsets() {
689                None => candle::bail!("alpha has to be contiguous"),
690                Some((o1, o2)) => &alpha[o1..o2],
691            };
692            let beta = match beta_layout.contiguous_offsets() {
693                None => candle::bail!("beta has to be contiguous"),
694                Some((o1, o2)) => &beta[o1..o2],
695            };
696            let el_count = layout.shape().elem_count();
697            let dims = layout.shape().dims();
698            let dim_m1 = dims[dims.len() - 1];
699            let mut dst = vec![T::zero(); el_count];
700            src.par_chunks(dim_m1)
701                .zip(dst.par_chunks_mut(dim_m1))
702                .for_each(|(src, dst)| {
703                    let mut sum = 0f32;
704                    let mut sum2 = 0f32;
705                    for v in src {
706                        let v = v.as_();
707                        sum += v;
708                        sum2 += v * v;
709                    }
710                    let mean = sum / dim_m1 as f32;
711                    let var = sum2 / dim_m1 as f32 - mean * mean;
712                    let inv_std = (var + eps).sqrt().recip();
713                    for ((d, s), (alpha, beta)) in
714                        dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta))
715                    {
716                        let alpha = alpha.as_();
717                        let beta = beta.as_();
718                        let d_ = (s.as_() - mean) * inv_std * alpha + beta;
719                        *d = T::from_f32(d_).unwrap_or_else(T::nan);
720                    }
721                });
722            let storage = candle::WithDType::to_cpu_storage_owned(dst);
723            Ok((storage, Shape::from_dims(dims)))
724        }
725
726        use CpuStorage as C;
727        match (s1, s2, s3) {
728            (C::BF16(s1), C::BF16(s2), C::BF16(s3)) => {
729                inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps)
730            }
731            (C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps),
732            (C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps),
733            _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
734        }
735    }
736
737    #[cfg(feature = "cuda")]
738    fn cuda_fwd(
739        &self,
740        s1: &candle::CudaStorage,
741        l1: &Layout,
742        s2: &candle::CudaStorage,
743        l2: &Layout,
744        s3: &candle::CudaStorage,
745        l3: &Layout,
746    ) -> Result<(candle::CudaStorage, Shape)> {
747        use candle::cuda_backend::cudarc::driver::{
748            CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
749        };
750        use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
751        use candle::{CudaDevice, WithDType};
752
753        struct S {
754            eps: f32,
755        }
756        impl Map3 for S {
757            fn f<T: DeviceRepr + WithDType>(
758                &self,
759                src: &CudaSlice<T>,
760                layout: &Layout,
761                alpha: &CudaSlice<T>,
762                alpha_layout: &Layout,
763                beta: &CudaSlice<T>,
764                beta_layout: &Layout,
765                dev: &CudaDevice,
766            ) -> Result<CudaSlice<T>> {
767                let src = match layout.contiguous_offsets() {
768                    None => candle::bail!("input has to be contiguous"),
769                    Some((o1, o2)) => src.slice(o1..o2),
770                };
771                let alpha = match alpha_layout.contiguous_offsets() {
772                    None => candle::bail!("alpha has to be contiguous"),
773                    Some((o1, o2)) => alpha.slice(o1..o2),
774                };
775                let beta = match beta_layout.contiguous_offsets() {
776                    None => candle::bail!("beta has to be contiguous"),
777                    Some((o1, o2)) => beta.slice(o1..o2),
778                };
779                let el = layout.shape().elem_count();
780                let dims = layout.shape().dims();
781                let dim_m1 = dims[dims.len() - 1];
782                let (n_rows, n_cols) = (el / dim_m1, dim_m1);
783
784                let block_size = if n_cols < 1024 { 32 } else { 1024 };
785                let cfg = LaunchConfig {
786                    grid_dim: (n_rows as u32, 1, 1),
787                    block_dim: (block_size, 1, 1),
788                    shared_mem_bytes: 0,
789                };
790                let func =
791                    dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
792                // SAFETY: Set later by running the kernel.
793                let dst = unsafe { dev.alloc::<T>(el)? };
794                let mut builder = func.builder();
795                builder.arg(&src);
796                builder.arg(&dst);
797                builder.arg(&alpha);
798                builder.arg(&beta);
799                candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
800                // SAFETY: ffi.
801                unsafe { builder.launch(cfg) }.w()?;
802                Ok(dst)
803            }
804        }
805
806        use candle::backend::BackendStorage;
807        let dev = s1.device();
808        let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?;
809        let dst = candle::cuda_backend::CudaStorage {
810            slice,
811            device: dev.clone(),
812        };
813        Ok((dst, l1.shape().clone()))
814    }
815
816    #[cfg(feature = "metal")]
817    fn metal_fwd(
818        &self,
819        s1: &candle::MetalStorage,
820        l1: &Layout,
821        s2: &candle::MetalStorage,
822        l2: &Layout,
823        s3: &candle::MetalStorage,
824        l3: &Layout,
825    ) -> Result<(candle::MetalStorage, Shape)> {
826        use candle::backend::BackendStorage;
827        let device = s1.device();
828        let encoder = device.command_encoder()?;
829        encoder.set_label("layernorm");
830        let kernels = device.kernels();
831        let name = match (s1.dtype(), s2.dtype(), s3.dtype()) {
832            (DType::F32, DType::F32, DType::F32) => "layernorm_f32",
833            (DType::F16, DType::F16, DType::F16) => "layernorm_f16",
834            (DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16",
835            (dt1, dt2, dt3) => {
836                candle::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}")
837            }
838        };
839
840        if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) {
841            candle::bail!("Non contiguous layernorm is not implemented");
842        }
843
844        let last_dim = l1.dims()[l1.shape().rank() - 1];
845        let elem_count = l1.shape().elem_count();
846        let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?;
847        candle_metal_kernels::call_layer_norm(
848            device.metal_device(),
849            &encoder,
850            kernels,
851            name,
852            elem_count,
853            last_dim,
854            self.eps,
855            s1.buffer(),
856            l1.start_offset() * s1.dtype().size_in_bytes(),
857            s2.buffer(),
858            l2.start_offset() * s2.dtype().size_in_bytes(),
859            s3.buffer(),
860            l3.start_offset() * s3.dtype().size_in_bytes(),
861            &output,
862        )
863        .map_err(candle::Error::wrap)?;
864        let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
865        Ok((newstorage, l1.shape().clone()))
866    }
867}
868
869pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
870    let x_dtype = x.dtype();
871    let internal_dtype = match x_dtype {
872        DType::F16 | DType::BF16 => DType::F32,
873        d => d,
874    };
875    let hidden_size = x.dim(D::Minus1)?;
876    let x = x.to_dtype(internal_dtype)?;
877    let x = {
878        let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
879        x.broadcast_sub(&mean_x)?
880    };
881    let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
882    let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
883    x_normed
884        .to_dtype(x_dtype)?
885        .broadcast_mul(alpha)?
886        .broadcast_add(beta)
887}
888
889pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
890    let hidden_size_xs = xs.dim(D::Minus1)?;
891    let hidden_size_alpha = alpha.dims1()?;
892    let hidden_size_beta = beta.dims1()?;
893    if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta {
894        candle::bail!(
895            "shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}",
896            xs.shape(),
897            alpha.shape(),
898            beta.shape()
899        )
900    }
901    xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps })
902}
903
904// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
905pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
906    let (b_size, c, h, w) = xs.dims4()?;
907    let out_c = c / upscale_factor / upscale_factor;
908    xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))?
909        .permute((0, 1, 4, 2, 5, 3))?
910        .reshape((b_size, out_c, h * upscale_factor, w * upscale_factor))
911}
912
913pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
914    let (b_size, c, h, w) = xs.dims4()?;
915    let out_c = c * downscale_factor * downscale_factor;
916    xs.reshape((
917        b_size,
918        c,
919        h / downscale_factor,
920        downscale_factor,
921        w / downscale_factor,
922        downscale_factor,
923    ))?
924    .permute((0, 1, 3, 5, 2, 4))?
925    .reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
926}
927
928// https://pytorch.org/docs/stable/generated/torch.nn.ReplicationPad2d.html
929pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
930    match pad {
931        0 => Ok(xs.clone()),
932        1 => {
933            let (_b_size, _c, h, w) = xs.dims4()?;
934            let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?);
935            let xs = Tensor::cat(&[&first, xs, &last], 3)?;
936            let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?);
937            Tensor::cat(&[&first, &xs, &last], 2)
938        }
939        n => candle::bail!("replication-pad with a size of {n} is not supported"),
940    }
941}
942
943#[derive(Clone, Debug)]
944pub struct Identity;
945
946impl Identity {
947    pub fn new() -> Identity {
948        Self
949    }
950}
951
952impl Default for Identity {
953    fn default() -> Self {
954        Self
955    }
956}
957
958impl Module for Identity {
959    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
960        Ok(xs.clone())
961    }
962}
963
964#[allow(dead_code)]
965struct Sdpa {
966    scale: f32,
967    softcapping: f32,
968    mask: Option<Tensor>,
969    do_causal: bool,
970}
971
972impl candle::CustomOp3 for Sdpa {
973    fn name(&self) -> &'static str {
974        "metal-sdpa"
975    }
976
977    fn cpu_fwd(
978        &self,
979        _s1: &CpuStorage,
980        _l1: &Layout,
981        _s2: &CpuStorage,
982        _l2: &Layout,
983        _s3: &CpuStorage,
984        _l3: &Layout,
985    ) -> Result<(CpuStorage, Shape)> {
986        candle::bail!("SDPA has no cpu impl")
987    }
988
989    #[cfg(feature = "metal")]
990    fn metal_fwd(
991        &self,
992        q: &candle::MetalStorage,
993        q_l: &Layout,
994        k: &candle::MetalStorage,
995        k_l: &Layout,
996        v: &candle::MetalStorage,
997        v_l: &Layout,
998    ) -> Result<(candle::MetalStorage, Shape)> {
999        use candle::backend::BackendStorage;
1000        use candle_metal_kernels::SdpaDType;
1001
1002        let device = q.device();
1003
1004        let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?];
1005        let elem_count: usize = out_dims.iter().product();
1006        let out_shape = Shape::from_dims(&out_dims);
1007        let out_layout = Layout::contiguous(out_shape.clone());
1008
1009        let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?;
1010
1011        // q,k must have matching emb dim
1012        if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? {
1013            candle::bail!("`q` and `k` last dims must match");
1014        }
1015
1016        // k,v must have matching n kv heads
1017        if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? {
1018            candle::bail!("`k` and `v` head dims must match");
1019        }
1020
1021        // n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.
1022        if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 {
1023            candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
1024        }
1025
1026        let k_head = k_l.dim(D::Minus1)?;
1027        let q_head = q_l.dim(D::Minus1)?;
1028        let q_seq = q_l.dim(2)?;
1029        let k_seq = k_l.dim(2)?;
1030
1031        let mut implementation_supports_use_case = q_head == k_head;
1032        let supported_head_dim = q_head == 32
1033            || q_head == 64
1034            || q_head == 72
1035            || q_head == 80
1036            || q_head == 96
1037            || q_head == 128
1038            || q_head == 256;
1039
1040        let supports_sdpa_full_mask = self.mask.is_none() || q_seq <= k_seq;
1041        let supports_sdpa_full = q_seq > 8 && supported_head_dim && supports_sdpa_full_mask;
1042        let supports_sdpa_vector = q_seq <= 8 && supported_head_dim && q_seq <= k_seq;
1043
1044        implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector;
1045
1046        if !supported_head_dim {
1047            candle::bail!(
1048                "Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.",
1049                q_l.dims(),
1050                k_l.dims(),
1051                v_l.dims()
1052            );
1053        }
1054        if !implementation_supports_use_case {
1055            candle::bail!(
1056                "Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.",
1057                q_l.dims(),
1058                k_l.dims(),
1059                v_l.dims()
1060            );
1061        }
1062
1063        for t in [k.dtype(), v.dtype()] {
1064            if q.dtype() != t {
1065                candle::bail!("all q, k, v dtypes must match.");
1066            }
1067        }
1068
1069        let itype = match q.dtype() {
1070            DType::BF16 => SdpaDType::BF16,
1071            DType::F16 => SdpaDType::F16,
1072            DType::F32 => SdpaDType::F32,
1073            other => candle::bail!("unsupported sdpa type {other:?}"),
1074        };
1075
1076        let encoder = q.device().command_encoder()?;
1077        if supports_sdpa_vector {
1078            // Route to the 2 pass fused attention if the k seqlen is large.
1079            // https://github.com/ml-explore/mlx/pull/1597
1080            const TWO_PASS_K_THRESHOLD: usize = 1024;
1081            if k_seq >= TWO_PASS_K_THRESHOLD {
1082                let mut intermediate_shape = [
1083                    &out_dims[0..out_dims.len() - 2],
1084                    &[candle_metal_kernels::SDPA_2PASS_BLOCKS],
1085                    &[out_dims[out_dims.len() - 1]],
1086                ]
1087                .concat();
1088                let intermediate = device.new_buffer(
1089                    intermediate_shape.iter().product::<usize>(),
1090                    DType::F32,
1091                    "sdpa_2pass_intermediate",
1092                )?;
1093                let _ = intermediate_shape.pop().unwrap();
1094                let sums = device.new_buffer(
1095                    intermediate_shape.iter().product::<usize>(),
1096                    DType::F32,
1097                    "sdpa_2pass_sums",
1098                )?;
1099                let maxs = device.new_buffer(
1100                    intermediate_shape.iter().product::<usize>(),
1101                    DType::F32,
1102                    "sdpa_2pass_maxs",
1103                )?;
1104
1105                encoder.set_label("vector_attention");
1106                candle_metal_kernels::call_sdpa_vector_2pass(
1107                    q.device().device(),
1108                    &encoder,
1109                    q.device().kernels(),
1110                    q_l.start_offset(),
1111                    q_l.dims(),
1112                    q.buffer(),
1113                    k_l.start_offset(),
1114                    k_l.dims(),
1115                    k_l.stride(),
1116                    k.buffer(),
1117                    v_l.start_offset(),
1118                    v_l.stride(),
1119                    v.buffer(),
1120                    &output,
1121                    &intermediate,
1122                    &sums,
1123                    &maxs,
1124                    self.scale,
1125                    self.softcapping,
1126                    itype,
1127                )
1128                .map_err(candle::Error::wrap)?;
1129            } else {
1130                encoder.set_label("vector_attention");
1131                candle_metal_kernels::call_sdpa_vector(
1132                    q.device().device(),
1133                    &encoder,
1134                    q.device().kernels(),
1135                    q_l.start_offset(),
1136                    q_l.dims(),
1137                    q.buffer(),
1138                    k_l.start_offset(),
1139                    k_l.dims(),
1140                    k_l.stride(),
1141                    k.buffer(),
1142                    v_l.start_offset(),
1143                    v_l.stride(),
1144                    v.buffer(),
1145                    &output,
1146                    self.scale,
1147                    self.softcapping,
1148                    itype,
1149                )
1150                .map_err(candle::Error::wrap)?;
1151            }
1152        } else if supports_sdpa_full {
1153            encoder.set_label("full_attention");
1154            if self.softcapping != 1. {
1155                candle::bail!("SDPA full requires softcapping to be disabled (1.0)");
1156            }
1157
1158            let mask_s_l = self.mask.as_ref().map(|m| m.storage_and_layout());
1159
1160            let (mask_type, mask_buffer, mask_strides) = if let Some(mask) = &self.mask {
1161                let (mask_s, mask_l) = mask_s_l.as_ref().unwrap();
1162
1163                let mask_buffer = match &**mask_s {
1164                    candle::Storage::Metal(m) => m.buffer(),
1165                    _ => candle::bail!("Expected metal device for mask"),
1166                };
1167
1168                let mask_type = match mask.dtype() {
1169                    DType::BF16 => SdpaDType::BF16,
1170                    DType::F16 => SdpaDType::F16,
1171                    DType::F32 => SdpaDType::F32,
1172                    other => candle::bail!("unsupported sdpa type {other:?}"),
1173                };
1174                if mask_type != itype {
1175                    candle::bail!("Mask type {mask_type:?} must match q type {itype:?}");
1176                }
1177
1178                if mask_l.dims() != [q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, k_seq] {
1179                    candle::bail!(
1180                        "Mask shape must be {:?} (bs, qheads, qseq, kseq), got {:?}",
1181                        [q_l.dim(0)?, q_head, q_l.dim(2)?, k_seq],
1182                        mask_l.dims()
1183                    );
1184                }
1185
1186                (
1187                    Some(mask_type),
1188                    Some(mask_buffer),
1189                    Some(mask_l.stride().to_vec()),
1190                )
1191            } else {
1192                (None, None, None)
1193            };
1194
1195            candle_metal_kernels::call_sdpa_full(
1196                q.device().device(),
1197                &encoder,
1198                q.device().kernels(),
1199                q_l.start_offset(),
1200                q_l.dims(),
1201                q_l.stride(),
1202                q.buffer(),
1203                k_l.start_offset(),
1204                k_l.dims(),
1205                k_l.stride(),
1206                k.buffer(),
1207                v_l.start_offset(),
1208                v.buffer(),
1209                v_l.stride(),
1210                mask_type,
1211                mask_buffer,
1212                mask_strides.as_deref(),
1213                &output,
1214                out_layout.stride(),
1215                self.scale,
1216                self.do_causal,
1217                itype,
1218            )
1219            .map_err(candle::Error::wrap)?;
1220        } else {
1221            candle::bail!("must be vector or full sdpa kernel");
1222        }
1223
1224        let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype());
1225        Ok((newstorage, out_shape))
1226    }
1227}
1228
1229/// Scaled dot product attention with a fused kernel.
1230///
1231/// Computes softmax(qk^T*scale)v.
1232///
1233/// **Inputs shapes:**
1234/// - `q`: (bs, qhead, seq, hidden)
1235/// - `k`: (bs, kv_head, kv_seq, hidden)
1236/// - `k`: (bs, kv_head, kv_seq, v_hidden)
1237/// - `mask`: (bs, qhead, seq, kv_seq)
1238/// - `do_causal`: Apply causal masking. If this is true, the mask does not need to be provided.
1239/// - `scale` is applied before softmax.
1240/// - If `softcapping` != 1.0:
1241///      - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v
1242///
1243/// **Output shape:** (bs, qhead, seq, v_hidden)
1244///
1245/// Note: For Grouped Query Attention and Multi-Query Attention, the k and v inputs should not be pre-tiled to match q.
1246///
1247/// ## On Metal:
1248/// - If `seq` == 1:
1249///     - Use a vectorized kernel
1250///     - Supports `seq` != `kv_seq` (cross attn. support)
1251///     - Supports GQA when `qhead` is a multiple of `kv_head`
1252/// - Otherwise:
1253///     - Masking is supported
1254///     - Supports `seq` != `kv_seq` (cross attn. support)
1255///     - Supports GQA when `qhead` is a multiple of `kv_head`
1256///     - Softcapping is not supported.
1257pub fn sdpa(
1258    q: &Tensor,
1259    k: &Tensor,
1260    v: &Tensor,
1261    mask: Option<&Tensor>,
1262    do_causal: bool,
1263    scale: f32,
1264    softcapping: f32,
1265) -> Result<Tensor> {
1266    q.apply_op3_no_bwd(
1267        k,
1268        v,
1269        &Sdpa {
1270            scale,
1271            softcapping,
1272            mask: mask.cloned(),
1273            do_causal,
1274        },
1275    )
1276}