1use hanzo_ml::{
2 backend::BackendStorage, shape::Dim, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout,
3 Result, Shape, Tensor, WithDType,
4};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6use rayon::slice::ParallelSliceMut;
7
8use std::{
9 fmt::Display,
10 ops::{BitAnd, BitOr, BitXor, Not, Shl},
11};
12
13#[cfg(feature = "cuda")]
14use crate::utils::{ffi, slice_ptr};
15#[cfg(feature = "cuda")]
16use hanzo_ml::cuda::{cudarc::driver::DevicePtr, CudaStorage};
17#[cfg(feature = "cuda")]
18use std::ffi::c_void;
19
20#[cfg(feature = "metal")]
21use crate::metal_kernels::SortScratchCache; #[cfg(feature = "metal")]
23use std::sync::OnceLock;
24
25#[cfg(feature = "metal")]
26static SORT_SCRATCH_CACHE: OnceLock<SortScratchCache> = OnceLock::new();
27
28struct Leftshift(usize);
29
30impl Leftshift {
31 fn leftshift<T: WithDType + Shl<Output = T>>(&self, vs: &[T]) -> Vec<T> {
32 let offset = T::from_f64(self.0 as f64);
33 vs.into_par_iter().map(|v| *v << offset).collect()
34 }
35}
36
37impl CustomOp1 for Leftshift {
38 fn name(&self) -> &'static str {
39 "left"
40 }
41
42 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
43 match s1 {
44 CpuStorage::U8(vs1) => {
45 let vs1 = match l1.contiguous_offsets() {
46 Some((a, b)) => &vs1[a..b],
47 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
48 };
49 let result = self.leftshift(vs1);
50 let result = CpuStorage::U8(result);
51 Ok((result, l1.shape().clone()))
52 }
53 CpuStorage::I16(vs1) => {
54 let vs1 = match l1.contiguous_offsets() {
55 Some((a, b)) => &vs1[a..b],
56 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
57 };
58 let result = self.leftshift(vs1);
59 let result = CpuStorage::I16(result);
60 Ok((result, l1.shape().clone()))
61 }
62 CpuStorage::U32(vs1) => {
63 let vs1 = match l1.contiguous_offsets() {
64 Some((a, b)) => &vs1[a..b],
65 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
66 };
67 let result = self.leftshift(vs1);
68 let result = CpuStorage::U32(result);
69 Ok((result, l1.shape().clone()))
70 }
71 CpuStorage::I64(vs1) => {
72 let vs1 = match l1.contiguous_offsets() {
73 Some((a, b)) => &vs1[a..b],
74 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
75 };
76 let result = self.leftshift(vs1);
77 let result = CpuStorage::I64(result);
78 Ok((result, l1.shape().clone()))
79 }
80 CpuStorage::I32(vs1) => {
81 let vs1 = match l1.contiguous_offsets() {
82 Some((a, b)) => &vs1[a..b],
83 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
84 };
85 let result = self.leftshift(vs1);
86 let result = CpuStorage::I32(result);
87 Ok((result, l1.shape().clone()))
88 }
89 _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "leftshift")),
90 }
91 }
92
93 #[cfg(feature = "cuda")]
94 fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> {
95 if !l1.is_contiguous() {
96 hanzo_ml::bail!("Input tensor s1 must be contiguous");
97 }
98 let dev = s1.device().clone();
99 let (d_in1_ptr, _d_guard, elem_count) = match s1.dtype() {
100 DType::U8 => {
101 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
102 let elem_count = l1.shape().elem_count();
103 (d_in1 as *const c_void, d_in1_guard, elem_count)
104 }
105 DType::I32 => {
106 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
107 let elem_count = l1.shape().elem_count();
108 (d_in1 as *const c_void, d_in1_guard, elem_count)
109 }
110 other => {
111 return Err(Error::UnsupportedDTypeForOp(other, "leftshift"));
112 }
113 };
114 let dst = match s1.dtype() {
115 DType::U8 => {
116 let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
117 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
118 unsafe {
119 ffi::leftshift_u8(
120 d_in1_ptr,
121 d_out_ptr as *mut std::ffi::c_void,
122 u32::try_from(elem_count)?,
123 self.0 as i32,
124 )
125 };
126 drop(d_out_guard);
127 CudaStorage::wrap_cuda_slice(d_out, dev)
128 }
129 DType::I32 => {
130 let d_out = unsafe { dev.alloc::<i32>(elem_count) }?;
131 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
132 unsafe {
133 ffi::leftshift_i32(
134 d_in1_ptr,
135 d_out_ptr as *mut std::ffi::c_void,
136 u32::try_from(elem_count)?,
137 self.0 as i32,
138 )
139 };
140 drop(d_out_guard);
141 CudaStorage::wrap_cuda_slice(d_out, dev)
142 }
143 _ => unreachable!(),
144 };
145 Ok((dst, l1.shape().clone()))
146 }
147
148 #[cfg(feature = "metal")]
149 fn metal_fwd(
150 &self,
151 s1: &hanzo_ml::MetalStorage,
152 l1: &Layout,
153 ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
154 if !l1.is_contiguous() {
155 hanzo_ml::bail!("Input tensor s1 must be contiguous");
156 }
157
158 let encoder = s1.device().command_encoder()?;
159 encoder.set_label("bitwise-leftshift");
160
161 let device = s1.device();
162
163 let out_shape = l1.shape().clone();
164
165 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-leftshift")?;
166
167 crate::metal_kernels::call_bitwise_leftshift(
168 device.device(),
169 &encoder,
170 &crate::metal_kernels::Kernels::new(),
171 s1.dtype(),
172 s1.buffer(),
173 l1.start_offset(),
174 self.0 as u32,
175 out_shape.elem_count(),
176 &output,
177 )
178 .map_err(hanzo_ml::Error::wrap)?;
179
180 let newstorage =
181 hanzo_ml::MetalStorage::new(output, device.clone(), out_shape.elem_count(), s1.dtype());
182 Ok((newstorage, out_shape))
183 }
184}
185
186#[allow(dead_code)]
187pub trait LeftshiftOp {
188 fn leftshift(&self, n: usize) -> Result<Tensor>;
189}
190
191impl LeftshiftOp for Tensor {
192 fn leftshift(&self, n: usize) -> Result<Tensor> {
193 self.apply_op1_no_bwd(&Leftshift(n))
194 }
195}
196
197pub enum BitWiseBinaryOpEnum {
198 And,
199 Or,
200 Xor,
201}
202
203impl Display for BitWiseBinaryOpEnum {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 match self {
206 BitWiseBinaryOpEnum::And => write!(f, "And"),
207 BitWiseBinaryOpEnum::Or => write!(f, "Or"),
208 BitWiseBinaryOpEnum::Xor => write!(f, "Xor"),
209 }
210 }
211}
212
213pub enum BitWiseUnaryOpEnum {
214 Not,
215}
216
217impl Display for BitWiseUnaryOpEnum {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 match self {
220 BitWiseUnaryOpEnum::Not => write!(f, "Not"),
221 }
222 }
223}
224
225struct BitWise {
226 pub op: BitWiseBinaryOpEnum,
227}
228
229impl BitWise {
230 pub fn new(op: BitWiseBinaryOpEnum) -> Self {
231 Self { op }
232 }
233
234 fn bitwise<T: WithDType + BitAnd<Output = T> + BitOr<Output = T> + BitXor<Output = T>>(
235 &self,
236 vs1: &[T],
237 vs2: &[T],
238 ) -> Vec<T> {
239 vs1.into_par_iter()
240 .zip_eq(vs2)
241 .map(|(v1, v2)| match self.op {
242 BitWiseBinaryOpEnum::And => *v1 & *v2,
243 BitWiseBinaryOpEnum::Or => *v1 | *v2,
244 BitWiseBinaryOpEnum::Xor => *v1 ^ *v2,
245 })
246 .collect()
247 }
248}
249
250impl CustomOp2 for BitWise {
251 fn name(&self) -> &'static str {
252 "bitwise"
253 }
254
255 fn cpu_fwd(
256 &self,
257 s1: &CpuStorage,
258 l1: &Layout,
259 s2: &CpuStorage,
260 l2: &Layout,
261 ) -> Result<(CpuStorage, Shape)> {
262 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
263 return Err(Error::ShapeMismatchBinaryOp {
264 lhs: l1.shape().clone(),
265 rhs: l2.shape().clone(),
266 op: "bitwise-op",
267 });
268 }
269 if s1.dtype() != s2.dtype() {
270 return Err(Error::DTypeMismatchBinaryOp {
271 lhs: s1.dtype(),
272 rhs: s2.dtype(),
273 op: "bitwise-op",
274 });
275 }
276 if !l1.is_contiguous() {
277 hanzo_ml::bail!("Input tensor s1 must be contiguous");
278 }
279 if !l2.is_contiguous() {
280 hanzo_ml::bail!("Input tensor s2 must be contiguous");
281 }
282
283 match s1 {
284 CpuStorage::U8(vs1) => {
285 let vs2 = s2.as_slice::<u8>().unwrap();
286 let vs1 = match l1.contiguous_offsets() {
287 Some((a, b)) => &vs1[a..b],
288 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
289 };
290 let vs2 = match l2.contiguous_offsets() {
291 Some((a, b)) => &vs2[a..b],
292 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
293 };
294 let result = self.bitwise(vs1, vs2);
295 let result = CpuStorage::U8(result);
296 Ok((result, l1.shape().clone()))
297 }
298 CpuStorage::U32(vs1) => {
299 let vs2 = s2.as_slice::<u32>().unwrap();
300 let vs1 = match l1.contiguous_offsets() {
301 Some((a, b)) => &vs1[a..b],
302 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
303 };
304 let vs2 = match l2.contiguous_offsets() {
305 Some((a, b)) => &vs2[a..b],
306 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
307 };
308 let result = self.bitwise(vs1, vs2);
309 let result = CpuStorage::U32(result);
310 Ok((result, l1.shape().clone()))
311 }
312 CpuStorage::I64(vs1) => {
313 let vs2 = s2.as_slice::<i64>().unwrap();
314 let vs1 = match l1.contiguous_offsets() {
315 Some((a, b)) => &vs1[a..b],
316 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
317 };
318 let vs2 = match l2.contiguous_offsets() {
319 Some((a, b)) => &vs2[a..b],
320 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
321 };
322 let result = self.bitwise(vs1, vs2);
323 let result = CpuStorage::I64(result);
324 Ok((result, l1.shape().clone()))
325 }
326 CpuStorage::I16(vs1) => {
327 let vs2 = s2.as_slice::<i16>().unwrap();
328 let vs1 = match l1.contiguous_offsets() {
329 Some((a, b)) => &vs1[a..b],
330 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
331 };
332 let vs2 = match l2.contiguous_offsets() {
333 Some((a, b)) => &vs2[a..b],
334 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
335 };
336 let result = self.bitwise(vs1, vs2);
337 let result = CpuStorage::I16(result);
338 Ok((result, l1.shape().clone()))
339 }
340 CpuStorage::I32(vs1) => {
341 let vs2 = s2.as_slice::<i32>().unwrap();
342 let vs1 = match l1.contiguous_offsets() {
343 Some((a, b)) => &vs1[a..b],
344 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
345 };
346 let vs2 = match l2.contiguous_offsets() {
347 Some((a, b)) => &vs2[a..b],
348 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
349 };
350 let result = self.bitwise(vs1, vs2);
351 let result = CpuStorage::I32(result);
352 Ok((result, l1.shape().clone()))
353 }
354 _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
355 }
356 }
357
358 #[cfg(feature = "cuda")]
359 fn cuda_fwd(
360 &self,
361 s1: &CudaStorage,
362 l1: &Layout,
363 s2: &CudaStorage,
364 l2: &Layout,
365 ) -> Result<(CudaStorage, Shape)> {
366 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
367 return Err(Error::ShapeMismatchBinaryOp {
368 lhs: l1.shape().clone(),
369 rhs: l2.shape().clone(),
370 op: "bitwise-op",
371 });
372 }
373 if s1.dtype() != s2.dtype() {
374 return Err(Error::DTypeMismatchBinaryOp {
375 lhs: s1.dtype(),
376 rhs: s2.dtype(),
377 op: "bitwise-op",
378 });
379 }
380 if !l1.is_contiguous() {
381 hanzo_ml::bail!("Input tensor s1 must be contiguous");
382 }
383 if !l2.is_contiguous() {
384 hanzo_ml::bail!("Input tensor s2 must be contiguous");
385 }
386
387 let dev = s1.device().clone();
388 let (d_in1_ptr, d_in2_ptr, _d_in1_guard, _d_in2_guard, elem_count) = match s1.dtype() {
389 DType::U8 => {
390 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
391 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u8>()?, l2.start_offset());
392 let elem_count = l1.shape().elem_count();
393 (
394 d_in1 as *const std::ffi::c_void,
395 d_in2 as *const std::ffi::c_void,
396 d_in1_guard,
397 d_in2_guard,
398 elem_count,
399 )
400 }
401 DType::U32 => {
402 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u32>()?, l1.start_offset());
403 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u32>()?, l2.start_offset());
404 let elem_count = l1.shape().elem_count();
405 (
406 d_in1 as *const std::ffi::c_void,
407 d_in2 as *const std::ffi::c_void,
408 d_in1_guard,
409 d_in2_guard,
410 elem_count,
411 )
412 }
413 DType::I64 => {
414 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i64>()?, l1.start_offset());
415 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i64>()?, l2.start_offset());
416 let elem_count = l1.shape().elem_count();
417 (
418 d_in1 as *const std::ffi::c_void,
419 d_in2 as *const std::ffi::c_void,
420 d_in1_guard,
421 d_in2_guard,
422 elem_count,
423 )
424 }
425 DType::I32 => {
426 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
427 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i32>()?, l2.start_offset());
428 let elem_count = l1.shape().elem_count();
429 (
430 d_in1 as *const std::ffi::c_void,
431 d_in2 as *const std::ffi::c_void,
432 d_in1_guard,
433 d_in2_guard,
434 elem_count,
435 )
436 }
437 DType::I16 => {
438 let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i16>()?, l1.start_offset());
439 let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i16>()?, l2.start_offset());
440 let elem_count = l1.shape().elem_count();
441 (
442 d_in1 as *const std::ffi::c_void,
443 d_in2 as *const std::ffi::c_void,
444 d_in1_guard,
445 d_in2_guard,
446 elem_count,
447 )
448 }
449 other => {
450 return Err(Error::UnsupportedDTypeForOp(other, "bitwise"));
451 }
452 };
453 let dst = match s1.dtype() {
454 DType::U8 => {
455 let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
456 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
457 unsafe {
458 match self.op {
459 BitWiseBinaryOpEnum::And => ffi::bitwise_and_u8(
460 d_in1_ptr,
461 d_in2_ptr,
462 d_out_ptr as *mut c_void,
463 u32::try_from(elem_count)?,
464 ),
465 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u8(
466 d_in1_ptr,
467 d_in2_ptr,
468 d_out_ptr as *mut c_void,
469 u32::try_from(elem_count)?,
470 ),
471 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u8(
472 d_in1_ptr,
473 d_in2_ptr,
474 d_out_ptr as *mut c_void,
475 u32::try_from(elem_count)?,
476 ),
477 }
478 };
479 drop(d_out_guard);
480 CudaStorage::wrap_cuda_slice(d_out, dev)
481 }
482 DType::U32 => {
483 let d_out = unsafe { dev.alloc::<u32>(elem_count) }?;
484 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
485 unsafe {
486 match self.op {
487 BitWiseBinaryOpEnum::And => ffi::bitwise_and_u32(
488 d_in1_ptr,
489 d_in2_ptr,
490 d_out_ptr as *mut c_void,
491 u32::try_from(elem_count)?,
492 ),
493 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u32(
494 d_in1_ptr,
495 d_in2_ptr,
496 d_out_ptr as *mut c_void,
497 u32::try_from(elem_count)?,
498 ),
499 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u32(
500 d_in1_ptr,
501 d_in2_ptr,
502 d_out_ptr as *mut c_void,
503 u32::try_from(elem_count)?,
504 ),
505 }
506 };
507 drop(d_out_guard);
508 CudaStorage::wrap_cuda_slice(d_out, dev)
509 }
510 DType::I64 => {
511 let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
512 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
513 unsafe {
514 match self.op {
515 BitWiseBinaryOpEnum::And => ffi::bitwise_and_i64(
516 d_in1_ptr,
517 d_in2_ptr,
518 d_out_ptr as *mut c_void,
519 u32::try_from(elem_count)?,
520 ),
521 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i64(
522 d_in1_ptr,
523 d_in2_ptr,
524 d_out_ptr as *mut c_void,
525 u32::try_from(elem_count)?,
526 ),
527 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i64(
528 d_in1_ptr,
529 d_in2_ptr,
530 d_out_ptr as *mut c_void,
531 u32::try_from(elem_count)?,
532 ),
533 }
534 };
535 drop(d_out_guard);
536 CudaStorage::wrap_cuda_slice(d_out, dev)
537 }
538 DType::I32 => {
539 let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
540 let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
541 unsafe {
542 match self.op {
543 BitWiseBinaryOpEnum::And => ffi::bitwise_and_i32(
544 d_in1_ptr,
545 d_in2_ptr,
546 d_out_ptr as *mut c_void,
547 u32::try_from(elem_count)?,
548 ),
549 BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i32(
550 d_in1_ptr,
551 d_in2_ptr,
552 d_out_ptr as *mut c_void,
553 u32::try_from(elem_count)?,
554 ),
555 BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i32(
556 d_in1_ptr,
557 d_in2_ptr,
558 d_out_ptr as *mut c_void,
559 u32::try_from(elem_count)?,
560 ),
561 }
562 };
563 drop(d_out_guard);
564 CudaStorage::wrap_cuda_slice(d_out, dev)
565 }
566 _ => unreachable!(),
567 };
568 Ok((dst, l1.shape().clone()))
569 }
570
571 #[cfg(feature = "metal")]
572 fn metal_fwd(
573 &self,
574 s1: &hanzo_ml::MetalStorage,
575 l1: &Layout,
576 s2: &hanzo_ml::MetalStorage,
577 l2: &Layout,
578 ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
579 if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
580 return Err(Error::ShapeMismatchBinaryOp {
581 lhs: l1.shape().clone(),
582 rhs: l2.shape().clone(),
583 op: "bitwise-op",
584 });
585 }
586 if s1.dtype() != s2.dtype() {
587 return Err(Error::DTypeMismatchBinaryOp {
588 lhs: s1.dtype(),
589 rhs: s2.dtype(),
590 op: "bitwise-op",
591 });
592 }
593 if !l1.is_contiguous() {
594 hanzo_ml::bail!("Input tensor s1 must be contiguous");
595 }
596 if !l2.is_contiguous() {
597 hanzo_ml::bail!("Input tensor s2 must be contiguous");
598 }
599
600 let encoder = s1.device().command_encoder()?;
601 encoder.set_label("bitwise-op");
602
603 let device = s1.device();
604
605 let out_shape = l1.shape().clone();
606
607 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
608
609 match self.op {
610 BitWiseBinaryOpEnum::Or => crate::metal_kernels::call_bitwise_or(
611 device.device(),
612 &encoder,
613 &crate::metal_kernels::Kernels::new(),
614 s1.dtype(),
615 s1.buffer(),
616 s2.buffer(),
617 l1.start_offset() * s1.dtype().size_in_bytes(),
618 l2.start_offset() * s2.dtype().size_in_bytes(),
619 out_shape.elem_count(),
620 &output,
621 )
622 .map_err(hanzo_ml::Error::wrap)?,
623 BitWiseBinaryOpEnum::And => crate::metal_kernels::call_bitwise_and(
624 device.device(),
625 &encoder,
626 &crate::metal_kernels::Kernels::new(),
627 s1.dtype(),
628 s1.buffer(),
629 s2.buffer(),
630 l1.start_offset() * s1.dtype().size_in_bytes(),
631 l2.start_offset() * s2.dtype().size_in_bytes(),
632 out_shape.elem_count(),
633 &output,
634 )
635 .map_err(hanzo_ml::Error::wrap)?,
636 BitWiseBinaryOpEnum::Xor => crate::metal_kernels::call_bitwise_xor(
637 device.device(),
638 &encoder,
639 &crate::metal_kernels::Kernels::new(),
640 s1.dtype(),
641 s1.buffer(),
642 s2.buffer(),
643 l1.start_offset() * s1.dtype().size_in_bytes(),
644 l2.start_offset() * s2.dtype().size_in_bytes(),
645 out_shape.elem_count(),
646 &output,
647 )
648 .map_err(hanzo_ml::Error::wrap)?,
649 }
650
651 let newstorage =
652 hanzo_ml::MetalStorage::new(output, device.clone(), out_shape.elem_count(), s1.dtype());
653 Ok((newstorage, out_shape))
654 }
655}
656
657struct BitWiseUnary {
658 pub op: BitWiseUnaryOpEnum,
659}
660
661impl BitWiseUnary {
662 pub fn new(op: BitWiseUnaryOpEnum) -> Self {
663 Self { op }
664 }
665
666 fn bitwise<T: WithDType + Not<Output = T>>(&self, vs1: &[T]) -> Vec<T> {
667 vs1.into_par_iter()
668 .map(|v1| match self.op {
669 BitWiseUnaryOpEnum::Not => !*v1,
670 })
671 .collect()
672 }
673}
674
675impl CustomOp1 for BitWiseUnary {
676 fn name(&self) -> &'static str {
677 "bitwise-unary"
678 }
679
680 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
681 if !l1.is_contiguous() {
682 hanzo_ml::bail!("Input tensor s1 must be contiguous");
683 }
684
685 match s1 {
686 CpuStorage::U8(vs1) => {
687 let vs1 = match l1.contiguous_offsets() {
688 Some((a, b)) => &vs1[a..b],
689 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
690 };
691 let result = self.bitwise(vs1);
692 let result = CpuStorage::U8(result);
693 Ok((result, l1.shape().clone()))
694 }
695 CpuStorage::U32(vs1) => {
696 let vs1 = match l1.contiguous_offsets() {
697 Some((a, b)) => &vs1[a..b],
698 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
699 };
700 let result = self.bitwise(vs1);
701 let result = CpuStorage::U32(result);
702 Ok((result, l1.shape().clone()))
703 }
704 CpuStorage::I64(vs1) => {
705 let vs1 = match l1.contiguous_offsets() {
706 Some((a, b)) => &vs1[a..b],
707 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
708 };
709 let result = self.bitwise(vs1);
710 let result = CpuStorage::I64(result);
711 Ok((result, l1.shape().clone()))
712 }
713 CpuStorage::I16(vs1) => {
714 let vs1 = match l1.contiguous_offsets() {
715 Some((a, b)) => &vs1[a..b],
716 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
717 };
718 let result = self.bitwise(vs1);
719 let result = CpuStorage::I16(result);
720 Ok((result, l1.shape().clone()))
721 }
722 CpuStorage::I32(vs1) => {
723 let vs1 = match l1.contiguous_offsets() {
724 Some((a, b)) => &vs1[a..b],
725 None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
726 };
727 let result = self.bitwise(vs1);
728 let result = CpuStorage::I32(result);
729 Ok((result, l1.shape().clone()))
730 }
731 _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
732 }
733 }
734
735 #[cfg(feature = "cuda")]
736 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
737 todo!()
738 }
739
740 #[cfg(feature = "metal")]
741 fn metal_fwd(
742 &self,
743 s1: &hanzo_ml::MetalStorage,
744 l1: &Layout,
745 ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
746 if !l1.is_contiguous() {
747 hanzo_ml::bail!("Input tensor s1 must be contiguous");
748 }
749
750 let encoder = s1.device().command_encoder()?;
751 encoder.set_label("bitwise-unary-op");
752
753 let device = s1.device();
754
755 let out_shape = l1.shape().clone();
756
757 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
758
759 match self.op {
760 BitWiseUnaryOpEnum::Not => crate::metal_kernels::call_bitwise_not(
761 device.device(),
762 &encoder,
763 &crate::metal_kernels::Kernels::new(),
764 s1.dtype(),
765 s1.buffer(),
766 l1.start_offset() * s1.dtype().size_in_bytes(),
767 out_shape.elem_count(),
768 &output,
769 )
770 .map_err(hanzo_ml::Error::wrap)?,
771 }
772
773 let newstorage =
774 hanzo_ml::MetalStorage::new(output, device.clone(), out_shape.elem_count(), s1.dtype());
775 Ok((newstorage, out_shape))
776 }
777}
778
779#[allow(dead_code)]
780pub trait BitWiseOp {
781 fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor>;
782 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
783 fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor>;
784 fn bitwise_not(&self) -> Result<Tensor>;
785}
786
787impl BitWiseOp for Tensor {
788 fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
789 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::And))
790 }
791
792 fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
793 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Or))
794 }
795
796 fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
797 self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Xor))
798 }
799
800 fn bitwise_not(&self) -> Result<Tensor> {
801 self.apply_op1_no_bwd(&BitWiseUnary::new(BitWiseUnaryOpEnum::Not))
802 }
803}
804
805#[allow(unused)]
808struct ArgSort {
810 axis: usize,
811}
812
813#[allow(unused)]
814struct Sort {
816 axis: usize,
817}
818
819impl CustomOp1 for ArgSort {
820 fn name(&self) -> &'static str {
821 "argsort"
822 }
823
824 fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
826 hanzo_ml::bail!("ArgSort is not implemented for the CPU backend");
827 }
828
829 #[cfg(feature = "cuda")]
831 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
832 hanzo_ml::bail!("ArgSort is not implemented for the CUDA backend");
833 }
834
835 #[cfg(feature = "metal")]
837 fn metal_fwd(
838 &self,
839 s1: &hanzo_ml::MetalStorage,
840 l1: &Layout,
841 ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
842 if !l1.is_contiguous() {
844 hanzo_ml::bail!("Input tensor s1 must be contiguous");
845 }
846
847 let encoder = s1.device().command_encoder()?;
849 encoder.set_label("argsort");
850
851 let device = s1.device();
852 let out_shape = l1.shape().clone();
853 let elem_count = out_shape.elem_count();
854
855 let output = device.new_buffer(elem_count, hanzo_ml::DType::U32, "argsort")?;
857
858 let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
862
863 let dims = l1.dims();
864 let size_sorted_axis = dims[self.axis];
865 let n_rows = l1.shape().elem_count() / size_sorted_axis;
866
867 let tn = 4usize;
869 let mut bn = match size_sorted_axis.div_ceil(tn) {
870 v if v > 256 => 512,
871 v if v > 128 => 256,
872 v if v > 64 => 128,
873 v if v > 32 => 64,
874 _ => 32,
875 };
876 if bn == 512 && s1.dtype().size_in_bytes() > 4 {
877 bn = 256;
878 }
879 let n_per_block = bn * tn;
880 let n_blocks = size_sorted_axis.div_ceil(n_per_block);
881
882 let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
884
885 let sort_args = crate::metal_kernels::SortArgs {
889 axis: self.axis,
890 shape: l1.dims(),
891 strides: l1.stride(),
892 out_shape: l1.dims(), out_strides: l1.stride(),
894 in_contiguous: l1.is_contiguous(),
895 in_ty: s1.dtype(),
896 out_ty: hanzo_ml::DType::U32,
897 src: s1.buffer(),
898 src_offset: l1.start_offset(), dst: &output,
900 bn,
901 tn,
902 n_blocks,
903 };
904
905 crate::metal_kernels::call_argsort(
907 device.device(),
908 &encoder, &crate::metal_kernels::Kernels::new(),
910 &sort_args,
911 &scratch,
912 )
913 .map_err(hanzo_ml::Error::wrap)?;
914
915 let newstorage =
917 hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, hanzo_ml::DType::U32);
918 Ok((newstorage, out_shape))
919 }
920}
921
922impl CustomOp1 for Sort {
923 fn name(&self) -> &'static str {
924 "sort"
925 }
926
927 fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
929 hanzo_ml::bail!("Sort is not implemented for the CPU backend");
930 }
931
932 #[cfg(feature = "cuda")]
934 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
935 hanzo_ml::bail!("Sort is not implemented for the CUDA backend");
936 }
937
938 #[cfg(feature = "metal")]
940 fn metal_fwd(
941 &self,
942 s1: &hanzo_ml::MetalStorage,
943 l1: &Layout,
944 ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
945 if !l1.is_contiguous() {
947 hanzo_ml::bail!("Input tensor s1 must be contiguous");
948 }
949
950 let encoder = s1.device().command_encoder()?;
952 encoder.set_label("sort");
953
954 let device = s1.device();
955 let out_shape = l1.shape().clone();
956 let elem_count = out_shape.elem_count();
957
958 let output = device.new_buffer(elem_count, s1.dtype(), "sort")?;
960
961 let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
965
966 let dims = l1.dims();
967 let size_sorted_axis = dims[self.axis];
968 let n_rows = l1.shape().elem_count() / size_sorted_axis;
969
970 let tn = 4usize;
972 let mut bn = match size_sorted_axis.div_ceil(tn) {
973 v if v > 256 => 512,
974 v if v > 128 => 256,
975 v if v > 64 => 128,
976 v if v > 32 => 64,
977 _ => 32,
978 };
979 if bn == 512 && s1.dtype().size_in_bytes() > 4 {
980 bn = 256;
981 }
982 let n_per_block = bn * tn;
983 let n_blocks = size_sorted_axis.div_ceil(n_per_block);
984
985 let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
987
988 let sort_args = crate::metal_kernels::SortArgs {
992 axis: self.axis,
993 shape: l1.dims(),
994 strides: l1.stride(),
995 out_shape: l1.dims(), out_strides: l1.stride(),
997 in_contiguous: l1.is_contiguous(),
998 in_ty: s1.dtype(),
999 out_ty: s1.dtype(),
1000 src: s1.buffer(),
1001 src_offset: l1.start_offset(), dst: &output,
1003 bn,
1004 tn,
1005 n_blocks,
1006 };
1007
1008 crate::metal_kernels::call_sort(
1010 device.device(),
1011 &encoder, &crate::metal_kernels::Kernels::new(),
1013 &sort_args,
1014 &scratch,
1015 )
1016 .map_err(hanzo_ml::Error::wrap)?;
1017
1018 let newstorage =
1020 hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
1021 Ok((newstorage, out_shape))
1022 }
1023}
1024
1025pub trait SortOp {
1027 fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1029 fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1031}
1032
1033impl SortOp for Tensor {
1034 fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1035 if self.device().is_cpu() || self.device().is_cuda() {
1036 return self.arg_sort_last_dim(true);
1037 }
1038 self.apply_op1_no_bwd(&ArgSort {
1039 axis: axis.to_index(self.shape(), "argsort")?,
1040 })
1041 }
1042
1043 fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1044 if self.device().is_cpu() || self.device().is_cuda() {
1045 return Ok(self.sort_last_dim(true)?.0);
1046 }
1047 self.apply_op1_no_bwd(&Sort {
1048 axis: axis.to_index(self.shape(), "sort")?,
1049 })
1050 }
1051}
1052
1053struct NonZero;
1054
1055impl NonZero {
1056 fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
1058 let n = layout.dims().len();
1059 let mut result = Vec::new();
1060 let mut indices = vec![0u32; n];
1061 for (i, v) in vs.iter().enumerate() {
1062 if !v.is_zero() {
1063 let mut idx = i;
1064 for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
1065 let d = idx % dim;
1066 indices[dim_index] = u32::try_from(d).unwrap();
1067 idx /= dim;
1068 }
1069 result.extend_from_slice(&indices);
1070 }
1071 }
1072 result
1073 }
1074}
1075
1076#[cfg(all(feature = "cuda", not(feature = "cuda-13000")))]
1077mod cuda_ops_cccl2 {
1078 use super::*;
1079
1080 pub(super) fn count_nonzero_cuda(
1081 dtype: hanzo_ml::DType,
1082 d_in: *const c_void,
1083 n: u32,
1084 stream: hanzo_ml::cuda::cudarc::driver::sys::CUstream,
1085 ) -> u32 {
1086 unsafe {
1087 match dtype {
1088 hanzo_ml::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1089 hanzo_ml::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1090 hanzo_ml::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1091 hanzo_ml::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1092 hanzo_ml::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1093 hanzo_ml::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1094 hanzo_ml::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1095 hanzo_ml::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1096 hanzo_ml::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1097 _ => unreachable!(),
1098 }
1099 }
1100 }
1101
1102 #[allow(clippy::too_many_arguments)]
1103 pub(super) fn nonzero_cuda(
1104 dtype: hanzo_ml::DType,
1105 d_in: *const c_void,
1106 n: u32,
1107 num_nonzero: u32,
1108 dims: *const c_void,
1109 num_dims: u32,
1110 d_out: *mut c_void,
1111 stream: hanzo_ml::cuda::cudarc::driver::sys::CUstream,
1112 ) {
1113 unsafe {
1114 match dtype {
1115 hanzo_ml::DType::U8 => {
1116 ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1117 }
1118 hanzo_ml::DType::U32 => {
1119 ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1120 }
1121 hanzo_ml::DType::I64 => {
1122 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1123 }
1124 hanzo_ml::DType::I32 => {
1125 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1126 }
1127 hanzo_ml::DType::I16 => {
1128 ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1129 }
1130 hanzo_ml::DType::BF16 => {
1131 ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1132 }
1133 hanzo_ml::DType::F16 => {
1134 ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1135 }
1136 hanzo_ml::DType::F32 => {
1137 ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1138 }
1139 hanzo_ml::DType::F64 => {
1140 ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1141 }
1142 _ => unreachable!(),
1143 }
1144 }
1145 }
1146}
1147
1148#[cfg(all(feature = "cuda", feature = "cuda-13000"))]
1149mod cuda_ops_cccl3 {
1150 use super::*;
1151
1152 pub(super) fn count_nonzero_cuda(
1153 dtype: hanzo_ml::DType,
1154 d_in: *const c_void,
1155 n: u32,
1156 stream: hanzo_ml::cuda::cudarc::driver::sys::CUstream,
1157 ) -> u32 {
1158 unsafe {
1159 match dtype {
1160 hanzo_ml::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1161 hanzo_ml::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1162 hanzo_ml::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1163 hanzo_ml::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1164 hanzo_ml::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1165 hanzo_ml::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1166 hanzo_ml::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1167 hanzo_ml::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1168 hanzo_ml::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1169 _ => unreachable!(),
1170 }
1171 }
1172 }
1173
1174 #[allow(clippy::too_many_arguments)]
1175 pub(super) fn nonzero_cuda(
1176 dtype: hanzo_ml::DType,
1177 d_in: *const c_void,
1178 n: u32,
1179 num_nonzero: u32,
1180 dims: *const c_void,
1181 num_dims: u32,
1182 d_out: *mut c_void,
1183 stream: hanzo_ml::cuda::cudarc::driver::sys::CUstream,
1184 ) {
1185 unsafe {
1186 match dtype {
1187 hanzo_ml::DType::U8 => {
1188 ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1189 }
1190 hanzo_ml::DType::U32 => {
1191 ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1192 }
1193 hanzo_ml::DType::I64 => {
1194 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1195 }
1196 hanzo_ml::DType::I32 => {
1197 ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1198 }
1199 hanzo_ml::DType::I16 => {
1200 ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1201 }
1202 hanzo_ml::DType::BF16 => {
1203 ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1204 }
1205 hanzo_ml::DType::F16 => {
1206 ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1207 }
1208 hanzo_ml::DType::F32 => {
1209 ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1210 }
1211 hanzo_ml::DType::F64 => {
1212 ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1213 }
1214 _ => unreachable!(),
1215 }
1216 }
1217 }
1218}
1219
1220#[cfg(all(feature = "cuda", not(feature = "cuda-13000")))]
1221use cuda_ops_cccl2::{count_nonzero_cuda, nonzero_cuda};
1222#[cfg(all(feature = "cuda", feature = "cuda-13000"))]
1223use cuda_ops_cccl3::{count_nonzero_cuda, nonzero_cuda};
1224
1225impl CustomOp1 for NonZero {
1226 fn name(&self) -> &'static str {
1227 "nonzero"
1228 }
1229
1230 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
1231 if !layout.is_contiguous() {
1232 return Err(Error::RequiresContiguous { op: "nonzero" });
1233 }
1234 let result = match storage {
1235 hanzo_ml::CpuStorage::U8(vs) => self.nonzero(vs, layout),
1236 hanzo_ml::CpuStorage::U32(vs) => self.nonzero(vs, layout),
1237 hanzo_ml::CpuStorage::I16(vs) => self.nonzero(vs, layout),
1238 hanzo_ml::CpuStorage::I32(vs) => self.nonzero(vs, layout),
1239 hanzo_ml::CpuStorage::I64(vs) => self.nonzero(vs, layout),
1240 hanzo_ml::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
1241 hanzo_ml::CpuStorage::F16(vs) => self.nonzero(vs, layout),
1242 hanzo_ml::CpuStorage::F32(vs) => self.nonzero(vs, layout),
1243 hanzo_ml::CpuStorage::F64(vs) => self.nonzero(vs, layout),
1244 _ => unreachable!(),
1245 };
1246 let index_len = layout.dims().len();
1247 let result_len = result.len() / index_len;
1248 let result = CpuStorage::U32(result);
1249 let shape = Shape::from_dims(&[result_len, index_len]);
1250 Ok((result, shape))
1251 }
1252
1253 #[cfg(feature = "cuda")]
1254 fn cuda_fwd(
1255 &self,
1256 storage: &hanzo_ml::CudaStorage,
1257 layout: &Layout,
1258 ) -> Result<(hanzo_ml::CudaStorage, Shape)> {
1259 if !layout.is_contiguous() {
1260 return Err(hanzo_ml::Error::RequiresContiguous { op: "nonzero" });
1261 }
1262 let dev = storage.device().clone();
1263 let (d_in, _d_in_guard) = match storage.dtype() {
1264 hanzo_ml::DType::U8 => {
1265 let slice = storage.as_cuda_slice::<u8>()?;
1266 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1267 (d_in as *const std::ffi::c_void, d_in_guard)
1268 }
1269 hanzo_ml::DType::U32 => {
1270 let slice = storage.as_cuda_slice::<u32>()?;
1271 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1272 (d_in as *const std::ffi::c_void, d_in_guard)
1273 }
1274 hanzo_ml::DType::I32 => {
1275 let slice = storage.as_cuda_slice::<i32>()?;
1276 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1277 (d_in as *const std::ffi::c_void, d_in_guard)
1278 }
1279 hanzo_ml::DType::I16 => {
1280 let slice = storage.as_cuda_slice::<i16>()?;
1281 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1282 (d_in as *const std::ffi::c_void, d_in_guard)
1283 }
1284 hanzo_ml::DType::I64 => {
1285 let slice = storage.as_cuda_slice::<i64>()?;
1286 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1287 (d_in as *const std::ffi::c_void, d_in_guard)
1288 }
1289 hanzo_ml::DType::BF16 => {
1290 let slice = storage.as_cuda_slice::<half::bf16>()?;
1291 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1292 (d_in as *const std::ffi::c_void, d_in_guard)
1293 }
1294 hanzo_ml::DType::F16 => {
1295 let slice = storage.as_cuda_slice::<half::f16>()?;
1296 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1297 (d_in as *const std::ffi::c_void, d_in_guard)
1298 }
1299 hanzo_ml::DType::F32 => {
1300 let slice = storage.as_cuda_slice::<f32>()?;
1301 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1302 (d_in as *const std::ffi::c_void, d_in_guard)
1303 }
1304 hanzo_ml::DType::F64 => {
1305 let slice = storage.as_cuda_slice::<f64>()?;
1306 let (d_in, d_in_guard) = slice_ptr(slice, 0);
1307 (d_in as *const std::ffi::c_void, d_in_guard)
1308 }
1309 _ => unreachable!(),
1310 };
1311 let n = layout.shape().elem_count();
1312
1313 let num_nonzero = count_nonzero_cuda(
1314 storage.dtype(),
1315 d_in,
1316 u32::try_from(n)?,
1317 dev.cuda_stream().cu_stream(),
1318 );
1319 let d_out = unsafe { dev.alloc::<u32>(num_nonzero as usize * layout.dims().len()) }
1320 .map_err(|_| Error::Msg("Failed to allocate memory for nonzero result".to_string()))?;
1321 if num_nonzero != 0 {
1322 let (d_out, _d_out_guard) = d_out.device_ptr(d_out.stream());
1323 let dims = layout
1324 .dims()
1325 .iter()
1326 .map(|&x| u32::try_from(x).unwrap())
1327 .collect::<Vec<u32>>();
1328 let mut d_dims = unsafe { dev.alloc::<u32>(dims.len()) }?;
1329 dev.memcpy_htod(&dims, &mut d_dims)?;
1330 let (d_dims_ptr, _d_dims_guard) = d_dims.device_ptr(d_dims.stream());
1331 nonzero_cuda(
1332 storage.dtype(),
1333 d_in,
1334 u32::try_from(n)?,
1335 num_nonzero,
1336 d_dims_ptr as *const c_void,
1337 u32::try_from(layout.dims().len())?,
1338 d_out as *mut c_void,
1339 dev.cuda_stream().cu_stream(),
1340 );
1341 }
1342 let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]);
1343 let dst = hanzo_ml::CudaStorage::wrap_cuda_slice(d_out, dev);
1344 Ok((dst, shape))
1345 }
1346}
1347
1348pub trait NonZeroOp {
1349 fn nonzero(&self) -> Result<Tensor>;
1350}
1351
1352impl NonZeroOp for Tensor {
1353 #[cfg(feature = "metal")]
1354 fn nonzero(&self) -> Result<Tensor> {
1355 if !self.is_contiguous() {
1356 return Err(hanzo_ml::Error::RequiresContiguous { op: "nonzero" });
1357 }
1358 let original_device = self.device();
1359 self.to_device(&hanzo_ml::Device::Cpu)?
1360 .apply_op1_no_bwd(&NonZero)?
1361 .to_device(original_device)
1362 }
1363
1364 #[cfg(not(feature = "metal"))]
1365 fn nonzero(&self) -> Result<Tensor> {
1366 if !self.is_contiguous() {
1367 return Err(hanzo_ml::Error::RequiresContiguous { op: "nonzero" });
1368 }
1369 self.apply_op1_no_bwd(&NonZero)
1370 }
1371}
1372
1373struct CumSum {
1374 inclusive: bool,
1375 reverse: bool,
1376 axis: usize,
1377}
1378
1379impl CustomOp1 for CumSum {
1380 fn name(&self) -> &'static str {
1381 "cumsum"
1382 }
1383
1384 fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
1385 use std::ops::Add;
1386 if !l1.is_contiguous() {
1387 hanzo_ml::bail!("Input tensor s1 must be contiguous");
1388 }
1389 let dims = l1.dims();
1390 let axis = self.axis;
1391 let axis_len = dims[axis];
1392 let (start, end) = l1
1393 .contiguous_offsets()
1394 .ok_or(Error::RequiresContiguous { op: "cumsum" })?;
1395
1396 macro_rules! scan_block {
1398 ($vt:ident, $ty:ty, $add:ident, $init:expr) => {{
1399 let vs: &[$ty] = $vt;
1400 let input = &vs[start..end];
1401 let count = input.len() / axis_len;
1402 let mut result = Vec::<$ty>::with_capacity(input.len());
1403 if !self.reverse {
1404 if self.inclusive {
1405 for block in 0..count {
1406 let base = block * axis_len;
1407 let mut sum = input[base];
1408 result.push(sum);
1409 for j in 1..axis_len {
1410 sum = sum.$add(input[base + j]);
1411 result.push(sum);
1412 }
1413 }
1414 } else {
1415 let init: $ty = $init;
1416 for block in 0..count {
1417 let base = block * axis_len;
1418 let mut sum = init;
1419 for j in 0..axis_len {
1420 result.push(sum);
1421 sum = sum.$add(input[base + j]);
1422 }
1423 }
1424 }
1425 } else {
1426 if self.inclusive {
1427 for block in 0..count {
1428 let base = block * axis_len;
1429 let mut temp = Vec::<$ty>::with_capacity(axis_len);
1430 let mut sum = input[base + axis_len - 1];
1431 temp.push(sum);
1432 for k in 1..axis_len {
1433 let idx = axis_len - 1 - k;
1434 sum = sum.$add(input[base + idx]);
1435 temp.push(sum);
1436 }
1437 temp.reverse();
1438 result.extend(temp);
1439 }
1440 } else {
1441 let init: $ty = $init;
1442 for block in 0..count {
1443 let base = block * axis_len;
1444 let mut temp = Vec::<$ty>::with_capacity(axis_len);
1445 let mut sum = init;
1446 for k in 0..axis_len {
1447 let idx = axis_len - 1 - k;
1448 temp.push(sum);
1449 sum = sum.$add(input[base + idx]);
1450 }
1451 temp.reverse();
1452 result.extend(temp);
1453 }
1454 }
1455 }
1456 result
1457 }};
1458 }
1459 match s1 {
1460 CpuStorage::U8(vs) => {
1461 let result = scan_block!(vs, u8, wrapping_add, 0u8);
1462 Ok((CpuStorage::U8(result), l1.shape().clone()))
1463 }
1464 CpuStorage::I16(vs) => {
1465 let result = scan_block!(vs, i16, add, 0i16);
1466 Ok((CpuStorage::I16(result), l1.shape().clone()))
1467 }
1468 CpuStorage::U32(vs) => {
1469 let result = scan_block!(vs, u32, wrapping_add, 0u32);
1470 Ok((CpuStorage::U32(result), l1.shape().clone()))
1471 }
1472 CpuStorage::I32(vs) => {
1473 let result = scan_block!(vs, i32, add, 0i32);
1474 Ok((CpuStorage::I32(result), l1.shape().clone()))
1475 }
1476 CpuStorage::I64(vs) => {
1477 let result = scan_block!(vs, i64, add, 0i64);
1478 Ok((CpuStorage::I64(result), l1.shape().clone()))
1479 }
1480 CpuStorage::F32(vs) => {
1481 let result = scan_block!(vs, f32, add, 0.0f32);
1482 Ok((CpuStorage::F32(result), l1.shape().clone()))
1483 }
1484 CpuStorage::F64(vs) => {
1485 let result = scan_block!(vs, f64, add, 0.0f64);
1486 Ok((CpuStorage::F64(result), l1.shape().clone()))
1487 }
1488 _ => Err(Error::UnsupportedDTypeForOp(DType::F32, "cumsum")),
1489 }
1490 }
1491
1492 #[cfg(feature = "cuda")]
1493 fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
1494 todo!()
1495 }
1496
1497 #[cfg(feature = "metal")]
1498 fn metal_fwd(
1499 &self,
1500 s1: &hanzo_ml::MetalStorage,
1501 l1: &Layout,
1502 ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
1503 use crate::metal_kernels::ScanType;
1504
1505 let encoder = s1.device().command_encoder()?;
1506 encoder.set_label("cumsum");
1507
1508 let device = s1.device();
1509
1510 let out_shape = l1.shape().clone();
1511
1512 let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "cumsum")?;
1513
1514 crate::metal_kernels::call_scan(
1515 device.device(),
1516 &encoder,
1517 &crate::metal_kernels::Kernels::new(),
1518 s1.dtype(),
1519 ScanType::Sum,
1520 s1.buffer(),
1521 l1.start_offset() * s1.dtype().size_in_bytes(),
1522 self.axis,
1523 l1.dims(),
1524 l1.stride(),
1525 self.reverse,
1526 self.inclusive,
1527 &output,
1528 )
1529 .map_err(hanzo_ml::Error::wrap)?;
1530
1531 let newstorage =
1532 hanzo_ml::MetalStorage::new(output, device.clone(), out_shape.elem_count(), s1.dtype());
1533 Ok((newstorage, out_shape))
1534 }
1535}
1536
1537#[allow(dead_code)]
1538pub trait CumSumOp {
1539 fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor>;
1541
1542 fn fast_cumsum_config<D: Dim>(&self, axis: D, inclusive: bool, reverse: bool)
1543 -> Result<Tensor>;
1544}
1545
1546impl CumSumOp for Tensor {
1547 fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor> {
1548 self.fast_cumsum_config(axis, false, false)
1549 }
1550
1551 fn fast_cumsum_config<D: Dim>(
1552 &self,
1553 axis: D,
1554 inclusive: bool,
1555 reverse: bool,
1556 ) -> Result<Tensor> {
1557 self.apply_op1_no_bwd(&CumSum {
1558 inclusive,
1559 reverse,
1560 axis: axis.to_index(self.shape(), "cumsum")?,
1561 })
1562 }
1563}
1564
1565#[cfg(feature = "cuda")]
1569pub fn gptoss_swiglu_fused(gate: &Tensor, up: &Tensor, alpha: f32, limit: f32) -> Result<Tensor> {
1570 use half::{bf16, f16};
1571
1572 let gate = gate.contiguous()?;
1573 let up = up.contiguous()?;
1574
1575 if gate.shape() != up.shape() {
1576 hanzo_ml::bail!(
1577 "gptoss_swiglu: gate and up must have same shape, got {:?} vs {:?}",
1578 gate.shape(),
1579 up.shape()
1580 );
1581 }
1582
1583 let device = match gate.device() {
1584 hanzo_ml::Device::Cuda(dev) => dev,
1585 _ => hanzo_ml::bail!("gptoss_swiglu requires CUDA device"),
1586 };
1587
1588 let n_elements = gate.elem_count();
1589 let dtype = gate.dtype();
1590
1591 let gate_storage = gate.storage_and_layout().0;
1592 let up_storage = up.storage_and_layout().0;
1593
1594 let gate_cuda = match &*gate_storage {
1595 hanzo_ml::Storage::Cuda(s) => s,
1596 _ => hanzo_ml::bail!("Expected CUDA storage for gate"),
1597 };
1598 let up_cuda = match &*up_storage {
1599 hanzo_ml::Storage::Cuda(s) => s,
1600 _ => hanzo_ml::bail!("Expected CUDA storage for up"),
1601 };
1602
1603 let stream = device.cuda_stream().cu_stream();
1604
1605 match dtype {
1606 DType::F16 => {
1607 let output = device.alloc_zeros::<f16>(n_elements)?;
1608 let gate_slice = gate_cuda.as_cuda_slice::<f16>()?;
1609 let up_slice = up_cuda.as_cuda_slice::<f16>()?;
1610
1611 let (gate_ptr, _g_guard) = slice_ptr(gate_slice, 0);
1612 let (up_ptr, _u_guard) = slice_ptr(up_slice, 0);
1613 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1614
1615 unsafe {
1616 ffi::gptoss_swiglu_f16(
1617 gate_ptr as *const c_void,
1618 up_ptr as *const c_void,
1619 out_ptr as *mut c_void,
1620 n_elements as u32,
1621 alpha,
1622 limit,
1623 stream,
1624 );
1625 }
1626
1627 drop(_o_guard);
1628 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1629 Ok(Tensor::from((
1630 hanzo_ml::Storage::Cuda(out_storage),
1631 gate.shape().clone(),
1632 )))
1633 }
1634 DType::BF16 => {
1635 let output = device.alloc_zeros::<bf16>(n_elements)?;
1636 let gate_slice = gate_cuda.as_cuda_slice::<bf16>()?;
1637 let up_slice = up_cuda.as_cuda_slice::<bf16>()?;
1638
1639 let (gate_ptr, _g_guard) = slice_ptr(gate_slice, 0);
1640 let (up_ptr, _u_guard) = slice_ptr(up_slice, 0);
1641 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1642
1643 unsafe {
1644 ffi::gptoss_swiglu_bf16(
1645 gate_ptr as *const c_void,
1646 up_ptr as *const c_void,
1647 out_ptr as *mut c_void,
1648 n_elements as u32,
1649 alpha,
1650 limit,
1651 stream,
1652 );
1653 }
1654
1655 drop(_o_guard);
1656 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1657 Ok(Tensor::from((
1658 hanzo_ml::Storage::Cuda(out_storage),
1659 gate.shape().clone(),
1660 )))
1661 }
1662 DType::F32 => {
1663 let output = device.alloc_zeros::<f32>(n_elements)?;
1664 let gate_slice = gate_cuda.as_cuda_slice::<f32>()?;
1665 let up_slice = up_cuda.as_cuda_slice::<f32>()?;
1666
1667 let (gate_ptr, _g_guard) = slice_ptr(gate_slice, 0);
1668 let (up_ptr, _u_guard) = slice_ptr(up_slice, 0);
1669 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1670
1671 unsafe {
1672 ffi::gptoss_swiglu_f32(
1673 gate_ptr as *const c_void,
1674 up_ptr as *const c_void,
1675 out_ptr as *mut c_void,
1676 n_elements as u32,
1677 alpha,
1678 limit,
1679 stream,
1680 );
1681 }
1682
1683 drop(_o_guard);
1684 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1685 Ok(Tensor::from((
1686 hanzo_ml::Storage::Cuda(out_storage),
1687 gate.shape().clone(),
1688 )))
1689 }
1690 _ => hanzo_ml::bail!("gptoss_swiglu: unsupported dtype {:?}", dtype),
1691 }
1692}
1693
1694#[cfg(feature = "cuda")]
1706pub fn gptoss_swiglu_interleaved(
1707 gate_up: &Tensor,
1708 intermediate_size: usize,
1709 alpha: f32,
1710 limit: f32,
1711) -> Result<Tensor> {
1712 use half::{bf16, f16};
1713 use std::ffi::c_void;
1714
1715 let gate_up = gate_up.contiguous()?;
1716
1717 let dims = gate_up.dims();
1718 if dims.len() != 3 || dims[2] != 2 {
1719 hanzo_ml::bail!(
1720 "gptoss_swiglu_interleaved: expected gate_up shape [N, intermediate_size, 2], got {:?}",
1721 dims
1722 );
1723 }
1724
1725 let n = dims[0]; let device = match gate_up.device() {
1727 hanzo_ml::Device::Cuda(dev) => dev,
1728 _ => hanzo_ml::bail!("gptoss_swiglu_interleaved requires CUDA device"),
1729 };
1730
1731 let dtype = gate_up.dtype();
1732 let n_output_elements = n * intermediate_size;
1733
1734 let gate_up_storage = gate_up.storage_and_layout().0;
1735 let gate_up_cuda = match &*gate_up_storage {
1736 hanzo_ml::Storage::Cuda(s) => s,
1737 _ => hanzo_ml::bail!("Expected CUDA storage for gate_up"),
1738 };
1739
1740 let stream = device.cuda_stream().cu_stream();
1741
1742 match dtype {
1743 DType::F16 => {
1744 let output = device.alloc_zeros::<f16>(n_output_elements)?;
1745 let gate_up_slice = gate_up_cuda.as_cuda_slice::<f16>()?;
1746
1747 let (gate_up_ptr, _gu_guard) = slice_ptr(gate_up_slice, 0);
1748 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1749
1750 unsafe {
1751 ffi::gptoss_swiglu_interleaved_f16(
1752 gate_up_ptr as *const c_void,
1753 out_ptr as *mut c_void,
1754 n as u32,
1755 intermediate_size as u32,
1756 alpha,
1757 limit,
1758 stream,
1759 );
1760 }
1761
1762 drop(_o_guard);
1763 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1764 Ok(Tensor::from((
1765 hanzo_ml::Storage::Cuda(out_storage),
1766 Shape::from(vec![n, intermediate_size]),
1767 )))
1768 }
1769 DType::BF16 => {
1770 let output = device.alloc_zeros::<bf16>(n_output_elements)?;
1771 let gate_up_slice = gate_up_cuda.as_cuda_slice::<bf16>()?;
1772
1773 let (gate_up_ptr, _gu_guard) = slice_ptr(gate_up_slice, 0);
1774 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1775
1776 unsafe {
1777 ffi::gptoss_swiglu_interleaved_bf16(
1778 gate_up_ptr as *const c_void,
1779 out_ptr as *mut c_void,
1780 n as u32,
1781 intermediate_size as u32,
1782 alpha,
1783 limit,
1784 stream,
1785 );
1786 }
1787
1788 drop(_o_guard);
1789 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1790 Ok(Tensor::from((
1791 hanzo_ml::Storage::Cuda(out_storage),
1792 Shape::from(vec![n, intermediate_size]),
1793 )))
1794 }
1795 DType::F32 => {
1796 let output = device.alloc_zeros::<f32>(n_output_elements)?;
1797 let gate_up_slice = gate_up_cuda.as_cuda_slice::<f32>()?;
1798
1799 let (gate_up_ptr, _gu_guard) = slice_ptr(gate_up_slice, 0);
1800 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1801
1802 unsafe {
1803 ffi::gptoss_swiglu_interleaved_f32(
1804 gate_up_ptr as *const c_void,
1805 out_ptr as *mut c_void,
1806 n as u32,
1807 intermediate_size as u32,
1808 alpha,
1809 limit,
1810 stream,
1811 );
1812 }
1813
1814 drop(_o_guard);
1815 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1816 Ok(Tensor::from((
1817 hanzo_ml::Storage::Cuda(out_storage),
1818 Shape::from(vec![n, intermediate_size]),
1819 )))
1820 }
1821 _ => hanzo_ml::bail!("gptoss_swiglu_interleaved: unsupported dtype {:?}", dtype),
1822 }
1823}
1824
1825struct SoftmaxWithSinks {
1837 sinks: Tensor,
1838 num_heads: usize,
1839 q_len: usize,
1840 k_len: usize,
1841}
1842
1843impl CustomOp1 for SoftmaxWithSinks {
1844 fn name(&self) -> &'static str {
1845 "softmax-with-sinks"
1846 }
1847
1848 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
1849 use half::{bf16, f16};
1850
1851 let out_shape = layout.shape().clone();
1852 let total_rows = out_shape.elem_count() / self.k_len;
1853 let k_len = self.k_len;
1854 let num_heads = self.num_heads;
1855 let q_len = self.q_len;
1856 let offset = layout.start_offset();
1857
1858 let sinks_data = self.sinks.storage_and_layout();
1859 let sinks_cpu = match &*sinks_data.0 {
1860 hanzo_ml::Storage::Cpu(s) => s,
1861 _ => hanzo_ml::bail!("softmax_with_sinks cpu_fwd: sinks must be on CPU"),
1862 };
1863 let sinks_offset = sinks_data.1.start_offset();
1864
1865 match storage.dtype() {
1866 DType::F32 => {
1867 let logits = storage.as_slice::<f32>()?;
1868 let sinks_vals = sinks_cpu.as_slice::<f32>()?;
1869
1870 let mut result = vec![0f32; total_rows * k_len];
1871 result
1872 .par_chunks_mut(k_len)
1873 .enumerate()
1874 .for_each(|(row, out_row)| {
1875 let h = (row / q_len) % num_heads;
1876 let sink_val = sinks_vals[sinks_offset + h];
1877 let row_start = offset + row * k_len;
1878
1879 let mut max_val = sink_val;
1880 for k in 0..k_len {
1881 let v = logits[row_start + k];
1882 if v > max_val {
1883 max_val = v;
1884 }
1885 }
1886
1887 let mut sum = (sink_val - max_val).exp();
1888 for k in 0..k_len {
1889 let e = (logits[row_start + k] - max_val).exp();
1890 out_row[k] = e;
1891 sum += e;
1892 }
1893
1894 let inv_sum = 1.0 / sum;
1895 for item in out_row.iter_mut().take(k_len) {
1896 *item *= inv_sum;
1897 }
1898 });
1899
1900 Ok((CpuStorage::F32(result), out_shape))
1901 }
1902 DType::F16 => {
1903 let logits = storage.as_slice::<f16>()?;
1904 let sinks_vals = sinks_cpu.as_slice::<f16>()?;
1905
1906 let mut result = vec![f16::ZERO; total_rows * k_len];
1907 result
1908 .par_chunks_mut(k_len)
1909 .enumerate()
1910 .for_each(|(row, out_row)| {
1911 let h = (row / q_len) % num_heads;
1912 let sink_val = sinks_vals[sinks_offset + h].to_f32();
1913 let row_start = offset + row * k_len;
1914
1915 let mut max_val = sink_val;
1916 for k in 0..k_len {
1917 let v = logits[row_start + k].to_f32();
1918 if v > max_val {
1919 max_val = v;
1920 }
1921 }
1922
1923 let mut sum = (sink_val - max_val).exp();
1924 for k in 0..k_len {
1925 let e = (logits[row_start + k].to_f32() - max_val).exp();
1926 out_row[k] = f16::from_f32(e);
1927 sum += e;
1928 }
1929
1930 let inv_sum = 1.0f32 / sum;
1931 for item in out_row.iter_mut().take(k_len) {
1932 *item = f16::from_f32(item.to_f32() * inv_sum);
1933 }
1934 });
1935
1936 Ok((CpuStorage::F16(result), out_shape))
1937 }
1938 DType::BF16 => {
1939 let logits = storage.as_slice::<bf16>()?;
1940 let sinks_vals = sinks_cpu.as_slice::<bf16>()?;
1941
1942 let mut result = vec![bf16::ZERO; total_rows * k_len];
1943 result
1944 .par_chunks_mut(k_len)
1945 .enumerate()
1946 .for_each(|(row, out_row)| {
1947 let h = (row / q_len) % num_heads;
1948 let sink_val = sinks_vals[sinks_offset + h].to_f32();
1949 let row_start = offset + row * k_len;
1950
1951 let mut max_val = sink_val;
1952 for k in 0..k_len {
1953 let v = logits[row_start + k].to_f32();
1954 if v > max_val {
1955 max_val = v;
1956 }
1957 }
1958
1959 let mut sum = (sink_val - max_val).exp();
1960 for k in 0..k_len {
1961 let e = (logits[row_start + k].to_f32() - max_val).exp();
1962 out_row[k] = bf16::from_f32(e);
1963 sum += e;
1964 }
1965
1966 let inv_sum = 1.0f32 / sum;
1967 for item in out_row.iter_mut().take(k_len) {
1968 *item = bf16::from_f32(item.to_f32() * inv_sum);
1969 }
1970 });
1971
1972 Ok((CpuStorage::BF16(result), out_shape))
1973 }
1974 other => hanzo_ml::bail!("softmax_with_sinks: unsupported dtype {:?}", other),
1975 }
1976 }
1977
1978 #[cfg(feature = "cuda")]
1979 fn cuda_fwd(&self, storage: &CudaStorage, layout: &Layout) -> Result<(CudaStorage, Shape)> {
1980 use half::{bf16, f16};
1981
1982 let device = storage.device();
1983 let dtype = storage.dtype();
1984 let n_elements = layout.shape().elem_count();
1985 let out_shape = layout.shape().clone();
1986 let stream = device.cuda_stream().cu_stream();
1987 let logits_offset = layout.start_offset();
1988
1989 let batch_size = out_shape.dims()[0];
1990
1991 let sinks_data = self.sinks.storage_and_layout();
1992 let sinks_cuda = match &*sinks_data.0 {
1993 hanzo_ml::Storage::Cuda(s) => s,
1994 _ => hanzo_ml::bail!("softmax_with_sinks cuda_fwd: sinks must be on CUDA"),
1995 };
1996 let sinks_offset = sinks_data.1.start_offset();
1997
1998 match dtype {
1999 DType::F16 => {
2000 let output = device.alloc_zeros::<f16>(n_elements)?;
2001 let logits_slice = storage.as_cuda_slice::<f16>()?;
2002 let sinks_slice = sinks_cuda.as_cuda_slice::<f16>()?;
2003
2004 let (logits_ptr, _l_guard) = slice_ptr(logits_slice, logits_offset);
2005 let (sinks_ptr, _s_guard) = slice_ptr(sinks_slice, sinks_offset);
2006 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
2007
2008 unsafe {
2009 ffi::softmax_with_sinks_f16(
2010 logits_ptr as *const c_void,
2011 sinks_ptr as *const c_void,
2012 std::ptr::null(), out_ptr as *mut c_void,
2014 batch_size as i32,
2015 self.num_heads as i32,
2016 self.q_len as i32,
2017 self.k_len as i32,
2018 1.0,
2019 stream,
2020 );
2021 }
2022
2023 drop(_o_guard);
2024 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2025 Ok((out_storage, out_shape))
2026 }
2027 DType::BF16 => {
2028 let output = device.alloc_zeros::<bf16>(n_elements)?;
2029 let logits_slice = storage.as_cuda_slice::<bf16>()?;
2030 let sinks_slice = sinks_cuda.as_cuda_slice::<bf16>()?;
2031
2032 let (logits_ptr, _l_guard) = slice_ptr(logits_slice, logits_offset);
2033 let (sinks_ptr, _s_guard) = slice_ptr(sinks_slice, sinks_offset);
2034 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
2035
2036 unsafe {
2037 ffi::softmax_with_sinks_bf16(
2038 logits_ptr as *const c_void,
2039 sinks_ptr as *const c_void,
2040 std::ptr::null(),
2041 out_ptr as *mut c_void,
2042 batch_size as i32,
2043 self.num_heads as i32,
2044 self.q_len as i32,
2045 self.k_len as i32,
2046 1.0,
2047 stream,
2048 );
2049 }
2050
2051 drop(_o_guard);
2052 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2053 Ok((out_storage, out_shape))
2054 }
2055 DType::F32 => {
2056 let output = device.alloc_zeros::<f32>(n_elements)?;
2057 let logits_slice = storage.as_cuda_slice::<f32>()?;
2058 let sinks_slice = sinks_cuda.as_cuda_slice::<f32>()?;
2059
2060 let (logits_ptr, _l_guard) = slice_ptr(logits_slice, logits_offset);
2061 let (sinks_ptr, _s_guard) = slice_ptr(sinks_slice, sinks_offset);
2062 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
2063
2064 unsafe {
2065 ffi::softmax_with_sinks_f32(
2066 logits_ptr as *const c_void,
2067 sinks_ptr as *const c_void,
2068 std::ptr::null(),
2069 out_ptr as *mut c_void,
2070 batch_size as i32,
2071 self.num_heads as i32,
2072 self.q_len as i32,
2073 self.k_len as i32,
2074 1.0,
2075 stream,
2076 );
2077 }
2078
2079 drop(_o_guard);
2080 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2081 Ok((out_storage, out_shape))
2082 }
2083 _ => hanzo_ml::bail!("softmax_with_sinks: unsupported dtype {:?}", dtype),
2084 }
2085 }
2086
2087 #[cfg(feature = "metal")]
2088 fn metal_fwd(
2089 &self,
2090 storage: &hanzo_ml::MetalStorage,
2091 layout: &Layout,
2092 ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
2093 let dtype = storage.dtype();
2094 let n_elements = layout.shape().elem_count();
2095 let out_shape = layout.shape().clone();
2096 let total_rows = n_elements / self.k_len;
2097
2098 let device = storage.device();
2099 let encoder = device.command_encoder()?;
2100 encoder.set_label("softmax-with-sinks");
2101
2102 let output = device.new_buffer(n_elements, dtype, "softmax-with-sinks-output")?;
2103
2104 let sinks_data = self.sinks.storage_and_layout();
2105 let sinks_metal = match &*sinks_data.0 {
2106 hanzo_ml::Storage::Metal(s) => s,
2107 _ => hanzo_ml::bail!("softmax_with_sinks metal_fwd: sinks must be on Metal"),
2108 };
2109 let sinks_offset = sinks_data.1.start_offset() * self.sinks.dtype().size_in_bytes();
2110
2111 crate::metal_kernels::call_softmax_with_sinks(
2112 device.device(),
2113 &encoder,
2114 &crate::metal_kernels::Kernels::new(),
2115 dtype,
2116 storage.buffer(),
2117 layout.start_offset() * dtype.size_in_bytes(),
2118 sinks_metal.buffer(),
2119 sinks_offset,
2120 &output,
2121 self.num_heads as u32,
2122 self.q_len as u32,
2123 self.k_len as u32,
2124 total_rows,
2125 )
2126 .map_err(hanzo_ml::Error::wrap)?;
2127
2128 let newstorage = hanzo_ml::MetalStorage::new(output, device.clone(), n_elements, dtype);
2129 Ok((newstorage, out_shape))
2130 }
2131}
2132
2133pub fn softmax_with_sinks(
2134 logits: &Tensor,
2135 sinks: &Tensor,
2136 mask: Option<&Tensor>,
2137) -> Result<Tensor> {
2138 let logits = if let Some(mask) = mask {
2139 logits.broadcast_add(mask)?
2140 } else {
2141 logits.clone()
2142 };
2143 let logits = logits.contiguous()?;
2144 let sinks = sinks.contiguous()?;
2145
2146 let dims = logits.dims();
2147 if dims.len() != 4 {
2148 hanzo_ml::bail!(
2149 "softmax_with_sinks: expected logits to have 4 dims [b, h, q, k], got {:?}",
2150 dims
2151 );
2152 }
2153
2154 let num_heads = dims[1];
2155 let q_len = dims[2];
2156 let k_len = dims[3];
2157
2158 if sinks.dims() != [num_heads] {
2159 hanzo_ml::bail!(
2160 "softmax_with_sinks: expected sinks shape [{}], got {:?}",
2161 num_heads,
2162 sinks.dims()
2163 );
2164 }
2165
2166 logits.apply_op1_no_bwd(&SoftmaxWithSinks {
2167 sinks: sinks.clone(),
2168 num_heads,
2169 q_len,
2170 k_len,
2171 })
2172}
2173
2174#[allow(dead_code)]
2179struct FlashAttnSinksMetal {
2180 key: Tensor,
2181 value: Tensor,
2182 sinks: Tensor, softmax_scale: f32,
2184 window_size: usize,
2185}
2186
2187impl CustomOp1 for FlashAttnSinksMetal {
2188 fn name(&self) -> &'static str {
2189 "flash-attn-sinks-metal"
2190 }
2191
2192 fn cpu_fwd(&self, _storage: &CpuStorage, _layout: &Layout) -> Result<(CpuStorage, Shape)> {
2193 hanzo_ml::bail!("flash_attn_sinks_metal: no CPU support, use softmax_with_sinks fallback")
2194 }
2195
2196 #[cfg(feature = "metal")]
2197 fn metal_fwd(
2198 &self,
2199 q_storage: &hanzo_ml::MetalStorage,
2200 q_layout: &Layout,
2201 ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
2202 let dtype = q_storage.dtype();
2203 let out_shape = q_layout.shape().clone();
2204 let (batch_size, num_heads, q_len, head_dim) = q_layout.shape().dims4()?;
2205
2206 let (k_s, k_l) = self.key.storage_and_layout();
2208 let k_metal = match &*k_s {
2209 hanzo_ml::Storage::Metal(s) => s,
2210 _ => hanzo_ml::bail!("flash_attn_sinks_metal: key must be a Metal tensor"),
2211 };
2212 let (_, num_kv_heads, k_len, _) = k_l.shape().dims4()?;
2213
2214 let (v_s, v_l) = self.value.storage_and_layout();
2216 let v_metal = match &*v_s {
2217 hanzo_ml::Storage::Metal(s) => s,
2218 _ => hanzo_ml::bail!("flash_attn_sinks_metal: value must be a Metal tensor"),
2219 };
2220
2221 let (s_s, s_l) = self.sinks.storage_and_layout();
2223 let sinks_metal = match &*s_s {
2224 hanzo_ml::Storage::Metal(s) => s,
2225 _ => hanzo_ml::bail!("flash_attn_sinks_metal: sinks must be a Metal tensor"),
2226 };
2227 let sinks_offset = s_l.start_offset() * self.sinks.dtype().size_in_bytes();
2228
2229 let device = q_storage.device();
2230 let elem_count = out_shape.elem_count();
2231 let output = device.new_buffer(elem_count, dtype, "flash-attn-sinks-output")?;
2232
2233 let encoder = device.command_encoder()?;
2234 encoder.set_label("flash-attn-sinks");
2235
2236 let kernels = crate::metal_kernels::Kernels::new();
2237
2238 let q_offset = q_layout.start_offset() * dtype.size_in_bytes();
2239 let k_offset = k_l.start_offset() * dtype.size_in_bytes();
2240 let v_offset = v_l.start_offset() * dtype.size_in_bytes();
2241
2242 if q_len == 1 {
2243 let gqa_factor = (num_heads / num_kv_heads) as i32;
2245 let b = batch_size * num_heads;
2246
2247 let k_stride = k_l.stride()[1]; let v_stride = v_l.stride()[1];
2251
2252 let two_pass_threshold = 1024;
2253 if k_len >= two_pass_threshold {
2254 let blocks: usize = 32;
2256 let intermediate = device.new_buffer(
2257 b * blocks * head_dim,
2258 DType::F32,
2259 "sdpa-sinks-intermediate",
2260 )?;
2261 let sums = device.new_buffer(b * blocks, DType::F32, "sdpa-sinks-sums")?;
2262 let maxs = device.new_buffer(b * blocks, DType::F32, "sdpa-sinks-maxs")?;
2263
2264 crate::metal_kernels::call_sdpa_vector_with_sinks_2pass(
2265 device.device(),
2266 &encoder,
2267 &kernels,
2268 dtype,
2269 q_storage.buffer(),
2270 q_offset,
2271 k_metal.buffer(),
2272 k_offset,
2273 v_metal.buffer(),
2274 v_offset,
2275 sinks_metal.buffer(),
2276 sinks_offset,
2277 &output,
2278 &intermediate,
2279 &sums,
2280 &maxs,
2281 head_dim,
2282 gqa_factor,
2283 k_len as i32,
2284 k_stride,
2285 v_stride,
2286 self.softmax_scale,
2287 b,
2288 )
2289 .map_err(hanzo_ml::Error::wrap)?;
2290 } else {
2291 crate::metal_kernels::call_sdpa_vector_with_sinks(
2293 device.device(),
2294 &encoder,
2295 &kernels,
2296 dtype,
2297 q_storage.buffer(),
2298 q_offset,
2299 k_metal.buffer(),
2300 k_offset,
2301 v_metal.buffer(),
2302 v_offset,
2303 sinks_metal.buffer(),
2304 sinks_offset,
2305 &output,
2306 head_dim,
2307 gqa_factor,
2308 k_len as i32,
2309 k_stride,
2310 v_stride,
2311 self.softmax_scale,
2312 b,
2313 )
2314 .map_err(hanzo_ml::Error::wrap)?;
2315 }
2316 } else {
2317 crate::metal_kernels::call_flash_attn_sinks_prefill(
2319 device.device(),
2320 &encoder,
2321 &kernels,
2322 dtype,
2323 q_storage.buffer(),
2324 q_offset,
2325 k_metal.buffer(),
2326 k_offset,
2327 v_metal.buffer(),
2328 v_offset,
2329 sinks_metal.buffer(),
2330 sinks_offset,
2331 &output,
2332 self.softmax_scale,
2333 batch_size,
2334 q_len,
2335 k_len,
2336 num_heads,
2337 num_kv_heads,
2338 head_dim,
2339 self.window_size,
2340 )
2341 .map_err(hanzo_ml::Error::wrap)?;
2342 }
2343
2344 let newstorage = hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, dtype);
2345 Ok((newstorage, out_shape))
2346 }
2347}
2348
2349#[allow(clippy::too_many_arguments)]
2369pub fn flash_attn_sinks_metal(
2370 q: &Tensor,
2371 k: &Tensor,
2372 v: &Tensor,
2373 sinks: Option<&Tensor>,
2374 softmax_scale: f32,
2375 window_size: usize,
2376) -> Result<Tensor> {
2377 let q = q.contiguous()?;
2378 let k = k.contiguous()?;
2379 let v = v.contiguous()?;
2380
2381 let sinks = match sinks {
2382 Some(s) => s.to_dtype(DType::F32)?.contiguous()?,
2383 None => {
2384 let num_heads = q.dim(1)?;
2386 Tensor::zeros(num_heads, DType::F32, q.device())?
2387 }
2388 };
2389
2390 let op = FlashAttnSinksMetal {
2391 key: k.clone(),
2392 value: v.clone(),
2393 sinks,
2394 softmax_scale,
2395 window_size,
2396 };
2397 q.apply_op1_no_bwd(&op)
2398}
2399
2400#[allow(dead_code)]
2401struct FlashAttnSinksVarlenMetal {
2402 key: Tensor, value: Tensor, sinks: Tensor, cu_seqlens_q: Tensor, cu_seqlens_k: Tensor, softmax_scale: f32,
2408 window_size: usize,
2409}
2410
2411impl CustomOp1 for FlashAttnSinksVarlenMetal {
2412 fn name(&self) -> &'static str {
2413 "flash-attn-sinks-varlen-metal"
2414 }
2415
2416 fn cpu_fwd(&self, _storage: &CpuStorage, _layout: &Layout) -> Result<(CpuStorage, Shape)> {
2417 hanzo_ml::bail!("flash_attn_sinks_varlen_metal: no CPU support")
2418 }
2419
2420 #[cfg(feature = "metal")]
2421 fn metal_fwd(
2422 &self,
2423 q_storage: &hanzo_ml::MetalStorage,
2424 q_layout: &Layout,
2425 ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
2426 let dtype = q_storage.dtype();
2427 let out_shape = q_layout.shape().clone();
2428 let (batch_size, num_heads, max_q_len, head_dim) = q_layout.shape().dims4()?;
2429
2430 let (k_s, k_l) = self.key.storage_and_layout();
2432 let k_metal = match &*k_s {
2433 hanzo_ml::Storage::Metal(s) => s,
2434 _ => hanzo_ml::bail!("flash_attn_sinks_varlen_metal: key must be a Metal tensor"),
2435 };
2436 let (_, num_kv_heads, _) = k_l.shape().dims3()?;
2437
2438 let (v_s, v_l) = self.value.storage_and_layout();
2440 let v_metal = match &*v_s {
2441 hanzo_ml::Storage::Metal(s) => s,
2442 _ => hanzo_ml::bail!("flash_attn_sinks_varlen_metal: value must be a Metal tensor"),
2443 };
2444
2445 let (s_s, s_l) = self.sinks.storage_and_layout();
2447 let sinks_metal = match &*s_s {
2448 hanzo_ml::Storage::Metal(s) => s,
2449 _ => hanzo_ml::bail!("flash_attn_sinks_varlen_metal: sinks must be a Metal tensor"),
2450 };
2451 let sinks_offset = s_l.start_offset() * self.sinks.dtype().size_in_bytes();
2452
2453 let (csq_s, csq_l) = self.cu_seqlens_q.storage_and_layout();
2455 let csq_metal = match &*csq_s {
2456 hanzo_ml::Storage::Metal(s) => s,
2457 _ => hanzo_ml::bail!(
2458 "flash_attn_sinks_varlen_metal: cu_seqlens_q must be a Metal tensor"
2459 ),
2460 };
2461 let csq_offset = csq_l.start_offset() * DType::U32.size_in_bytes();
2462
2463 let (csk_s, csk_l) = self.cu_seqlens_k.storage_and_layout();
2465 let csk_metal = match &*csk_s {
2466 hanzo_ml::Storage::Metal(s) => s,
2467 _ => hanzo_ml::bail!(
2468 "flash_attn_sinks_varlen_metal: cu_seqlens_k must be a Metal tensor"
2469 ),
2470 };
2471 let csk_offset = csk_l.start_offset() * DType::U32.size_in_bytes();
2472
2473 let device = q_storage.device();
2474 let elem_count = out_shape.elem_count();
2475 let output = device.new_buffer(elem_count, dtype, "flash-attn-sinks-varlen-output")?;
2476
2477 let encoder = device.command_encoder()?;
2478 encoder.set_label("flash-attn-sinks-varlen");
2479
2480 let kernels = crate::metal_kernels::Kernels::new();
2481
2482 let q_offset = q_layout.start_offset() * dtype.size_in_bytes();
2483 let k_offset = k_l.start_offset() * dtype.size_in_bytes();
2484 let v_offset = v_l.start_offset() * dtype.size_in_bytes();
2485
2486 crate::metal_kernels::call_flash_attn_sinks_varlen_prefill(
2487 device.device(),
2488 &encoder,
2489 &kernels,
2490 dtype,
2491 q_storage.buffer(),
2492 q_offset,
2493 k_metal.buffer(),
2494 k_offset,
2495 v_metal.buffer(),
2496 v_offset,
2497 sinks_metal.buffer(),
2498 sinks_offset,
2499 &output,
2500 csq_metal.buffer(),
2501 csq_offset,
2502 csk_metal.buffer(),
2503 csk_offset,
2504 self.softmax_scale,
2505 batch_size,
2506 max_q_len,
2507 num_heads,
2508 num_kv_heads,
2509 head_dim,
2510 self.window_size,
2511 )
2512 .map_err(hanzo_ml::Error::wrap)?;
2513
2514 let newstorage = hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, dtype);
2515 Ok((newstorage, out_shape))
2516 }
2517}
2518
2519#[allow(clippy::too_many_arguments)]
2537pub fn flash_attn_sinks_varlen_metal(
2538 q: &Tensor,
2539 k: &Tensor,
2540 v: &Tensor,
2541 sinks: Option<&Tensor>,
2542 cu_seqlens_q: &Tensor,
2543 cu_seqlens_k: &Tensor,
2544 softmax_scale: f32,
2545 window_size: usize,
2546) -> Result<Tensor> {
2547 let q = q.contiguous()?;
2548 let k = k.contiguous()?;
2549 let v = v.contiguous()?;
2550
2551 let sinks = match sinks {
2552 Some(s) => s.to_dtype(DType::F32)?.contiguous()?,
2553 None => {
2554 let num_heads = q.dim(1)?;
2555 Tensor::zeros(num_heads, DType::F32, q.device())?
2556 }
2557 };
2558
2559 let op = FlashAttnSinksVarlenMetal {
2560 key: k.clone(),
2561 value: v.clone(),
2562 sinks,
2563 cu_seqlens_q: cu_seqlens_q.clone(),
2564 cu_seqlens_k: cu_seqlens_k.clone(),
2565 softmax_scale,
2566 window_size,
2567 };
2568 q.apply_op1_no_bwd(&op)
2569}
2570
2571#[derive(Clone, Copy, Debug)]
2574#[repr(i32)]
2575pub enum GluActivationType {
2576 Silu = 0,
2577 Gelu = 1,
2578 Relu = 2,
2579 GeluErf = 3,
2580}
2581
2582fn cpu_silu(x: f32) -> f32 {
2584 x / (1.0 + (-x).exp())
2585}
2586
2587fn cpu_gelu(x: f32) -> f32 {
2588 #[allow(clippy::excessive_precision)]
2589 const SQRT_2_OVER_PI: f32 = 0.7978845608;
2590 const COEFF: f32 = 0.044715;
2591 let x3 = x * x * x;
2592 let inner = SQRT_2_OVER_PI * (x + COEFF * x3);
2593 0.5 * x * (1.0 + inner.tanh())
2594}
2595
2596fn cpu_relu(x: f32) -> f32 {
2597 x.max(0.0)
2598}
2599
2600fn cpu_gelu_erf(x: f32) -> f32 {
2601 x * (1.0 + hanzo_ml::cpu::erf::erf_f32(x * std::f32::consts::FRAC_1_SQRT_2)) / 2.0
2603}
2604
2605fn apply_cpu_activation(x: f32, activation: GluActivationType) -> f32 {
2606 match activation {
2607 GluActivationType::Silu => cpu_silu(x),
2608 GluActivationType::Gelu => cpu_gelu(x),
2609 GluActivationType::Relu => cpu_relu(x),
2610 GluActivationType::GeluErf => cpu_gelu_erf(x),
2611 }
2612}
2613
2614struct FusedGlu(GluActivationType);
2615
2616impl CustomOp2 for FusedGlu {
2617 fn name(&self) -> &'static str {
2618 "fused-glu"
2619 }
2620
2621 fn cpu_fwd(
2622 &self,
2623 s1: &CpuStorage,
2624 l1: &Layout,
2625 s2: &CpuStorage,
2626 l2: &Layout,
2627 ) -> Result<(CpuStorage, Shape)> {
2628 use half::{bf16, f16};
2629
2630 let activation = self.0;
2631 let out_shape = l1.shape().clone();
2632 let len = out_shape.elem_count();
2633
2634 let result_storage = match s1.dtype() {
2635 DType::F32 => {
2636 let a_slice = s1.as_slice::<f32>()?;
2637 let b_slice = s2.as_slice::<f32>()?;
2638 let a_offset = l1.start_offset();
2639 let b_offset = l2.start_offset();
2640
2641 let result: Vec<f32> = (0..len)
2642 .into_par_iter()
2643 .map(|i| {
2644 let a_val = a_slice[a_offset + i];
2645 let b_val = b_slice[b_offset + i];
2646 apply_cpu_activation(a_val, activation) * b_val
2647 })
2648 .collect();
2649 CpuStorage::F32(result)
2650 }
2651 DType::F16 => {
2652 let a_slice = s1.as_slice::<f16>()?;
2653 let b_slice = s2.as_slice::<f16>()?;
2654 let a_offset = l1.start_offset();
2655 let b_offset = l2.start_offset();
2656
2657 let result: Vec<f16> = (0..len)
2658 .into_par_iter()
2659 .map(|i| {
2660 let a_val = a_slice[a_offset + i].to_f32();
2661 let activated = f16::from_f32(apply_cpu_activation(a_val, activation));
2664 f16::from_f32(activated.to_f32() * b_slice[b_offset + i].to_f32())
2665 })
2666 .collect();
2667 CpuStorage::F16(result)
2668 }
2669 DType::BF16 => {
2670 let a_slice = s1.as_slice::<bf16>()?;
2671 let b_slice = s2.as_slice::<bf16>()?;
2672 let a_offset = l1.start_offset();
2673 let b_offset = l2.start_offset();
2674
2675 let result: Vec<bf16> = (0..len)
2676 .into_par_iter()
2677 .map(|i| {
2678 let a_val = a_slice[a_offset + i].to_f32();
2679 let activated = bf16::from_f32(apply_cpu_activation(a_val, activation));
2682 bf16::from_f32(activated.to_f32() * b_slice[b_offset + i].to_f32())
2683 })
2684 .collect();
2685 CpuStorage::BF16(result)
2686 }
2687 other => hanzo_ml::bail!("fused_glu: unsupported dtype {:?}", other),
2688 };
2689
2690 Ok((result_storage, out_shape))
2691 }
2692
2693 #[cfg(feature = "cuda")]
2694 fn cuda_fwd(
2695 &self,
2696 s1: &CudaStorage,
2697 l1: &Layout,
2698 s2: &CudaStorage,
2699 l2: &Layout,
2700 ) -> Result<(CudaStorage, Shape)> {
2701 use half::{bf16, f16};
2702
2703 let activation = self.0;
2704 let device = s1.device();
2705 let n_elements = l1.shape().elem_count();
2706 let dtype = s1.dtype();
2707 let out_shape = l1.shape().clone();
2708 let stream = device.cuda_stream().cu_stream();
2709 let a_offset = l1.start_offset();
2710 let b_offset = l2.start_offset();
2711
2712 match dtype {
2713 DType::F16 => {
2714 let output = device.alloc_zeros::<f16>(n_elements)?;
2715 let a_slice = s1.as_cuda_slice::<f16>()?;
2716 let b_slice = s2.as_cuda_slice::<f16>()?;
2717
2718 let (a_ptr, _a_guard) = slice_ptr(a_slice, a_offset);
2719 let (b_ptr, _b_guard) = slice_ptr(b_slice, b_offset);
2720 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
2721
2722 unsafe {
2723 ffi::fused_glu_f16(
2724 a_ptr as *const c_void,
2725 b_ptr as *const c_void,
2726 out_ptr as *mut c_void,
2727 n_elements as u32,
2728 activation as i32,
2729 stream,
2730 );
2731 }
2732
2733 drop(_o_guard);
2734 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2735 Ok((out_storage, out_shape))
2736 }
2737 DType::BF16 => {
2738 let output = device.alloc_zeros::<bf16>(n_elements)?;
2739 let a_slice = s1.as_cuda_slice::<bf16>()?;
2740 let b_slice = s2.as_cuda_slice::<bf16>()?;
2741
2742 let (a_ptr, _a_guard) = slice_ptr(a_slice, a_offset);
2743 let (b_ptr, _b_guard) = slice_ptr(b_slice, b_offset);
2744 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
2745
2746 unsafe {
2747 ffi::fused_glu_bf16(
2748 a_ptr as *const c_void,
2749 b_ptr as *const c_void,
2750 out_ptr as *mut c_void,
2751 n_elements as u32,
2752 activation as i32,
2753 stream,
2754 );
2755 }
2756
2757 drop(_o_guard);
2758 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2759 Ok((out_storage, out_shape))
2760 }
2761 DType::F32 => {
2762 let output = device.alloc_zeros::<f32>(n_elements)?;
2763 let a_slice = s1.as_cuda_slice::<f32>()?;
2764 let b_slice = s2.as_cuda_slice::<f32>()?;
2765
2766 let (a_ptr, _a_guard) = slice_ptr(a_slice, a_offset);
2767 let (b_ptr, _b_guard) = slice_ptr(b_slice, b_offset);
2768 let (out_ptr, _o_guard) = slice_ptr(&output, 0);
2769
2770 unsafe {
2771 ffi::fused_glu_f32(
2772 a_ptr as *const c_void,
2773 b_ptr as *const c_void,
2774 out_ptr as *mut c_void,
2775 n_elements as u32,
2776 activation as i32,
2777 stream,
2778 );
2779 }
2780
2781 drop(_o_guard);
2782 let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2783 Ok((out_storage, out_shape))
2784 }
2785 _ => hanzo_ml::bail!("fused_glu: unsupported dtype {:?}", dtype),
2786 }
2787 }
2788
2789 #[cfg(feature = "metal")]
2790 fn metal_fwd(
2791 &self,
2792 s1: &hanzo_ml::MetalStorage,
2793 l1: &Layout,
2794 s2: &hanzo_ml::MetalStorage,
2795 l2: &Layout,
2796 ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
2797 let activation = self.0;
2798 let n_elements = l1.shape().elem_count();
2799 let dtype = s1.dtype();
2800 let out_shape = l1.shape().clone();
2801
2802 let device = s1.device();
2803 let encoder = device.command_encoder()?;
2804 encoder.set_label("fused-glu");
2805
2806 let output = device.new_buffer(n_elements, dtype, "fused-glu-output")?;
2807
2808 crate::metal_kernels::call_fused_glu(
2809 device.device(),
2810 &encoder,
2811 &crate::metal_kernels::Kernels::new(),
2812 dtype,
2813 s1.buffer(),
2814 s2.buffer(),
2815 l1.start_offset() * dtype.size_in_bytes(),
2816 l2.start_offset() * dtype.size_in_bytes(),
2817 n_elements,
2818 activation as i32,
2819 &output,
2820 )
2821 .map_err(hanzo_ml::Error::wrap)?;
2822
2823 let newstorage = hanzo_ml::MetalStorage::new(output, device.clone(), n_elements, dtype);
2824 Ok((newstorage, out_shape))
2825 }
2826}
2827
2828pub fn fused_glu(a: &Tensor, b: &Tensor, activation: GluActivationType) -> Result<Tensor> {
2834 let a = a.contiguous()?;
2835 let b = b.contiguous()?;
2836
2837 if a.shape() != b.shape() {
2838 hanzo_ml::bail!(
2839 "fused_glu: a and b must have same shape, got {:?} vs {:?}",
2840 a.shape(),
2841 b.shape()
2842 );
2843 }
2844
2845 if a.device().is_rocm() || a.device().is_vulkan() {
2848 let act = match activation {
2849 GluActivationType::Silu => a.silu()?,
2850 GluActivationType::Gelu => a.gelu()?,
2851 GluActivationType::GeluErf => a.gelu_erf()?,
2852 GluActivationType::Relu => a.relu()?,
2853 };
2854 return act.mul(&b);
2855 }
2856
2857 a.apply_op2_no_bwd(&b, &FusedGlu(activation))
2858}
2859
2860mod tests {
2861 #[test]
2862 fn test_cumsum_exclusive_forward_cpu() {
2863 use crate::utils::ops::CumSumOp;
2864 use hanzo_ml::Tensor;
2865 let device = hanzo_ml::Device::Cpu;
2866 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2867 let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
2868 assert_eq!(b, [0, 1, 3, 6]);
2869 }
2870
2871 #[test]
2872 fn test_cumsum_inclusive_forward_cpu() {
2873 use crate::utils::ops::CumSumOp;
2874 use hanzo_ml::Tensor;
2875 let device = hanzo_ml::Device::Cpu;
2876 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2877 let b = a
2878 .fast_cumsum_config(0, true, false)
2879 .unwrap()
2880 .to_vec1::<i64>()
2881 .unwrap();
2882 assert_eq!(b, [1, 3, 6, 10]);
2883 }
2884
2885 #[test]
2886 fn test_cumsum_exclusive_reverse_cpu() {
2887 use crate::utils::ops::CumSumOp;
2888 use hanzo_ml::Tensor;
2889 let device = hanzo_ml::Device::Cpu;
2890 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2891 let b = a
2892 .fast_cumsum_config(0, false, true)
2893 .unwrap()
2894 .to_vec1::<i64>()
2895 .unwrap();
2896 assert_eq!(b, [9, 7, 4, 0]);
2897 }
2898
2899 #[test]
2900 fn test_cumsum_inclusive_reverse_cpu() {
2901 use crate::utils::ops::CumSumOp;
2902 use hanzo_ml::Tensor;
2903 let device = hanzo_ml::Device::Cpu;
2904 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2905 let b = a
2906 .fast_cumsum_config(0, true, true)
2907 .unwrap()
2908 .to_vec1::<i64>()
2909 .unwrap();
2910 assert_eq!(b, [10, 9, 7, 4]);
2911 }
2912
2913 #[cfg(feature = "metal")]
2914 #[test]
2915 fn test_cumsum_exclusive_forward_metal() {
2916 use crate::utils::ops::CumSumOp;
2917 use hanzo_ml::Tensor;
2918 let device = hanzo_ml::Device::new_metal(0).unwrap();
2919 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2920 let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
2921 assert_eq!(b, [0, 1, 3, 6]);
2922 }
2923
2924 #[cfg(feature = "metal")]
2925 #[test]
2926 fn test_cumsum_inclusive_forward_metal() {
2927 use crate::utils::ops::CumSumOp;
2928 use hanzo_ml::Tensor;
2929 let device = hanzo_ml::Device::new_metal(0).unwrap();
2930 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2931 let b = a
2932 .fast_cumsum_config(0, true, false)
2933 .unwrap()
2934 .to_vec1::<i64>()
2935 .unwrap();
2936 assert_eq!(b, [1, 3, 6, 10]);
2937 }
2938
2939 #[cfg(feature = "metal")]
2940 #[test]
2941 fn test_cumsum_exclusive_reverse_metal() {
2942 use crate::utils::ops::CumSumOp;
2943 use hanzo_ml::Tensor;
2944 let device = hanzo_ml::Device::new_metal(0).unwrap();
2945 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2946 let b = a
2947 .fast_cumsum_config(0, false, true)
2948 .unwrap()
2949 .to_vec1::<i64>()
2950 .unwrap();
2951 assert_eq!(b, [9, 7, 4, 0]);
2952 }
2953
2954 #[cfg(feature = "metal")]
2955 #[test]
2956 fn test_cumsum_inclusive_reverse_metal() {
2957 use crate::utils::ops::CumSumOp;
2958 use hanzo_ml::Tensor;
2959 let device = hanzo_ml::Device::new_metal(0).unwrap();
2960 let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2961 let b = a
2962 .fast_cumsum_config(0, true, true)
2963 .unwrap()
2964 .to_vec1::<i64>()
2965 .unwrap();
2966 assert_eq!(b, [10, 9, 7, 4]);
2967 }
2968
2969 #[test]
2970 fn test_nonzero_cpu() {
2971 use crate::utils::ops::NonZeroOp;
2972 use hanzo_ml::Tensor;
2973 let device = hanzo_ml::Device::Cpu;
2974 let a = Tensor::from_vec(
2975 vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
2976 &[2, 4],
2977 &device,
2978 )
2979 .unwrap();
2980 let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
2981 assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
2982 }
2983
2984 #[cfg(feature = "cuda")]
2985 #[test]
2986 fn test_nonzero_cuda() {
2987 use crate::utils::ops::NonZeroOp;
2988 use hanzo_ml::Tensor;
2989 let device = hanzo_ml::Device::new_cuda(0).unwrap();
2990 let a = Tensor::from_vec(
2991 vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
2992 &[2, 4],
2993 &device,
2994 )
2995 .unwrap();
2996 let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
2997 assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
2998 }
2999
3000 #[test]
3001 fn test_bitwise_and_cpu() {
3002 use crate::utils::ops::BitWiseOp;
3003 use hanzo_ml::Tensor;
3004 let device = hanzo_ml::Device::Cpu;
3005 let a =
3006 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
3007 let b =
3008 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
3009 let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
3010 assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [5, 7]]);
3011 }
3012
3013 #[cfg(feature = "cuda")]
3014 #[test]
3015 fn test_bitwise_and_cuda() {
3016 use crate::utils::ops::BitWiseOp;
3017 use hanzo_ml::Tensor;
3018 let device = hanzo_ml::Device::new_cuda(0).unwrap();
3019 let a =
3020 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
3021 let b =
3022 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 0, 7], (5, 2), &device).unwrap();
3023 let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
3024 assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [0, 7]]);
3025 }
3026
3027 #[test]
3028 fn test_bitwise_or_cpu() {
3029 use crate::utils::ops::BitWiseOp;
3030 use hanzo_ml::Tensor;
3031 let device = hanzo_ml::Device::Cpu;
3032 let a =
3033 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
3034 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
3035 let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
3036 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
3037 }
3038
3039 #[cfg(feature = "cuda")]
3040 #[test]
3041 fn test_bitwise_or_cuda() {
3042 use crate::utils::ops::BitWiseOp;
3043 use hanzo_ml::Tensor;
3044 let device = hanzo_ml::Device::new_cuda(0).unwrap();
3045 let a =
3046 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
3047 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
3048 let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
3049 assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
3050 }
3051
3052 #[test]
3053 fn test_bitwise_xor_cpu() {
3054 use crate::utils::ops::BitWiseOp;
3055 use hanzo_ml::Tensor;
3056 let device = hanzo_ml::Device::Cpu;
3057 let a =
3058 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
3059 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
3060 let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
3061 assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
3062 }
3063
3064 #[cfg(feature = "cuda")]
3065 #[test]
3066 fn test_bitwise_xor_cuda() {
3067 use crate::utils::ops::BitWiseOp;
3068 use hanzo_ml::Tensor;
3069 let device = hanzo_ml::Device::new_cuda(0).unwrap();
3070 let a =
3071 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
3072 let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
3073 let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
3074 assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
3075 }
3076
3077 #[test]
3078 fn test_nonzero_and() {
3079 use crate::utils::ops::{BitWiseOp, NonZeroOp};
3080 use hanzo_ml::{Device, Tensor};
3081
3082 let input1 = Tensor::from_vec(
3083 vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
3084 (10,),
3085 &Device::Cpu,
3086 )
3087 .unwrap();
3088 let input2 = Tensor::from_vec(
3089 vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
3090 (10,),
3091 &Device::Cpu,
3092 )
3093 .unwrap();
3094 let input = Tensor::stack(&[input1, input2], 0).unwrap();
3095
3096 let lt = input.lt(0.0).unwrap();
3097 let gt = input.gt(-10.0).unwrap();
3098 let res = lt
3099 .bitwise_and(>)
3100 .unwrap()
3101 .nonzero()
3102 .unwrap()
3103 .to_vec2::<u32>()
3104 .unwrap();
3105
3106 assert_eq!(
3107 res,
3108 [
3109 [0, 3],
3110 [0, 4],
3111 [0, 5],
3112 [0, 6],
3113 [1, 0],
3114 [1, 3],
3115 [1, 5],
3116 [1, 6]
3117 ]
3118 );
3119 }
3120
3121 #[cfg(feature = "cuda")]
3122 #[test]
3123 fn nonzero_and_cuda() {
3124 use crate::utils::ops::{BitWiseOp, NonZeroOp};
3125 use hanzo_ml::{Device, Tensor};
3126
3127 let device = Device::new_cuda(0).unwrap();
3128 let input1 =
3129 Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
3130 let input2 =
3131 Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
3132 let input = Tensor::stack(&[input1, input2], 0).unwrap();
3133
3134 let lt = input.lt(0.0).unwrap();
3135 let gt = input.gt(-10.0).unwrap();
3136 let res = lt
3137 .bitwise_and(>)
3138 .unwrap()
3139 .nonzero()
3140 .unwrap()
3141 .to_vec2::<u32>()
3142 .unwrap();
3143
3144 assert_eq!(
3145 res,
3146 [
3147 [0, 3],
3148 [0, 4],
3149 [0, 5],
3150 [0, 6],
3151 [1, 0],
3152 [1, 3],
3153 [1, 5],
3154 [1, 6]
3155 ]
3156 );
3157 }
3158
3159 #[test]
3160 fn test_bitpack_8bit_cpu() {
3161 use crate::HqqBits;
3162 use hanzo_ml::{Device, Tensor};
3163 let bits = HqqBits::Eight;
3164 let device = Device::Cpu;
3165 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
3166 let c = bits.bitpack_type()(wq.clone())
3167 .unwrap()
3168 .to_vec2::<u8>()
3169 .unwrap();
3170 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
3171 }
3172
3173 #[cfg(feature = "cuda")]
3174 #[test]
3175 fn test_bitpack_8bit_cuda() {
3176 use crate::HqqBits;
3177 use hanzo_ml::{Device, Tensor};
3178 let bits = HqqBits::Eight;
3179 let device = Device::new_cuda(0).unwrap();
3180 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 255, 0], (3, 2), &device).unwrap();
3183 let c = bits.bitpack_type()(wq.clone())
3184 .unwrap()
3185 .to_vec2::<u8>()
3186 .unwrap();
3187 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
3188 }
3189
3190 #[cfg(feature = "metal")]
3191 #[test]
3192 fn test_bitpack_8bit_metal() {
3193 use crate::HqqBits;
3194 use hanzo_ml::{Device, Tensor};
3195 let bits = HqqBits::Eight;
3196 let device = Device::new_metal(0).unwrap();
3197 let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
3198 let c = bits.bitpack_type()(wq.clone())
3199 .unwrap()
3200 .to_vec2::<u8>()
3201 .unwrap();
3202 assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
3203 }
3204
3205 #[test]
3206 fn test_bitpack_4bit() {
3207 use crate::HqqBits;
3208 use hanzo_ml::{Device, Tensor};
3209 let bits = HqqBits::Four;
3210 let device = Device::Cpu;
3211 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
3212 let c = bits.bitpack_type()(wq.clone())
3213 .unwrap()
3214 .to_vec2::<u8>()
3215 .unwrap();
3216 assert_eq!(c, [[19, 36]]);
3217 }
3218
3219 #[cfg(feature = "cuda")]
3220 #[test]
3221 fn test_bitpack_4bit_cuda() {
3222 use crate::HqqBits;
3223 use hanzo_ml::{Device, Tensor};
3224 let bits = HqqBits::Four;
3225 let device = Device::new_cuda(0).unwrap();
3226 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
3227 let c = bits.bitpack_type()(wq.clone())
3228 .unwrap()
3229 .to_vec2::<u8>()
3230 .unwrap();
3231 assert_eq!(c, [[19, 36]]);
3232 }
3233
3234 #[cfg(feature = "metal")]
3235 #[test]
3236 fn test_bitpack_4bit_metal() {
3237 use crate::HqqBits;
3238 use hanzo_ml::{Device, Tensor};
3239 let bits = HqqBits::Four;
3240 let device = Device::new_metal(0).unwrap();
3241 let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
3242 let c = bits.bitpack_type()(wq.clone())
3243 .unwrap()
3244 .to_vec2::<u8>()
3245 .unwrap();
3246 assert_eq!(c, [[19, 36]]);
3247 }
3248 #[cfg(feature = "metal")]
3250 #[test]
3251 fn test_sort_and_argsort_vector_metal() {
3252 use crate::utils::ops::SortOp;
3253 use hanzo_ml::Tensor;
3254
3255 let device = hanzo_ml::Device::new_metal(0).unwrap();
3256 let a = Tensor::from_vec(vec![3i32, 1, 4, 2], &[4], &device).unwrap();
3257
3258 let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
3260 assert_eq!(sorted, [1, 2, 3, 4]);
3261
3262 let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
3264 assert_eq!(idx, [1, 3, 0, 2]);
3265 }
3266
3267 #[cfg(feature = "metal")]
3268 #[test]
3269 fn test_sort_and_argsort_matrix_axis1_metal() {
3270 use crate::utils::ops::SortOp;
3271 use hanzo_ml::Tensor;
3272
3273 let device = hanzo_ml::Device::new_metal(0).unwrap();
3274 let a = Tensor::from_vec(vec![3i32, 1, 2, 0, 4, 5], &[2, 3], &device).unwrap();
3278
3279 let sorted = a.fast_sort_asc(1).unwrap().to_vec2::<i32>().unwrap();
3281 assert_eq!(sorted, [[1, 2, 3], [0, 4, 5]]);
3282
3283 let idx = a.fast_argsort_asc(1).unwrap().to_vec2::<u32>().unwrap();
3285 assert_eq!(idx, [[1, 2, 0], [0, 1, 2]]);
3286 }
3287
3288 #[cfg(feature = "metal")]
3290 #[test]
3291 fn test_sort_and_argsort_vector_2048_metal() {
3292 use crate::utils::ops::SortOp;
3293 use hanzo_ml::Tensor;
3294
3295 const N: usize = 4096;
3296
3297 let device = hanzo_ml::Device::new_metal(0).expect("Metal device");
3298
3299 let vals: Vec<i32> = (0..N as i32).rev().collect();
3301 let a = Tensor::from_vec(vals.clone(), &[N], &device).unwrap();
3302
3303 let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
3305 let expected: Vec<i32> = (0..N as i32).collect();
3306 assert_eq!(sorted, expected);
3307
3308 let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
3310 for (i, &v) in idx.iter().enumerate() {
3312 assert_eq!(v as usize, N - 1 - i);
3313 }
3314 }
3315
3316 #[cfg(feature = "metal")]
3317 #[test]
3318 fn test_fused_glu_metal_silu_f32() {
3319 use super::{fused_glu, GluActivationType};
3320 use hanzo_ml::Tensor;
3321
3322 let cpu = hanzo_ml::Device::Cpu;
3323 let metal = hanzo_ml::Device::new_metal(0).unwrap();
3324
3325 let a_data: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 64.0).collect();
3326 let b_data: Vec<f32> = (0..256).map(|i| (i as f32 * 0.7 - 90.0) / 50.0).collect();
3327
3328 let a_cpu = Tensor::from_vec(a_data.clone(), &[4, 64], &cpu).unwrap();
3329 let b_cpu = Tensor::from_vec(b_data.clone(), &[4, 64], &cpu).unwrap();
3330 let a_metal = Tensor::from_vec(a_data, &[4, 64], &metal).unwrap();
3331 let b_metal = Tensor::from_vec(b_data, &[4, 64], &metal).unwrap();
3332
3333 let cpu_result = fused_glu(&a_cpu, &b_cpu, GluActivationType::Silu)
3334 .unwrap()
3335 .to_vec2::<f32>()
3336 .unwrap();
3337 let metal_result = fused_glu(&a_metal, &b_metal, GluActivationType::Silu)
3338 .unwrap()
3339 .to_device(&cpu)
3340 .unwrap()
3341 .to_vec2::<f32>()
3342 .unwrap();
3343
3344 for (row_cpu, row_metal) in cpu_result.iter().zip(metal_result.iter()) {
3345 for (c, m) in row_cpu.iter().zip(row_metal.iter()) {
3346 let diff = (c - m).abs();
3347 assert!(
3348 diff < 1e-4,
3349 "SiLU F32 mismatch: cpu={c}, metal={m}, diff={diff}"
3350 );
3351 }
3352 }
3353 }
3354
3355 #[cfg(feature = "metal")]
3356 #[test]
3357 fn test_fused_glu_metal_silu_f16() {
3358 use super::{fused_glu, GluActivationType};
3359 use hanzo_ml::{DType, Tensor};
3360
3361 let cpu = hanzo_ml::Device::Cpu;
3362 let metal = hanzo_ml::Device::new_metal(0).unwrap();
3363
3364 let a_data: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 64.0).collect();
3365 let b_data: Vec<f32> = (0..256).map(|i| (i as f32 * 0.7 - 90.0) / 50.0).collect();
3366
3367 let a_cpu = Tensor::from_vec(a_data.clone(), &[256], &cpu)
3368 .unwrap()
3369 .to_dtype(DType::F16)
3370 .unwrap();
3371 let b_cpu = Tensor::from_vec(b_data.clone(), &[256], &cpu)
3372 .unwrap()
3373 .to_dtype(DType::F16)
3374 .unwrap();
3375 let a_metal = Tensor::from_vec(a_data, &[256], &metal)
3376 .unwrap()
3377 .to_dtype(DType::F16)
3378 .unwrap();
3379 let b_metal = Tensor::from_vec(b_data, &[256], &metal)
3380 .unwrap()
3381 .to_dtype(DType::F16)
3382 .unwrap();
3383
3384 let cpu_result = fused_glu(&a_cpu, &b_cpu, GluActivationType::Silu)
3385 .unwrap()
3386 .to_dtype(DType::F32)
3387 .unwrap()
3388 .to_vec1::<f32>()
3389 .unwrap();
3390 let metal_result = fused_glu(&a_metal, &b_metal, GluActivationType::Silu)
3391 .unwrap()
3392 .to_device(&cpu)
3393 .unwrap()
3394 .to_dtype(DType::F32)
3395 .unwrap()
3396 .to_vec1::<f32>()
3397 .unwrap();
3398
3399 for (i, (c, m)) in cpu_result.iter().zip(metal_result.iter()).enumerate() {
3400 let diff = (c - m).abs();
3401 assert!(
3402 diff < 1e-2,
3403 "SiLU F16 mismatch at {i}: cpu={c}, metal={m}, diff={diff}"
3404 );
3405 }
3406 }
3407
3408 #[cfg(feature = "metal")]
3409 #[test]
3410 fn test_fused_glu_metal_all_activations() {
3411 use super::{fused_glu, GluActivationType};
3412 use hanzo_ml::Tensor;
3413
3414 let cpu = hanzo_ml::Device::Cpu;
3415 let metal = hanzo_ml::Device::new_metal(0).unwrap();
3416
3417 let a_data: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) / 32.0).collect();
3418 let b_data: Vec<f32> = (0..128).map(|i| (i as f32 * 0.5 - 32.0) / 20.0).collect();
3419
3420 for act in [
3421 GluActivationType::Silu,
3422 GluActivationType::Gelu,
3423 GluActivationType::Relu,
3424 GluActivationType::GeluErf,
3425 ] {
3426 let a_cpu = Tensor::from_vec(a_data.clone(), &[128], &cpu).unwrap();
3427 let b_cpu = Tensor::from_vec(b_data.clone(), &[128], &cpu).unwrap();
3428 let a_metal = Tensor::from_vec(a_data.clone(), &[128], &metal).unwrap();
3429 let b_metal = Tensor::from_vec(b_data.clone(), &[128], &metal).unwrap();
3430
3431 let cpu_result = fused_glu(&a_cpu, &b_cpu, act)
3432 .unwrap()
3433 .to_vec1::<f32>()
3434 .unwrap();
3435 let metal_result = fused_glu(&a_metal, &b_metal, act)
3436 .unwrap()
3437 .to_device(&cpu)
3438 .unwrap()
3439 .to_vec1::<f32>()
3440 .unwrap();
3441
3442 for (i, (c, m)) in cpu_result.iter().zip(metal_result.iter()).enumerate() {
3443 let diff = (c - m).abs();
3444 assert!(
3445 diff < 1e-4,
3446 "{act:?} F32 mismatch at {i}: cpu={c}, metal={m}, diff={diff}"
3447 );
3448 }
3449 }
3450 }
3451
3452 #[cfg(feature = "metal")]
3455 #[test]
3456 fn test_fused_glu_matches_fallback_bf16() {
3457 use super::{fused_glu, GluActivationType};
3458 use hanzo_ml::{DType, Tensor};
3459
3460 let metal = hanzo_ml::Device::new_metal(0).unwrap();
3461
3462 let n = 10240;
3464 let a_data: Vec<f32> = (0..n).map(|i| (i as f32 - 5120.0) / 2560.0).collect();
3465 let b_data: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3 - 1500.0) / 1000.0).collect();
3466
3467 let a_metal = Tensor::from_vec(a_data.clone(), &[1, 2, n / 2], &metal)
3468 .unwrap()
3469 .to_dtype(DType::BF16)
3470 .unwrap();
3471 let b_metal = Tensor::from_vec(b_data.clone(), &[1, 2, n / 2], &metal)
3472 .unwrap()
3473 .to_dtype(DType::BF16)
3474 .unwrap();
3475
3476 let fused = fused_glu(&a_metal, &b_metal, GluActivationType::Gelu).unwrap();
3478
3479 let fallback = (a_metal.gelu().unwrap() * &b_metal).unwrap();
3481
3482 let fused_f32 = fused
3483 .to_dtype(DType::F32)
3484 .unwrap()
3485 .flatten_all()
3486 .unwrap()
3487 .to_vec1::<f32>()
3488 .unwrap();
3489 let fallback_f32 = fallback
3490 .to_dtype(DType::F32)
3491 .unwrap()
3492 .flatten_all()
3493 .unwrap()
3494 .to_vec1::<f32>()
3495 .unwrap();
3496
3497 let mut max_diff: f32 = 0.0;
3498 let mut num_mismatches = 0;
3499 for (f, fb) in fused_f32.iter().zip(fallback_f32.iter()) {
3500 let diff = (f - fb).abs();
3501 if diff > max_diff {
3502 max_diff = diff;
3503 }
3504 if diff > 0.0 {
3505 num_mismatches += 1;
3506 }
3507 }
3508 eprintln!(
3509 "BF16 Gelu fused vs fallback: max_diff={max_diff}, mismatches={num_mismatches}/{}",
3510 fused_f32.len()
3511 );
3512 assert!(
3515 max_diff <= 0.015625,
3516 "BF16 Gelu fused vs reference fallback max_diff {max_diff} exceeds 1 BF16 ULP"
3517 );
3518 }
3519
3520 #[cfg(feature = "cuda")]
3521 #[test]
3522 fn test_fused_glu_cuda_silu_f32() {
3523 use super::{fused_glu, GluActivationType};
3524 use hanzo_ml::Tensor;
3525
3526 let cpu = hanzo_ml::Device::Cpu;
3527 let cuda = hanzo_ml::Device::new_cuda(0).unwrap();
3528
3529 let a_data: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 64.0).collect();
3530 let b_data: Vec<f32> = (0..256).map(|i| (i as f32 * 0.7 - 90.0) / 50.0).collect();
3531
3532 let a_cpu = Tensor::from_vec(a_data.clone(), &[4, 64], &cpu).unwrap();
3533 let b_cpu = Tensor::from_vec(b_data.clone(), &[4, 64], &cpu).unwrap();
3534 let a_cuda = Tensor::from_vec(a_data, &[4, 64], &cuda).unwrap();
3535 let b_cuda = Tensor::from_vec(b_data, &[4, 64], &cuda).unwrap();
3536
3537 let cpu_result = fused_glu(&a_cpu, &b_cpu, GluActivationType::Silu)
3538 .unwrap()
3539 .to_vec2::<f32>()
3540 .unwrap();
3541 let cuda_result = fused_glu(&a_cuda, &b_cuda, GluActivationType::Silu)
3542 .unwrap()
3543 .to_device(&cpu)
3544 .unwrap()
3545 .to_vec2::<f32>()
3546 .unwrap();
3547
3548 for (row_cpu, row_cuda) in cpu_result.iter().zip(cuda_result.iter()) {
3549 for (c, g) in row_cpu.iter().zip(row_cuda.iter()) {
3550 let diff = (c - g).abs();
3551 assert!(
3552 diff < 1e-4,
3553 "SiLU F32 mismatch: cpu={c}, cuda={g}, diff={diff}"
3554 );
3555 }
3556 }
3557 }
3558
3559 #[cfg(feature = "cuda")]
3560 #[test]
3561 fn test_fused_glu_cuda_silu_f16() {
3562 use super::{fused_glu, GluActivationType};
3563 use hanzo_ml::{DType, Tensor};
3564
3565 let cpu = hanzo_ml::Device::Cpu;
3566 let cuda = hanzo_ml::Device::new_cuda(0).unwrap();
3567
3568 let a_data: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 64.0).collect();
3569 let b_data: Vec<f32> = (0..256).map(|i| (i as f32 * 0.7 - 90.0) / 50.0).collect();
3570
3571 let a_cpu = Tensor::from_vec(a_data.clone(), &[256], &cpu)
3572 .unwrap()
3573 .to_dtype(DType::F16)
3574 .unwrap();
3575 let b_cpu = Tensor::from_vec(b_data.clone(), &[256], &cpu)
3576 .unwrap()
3577 .to_dtype(DType::F16)
3578 .unwrap();
3579 let a_cuda = Tensor::from_vec(a_data, &[256], &cuda)
3580 .unwrap()
3581 .to_dtype(DType::F16)
3582 .unwrap();
3583 let b_cuda = Tensor::from_vec(b_data, &[256], &cuda)
3584 .unwrap()
3585 .to_dtype(DType::F16)
3586 .unwrap();
3587
3588 let cpu_result = fused_glu(&a_cpu, &b_cpu, GluActivationType::Silu)
3589 .unwrap()
3590 .to_dtype(DType::F32)
3591 .unwrap()
3592 .to_vec1::<f32>()
3593 .unwrap();
3594 let cuda_result = fused_glu(&a_cuda, &b_cuda, GluActivationType::Silu)
3595 .unwrap()
3596 .to_device(&cpu)
3597 .unwrap()
3598 .to_dtype(DType::F32)
3599 .unwrap()
3600 .to_vec1::<f32>()
3601 .unwrap();
3602
3603 for (i, (c, g)) in cpu_result.iter().zip(cuda_result.iter()).enumerate() {
3604 let diff = (c - g).abs();
3605 assert!(
3606 diff < 1e-2,
3607 "SiLU F16 mismatch at {i}: cpu={c}, cuda={g}, diff={diff}"
3608 );
3609 }
3610 }
3611
3612 #[cfg(feature = "cuda")]
3613 #[test]
3614 fn test_fused_glu_cuda_all_activations() {
3615 use super::{fused_glu, GluActivationType};
3616 use hanzo_ml::Tensor;
3617
3618 let cpu = hanzo_ml::Device::Cpu;
3619 let cuda = hanzo_ml::Device::new_cuda(0).unwrap();
3620
3621 let a_data: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) / 32.0).collect();
3622 let b_data: Vec<f32> = (0..128).map(|i| (i as f32 * 0.5 - 32.0) / 20.0).collect();
3623
3624 for act in [
3625 GluActivationType::Silu,
3626 GluActivationType::Gelu,
3627 GluActivationType::Relu,
3628 GluActivationType::GeluErf,
3629 ] {
3630 let a_cpu = Tensor::from_vec(a_data.clone(), &[128], &cpu).unwrap();
3631 let b_cpu = Tensor::from_vec(b_data.clone(), &[128], &cpu).unwrap();
3632 let a_cuda = Tensor::from_vec(a_data.clone(), &[128], &cuda).unwrap();
3633 let b_cuda = Tensor::from_vec(b_data.clone(), &[128], &cuda).unwrap();
3634
3635 let cpu_result = fused_glu(&a_cpu, &b_cpu, act)
3636 .unwrap()
3637 .to_vec1::<f32>()
3638 .unwrap();
3639 let cuda_result = fused_glu(&a_cuda, &b_cuda, act)
3640 .unwrap()
3641 .to_device(&cpu)
3642 .unwrap()
3643 .to_vec1::<f32>()
3644 .unwrap();
3645
3646 for (i, (c, g)) in cpu_result.iter().zip(cuda_result.iter()).enumerate() {
3647 let diff = (c - g).abs();
3648 assert!(
3649 diff < 1e-4,
3650 "{act:?} F32 mismatch at {i}: cpu={c}, cuda={g}, diff={diff}"
3651 );
3652 }
3653 }
3654 }
3655
3656 #[cfg(feature = "cuda")]
3657 #[test]
3658 fn test_fused_glu_matches_fallback_bf16_cuda() {
3659 use super::{fused_glu, GluActivationType};
3660 use hanzo_ml::{DType, Tensor};
3661
3662 let cuda = hanzo_ml::Device::new_cuda(0).unwrap();
3663
3664 let n = 10240;
3665 let a_data: Vec<f32> = (0..n).map(|i| (i as f32 - 5120.0) / 2560.0).collect();
3666 let b_data: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3 - 1500.0) / 1000.0).collect();
3667
3668 let a_cuda = Tensor::from_vec(a_data.clone(), &[1, 2, n / 2], &cuda)
3669 .unwrap()
3670 .to_dtype(DType::BF16)
3671 .unwrap();
3672 let b_cuda = Tensor::from_vec(b_data.clone(), &[1, 2, n / 2], &cuda)
3673 .unwrap()
3674 .to_dtype(DType::BF16)
3675 .unwrap();
3676
3677 let fused = fused_glu(&a_cuda, &b_cuda, GluActivationType::Gelu).unwrap();
3679
3680 let fallback = (a_cuda.gelu().unwrap() * &b_cuda).unwrap();
3682
3683 let fused_f32 = fused
3684 .to_dtype(DType::F32)
3685 .unwrap()
3686 .flatten_all()
3687 .unwrap()
3688 .to_vec1::<f32>()
3689 .unwrap();
3690 let fallback_f32 = fallback
3691 .to_dtype(DType::F32)
3692 .unwrap()
3693 .flatten_all()
3694 .unwrap()
3695 .to_vec1::<f32>()
3696 .unwrap();
3697
3698 let mut max_diff: f32 = 0.0;
3699 let mut num_mismatches = 0;
3700 for (f, fb) in fused_f32.iter().zip(fallback_f32.iter()) {
3701 let diff = (f - fb).abs();
3702 if diff > max_diff {
3703 max_diff = diff;
3704 }
3705 if diff > 0.0 {
3706 num_mismatches += 1;
3707 }
3708 }
3709 eprintln!(
3710 "CUDA BF16 Gelu fused vs fallback: max_diff={max_diff}, mismatches={num_mismatches}/{}",
3711 fused_f32.len()
3712 );
3713 assert!(
3715 max_diff <= 0.015625,
3716 "CUDA BF16 Gelu fused vs reference fallback max_diff {max_diff} exceeds 1 BF16 ULP"
3717 );
3718 }
3719}