1use candle::{CpuStorage, Layout, Result, Shape, Tensor, D};
4use rayon::prelude::*;
5
6#[derive(Debug, Clone)]
12struct RotaryEmbI;
13
14impl candle::CustomOp3 for RotaryEmbI {
15 fn name(&self) -> &'static str {
16 "rotary-emb-int"
17 }
18
19 fn cpu_fwd(
20 &self,
21 s1: &CpuStorage,
22 l1: &Layout,
23 s2: &CpuStorage,
24 l2: &Layout,
25 s3: &CpuStorage,
26 l3: &Layout,
27 ) -> Result<(CpuStorage, Shape)> {
28 fn inner<T: candle::WithDType + num_traits::Float>(
29 src: &[T],
30 l_src: &Layout,
31 cos: &[T],
32 l_cos: &Layout,
33 sin: &[T],
34 l_sin: &Layout,
35 ) -> Result<(CpuStorage, Shape)> {
36 let src = match l_src.contiguous_offsets() {
37 None => candle::bail!("input src has to be contiguous"),
38 Some((o1, o2)) => &src[o1..o2],
39 };
40 let cos = match l_cos.contiguous_offsets() {
41 None => candle::bail!("input cos has to be contiguous"),
42 Some((o1, o2)) => &cos[o1..o2],
43 };
44 let sin = match l_sin.contiguous_offsets() {
45 None => candle::bail!("input sin has to be contiguous"),
46 Some((o1, o2)) => &sin[o1..o2],
47 };
48 let (b, h, t, d) = l_src.shape().dims4()?;
49 let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
50 let el_count = b * h * t * d;
51 let mut dst = vec![T::zero(); el_count];
52 src.par_chunks(t * d)
53 .zip(dst.par_chunks_mut(t * d))
54 .enumerate()
55 .for_each(|(bh_i, (src, dst))| {
56 for i_over_2 in 0..t * d / 2 {
57 let i = 2 * i_over_2;
58 let rope_i = if unbatched_rope {
59 let b_i = bh_i / h;
60 i_over_2 + b_i * t * d / 2
61 } else {
62 i_over_2
63 };
64 dst[i] = src[i] * cos[rope_i] - src[i + 1] * sin[rope_i];
65 dst[i + 1] = src[i] * sin[rope_i] + src[i + 1] * cos[rope_i];
66 }
67 });
68 let storage = candle::WithDType::to_cpu_storage_owned(dst);
69 Ok((storage, (b, h, t, d).into()))
70 }
71
72 use candle::backend::BackendStorage;
73 use CpuStorage::{BF16, F16, F32, F64};
74 match (s1, s2, s3) {
75 (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
76 (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
77 (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
78 (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
79 _ => candle::bail!(
80 "unsupported dtype for rope {:?} {:?} {:?}",
81 s1.dtype(),
82 s2.dtype(),
83 s3.dtype()
84 ),
85 }
86 }
87
88 #[cfg(feature = "cuda")]
89 fn cuda_fwd(
90 &self,
91 s1: &candle::CudaStorage,
92 l1: &Layout,
93 s2: &candle::CudaStorage,
94 l2: &Layout,
95 s3: &candle::CudaStorage,
96 l3: &Layout,
97 ) -> Result<(candle::CudaStorage, Shape)> {
98 use candle::cuda_backend::cudarc::driver::{
99 CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
100 };
101 use candle::cuda_backend::{kernel_name, kernels, WrapErr};
102 use candle::{CudaDevice, WithDType};
103
104 fn inner<T: DeviceRepr + WithDType>(
105 src: &CudaSlice<T>,
106 l_src: &Layout,
107 cos: &CudaSlice<T>,
108 l_cos: &Layout,
109 sin: &CudaSlice<T>,
110 l_sin: &Layout,
111 dev: &CudaDevice,
112 ) -> Result<CudaSlice<T>> {
113 let src = match l_src.contiguous_offsets() {
114 None => candle::bail!("src input has to be contiguous"),
115 Some((o1, o2)) => src.slice(o1..o2),
116 };
117 let cos = match l_cos.contiguous_offsets() {
118 None => candle::bail!("cos input has to be contiguous"),
119 Some((o1, o2)) => cos.slice(o1..o2),
120 };
121 let sin = match l_sin.contiguous_offsets() {
122 None => candle::bail!("sin input has to be contiguous"),
123 Some((o1, o2)) => sin.slice(o1..o2),
124 };
125 let (b, h, t, d) = l_src.shape().dims4()?;
126 let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
127 (h * t * d) as u32
128 } else {
129 0u32
130 };
131 let el = b * h * t * d;
132 let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
133 let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
134 let dst = unsafe { dev.alloc::<T>(el)? };
136 let mut builder = func.builder();
137 builder.arg(&src);
138 builder.arg(&cos);
139 builder.arg(&sin);
140 builder.arg(&dst);
141 candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, stride_b);
142 unsafe { builder.launch(cfg) }.w()?;
144 Ok(dst)
145 }
146
147 use candle::backend::BackendStorage;
148 use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
149 let dev = s1.device();
150 let slice = match (&s1.slice, &s2.slice, &s3.slice) {
151 (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
152 (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
153 (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
154 (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
155 _ => candle::bail!(
156 "unsupported dtype for rope {:?} {:?} {:?}",
157 s1.dtype(),
158 s2.dtype(),
159 s3.dtype()
160 ),
161 };
162 let dst = candle::cuda_backend::CudaStorage {
163 slice,
164 device: dev.clone(),
165 };
166 Ok((dst, l1.shape().clone()))
167 }
168
169 #[cfg(feature = "metal")]
170 fn metal_fwd(
171 &self,
172 src: &candle::MetalStorage,
173 l_src: &Layout,
174 cos: &candle::MetalStorage,
175 l_cos: &Layout,
176 sin: &candle::MetalStorage,
177 l_sin: &Layout,
178 ) -> Result<(candle::MetalStorage, Shape)> {
179 use candle::backend::BackendStorage;
180 let device = src.device();
181 let encoder = device.command_encoder()?;
182 encoder.set_label("rope_i");
183 let kernels = device.kernels();
184 if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
185 candle::bail!(
186 "dtype mismatch in rope-i {:?} {:?} {:?}",
187 src.dtype(),
188 cos.dtype(),
189 sin.dtype()
190 )
191 }
192 let name = match src.dtype() {
193 candle::DType::F32 => "rope_i_f32",
194 candle::DType::F16 => "rope_i_f16",
195 candle::DType::BF16 => "rope_i_bf16",
196 dtype => candle::bail!("rope-i is not implemented for {dtype:?}"),
197 };
198 let (b, h, t, d) = l_src.shape().dims4()?;
199 let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
200 h * t * d
201 } else {
202 0usize
203 };
204 let el = b * h * t * d;
205 let output = device.new_buffer(el, src.dtype(), "rope_i")?;
206 candle_metal_kernels::call_rope_i(
207 device.metal_device(),
208 &encoder,
209 kernels,
210 name,
211 b * h,
212 t * d,
213 stride_b,
214 src.buffer(),
215 l_src.start_offset() * src.dtype().size_in_bytes(),
216 cos.buffer(),
217 l_cos.start_offset() * cos.dtype().size_in_bytes(),
218 sin.buffer(),
219 l_sin.start_offset() * sin.dtype().size_in_bytes(),
220 &output,
221 )
222 .map_err(candle::Error::wrap)?;
223 let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
224 Ok((out, l_src.shape().clone()))
225 }
226}
227
228fn rope_check_cs(cs: &Tensor, b_sz: usize) -> Result<(usize, usize)> {
229 match *cs.dims() {
230 [t, d] => Ok((t, d)),
231 [b, t, d] => {
232 if b != b_sz {
233 candle::bail!("inconsistent batch size in rope {b_sz} {cs:?}",)
234 }
235 Ok((t, d))
236 }
237 _ => candle::bail!("cos/sin has to be 2D or 3D in rope {b_sz} {cs:?}"),
238 }
239}
240
241pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
242 let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
243 let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
244 let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
245 if cos_n_embd * 2 != n_embd
246 || sin_n_embd * 2 != n_embd
247 || seq_len > cos_seq_len
248 || seq_len > sin_seq_len
249 {
250 candle::bail!(
251 "inconsistent last dim size in rope {:?} {:?} {:?}",
252 xs.shape(),
253 cos.shape(),
254 sin.shape()
255 )
256 }
257 if !xs.is_contiguous() {
258 candle::bail!("xs has to be contiguous in rope")
259 }
260 if !cos.is_contiguous() {
261 candle::bail!("cos has to be contiguous in rope")
262 }
263 if !sin.is_contiguous() {
264 candle::bail!("sin has to be contiguous in rope")
265 }
266 xs.apply_op3_no_bwd(cos, sin, &RotaryEmbI)
267}
268
269pub fn rope_i_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
270 let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
271 let cos = cos
272 .narrow(0, 0, seq_len)?
273 .reshape((seq_len, n_embd / 2, 1))?;
274 let sin = sin
275 .narrow(0, 0, seq_len)?
276 .reshape((seq_len, n_embd / 2, 1))?;
277 let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
278 let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
279 let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
280 let x0 = x.narrow(D::Minus1, 0, 1)?;
281 let x1 = x.narrow(D::Minus1, 1, 1)?;
282 let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
283 let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
284 let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
285 let rope = rope.flatten_from(D::Minus2)?;
286 Ok(rope)
287}
288
289#[derive(Debug, Clone)]
291struct RotaryEmb;
292
293impl candle::CustomOp3 for RotaryEmb {
294 fn name(&self) -> &'static str {
295 "rotary-emb"
296 }
297
298 fn cpu_fwd(
299 &self,
300 s1: &CpuStorage,
301 l1: &Layout,
302 s2: &CpuStorage,
303 l2: &Layout,
304 s3: &CpuStorage,
305 l3: &Layout,
306 ) -> Result<(CpuStorage, Shape)> {
307 fn inner<T: candle::WithDType + num_traits::Float>(
308 src: &[T],
309 l_src: &Layout,
310 cos: &[T],
311 l_cos: &Layout,
312 sin: &[T],
313 l_sin: &Layout,
314 ) -> Result<(CpuStorage, Shape)> {
315 let src = match l_src.contiguous_offsets() {
316 None => candle::bail!("input src has to be contiguous"),
317 Some((o1, o2)) => &src[o1..o2],
318 };
319 let cos = match l_cos.contiguous_offsets() {
320 None => candle::bail!("input cos has to be contiguous"),
321 Some((o1, o2)) => &cos[o1..o2],
322 };
323 let sin = match l_sin.contiguous_offsets() {
324 None => candle::bail!("input sin has to be contiguous"),
325 Some((o1, o2)) => &sin[o1..o2],
326 };
327 let (b, h, t, d) = l_src.shape().dims4()?;
328 let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
329 let el_count = b * h * t * d;
330 let mut dst = vec![T::zero(); el_count];
331 src.par_chunks(t * d)
332 .zip(dst.par_chunks_mut(t * d))
333 .enumerate()
334 .for_each(|(bh_i, (src, dst))| {
335 for i_t in 0..t {
336 for i_d in 0..d / 2 {
337 let i1 = i_t * d + i_d;
338 let i2 = i1 + d / 2;
339 let i_cs = i_t * (d / 2) + i_d;
340 let i_cs = if unbatched_rope {
341 let b_i = bh_i / h;
342 i_cs + b_i * t * d / 2
343 } else {
344 i_cs
345 };
346 dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
347 dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
348 }
349 }
350 });
351 let storage = candle::WithDType::to_cpu_storage_owned(dst);
352 Ok((storage, (b, h, t, d).into()))
353 }
354
355 use candle::backend::BackendStorage;
356 use CpuStorage::{BF16, F16, F32, F64};
357 match (s1, s2, s3) {
358 (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
359 (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
360 (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
361 (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
362 _ => candle::bail!(
363 "unsupported dtype for rope {:?} {:?} {:?}",
364 s1.dtype(),
365 s2.dtype(),
366 s3.dtype()
367 ),
368 }
369 }
370
371 #[cfg(feature = "cuda")]
372 fn cuda_fwd(
373 &self,
374 s1: &candle::CudaStorage,
375 l1: &Layout,
376 s2: &candle::CudaStorage,
377 l2: &Layout,
378 s3: &candle::CudaStorage,
379 l3: &Layout,
380 ) -> Result<(candle::CudaStorage, Shape)> {
381 use candle::cuda_backend::cudarc::driver::{
382 CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
383 };
384 use candle::cuda_backend::{kernel_name, kernels, WrapErr};
385 use candle::{CudaDevice, WithDType};
386
387 fn inner<T: DeviceRepr + WithDType>(
388 src: &CudaSlice<T>,
389 l_src: &Layout,
390 cos: &CudaSlice<T>,
391 l_cos: &Layout,
392 sin: &CudaSlice<T>,
393 l_sin: &Layout,
394 dev: &CudaDevice,
395 ) -> Result<CudaSlice<T>> {
396 let src = match l_src.contiguous_offsets() {
397 None => candle::bail!("src input has to be contiguous"),
398 Some((o1, o2)) => src.slice(o1..o2),
399 };
400 let cos = match l_cos.contiguous_offsets() {
401 None => candle::bail!("cos input has to be contiguous"),
402 Some((o1, o2)) => cos.slice(o1..o2),
403 };
404 let sin = match l_sin.contiguous_offsets() {
405 None => candle::bail!("sin input has to be contiguous"),
406 Some((o1, o2)) => sin.slice(o1..o2),
407 };
408 let (b, h, t, d) = l_src.shape().dims4()?;
409 let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
410 (h * t * d) as u32
411 } else {
412 0u32
413 };
414 let el = b * h * t * d;
415 let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
416 let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
417 let dst = unsafe { dev.alloc::<T>(el)? };
419 let mut builder = func.builder();
420 builder.arg(&src);
421 builder.arg(&cos);
422 builder.arg(&sin);
423 builder.arg(&dst);
424 candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32, stride_b);
425 unsafe { builder.launch(cfg) }.w()?;
427 Ok(dst)
428 }
429
430 use candle::backend::BackendStorage;
431 use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
432 let dev = s1.device();
433 let slice = match (&s1.slice, &s2.slice, &s3.slice) {
434 (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
435 (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
436 (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
437 (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
438 _ => candle::bail!(
439 "unsupported dtype for rope {:?} {:?} {:?}",
440 s1.dtype(),
441 s2.dtype(),
442 s3.dtype()
443 ),
444 };
445 let dst = candle::cuda_backend::CudaStorage {
446 slice,
447 device: dev.clone(),
448 };
449 Ok((dst, l1.shape().clone()))
450 }
451
452 #[cfg(feature = "metal")]
453 fn metal_fwd(
454 &self,
455 src: &candle::MetalStorage,
456 l_src: &Layout,
457 cos: &candle::MetalStorage,
458 l_cos: &Layout,
459 sin: &candle::MetalStorage,
460 l_sin: &Layout,
461 ) -> Result<(candle::MetalStorage, Shape)> {
462 use candle::backend::BackendStorage;
463 let device = src.device();
464 let encoder = device.command_encoder()?;
465 encoder.set_label("rope");
466 let kernels = device.kernels();
467 if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
468 candle::bail!(
469 "dtype mismatch in rope {:?} {:?} {:?}",
470 src.dtype(),
471 cos.dtype(),
472 sin.dtype()
473 )
474 }
475 let name = match src.dtype() {
476 candle::DType::F32 => "rope_f32",
477 candle::DType::F16 => "rope_f16",
478 candle::DType::BF16 => "rope_bf16",
479 dtype => candle::bail!("rope is not implemented for {dtype:?}"),
480 };
481 let (b, h, t, d) = l_src.shape().dims4()?;
482 let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
483 h * t * d
484 } else {
485 0usize
486 };
487 let el = b * h * t * d;
488 let output = device.new_buffer(el, src.dtype(), "rope")?;
489 candle_metal_kernels::call_rope(
490 device.metal_device(),
491 &encoder,
492 kernels,
493 name,
494 b * h,
495 t * d,
496 d,
497 stride_b,
498 src.buffer(),
499 l_src.start_offset() * src.dtype().size_in_bytes(),
500 cos.buffer(),
501 l_cos.start_offset() * cos.dtype().size_in_bytes(),
502 sin.buffer(),
503 l_sin.start_offset() * sin.dtype().size_in_bytes(),
504 &output,
505 )
506 .map_err(candle::Error::wrap)?;
507 let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
508 Ok((out, l_src.shape().clone()))
509 }
510}
511
512pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
513 let (b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
514 let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
515 let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
516 if cos_n_embd * 2 != n_embd
517 || sin_n_embd * 2 != n_embd
518 || seq_len > cos_seq_len
519 || seq_len > sin_seq_len
520 {
521 candle::bail!(
522 "inconsistent last dim size in rope {:?} {:?} {:?}",
523 xs.shape(),
524 cos.shape(),
525 sin.shape()
526 )
527 }
528 if !xs.is_contiguous() {
529 candle::bail!("xs has to be contiguous in rope")
530 }
531 if !cos.is_contiguous() {
532 candle::bail!("cos has to be contiguous in rope")
533 }
534 if !sin.is_contiguous() {
535 candle::bail!("sin has to be contiguous in rope")
536 }
537 xs.apply_op3_no_bwd(cos, sin, &RotaryEmb)
538}
539
540fn rotate_half(xs: &Tensor) -> Result<Tensor> {
541 let last_dim = xs.dim(D::Minus1)?;
542 let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
543 let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
544 Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
545}
546
547pub fn rope_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
548 let (_b_sz, _h, seq_len, _n_embd) = x.dims4()?;
549 let cos = Tensor::cat(&[cos, cos], D::Minus1)?;
550 let sin = Tensor::cat(&[sin, sin], D::Minus1)?;
551 let cos = cos.narrow(0, 0, seq_len)?;
552 let sin = sin.narrow(0, 0, seq_len)?;
553 let cos = cos.unsqueeze(0)?.unsqueeze(0)?;
554 let sin = sin.unsqueeze(0)?.unsqueeze(0)?;
555 x.broadcast_mul(&cos)? + rotate_half(x)?.broadcast_mul(&sin)?
556}
557
558#[derive(Debug, Clone)]
560struct RotaryEmbThd;
561
562impl candle::CustomOp3 for RotaryEmbThd {
563 fn name(&self) -> &'static str {
564 "rotary-emb"
565 }
566
567 fn cpu_fwd(
568 &self,
569 s1: &CpuStorage,
570 l1: &Layout,
571 s2: &CpuStorage,
572 l2: &Layout,
573 s3: &CpuStorage,
574 l3: &Layout,
575 ) -> Result<(CpuStorage, Shape)> {
576 fn inner<T: candle::WithDType + num_traits::Float>(
577 src: &[T],
578 l_src: &Layout,
579 cos: &[T],
580 l_cos: &Layout,
581 sin: &[T],
582 l_sin: &Layout,
583 ) -> Result<(CpuStorage, Shape)> {
584 let src = match l_src.contiguous_offsets() {
585 None => candle::bail!("input src has to be contiguous"),
586 Some((o1, o2)) => &src[o1..o2],
587 };
588 let cos = match l_cos.contiguous_offsets() {
589 None => candle::bail!("input cos has to be contiguous"),
590 Some((o1, o2)) => &cos[o1..o2],
591 };
592 let sin = match l_sin.contiguous_offsets() {
593 None => candle::bail!("input sin has to be contiguous"),
594 Some((o1, o2)) => &sin[o1..o2],
595 };
596 let (b, t, h, d) = l_src.shape().dims4()?;
597 let unbatched_rope = l_cos.dims().len() == 3 && l_sin.dims().len() == 3;
598 let el_count = b * h * t * d;
599 let mut dst = vec![T::zero(); el_count];
600 src.par_chunks(t * h * d)
601 .zip(dst.par_chunks_mut(t * h * d))
602 .enumerate()
603 .for_each(|(b_i, (src, dst))| {
604 for i_t in 0..t {
605 for i_d in 0..d / 2 {
606 let i_cs = i_t * (d / 2) + i_d;
607 let i_cs = if unbatched_rope {
608 i_cs + b_i * t * d / 2
609 } else {
610 i_cs
611 };
612 for i_h in 0..h {
613 let i1 = i_t * h * d + i_h * d + i_d;
614 let i2 = i1 + d / 2;
615 dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
616 dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
617 }
618 }
619 }
620 });
621 let storage = candle::WithDType::to_cpu_storage_owned(dst);
622 Ok((storage, (b, t, h, d).into()))
623 }
624
625 use candle::backend::BackendStorage;
626 use CpuStorage::{BF16, F16, F32, F64};
627 match (s1, s2, s3) {
628 (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
629 (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
630 (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
631 (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
632 _ => candle::bail!(
633 "unsupported dtype for rope {:?} {:?} {:?}",
634 s1.dtype(),
635 s2.dtype(),
636 s3.dtype()
637 ),
638 }
639 }
640
641 #[cfg(feature = "cuda")]
642 fn cuda_fwd(
643 &self,
644 s1: &candle::CudaStorage,
645 l1: &Layout,
646 s2: &candle::CudaStorage,
647 l2: &Layout,
648 s3: &candle::CudaStorage,
649 l3: &Layout,
650 ) -> Result<(candle::CudaStorage, Shape)> {
651 use candle::cuda_backend::cudarc::driver::{
652 CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
653 };
654 use candle::cuda_backend::{kernel_name, kernels, WrapErr};
655 use candle::{CudaDevice, WithDType};
656
657 fn inner<T: DeviceRepr + WithDType>(
658 src: &CudaSlice<T>,
659 l_src: &Layout,
660 cos: &CudaSlice<T>,
661 l_cos: &Layout,
662 sin: &CudaSlice<T>,
663 l_sin: &Layout,
664 dev: &CudaDevice,
665 ) -> Result<CudaSlice<T>> {
666 let src = match l_src.contiguous_offsets() {
667 None => candle::bail!("src input has to be contiguous"),
668 Some((o1, o2)) => src.slice(o1..o2),
669 };
670 let cos = match l_cos.contiguous_offsets() {
671 None => candle::bail!("cos input has to be contiguous"),
672 Some((o1, o2)) => cos.slice(o1..o2),
673 };
674 let sin = match l_sin.contiguous_offsets() {
675 None => candle::bail!("sin input has to be contiguous"),
676 Some((o1, o2)) => sin.slice(o1..o2),
677 };
678 let (b, t, h, d) = l_src.shape().dims4()?;
679 let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
680 (h * t * d) as u32
681 } else {
682 0u32
683 };
684 let el = b * h * t * d;
685 let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
686 let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
687 let dst = unsafe { dev.alloc::<T>(el)? };
689 let mut builder = func.builder();
690 builder.arg(&src);
691 builder.arg(&cos);
692 builder.arg(&sin);
693 builder.arg(&dst);
694 candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32, stride_b);
695 unsafe { builder.launch(cfg) }.w()?;
697 Ok(dst)
698 }
699
700 use candle::backend::BackendStorage;
701 use candle::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
702 let dev = s1.device();
703 let slice = match (&s1.slice, &s2.slice, &s3.slice) {
704 (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
705 (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
706 (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
707 (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
708 _ => candle::bail!(
709 "unsupported dtype for rope {:?} {:?} {:?}",
710 s1.dtype(),
711 s2.dtype(),
712 s3.dtype()
713 ),
714 };
715 let dst = candle::cuda_backend::CudaStorage {
716 slice,
717 device: dev.clone(),
718 };
719 Ok((dst, l1.shape().clone()))
720 }
721
722 #[cfg(feature = "metal")]
723 fn metal_fwd(
724 &self,
725 src: &candle::MetalStorage,
726 l_src: &Layout,
727 cos: &candle::MetalStorage,
728 l_cos: &Layout,
729 sin: &candle::MetalStorage,
730 l_sin: &Layout,
731 ) -> Result<(candle::MetalStorage, Shape)> {
732 use candle::backend::BackendStorage;
733 let device = src.device();
734 let encoder = device.command_encoder()?;
735 encoder.set_label("rope_thd");
736 let kernels = device.kernels();
737 if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
738 candle::bail!(
739 "dtype mismatch in rope {:?} {:?} {:?}",
740 src.dtype(),
741 cos.dtype(),
742 sin.dtype()
743 )
744 }
745 let name = match src.dtype() {
746 candle::DType::F32 => "rope_thd_f32",
747 candle::DType::F16 => "rope_thd_f16",
748 candle::DType::BF16 => "rope_thd_bf16",
749 dtype => candle::bail!("rope_thd is not implemented for {dtype:?}"),
750 };
751 let (b, t, h, d) = l_src.shape().dims4()?;
752 let stride_b = if l_cos.dims().len() == 3 && l_sin.dims().len() == 3 {
753 h * t * d
754 } else {
755 0usize
756 };
757 let el = b * h * t * d;
758 let output = device.new_buffer(el, src.dtype(), "rope_thd")?;
759 candle_metal_kernels::call_rope_thd(
760 device.metal_device(),
761 &encoder,
762 kernels,
763 name,
764 b,
765 t,
766 h,
767 d,
768 stride_b,
769 src.buffer(),
770 l_src.start_offset() * src.dtype().size_in_bytes(),
771 cos.buffer(),
772 l_cos.start_offset() * cos.dtype().size_in_bytes(),
773 sin.buffer(),
774 l_sin.start_offset() * sin.dtype().size_in_bytes(),
775 &output,
776 )
777 .map_err(candle::Error::wrap)?;
778 let out = candle::MetalStorage::new(output, device.clone(), el, src.dtype());
779 Ok((out, l_src.shape().clone()))
780 }
781}
782
783pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
784 let (b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
785 let (cos_seq_len, cos_n_embd) = rope_check_cs(cos, b_sz)?;
786 let (sin_seq_len, sin_n_embd) = rope_check_cs(sin, b_sz)?;
787 if cos_n_embd * 2 != n_embd
788 || sin_n_embd * 2 != n_embd
789 || seq_len > cos_seq_len
790 || seq_len > sin_seq_len
791 {
792 candle::bail!(
793 "inconsistent last dim size in rope {:?} {:?} {:?}",
794 xs.shape(),
795 cos.shape(),
796 sin.shape()
797 )
798 }
799 if !xs.is_contiguous() {
800 candle::bail!("xs has to be contiguous in rope")
801 }
802 if !cos.is_contiguous() {
803 candle::bail!("cos has to be contiguous in rope")
804 }
805 if !sin.is_contiguous() {
806 candle::bail!("sin has to be contiguous in rope")
807 }
808 xs.apply_op3_no_bwd(cos, sin, &RotaryEmbThd)
809}