1use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
5use rayon::prelude::*;
6
7pub fn softmax<D: candle::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
23 let dim = dim.to_index(xs.shape(), "softmax")?;
24 let max = xs.max_keepdim(dim)?;
25 let diff = xs.broadcast_sub(&max)?;
26 let num = diff.exp()?;
27 let den = num.sum_keepdim(dim)?;
28 num.broadcast_div(&den)
29}
30
31pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
32 let d = d.to_index(xs.shape(), "log-softmax")?;
33 let max = xs.max_keepdim(d)?;
34 let diff = xs.broadcast_sub(&max)?;
35 let sum_exp = diff.exp()?.sum_keepdim(d)?;
36 let log_sm = diff.broadcast_sub(&sum_exp.log()?)?;
37 Ok(log_sm)
38}
39
40pub fn silu(xs: &Tensor) -> Result<Tensor> {
41 xs.silu()
42}
43
44pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
45 let xs = xs.chunk(2, D::Minus1)?;
46 &xs[0].silu()? * &xs[1]
47}
48
49struct Sigmoid;
50
51impl candle::CustomOp1 for Sigmoid {
52 fn name(&self) -> &'static str {
53 "sigmoid"
54 }
55
56 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
57 use candle::backend::BackendStorage;
58
59 fn fwd<T: num_traits::Float>(v: T) -> T {
60 (v.neg().exp() + T::one()).recip()
61 }
62
63 let storage = match storage {
65 CpuStorage::BF16(slice) => {
66 CpuStorage::BF16(candle::cpu_backend::unary_map(slice, layout, fwd))
67 }
68 CpuStorage::F16(slice) => {
69 CpuStorage::F16(candle::cpu_backend::unary_map(slice, layout, fwd))
70 }
71 CpuStorage::F32(slice) => {
72 CpuStorage::F32(candle::cpu_backend::unary_map(slice, layout, fwd))
73 }
74 CpuStorage::F64(slice) => {
75 CpuStorage::F64(candle::cpu_backend::unary_map(slice, layout, fwd))
76 }
77 _ => Err(candle::Error::UnsupportedDTypeForOp(
78 storage.dtype(),
79 self.name(),
80 ))?,
81 };
82 Ok((storage, layout.shape().clone()))
83 }
84
85 #[cfg(feature = "cuda")]
86 fn cuda_fwd(
87 &self,
88 storage: &candle::CudaStorage,
89 layout: &Layout,
90 ) -> Result<(candle::CudaStorage, Shape)> {
91 use candle::backend::BackendStorage;
92 use candle::cuda_backend::cudarc::driver::{
93 CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
94 };
95 use candle::cuda_backend::SlicePtrOrNull;
96 use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
97 use candle::{CudaDevice, WithDType};
98
99 struct S;
100 impl Map1 for S {
101 fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
102 &self,
103 src: &CudaSlice<T>,
104 dev: &CudaDevice,
105 layout: &Layout,
106 ) -> Result<CudaSlice<T>> {
107 let shape = layout.shape();
108 let dims = shape.dims();
109 let el_count = shape.elem_count();
110 let cfg = LaunchConfig::for_num_elems(el_count as u32);
111 let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
112 let src = &src.slice(layout.start_offset()..);
113 let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
114 let out = unsafe { dev.alloc::<T>(el_count)? };
116
117 let mut builder = func.builder();
118 candle::builder_arg!(builder, el_count, dims.len());
119 ds.builder_arg(&mut builder);
120 builder.arg(src);
121 builder.arg(&out);
122 unsafe { builder.launch(cfg) }.w()?;
124 Ok(out)
125 }
126 }
127
128 let dev = storage.device();
129 let slice = S.map(&storage.slice, dev, layout)?;
130 let dst = candle::CudaStorage {
131 slice,
132 device: dev.clone(),
133 };
134 Ok((dst, layout.shape().clone()))
135 }
136
137 #[cfg(feature = "metal")]
138 fn metal_fwd(
139 &self,
140 storage: &candle::MetalStorage,
141 layout: &Layout,
142 ) -> Result<(candle::MetalStorage, Shape)> {
143 use candle::backend::BackendStorage;
144 use candle::MetalError;
145 let device = storage.device();
146 let dtype = storage.dtype();
147 let shape = layout.shape();
148 let el_count = shape.elem_count();
149 let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
150 let encoder = device.command_encoder()?;
151 encoder.set_label("sigmoid");
152 let src = candle_metal_kernels::BufferOffset {
153 buffer: storage.buffer(),
154 offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
155 };
156
157 if layout.is_contiguous() {
158 use candle_metal_kernels::unary::contiguous;
159 let kernel_name = match dtype {
160 DType::F16 => contiguous::sigmoid::HALF,
161 DType::F32 => contiguous::sigmoid::FLOAT,
162 DType::BF16 => contiguous::sigmoid::BFLOAT,
163 dtype => {
164 candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
165 }
166 };
167 candle_metal_kernels::call_unary_contiguous(
168 device.metal_device(),
169 &encoder,
170 device.kernels(),
171 kernel_name,
172 dtype.size_in_bytes(),
173 el_count,
174 src,
175 &buffer,
176 )
177 .map_err(MetalError::from)?;
178 } else {
179 use candle_metal_kernels::unary::strided;
180 let kernel_name = match dtype {
181 DType::F16 => strided::sigmoid::HALF,
182 DType::F32 => strided::sigmoid::FLOAT,
183 DType::BF16 => strided::sigmoid::BFLOAT,
184 dtype => {
185 candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
186 }
187 };
188 let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer);
189 candle_metal_kernels::call_unary_strided(
190 device.metal_device(),
191 &encoder,
192 device.kernels(),
193 kernel_name,
194 layout.dims(),
195 src,
196 layout.stride(),
197 dst,
198 )
199 .map_err(MetalError::from)?;
200 }
201
202 let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype);
203 Ok((new_storage, layout.shape().clone()))
204 }
205
206 fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
207 let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?;
209 Ok(Some(grad_res.mul(&d_dx_sigmoid)?))
210 }
211}
212
213pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
214 xs.apply_op1(Sigmoid)
215}
216
217pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
218 ((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32)
220}
221
222pub fn mish(xs: &Tensor) -> Result<Tensor> {
223 xs * (1.0 + xs.exp()?)?.log()?.tanh()
224}
225
226pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
227 let zeros = xs.zeros_like()?;
228 xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
229}
230
231pub fn selu(xs: &Tensor, alpha: f32, gamma: f32) -> Result<Tensor> {
232 let is_pos = xs.gt(0f32)?;
233 let alpha_t = Tensor::full(alpha, xs.dims(), xs.device())?;
234 let neg = xs.exp()?.mul(&alpha_t)?.sub(&alpha_t)?;
235 let selu = is_pos.where_cond(xs, &neg)?;
236 let gamma_t = Tensor::full(gamma, xs.dims(), xs.device())?;
237 selu.broadcast_mul(&gamma_t)
238}
239
240pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
241 if !(0. ..1.).contains(&drop_p) {
247 candle::bail!("dropout probability has to be in [0, 1), got {drop_p}")
248 }
249 let rand = Tensor::rand(0f32, 1f32, xs.shape(), xs.device())?;
250 let scale = 1.0 / (1.0 - drop_p as f64);
251 let drop_p = Tensor::new(drop_p, xs.device())?.broadcast_as(xs.shape())?;
252 let mask = (rand.ge(&drop_p)?.to_dtype(xs.dtype())? * scale)?;
253 xs * mask
254}
255
256#[derive(Clone, Debug)]
257pub struct Dropout {
258 drop_p: f32,
259}
260
261impl Dropout {
262 pub fn new(drop_p: f32) -> Dropout {
263 Self { drop_p }
264 }
265
266 pub fn forward(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
267 if train {
268 dropout(xs, self.drop_p)
269 } else {
270 Ok(xs.clone())
271 }
272 }
273}
274
275impl candle::ModuleT for Dropout {
276 fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
277 self.forward(xs, train)
278 }
279}
280
281struct SoftmaxLastDim;
282
283impl candle::CustomOp1 for SoftmaxLastDim {
284 fn name(&self) -> &'static str {
285 "softmax-last-dim"
286 }
287
288 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
289 fn softmax<T: candle::WithDType + num_traits::Float>(
290 src: &[T],
291 layout: &Layout,
292 ) -> Result<(CpuStorage, Shape)> {
293 let src = match layout.contiguous_offsets() {
294 None => candle::bail!("input has to be contiguous"),
295 Some((o1, o2)) => &src[o1..o2],
296 };
297 let el_count = layout.shape().elem_count();
298 let dims = layout.shape().dims();
299 let dim_m1 = dims[dims.len() - 1];
300 let mut dst = vec![T::zero(); el_count];
301 src.par_chunks(dim_m1)
302 .zip(dst.par_chunks_mut(dim_m1))
303 .for_each(|(src, dst)| {
304 let mut max = T::neg_infinity();
305 unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) };
306 for (s, d) in src.iter().zip(dst.iter_mut()) {
307 *d = (*s - max).exp();
308 }
309 let mut sum_exp = T::zero();
310 unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) };
311 for d in dst.iter_mut() {
312 *d /= sum_exp
313 }
314 });
315 let storage = candle::WithDType::to_cpu_storage_owned(dst);
316 Ok((storage, Shape::from_dims(dims)))
317 }
318
319 match storage {
320 CpuStorage::BF16(slice) => softmax::<half::bf16>(slice, layout),
321 CpuStorage::F16(slice) => softmax::<half::f16>(slice, layout),
322 CpuStorage::F32(slice) => softmax::<f32>(slice, layout),
323 CpuStorage::F64(slice) => softmax::<f64>(slice, layout),
324 _ => candle::bail!("unsupported dtype for softmax {:?}", storage),
325 }
326 }
327
328 #[cfg(feature = "cuda")]
329 fn cuda_fwd(
330 &self,
331 storage: &candle::CudaStorage,
332 layout: &Layout,
333 ) -> Result<(candle::CudaStorage, Shape)> {
334 use candle::cuda_backend::cudarc::driver::{
335 CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
336 };
337 use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
338 use candle::{CudaDevice, WithDType};
339
340 struct S;
341 impl Map1 for S {
342 fn f<T: DeviceRepr + WithDType>(
343 &self,
344 src: &CudaSlice<T>,
345 dev: &CudaDevice,
346 layout: &Layout,
347 ) -> Result<CudaSlice<T>> {
348 let src = match layout.contiguous_offsets() {
349 None => candle::bail!("input has to be contiguous"),
350 Some((o1, o2)) => src.slice(o1..o2),
351 };
352 let el = layout.shape().elem_count();
353 let dims = layout.shape().dims();
354 let dim_m1 = dims[dims.len() - 1];
355 let (n_rows, n_cols) = (el / dim_m1, dim_m1);
356
357 let cfg = LaunchConfig {
358 grid_dim: (n_rows as u32, 1, 1),
359 block_dim: (1, 32, 1),
360 shared_mem_bytes: 0,
361 };
362 let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
363 let dst = unsafe { dev.alloc::<T>(el)? };
365 let mut builder = func.builder();
366 builder.arg(&src);
367 builder.arg(&dst);
368 candle::builder_arg!(builder, n_cols as i32);
369 unsafe { builder.launch(cfg) }.w()?;
371 Ok(dst)
372 }
373 }
374
375 use candle::backend::BackendStorage;
376 let dev = storage.device();
377 let slice = S.map(&storage.slice, dev, layout)?;
378 let dst = candle::cuda_backend::CudaStorage {
379 slice,
380 device: dev.clone(),
381 };
382 Ok((dst, layout.shape().clone()))
383 }
384
385 #[cfg(feature = "metal")]
386 fn metal_fwd(
387 &self,
388 storage: &candle::MetalStorage,
389 layout: &Layout,
390 ) -> Result<(candle::MetalStorage, Shape)> {
391 use candle::backend::BackendStorage;
392 let device = storage.device();
393 let encoder = device.command_encoder()?;
394 encoder.set_label("softmax");
395 let kernels = device.kernels();
396 let name = match storage.dtype() {
397 DType::F32 => "softmax_f32",
398 DType::F16 => "softmax_f16",
399 DType::BF16 => "softmax_bf16",
400 dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
401 };
402
403 let n = layout.stride().len();
404 if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
405 candle::bail!("Non contiguous softmax-last-dim is not implemented");
406 }
407
408 let last_dim = layout.dims()[layout.shape().rank() - 1];
409 let elem_count = layout.shape().elem_count();
410 let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
411 candle_metal_kernels::call_last_softmax(
412 device.metal_device(),
413 &encoder,
414 kernels,
415 name,
416 elem_count,
417 last_dim,
418 storage.buffer(),
419 layout.start_offset() * storage.dtype().size_in_bytes(),
420 &output,
421 )
422 .map_err(candle::Error::wrap)?;
423 let newstorage =
424 candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
425 Ok((newstorage, layout.shape().clone()))
426 }
427}
428
429pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
430 xs.apply_op1_no_bwd(&SoftmaxLastDim)
431}
432
433#[derive(Debug, Clone)]
434struct RmsNorm {
435 eps: f32,
436}
437
438impl candle::CustomOp2 for RmsNorm {
439 fn name(&self) -> &'static str {
440 "rms-norm"
441 }
442
443 fn cpu_fwd(
444 &self,
445 s1: &CpuStorage,
446 l1: &Layout,
447 s2: &CpuStorage,
448 l2: &Layout,
449 ) -> Result<(CpuStorage, Shape)> {
450 use candle::backend::BackendStorage;
451
452 let eps = self.eps;
453 fn inner<
454 T: candle::WithDType
455 + num_traits::Float
456 + num_traits::AsPrimitive<f32>
457 + num_traits::FromPrimitive,
458 >(
459 src: &[T],
460 layout: &Layout,
461 alpha: &[T],
462 alpha_layout: &Layout,
463 eps: f32,
464 ) -> Result<(CpuStorage, Shape)> {
465 let src = match layout.contiguous_offsets() {
466 None => candle::bail!("input has to be contiguous"),
467 Some((o1, o2)) => &src[o1..o2],
468 };
469 let alpha = match alpha_layout.contiguous_offsets() {
470 None => candle::bail!("alpha has to be contiguous"),
471 Some((o1, o2)) => &alpha[o1..o2],
472 };
473 let el_count = layout.shape().elem_count();
474 let dims = layout.shape().dims();
475 let dim_m1 = dims[dims.len() - 1];
476 let mut dst = vec![T::zero(); el_count];
477 src.par_chunks(dim_m1)
478 .zip(dst.par_chunks_mut(dim_m1))
479 .for_each(|(src, dst)| {
480 let sum2 = src
481 .iter()
482 .map(|&v| {
483 let v = v.as_();
484 v * v
485 })
486 .sum::<f32>();
487 let m = (sum2 / dim_m1 as f32 + eps).sqrt();
488 let m = T::from_f32(m).unwrap_or_else(T::nan);
489 for ((d, s), alpha) in dst.iter_mut().zip(src.iter()).zip(alpha) {
490 *d = *s / m * *alpha
491 }
492 });
493 let storage = candle::WithDType::to_cpu_storage_owned(dst);
494 Ok((storage, Shape::from_dims(dims)))
495 }
496
497 use CpuStorage as C;
498 match (s1, s2) {
499 (C::BF16(s1), C::BF16(s2)) => inner::<half::bf16>(s1, l1, s2, l2, eps),
500 (C::F16(s1), C::F16(s2)) => inner::<half::f16>(s1, l1, s2, l2, eps),
501 (C::F32(s1), C::F32(s2)) => inner::<f32>(s1, l1, s2, l2, eps),
502 _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
503 }
504 }
505
506 #[cfg(feature = "cuda")]
507 fn cuda_fwd(
508 &self,
509 s1: &candle::CudaStorage,
510 l1: &Layout,
511 s2: &candle::CudaStorage,
512 l2: &Layout,
513 ) -> Result<(candle::CudaStorage, Shape)> {
514 use candle::cuda_backend::cudarc::driver::{
515 CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
516 };
517 use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
518 use candle::{CudaDevice, WithDType};
519
520 struct S {
521 eps: f32,
522 }
523 impl Map2 for S {
524 fn f<T: DeviceRepr + WithDType>(
525 &self,
526 src: &CudaSlice<T>,
527 layout: &Layout,
528 alpha: &CudaSlice<T>,
529 alpha_layout: &Layout,
530 dev: &CudaDevice,
531 ) -> Result<CudaSlice<T>> {
532 let src = match layout.contiguous_offsets() {
533 None => candle::bail!("input has to be contiguous"),
534 Some((o1, o2)) => src.slice(o1..o2),
535 };
536 let alpha = match alpha_layout.contiguous_offsets() {
537 None => candle::bail!("alpha has to be contiguous"),
538 Some((o1, o2)) => alpha.slice(o1..o2),
539 };
540 let el = layout.shape().elem_count();
541 let dims = layout.shape().dims();
542 let dim_m1 = dims[dims.len() - 1];
543 let (n_rows, n_cols) = (el / dim_m1, dim_m1);
544
545 let block_size = if n_cols < 1024 { 32 } else { 1024 };
546 let cfg = LaunchConfig {
547 grid_dim: (n_rows as u32, 1, 1),
548 block_dim: (block_size, 1, 1),
549 shared_mem_bytes: 0,
550 };
551 let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
552 let dst = unsafe { dev.alloc::<T>(el)? };
554 let mut builder = func.builder();
555 builder.arg(&src);
556 builder.arg(&dst);
557 builder.arg(&alpha);
558 candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
559 unsafe { builder.launch(cfg) }.w()?;
561 Ok(dst)
562 }
563 }
564
565 use candle::backend::BackendStorage;
566 let dev = s1.device();
567 let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, dev)?;
568 let dst = candle::cuda_backend::CudaStorage {
569 slice,
570 device: dev.clone(),
571 };
572 Ok((dst, l1.shape().clone()))
573 }
574
575 #[cfg(feature = "metal")]
576 fn metal_fwd(
577 &self,
578 s1: &candle::MetalStorage,
579 l1: &Layout,
580 s2: &candle::MetalStorage,
581 l2: &Layout,
582 ) -> Result<(candle::MetalStorage, Shape)> {
583 use candle::backend::BackendStorage;
584 let device = s1.device();
585 let encoder = device.command_encoder()?;
586 encoder.set_label("rmsnorm");
587 let kernels = device.kernels();
588 let name = match (s1.dtype(), s2.dtype()) {
589 (DType::F32, DType::F32) => "rmsnorm_f32",
590 (DType::F16, DType::F16) => "rmsnorm_f16",
591 (DType::BF16, DType::BF16) => "rmsnorm_bf16",
592 (dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"),
593 };
594
595 if !(l1.is_contiguous() && l2.is_contiguous()) {
596 candle::bail!("Non contiguous rmsnorm is not implemented");
597 }
598
599 let last_dim = l1.dims()[l1.shape().rank() - 1];
600 let elem_count = l1.shape().elem_count();
601 let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?;
602 candle_metal_kernels::call_rms_norm(
603 device.metal_device(),
604 &encoder,
605 kernels,
606 name,
607 elem_count,
608 last_dim,
609 self.eps,
610 s1.buffer(),
611 l1.start_offset() * s1.dtype().size_in_bytes(),
612 s2.buffer(),
613 l2.start_offset() * s2.dtype().size_in_bytes(),
614 &output,
615 )
616 .map_err(candle::Error::wrap)?;
617 let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
618 Ok((newstorage, l1.shape().clone()))
619 }
620}
621
622pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
623 let x_dtype = x.dtype();
624 let internal_dtype = match x_dtype {
625 DType::F16 | DType::BF16 => DType::F32,
626 d => d,
627 };
628 let hidden_size = x.dim(D::Minus1)?;
629 let x = x.to_dtype(internal_dtype)?;
630 let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
631 let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
632 x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
633}
634
635pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
636 let hidden_size_xs = xs.dim(D::Minus1)?;
637 let hidden_size_alpha = alpha.dims1()?;
638 if hidden_size_xs != hidden_size_alpha {
639 candle::bail!(
640 "shape mismatch in rms-norm {:?} {:?}",
641 xs.shape(),
642 alpha.shape()
643 )
644 }
645 xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
646}
647
648#[derive(Debug, Clone)]
649struct LayerNorm {
650 eps: f32,
651}
652
653impl candle::CustomOp3 for LayerNorm {
654 fn name(&self) -> &'static str {
655 "layer-norm"
656 }
657
658 fn cpu_fwd(
659 &self,
660 s1: &CpuStorage,
661 l1: &Layout,
662 s2: &CpuStorage,
663 l2: &Layout,
664 s3: &CpuStorage,
665 l3: &Layout,
666 ) -> Result<(CpuStorage, Shape)> {
667 use candle::backend::BackendStorage;
668
669 let eps = self.eps;
670 fn inner<
671 T: candle::WithDType
672 + num_traits::Float
673 + num_traits::AsPrimitive<f32>
674 + num_traits::FromPrimitive,
675 >(
676 src: &[T],
677 layout: &Layout,
678 alpha: &[T],
679 alpha_layout: &Layout,
680 beta: &[T],
681 beta_layout: &Layout,
682 eps: f32,
683 ) -> Result<(CpuStorage, Shape)> {
684 let src = match layout.contiguous_offsets() {
685 None => candle::bail!("input has to be contiguous"),
686 Some((o1, o2)) => &src[o1..o2],
687 };
688 let alpha = match alpha_layout.contiguous_offsets() {
689 None => candle::bail!("alpha has to be contiguous"),
690 Some((o1, o2)) => &alpha[o1..o2],
691 };
692 let beta = match beta_layout.contiguous_offsets() {
693 None => candle::bail!("beta has to be contiguous"),
694 Some((o1, o2)) => &beta[o1..o2],
695 };
696 let el_count = layout.shape().elem_count();
697 let dims = layout.shape().dims();
698 let dim_m1 = dims[dims.len() - 1];
699 let mut dst = vec![T::zero(); el_count];
700 src.par_chunks(dim_m1)
701 .zip(dst.par_chunks_mut(dim_m1))
702 .for_each(|(src, dst)| {
703 let mut sum = 0f32;
704 let mut sum2 = 0f32;
705 for v in src {
706 let v = v.as_();
707 sum += v;
708 sum2 += v * v;
709 }
710 let mean = sum / dim_m1 as f32;
711 let var = sum2 / dim_m1 as f32 - mean * mean;
712 let inv_std = (var + eps).sqrt().recip();
713 for ((d, s), (alpha, beta)) in
714 dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta))
715 {
716 let alpha = alpha.as_();
717 let beta = beta.as_();
718 let d_ = (s.as_() - mean) * inv_std * alpha + beta;
719 *d = T::from_f32(d_).unwrap_or_else(T::nan);
720 }
721 });
722 let storage = candle::WithDType::to_cpu_storage_owned(dst);
723 Ok((storage, Shape::from_dims(dims)))
724 }
725
726 use CpuStorage as C;
727 match (s1, s2, s3) {
728 (C::BF16(s1), C::BF16(s2), C::BF16(s3)) => {
729 inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps)
730 }
731 (C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps),
732 (C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps),
733 _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
734 }
735 }
736
737 #[cfg(feature = "cuda")]
738 fn cuda_fwd(
739 &self,
740 s1: &candle::CudaStorage,
741 l1: &Layout,
742 s2: &candle::CudaStorage,
743 l2: &Layout,
744 s3: &candle::CudaStorage,
745 l3: &Layout,
746 ) -> Result<(candle::CudaStorage, Shape)> {
747 use candle::cuda_backend::cudarc::driver::{
748 CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
749 };
750 use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
751 use candle::{CudaDevice, WithDType};
752
753 struct S {
754 eps: f32,
755 }
756 impl Map3 for S {
757 fn f<T: DeviceRepr + WithDType>(
758 &self,
759 src: &CudaSlice<T>,
760 layout: &Layout,
761 alpha: &CudaSlice<T>,
762 alpha_layout: &Layout,
763 beta: &CudaSlice<T>,
764 beta_layout: &Layout,
765 dev: &CudaDevice,
766 ) -> Result<CudaSlice<T>> {
767 let src = match layout.contiguous_offsets() {
768 None => candle::bail!("input has to be contiguous"),
769 Some((o1, o2)) => src.slice(o1..o2),
770 };
771 let alpha = match alpha_layout.contiguous_offsets() {
772 None => candle::bail!("alpha has to be contiguous"),
773 Some((o1, o2)) => alpha.slice(o1..o2),
774 };
775 let beta = match beta_layout.contiguous_offsets() {
776 None => candle::bail!("beta has to be contiguous"),
777 Some((o1, o2)) => beta.slice(o1..o2),
778 };
779 let el = layout.shape().elem_count();
780 let dims = layout.shape().dims();
781 let dim_m1 = dims[dims.len() - 1];
782 let (n_rows, n_cols) = (el / dim_m1, dim_m1);
783
784 let block_size = if n_cols < 1024 { 32 } else { 1024 };
785 let cfg = LaunchConfig {
786 grid_dim: (n_rows as u32, 1, 1),
787 block_dim: (block_size, 1, 1),
788 shared_mem_bytes: 0,
789 };
790 let func =
791 dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
792 let dst = unsafe { dev.alloc::<T>(el)? };
794 let mut builder = func.builder();
795 builder.arg(&src);
796 builder.arg(&dst);
797 builder.arg(&alpha);
798 builder.arg(&beta);
799 candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
800 unsafe { builder.launch(cfg) }.w()?;
802 Ok(dst)
803 }
804 }
805
806 use candle::backend::BackendStorage;
807 let dev = s1.device();
808 let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?;
809 let dst = candle::cuda_backend::CudaStorage {
810 slice,
811 device: dev.clone(),
812 };
813 Ok((dst, l1.shape().clone()))
814 }
815
816 #[cfg(feature = "metal")]
817 fn metal_fwd(
818 &self,
819 s1: &candle::MetalStorage,
820 l1: &Layout,
821 s2: &candle::MetalStorage,
822 l2: &Layout,
823 s3: &candle::MetalStorage,
824 l3: &Layout,
825 ) -> Result<(candle::MetalStorage, Shape)> {
826 use candle::backend::BackendStorage;
827 let device = s1.device();
828 let encoder = device.command_encoder()?;
829 encoder.set_label("layernorm");
830 let kernels = device.kernels();
831 let name = match (s1.dtype(), s2.dtype(), s3.dtype()) {
832 (DType::F32, DType::F32, DType::F32) => "layernorm_f32",
833 (DType::F16, DType::F16, DType::F16) => "layernorm_f16",
834 (DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16",
835 (dt1, dt2, dt3) => {
836 candle::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}")
837 }
838 };
839
840 if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) {
841 candle::bail!("Non contiguous layernorm is not implemented");
842 }
843
844 let last_dim = l1.dims()[l1.shape().rank() - 1];
845 let elem_count = l1.shape().elem_count();
846 let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?;
847 candle_metal_kernels::call_layer_norm(
848 device.metal_device(),
849 &encoder,
850 kernels,
851 name,
852 elem_count,
853 last_dim,
854 self.eps,
855 s1.buffer(),
856 l1.start_offset() * s1.dtype().size_in_bytes(),
857 s2.buffer(),
858 l2.start_offset() * s2.dtype().size_in_bytes(),
859 s3.buffer(),
860 l3.start_offset() * s3.dtype().size_in_bytes(),
861 &output,
862 )
863 .map_err(candle::Error::wrap)?;
864 let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
865 Ok((newstorage, l1.shape().clone()))
866 }
867}
868
869pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
870 let x_dtype = x.dtype();
871 let internal_dtype = match x_dtype {
872 DType::F16 | DType::BF16 => DType::F32,
873 d => d,
874 };
875 let hidden_size = x.dim(D::Minus1)?;
876 let x = x.to_dtype(internal_dtype)?;
877 let x = {
878 let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
879 x.broadcast_sub(&mean_x)?
880 };
881 let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
882 let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
883 x_normed
884 .to_dtype(x_dtype)?
885 .broadcast_mul(alpha)?
886 .broadcast_add(beta)
887}
888
889pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
890 let hidden_size_xs = xs.dim(D::Minus1)?;
891 let hidden_size_alpha = alpha.dims1()?;
892 let hidden_size_beta = beta.dims1()?;
893 if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta {
894 candle::bail!(
895 "shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}",
896 xs.shape(),
897 alpha.shape(),
898 beta.shape()
899 )
900 }
901 xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps })
902}
903
904pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
906 let (b_size, c, h, w) = xs.dims4()?;
907 let out_c = c / upscale_factor / upscale_factor;
908 xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))?
909 .permute((0, 1, 4, 2, 5, 3))?
910 .reshape((b_size, out_c, h * upscale_factor, w * upscale_factor))
911}
912
913pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
914 let (b_size, c, h, w) = xs.dims4()?;
915 let out_c = c * downscale_factor * downscale_factor;
916 xs.reshape((
917 b_size,
918 c,
919 h / downscale_factor,
920 downscale_factor,
921 w / downscale_factor,
922 downscale_factor,
923 ))?
924 .permute((0, 1, 3, 5, 2, 4))?
925 .reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
926}
927
928pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
930 match pad {
931 0 => Ok(xs.clone()),
932 1 => {
933 let (_b_size, _c, h, w) = xs.dims4()?;
934 let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?);
935 let xs = Tensor::cat(&[&first, xs, &last], 3)?;
936 let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?);
937 Tensor::cat(&[&first, &xs, &last], 2)
938 }
939 n => candle::bail!("replication-pad with a size of {n} is not supported"),
940 }
941}
942
943#[derive(Clone, Debug)]
944pub struct Identity;
945
946impl Identity {
947 pub fn new() -> Identity {
948 Self
949 }
950}
951
952impl Default for Identity {
953 fn default() -> Self {
954 Self
955 }
956}
957
958impl Module for Identity {
959 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
960 Ok(xs.clone())
961 }
962}
963
964#[allow(dead_code)]
965struct Sdpa {
966 scale: f32,
967 softcapping: f32,
968 mask: Option<Tensor>,
969 do_causal: bool,
970}
971
972impl candle::CustomOp3 for Sdpa {
973 fn name(&self) -> &'static str {
974 "metal-sdpa"
975 }
976
977 fn cpu_fwd(
978 &self,
979 _s1: &CpuStorage,
980 _l1: &Layout,
981 _s2: &CpuStorage,
982 _l2: &Layout,
983 _s3: &CpuStorage,
984 _l3: &Layout,
985 ) -> Result<(CpuStorage, Shape)> {
986 candle::bail!("SDPA has no cpu impl")
987 }
988
989 #[cfg(feature = "metal")]
990 fn metal_fwd(
991 &self,
992 q: &candle::MetalStorage,
993 q_l: &Layout,
994 k: &candle::MetalStorage,
995 k_l: &Layout,
996 v: &candle::MetalStorage,
997 v_l: &Layout,
998 ) -> Result<(candle::MetalStorage, Shape)> {
999 use candle::backend::BackendStorage;
1000 use candle_metal_kernels::SdpaDType;
1001
1002 let device = q.device();
1003
1004 let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?];
1005 let elem_count: usize = out_dims.iter().product();
1006 let out_shape = Shape::from_dims(&out_dims);
1007 let out_layout = Layout::contiguous(out_shape.clone());
1008
1009 let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?;
1010
1011 if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? {
1013 candle::bail!("`q` and `k` last dims must match");
1014 }
1015
1016 if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? {
1018 candle::bail!("`k` and `v` head dims must match");
1019 }
1020
1021 if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 {
1023 candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
1024 }
1025
1026 let k_head = k_l.dim(D::Minus1)?;
1027 let q_head = q_l.dim(D::Minus1)?;
1028 let q_seq = q_l.dim(2)?;
1029 let k_seq = k_l.dim(2)?;
1030
1031 let mut implementation_supports_use_case = q_head == k_head;
1032 let supported_head_dim = q_head == 32
1033 || q_head == 64
1034 || q_head == 72
1035 || q_head == 80
1036 || q_head == 96
1037 || q_head == 128
1038 || q_head == 256;
1039
1040 let supports_sdpa_full_mask = self.mask.is_none() || q_seq <= k_seq;
1041 let supports_sdpa_full = q_seq > 8 && supported_head_dim && supports_sdpa_full_mask;
1042 let supports_sdpa_vector = q_seq <= 8 && supported_head_dim && q_seq <= k_seq;
1043
1044 implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector;
1045
1046 if !supported_head_dim {
1047 candle::bail!(
1048 "Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.",
1049 q_l.dims(),
1050 k_l.dims(),
1051 v_l.dims()
1052 );
1053 }
1054 if !implementation_supports_use_case {
1055 candle::bail!(
1056 "Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.",
1057 q_l.dims(),
1058 k_l.dims(),
1059 v_l.dims()
1060 );
1061 }
1062
1063 for t in [k.dtype(), v.dtype()] {
1064 if q.dtype() != t {
1065 candle::bail!("all q, k, v dtypes must match.");
1066 }
1067 }
1068
1069 let itype = match q.dtype() {
1070 DType::BF16 => SdpaDType::BF16,
1071 DType::F16 => SdpaDType::F16,
1072 DType::F32 => SdpaDType::F32,
1073 other => candle::bail!("unsupported sdpa type {other:?}"),
1074 };
1075
1076 let encoder = q.device().command_encoder()?;
1077 if supports_sdpa_vector {
1078 const TWO_PASS_K_THRESHOLD: usize = 1024;
1081 if k_seq >= TWO_PASS_K_THRESHOLD {
1082 let mut intermediate_shape = [
1083 &out_dims[0..out_dims.len() - 2],
1084 &[candle_metal_kernels::SDPA_2PASS_BLOCKS],
1085 &[out_dims[out_dims.len() - 1]],
1086 ]
1087 .concat();
1088 let intermediate = device.new_buffer(
1089 intermediate_shape.iter().product::<usize>(),
1090 DType::F32,
1091 "sdpa_2pass_intermediate",
1092 )?;
1093 let _ = intermediate_shape.pop().unwrap();
1094 let sums = device.new_buffer(
1095 intermediate_shape.iter().product::<usize>(),
1096 DType::F32,
1097 "sdpa_2pass_sums",
1098 )?;
1099 let maxs = device.new_buffer(
1100 intermediate_shape.iter().product::<usize>(),
1101 DType::F32,
1102 "sdpa_2pass_maxs",
1103 )?;
1104
1105 encoder.set_label("vector_attention");
1106 candle_metal_kernels::call_sdpa_vector_2pass(
1107 q.device().device(),
1108 &encoder,
1109 q.device().kernels(),
1110 q_l.start_offset(),
1111 q_l.dims(),
1112 q.buffer(),
1113 k_l.start_offset(),
1114 k_l.dims(),
1115 k_l.stride(),
1116 k.buffer(),
1117 v_l.start_offset(),
1118 v_l.stride(),
1119 v.buffer(),
1120 &output,
1121 &intermediate,
1122 &sums,
1123 &maxs,
1124 self.scale,
1125 self.softcapping,
1126 itype,
1127 )
1128 .map_err(candle::Error::wrap)?;
1129 } else {
1130 encoder.set_label("vector_attention");
1131 candle_metal_kernels::call_sdpa_vector(
1132 q.device().device(),
1133 &encoder,
1134 q.device().kernels(),
1135 q_l.start_offset(),
1136 q_l.dims(),
1137 q.buffer(),
1138 k_l.start_offset(),
1139 k_l.dims(),
1140 k_l.stride(),
1141 k.buffer(),
1142 v_l.start_offset(),
1143 v_l.stride(),
1144 v.buffer(),
1145 &output,
1146 self.scale,
1147 self.softcapping,
1148 itype,
1149 )
1150 .map_err(candle::Error::wrap)?;
1151 }
1152 } else if supports_sdpa_full {
1153 encoder.set_label("full_attention");
1154 if self.softcapping != 1. {
1155 candle::bail!("SDPA full requires softcapping to be disabled (1.0)");
1156 }
1157
1158 let mask_s_l = self.mask.as_ref().map(|m| m.storage_and_layout());
1159
1160 let (mask_type, mask_buffer, mask_strides) = if let Some(mask) = &self.mask {
1161 let (mask_s, mask_l) = mask_s_l.as_ref().unwrap();
1162
1163 let mask_buffer = match &**mask_s {
1164 candle::Storage::Metal(m) => m.buffer(),
1165 _ => candle::bail!("Expected metal device for mask"),
1166 };
1167
1168 let mask_type = match mask.dtype() {
1169 DType::BF16 => SdpaDType::BF16,
1170 DType::F16 => SdpaDType::F16,
1171 DType::F32 => SdpaDType::F32,
1172 other => candle::bail!("unsupported sdpa type {other:?}"),
1173 };
1174 if mask_type != itype {
1175 candle::bail!("Mask type {mask_type:?} must match q type {itype:?}");
1176 }
1177
1178 if mask_l.dims() != [q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, k_seq] {
1179 candle::bail!(
1180 "Mask shape must be {:?} (bs, qheads, qseq, kseq), got {:?}",
1181 [q_l.dim(0)?, q_head, q_l.dim(2)?, k_seq],
1182 mask_l.dims()
1183 );
1184 }
1185
1186 (
1187 Some(mask_type),
1188 Some(mask_buffer),
1189 Some(mask_l.stride().to_vec()),
1190 )
1191 } else {
1192 (None, None, None)
1193 };
1194
1195 candle_metal_kernels::call_sdpa_full(
1196 q.device().device(),
1197 &encoder,
1198 q.device().kernels(),
1199 q_l.start_offset(),
1200 q_l.dims(),
1201 q_l.stride(),
1202 q.buffer(),
1203 k_l.start_offset(),
1204 k_l.dims(),
1205 k_l.stride(),
1206 k.buffer(),
1207 v_l.start_offset(),
1208 v.buffer(),
1209 v_l.stride(),
1210 mask_type,
1211 mask_buffer,
1212 mask_strides.as_deref(),
1213 &output,
1214 out_layout.stride(),
1215 self.scale,
1216 self.do_causal,
1217 itype,
1218 )
1219 .map_err(candle::Error::wrap)?;
1220 } else {
1221 candle::bail!("must be vector or full sdpa kernel");
1222 }
1223
1224 let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype());
1225 Ok((newstorage, out_shape))
1226 }
1227}
1228
1229pub fn sdpa(
1258 q: &Tensor,
1259 k: &Tensor,
1260 v: &Tensor,
1261 mask: Option<&Tensor>,
1262 do_causal: bool,
1263 scale: f32,
1264 softcapping: f32,
1265) -> Result<Tensor> {
1266 q.apply_op3_no_bwd(
1267 k,
1268 v,
1269 &Sdpa {
1270 scale,
1271 softcapping,
1272 mask: mask.cloned(),
1273 do_causal,
1274 },
1275 )
1276}