1use crate::core::{CpuStorage, Layout, Result, Shape, Tensor, D};
4use rayon::prelude::*;
5
6#[derive(Debug, Clone)]
12struct RotaryEmbI;
13
14impl crate::core::CustomOp3 for RotaryEmbI {
15 fn name(&self) -> &'static str {
16 "rotary-emb-int"
17 }
18
19 fn cpu_fwd(
20 &self,
21 s1: &CpuStorage,
22 l1: &Layout,
23 s2: &CpuStorage,
24 l2: &Layout,
25 s3: &CpuStorage,
26 l3: &Layout,
27 ) -> Result<(CpuStorage, Shape)> {
28 fn inner<T: crate::core::WithDType + num_traits::Float>(
29 src: &[T],
30 l_src: &Layout,
31 cos: &[T],
32 l_cos: &Layout,
33 sin: &[T],
34 l_sin: &Layout,
35 ) -> Result<(CpuStorage, Shape)> {
36 let src = match l_src.contiguous_offsets() {
37 None => crate::bail!("input src has to be contiguous"),
38 Some((o1, o2)) => &src[o1..o2],
39 };
40 let cos = match l_cos.contiguous_offsets() {
41 None => crate::bail!("input cos has to be contiguous"),
42 Some((o1, o2)) => &cos[o1..o2],
43 };
44 let sin = match l_sin.contiguous_offsets() {
45 None => crate::bail!("input sin has to be contiguous"),
46 Some((o1, o2)) => &sin[o1..o2],
47 };
48 let (b, h, t, d) = l_src.shape().dims4()?;
49 let el_count = b * h * t * d;
50 let mut dst = vec![T::zero(); el_count];
51 src.par_chunks(t * d)
52 .zip(dst.par_chunks_mut(t * d))
53 .for_each(|(src, dst)| {
54 for i_over_2 in 0..t * d / 2 {
55 let i = 2 * i_over_2;
56 dst[i] = src[i] * cos[i_over_2] - src[i + 1] * sin[i_over_2];
57 dst[i + 1] = src[i] * sin[i_over_2] + src[i + 1] * cos[i_over_2];
58 }
59 });
60 let storage = crate::core::WithDType::to_cpu_storage_owned(dst);
61 Ok((storage, (b, h, t, d).into()))
62 }
63
64 use crate::core::backend::BackendStorage;
65 use CpuStorage::{BF16, F16, F32, F64};
66 match (s1, s2, s3) {
67 (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
68 (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
69 (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
70 (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
71 _ => crate::bail!(
72 "unsupported dtype for rope {:?} {:?} {:?}",
73 s1.dtype(),
74 s2.dtype(),
75 s3.dtype()
76 ),
77 }
78 }
79
80 #[cfg(feature = "cuda")]
81 fn cuda_fwd(
82 &self,
83 s1: &crate::core::CudaStorage,
84 l1: &Layout,
85 s2: &crate::core::CudaStorage,
86 l2: &Layout,
87 s3: &crate::core::CudaStorage,
88 l3: &Layout,
89 ) -> Result<(crate::core::CudaStorage, Shape)> {
90 use crate::core::cuda_backend::cudarc::driver::{
91 CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
92 };
93 use crate::core::cuda_backend::{kernel_name, kernels, WrapErr};
94 use crate::core::{CudaDevice, WithDType};
95
96 fn inner<T: DeviceRepr + WithDType>(
97 src: &CudaSlice<T>,
98 l_src: &Layout,
99 cos: &CudaSlice<T>,
100 l_cos: &Layout,
101 sin: &CudaSlice<T>,
102 l_sin: &Layout,
103 dev: &CudaDevice,
104 ) -> Result<CudaSlice<T>> {
105 let src = match l_src.contiguous_offsets() {
106 None => crate::bail!("src input has to be contiguous"),
107 Some((o1, o2)) => src.slice(o1..o2),
108 };
109 let cos = match l_cos.contiguous_offsets() {
110 None => crate::bail!("cos input has to be contiguous"),
111 Some((o1, o2)) => cos.slice(o1..o2),
112 };
113 let sin = match l_sin.contiguous_offsets() {
114 None => crate::bail!("sin input has to be contiguous"),
115 Some((o1, o2)) => sin.slice(o1..o2),
116 };
117 let (b, h, t, d) = l_src.shape().dims4()?;
118 let el = b * h * t * d;
119 let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
120 let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), kernels::REDUCE)?;
121 let dst = unsafe { dev.alloc::<T>(el) }.w()?;
123 let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32);
124 unsafe { func.launch(cfg, params) }.w()?;
126 Ok(dst)
127 }
128
129 use crate::core::backend::BackendStorage;
130 use crate::core::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
131 let dev = s1.device();
132 let slice = match (&s1.slice, &s2.slice, &s3.slice) {
133 (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
134 (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
135 (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
136 (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
137 _ => crate::bail!(
138 "unsupported dtype for rope {:?} {:?} {:?}",
139 s1.dtype(),
140 s2.dtype(),
141 s3.dtype()
142 ),
143 };
144 let dst = crate::core::cuda_backend::CudaStorage {
145 slice,
146 device: dev.clone(),
147 };
148 Ok((dst, l1.shape().clone()))
149 }
150
151 #[cfg(feature = "metal")]
152 fn metal_fwd(
153 &self,
154 src: &crate::core::MetalStorage,
155 l_src: &Layout,
156 cos: &crate::core::MetalStorage,
157 l_cos: &Layout,
158 sin: &crate::core::MetalStorage,
159 l_sin: &Layout,
160 ) -> Result<(crate::core::MetalStorage, Shape)> {
161 use crate::core::backend::BackendStorage;
162 let device = src.device();
163 let command_buffer = device.command_buffer()?;
164 let kernels = device.kernels();
165 if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
166 crate::bail!(
167 "dtype mismatch in rope-i {:?} {:?} {:?}",
168 src.dtype(),
169 cos.dtype(),
170 sin.dtype()
171 )
172 }
173 let name = match src.dtype() {
174 crate::core::DType::F32 => "rope_i_f32",
175 crate::core::DType::F16 => "rope_i_f16",
176 crate::core::DType::BF16 => "rope_i_bf16",
177 dtype => crate::bail!("rope-i is not implemented for {dtype:?}"),
178 };
179 let (b, h, t, d) = l_src.shape().dims4()?;
180 let el = b * h * t * d;
181 let output = device.new_buffer(el, src.dtype(), "rope-i")?;
182 crate::metal_kernels::call_rope_i(
183 device.metal_device(),
184 &command_buffer,
185 kernels,
186 name,
187 b * h,
188 t * d,
189 src.buffer(),
190 l_src.start_offset() * src.dtype().size_in_bytes(),
191 cos.buffer(),
192 l_cos.start_offset() * cos.dtype().size_in_bytes(),
193 sin.buffer(),
194 l_sin.start_offset() * sin.dtype().size_in_bytes(),
195 &output,
196 )
197 .map_err(crate::core::Error::wrap)?;
198 let out = crate::core::MetalStorage::new(output, device.clone(), el, src.dtype());
199 Ok((out, l_src.shape().clone()))
200 }
201}
202
203pub fn rope_i(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
204 let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
205 let (cos_seq_len, cos_n_embd) = cos.dims2()?;
206 let (sin_seq_len, sin_n_embd) = cos.dims2()?;
207 if cos_n_embd * 2 != n_embd
208 || sin_n_embd * 2 != n_embd
209 || seq_len > cos_seq_len
210 || seq_len > sin_seq_len
211 {
212 crate::bail!(
213 "inconsistent last dim size in rope {:?} {:?} {:?}",
214 xs.shape(),
215 cos.shape(),
216 sin.shape()
217 )
218 }
219 if !xs.is_contiguous() {
220 crate::bail!("xs has to be contiguous in rope")
221 }
222 if !cos.is_contiguous() {
223 crate::bail!("cos has to be contiguous in rope")
224 }
225 if !sin.is_contiguous() {
226 crate::bail!("sin has to be contiguous in rope")
227 }
228 xs.apply_op3_no_bwd(cos, sin, &RotaryEmbI)
229}
230
231pub fn rope_i_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
232 let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
233 let cos = cos
234 .narrow(0, 0, seq_len)?
235 .reshape((seq_len, n_embd / 2, 1))?;
236 let sin = sin
237 .narrow(0, 0, seq_len)?
238 .reshape((seq_len, n_embd / 2, 1))?;
239 let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
240 let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?;
241 let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
242 let x0 = x.narrow(D::Minus1, 0, 1)?;
243 let x1 = x.narrow(D::Minus1, 1, 1)?;
244 let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
245 let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
246 let rope = Tensor::cat(&[y0, y1], D::Minus1)?;
247 let rope = rope.flatten_from(D::Minus2)?;
248 Ok(rope)
249}
250
251#[derive(Debug, Clone)]
253struct RotaryEmb;
254
255impl crate::core::CustomOp3 for RotaryEmb {
256 fn name(&self) -> &'static str {
257 "rotary-emb"
258 }
259
260 fn cpu_fwd(
261 &self,
262 s1: &CpuStorage,
263 l1: &Layout,
264 s2: &CpuStorage,
265 l2: &Layout,
266 s3: &CpuStorage,
267 l3: &Layout,
268 ) -> Result<(CpuStorage, Shape)> {
269 fn inner<T: crate::core::WithDType + num_traits::Float>(
270 src: &[T],
271 l_src: &Layout,
272 cos: &[T],
273 l_cos: &Layout,
274 sin: &[T],
275 l_sin: &Layout,
276 ) -> Result<(CpuStorage, Shape)> {
277 let src = match l_src.contiguous_offsets() {
278 None => crate::bail!("input src has to be contiguous"),
279 Some((o1, o2)) => &src[o1..o2],
280 };
281 let cos = match l_cos.contiguous_offsets() {
282 None => crate::bail!("input cos has to be contiguous"),
283 Some((o1, o2)) => &cos[o1..o2],
284 };
285 let sin = match l_sin.contiguous_offsets() {
286 None => crate::bail!("input sin has to be contiguous"),
287 Some((o1, o2)) => &sin[o1..o2],
288 };
289 let (b, h, t, d) = l_src.shape().dims4()?;
290 let el_count = b * h * t * d;
291 let mut dst = vec![T::zero(); el_count];
292 src.par_chunks(t * d)
293 .zip(dst.par_chunks_mut(t * d))
294 .for_each(|(src, dst)| {
295 for i_t in 0..t {
296 for i_d in 0..d / 2 {
297 let i1 = i_t * d + i_d;
298 let i2 = i1 + d / 2;
299 let i_cs = i_t * (d / 2) + i_d;
300 dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
301 dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
302 }
303 }
304 });
305 let storage = crate::core::WithDType::to_cpu_storage_owned(dst);
306 Ok((storage, (b, h, t, d).into()))
307 }
308
309 use crate::core::backend::BackendStorage;
310 use CpuStorage::{BF16, F16, F32, F64};
311 match (s1, s2, s3) {
312 (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
313 (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
314 (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
315 (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
316 _ => crate::bail!(
317 "unsupported dtype for rope {:?} {:?} {:?}",
318 s1.dtype(),
319 s2.dtype(),
320 s3.dtype()
321 ),
322 }
323 }
324
325 #[cfg(feature = "cuda")]
326 fn cuda_fwd(
327 &self,
328 s1: &crate::core::CudaStorage,
329 l1: &Layout,
330 s2: &crate::core::CudaStorage,
331 l2: &Layout,
332 s3: &crate::core::CudaStorage,
333 l3: &Layout,
334 ) -> Result<(crate::core::CudaStorage, Shape)> {
335 use crate::core::cuda_backend::cudarc::driver::{
336 CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
337 };
338 use crate::core::cuda_backend::{kernel_name, kernels, WrapErr};
339 use crate::core::{CudaDevice, WithDType};
340
341 fn inner<T: DeviceRepr + WithDType>(
342 src: &CudaSlice<T>,
343 l_src: &Layout,
344 cos: &CudaSlice<T>,
345 l_cos: &Layout,
346 sin: &CudaSlice<T>,
347 l_sin: &Layout,
348 dev: &CudaDevice,
349 ) -> Result<CudaSlice<T>> {
350 let src = match l_src.contiguous_offsets() {
351 None => crate::bail!("src input has to be contiguous"),
352 Some((o1, o2)) => src.slice(o1..o2),
353 };
354 let cos = match l_cos.contiguous_offsets() {
355 None => crate::bail!("cos input has to be contiguous"),
356 Some((o1, o2)) => cos.slice(o1..o2),
357 };
358 let sin = match l_sin.contiguous_offsets() {
359 None => crate::bail!("sin input has to be contiguous"),
360 Some((o1, o2)) => sin.slice(o1..o2),
361 };
362 let (b, h, t, d) = l_src.shape().dims4()?;
363 let el = b * h * t * d;
364 let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
365 let func = dev.get_or_load_func(&kernel_name::<T>("rope"), kernels::REDUCE)?;
366 let dst = unsafe { dev.alloc::<T>(el) }.w()?;
368 let params = (
369 &src,
370 &cos,
371 &sin,
372 &dst,
373 (b * h) as u32,
374 (t * d) as u32,
375 d as u32,
376 );
377 unsafe { func.launch(cfg, params) }.w()?;
379 Ok(dst)
380 }
381
382 use crate::core::backend::BackendStorage;
383 use crate::core::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
384 let dev = s1.device();
385 let slice = match (&s1.slice, &s2.slice, &s3.slice) {
386 (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
387 (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
388 (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
389 (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
390 _ => crate::bail!(
391 "unsupported dtype for rope {:?} {:?} {:?}",
392 s1.dtype(),
393 s2.dtype(),
394 s3.dtype()
395 ),
396 };
397 let dst = crate::core::cuda_backend::CudaStorage {
398 slice,
399 device: dev.clone(),
400 };
401 Ok((dst, l1.shape().clone()))
402 }
403
404 #[cfg(feature = "metal")]
405 fn metal_fwd(
406 &self,
407 src: &crate::core::MetalStorage,
408 l_src: &Layout,
409 cos: &crate::core::MetalStorage,
410 l_cos: &Layout,
411 sin: &crate::core::MetalStorage,
412 l_sin: &Layout,
413 ) -> Result<(crate::core::MetalStorage, Shape)> {
414 use crate::core::backend::BackendStorage;
415 let device = src.device();
416 let command_buffer = device.command_buffer()?;
417 let kernels = device.kernels();
418 if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
419 crate::bail!(
420 "dtype mismatch in rope {:?} {:?} {:?}",
421 src.dtype(),
422 cos.dtype(),
423 sin.dtype()
424 )
425 }
426 let name = match src.dtype() {
427 crate::core::DType::F32 => "rope_f32",
428 crate::core::DType::F16 => "rope_f16",
429 crate::core::DType::BF16 => "rope_bf16",
430 dtype => crate::bail!("rope is not implemented for {dtype:?}"),
431 };
432 let (b, h, t, d) = l_src.shape().dims4()?;
433 let el = b * h * t * d;
434 let output = device.new_buffer(el, src.dtype(), "rope-i")?;
435 crate::metal_kernels::call_rope(
436 device.metal_device(),
437 &command_buffer,
438 kernels,
439 name,
440 b * h,
441 t * d,
442 d,
443 src.buffer(),
444 l_src.start_offset() * src.dtype().size_in_bytes(),
445 cos.buffer(),
446 l_cos.start_offset() * cos.dtype().size_in_bytes(),
447 sin.buffer(),
448 l_sin.start_offset() * sin.dtype().size_in_bytes(),
449 &output,
450 )
451 .map_err(crate::core::Error::wrap)?;
452 let out = crate::core::MetalStorage::new(output, device.clone(), el, src.dtype());
453 Ok((out, l_src.shape().clone()))
454 }
455}
456
457pub fn rope(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
458 let (_b_sz, _n_head, seq_len, n_embd) = xs.dims4()?;
459 let (cos_seq_len, cos_n_embd) = cos.dims2()?;
460 let (sin_seq_len, sin_n_embd) = sin.dims2()?;
461 if cos_n_embd * 2 != n_embd
462 || sin_n_embd * 2 != n_embd
463 || seq_len > cos_seq_len
464 || seq_len > sin_seq_len
465 {
466 crate::bail!(
467 "inconsistent last dim size in rope {:?} {:?} {:?}",
468 xs.shape(),
469 cos.shape(),
470 sin.shape()
471 )
472 }
473 if !xs.is_contiguous() {
474 crate::bail!("xs has to be contiguous in rope")
475 }
476 if !cos.is_contiguous() {
477 crate::bail!("cos has to be contiguous in rope")
478 }
479 if !sin.is_contiguous() {
480 crate::bail!("sin has to be contiguous in rope")
481 }
482 xs.apply_op3_no_bwd(cos, sin, &RotaryEmb)
483}
484
485fn rotate_half(xs: &Tensor) -> Result<Tensor> {
486 let last_dim = xs.dim(D::Minus1)?;
487 let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
488 let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
489 Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
490}
491
492pub fn rope_slow(x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
493 let (_b_sz, _h, seq_len, _n_embd) = x.dims4()?;
494 let cos = Tensor::cat(&[cos, cos], D::Minus1)?;
495 let sin = Tensor::cat(&[sin, sin], D::Minus1)?;
496 let cos = cos.narrow(0, 0, seq_len)?;
497 let sin = sin.narrow(0, 0, seq_len)?;
498 let cos = cos.unsqueeze(0)?.unsqueeze(0)?;
499 let sin = sin.unsqueeze(0)?.unsqueeze(0)?;
500 x.broadcast_mul(&cos)? + rotate_half(x)?.broadcast_mul(&sin)?
501}
502
503#[derive(Debug, Clone)]
505struct RotaryEmbThd;
506
507impl crate::core::CustomOp3 for RotaryEmbThd {
508 fn name(&self) -> &'static str {
509 "rotary-emb"
510 }
511
512 fn cpu_fwd(
513 &self,
514 s1: &CpuStorage,
515 l1: &Layout,
516 s2: &CpuStorage,
517 l2: &Layout,
518 s3: &CpuStorage,
519 l3: &Layout,
520 ) -> Result<(CpuStorage, Shape)> {
521 fn inner<T: crate::core::WithDType + num_traits::Float>(
522 src: &[T],
523 l_src: &Layout,
524 cos: &[T],
525 l_cos: &Layout,
526 sin: &[T],
527 l_sin: &Layout,
528 ) -> Result<(CpuStorage, Shape)> {
529 let src = match l_src.contiguous_offsets() {
530 None => crate::bail!("input src has to be contiguous"),
531 Some((o1, o2)) => &src[o1..o2],
532 };
533 let cos = match l_cos.contiguous_offsets() {
534 None => crate::bail!("input cos has to be contiguous"),
535 Some((o1, o2)) => &cos[o1..o2],
536 };
537 let sin = match l_sin.contiguous_offsets() {
538 None => crate::bail!("input sin has to be contiguous"),
539 Some((o1, o2)) => &sin[o1..o2],
540 };
541 let (b, t, h, d) = l_src.shape().dims4()?;
542 let el_count = b * h * t * d;
543 let mut dst = vec![T::zero(); el_count];
544 src.par_chunks(t * h * d)
545 .zip(dst.par_chunks_mut(t * h * d))
546 .for_each(|(src, dst)| {
547 for i_t in 0..t {
548 for i_d in 0..d / 2 {
549 let i_cs = i_t * (d / 2) + i_d;
550 for i_h in 0..h {
551 let i1 = i_t * h * d + i_h * d + i_d;
552 let i2 = i1 + d / 2;
553 dst[i1] = src[i1] * cos[i_cs] - src[i2] * sin[i_cs];
554 dst[i2] = src[i1] * sin[i_cs] + src[i2] * cos[i_cs];
555 }
556 }
557 }
558 });
559 let storage = crate::core::WithDType::to_cpu_storage_owned(dst);
560 Ok((storage, (b, t, h, d).into()))
561 }
562
563 use crate::core::backend::BackendStorage;
564 use CpuStorage::{BF16, F16, F32, F64};
565 match (s1, s2, s3) {
566 (BF16(s1), BF16(s2), BF16(s3)) => inner(s1, l1, s2, l2, s3, l3),
567 (F16(s1), F16(s2), F16(s3)) => inner(s1, l1, s2, l2, s3, l3),
568 (F32(s1), F32(s2), F32(s3)) => inner(s1, l1, s2, l2, s3, l3),
569 (F64(s1), F64(s2), F64(s3)) => inner(s1, l1, s2, l2, s3, l3),
570 _ => crate::bail!(
571 "unsupported dtype for rope {:?} {:?} {:?}",
572 s1.dtype(),
573 s2.dtype(),
574 s3.dtype()
575 ),
576 }
577 }
578
579 #[cfg(feature = "cuda")]
580 fn cuda_fwd(
581 &self,
582 s1: &crate::core::CudaStorage,
583 l1: &Layout,
584 s2: &crate::core::CudaStorage,
585 l2: &Layout,
586 s3: &crate::core::CudaStorage,
587 l3: &Layout,
588 ) -> Result<(crate::core::CudaStorage, Shape)> {
589 use crate::core::cuda_backend::cudarc::driver::{
590 CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
591 };
592 use crate::core::cuda_backend::{kernel_name, kernels, WrapErr};
593 use crate::core::{CudaDevice, WithDType};
594
595 fn inner<T: DeviceRepr + WithDType>(
596 src: &CudaSlice<T>,
597 l_src: &Layout,
598 cos: &CudaSlice<T>,
599 l_cos: &Layout,
600 sin: &CudaSlice<T>,
601 l_sin: &Layout,
602 dev: &CudaDevice,
603 ) -> Result<CudaSlice<T>> {
604 let src = match l_src.contiguous_offsets() {
605 None => crate::bail!("src input has to be contiguous"),
606 Some((o1, o2)) => src.slice(o1..o2),
607 };
608 let cos = match l_cos.contiguous_offsets() {
609 None => crate::bail!("cos input has to be contiguous"),
610 Some((o1, o2)) => cos.slice(o1..o2),
611 };
612 let sin = match l_sin.contiguous_offsets() {
613 None => crate::bail!("sin input has to be contiguous"),
614 Some((o1, o2)) => sin.slice(o1..o2),
615 };
616 let (b, t, h, d) = l_src.shape().dims4()?;
617 let el = b * h * t * d;
618 let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
619 let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?;
620 let dst = unsafe { dev.alloc::<T>(el) }.w()?;
622 let params = (
623 &src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32,
624 );
625 unsafe { func.launch(cfg, params) }.w()?;
627 Ok(dst)
628 }
629
630 use crate::core::backend::BackendStorage;
631 use crate::core::cuda_backend::CudaStorageSlice::{BF16, F16, F32, F64};
632 let dev = s1.device();
633 let slice = match (&s1.slice, &s2.slice, &s3.slice) {
634 (BF16(s1), BF16(s2), BF16(s3)) => BF16(inner(s1, l1, s2, l2, s3, l3, dev)?),
635 (F16(s1), F16(s2), F16(s3)) => F16(inner(s1, l1, s2, l2, s3, l3, dev)?),
636 (F32(s1), F32(s2), F32(s3)) => F32(inner(s1, l1, s2, l2, s3, l3, dev)?),
637 (F64(s1), F64(s2), F64(s3)) => F64(inner(s1, l1, s2, l2, s3, l3, dev)?),
638 _ => crate::bail!(
639 "unsupported dtype for rope {:?} {:?} {:?}",
640 s1.dtype(),
641 s2.dtype(),
642 s3.dtype()
643 ),
644 };
645 let dst = crate::core::cuda_backend::CudaStorage {
646 slice,
647 device: dev.clone(),
648 };
649 Ok((dst, l1.shape().clone()))
650 }
651
652 #[cfg(feature = "metal")]
653 fn metal_fwd(
654 &self,
655 src: &crate::core::MetalStorage,
656 l_src: &Layout,
657 cos: &crate::core::MetalStorage,
658 l_cos: &Layout,
659 sin: &crate::core::MetalStorage,
660 l_sin: &Layout,
661 ) -> Result<(crate::core::MetalStorage, Shape)> {
662 use crate::core::backend::BackendStorage;
663 let device = src.device();
664 let command_buffer = device.command_buffer()?;
665 let kernels = device.kernels();
666 if cos.dtype() != src.dtype() || sin.dtype() != src.dtype() {
667 crate::bail!(
668 "dtype mismatch in rope {:?} {:?} {:?}",
669 src.dtype(),
670 cos.dtype(),
671 sin.dtype()
672 )
673 }
674 let name = match src.dtype() {
675 crate::core::DType::F32 => "rope_thd_f32",
676 crate::core::DType::F16 => "rope_thd_f16",
677 crate::core::DType::BF16 => "rope_thd_bf16",
678 dtype => crate::bail!("rope_thd is not implemented for {dtype:?}"),
679 };
680 let (b, t, h, d) = l_src.shape().dims4()?;
681 let el = b * h * t * d;
682 let output = device.new_buffer(el, src.dtype(), "rope-thd")?;
683 crate::metal_kernels::call_rope_thd(
684 device.metal_device(),
685 &command_buffer,
686 kernels,
687 name,
688 b,
689 t,
690 h,
691 d,
692 src.buffer(),
693 l_src.start_offset() * src.dtype().size_in_bytes(),
694 cos.buffer(),
695 l_cos.start_offset() * cos.dtype().size_in_bytes(),
696 sin.buffer(),
697 l_sin.start_offset() * sin.dtype().size_in_bytes(),
698 &output,
699 )
700 .map_err(crate::core::Error::wrap)?;
701 let out = crate::core::MetalStorage::new(output, device.clone(), el, src.dtype());
702 Ok((out, l_src.shape().clone()))
703 }
704}
705
706pub fn rope_thd(xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
707 let (_b_sz, seq_len, _n_head, n_embd) = xs.dims4()?;
708 let (cos_seq_len, cos_n_embd) = cos.dims2()?;
709 let (sin_seq_len, sin_n_embd) = sin.dims2()?;
710 if cos_n_embd * 2 != n_embd
711 || sin_n_embd * 2 != n_embd
712 || seq_len > cos_seq_len
713 || seq_len > sin_seq_len
714 {
715 crate::bail!(
716 "inconsistent last dim size in rope {:?} {:?} {:?}",
717 xs.shape(),
718 cos.shape(),
719 sin.shape()
720 )
721 }
722 if !xs.is_contiguous() {
723 crate::bail!("xs has to be contiguous in rope")
724 }
725 if !cos.is_contiguous() {
726 crate::bail!("cos has to be contiguous in rope")
727 }
728 if !sin.is_contiguous() {
729 crate::bail!("sin has to be contiguous in rope")
730 }
731 xs.apply_op3_no_bwd(cos, sin, &RotaryEmbThd)
732}