candle_nn/
ops.rs

1//! Tensor ops.
2//!
3
4use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
5use rayon::prelude::*;
6
7/// Applies the softmax function to the input tensor, rescaling the element so that elements on
8/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
9///
10/// ```rust
11/// use candle::{Tensor, Device, test_utils::to_vec2_round};
12/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
13/// let a = candle_nn::ops::softmax(&a, 1)?;
14/// assert_eq!(
15///     to_vec2_round(&a, 4)?,
16///     &[
17///         [0.1345, 0.3655, 0.1345, 0.3655],
18///         [0.0049, 0.2671, 0.7262, 0.0018]
19///     ]);
20/// # Ok::<(), candle::Error>(())
21/// ```
22pub fn softmax<D: candle::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
23    let dim = dim.to_index(xs.shape(), "softmax")?;
24    let max = xs.max_keepdim(dim)?;
25    let diff = xs.broadcast_sub(&max)?;
26    let num = diff.exp()?;
27    let den = num.sum_keepdim(dim)?;
28    num.broadcast_div(&den)
29}
30
31pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
32    let d = d.to_index(xs.shape(), "log-softmax")?;
33    let max = xs.max_keepdim(d)?;
34    let diff = xs.broadcast_sub(&max)?;
35    let sum_exp = diff.exp()?.sum_keepdim(d)?;
36    let log_sm = diff.broadcast_sub(&sum_exp.log()?)?;
37    Ok(log_sm)
38}
39
40pub fn silu(xs: &Tensor) -> Result<Tensor> {
41    xs.silu()
42}
43
44pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
45    let xs = xs.chunk(2, D::Minus1)?;
46    &xs[0].silu()? * &xs[1]
47}
48
49struct Sigmoid;
50
51impl candle::CustomOp1 for Sigmoid {
52    fn name(&self) -> &'static str {
53        "sigmoid"
54    }
55
56    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
57        use candle::backend::BackendStorage;
58
59        fn fwd<T: num_traits::Float>(v: T) -> T {
60            (v.neg().exp() + T::one()).recip()
61        }
62
63        // FIXME: using `candle::map_dtype` causes compilation errors.
64        let storage = match storage {
65            CpuStorage::BF16(slice) => {
66                CpuStorage::BF16(candle::cpu_backend::unary_map(slice, layout, fwd))
67            }
68            CpuStorage::F16(slice) => {
69                CpuStorage::F16(candle::cpu_backend::unary_map(slice, layout, fwd))
70            }
71            CpuStorage::F32(slice) => {
72                CpuStorage::F32(candle::cpu_backend::unary_map(slice, layout, fwd))
73            }
74            CpuStorage::F64(slice) => {
75                CpuStorage::F64(candle::cpu_backend::unary_map(slice, layout, fwd))
76            }
77            _ => Err(candle::Error::UnsupportedDTypeForOp(
78                storage.dtype(),
79                self.name(),
80            ))?,
81        };
82        Ok((storage, layout.shape().clone()))
83    }
84
85    #[cfg(feature = "cuda")]
86    fn cuda_fwd(
87        &self,
88        storage: &candle::CudaStorage,
89        layout: &Layout,
90    ) -> Result<(candle::CudaStorage, Shape)> {
91        use candle::backend::BackendStorage;
92        use candle::cuda_backend::cudarc::driver::{
93            CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
94        };
95        use candle::cuda_backend::SlicePtrOrNull;
96        use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
97        use candle::{CudaDevice, WithDType};
98
99        struct S;
100        impl Map1 for S {
101            fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
102                &self,
103                src: &CudaSlice<T>,
104                dev: &CudaDevice,
105                layout: &Layout,
106            ) -> Result<CudaSlice<T>> {
107                let shape = layout.shape();
108                let dims = shape.dims();
109                let el_count = shape.elem_count();
110                let cfg = LaunchConfig::for_num_elems(el_count as u32);
111                let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
112                let src = &src.slice(layout.start_offset()..);
113                let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
114                // SAFETY: Set later by running the kernel.
115                let out = unsafe { dev.alloc::<T>(el_count)? };
116
117                let mut builder = func.builder();
118                candle::builder_arg!(builder, el_count, dims.len());
119                ds.builder_arg(&mut builder);
120                builder.arg(src);
121                builder.arg(&out);
122                // SAFETY: ffi.
123                unsafe { builder.launch(cfg) }.w()?;
124                Ok(out)
125            }
126        }
127
128        let dev = storage.device();
129        let slice = S.map(&storage.slice, dev, layout)?;
130        let dst = candle::CudaStorage {
131            slice,
132            device: dev.clone(),
133        };
134        Ok((dst, layout.shape().clone()))
135    }
136
137    #[cfg(feature = "metal")]
138    fn metal_fwd(
139        &self,
140        storage: &candle::MetalStorage,
141        layout: &Layout,
142    ) -> Result<(candle::MetalStorage, Shape)> {
143        use candle::backend::BackendStorage;
144        use candle::MetalError;
145        let device = storage.device();
146        let dtype = storage.dtype();
147        let shape = layout.shape();
148        let el_count = shape.elem_count();
149        let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
150        let command_buffer = device.command_buffer()?;
151        command_buffer.set_label("sigmoid");
152        let src = candle_metal_kernels::BufferOffset {
153            buffer: storage.buffer(),
154            offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
155        };
156
157        match (el_count % 2, dtype, layout.is_contiguous()) {
158            (0, DType::BF16 | DType::F16, true) => {
159                use candle_metal_kernels::unary::contiguous_tiled;
160                let kernel_name = match dtype {
161                    DType::F16 => contiguous_tiled::sigmoid::HALF,
162                    DType::F32 => contiguous_tiled::sigmoid::FLOAT,
163                    DType::BF16 => contiguous_tiled::sigmoid::BFLOAT,
164                    dtype => {
165                        candle::bail!(
166                            "Metal contiguous_tiled unary sigmoid {dtype:?} not implemented"
167                        )
168                    }
169                };
170                candle_metal_kernels::call_unary_contiguous_tiled(
171                    device.metal_device(),
172                    &command_buffer,
173                    device.kernels(),
174                    kernel_name,
175                    el_count,
176                    src,
177                    &buffer,
178                )
179                .map_err(MetalError::from)?;
180            }
181            (_, _, true) => {
182                use candle_metal_kernels::unary::contiguous;
183                let kernel_name = match dtype {
184                    DType::F16 => contiguous::sigmoid::HALF,
185                    DType::F32 => contiguous::sigmoid::FLOAT,
186                    DType::BF16 => contiguous::sigmoid::BFLOAT,
187                    dtype => {
188                        candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
189                    }
190                };
191                candle_metal_kernels::call_unary_contiguous(
192                    device.metal_device(),
193                    &command_buffer,
194                    device.kernels(),
195                    kernel_name,
196                    el_count,
197                    src,
198                    &buffer,
199                )
200                .map_err(MetalError::from)?;
201            }
202            (_, _, false) => {
203                use candle_metal_kernels::unary::strided;
204                let kernel_name = match dtype {
205                    DType::F16 => strided::sigmoid::HALF,
206                    DType::F32 => strided::sigmoid::FLOAT,
207                    DType::BF16 => strided::sigmoid::BFLOAT,
208                    dtype => {
209                        candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
210                    }
211                };
212                let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer);
213                candle_metal_kernels::call_unary_strided(
214                    device.metal_device(),
215                    &command_buffer,
216                    device.kernels(),
217                    kernel_name,
218                    layout.dims(),
219                    src,
220                    layout.stride(),
221                    dst,
222                )
223                .map_err(MetalError::from)?;
224            }
225        }
226
227        let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype);
228        Ok((new_storage, layout.shape().clone()))
229    }
230
231    fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
232        // d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x)
233        let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?;
234        Ok(Some(grad_res.mul(&d_dx_sigmoid)?))
235    }
236}
237
238pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
239    xs.apply_op1(Sigmoid)
240}
241
242pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
243    // TODO: Should we have a specialized op for this?
244    ((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32)
245}
246
247pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
248    let zeros = xs.zeros_like()?;
249    xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
250}
251
252pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
253    // This implementation is inefficient as it stores the full mask for the backward pass.
254    // Instead we could just store the seed and have a specialized kernel that would both
255    // generate the random mask and apply it.
256    // Another easier optimization would be to be able to generate boolean mask using just a bit of
257    // entropy per element rather than generating a full float per element.
258    if !(0. ..1.).contains(&drop_p) {
259        candle::bail!("dropout probability has to be in [0, 1), got {drop_p}")
260    }
261    let rand = Tensor::rand(0f32, 1f32, xs.shape(), xs.device())?;
262    let scale = 1.0 / (1.0 - drop_p as f64);
263    let drop_p = Tensor::new(drop_p, xs.device())?.broadcast_as(xs.shape())?;
264    let mask = (rand.ge(&drop_p)?.to_dtype(xs.dtype())? * scale)?;
265    xs * mask
266}
267
268#[derive(Clone, Debug)]
269pub struct Dropout {
270    drop_p: f32,
271}
272
273impl Dropout {
274    pub fn new(drop_p: f32) -> Dropout {
275        Self { drop_p }
276    }
277
278    pub fn forward(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
279        if train {
280            dropout(xs, self.drop_p)
281        } else {
282            Ok(xs.clone())
283        }
284    }
285}
286
287impl candle::ModuleT for Dropout {
288    fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
289        self.forward(xs, train)
290    }
291}
292
293struct SoftmaxLastDim;
294
295impl candle::CustomOp1 for SoftmaxLastDim {
296    fn name(&self) -> &'static str {
297        "softmax-last-dim"
298    }
299
300    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
301        fn softmax<T: candle::WithDType + num_traits::Float>(
302            src: &[T],
303            layout: &Layout,
304        ) -> Result<(CpuStorage, Shape)> {
305            let src = match layout.contiguous_offsets() {
306                None => candle::bail!("input has to be contiguous"),
307                Some((o1, o2)) => &src[o1..o2],
308            };
309            let el_count = layout.shape().elem_count();
310            let dims = layout.shape().dims();
311            let dim_m1 = dims[dims.len() - 1];
312            let mut dst = vec![T::zero(); el_count];
313            src.par_chunks(dim_m1)
314                .zip(dst.par_chunks_mut(dim_m1))
315                .for_each(|(src, dst)| {
316                    let mut max = T::neg_infinity();
317                    unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) };
318                    for (s, d) in src.iter().zip(dst.iter_mut()) {
319                        *d = (*s - max).exp();
320                    }
321                    let mut sum_exp = T::zero();
322                    unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) };
323                    for d in dst.iter_mut() {
324                        *d /= sum_exp
325                    }
326                });
327            let storage = candle::WithDType::to_cpu_storage_owned(dst);
328            Ok((storage, Shape::from_dims(dims)))
329        }
330
331        match storage {
332            CpuStorage::BF16(slice) => softmax::<half::bf16>(slice, layout),
333            CpuStorage::F16(slice) => softmax::<half::f16>(slice, layout),
334            CpuStorage::F32(slice) => softmax::<f32>(slice, layout),
335            CpuStorage::F64(slice) => softmax::<f64>(slice, layout),
336            _ => candle::bail!("unsupported dtype for softmax {:?}", storage),
337        }
338    }
339
340    #[cfg(feature = "cuda")]
341    fn cuda_fwd(
342        &self,
343        storage: &candle::CudaStorage,
344        layout: &Layout,
345    ) -> Result<(candle::CudaStorage, Shape)> {
346        use candle::cuda_backend::cudarc::driver::{
347            CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
348        };
349        use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
350        use candle::{CudaDevice, WithDType};
351
352        struct S;
353        impl Map1 for S {
354            fn f<T: DeviceRepr + WithDType>(
355                &self,
356                src: &CudaSlice<T>,
357                dev: &CudaDevice,
358                layout: &Layout,
359            ) -> Result<CudaSlice<T>> {
360                let src = match layout.contiguous_offsets() {
361                    None => candle::bail!("input has to be contiguous"),
362                    Some((o1, o2)) => src.slice(o1..o2),
363                };
364                let el = layout.shape().elem_count();
365                let dims = layout.shape().dims();
366                let dim_m1 = dims[dims.len() - 1];
367                let (n_rows, n_cols) = (el / dim_m1, dim_m1);
368
369                let cfg = LaunchConfig {
370                    grid_dim: (n_rows as u32, 1, 1),
371                    block_dim: (1, 32, 1),
372                    shared_mem_bytes: 0,
373                };
374                let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
375                // SAFETY: Set later by running the kernel.
376                let dst = unsafe { dev.alloc::<T>(el)? };
377                let mut builder = func.builder();
378                builder.arg(&src);
379                builder.arg(&dst);
380                candle::builder_arg!(builder, n_cols as i32);
381                // SAFETY: ffi.
382                unsafe { builder.launch(cfg) }.w()?;
383                Ok(dst)
384            }
385        }
386
387        use candle::backend::BackendStorage;
388        let dev = storage.device();
389        let slice = S.map(&storage.slice, dev, layout)?;
390        let dst = candle::cuda_backend::CudaStorage {
391            slice,
392            device: dev.clone(),
393        };
394        Ok((dst, layout.shape().clone()))
395    }
396
397    #[cfg(feature = "metal")]
398    fn metal_fwd(
399        &self,
400        storage: &candle::MetalStorage,
401        layout: &Layout,
402    ) -> Result<(candle::MetalStorage, Shape)> {
403        use candle::backend::BackendStorage;
404        let device = storage.device();
405        let command_buffer = device.command_buffer()?;
406        let kernels = device.kernels();
407        let name = match storage.dtype() {
408            DType::F32 => "softmax_f32",
409            DType::F16 => "softmax_f16",
410            DType::BF16 => "softmax_bf16",
411            dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
412        };
413
414        let n = layout.stride().len();
415        if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
416            candle::bail!("Non contiguous softmax-last-dim is not implemented");
417        }
418
419        let last_dim = layout.dims()[layout.shape().rank() - 1];
420        let elem_count = layout.shape().elem_count();
421        let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
422        candle_metal_kernels::call_last_softmax(
423            device.metal_device(),
424            &command_buffer,
425            kernels,
426            name,
427            elem_count,
428            last_dim,
429            storage.buffer(),
430            layout.start_offset() * storage.dtype().size_in_bytes(),
431            &output,
432        )
433        .map_err(candle::Error::wrap)?;
434        let newstorage =
435            candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
436        Ok((newstorage, layout.shape().clone()))
437    }
438}
439
440pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
441    xs.apply_op1_no_bwd(&SoftmaxLastDim)
442}
443
444#[derive(Debug, Clone)]
445struct RmsNorm {
446    eps: f32,
447}
448
449impl candle::CustomOp2 for RmsNorm {
450    fn name(&self) -> &'static str {
451        "rms-norm"
452    }
453
454    fn cpu_fwd(
455        &self,
456        s1: &CpuStorage,
457        l1: &Layout,
458        s2: &CpuStorage,
459        l2: &Layout,
460    ) -> Result<(CpuStorage, Shape)> {
461        use candle::backend::BackendStorage;
462
463        let eps = self.eps;
464        fn inner<
465            T: candle::WithDType
466                + num_traits::Float
467                + num_traits::AsPrimitive<f32>
468                + num_traits::FromPrimitive,
469        >(
470            src: &[T],
471            layout: &Layout,
472            alpha: &[T],
473            alpha_layout: &Layout,
474            eps: f32,
475        ) -> Result<(CpuStorage, Shape)> {
476            let src = match layout.contiguous_offsets() {
477                None => candle::bail!("input has to be contiguous"),
478                Some((o1, o2)) => &src[o1..o2],
479            };
480            let alpha = match alpha_layout.contiguous_offsets() {
481                None => candle::bail!("alpha has to be contiguous"),
482                Some((o1, o2)) => &alpha[o1..o2],
483            };
484            let el_count = layout.shape().elem_count();
485            let dims = layout.shape().dims();
486            let dim_m1 = dims[dims.len() - 1];
487            let mut dst = vec![T::zero(); el_count];
488            src.par_chunks(dim_m1)
489                .zip(dst.par_chunks_mut(dim_m1))
490                .for_each(|(src, dst)| {
491                    let sum2 = src
492                        .iter()
493                        .map(|&v| {
494                            let v = v.as_();
495                            v * v
496                        })
497                        .sum::<f32>();
498                    let m = (sum2 / dim_m1 as f32 + eps).sqrt();
499                    let m = T::from_f32(m).unwrap_or_else(T::nan);
500                    for ((d, s), alpha) in dst.iter_mut().zip(src.iter()).zip(alpha) {
501                        *d = *s / m * *alpha
502                    }
503                });
504            let storage = candle::WithDType::to_cpu_storage_owned(dst);
505            Ok((storage, Shape::from_dims(dims)))
506        }
507
508        use CpuStorage as C;
509        match (s1, s2) {
510            (C::BF16(s1), C::BF16(s2)) => inner::<half::bf16>(s1, l1, s2, l2, eps),
511            (C::F16(s1), C::F16(s2)) => inner::<half::f16>(s1, l1, s2, l2, eps),
512            (C::F32(s1), C::F32(s2)) => inner::<f32>(s1, l1, s2, l2, eps),
513            _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
514        }
515    }
516
517    #[cfg(feature = "cuda")]
518    fn cuda_fwd(
519        &self,
520        s1: &candle::CudaStorage,
521        l1: &Layout,
522        s2: &candle::CudaStorage,
523        l2: &Layout,
524    ) -> Result<(candle::CudaStorage, Shape)> {
525        use candle::cuda_backend::cudarc::driver::{
526            CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
527        };
528        use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
529        use candle::{CudaDevice, WithDType};
530
531        struct S {
532            eps: f32,
533        }
534        impl Map2 for S {
535            fn f<T: DeviceRepr + WithDType>(
536                &self,
537                src: &CudaSlice<T>,
538                layout: &Layout,
539                alpha: &CudaSlice<T>,
540                alpha_layout: &Layout,
541                dev: &CudaDevice,
542            ) -> Result<CudaSlice<T>> {
543                let src = match layout.contiguous_offsets() {
544                    None => candle::bail!("input has to be contiguous"),
545                    Some((o1, o2)) => src.slice(o1..o2),
546                };
547                let alpha = match alpha_layout.contiguous_offsets() {
548                    None => candle::bail!("alpha has to be contiguous"),
549                    Some((o1, o2)) => alpha.slice(o1..o2),
550                };
551                let el = layout.shape().elem_count();
552                let dims = layout.shape().dims();
553                let dim_m1 = dims[dims.len() - 1];
554                let (n_rows, n_cols) = (el / dim_m1, dim_m1);
555
556                let block_size = if n_cols < 1024 { 32 } else { 1024 };
557                let cfg = LaunchConfig {
558                    grid_dim: (n_rows as u32, 1, 1),
559                    block_dim: (block_size, 1, 1),
560                    shared_mem_bytes: 0,
561                };
562                let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
563                // SAFETY: Set later by running the kernel.
564                let dst = unsafe { dev.alloc::<T>(el)? };
565                let mut builder = func.builder();
566                builder.arg(&src);
567                builder.arg(&dst);
568                builder.arg(&alpha);
569                candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
570                // SAFETY: ffi.
571                unsafe { builder.launch(cfg) }.w()?;
572                Ok(dst)
573            }
574        }
575
576        use candle::backend::BackendStorage;
577        let dev = s1.device();
578        let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, dev)?;
579        let dst = candle::cuda_backend::CudaStorage {
580            slice,
581            device: dev.clone(),
582        };
583        Ok((dst, l1.shape().clone()))
584    }
585
586    #[cfg(feature = "metal")]
587    fn metal_fwd(
588        &self,
589        s1: &candle::MetalStorage,
590        l1: &Layout,
591        s2: &candle::MetalStorage,
592        l2: &Layout,
593    ) -> Result<(candle::MetalStorage, Shape)> {
594        use candle::backend::BackendStorage;
595        let device = s1.device();
596        let command_buffer = device.command_buffer()?;
597        let kernels = device.kernels();
598        let name = match (s1.dtype(), s2.dtype()) {
599            (DType::F32, DType::F32) => "rmsnorm_f32",
600            (DType::F16, DType::F16) => "rmsnorm_f16",
601            (DType::BF16, DType::BF16) => "rmsnorm_bf16",
602            (dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"),
603        };
604
605        if !(l1.is_contiguous() && l2.is_contiguous()) {
606            candle::bail!("Non contiguous rmsnorm is not implemented");
607        }
608
609        let last_dim = l1.dims()[l1.shape().rank() - 1];
610        let elem_count = l1.shape().elem_count();
611        let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?;
612        candle_metal_kernels::call_rms_norm(
613            device.metal_device(),
614            &command_buffer,
615            kernels,
616            name,
617            elem_count,
618            last_dim,
619            self.eps,
620            s1.buffer(),
621            l1.start_offset() * s1.dtype().size_in_bytes(),
622            s2.buffer(),
623            l2.start_offset() * s2.dtype().size_in_bytes(),
624            &output,
625        )
626        .map_err(candle::Error::wrap)?;
627        let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
628        Ok((newstorage, l1.shape().clone()))
629    }
630}
631
632pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
633    let x_dtype = x.dtype();
634    let internal_dtype = match x_dtype {
635        DType::F16 | DType::BF16 => DType::F32,
636        d => d,
637    };
638    let hidden_size = x.dim(D::Minus1)?;
639    let x = x.to_dtype(internal_dtype)?;
640    let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
641    let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
642    x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
643}
644
645pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
646    let hidden_size_xs = xs.dim(D::Minus1)?;
647    let hidden_size_alpha = alpha.dims1()?;
648    if hidden_size_xs != hidden_size_alpha {
649        candle::bail!(
650            "shape mismatch in rms-norm {:?} {:?}",
651            xs.shape(),
652            alpha.shape()
653        )
654    }
655    xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
656}
657
658#[derive(Debug, Clone)]
659struct LayerNorm {
660    eps: f32,
661}
662
663impl candle::CustomOp3 for LayerNorm {
664    fn name(&self) -> &'static str {
665        "layer-norm"
666    }
667
668    fn cpu_fwd(
669        &self,
670        s1: &CpuStorage,
671        l1: &Layout,
672        s2: &CpuStorage,
673        l2: &Layout,
674        s3: &CpuStorage,
675        l3: &Layout,
676    ) -> Result<(CpuStorage, Shape)> {
677        use candle::backend::BackendStorage;
678
679        let eps = self.eps;
680        fn inner<
681            T: candle::WithDType
682                + num_traits::Float
683                + num_traits::AsPrimitive<f32>
684                + num_traits::FromPrimitive,
685        >(
686            src: &[T],
687            layout: &Layout,
688            alpha: &[T],
689            alpha_layout: &Layout,
690            beta: &[T],
691            beta_layout: &Layout,
692            eps: f32,
693        ) -> Result<(CpuStorage, Shape)> {
694            let src = match layout.contiguous_offsets() {
695                None => candle::bail!("input has to be contiguous"),
696                Some((o1, o2)) => &src[o1..o2],
697            };
698            let alpha = match alpha_layout.contiguous_offsets() {
699                None => candle::bail!("alpha has to be contiguous"),
700                Some((o1, o2)) => &alpha[o1..o2],
701            };
702            let beta = match beta_layout.contiguous_offsets() {
703                None => candle::bail!("beta has to be contiguous"),
704                Some((o1, o2)) => &beta[o1..o2],
705            };
706            let el_count = layout.shape().elem_count();
707            let dims = layout.shape().dims();
708            let dim_m1 = dims[dims.len() - 1];
709            let mut dst = vec![T::zero(); el_count];
710            src.par_chunks(dim_m1)
711                .zip(dst.par_chunks_mut(dim_m1))
712                .for_each(|(src, dst)| {
713                    let mut sum = 0f32;
714                    let mut sum2 = 0f32;
715                    for v in src {
716                        let v = v.as_();
717                        sum += v;
718                        sum2 += v * v;
719                    }
720                    let mean = sum / dim_m1 as f32;
721                    let var = sum2 / dim_m1 as f32 - mean * mean;
722                    let inv_std = (var + eps).sqrt().recip();
723                    for ((d, s), (alpha, beta)) in
724                        dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta))
725                    {
726                        let alpha = alpha.as_();
727                        let beta = beta.as_();
728                        let d_ = (s.as_() - mean) * inv_std * alpha + beta;
729                        *d = T::from_f32(d_).unwrap_or_else(T::nan);
730                    }
731                });
732            let storage = candle::WithDType::to_cpu_storage_owned(dst);
733            Ok((storage, Shape::from_dims(dims)))
734        }
735
736        use CpuStorage as C;
737        match (s1, s2, s3) {
738            (C::BF16(s1), C::BF16(s2), C::BF16(s3)) => {
739                inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps)
740            }
741            (C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps),
742            (C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps),
743            _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
744        }
745    }
746
747    #[cfg(feature = "cuda")]
748    fn cuda_fwd(
749        &self,
750        s1: &candle::CudaStorage,
751        l1: &Layout,
752        s2: &candle::CudaStorage,
753        l2: &Layout,
754        s3: &candle::CudaStorage,
755        l3: &Layout,
756    ) -> Result<(candle::CudaStorage, Shape)> {
757        use candle::cuda_backend::cudarc::driver::{
758            CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
759        };
760        use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
761        use candle::{CudaDevice, WithDType};
762
763        struct S {
764            eps: f32,
765        }
766        impl Map3 for S {
767            fn f<T: DeviceRepr + WithDType>(
768                &self,
769                src: &CudaSlice<T>,
770                layout: &Layout,
771                alpha: &CudaSlice<T>,
772                alpha_layout: &Layout,
773                beta: &CudaSlice<T>,
774                beta_layout: &Layout,
775                dev: &CudaDevice,
776            ) -> Result<CudaSlice<T>> {
777                let src = match layout.contiguous_offsets() {
778                    None => candle::bail!("input has to be contiguous"),
779                    Some((o1, o2)) => src.slice(o1..o2),
780                };
781                let alpha = match alpha_layout.contiguous_offsets() {
782                    None => candle::bail!("alpha has to be contiguous"),
783                    Some((o1, o2)) => alpha.slice(o1..o2),
784                };
785                let beta = match beta_layout.contiguous_offsets() {
786                    None => candle::bail!("beta has to be contiguous"),
787                    Some((o1, o2)) => beta.slice(o1..o2),
788                };
789                let el = layout.shape().elem_count();
790                let dims = layout.shape().dims();
791                let dim_m1 = dims[dims.len() - 1];
792                let (n_rows, n_cols) = (el / dim_m1, dim_m1);
793
794                let block_size = if n_cols < 1024 { 32 } else { 1024 };
795                let cfg = LaunchConfig {
796                    grid_dim: (n_rows as u32, 1, 1),
797                    block_dim: (block_size, 1, 1),
798                    shared_mem_bytes: 0,
799                };
800                let func =
801                    dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
802                // SAFETY: Set later by running the kernel.
803                let dst = unsafe { dev.alloc::<T>(el)? };
804                let mut builder = func.builder();
805                builder.arg(&src);
806                builder.arg(&dst);
807                builder.arg(&alpha);
808                builder.arg(&beta);
809                candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
810                // SAFETY: ffi.
811                unsafe { builder.launch(cfg) }.w()?;
812                Ok(dst)
813            }
814        }
815
816        use candle::backend::BackendStorage;
817        let dev = s1.device();
818        let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?;
819        let dst = candle::cuda_backend::CudaStorage {
820            slice,
821            device: dev.clone(),
822        };
823        Ok((dst, l1.shape().clone()))
824    }
825
826    #[cfg(feature = "metal")]
827    fn metal_fwd(
828        &self,
829        s1: &candle::MetalStorage,
830        l1: &Layout,
831        s2: &candle::MetalStorage,
832        l2: &Layout,
833        s3: &candle::MetalStorage,
834        l3: &Layout,
835    ) -> Result<(candle::MetalStorage, Shape)> {
836        use candle::backend::BackendStorage;
837        let device = s1.device();
838        let command_buffer = device.command_buffer()?;
839        let kernels = device.kernels();
840        let name = match (s1.dtype(), s2.dtype(), s3.dtype()) {
841            (DType::F32, DType::F32, DType::F32) => "layernorm_f32",
842            (DType::F16, DType::F16, DType::F16) => "layernorm_f16",
843            (DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16",
844            (dt1, dt2, dt3) => {
845                candle::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}")
846            }
847        };
848
849        if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) {
850            candle::bail!("Non contiguous layernorm is not implemented");
851        }
852
853        let last_dim = l1.dims()[l1.shape().rank() - 1];
854        let elem_count = l1.shape().elem_count();
855        let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?;
856        candle_metal_kernels::call_layer_norm(
857            device.metal_device(),
858            &command_buffer,
859            kernels,
860            name,
861            elem_count,
862            last_dim,
863            self.eps,
864            s1.buffer(),
865            l1.start_offset() * s1.dtype().size_in_bytes(),
866            s2.buffer(),
867            l2.start_offset() * s2.dtype().size_in_bytes(),
868            s3.buffer(),
869            l3.start_offset() * s3.dtype().size_in_bytes(),
870            &output,
871        )
872        .map_err(candle::Error::wrap)?;
873        let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
874        Ok((newstorage, l1.shape().clone()))
875    }
876}
877
878pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
879    let x_dtype = x.dtype();
880    let internal_dtype = match x_dtype {
881        DType::F16 | DType::BF16 => DType::F32,
882        d => d,
883    };
884    let hidden_size = x.dim(D::Minus1)?;
885    let x = x.to_dtype(internal_dtype)?;
886    let x = {
887        let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
888        x.broadcast_sub(&mean_x)?
889    };
890    let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
891    let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
892    x_normed
893        .to_dtype(x_dtype)?
894        .broadcast_mul(alpha)?
895        .broadcast_add(beta)
896}
897
898pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
899    let hidden_size_xs = xs.dim(D::Minus1)?;
900    let hidden_size_alpha = alpha.dims1()?;
901    let hidden_size_beta = beta.dims1()?;
902    if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta {
903        candle::bail!(
904            "shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}",
905            xs.shape(),
906            alpha.shape(),
907            beta.shape()
908        )
909    }
910    xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps })
911}
912
913// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
914pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
915    let (b_size, c, h, w) = xs.dims4()?;
916    let out_c = c / upscale_factor / upscale_factor;
917    xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))?
918        .permute((0, 1, 4, 2, 5, 3))?
919        .reshape((b_size, out_c, h * upscale_factor, w * upscale_factor))
920}
921
922pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
923    let (b_size, c, h, w) = xs.dims4()?;
924    let out_c = c * downscale_factor * downscale_factor;
925    xs.reshape((
926        b_size,
927        c,
928        h / downscale_factor,
929        downscale_factor,
930        w / downscale_factor,
931        downscale_factor,
932    ))?
933    .permute((0, 1, 3, 5, 2, 4))?
934    .reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
935}
936
937// https://pytorch.org/docs/stable/generated/torch.nn.ReplicationPad2d.html
938pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
939    match pad {
940        0 => Ok(xs.clone()),
941        1 => {
942            let (_b_size, _c, h, w) = xs.dims4()?;
943            let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?);
944            let xs = Tensor::cat(&[&first, xs, &last], 3)?;
945            let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?);
946            Tensor::cat(&[&first, &xs, &last], 2)
947        }
948        n => candle::bail!("replication-pad with a size of {n} is not supported"),
949    }
950}
951
952#[derive(Clone, Debug)]
953pub struct Identity;
954
955impl Identity {
956    pub fn new() -> Identity {
957        Self
958    }
959}
960
961impl Default for Identity {
962    fn default() -> Self {
963        Self
964    }
965}
966
967impl Module for Identity {
968    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
969        Ok(xs.clone())
970    }
971}
972
973#[allow(dead_code)]
974struct Sdpa {
975    scale: f32,
976    softcapping: f32,
977}
978
979impl candle::CustomOp3 for Sdpa {
980    fn name(&self) -> &'static str {
981        "metal-sdpa"
982    }
983
984    fn cpu_fwd(
985        &self,
986        _s1: &CpuStorage,
987        _l1: &Layout,
988        _s2: &CpuStorage,
989        _l2: &Layout,
990        _s3: &CpuStorage,
991        _l3: &Layout,
992    ) -> Result<(CpuStorage, Shape)> {
993        candle::bail!("SDPA has no cpu impl")
994    }
995
996    #[cfg(feature = "metal")]
997    fn metal_fwd(
998        &self,
999        q: &candle::MetalStorage,
1000        q_l: &Layout,
1001        k: &candle::MetalStorage,
1002        k_l: &Layout,
1003        v: &candle::MetalStorage,
1004        v_l: &Layout,
1005    ) -> Result<(candle::MetalStorage, Shape)> {
1006        use candle::backend::BackendStorage;
1007        use candle_metal_kernels::SdpaDType;
1008
1009        let device = q.device();
1010
1011        let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?];
1012        let elem_count: usize = out_dims.iter().product();
1013
1014        let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?;
1015
1016        // q,k must have matching emb dim
1017        if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? {
1018            candle::bail!("`q` and `k` last dims must match");
1019        }
1020
1021        // k,v must have matching n kv heads
1022        if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? {
1023            candle::bail!("`k` and `v` head dims must match");
1024        }
1025
1026        // n_heads % n_kv_heads == 0; n_heads >= 1, n_kv_heads >= 1.
1027        if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 {
1028            candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
1029        }
1030
1031        let k_head = k_l.dim(D::Minus1)?;
1032        let q_head = q_l.dim(D::Minus1)?;
1033        let q_seq = q_l.dim(2)?;
1034
1035        let mut implementation_supports_use_case = q_head == k_head;
1036        let supported_head_dim =
1037            q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256;
1038
1039        const SDPA_FULL_THRESHOLD: usize = 2;
1040
1041        let supports_sdpa_full =
1042            q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head;
1043        let supports_sdpa_vector = q_seq == 1 && supported_head_dim;
1044
1045        implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector;
1046
1047        if !supported_head_dim {
1048            candle::bail!(
1049                "Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.",
1050                q_l.dims(),
1051                k_l.dims(),
1052                v_l.dims()
1053            );
1054        }
1055        if !implementation_supports_use_case {
1056            candle::bail!(
1057                "Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.",
1058                q_l.dims(),
1059                k_l.dims(),
1060                v_l.dims()
1061            );
1062        }
1063
1064        for t in [k.dtype(), v.dtype()] {
1065            if q.dtype() != t {
1066                candle::bail!("all q, k, v dtypes must match.");
1067            }
1068        }
1069
1070        let itype = match q.dtype() {
1071            DType::BF16 => SdpaDType::BF16,
1072            DType::F16 => SdpaDType::F16,
1073            DType::F32 => SdpaDType::F32,
1074            other => candle::bail!("unsupported sdpa type {other:?}"),
1075        };
1076
1077        let command_buffer = q.device().command_buffer()?;
1078        if supports_sdpa_vector {
1079            // Route to the 2 pass fused attention if the k seqlen is large.
1080            // https://github.com/ml-explore/mlx/pull/1597
1081            const TWO_PASS_K_THRESHOLD: usize = 1024;
1082            if k_l.dim(2)? >= TWO_PASS_K_THRESHOLD {
1083                let mut intermediate_shape = [
1084                    &out_dims[0..out_dims.len() - 2],
1085                    &[candle_metal_kernels::SDPA_2PASS_BLOCKS],
1086                    &[out_dims[out_dims.len() - 1]],
1087                ]
1088                .concat();
1089                let intermediate = device.new_buffer(
1090                    intermediate_shape.iter().product::<usize>(),
1091                    DType::F32,
1092                    "sdpa_2pass_intermediate",
1093                )?;
1094                let _ = intermediate_shape.pop().unwrap();
1095                let sums = device.new_buffer(
1096                    intermediate_shape.iter().product::<usize>(),
1097                    DType::F32,
1098                    "sdpa_2pass_sums",
1099                )?;
1100                let maxs = device.new_buffer(
1101                    intermediate_shape.iter().product::<usize>(),
1102                    DType::F32,
1103                    "sdpa_2pass_maxs",
1104                )?;
1105
1106                command_buffer.set_label("vector_attention");
1107                candle_metal_kernels::call_sdpa_vector_2pass(
1108                    q.device().device(),
1109                    &command_buffer,
1110                    q.device().kernels(),
1111                    q_l.start_offset(),
1112                    q_l.dims(),
1113                    q.buffer(),
1114                    k_l.start_offset(),
1115                    k_l.dims(),
1116                    k_l.stride(),
1117                    k.buffer(),
1118                    v_l.start_offset(),
1119                    v_l.stride(),
1120                    v.buffer(),
1121                    &output,
1122                    &intermediate,
1123                    &sums,
1124                    &maxs,
1125                    self.scale,
1126                    self.softcapping,
1127                    itype,
1128                )
1129                .map_err(candle::Error::wrap)?;
1130            } else {
1131                command_buffer.set_label("vector_attention");
1132                candle_metal_kernels::call_sdpa_vector(
1133                    q.device().device(),
1134                    &command_buffer,
1135                    q.device().kernels(),
1136                    q_l.start_offset(),
1137                    q_l.dims(),
1138                    q.buffer(),
1139                    k_l.start_offset(),
1140                    k_l.dims(),
1141                    k_l.stride(),
1142                    k.buffer(),
1143                    v_l.start_offset(),
1144                    v_l.stride(),
1145                    v.buffer(),
1146                    &output,
1147                    self.scale,
1148                    self.softcapping,
1149                    itype,
1150                )
1151                .map_err(candle::Error::wrap)?;
1152            }
1153        } else if supports_sdpa_full {
1154            if q_l.dim(2)? != k_l.dim(2)? {
1155                candle::bail!(
1156                    "query and key sequence length must be equal if using full metal sdpa"
1157                )
1158            }
1159
1160            command_buffer.set_label("full_attention");
1161            candle_metal_kernels::call_sdpa_full(
1162                q.device().device(),
1163                &command_buffer,
1164                q.device().kernels(),
1165                q_l.start_offset(),
1166                q_l.dims(),
1167                q.buffer(),
1168                k_l.start_offset(),
1169                k.buffer(),
1170                v_l.start_offset(),
1171                v.buffer(),
1172                &output,
1173                self.scale,
1174                self.softcapping,
1175                itype,
1176            )
1177            .map_err(candle::Error::wrap)?;
1178        } else {
1179            candle::bail!("must be vector or full sdpa kernel");
1180        }
1181
1182        let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype());
1183        Ok((newstorage, Shape::from_dims(&out_dims)))
1184    }
1185}
1186
1187/// Scaled dot product attention with a fused kernel.
1188///
1189/// Computes softmax(qk^T*scale)v.
1190///
1191/// **Inputs shapes:**
1192/// - `q`: (bs, qhead, seq, hidden)
1193/// - `k`: (bs, kv_head, kv_seq, hidden)
1194/// - `k`: (bs, kv_head, kv_seq, v_hidden)
1195/// - `scale` is applied before softmax.
1196/// - If `softcapping` != 1.0:
1197///      - Computation is: softmax(tanh(qk^T*scale/cap)*cap)v
1198///
1199/// **Output shape:** (bs, qhead, seq, v_hidden)
1200///
1201/// **Supported head dims:** 32, 64, 96, 128, 256.
1202///
1203/// ## On Metal:
1204/// - If `seq` == 1:
1205///     - Use a vectorized kernel
1206///     - Supports `seq` != `kv_seq` (cross attn. support)
1207///     - Supports GQA when `qhead` is a multiple of `kv_head`
1208/// - Otherwise:
1209///     - Use an alternate kernel
1210///     - Requires `seq` == `kv_seq`
1211///     - GQA is not supported (requires `qhead` == `kv_head`)
1212pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result<Tensor> {
1213    q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping })
1214}