Skip to main content

candle_nn/
rotary_emb.rs

1//! Rotary Embeddings
2//!
3use candle::{CpuStorage, Layout, Result, Shape, Tensor, D};
4use rayon::prelude::*;
5
6/// Interleaved variant of rotary embeddings.
7/// The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension.
8/// The resulting y0 and y1 are also interleaved with:
9///   y0 = x0*cos - x1*sin
10///   y1 = x0*sin + x1*cos
11#[derive(Debug, Clone)]
12struct RotaryEmbI;
13
14impl candle::CustomOp3 for RotaryEmbI {
15    fn name(&self) -> &'static str {
16        "rotary-emb-int"
17    }
18
19    fn cpu_fwd(
20        &self,
21        s1: &CpuStorage,
22        l1: &Layout,
23        s2: &CpuStorage,
24        l2: &Layout,
25        s3: &CpuStorage,
26        l3: &Layout,
27    ) -> Result<(CpuStorage, Shape)> {
28        fn inner<T: candle::WithDType + num_traits::Float>(
29            src: &[T],
30            l_src: &Layout,
31            cos: &[T],
32            l_cos: &Layout,
33            sin: &[T],
34            l_sin: &Layout,
35        ) -> Result<(CpuStorage, Shape)> {
36            let src = match l_src.contiguous_offsets() {
37                None => candle::bail!("input src has to be contiguous"),
38                Some((o1, o2)) => &src[o1..o2],
39            };
40            let cos = match l_cos.contiguous_offsets() {
41                None => candle::bail!("input cos has to be contiguous"),
42                Some((o1, o2)) => &cos[o1..o2],
43            };
44            let sin = match l_sin.contiguous_offsets() {
45                None => candle::bail!("input sin has to be contiguous"),
46                Some((o1, o2)) => &sin[o1..o2],
47            };
48            let (b, h, t, d) = l_src.shape().dims4()?;
49            let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
50            let el_count = b * h * t * d;
51            let mut dst = vec![T::zero(); el_count];
52            src.par_chunks(t * d)
53                .zip(dst.par_chunks_mut(t * d))
54                .enumerate()
55                .for_each(|(bh_i, (src, dst))| {
56                    for i_over_2 in 0..t * d / 2 {
57                        let i = 2 * i_over_2;
58                        let rope_i = if unbatched_rope {
59                            let b_i = bh_i / h;
60                            i_over_2 + b_i * t * d / 2
61                        } else {
62                            i_over_2
63                        };
64                        dst[i] = src[i] * cos[rope_i] - src[i + 1] * sin[rope_i];
65                        dst[i + 1] = src[i] * sin[rope_i] + src[i + 1] * cos[rope_i];
66                    }
67                });
68            let storage = candle::WithDType::to_cpu_storage_owned(dst);
69            Ok((storage, (b, h, t, d).into()))
70        }
71
72        use candle::backend::BackendStorage;
73        use CpuStorage::{BF16, F16, F32, F64};
74        match (s1, s2, s3) {
75            (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
76            (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
77            (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
78            (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
79            _ => candle::bail!(
80                "unsupported dtype for rope {:?} {:?} {:?}",
81                s1.dtype(),
82                s2.dtype(),
83                s3.dtype()
84            ),
85        }
86    }
87
88    #[cfg(feature = "cuda")]
89    fn cuda_fwd(
90        &self,
91        s1: &candle::CudaStorage,
92        l1: &Layout,
93        s2: &candle::CudaStorage,
94        l2: &Layout,
95        s3: &candle::CudaStorage,
96        l3: &Layout,
97    ) -> Result<(candle::CudaStorage, Shape)> {
98        use candle::cuda_backend::cudarc::driver::{
99            CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
100        };
101        use candle::cuda_backend::{kernel_name, kernels, WrapErr};
102        use candle::{CudaDevice, WithDType};
103
104        fn inner<T: DeviceRepr + WithDType>(
105            src: &CudaSlice<T>,
106            l_src: &Layout,
107            cos: &CudaSlice<T>,
108            l_cos: &Layout,
109            sin: &CudaSlice<T>,
110            l_sin: &Layout,
111            dev: &CudaDevice,
112        ) -> Result<CudaSlice<T>> {
113            let src = match l_src.contiguous_offsets() {
114                None => candle::bail!("src input has to be contiguous"),
115                Some((o1, o2)) => src.slice(o1..o2),
116            };
117            let cos = match l_cos.contiguous_offsets() {
118                None => candle::bail!("cos input has to be contiguous"),
119                Some((o1, o2)) => cos.slice(o1..o2),
120            };
121            let sin = match l_sin.contiguous_offsets() {
122                None => candle::bail!("sin input has to be contiguous"),
123                Some((o1, o2)) => sin.slice(o1..o2),
124            };
125            let (b, h, t, d) = l_src.shape().dims4()?;
126            let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
127                (h * t * d) as u32
128            } else {
129                0u32
130            };
131            let el = b * h * t * d;
132            let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
133            let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
134            // SAFETY: Set later by running the kernel.
135            let dst = unsafe { dev.alloc::<T>(el)? };
136            let mut builder = func.builder();
137            builder.arg(&src);
138            builder.arg(&cos);
139            builder.arg(&sin);
140            builder.arg(&dst);
141            candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, stride_b);
142            // SAFETY: ffi.
143            unsafe { builder.launch(cfg) }.w()?;
144            Ok(dst)
145        }
146
147        use candle::backend::BackendStorage;
148        use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
149        let dev = s1.device();
150        let slice = match (&s1.slice, &s2.slice, &s3.slice) {
151            (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
152            (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
153            (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
154            (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
155            _ => candle::bail!(
156                "unsupported dtype for rope {:?} {:?} {:?}",
157                s1.dtype(),
158                s2.dtype(),
159                s3.dtype()
160            ),
161        };
162        let dst = candle::cuda_backend::CudaStorage {
163            slice,
164            device: dev.clone(),
165        };
166        Ok((dst, l1.shape().clone()))
167    }
168
169    #[cfg(feature = "metal")]
170    fn metal_fwd(
171        &self,
172        src: &candle::MetalStorage,
173        l_src: &Layout,
174        cos: &candle::MetalStorage,
175        l_cos: &Layout,
176        sin: &candle::MetalStorage,
177        l_sin: &Layout,
178    ) -> Result<(candle::MetalStorage, Shape)> {
179        use candle::backend::BackendStorage;
180        let device = src.device();
181        let encoder = device.command_encoder()?;
182        encoder.set_label("rope_i");
183        let kernels = device.kernels();
184        if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
185            candle::bail!(
186                "dtype mismatch in rope-i {:?} {:?} {:?}",
187                src.dtype(),
188                cos.dtype(),
189                sin.dtype()
190            )
191        }
192        let name = match src.dtype() {
193            candle::DType::F32 => "rope_i_f32",
194            candle::DType::F16 => "rope_i_f16",
195            candle::DType::BF16 => "rope_i_bf16",
196            dtype => candle::bail!("rope-i is not implemented for {dtype:?}"),
197        };
198        let (b, h, t, d) = l_src.shape().dims4()?;
199        let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
200            h * t * d
201        } else {
202            0usize
203        };
204        let el = b * h * t * d;
205        let output = device.new_buffer(el, src.dtype(), "rope_i")?;
206        candle_metal_kernels::call_rope_i(
207            device.metal_device(),
208            &encoder,
209            kernels,
210            name,
211            b * h,
212            t * d,
213            stride_b,
214            src.buffer(),
215            l_src.start_offset() * src.dtype().size_in_bytes(),
216            cos.buffer(),
217            l_cos.start_offset() * cos.dtype().size_in_bytes(),
218            sin.buffer(),
219            l_sin.start_offset() * sin.dtype().size_in_bytes(),
220            &output,
221        )
222        .map_err(candle::Error::wrap)?;
223        let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
224        Ok((out, l_src.shape().clone()))
225    }
226}
227
228fn rope_check_cs(cs: &Tensor, b_sz: usize) -> Result<(usize, usize)> {
229    match *cs.dims() {
230        [t, d] => Ok((t, d)),
231        [b, t, d] => {
232            if b != b_sz {
233                candle::bail!("inconsistent batch size in rope {b_sz} {cs:?}",)
234            }
235            Ok((t, d))
236        }
237        _ => candle::bail!("cos/sin has to be 2D or 3D in rope {b_sz} {cs:?}"),
238    }
239}
240
241pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
242    let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
243    let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
244    let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
245    if cos_n_embd * 2 != n_embd
246        || sin_n_embd * 2 != n_embd
247        || seq_len > cos_seq_len
248        || seq_len > sin_seq_len
249    {
250        candle::bail!(
251            "inconsistent last dim size in rope {:?} {:?} {:?}",
252            xs.shape(),
253            cos.shape(),
254            sin.shape()
255        )
256    }
257    if !xs.is_contiguous() {
258        candle::bail!("xs has to be contiguous in rope")
259    }
260    if !cos.is_contiguous() {
261        candle::bail!("cos has to be contiguous in rope")
262    }
263    if !sin.is_contiguous() {
264        candle::bail!("sin has to be contiguous in rope")
265    }
266    xs.apply_op3_no_bwd(cos, sin, &RotaryEmbI)
267}
268
269pub fn rope_i_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
270    let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
271    let cos = cos
272        .narrow(0, 0, seq_len)?
273        .reshape((seq_len, n_embd / 2, 1))?;
274    let sin = sin
275        .narrow(0, 0, seq_len)?
276        .reshape((seq_len, n_embd / 2, 1))?;
277    let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
278    let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
279    let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
280    let x0 = x.narrow(D::Minus1, 0, 1)?;
281    let x1 = x.narrow(D::Minus1, 1, 1)?;
282    let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
283    let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
284    let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
285    let rope = rope.flatten_from(D::Minus2)?;
286    Ok(rope)
287}
288
289/// Contiguous variant of rope embeddings.
290#[derive(Debug, Clone)]
291struct RotaryEmb;
292
293impl candle::CustomOp3 for RotaryEmb {
294    fn name(&self) -> &'static str {
295        "rotary-emb"
296    }
297
298    fn cpu_fwd(
299        &self,
300        s1: &CpuStorage,
301        l1: &Layout,
302        s2: &CpuStorage,
303        l2: &Layout,
304        s3: &CpuStorage,
305        l3: &Layout,
306    ) -> Result<(CpuStorage, Shape)> {
307        fn inner<T: candle::WithDType + num_traits::Float>(
308            src: &[T],
309            l_src: &Layout,
310            cos: &[T],
311            l_cos: &Layout,
312            sin: &[T],
313            l_sin: &Layout,
314        ) -> Result<(CpuStorage, Shape)> {
315            let src = match l_src.contiguous_offsets() {
316                None => candle::bail!("input src has to be contiguous"),
317                Some((o1, o2)) => &src[o1..o2],
318            };
319            let cos = match l_cos.contiguous_offsets() {
320                None => candle::bail!("input cos has to be contiguous"),
321                Some((o1, o2)) => &cos[o1..o2],
322            };
323            let sin = match l_sin.contiguous_offsets() {
324                None => candle::bail!("input sin has to be contiguous"),
325                Some((o1, o2)) => &sin[o1..o2],
326            };
327            let (b, h, t, d) = l_src.shape().dims4()?;
328            let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
329            let el_count = b * h * t * d;
330            let mut dst = vec![T::zero(); el_count];
331            src.par_chunks(t * d)
332                .zip(dst.par_chunks_mut(t * d))
333                .enumerate()
334                .for_each(|(bh_i, (src, dst))| {
335                    for i_t in 0..t {
336                        for i_d in 0..d / 2 {
337                            let i1 = i_t * d + i_d;
338                            let i2 = i1 + d / 2;
339                            let i_cs = i_t * (d / 2) + i_d;
340                            let i_cs = if unbatched_rope {
341                                let b_i = bh_i / h;
342                                i_cs + b_i * t * d / 2
343                            } else {
344                                i_cs
345                            };
346                            dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
347                            dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
348                        }
349                    }
350                });
351            let storage = candle::WithDType::to_cpu_storage_owned(dst);
352            Ok((storage, (b, h, t, d).into()))
353        }
354
355        use candle::backend::BackendStorage;
356        use CpuStorage::{BF16, F16, F32, F64};
357        match (s1, s2, s3) {
358            (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
359            (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
360            (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
361            (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
362            _ => candle::bail!(
363                "unsupported dtype for rope {:?} {:?} {:?}",
364                s1.dtype(),
365                s2.dtype(),
366                s3.dtype()
367            ),
368        }
369    }
370
371    #[cfg(feature = "cuda")]
372    fn cuda_fwd(
373        &self,
374        s1: &candle::CudaStorage,
375        l1: &Layout,
376        s2: &candle::CudaStorage,
377        l2: &Layout,
378        s3: &candle::CudaStorage,
379        l3: &Layout,
380    ) -> Result<(candle::CudaStorage, Shape)> {
381        use candle::cuda_backend::cudarc::driver::{
382            CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
383        };
384        use candle::cuda_backend::{kernel_name, kernels, WrapErr};
385        use candle::{CudaDevice, WithDType};
386
387        fn inner<T: DeviceRepr + WithDType>(
388            src: &CudaSlice<T>,
389            l_src: &Layout,
390            cos: &CudaSlice<T>,
391            l_cos: &Layout,
392            sin: &CudaSlice<T>,
393            l_sin: &Layout,
394            dev: &CudaDevice,
395        ) -> Result<CudaSlice<T>> {
396            let src = match l_src.contiguous_offsets() {
397                None => candle::bail!("src input has to be contiguous"),
398                Some((o1, o2)) => src.slice(o1..o2),
399            };
400            let cos = match l_cos.contiguous_offsets() {
401                None => candle::bail!("cos input has to be contiguous"),
402                Some((o1, o2)) => cos.slice(o1..o2),
403            };
404            let sin = match l_sin.contiguous_offsets() {
405                None => candle::bail!("sin input has to be contiguous"),
406                Some((o1, o2)) => sin.slice(o1..o2),
407            };
408            let (b, h, t, d) = l_src.shape().dims4()?;
409            let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
410                (h * t * d) as u32
411            } else {
412                0u32
413            };
414            let el = b * h * t * d;
415            let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
416            let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
417            // SAFETY: Set later by running the kernel.
418            let dst = unsafe { dev.alloc::<T>(el)? };
419            let mut builder = func.builder();
420            builder.arg(&src);
421            builder.arg(&cos);
422            builder.arg(&sin);
423            builder.arg(&dst);
424            candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32, stride_b);
425            // SAFETY: ffi.
426            unsafe { builder.launch(cfg) }.w()?;
427            Ok(dst)
428        }
429
430        use candle::backend::BackendStorage;
431        use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
432        let dev = s1.device();
433        let slice = match (&s1.slice, &s2.slice, &s3.slice) {
434            (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
435            (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
436            (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
437            (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
438            _ => candle::bail!(
439                "unsupported dtype for rope {:?} {:?} {:?}",
440                s1.dtype(),
441                s2.dtype(),
442                s3.dtype()
443            ),
444        };
445        let dst = candle::cuda_backend::CudaStorage {
446            slice,
447            device: dev.clone(),
448        };
449        Ok((dst, l1.shape().clone()))
450    }
451
452    #[cfg(feature = "metal")]
453    fn metal_fwd(
454        &self,
455        src: &candle::MetalStorage,
456        l_src: &Layout,
457        cos: &candle::MetalStorage,
458        l_cos: &Layout,
459        sin: &candle::MetalStorage,
460        l_sin: &Layout,
461    ) -> Result<(candle::MetalStorage, Shape)> {
462        use candle::backend::BackendStorage;
463        let device = src.device();
464        let encoder = device.command_encoder()?;
465        encoder.set_label("rope");
466        let kernels = device.kernels();
467        if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
468            candle::bail!(
469                "dtype mismatch in rope {:?} {:?} {:?}",
470                src.dtype(),
471                cos.dtype(),
472                sin.dtype()
473            )
474        }
475        let name = match src.dtype() {
476            candle::DType::F32 => "rope_f32",
477            candle::DType::F16 => "rope_f16",
478            candle::DType::BF16 => "rope_bf16",
479            dtype => candle::bail!("rope is not implemented for {dtype:?}"),
480        };
481        let (b, h, t, d) = l_src.shape().dims4()?;
482        let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
483            h * t * d
484        } else {
485            0usize
486        };
487        let el = b * h * t * d;
488        let output = device.new_buffer(el, src.dtype(), "rope")?;
489        candle_metal_kernels::call_rope(
490            device.metal_device(),
491            &encoder,
492            kernels,
493            name,
494            b * h,
495            t * d,
496            d,
497            stride_b,
498            src.buffer(),
499            l_src.start_offset() * src.dtype().size_in_bytes(),
500            cos.buffer(),
501            l_cos.start_offset() * cos.dtype().size_in_bytes(),
502            sin.buffer(),
503            l_sin.start_offset() * sin.dtype().size_in_bytes(),
504            &output,
505        )
506        .map_err(candle::Error::wrap)?;
507        let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
508        Ok((out, l_src.shape().clone()))
509    }
510}
511
512pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
513    let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
514    let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
515    let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
516    if cos_n_embd * 2 != n_embd
517        || sin_n_embd * 2 != n_embd
518        || seq_len > cos_seq_len
519        || seq_len > sin_seq_len
520    {
521        candle::bail!(
522            "inconsistent last dim size in rope {:?} {:?} {:?}",
523            xs.shape(),
524            cos.shape(),
525            sin.shape()
526        )
527    }
528    if !xs.is_contiguous() {
529        candle::bail!("xs has to be contiguous in rope")
530    }
531    if !cos.is_contiguous() {
532        candle::bail!("cos has to be contiguous in rope")
533    }
534    if !sin.is_contiguous() {
535        candle::bail!("sin has to be contiguous in rope")
536    }
537    xs.apply_op3_no_bwd(cos, sin, &RotaryEmb)
538}
539
540fn rotate_half(xs: &Tensor) -> Result<Tensor> {
541    let last_dim = xs.dim(D::Minus1)?;
542    let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
543    let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
544    Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
545}
546
547pub fn rope_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
548    let (_b_sz, _h, seq_len, _n_embd) = x.dims4()?;
549    let cos = Tensor::cat(&[cos, cos], D::Minus1)?;
550    let sin = Tensor::cat(&[sin, sin], D::Minus1)?;
551    let cos = cos.narrow(0, 0, seq_len)?;
552    let sin = sin.narrow(0, 0, seq_len)?;
553    let cos = cos.unsqueeze(0)?.unsqueeze(0)?;
554    let sin = sin.unsqueeze(0)?.unsqueeze(0)?;
555    x.broadcast_mul(&cos)? + rotate_half(x)?.broadcast_mul(&sin)?
556}
557
558/// T (seqlen)/H (num-heads)/D (head-dim) contiguous variant of rope embeddings.
559#[derive(Debug, Clone)]
560struct RotaryEmbThd;
561
562impl candle::CustomOp3 for RotaryEmbThd {
563    fn name(&self) -> &'static str {
564        "rotary-emb"
565    }
566
567    fn cpu_fwd(
568        &self,
569        s1: &CpuStorage,
570        l1: &Layout,
571        s2: &CpuStorage,
572        l2: &Layout,
573        s3: &CpuStorage,
574        l3: &Layout,
575    ) -> Result<(CpuStorage, Shape)> {
576        fn inner<T: candle::WithDType + num_traits::Float>(
577            src: &[T],
578            l_src: &Layout,
579            cos: &[T],
580            l_cos: &Layout,
581            sin: &[T],
582            l_sin: &Layout,
583        ) -> Result<(CpuStorage, Shape)> {
584            let src = match l_src.contiguous_offsets() {
585                None => candle::bail!("input src has to be contiguous"),
586                Some((o1, o2)) => &src[o1..o2],
587            };
588            let cos = match l_cos.contiguous_offsets() {
589                None => candle::bail!("input cos has to be contiguous"),
590                Some((o1, o2)) => &cos[o1..o2],
591            };
592            let sin = match l_sin.contiguous_offsets() {
593                None => candle::bail!("input sin has to be contiguous"),
594                Some((o1, o2)) => &sin[o1..o2],
595            };
596            let (b, t, h, d) = l_src.shape().dims4()?;
597            let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
598            let el_count = b * h * t * d;
599            let mut dst = vec![T::zero(); el_count];
600            src.par_chunks(t * h * d)
601                .zip(dst.par_chunks_mut(t * h * d))
602                .enumerate()
603                .for_each(|(b_i, (src, dst))| {
604                    for i_t in 0..t {
605                        for i_d in 0..d / 2 {
606                            let i_cs = i_t * (d / 2) + i_d;
607                            let i_cs = if unbatched_rope {
608                                i_cs + b_i * t * d / 2
609                            } else {
610                                i_cs
611                            };
612                            for i_h in 0..h {
613                                let i1 = i_t * h * d + i_h * d + i_d;
614                                let i2 = i1 + d / 2;
615                                dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
616                                dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
617                            }
618                        }
619                    }
620                });
621            let storage = candle::WithDType::to_cpu_storage_owned(dst);
622            Ok((storage, (b, t, h, d).into()))
623        }
624
625        use candle::backend::BackendStorage;
626        use CpuStorage::{BF16, F16, F32, F64};
627        match (s1, s2, s3) {
628            (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
629            (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
630            (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
631            (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
632            _ => candle::bail!(
633                "unsupported dtype for rope {:?} {:?} {:?}",
634                s1.dtype(),
635                s2.dtype(),
636                s3.dtype()
637            ),
638        }
639    }
640
641    #[cfg(feature = "cuda")]
642    fn cuda_fwd(
643        &self,
644        s1: &candle::CudaStorage,
645        l1: &Layout,
646        s2: &candle::CudaStorage,
647        l2: &Layout,
648        s3: &candle::CudaStorage,
649        l3: &Layout,
650    ) -> Result<(candle::CudaStorage, Shape)> {
651        use candle::cuda_backend::cudarc::driver::{
652            CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
653        };
654        use candle::cuda_backend::{kernel_name, kernels, WrapErr};
655        use candle::{CudaDevice, WithDType};
656
657        fn inner<T: DeviceRepr + WithDType>(
658            src: &CudaSlice<T>,
659            l_src: &Layout,
660            cos: &CudaSlice<T>,
661            l_cos: &Layout,
662            sin: &CudaSlice<T>,
663            l_sin: &Layout,
664            dev: &CudaDevice,
665        ) -> Result<CudaSlice<T>> {
666            let src = match l_src.contiguous_offsets() {
667                None => candle::bail!("src input has to be contiguous"),
668                Some((o1, o2)) => src.slice(o1..o2),
669            };
670            let cos = match l_cos.contiguous_offsets() {
671                None => candle::bail!("cos input has to be contiguous"),
672                Some((o1, o2)) => cos.slice(o1..o2),
673            };
674            let sin = match l_sin.contiguous_offsets() {
675                None => candle::bail!("sin input has to be contiguous"),
676                Some((o1, o2)) => sin.slice(o1..o2),
677            };
678            let (b, t, h, d) = l_src.shape().dims4()?;
679            let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
680                (h * t * d) as u32
681            } else {
682                0u32
683            };
684            let el = b * h * t * d;
685            let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
686            let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
687            // SAFETY: Set later by running the kernel.
688            let dst = unsafe { dev.alloc::<T>(el)? };
689            let mut builder = func.builder();
690            builder.arg(&src);
691            builder.arg(&cos);
692            builder.arg(&sin);
693            builder.arg(&dst);
694            candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32, stride_b);
695            // SAFETY: ffi.
696            unsafe { builder.launch(cfg) }.w()?;
697            Ok(dst)
698        }
699
700        use candle::backend::BackendStorage;
701        use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
702        let dev = s1.device();
703        let slice = match (&s1.slice, &s2.slice, &s3.slice) {
704            (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
705            (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
706            (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
707            (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
708            _ => candle::bail!(
709                "unsupported dtype for rope {:?} {:?} {:?}",
710                s1.dtype(),
711                s2.dtype(),
712                s3.dtype()
713            ),
714        };
715        let dst = candle::cuda_backend::CudaStorage {
716            slice,
717            device: dev.clone(),
718        };
719        Ok((dst, l1.shape().clone()))
720    }
721
722    #[cfg(feature = "metal")]
723    fn metal_fwd(
724        &self,
725        src: &candle::MetalStorage,
726        l_src: &Layout,
727        cos: &candle::MetalStorage,
728        l_cos: &Layout,
729        sin: &candle::MetalStorage,
730        l_sin: &Layout,
731    ) -> Result<(candle::MetalStorage, Shape)> {
732        use candle::backend::BackendStorage;
733        let device = src.device();
734        let encoder = device.command_encoder()?;
735        encoder.set_label("rope_thd");
736        let kernels = device.kernels();
737        if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
738            candle::bail!(
739                "dtype mismatch in rope {:?} {:?} {:?}",
740                src.dtype(),
741                cos.dtype(),
742                sin.dtype()
743            )
744        }
745        let name = match src.dtype() {
746            candle::DType::F32 => "rope_thd_f32",
747            candle::DType::F16 => "rope_thd_f16",
748            candle::DType::BF16 => "rope_thd_bf16",
749            dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"),
750        };
751        let (b, t, h, d) = l_src.shape().dims4()?;
752        let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
753            h * t * d
754        } else {
755            0usize
756        };
757        let el = b * h * t * d;
758        let output = device.new_buffer(el, src.dtype(), "rope_thd")?;
759        candle_metal_kernels::call_rope_thd(
760            device.metal_device(),
761            &encoder,
762            kernels,
763            name,
764            b,
765            t,
766            h,
767            d,
768            stride_b,
769            src.buffer(),
770            l_src.start_offset() * src.dtype().size_in_bytes(),
771            cos.buffer(),
772            l_cos.start_offset() * cos.dtype().size_in_bytes(),
773            sin.buffer(),
774            l_sin.start_offset() * sin.dtype().size_in_bytes(),
775            &output,
776        )
777        .map_err(candle::Error::wrap)?;
778        let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
779        Ok((out, l_src.shape().clone()))
780    }
781}
782
783pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
784    let (b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
785    let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
786    let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
787    if cos_n_embd * 2 != n_embd
788        || sin_n_embd * 2 != n_embd
789        || seq_len > cos_seq_len
790        || seq_len > sin_seq_len
791    {
792        candle::bail!(
793            "inconsistent last dim size in rope {:?} {:?} {:?}",
794            xs.shape(),
795            cos.shape(),
796            sin.shape()
797        )
798    }
799    if !xs.is_contiguous() {
800        candle::bail!("xs has to be contiguous in rope")
801    }
802    if !cos.is_contiguous() {
803        candle::bail!("cos has to be contiguous in rope")
804    }
805    if !sin.is_contiguous() {
806        candle::bail!("sin has to be contiguous in rope")
807    }
808    xs.apply_op3_no_bwd(cos, sin, &RotaryEmbThd)
809}