1use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
5use rayon::prelude::*;
6
7pub fn softmax<D: candle::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
23 let dim = dim.to_index(xs.shape(), "softmax")?;
24 let max = xs.max_keepdim(dim)?;
25 let diff = xs.broadcast_sub(&max)?;
26 let num = diff.exp()?;
27 let den = num.sum_keepdim(dim)?;
28 num.broadcast_div(&den)
29}
30
31pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
32 let d = d.to_index(xs.shape(), "log-softmax")?;
33 let max = xs.max_keepdim(d)?;
34 let diff = xs.broadcast_sub(&max)?;
35 let sum_exp = diff.exp()?.sum_keepdim(d)?;
36 let log_sm = diff.broadcast_sub(&sum_exp.log()?)?;
37 Ok(log_sm)
38}
39
40pub fn silu(xs: &Tensor) -> Result<Tensor> {
41 xs.silu()
42}
43
44pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
45 let xs = xs.chunk(2, D::Minus1)?;
46 &xs[0].silu()? * &xs[1]
47}
48
49struct Sigmoid;
50
51impl candle::CustomOp1 for Sigmoid {
52 fn name(&self) -> &'static str {
53 "sigmoid"
54 }
55
56 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
57 use candle::backend::BackendStorage;
58
59 fn fwd<T: num_traits::Float>(v: T) -> T {
60 (v.neg().exp() + T::one()).recip()
61 }
62
63 let storage = match storage {
65 CpuStorage::BF16(slice) => {
66 CpuStorage::BF16(candle::cpu_backend::unary_map(slice, layout, fwd))
67 }
68 CpuStorage::F16(slice) => {
69 CpuStorage::F16(candle::cpu_backend::unary_map(slice, layout, fwd))
70 }
71 CpuStorage::F32(slice) => {
72 CpuStorage::F32(candle::cpu_backend::unary_map(slice, layout, fwd))
73 }
74 CpuStorage::F64(slice) => {
75 CpuStorage::F64(candle::cpu_backend::unary_map(slice, layout, fwd))
76 }
77 _ => Err(candle::Error::UnsupportedDTypeForOp(
78 storage.dtype(),
79 self.name(),
80 ))?,
81 };
82 Ok((storage, layout.shape().clone()))
83 }
84
85 #[cfg(feature = "cuda")]
86 fn cuda_fwd(
87 &self,
88 storage: &candle::CudaStorage,
89 layout: &Layout,
90 ) -> Result<(candle::CudaStorage, Shape)> {
91 use candle::backend::BackendStorage;
92 use candle::cuda_backend::cudarc::driver::{
93 CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
94 };
95 use candle::cuda_backend::SlicePtrOrNull;
96 use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
97 use candle::{CudaDevice, WithDType};
98
99 struct S;
100 impl Map1 for S {
101 fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
102 &self,
103 src: &CudaSlice<T>,
104 dev: &CudaDevice,
105 layout: &Layout,
106 ) -> Result<CudaSlice<T>> {
107 let shape = layout.shape();
108 let dims = shape.dims();
109 let el_count = shape.elem_count();
110 let cfg = LaunchConfig::for_num_elems(el_count as u32);
111 let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
112 let src = &src.slice(layout.start_offset()..);
113 let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
114 let out = unsafe { dev.alloc::<T>(el_count)? };
116
117 let mut builder = func.builder();
118 candle::builder_arg!(builder, el_count, dims.len());
119 ds.builder_arg(&mut builder);
120 builder.arg(src);
121 builder.arg(&out);
122 unsafe { builder.launch(cfg) }.w()?;
124 Ok(out)
125 }
126 }
127
128 let dev = storage.device();
129 let slice = S.map(&storage.slice, dev, layout)?;
130 let dst = candle::CudaStorage {
131 slice,
132 device: dev.clone(),
133 };
134 Ok((dst, layout.shape().clone()))
135 }
136
137 #[cfg(feature = "metal")]
138 fn metal_fwd(
139 &self,
140 storage: &candle::MetalStorage,
141 layout: &Layout,
142 ) -> Result<(candle::MetalStorage, Shape)> {
143 use candle::backend::BackendStorage;
144 use candle::MetalError;
145 let device = storage.device();
146 let dtype = storage.dtype();
147 let shape = layout.shape();
148 let el_count = shape.elem_count();
149 let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
150 let command_buffer = device.command_buffer()?;
151 command_buffer.set_label("sigmoid");
152 let src = candle_metal_kernels::BufferOffset {
153 buffer: storage.buffer(),
154 offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
155 };
156
157 match (el_count % 2, dtype, layout.is_contiguous()) {
158 (0, DType::BF16 | DType::F16, true) => {
159 use candle_metal_kernels::unary::contiguous_tiled;
160 let kernel_name = match dtype {
161 DType::F16 => contiguous_tiled::sigmoid::HALF,
162 DType::F32 => contiguous_tiled::sigmoid::FLOAT,
163 DType::BF16 => contiguous_tiled::sigmoid::BFLOAT,
164 dtype => {
165 candle::bail!(
166 "Metal contiguous_tiled unary sigmoid {dtype:?} not implemented"
167 )
168 }
169 };
170 candle_metal_kernels::call_unary_contiguous_tiled(
171 device.metal_device(),
172 &command_buffer,
173 device.kernels(),
174 kernel_name,
175 el_count,
176 src,
177 &buffer,
178 )
179 .map_err(MetalError::from)?;
180 }
181 (_, _, true) => {
182 use candle_metal_kernels::unary::contiguous;
183 let kernel_name = match dtype {
184 DType::F16 => contiguous::sigmoid::HALF,
185 DType::F32 => contiguous::sigmoid::FLOAT,
186 DType::BF16 => contiguous::sigmoid::BFLOAT,
187 dtype => {
188 candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
189 }
190 };
191 candle_metal_kernels::call_unary_contiguous(
192 device.metal_device(),
193 &command_buffer,
194 device.kernels(),
195 kernel_name,
196 el_count,
197 src,
198 &buffer,
199 )
200 .map_err(MetalError::from)?;
201 }
202 (_, _, false) => {
203 use candle_metal_kernels::unary::strided;
204 let kernel_name = match dtype {
205 DType::F16 => strided::sigmoid::HALF,
206 DType::F32 => strided::sigmoid::FLOAT,
207 DType::BF16 => strided::sigmoid::BFLOAT,
208 dtype => {
209 candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
210 }
211 };
212 let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer);
213 candle_metal_kernels::call_unary_strided(
214 device.metal_device(),
215 &command_buffer,
216 device.kernels(),
217 kernel_name,
218 layout.dims(),
219 src,
220 layout.stride(),
221 dst,
222 )
223 .map_err(MetalError::from)?;
224 }
225 }
226
227 let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype);
228 Ok((new_storage, layout.shape().clone()))
229 }
230
231 fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
232 let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?;
234 Ok(Some(grad_res.mul(&d_dx_sigmoid)?))
235 }
236}
237
238pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
239 xs.apply_op1(Sigmoid)
240}
241
242pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
243 ((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32)
245}
246
247pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
248 let zeros = xs.zeros_like()?;
249 xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
250}
251
252pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
253 if !(0. ..1.).contains(&drop_p) {
259 candle::bail!("dropout probability has to be in [0, 1), got {drop_p}")
260 }
261 let rand = Tensor::rand(0f32, 1f32, xs.shape(), xs.device())?;
262 let scale = 1.0 / (1.0 - drop_p as f64);
263 let drop_p = Tensor::new(drop_p, xs.device())?.broadcast_as(xs.shape())?;
264 let mask = (rand.ge(&drop_p)?.to_dtype(xs.dtype())? * scale)?;
265 xs * mask
266}
267
268#[derive(Clone, Debug)]
269pub struct Dropout {
270 drop_p: f32,
271}
272
273impl Dropout {
274 pub fn new(drop_p: f32) -> Dropout {
275 Self { drop_p }
276 }
277
278 pub fn forward(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
279 if train {
280 dropout(xs, self.drop_p)
281 } else {
282 Ok(xs.clone())
283 }
284 }
285}
286
287impl candle::ModuleT for Dropout {
288 fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
289 self.forward(xs, train)
290 }
291}
292
293struct SoftmaxLastDim;
294
295impl candle::CustomOp1 for SoftmaxLastDim {
296 fn name(&self) -> &'static str {
297 "softmax-last-dim"
298 }
299
300 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
301 fn softmax<T: candle::WithDType + num_traits::Float>(
302 src: &[T],
303 layout: &Layout,
304 ) -> Result<(CpuStorage, Shape)> {
305 let src = match layout.contiguous_offsets() {
306 None => candle::bail!("input has to be contiguous"),
307 Some((o1, o2)) => &src[o1..o2],
308 };
309 let el_count = layout.shape().elem_count();
310 let dims = layout.shape().dims();
311 let dim_m1 = dims[dims.len() - 1];
312 let mut dst = vec![T::zero(); el_count];
313 src.par_chunks(dim_m1)
314 .zip(dst.par_chunks_mut(dim_m1))
315 .for_each(|(src, dst)| {
316 let mut max = T::neg_infinity();
317 unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) };
318 for (s, d) in src.iter().zip(dst.iter_mut()) {
319 *d = (*s - max).exp();
320 }
321 let mut sum_exp = T::zero();
322 unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) };
323 for d in dst.iter_mut() {
324 *d /= sum_exp
325 }
326 });
327 let storage = candle::WithDType::to_cpu_storage_owned(dst);
328 Ok((storage, Shape::from_dims(dims)))
329 }
330
331 match storage {
332 CpuStorage::BF16(slice) => softmax::<half::bf16>(slice, layout),
333 CpuStorage::F16(slice) => softmax::<half::f16>(slice, layout),
334 CpuStorage::F32(slice) => softmax::<f32>(slice, layout),
335 CpuStorage::F64(slice) => softmax::<f64>(slice, layout),
336 _ => candle::bail!("unsupported dtype for softmax {:?}", storage),
337 }
338 }
339
340 #[cfg(feature = "cuda")]
341 fn cuda_fwd(
342 &self,
343 storage: &candle::CudaStorage,
344 layout: &Layout,
345 ) -> Result<(candle::CudaStorage, Shape)> {
346 use candle::cuda_backend::cudarc::driver::{
347 CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
348 };
349 use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
350 use candle::{CudaDevice, WithDType};
351
352 struct S;
353 impl Map1 for S {
354 fn f<T: DeviceRepr + WithDType>(
355 &self,
356 src: &CudaSlice<T>,
357 dev: &CudaDevice,
358 layout: &Layout,
359 ) -> Result<CudaSlice<T>> {
360 let src = match layout.contiguous_offsets() {
361 None => candle::bail!("input has to be contiguous"),
362 Some((o1, o2)) => src.slice(o1..o2),
363 };
364 let el = layout.shape().elem_count();
365 let dims = layout.shape().dims();
366 let dim_m1 = dims[dims.len() - 1];
367 let (n_rows, n_cols) = (el / dim_m1, dim_m1);
368
369 let cfg = LaunchConfig {
370 grid_dim: (n_rows as u32, 1, 1),
371 block_dim: (1, 32, 1),
372 shared_mem_bytes: 0,
373 };
374 let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
375 let dst = unsafe { dev.alloc::<T>(el)? };
377 let mut builder = func.builder();
378 builder.arg(&src);
379 builder.arg(&dst);
380 candle::builder_arg!(builder, n_cols as i32);
381 unsafe { builder.launch(cfg) }.w()?;
383 Ok(dst)
384 }
385 }
386
387 use candle::backend::BackendStorage;
388 let dev = storage.device();
389 let slice = S.map(&storage.slice, dev, layout)?;
390 let dst = candle::cuda_backend::CudaStorage {
391 slice,
392 device: dev.clone(),
393 };
394 Ok((dst, layout.shape().clone()))
395 }
396
397 #[cfg(feature = "metal")]
398 fn metal_fwd(
399 &self,
400 storage: &candle::MetalStorage,
401 layout: &Layout,
402 ) -> Result<(candle::MetalStorage, Shape)> {
403 use candle::backend::BackendStorage;
404 let device = storage.device();
405 let command_buffer = device.command_buffer()?;
406 let kernels = device.kernels();
407 let name = match storage.dtype() {
408 DType::F32 => "softmax_f32",
409 DType::F16 => "softmax_f16",
410 DType::BF16 => "softmax_bf16",
411 dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
412 };
413
414 let n = layout.stride().len();
415 if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
416 candle::bail!("Non contiguous softmax-last-dim is not implemented");
417 }
418
419 let last_dim = layout.dims()[layout.shape().rank() - 1];
420 let elem_count = layout.shape().elem_count();
421 let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
422 candle_metal_kernels::call_last_softmax(
423 device.metal_device(),
424 &command_buffer,
425 kernels,
426 name,
427 elem_count,
428 last_dim,
429 storage.buffer(),
430 layout.start_offset() * storage.dtype().size_in_bytes(),
431 &output,
432 )
433 .map_err(candle::Error::wrap)?;
434 let newstorage =
435 candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
436 Ok((newstorage, layout.shape().clone()))
437 }
438}
439
440pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
441 xs.apply_op1_no_bwd(&SoftmaxLastDim)
442}
443
444#[derive(Debug, Clone)]
445struct RmsNorm {
446 eps: f32,
447}
448
449impl candle::CustomOp2 for RmsNorm {
450 fn name(&self) -> &'static str {
451 "rms-norm"
452 }
453
454 fn cpu_fwd(
455 &self,
456 s1: &CpuStorage,
457 l1: &Layout,
458 s2: &CpuStorage,
459 l2: &Layout,
460 ) -> Result<(CpuStorage, Shape)> {
461 use candle::backend::BackendStorage;
462
463 let eps = self.eps;
464 fn inner<
465 T: candle::WithDType
466 + num_traits::Float
467 + num_traits::AsPrimitive<f32>
468 + num_traits::FromPrimitive,
469 >(
470 src: &[T],
471 layout: &Layout,
472 alpha: &[T],
473 alpha_layout: &Layout,
474 eps: f32,
475 ) -> Result<(CpuStorage, Shape)> {
476 let src = match layout.contiguous_offsets() {
477 None => candle::bail!("input has to be contiguous"),
478 Some((o1, o2)) => &src[o1..o2],
479 };
480 let alpha = match alpha_layout.contiguous_offsets() {
481 None => candle::bail!("alpha has to be contiguous"),
482 Some((o1, o2)) => &alpha[o1..o2],
483 };
484 let el_count = layout.shape().elem_count();
485 let dims = layout.shape().dims();
486 let dim_m1 = dims[dims.len() - 1];
487 let mut dst = vec![T::zero(); el_count];
488 src.par_chunks(dim_m1)
489 .zip(dst.par_chunks_mut(dim_m1))
490 .for_each(|(src, dst)| {
491 let sum2 = src
492 .iter()
493 .map(|&v| {
494 let v = v.as_();
495 v * v
496 })
497 .sum::<f32>();
498 let m = (sum2 / dim_m1 as f32 + eps).sqrt();
499 let m = T::from_f32(m).unwrap_or_else(T::nan);
500 for ((d, s), alpha) in dst.iter_mut().zip(src.iter()).zip(alpha) {
501 *d = *s / m * *alpha
502 }
503 });
504 let storage = candle::WithDType::to_cpu_storage_owned(dst);
505 Ok((storage, Shape::from_dims(dims)))
506 }
507
508 use CpuStorage as C;
509 match (s1, s2) {
510 (C::BF16(s1), C::BF16(s2)) => inner::<half::bf16>(s1, l1, s2, l2, eps),
511 (C::F16(s1), C::F16(s2)) => inner::<half::f16>(s1, l1, s2, l2, eps),
512 (C::F32(s1), C::F32(s2)) => inner::<f32>(s1, l1, s2, l2, eps),
513 _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
514 }
515 }
516
517 #[cfg(feature = "cuda")]
518 fn cuda_fwd(
519 &self,
520 s1: &candle::CudaStorage,
521 l1: &Layout,
522 s2: &candle::CudaStorage,
523 l2: &Layout,
524 ) -> Result<(candle::CudaStorage, Shape)> {
525 use candle::cuda_backend::cudarc::driver::{
526 CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
527 };
528 use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
529 use candle::{CudaDevice, WithDType};
530
531 struct S {
532 eps: f32,
533 }
534 impl Map2 for S {
535 fn f<T: DeviceRepr + WithDType>(
536 &self,
537 src: &CudaSlice<T>,
538 layout: &Layout,
539 alpha: &CudaSlice<T>,
540 alpha_layout: &Layout,
541 dev: &CudaDevice,
542 ) -> Result<CudaSlice<T>> {
543 let src = match layout.contiguous_offsets() {
544 None => candle::bail!("input has to be contiguous"),
545 Some((o1, o2)) => src.slice(o1..o2),
546 };
547 let alpha = match alpha_layout.contiguous_offsets() {
548 None => candle::bail!("alpha has to be contiguous"),
549 Some((o1, o2)) => alpha.slice(o1..o2),
550 };
551 let el = layout.shape().elem_count();
552 let dims = layout.shape().dims();
553 let dim_m1 = dims[dims.len() - 1];
554 let (n_rows, n_cols) = (el / dim_m1, dim_m1);
555
556 let block_size = if n_cols < 1024 { 32 } else { 1024 };
557 let cfg = LaunchConfig {
558 grid_dim: (n_rows as u32, 1, 1),
559 block_dim: (block_size, 1, 1),
560 shared_mem_bytes: 0,
561 };
562 let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
563 let dst = unsafe { dev.alloc::<T>(el)? };
565 let mut builder = func.builder();
566 builder.arg(&src);
567 builder.arg(&dst);
568 builder.arg(&alpha);
569 candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
570 unsafe { builder.launch(cfg) }.w()?;
572 Ok(dst)
573 }
574 }
575
576 use candle::backend::BackendStorage;
577 let dev = s1.device();
578 let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, dev)?;
579 let dst = candle::cuda_backend::CudaStorage {
580 slice,
581 device: dev.clone(),
582 };
583 Ok((dst, l1.shape().clone()))
584 }
585
586 #[cfg(feature = "metal")]
587 fn metal_fwd(
588 &self,
589 s1: &candle::MetalStorage,
590 l1: &Layout,
591 s2: &candle::MetalStorage,
592 l2: &Layout,
593 ) -> Result<(candle::MetalStorage, Shape)> {
594 use candle::backend::BackendStorage;
595 let device = s1.device();
596 let command_buffer = device.command_buffer()?;
597 let kernels = device.kernels();
598 let name = match (s1.dtype(), s2.dtype()) {
599 (DType::F32, DType::F32) => "rmsnorm_f32",
600 (DType::F16, DType::F16) => "rmsnorm_f16",
601 (DType::BF16, DType::BF16) => "rmsnorm_bf16",
602 (dt1, dt2) => candle::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"),
603 };
604
605 if !(l1.is_contiguous() && l2.is_contiguous()) {
606 candle::bail!("Non contiguous rmsnorm is not implemented");
607 }
608
609 let last_dim = l1.dims()[l1.shape().rank() - 1];
610 let elem_count = l1.shape().elem_count();
611 let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?;
612 candle_metal_kernels::call_rms_norm(
613 device.metal_device(),
614 &command_buffer,
615 kernels,
616 name,
617 elem_count,
618 last_dim,
619 self.eps,
620 s1.buffer(),
621 l1.start_offset() * s1.dtype().size_in_bytes(),
622 s2.buffer(),
623 l2.start_offset() * s2.dtype().size_in_bytes(),
624 &output,
625 )
626 .map_err(candle::Error::wrap)?;
627 let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
628 Ok((newstorage, l1.shape().clone()))
629 }
630}
631
632pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
633 let x_dtype = x.dtype();
634 let internal_dtype = match x_dtype {
635 DType::F16 | DType::BF16 => DType::F32,
636 d => d,
637 };
638 let hidden_size = x.dim(D::Minus1)?;
639 let x = x.to_dtype(internal_dtype)?;
640 let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
641 let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
642 x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
643}
644
645pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
646 let hidden_size_xs = xs.dim(D::Minus1)?;
647 let hidden_size_alpha = alpha.dims1()?;
648 if hidden_size_xs != hidden_size_alpha {
649 candle::bail!(
650 "shape mismatch in rms-norm {:?} {:?}",
651 xs.shape(),
652 alpha.shape()
653 )
654 }
655 xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
656}
657
658#[derive(Debug, Clone)]
659struct LayerNorm {
660 eps: f32,
661}
662
663impl candle::CustomOp3 for LayerNorm {
664 fn name(&self) -> &'static str {
665 "layer-norm"
666 }
667
668 fn cpu_fwd(
669 &self,
670 s1: &CpuStorage,
671 l1: &Layout,
672 s2: &CpuStorage,
673 l2: &Layout,
674 s3: &CpuStorage,
675 l3: &Layout,
676 ) -> Result<(CpuStorage, Shape)> {
677 use candle::backend::BackendStorage;
678
679 let eps = self.eps;
680 fn inner<
681 T: candle::WithDType
682 + num_traits::Float
683 + num_traits::AsPrimitive<f32>
684 + num_traits::FromPrimitive,
685 >(
686 src: &[T],
687 layout: &Layout,
688 alpha: &[T],
689 alpha_layout: &Layout,
690 beta: &[T],
691 beta_layout: &Layout,
692 eps: f32,
693 ) -> Result<(CpuStorage, Shape)> {
694 let src = match layout.contiguous_offsets() {
695 None => candle::bail!("input has to be contiguous"),
696 Some((o1, o2)) => &src[o1..o2],
697 };
698 let alpha = match alpha_layout.contiguous_offsets() {
699 None => candle::bail!("alpha has to be contiguous"),
700 Some((o1, o2)) => &alpha[o1..o2],
701 };
702 let beta = match beta_layout.contiguous_offsets() {
703 None => candle::bail!("beta has to be contiguous"),
704 Some((o1, o2)) => &beta[o1..o2],
705 };
706 let el_count = layout.shape().elem_count();
707 let dims = layout.shape().dims();
708 let dim_m1 = dims[dims.len() - 1];
709 let mut dst = vec![T::zero(); el_count];
710 src.par_chunks(dim_m1)
711 .zip(dst.par_chunks_mut(dim_m1))
712 .for_each(|(src, dst)| {
713 let mut sum = 0f32;
714 let mut sum2 = 0f32;
715 for v in src {
716 let v = v.as_();
717 sum += v;
718 sum2 += v * v;
719 }
720 let mean = sum / dim_m1 as f32;
721 let var = sum2 / dim_m1 as f32 - mean * mean;
722 let inv_std = (var + eps).sqrt().recip();
723 for ((d, s), (alpha, beta)) in
724 dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta))
725 {
726 let alpha = alpha.as_();
727 let beta = beta.as_();
728 let d_ = (s.as_() - mean) * inv_std * alpha + beta;
729 *d = T::from_f32(d_).unwrap_or_else(T::nan);
730 }
731 });
732 let storage = candle::WithDType::to_cpu_storage_owned(dst);
733 Ok((storage, Shape::from_dims(dims)))
734 }
735
736 use CpuStorage as C;
737 match (s1, s2, s3) {
738 (C::BF16(s1), C::BF16(s2), C::BF16(s3)) => {
739 inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps)
740 }
741 (C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps),
742 (C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps),
743 _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
744 }
745 }
746
747 #[cfg(feature = "cuda")]
748 fn cuda_fwd(
749 &self,
750 s1: &candle::CudaStorage,
751 l1: &Layout,
752 s2: &candle::CudaStorage,
753 l2: &Layout,
754 s3: &candle::CudaStorage,
755 l3: &Layout,
756 ) -> Result<(candle::CudaStorage, Shape)> {
757 use candle::cuda_backend::cudarc::driver::{
758 CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
759 };
760 use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
761 use candle::{CudaDevice, WithDType};
762
763 struct S {
764 eps: f32,
765 }
766 impl Map3 for S {
767 fn f<T: DeviceRepr + WithDType>(
768 &self,
769 src: &CudaSlice<T>,
770 layout: &Layout,
771 alpha: &CudaSlice<T>,
772 alpha_layout: &Layout,
773 beta: &CudaSlice<T>,
774 beta_layout: &Layout,
775 dev: &CudaDevice,
776 ) -> Result<CudaSlice<T>> {
777 let src = match layout.contiguous_offsets() {
778 None => candle::bail!("input has to be contiguous"),
779 Some((o1, o2)) => src.slice(o1..o2),
780 };
781 let alpha = match alpha_layout.contiguous_offsets() {
782 None => candle::bail!("alpha has to be contiguous"),
783 Some((o1, o2)) => alpha.slice(o1..o2),
784 };
785 let beta = match beta_layout.contiguous_offsets() {
786 None => candle::bail!("beta has to be contiguous"),
787 Some((o1, o2)) => beta.slice(o1..o2),
788 };
789 let el = layout.shape().elem_count();
790 let dims = layout.shape().dims();
791 let dim_m1 = dims[dims.len() - 1];
792 let (n_rows, n_cols) = (el / dim_m1, dim_m1);
793
794 let block_size = if n_cols < 1024 { 32 } else { 1024 };
795 let cfg = LaunchConfig {
796 grid_dim: (n_rows as u32, 1, 1),
797 block_dim: (block_size, 1, 1),
798 shared_mem_bytes: 0,
799 };
800 let func =
801 dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
802 let dst = unsafe { dev.alloc::<T>(el)? };
804 let mut builder = func.builder();
805 builder.arg(&src);
806 builder.arg(&dst);
807 builder.arg(&alpha);
808 builder.arg(&beta);
809 candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
810 unsafe { builder.launch(cfg) }.w()?;
812 Ok(dst)
813 }
814 }
815
816 use candle::backend::BackendStorage;
817 let dev = s1.device();
818 let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?;
819 let dst = candle::cuda_backend::CudaStorage {
820 slice,
821 device: dev.clone(),
822 };
823 Ok((dst, l1.shape().clone()))
824 }
825
826 #[cfg(feature = "metal")]
827 fn metal_fwd(
828 &self,
829 s1: &candle::MetalStorage,
830 l1: &Layout,
831 s2: &candle::MetalStorage,
832 l2: &Layout,
833 s3: &candle::MetalStorage,
834 l3: &Layout,
835 ) -> Result<(candle::MetalStorage, Shape)> {
836 use candle::backend::BackendStorage;
837 let device = s1.device();
838 let command_buffer = device.command_buffer()?;
839 let kernels = device.kernels();
840 let name = match (s1.dtype(), s2.dtype(), s3.dtype()) {
841 (DType::F32, DType::F32, DType::F32) => "layernorm_f32",
842 (DType::F16, DType::F16, DType::F16) => "layernorm_f16",
843 (DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16",
844 (dt1, dt2, dt3) => {
845 candle::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}")
846 }
847 };
848
849 if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) {
850 candle::bail!("Non contiguous layernorm is not implemented");
851 }
852
853 let last_dim = l1.dims()[l1.shape().rank() - 1];
854 let elem_count = l1.shape().elem_count();
855 let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?;
856 candle_metal_kernels::call_layer_norm(
857 device.metal_device(),
858 &command_buffer,
859 kernels,
860 name,
861 elem_count,
862 last_dim,
863 self.eps,
864 s1.buffer(),
865 l1.start_offset() * s1.dtype().size_in_bytes(),
866 s2.buffer(),
867 l2.start_offset() * s2.dtype().size_in_bytes(),
868 s3.buffer(),
869 l3.start_offset() * s3.dtype().size_in_bytes(),
870 &output,
871 )
872 .map_err(candle::Error::wrap)?;
873 let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
874 Ok((newstorage, l1.shape().clone()))
875 }
876}
877
878pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
879 let x_dtype = x.dtype();
880 let internal_dtype = match x_dtype {
881 DType::F16 | DType::BF16 => DType::F32,
882 d => d,
883 };
884 let hidden_size = x.dim(D::Minus1)?;
885 let x = x.to_dtype(internal_dtype)?;
886 let x = {
887 let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
888 x.broadcast_sub(&mean_x)?
889 };
890 let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
891 let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
892 x_normed
893 .to_dtype(x_dtype)?
894 .broadcast_mul(alpha)?
895 .broadcast_add(beta)
896}
897
898pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
899 let hidden_size_xs = xs.dim(D::Minus1)?;
900 let hidden_size_alpha = alpha.dims1()?;
901 let hidden_size_beta = beta.dims1()?;
902 if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta {
903 candle::bail!(
904 "shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}",
905 xs.shape(),
906 alpha.shape(),
907 beta.shape()
908 )
909 }
910 xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps })
911}
912
913pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
915 let (b_size, c, h, w) = xs.dims4()?;
916 let out_c = c / upscale_factor / upscale_factor;
917 xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))?
918 .permute((0, 1, 4, 2, 5, 3))?
919 .reshape((b_size, out_c, h * upscale_factor, w * upscale_factor))
920}
921
922pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
923 let (b_size, c, h, w) = xs.dims4()?;
924 let out_c = c * downscale_factor * downscale_factor;
925 xs.reshape((
926 b_size,
927 c,
928 h / downscale_factor,
929 downscale_factor,
930 w / downscale_factor,
931 downscale_factor,
932 ))?
933 .permute((0, 1, 3, 5, 2, 4))?
934 .reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
935}
936
937pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
939 match pad {
940 0 => Ok(xs.clone()),
941 1 => {
942 let (_b_size, _c, h, w) = xs.dims4()?;
943 let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?);
944 let xs = Tensor::cat(&[&first, xs, &last], 3)?;
945 let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?);
946 Tensor::cat(&[&first, &xs, &last], 2)
947 }
948 n => candle::bail!("replication-pad with a size of {n} is not supported"),
949 }
950}
951
952#[derive(Clone, Debug)]
953pub struct Identity;
954
955impl Identity {
956 pub fn new() -> Identity {
957 Self
958 }
959}
960
961impl Default for Identity {
962 fn default() -> Self {
963 Self
964 }
965}
966
967impl Module for Identity {
968 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
969 Ok(xs.clone())
970 }
971}
972
973#[allow(dead_code)]
974struct Sdpa {
975 scale: f32,
976 softcapping: f32,
977}
978
979impl candle::CustomOp3 for Sdpa {
980 fn name(&self) -> &'static str {
981 "metal-sdpa"
982 }
983
984 fn cpu_fwd(
985 &self,
986 _s1: &CpuStorage,
987 _l1: &Layout,
988 _s2: &CpuStorage,
989 _l2: &Layout,
990 _s3: &CpuStorage,
991 _l3: &Layout,
992 ) -> Result<(CpuStorage, Shape)> {
993 candle::bail!("SDPA has no cpu impl")
994 }
995
996 #[cfg(feature = "metal")]
997 fn metal_fwd(
998 &self,
999 q: &candle::MetalStorage,
1000 q_l: &Layout,
1001 k: &candle::MetalStorage,
1002 k_l: &Layout,
1003 v: &candle::MetalStorage,
1004 v_l: &Layout,
1005 ) -> Result<(candle::MetalStorage, Shape)> {
1006 use candle::backend::BackendStorage;
1007 use candle_metal_kernels::SdpaDType;
1008
1009 let device = q.device();
1010
1011 let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?];
1012 let elem_count: usize = out_dims.iter().product();
1013
1014 let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?;
1015
1016 if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? {
1018 candle::bail!("`q` and `k` last dims must match");
1019 }
1020
1021 if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? {
1023 candle::bail!("`k` and `v` head dims must match");
1024 }
1025
1026 if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 {
1028 candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
1029 }
1030
1031 let k_head = k_l.dim(D::Minus1)?;
1032 let q_head = q_l.dim(D::Minus1)?;
1033 let q_seq = q_l.dim(2)?;
1034
1035 let mut implementation_supports_use_case = q_head == k_head;
1036 let supported_head_dim =
1037 q_head == 32 || q_head == 64 || q_head == 96 || q_head == 128 || q_head == 256;
1038
1039 const SDPA_FULL_THRESHOLD: usize = 2;
1040
1041 let supports_sdpa_full =
1042 q_seq >= SDPA_FULL_THRESHOLD && supported_head_dim && q_head == k_head;
1043 let supports_sdpa_vector = q_seq == 1 && supported_head_dim;
1044
1045 implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector;
1046
1047 if !supported_head_dim {
1048 candle::bail!(
1049 "Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.",
1050 q_l.dims(),
1051 k_l.dims(),
1052 v_l.dims()
1053 );
1054 }
1055 if !implementation_supports_use_case {
1056 candle::bail!(
1057 "Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.",
1058 q_l.dims(),
1059 k_l.dims(),
1060 v_l.dims()
1061 );
1062 }
1063
1064 for t in [k.dtype(), v.dtype()] {
1065 if q.dtype() != t {
1066 candle::bail!("all q, k, v dtypes must match.");
1067 }
1068 }
1069
1070 let itype = match q.dtype() {
1071 DType::BF16 => SdpaDType::BF16,
1072 DType::F16 => SdpaDType::F16,
1073 DType::F32 => SdpaDType::F32,
1074 other => candle::bail!("unsupported sdpa type {other:?}"),
1075 };
1076
1077 let command_buffer = q.device().command_buffer()?;
1078 if supports_sdpa_vector {
1079 const TWO_PASS_K_THRESHOLD: usize = 1024;
1082 if k_l.dim(2)? >= TWO_PASS_K_THRESHOLD {
1083 let mut intermediate_shape = [
1084 &out_dims[0..out_dims.len() - 2],
1085 &[candle_metal_kernels::SDPA_2PASS_BLOCKS],
1086 &[out_dims[out_dims.len() - 1]],
1087 ]
1088 .concat();
1089 let intermediate = device.new_buffer(
1090 intermediate_shape.iter().product::<usize>(),
1091 DType::F32,
1092 "sdpa_2pass_intermediate",
1093 )?;
1094 let _ = intermediate_shape.pop().unwrap();
1095 let sums = device.new_buffer(
1096 intermediate_shape.iter().product::<usize>(),
1097 DType::F32,
1098 "sdpa_2pass_sums",
1099 )?;
1100 let maxs = device.new_buffer(
1101 intermediate_shape.iter().product::<usize>(),
1102 DType::F32,
1103 "sdpa_2pass_maxs",
1104 )?;
1105
1106 command_buffer.set_label("vector_attention");
1107 candle_metal_kernels::call_sdpa_vector_2pass(
1108 q.device().device(),
1109 &command_buffer,
1110 q.device().kernels(),
1111 q_l.start_offset(),
1112 q_l.dims(),
1113 q.buffer(),
1114 k_l.start_offset(),
1115 k_l.dims(),
1116 k_l.stride(),
1117 k.buffer(),
1118 v_l.start_offset(),
1119 v_l.stride(),
1120 v.buffer(),
1121 &output,
1122 &intermediate,
1123 &sums,
1124 &maxs,
1125 self.scale,
1126 self.softcapping,
1127 itype,
1128 )
1129 .map_err(candle::Error::wrap)?;
1130 } else {
1131 command_buffer.set_label("vector_attention");
1132 candle_metal_kernels::call_sdpa_vector(
1133 q.device().device(),
1134 &command_buffer,
1135 q.device().kernels(),
1136 q_l.start_offset(),
1137 q_l.dims(),
1138 q.buffer(),
1139 k_l.start_offset(),
1140 k_l.dims(),
1141 k_l.stride(),
1142 k.buffer(),
1143 v_l.start_offset(),
1144 v_l.stride(),
1145 v.buffer(),
1146 &output,
1147 self.scale,
1148 self.softcapping,
1149 itype,
1150 )
1151 .map_err(candle::Error::wrap)?;
1152 }
1153 } else if supports_sdpa_full {
1154 if q_l.dim(2)? != k_l.dim(2)? {
1155 candle::bail!(
1156 "query and key sequence length must be equal if using full metal sdpa"
1157 )
1158 }
1159
1160 command_buffer.set_label("full_attention");
1161 candle_metal_kernels::call_sdpa_full(
1162 q.device().device(),
1163 &command_buffer,
1164 q.device().kernels(),
1165 q_l.start_offset(),
1166 q_l.dims(),
1167 q.buffer(),
1168 k_l.start_offset(),
1169 k.buffer(),
1170 v_l.start_offset(),
1171 v.buffer(),
1172 &output,
1173 self.scale,
1174 self.softcapping,
1175 itype,
1176 )
1177 .map_err(candle::Error::wrap)?;
1178 } else {
1179 candle::bail!("must be vector or full sdpa kernel");
1180 }
1181
1182 let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, q.dtype());
1183 Ok((newstorage, Shape::from_dims(&out_dims)))
1184 }
1185}
1186
1187pub fn sdpa(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32, softcapping: f32) -> Result<Tensor> {
1213 q.apply_op3_no_bwd(k, v, &Sdpa { scale, softcapping })
1214}