diffusion_rs_common/nn/
rotary_emb.rs

1//! Rotary Embeddings
2//!
3use crate::core::{CpuStorage, Layout, Result, Shape, Tensor, D};
4use rayon::prelude::*;
5
6/// Interleaved variant of rotary embeddings.
7/// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
8/// The resulting y0 and y1 are also interleaved with:
9///   y0 = x0*cos - x1*sin
10///   y1 = x0*sin + x1*cos
11#[derive(Debug, Clone)]
12struct RotaryEmbI;
13
14impl crate::core::CustomOp3 for RotaryEmbI {
15    fn name(&self) -> &'static str {
16        "rotary-emb-int"
17    }
18
19    fn cpu_fwd(
20        &self,
21        s1: &CpuStorage,
22        l1: &Layout,
23        s2: &CpuStorage,
24        l2: &Layout,
25        s3: &CpuStorage,
26        l3: &Layout,
27    ) -> Result<(CpuStorage, Shape)> {
28        fn inner<T: crate::core::WithDType + num_traits::Float>(
29            src: &[T],
30            l_src: &Layout,
31            cos: &[T],
32            l_cos: &Layout,
33            sin: &[T],
34            l_sin: &Layout,
35        ) -> Result<(CpuStorage, Shape)> {
36            let src = match l_src.contiguous_offsets() {
37                None => crate::bail!("input src has to be contiguous"),
38                Some((o1, o2)) => &src[o1..o2],
39            };
40            let cos = match l_cos.contiguous_offsets() {
41                None => crate::bail!("input cos has to be contiguous"),
42                Some((o1, o2)) => &cos[o1..o2],
43            };
44            let sin = match l_sin.contiguous_offsets() {
45                None => crate::bail!("input sin has to be contiguous"),
46                Some((o1, o2)) => &sin[o1..o2],
47            };
48            let (b, h, t, d) = l_src.shape().dims4()?;
49            let el_count = b * h * t * d;
50            let mut dst = vec![T::zero(); el_count];
51            src.par_chunks(t * d)
52                .zip(dst.par_chunks_mut(t * d))
53                .for_each(|(src, dst)| {
54                    for i_over_2 in 0..t * d / 2 {
55                        let i = 2 * i_over_2;
56                        dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2];
57                        dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2];
58                    }
59                });
60            let storage = crate::core::WithDType::to_cpu_storage_owned(dst);
61            Ok((storage, (b, h, t, d).into()))
62        }
63
64        use crate::core::backend::BackendStorage;
65        use CpuStorage::{BF16, F16, F32, F64};
66        match (s1, s2, s3) {
67            (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
68            (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
69            (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
70            (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
71            _ => crate::bail!(
72                "unsupported dtype for rope {:?} {:?} {:?}",
73                s1.dtype(),
74                s2.dtype(),
75                s3.dtype()
76            ),
77        }
78    }
79
80    #[cfg(feature = "cuda")]
81    fn cuda_fwd(
82        &self,
83        s1: &crate::core::CudaStorage,
84        l1: &Layout,
85        s2: &crate::core::CudaStorage,
86        l2: &Layout,
87        s3: &crate::core::CudaStorage,
88        l3: &Layout,
89    ) -> Result<(crate::core::CudaStorage, Shape)> {
90        use crate::core::cuda_backend::cudarc::driver::{
91            CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
92        };
93        use crate::core::cuda_backend::{kernel_name, kernels, WrapErr};
94        use crate::core::{CudaDevice, WithDType};
95
96        fn inner<T: DeviceRepr + WithDType>(
97            src: &CudaSlice<T>,
98            l_src: &Layout,
99            cos: &CudaSlice<T>,
100            l_cos: &Layout,
101            sin: &CudaSlice<T>,
102            l_sin: &Layout,
103            dev: &CudaDevice,
104        ) -> Result<CudaSlice<T>> {
105            let src = match l_src.contiguous_offsets() {
106                None => crate::bail!("src input has to be contiguous"),
107                Some((o1, o2)) => src.slice(o1..o2),
108            };
109            let cos = match l_cos.contiguous_offsets() {
110                None => crate::bail!("cos input has to be contiguous"),
111                Some((o1, o2)) => cos.slice(o1..o2),
112            };
113            let sin = match l_sin.contiguous_offsets() {
114                None => crate::bail!("sin input has to be contiguous"),
115                Some((o1, o2)) => sin.slice(o1..o2),
116            };
117            let (b, h, t, d) = l_src.shape().dims4()?;
118            let el = b * h * t * d;
119            let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
120            let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), kernels::REDUCE)?;
121            // SAFETY: Set later by running the kernel.
122            let dst = unsafe { dev.alloc::<T>(el) }.w()?;
123            let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32);
124            // SAFETY: ffi.
125            unsafe { func.launch(cfg, params) }.w()?;
126            Ok(dst)
127        }
128
129        use crate::core::backend::BackendStorage;
130        use crate::core::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
131        let dev = s1.device();
132        let slice = match (&s1.slice, &s2.slice, &s3.slice) {
133            (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
134            (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
135            (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
136            (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
137            _ => crate::bail!(
138                "unsupported dtype for rope {:?} {:?} {:?}",
139                s1.dtype(),
140                s2.dtype(),
141                s3.dtype()
142            ),
143        };
144        let dst = crate::core::cuda_backend::CudaStorage {
145            slice,
146            device: dev.clone(),
147        };
148        Ok((dst, l1.shape().clone()))
149    }
150
151    #[cfg(feature = "metal")]
152    fn metal_fwd(
153        &self,
154        src: &crate::core::MetalStorage,
155        l_src: &Layout,
156        cos: &crate::core::MetalStorage,
157        l_cos: &Layout,
158        sin: &crate::core::MetalStorage,
159        l_sin: &Layout,
160    ) -> Result<(crate::core::MetalStorage, Shape)> {
161        use crate::core::backend::BackendStorage;
162        let device = src.device();
163        let command_buffer = device.command_buffer()?;
164        let kernels = device.kernels();
165        if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
166            crate::bail!(
167                "dtype mismatch in rope-i {:?} {:?} {:?}",
168                src.dtype(),
169                cos.dtype(),
170                sin.dtype()
171            )
172        }
173        let name = match src.dtype() {
174            crate::core::DType::F32 => "rope_i_f32",
175            crate::core::DType::F16 => "rope_i_f16",
176            crate::core::DType::BF16 => "rope_i_bf16",
177            dtype => crate::bail!("rope-i is not implemented for {dtype:?}"),
178        };
179        let (b, h, t, d) = l_src.shape().dims4()?;
180        let el = b * h * t * d;
181        let output = device.new_buffer(el, src.dtype(), "rope-i")?;
182        crate::metal_kernels::call_rope_i(
183            device.metal_device(),
184            &command_buffer,
185            kernels,
186            name,
187            b * h,
188            t * d,
189            src.buffer(),
190            l_src.start_offset() * src.dtype().size_in_bytes(),
191            cos.buffer(),
192            l_cos.start_offset() * cos.dtype().size_in_bytes(),
193            sin.buffer(),
194            l_sin.start_offset() * sin.dtype().size_in_bytes(),
195            &output,
196        )
197        .map_err(crate::core::Error::wrap)?;
198        let out = crate::core::MetalStorage::new(output, device.clone(), el, src.dtype());
199        Ok((out, l_src.shape().clone()))
200    }
201}
202
203pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
204    let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
205    let (cos_seq_len, cos_n_embd) = cos.dims2()?;
206    let (sin_seq_len, sin_n_embd) = cos.dims2()?;
207    if cos_n_embd * 2 != n_embd
208        || sin_n_embd * 2 != n_embd
209        || seq_len > cos_seq_len
210        || seq_len > sin_seq_len
211    {
212        crate::bail!(
213            "inconsistent last dim size in rope {:?} {:?} {:?}",
214            xs.shape(),
215            cos.shape(),
216            sin.shape()
217        )
218    }
219    if !xs.is_contiguous() {
220        crate::bail!("xs has to be contiguous in rope")
221    }
222    if !cos.is_contiguous() {
223        crate::bail!("cos has to be contiguous in rope")
224    }
225    if !sin.is_contiguous() {
226        crate::bail!("sin has to be contiguous in rope")
227    }
228    xs.apply_op3_no_bwd(cos, sin, &RotaryEmbI)
229}
230
231pub fn rope_i_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
232    let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
233    let cos = cos
234        .narrow(0, 0, seq_len)?
235        .reshape((seq_len, n_embd / 2, 1))?;
236    let sin = sin
237        .narrow(0, 0, seq_len)?
238        .reshape((seq_len, n_embd / 2, 1))?;
239    let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
240    let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
241    let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
242    let x0 = x.narrow(D::Minus1, 0, 1)?;
243    let x1 = x.narrow(D::Minus1, 1, 1)?;
244    let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
245    let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
246    let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
247    let rope = rope.flatten_from(D::Minus2)?;
248    Ok(rope)
249}
250
251/// Contiguous variant of rope embeddings.
252#[derive(Debug, Clone)]
253struct RotaryEmb;
254
255impl crate::core::CustomOp3 for RotaryEmb {
256    fn name(&self) -> &'static str {
257        "rotary-emb"
258    }
259
260    fn cpu_fwd(
261        &self,
262        s1: &CpuStorage,
263        l1: &Layout,
264        s2: &CpuStorage,
265        l2: &Layout,
266        s3: &CpuStorage,
267        l3: &Layout,
268    ) -> Result<(CpuStorage, Shape)> {
269        fn inner<T: crate::core::WithDType + num_traits::Float>(
270            src: &[T],
271            l_src: &Layout,
272            cos: &[T],
273            l_cos: &Layout,
274            sin: &[T],
275            l_sin: &Layout,
276        ) -> Result<(CpuStorage, Shape)> {
277            let src = match l_src.contiguous_offsets() {
278                None => crate::bail!("input src has to be contiguous"),
279                Some((o1, o2)) => &src[o1..o2],
280            };
281            let cos = match l_cos.contiguous_offsets() {
282                None => crate::bail!("input cos has to be contiguous"),
283                Some((o1, o2)) => &cos[o1..o2],
284            };
285            let sin = match l_sin.contiguous_offsets() {
286                None => crate::bail!("input sin has to be contiguous"),
287                Some((o1, o2)) => &sin[o1..o2],
288            };
289            let (b, h, t, d) = l_src.shape().dims4()?;
290            let el_count = b * h * t * d;
291            let mut dst = vec![T::zero(); el_count];
292            src.par_chunks(t * d)
293                .zip(dst.par_chunks_mut(t * d))
294                .for_each(|(src, dst)| {
295                    for i_t in 0..t {
296                        for i_d in 0..d / 2 {
297                            let i1 = i_t * d + i_d;
298                            let i2 = i1 + d / 2;
299                            let i_cs = i_t * (d / 2) + i_d;
300                            dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
301                            dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
302                        }
303                    }
304                });
305            let storage = crate::core::WithDType::to_cpu_storage_owned(dst);
306            Ok((storage, (b, h, t, d).into()))
307        }
308
309        use crate::core::backend::BackendStorage;
310        use CpuStorage::{BF16, F16, F32, F64};
311        match (s1, s2, s3) {
312            (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
313            (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
314            (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
315            (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
316            _ => crate::bail!(
317                "unsupported dtype for rope {:?} {:?} {:?}",
318                s1.dtype(),
319                s2.dtype(),
320                s3.dtype()
321            ),
322        }
323    }
324
325    #[cfg(feature = "cuda")]
326    fn cuda_fwd(
327        &self,
328        s1: &crate::core::CudaStorage,
329        l1: &Layout,
330        s2: &crate::core::CudaStorage,
331        l2: &Layout,
332        s3: &crate::core::CudaStorage,
333        l3: &Layout,
334    ) -> Result<(crate::core::CudaStorage, Shape)> {
335        use crate::core::cuda_backend::cudarc::driver::{
336            CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
337        };
338        use crate::core::cuda_backend::{kernel_name, kernels, WrapErr};
339        use crate::core::{CudaDevice, WithDType};
340
341        fn inner<T: DeviceRepr + WithDType>(
342            src: &CudaSlice<T>,
343            l_src: &Layout,
344            cos: &CudaSlice<T>,
345            l_cos: &Layout,
346            sin: &CudaSlice<T>,
347            l_sin: &Layout,
348            dev: &CudaDevice,
349        ) -> Result<CudaSlice<T>> {
350            let src = match l_src.contiguous_offsets() {
351                None => crate::bail!("src input has to be contiguous"),
352                Some((o1, o2)) => src.slice(o1..o2),
353            };
354            let cos = match l_cos.contiguous_offsets() {
355                None => crate::bail!("cos input has to be contiguous"),
356                Some((o1, o2)) => cos.slice(o1..o2),
357            };
358            let sin = match l_sin.contiguous_offsets() {
359                None => crate::bail!("sin input has to be contiguous"),
360                Some((o1, o2)) => sin.slice(o1..o2),
361            };
362            let (b, h, t, d) = l_src.shape().dims4()?;
363            let el = b * h * t * d;
364            let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
365            let func = dev.get_or_load_func(&kernel_name::<T>("rope"), kernels::REDUCE)?;
366            // SAFETY: Set later by running the kernel.
367            let dst = unsafe { dev.alloc::<T>(el) }.w()?;
368            let params = (
369                &src,
370                &cos,
371                &sin,
372                &dst,
373                (b * h) as u32,
374                (t * d) as u32,
375                d as u32,
376            );
377            // SAFETY: ffi.
378            unsafe { func.launch(cfg, params) }.w()?;
379            Ok(dst)
380        }
381
382        use crate::core::backend::BackendStorage;
383        use crate::core::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
384        let dev = s1.device();
385        let slice = match (&s1.slice, &s2.slice, &s3.slice) {
386            (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
387            (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
388            (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
389            (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
390            _ => crate::bail!(
391                "unsupported dtype for rope {:?} {:?} {:?}",
392                s1.dtype(),
393                s2.dtype(),
394                s3.dtype()
395            ),
396        };
397        let dst = crate::core::cuda_backend::CudaStorage {
398            slice,
399            device: dev.clone(),
400        };
401        Ok((dst, l1.shape().clone()))
402    }
403
404    #[cfg(feature = "metal")]
405    fn metal_fwd(
406        &self,
407        src: &crate::core::MetalStorage,
408        l_src: &Layout,
409        cos: &crate::core::MetalStorage,
410        l_cos: &Layout,
411        sin: &crate::core::MetalStorage,
412        l_sin: &Layout,
413    ) -> Result<(crate::core::MetalStorage, Shape)> {
414        use crate::core::backend::BackendStorage;
415        let device = src.device();
416        let command_buffer = device.command_buffer()?;
417        let kernels = device.kernels();
418        if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
419            crate::bail!(
420                "dtype mismatch in rope {:?} {:?} {:?}",
421                src.dtype(),
422                cos.dtype(),
423                sin.dtype()
424            )
425        }
426        let name = match src.dtype() {
427            crate::core::DType::F32 => "rope_f32",
428            crate::core::DType::F16 => "rope_f16",
429            crate::core::DType::BF16 => "rope_bf16",
430            dtype => crate::bail!("rope is not implemented for {dtype:?}"),
431        };
432        let (b, h, t, d) = l_src.shape().dims4()?;
433        let el = b * h * t * d;
434        let output = device.new_buffer(el, src.dtype(), "rope-i")?;
435        crate::metal_kernels::call_rope(
436            device.metal_device(),
437            &command_buffer,
438            kernels,
439            name,
440            b * h,
441            t * d,
442            d,
443            src.buffer(),
444            l_src.start_offset() * src.dtype().size_in_bytes(),
445            cos.buffer(),
446            l_cos.start_offset() * cos.dtype().size_in_bytes(),
447            sin.buffer(),
448            l_sin.start_offset() * sin.dtype().size_in_bytes(),
449            &output,
450        )
451        .map_err(crate::core::Error::wrap)?;
452        let out = crate::core::MetalStorage::new(output, device.clone(), el, src.dtype());
453        Ok((out, l_src.shape().clone()))
454    }
455}
456
457pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
458    let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
459    let (cos_seq_len, cos_n_embd) = cos.dims2()?;
460    let (sin_seq_len, sin_n_embd) = sin.dims2()?;
461    if cos_n_embd * 2 != n_embd
462        || sin_n_embd * 2 != n_embd
463        || seq_len > cos_seq_len
464        || seq_len > sin_seq_len
465    {
466        crate::bail!(
467            "inconsistent last dim size in rope {:?} {:?} {:?}",
468            xs.shape(),
469            cos.shape(),
470            sin.shape()
471        )
472    }
473    if !xs.is_contiguous() {
474        crate::bail!("xs has to be contiguous in rope")
475    }
476    if !cos.is_contiguous() {
477        crate::bail!("cos has to be contiguous in rope")
478    }
479    if !sin.is_contiguous() {
480        crate::bail!("sin has to be contiguous in rope")
481    }
482    xs.apply_op3_no_bwd(cos, sin, &RotaryEmb)
483}
484
485fn rotate_half(xs: &Tensor) -> Result<Tensor> {
486    let last_dim = xs.dim(D::Minus1)?;
487    let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
488    let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
489    Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
490}
491
492pub fn rope_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
493    let (_b_sz, _h, seq_len, _n_embd) = x.dims4()?;
494    let cos = Tensor::cat(&[cos, cos], D::Minus1)?;
495    let sin = Tensor::cat(&[sin, sin], D::Minus1)?;
496    let cos = cos.narrow(0, 0, seq_len)?;
497    let sin = sin.narrow(0, 0, seq_len)?;
498    let cos = cos.unsqueeze(0)?.unsqueeze(0)?;
499    let sin = sin.unsqueeze(0)?.unsqueeze(0)?;
500    x.broadcast_mul(&cos)? + rotate_half(x)?.broadcast_mul(&sin)?
501}
502
503/// T (seqlen)/H (num-heads)/D (head-dim) contiguous variant of rope embeddings.
504#[derive(Debug, Clone)]
505struct RotaryEmbThd;
506
507impl crate::core::CustomOp3 for RotaryEmbThd {
508    fn name(&self) -> &'static str {
509        "rotary-emb"
510    }
511
512    fn cpu_fwd(
513        &self,
514        s1: &CpuStorage,
515        l1: &Layout,
516        s2: &CpuStorage,
517        l2: &Layout,
518        s3: &CpuStorage,
519        l3: &Layout,
520    ) -> Result<(CpuStorage, Shape)> {
521        fn inner<T: crate::core::WithDType + num_traits::Float>(
522            src: &[T],
523            l_src: &Layout,
524            cos: &[T],
525            l_cos: &Layout,
526            sin: &[T],
527            l_sin: &Layout,
528        ) -> Result<(CpuStorage, Shape)> {
529            let src = match l_src.contiguous_offsets() {
530                None => crate::bail!("input src has to be contiguous"),
531                Some((o1, o2)) => &src[o1..o2],
532            };
533            let cos = match l_cos.contiguous_offsets() {
534                None => crate::bail!("input cos has to be contiguous"),
535                Some((o1, o2)) => &cos[o1..o2],
536            };
537            let sin = match l_sin.contiguous_offsets() {
538                None => crate::bail!("input sin has to be contiguous"),
539                Some((o1, o2)) => &sin[o1..o2],
540            };
541            let (b, t, h, d) = l_src.shape().dims4()?;
542            let el_count = b * h * t * d;
543            let mut dst = vec![T::zero(); el_count];
544            src.par_chunks(t * h * d)
545                .zip(dst.par_chunks_mut(t * h * d))
546                .for_each(|(src, dst)| {
547                    for i_t in 0..t {
548                        for i_d in 0..d / 2 {
549                            let i_cs = i_t * (d / 2) + i_d;
550                            for i_h in 0..h {
551                                let i1 = i_t * h * d + i_h * d + i_d;
552                                let i2 = i1 + d / 2;
553                                dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
554                                dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
555                            }
556                        }
557                    }
558                });
559            let storage = crate::core::WithDType::to_cpu_storage_owned(dst);
560            Ok((storage, (b, t, h, d).into()))
561        }
562
563        use crate::core::backend::BackendStorage;
564        use CpuStorage::{BF16, F16, F32, F64};
565        match (s1, s2, s3) {
566            (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
567            (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
568            (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
569            (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
570            _ => crate::bail!(
571                "unsupported dtype for rope {:?} {:?} {:?}",
572                s1.dtype(),
573                s2.dtype(),
574                s3.dtype()
575            ),
576        }
577    }
578
579    #[cfg(feature = "cuda")]
580    fn cuda_fwd(
581        &self,
582        s1: &crate::core::CudaStorage,
583        l1: &Layout,
584        s2: &crate::core::CudaStorage,
585        l2: &Layout,
586        s3: &crate::core::CudaStorage,
587        l3: &Layout,
588    ) -> Result<(crate::core::CudaStorage, Shape)> {
589        use crate::core::cuda_backend::cudarc::driver::{
590            CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
591        };
592        use crate::core::cuda_backend::{kernel_name, kernels, WrapErr};
593        use crate::core::{CudaDevice, WithDType};
594
595        fn inner<T: DeviceRepr + WithDType>(
596            src: &CudaSlice<T>,
597            l_src: &Layout,
598            cos: &CudaSlice<T>,
599            l_cos: &Layout,
600            sin: &CudaSlice<T>,
601            l_sin: &Layout,
602            dev: &CudaDevice,
603        ) -> Result<CudaSlice<T>> {
604            let src = match l_src.contiguous_offsets() {
605                None => crate::bail!("src input has to be contiguous"),
606                Some((o1, o2)) => src.slice(o1..o2),
607            };
608            let cos = match l_cos.contiguous_offsets() {
609                None => crate::bail!("cos input has to be contiguous"),
610                Some((o1, o2)) => cos.slice(o1..o2),
611            };
612            let sin = match l_sin.contiguous_offsets() {
613                None => crate::bail!("sin input has to be contiguous"),
614                Some((o1, o2)) => sin.slice(o1..o2),
615            };
616            let (b, t, h, d) = l_src.shape().dims4()?;
617            let el = b * h * t * d;
618            let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
619            let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?;
620            // SAFETY: Set later by running the kernel.
621            let dst = unsafe { dev.alloc::<T>(el) }.w()?;
622            let params = (
623                &src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32,
624            );
625            // SAFETY: ffi.
626            unsafe { func.launch(cfg, params) }.w()?;
627            Ok(dst)
628        }
629
630        use crate::core::backend::BackendStorage;
631        use crate::core::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
632        let dev = s1.device();
633        let slice = match (&s1.slice, &s2.slice, &s3.slice) {
634            (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
635            (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
636            (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
637            (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
638            _ => crate::bail!(
639                "unsupported dtype for rope {:?} {:?} {:?}",
640                s1.dtype(),
641                s2.dtype(),
642                s3.dtype()
643            ),
644        };
645        let dst = crate::core::cuda_backend::CudaStorage {
646            slice,
647            device: dev.clone(),
648        };
649        Ok((dst, l1.shape().clone()))
650    }
651
652    #[cfg(feature = "metal")]
653    fn metal_fwd(
654        &self,
655        src: &crate::core::MetalStorage,
656        l_src: &Layout,
657        cos: &crate::core::MetalStorage,
658        l_cos: &Layout,
659        sin: &crate::core::MetalStorage,
660        l_sin: &Layout,
661    ) -> Result<(crate::core::MetalStorage, Shape)> {
662        use crate::core::backend::BackendStorage;
663        let device = src.device();
664        let command_buffer = device.command_buffer()?;
665        let kernels = device.kernels();
666        if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
667            crate::bail!(
668                "dtype mismatch in rope {:?} {:?} {:?}",
669                src.dtype(),
670                cos.dtype(),
671                sin.dtype()
672            )
673        }
674        let name = match src.dtype() {
675            crate::core::DType::F32 => "rope_thd_f32",
676            crate::core::DType::F16 => "rope_thd_f16",
677            crate::core::DType::BF16 => "rope_thd_bf16",
678            dtype => crate::bail!("rope_thd is not implemented for {dtype:?}"),
679        };
680        let (b, t, h, d) = l_src.shape().dims4()?;
681        let el = b * h * t * d;
682        let output = device.new_buffer(el, src.dtype(), "rope-thd")?;
683        crate::metal_kernels::call_rope_thd(
684            device.metal_device(),
685            &command_buffer,
686            kernels,
687            name,
688            b,
689            t,
690            h,
691            d,
692            src.buffer(),
693            l_src.start_offset() * src.dtype().size_in_bytes(),
694            cos.buffer(),
695            l_cos.start_offset() * cos.dtype().size_in_bytes(),
696            sin.buffer(),
697            l_sin.start_offset() * sin.dtype().size_in_bytes(),
698            &output,
699        )
700        .map_err(crate::core::Error::wrap)?;
701        let out = crate::core::MetalStorage::new(output, device.clone(), el, src.dtype());
702        Ok((out, l_src.shape().clone()))
703    }
704}
705
706pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
707    let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
708    let (cos_seq_len, cos_n_embd) = cos.dims2()?;
709    let (sin_seq_len, sin_n_embd) = sin.dims2()?;
710    if cos_n_embd * 2 != n_embd
711        || sin_n_embd * 2 != n_embd
712        || seq_len > cos_seq_len
713        || seq_len > sin_seq_len
714    {
715        crate::bail!(
716            "inconsistent last dim size in rope {:?} {:?} {:?}",
717            xs.shape(),
718            cos.shape(),
719            sin.shape()
720        )
721    }
722    if !xs.is_contiguous() {
723        crate::bail!("xs has to be contiguous in rope")
724    }
725    if !cos.is_contiguous() {
726        crate::bail!("cos has to be contiguous in rope")
727    }
728    if !sin.is_contiguous() {
729        crate::bail!("sin has to be contiguous in rope")
730    }
731    xs.apply_op3_no_bwd(cos, sin, &RotaryEmbThd)
732}