Skip to main content

hanzo_quant/utils/
ops.rs

1use hanzo_ml::{
2    backend::BackendStorage, shape::Dim, CpuStorage, CustomOp1, CustomOp2, DType, Error, Layout,
3    Result, Shape, Tensor, WithDType,
4};
5use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
6use rayon::slice::ParallelSliceMut;
7
8use std::{
9    fmt::Display,
10    ops::{BitAnd, BitOr, BitXor, Not, Shl},
11};
12
13#[cfg(feature = "cuda")]
14use crate::utils::{ffi, slice_ptr};
15#[cfg(feature = "cuda")]
16use hanzo_ml::cuda::{cudarc::driver::DevicePtr, CudaStorage};
17#[cfg(feature = "cuda")]
18use std::ffi::c_void;
19
20#[cfg(feature = "metal")]
21use crate::metal_kernels::SortScratchCache; // re‑export for clarity
22#[cfg(feature = "metal")]
23use std::sync::OnceLock;
24
25#[cfg(feature = "metal")]
26static SORT_SCRATCH_CACHE: OnceLock<SortScratchCache> = OnceLock::new();
27
28struct Leftshift(usize);
29
30impl Leftshift {
31    fn leftshift<T: WithDType + Shl<Output = T>>(&self, vs: &[T]) -> Vec<T> {
32        let offset = T::from_f64(self.0 as f64);
33        vs.into_par_iter().map(|v| *v << offset).collect()
34    }
35}
36
37impl CustomOp1 for Leftshift {
38    fn name(&self) -> &'static str {
39        "left"
40    }
41
42    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
43        match s1 {
44            CpuStorage::U8(vs1) => {
45                let vs1 = match l1.contiguous_offsets() {
46                    Some((a, b)) => &vs1[a..b],
47                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
48                };
49                let result = self.leftshift(vs1);
50                let result = CpuStorage::U8(result);
51                Ok((result, l1.shape().clone()))
52            }
53            CpuStorage::I16(vs1) => {
54                let vs1 = match l1.contiguous_offsets() {
55                    Some((a, b)) => &vs1[a..b],
56                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
57                };
58                let result = self.leftshift(vs1);
59                let result = CpuStorage::I16(result);
60                Ok((result, l1.shape().clone()))
61            }
62            CpuStorage::U32(vs1) => {
63                let vs1 = match l1.contiguous_offsets() {
64                    Some((a, b)) => &vs1[a..b],
65                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
66                };
67                let result = self.leftshift(vs1);
68                let result = CpuStorage::U32(result);
69                Ok((result, l1.shape().clone()))
70            }
71            CpuStorage::I64(vs1) => {
72                let vs1 = match l1.contiguous_offsets() {
73                    Some((a, b)) => &vs1[a..b],
74                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
75                };
76                let result = self.leftshift(vs1);
77                let result = CpuStorage::I64(result);
78                Ok((result, l1.shape().clone()))
79            }
80            CpuStorage::I32(vs1) => {
81                let vs1 = match l1.contiguous_offsets() {
82                    Some((a, b)) => &vs1[a..b],
83                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
84                };
85                let result = self.leftshift(vs1);
86                let result = CpuStorage::I32(result);
87                Ok((result, l1.shape().clone()))
88            }
89            _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "leftshift")),
90        }
91    }
92
93    #[cfg(feature = "cuda")]
94    fn cuda_fwd(&self, s1: &CudaStorage, l1: &Layout) -> Result<(CudaStorage, Shape)> {
95        if !l1.is_contiguous() {
96            hanzo_ml::bail!("Input tensor s1 must be contiguous");
97        }
98        let dev = s1.device().clone();
99        let (d_in1_ptr, _d_guard, elem_count) = match s1.dtype() {
100            DType::U8 => {
101                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
102                let elem_count = l1.shape().elem_count();
103                (d_in1 as *const c_void, d_in1_guard, elem_count)
104            }
105            DType::I32 => {
106                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
107                let elem_count = l1.shape().elem_count();
108                (d_in1 as *const c_void, d_in1_guard, elem_count)
109            }
110            other => {
111                return Err(Error::UnsupportedDTypeForOp(other, "leftshift"));
112            }
113        };
114        let dst = match s1.dtype() {
115            DType::U8 => {
116                let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
117                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
118                unsafe {
119                    ffi::leftshift_u8(
120                        d_in1_ptr,
121                        d_out_ptr as *mut std::ffi::c_void,
122                        u32::try_from(elem_count)?,
123                        self.0 as i32,
124                    )
125                };
126                drop(d_out_guard);
127                CudaStorage::wrap_cuda_slice(d_out, dev)
128            }
129            DType::I32 => {
130                let d_out = unsafe { dev.alloc::<i32>(elem_count) }?;
131                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
132                unsafe {
133                    ffi::leftshift_i32(
134                        d_in1_ptr,
135                        d_out_ptr as *mut std::ffi::c_void,
136                        u32::try_from(elem_count)?,
137                        self.0 as i32,
138                    )
139                };
140                drop(d_out_guard);
141                CudaStorage::wrap_cuda_slice(d_out, dev)
142            }
143            _ => unreachable!(),
144        };
145        Ok((dst, l1.shape().clone()))
146    }
147
148    #[cfg(feature = "metal")]
149    fn metal_fwd(
150        &self,
151        s1: &hanzo_ml::MetalStorage,
152        l1: &Layout,
153    ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
154        if !l1.is_contiguous() {
155            hanzo_ml::bail!("Input tensor s1 must be contiguous");
156        }
157
158        let encoder = s1.device().command_encoder()?;
159        encoder.set_label("bitwise-leftshift");
160
161        let device = s1.device();
162
163        let out_shape = l1.shape().clone();
164
165        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-leftshift")?;
166
167        crate::metal_kernels::call_bitwise_leftshift(
168            device.device(),
169            &encoder,
170            &crate::metal_kernels::Kernels::new(),
171            s1.dtype(),
172            s1.buffer(),
173            l1.start_offset(),
174            self.0 as u32,
175            out_shape.elem_count(),
176            &output,
177        )
178        .map_err(hanzo_ml::Error::wrap)?;
179
180        let newstorage =
181            hanzo_ml::MetalStorage::new(output, device.clone(), out_shape.elem_count(), s1.dtype());
182        Ok((newstorage, out_shape))
183    }
184}
185
186#[allow(dead_code)]
187pub trait LeftshiftOp {
188    fn leftshift(&self, n: usize) -> Result<Tensor>;
189}
190
191impl LeftshiftOp for Tensor {
192    fn leftshift(&self, n: usize) -> Result<Tensor> {
193        self.apply_op1_no_bwd(&Leftshift(n))
194    }
195}
196
197pub enum BitWiseBinaryOpEnum {
198    And,
199    Or,
200    Xor,
201}
202
203impl Display for BitWiseBinaryOpEnum {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        match self {
206            BitWiseBinaryOpEnum::And => write!(f, "And"),
207            BitWiseBinaryOpEnum::Or => write!(f, "Or"),
208            BitWiseBinaryOpEnum::Xor => write!(f, "Xor"),
209        }
210    }
211}
212
213pub enum BitWiseUnaryOpEnum {
214    Not,
215}
216
217impl Display for BitWiseUnaryOpEnum {
218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219        match self {
220            BitWiseUnaryOpEnum::Not => write!(f, "Not"),
221        }
222    }
223}
224
225struct BitWise {
226    pub op: BitWiseBinaryOpEnum,
227}
228
229impl BitWise {
230    pub fn new(op: BitWiseBinaryOpEnum) -> Self {
231        Self { op }
232    }
233
234    fn bitwise<T: WithDType + BitAnd<Output = T> + BitOr<Output = T> + BitXor<Output = T>>(
235        &self,
236        vs1: &[T],
237        vs2: &[T],
238    ) -> Vec<T> {
239        vs1.into_par_iter()
240            .zip_eq(vs2)
241            .map(|(v1, v2)| match self.op {
242                BitWiseBinaryOpEnum::And => *v1 & *v2,
243                BitWiseBinaryOpEnum::Or => *v1 | *v2,
244                BitWiseBinaryOpEnum::Xor => *v1 ^ *v2,
245            })
246            .collect()
247    }
248}
249
250impl CustomOp2 for BitWise {
251    fn name(&self) -> &'static str {
252        "bitwise"
253    }
254
255    fn cpu_fwd(
256        &self,
257        s1: &CpuStorage,
258        l1: &Layout,
259        s2: &CpuStorage,
260        l2: &Layout,
261    ) -> Result<(CpuStorage, Shape)> {
262        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
263            return Err(Error::ShapeMismatchBinaryOp {
264                lhs: l1.shape().clone(),
265                rhs: l2.shape().clone(),
266                op: "bitwise-op",
267            });
268        }
269        if s1.dtype() != s2.dtype() {
270            return Err(Error::DTypeMismatchBinaryOp {
271                lhs: s1.dtype(),
272                rhs: s2.dtype(),
273                op: "bitwise-op",
274            });
275        }
276        if !l1.is_contiguous() {
277            hanzo_ml::bail!("Input tensor s1 must be contiguous");
278        }
279        if !l2.is_contiguous() {
280            hanzo_ml::bail!("Input tensor s2 must be contiguous");
281        }
282
283        match s1 {
284            CpuStorage::U8(vs1) => {
285                let vs2 = s2.as_slice::<u8>().unwrap();
286                let vs1 = match l1.contiguous_offsets() {
287                    Some((a, b)) => &vs1[a..b],
288                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
289                };
290                let vs2 = match l2.contiguous_offsets() {
291                    Some((a, b)) => &vs2[a..b],
292                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
293                };
294                let result = self.bitwise(vs1, vs2);
295                let result = CpuStorage::U8(result);
296                Ok((result, l1.shape().clone()))
297            }
298            CpuStorage::U32(vs1) => {
299                let vs2 = s2.as_slice::<u32>().unwrap();
300                let vs1 = match l1.contiguous_offsets() {
301                    Some((a, b)) => &vs1[a..b],
302                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
303                };
304                let vs2 = match l2.contiguous_offsets() {
305                    Some((a, b)) => &vs2[a..b],
306                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
307                };
308                let result = self.bitwise(vs1, vs2);
309                let result = CpuStorage::U32(result);
310                Ok((result, l1.shape().clone()))
311            }
312            CpuStorage::I64(vs1) => {
313                let vs2 = s2.as_slice::<i64>().unwrap();
314                let vs1 = match l1.contiguous_offsets() {
315                    Some((a, b)) => &vs1[a..b],
316                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
317                };
318                let vs2 = match l2.contiguous_offsets() {
319                    Some((a, b)) => &vs2[a..b],
320                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
321                };
322                let result = self.bitwise(vs1, vs2);
323                let result = CpuStorage::I64(result);
324                Ok((result, l1.shape().clone()))
325            }
326            CpuStorage::I16(vs1) => {
327                let vs2 = s2.as_slice::<i16>().unwrap();
328                let vs1 = match l1.contiguous_offsets() {
329                    Some((a, b)) => &vs1[a..b],
330                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
331                };
332                let vs2 = match l2.contiguous_offsets() {
333                    Some((a, b)) => &vs2[a..b],
334                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
335                };
336                let result = self.bitwise(vs1, vs2);
337                let result = CpuStorage::I16(result);
338                Ok((result, l1.shape().clone()))
339            }
340            CpuStorage::I32(vs1) => {
341                let vs2 = s2.as_slice::<i32>().unwrap();
342                let vs1 = match l1.contiguous_offsets() {
343                    Some((a, b)) => &vs1[a..b],
344                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
345                };
346                let vs2 = match l2.contiguous_offsets() {
347                    Some((a, b)) => &vs2[a..b],
348                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
349                };
350                let result = self.bitwise(vs1, vs2);
351                let result = CpuStorage::I32(result);
352                Ok((result, l1.shape().clone()))
353            }
354            _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
355        }
356    }
357
358    #[cfg(feature = "cuda")]
359    fn cuda_fwd(
360        &self,
361        s1: &CudaStorage,
362        l1: &Layout,
363        s2: &CudaStorage,
364        l2: &Layout,
365    ) -> Result<(CudaStorage, Shape)> {
366        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
367            return Err(Error::ShapeMismatchBinaryOp {
368                lhs: l1.shape().clone(),
369                rhs: l2.shape().clone(),
370                op: "bitwise-op",
371            });
372        }
373        if s1.dtype() != s2.dtype() {
374            return Err(Error::DTypeMismatchBinaryOp {
375                lhs: s1.dtype(),
376                rhs: s2.dtype(),
377                op: "bitwise-op",
378            });
379        }
380        if !l1.is_contiguous() {
381            hanzo_ml::bail!("Input tensor s1 must be contiguous");
382        }
383        if !l2.is_contiguous() {
384            hanzo_ml::bail!("Input tensor s2 must be contiguous");
385        }
386
387        let dev = s1.device().clone();
388        let (d_in1_ptr, d_in2_ptr, _d_in1_guard, _d_in2_guard, elem_count) = match s1.dtype() {
389            DType::U8 => {
390                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u8>()?, l1.start_offset());
391                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u8>()?, l2.start_offset());
392                let elem_count = l1.shape().elem_count();
393                (
394                    d_in1 as *const std::ffi::c_void,
395                    d_in2 as *const std::ffi::c_void,
396                    d_in1_guard,
397                    d_in2_guard,
398                    elem_count,
399                )
400            }
401            DType::U32 => {
402                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<u32>()?, l1.start_offset());
403                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<u32>()?, l2.start_offset());
404                let elem_count = l1.shape().elem_count();
405                (
406                    d_in1 as *const std::ffi::c_void,
407                    d_in2 as *const std::ffi::c_void,
408                    d_in1_guard,
409                    d_in2_guard,
410                    elem_count,
411                )
412            }
413            DType::I64 => {
414                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i64>()?, l1.start_offset());
415                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i64>()?, l2.start_offset());
416                let elem_count = l1.shape().elem_count();
417                (
418                    d_in1 as *const std::ffi::c_void,
419                    d_in2 as *const std::ffi::c_void,
420                    d_in1_guard,
421                    d_in2_guard,
422                    elem_count,
423                )
424            }
425            DType::I32 => {
426                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i32>()?, l1.start_offset());
427                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i32>()?, l2.start_offset());
428                let elem_count = l1.shape().elem_count();
429                (
430                    d_in1 as *const std::ffi::c_void,
431                    d_in2 as *const std::ffi::c_void,
432                    d_in1_guard,
433                    d_in2_guard,
434                    elem_count,
435                )
436            }
437            DType::I16 => {
438                let (d_in1, d_in1_guard) = slice_ptr(s1.as_cuda_slice::<i16>()?, l1.start_offset());
439                let (d_in2, d_in2_guard) = slice_ptr(s2.as_cuda_slice::<i16>()?, l2.start_offset());
440                let elem_count = l1.shape().elem_count();
441                (
442                    d_in1 as *const std::ffi::c_void,
443                    d_in2 as *const std::ffi::c_void,
444                    d_in1_guard,
445                    d_in2_guard,
446                    elem_count,
447                )
448            }
449            other => {
450                return Err(Error::UnsupportedDTypeForOp(other, "bitwise"));
451            }
452        };
453        let dst = match s1.dtype() {
454            DType::U8 => {
455                let d_out = unsafe { dev.alloc::<u8>(elem_count) }?;
456                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
457                unsafe {
458                    match self.op {
459                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_u8(
460                            d_in1_ptr,
461                            d_in2_ptr,
462                            d_out_ptr as *mut c_void,
463                            u32::try_from(elem_count)?,
464                        ),
465                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u8(
466                            d_in1_ptr,
467                            d_in2_ptr,
468                            d_out_ptr as *mut c_void,
469                            u32::try_from(elem_count)?,
470                        ),
471                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u8(
472                            d_in1_ptr,
473                            d_in2_ptr,
474                            d_out_ptr as *mut c_void,
475                            u32::try_from(elem_count)?,
476                        ),
477                    }
478                };
479                drop(d_out_guard);
480                CudaStorage::wrap_cuda_slice(d_out, dev)
481            }
482            DType::U32 => {
483                let d_out = unsafe { dev.alloc::<u32>(elem_count) }?;
484                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
485                unsafe {
486                    match self.op {
487                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_u32(
488                            d_in1_ptr,
489                            d_in2_ptr,
490                            d_out_ptr as *mut c_void,
491                            u32::try_from(elem_count)?,
492                        ),
493                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_u32(
494                            d_in1_ptr,
495                            d_in2_ptr,
496                            d_out_ptr as *mut c_void,
497                            u32::try_from(elem_count)?,
498                        ),
499                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_u32(
500                            d_in1_ptr,
501                            d_in2_ptr,
502                            d_out_ptr as *mut c_void,
503                            u32::try_from(elem_count)?,
504                        ),
505                    }
506                };
507                drop(d_out_guard);
508                CudaStorage::wrap_cuda_slice(d_out, dev)
509            }
510            DType::I64 => {
511                let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
512                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
513                unsafe {
514                    match self.op {
515                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_i64(
516                            d_in1_ptr,
517                            d_in2_ptr,
518                            d_out_ptr as *mut c_void,
519                            u32::try_from(elem_count)?,
520                        ),
521                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i64(
522                            d_in1_ptr,
523                            d_in2_ptr,
524                            d_out_ptr as *mut c_void,
525                            u32::try_from(elem_count)?,
526                        ),
527                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i64(
528                            d_in1_ptr,
529                            d_in2_ptr,
530                            d_out_ptr as *mut c_void,
531                            u32::try_from(elem_count)?,
532                        ),
533                    }
534                };
535                drop(d_out_guard);
536                CudaStorage::wrap_cuda_slice(d_out, dev)
537            }
538            DType::I32 => {
539                let d_out = unsafe { dev.alloc::<i64>(elem_count) }?;
540                let (d_out_ptr, d_out_guard) = d_out.device_ptr(d_out.stream());
541                unsafe {
542                    match self.op {
543                        BitWiseBinaryOpEnum::And => ffi::bitwise_and_i32(
544                            d_in1_ptr,
545                            d_in2_ptr,
546                            d_out_ptr as *mut c_void,
547                            u32::try_from(elem_count)?,
548                        ),
549                        BitWiseBinaryOpEnum::Or => ffi::bitwise_or_i32(
550                            d_in1_ptr,
551                            d_in2_ptr,
552                            d_out_ptr as *mut c_void,
553                            u32::try_from(elem_count)?,
554                        ),
555                        BitWiseBinaryOpEnum::Xor => ffi::bitwise_xor_i32(
556                            d_in1_ptr,
557                            d_in2_ptr,
558                            d_out_ptr as *mut c_void,
559                            u32::try_from(elem_count)?,
560                        ),
561                    }
562                };
563                drop(d_out_guard);
564                CudaStorage::wrap_cuda_slice(d_out, dev)
565            }
566            _ => unreachable!(),
567        };
568        Ok((dst, l1.shape().clone()))
569    }
570
571    #[cfg(feature = "metal")]
572    fn metal_fwd(
573        &self,
574        s1: &hanzo_ml::MetalStorage,
575        l1: &Layout,
576        s2: &hanzo_ml::MetalStorage,
577        l2: &Layout,
578    ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
579        if l1.shape() != l2.shape() || l1.stride() != l2.stride() {
580            return Err(Error::ShapeMismatchBinaryOp {
581                lhs: l1.shape().clone(),
582                rhs: l2.shape().clone(),
583                op: "bitwise-op",
584            });
585        }
586        if s1.dtype() != s2.dtype() {
587            return Err(Error::DTypeMismatchBinaryOp {
588                lhs: s1.dtype(),
589                rhs: s2.dtype(),
590                op: "bitwise-op",
591            });
592        }
593        if !l1.is_contiguous() {
594            hanzo_ml::bail!("Input tensor s1 must be contiguous");
595        }
596        if !l2.is_contiguous() {
597            hanzo_ml::bail!("Input tensor s2 must be contiguous");
598        }
599
600        let encoder = s1.device().command_encoder()?;
601        encoder.set_label("bitwise-op");
602
603        let device = s1.device();
604
605        let out_shape = l1.shape().clone();
606
607        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
608
609        match self.op {
610            BitWiseBinaryOpEnum::Or => crate::metal_kernels::call_bitwise_or(
611                device.device(),
612                &encoder,
613                &crate::metal_kernels::Kernels::new(),
614                s1.dtype(),
615                s1.buffer(),
616                s2.buffer(),
617                l1.start_offset() * s1.dtype().size_in_bytes(),
618                l2.start_offset() * s2.dtype().size_in_bytes(),
619                out_shape.elem_count(),
620                &output,
621            )
622            .map_err(hanzo_ml::Error::wrap)?,
623            BitWiseBinaryOpEnum::And => crate::metal_kernels::call_bitwise_and(
624                device.device(),
625                &encoder,
626                &crate::metal_kernels::Kernels::new(),
627                s1.dtype(),
628                s1.buffer(),
629                s2.buffer(),
630                l1.start_offset() * s1.dtype().size_in_bytes(),
631                l2.start_offset() * s2.dtype().size_in_bytes(),
632                out_shape.elem_count(),
633                &output,
634            )
635            .map_err(hanzo_ml::Error::wrap)?,
636            BitWiseBinaryOpEnum::Xor => crate::metal_kernels::call_bitwise_xor(
637                device.device(),
638                &encoder,
639                &crate::metal_kernels::Kernels::new(),
640                s1.dtype(),
641                s1.buffer(),
642                s2.buffer(),
643                l1.start_offset() * s1.dtype().size_in_bytes(),
644                l2.start_offset() * s2.dtype().size_in_bytes(),
645                out_shape.elem_count(),
646                &output,
647            )
648            .map_err(hanzo_ml::Error::wrap)?,
649        }
650
651        let newstorage =
652            hanzo_ml::MetalStorage::new(output, device.clone(), out_shape.elem_count(), s1.dtype());
653        Ok((newstorage, out_shape))
654    }
655}
656
657struct BitWiseUnary {
658    pub op: BitWiseUnaryOpEnum,
659}
660
661impl BitWiseUnary {
662    pub fn new(op: BitWiseUnaryOpEnum) -> Self {
663        Self { op }
664    }
665
666    fn bitwise<T: WithDType + Not<Output = T>>(&self, vs1: &[T]) -> Vec<T> {
667        vs1.into_par_iter()
668            .map(|v1| match self.op {
669                BitWiseUnaryOpEnum::Not => !*v1,
670            })
671            .collect()
672    }
673}
674
675impl CustomOp1 for BitWiseUnary {
676    fn name(&self) -> &'static str {
677        "bitwise-unary"
678    }
679
680    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
681        if !l1.is_contiguous() {
682            hanzo_ml::bail!("Input tensor s1 must be contiguous");
683        }
684
685        match s1 {
686            CpuStorage::U8(vs1) => {
687                let vs1 = match l1.contiguous_offsets() {
688                    Some((a, b)) => &vs1[a..b],
689                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
690                };
691                let result = self.bitwise(vs1);
692                let result = CpuStorage::U8(result);
693                Ok((result, l1.shape().clone()))
694            }
695            CpuStorage::U32(vs1) => {
696                let vs1 = match l1.contiguous_offsets() {
697                    Some((a, b)) => &vs1[a..b],
698                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
699                };
700                let result = self.bitwise(vs1);
701                let result = CpuStorage::U32(result);
702                Ok((result, l1.shape().clone()))
703            }
704            CpuStorage::I64(vs1) => {
705                let vs1 = match l1.contiguous_offsets() {
706                    Some((a, b)) => &vs1[a..b],
707                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
708                };
709                let result = self.bitwise(vs1);
710                let result = CpuStorage::I64(result);
711                Ok((result, l1.shape().clone()))
712            }
713            CpuStorage::I16(vs1) => {
714                let vs1 = match l1.contiguous_offsets() {
715                    Some((a, b)) => &vs1[a..b],
716                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
717                };
718                let result = self.bitwise(vs1);
719                let result = CpuStorage::I16(result);
720                Ok((result, l1.shape().clone()))
721            }
722            CpuStorage::I32(vs1) => {
723                let vs1 = match l1.contiguous_offsets() {
724                    Some((a, b)) => &vs1[a..b],
725                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
726                };
727                let result = self.bitwise(vs1);
728                let result = CpuStorage::I32(result);
729                Ok((result, l1.shape().clone()))
730            }
731            _ => Err(Error::UnsupportedDTypeForOp(s1.dtype(), "bitwise")),
732        }
733    }
734
735    #[cfg(feature = "cuda")]
736    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
737        todo!()
738    }
739
740    #[cfg(feature = "metal")]
741    fn metal_fwd(
742        &self,
743        s1: &hanzo_ml::MetalStorage,
744        l1: &Layout,
745    ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
746        if !l1.is_contiguous() {
747            hanzo_ml::bail!("Input tensor s1 must be contiguous");
748        }
749
750        let encoder = s1.device().command_encoder()?;
751        encoder.set_label("bitwise-unary-op");
752
753        let device = s1.device();
754
755        let out_shape = l1.shape().clone();
756
757        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "bitwise-op")?;
758
759        match self.op {
760            BitWiseUnaryOpEnum::Not => crate::metal_kernels::call_bitwise_not(
761                device.device(),
762                &encoder,
763                &crate::metal_kernels::Kernels::new(),
764                s1.dtype(),
765                s1.buffer(),
766                l1.start_offset() * s1.dtype().size_in_bytes(),
767                out_shape.elem_count(),
768                &output,
769            )
770            .map_err(hanzo_ml::Error::wrap)?,
771        }
772
773        let newstorage =
774            hanzo_ml::MetalStorage::new(output, device.clone(), out_shape.elem_count(), s1.dtype());
775        Ok((newstorage, out_shape))
776    }
777}
778
779#[allow(dead_code)]
780pub trait BitWiseOp {
781    fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor>;
782    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor>;
783    fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor>;
784    fn bitwise_not(&self) -> Result<Tensor>;
785}
786
787impl BitWiseOp for Tensor {
788    fn bitwise_and(&self, rhs: &Tensor) -> Result<Tensor> {
789        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::And))
790    }
791
792    fn bitwise_or(&self, rhs: &Tensor) -> Result<Tensor> {
793        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Or))
794    }
795
796    fn bitwise_xor(&self, rhs: &Tensor) -> Result<Tensor> {
797        self.apply_op2_no_bwd(rhs, &BitWise::new(BitWiseBinaryOpEnum::Xor))
798    }
799
800    fn bitwise_not(&self) -> Result<Tensor> {
801        self.apply_op1_no_bwd(&BitWiseUnary::new(BitWiseUnaryOpEnum::Not))
802    }
803}
804
805// ────────────────────────────── ArgSort / Sort ────────────────────────────────
806
807#[allow(unused)]
808/// Configuration for an **argsort** (returns indices) operation.
809struct ArgSort {
810    axis: usize,
811}
812
813#[allow(unused)]
814/// Configuration for a **sort** (returns re‑ordered values) operation.
815struct Sort {
816    axis: usize,
817}
818
819impl CustomOp1 for ArgSort {
820    fn name(&self) -> &'static str {
821        "argsort"
822    }
823
824    // -------- CPU ------------------------------------------------------------
825    fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
826        hanzo_ml::bail!("ArgSort is not implemented for the CPU backend");
827    }
828
829    // -------- CUDA -----------------------------------------------------------
830    #[cfg(feature = "cuda")]
831    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
832        hanzo_ml::bail!("ArgSort is not implemented for the CUDA backend");
833    }
834
835    // -------- Metal ----------------------------------------------------------
836    #[cfg(feature = "metal")]
837    fn metal_fwd(
838        &self,
839        s1: &hanzo_ml::MetalStorage,
840        l1: &Layout,
841    ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
842        // Require contiguous input (same as other metal ops in this file)
843        if !l1.is_contiguous() {
844            hanzo_ml::bail!("Input tensor s1 must be contiguous");
845        }
846
847        // Create a command encoder and label it for easy debugging in Xcode’s GPU frame‑capture
848        let encoder = s1.device().command_encoder()?;
849        encoder.set_label("argsort");
850
851        let device = s1.device();
852        let out_shape = l1.shape().clone();
853        let elem_count = out_shape.elem_count();
854
855        // Output buffer holds the sorted indices -> always `U32`
856        let output = device.new_buffer(elem_count, hanzo_ml::DType::U32, "argsort")?;
857
858        // ------------------------------------------------------------------
859        // Obtain a scratch‑buffer set from the global LRU cache (cap=4)
860        // ------------------------------------------------------------------
861        let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
862
863        let dims = l1.dims();
864        let size_sorted_axis = dims[self.axis];
865        let n_rows = l1.shape().elem_count() / size_sorted_axis;
866
867        // Replicate the kernel’s internal block sizing to derive `n_blocks`
868        let tn = 4usize;
869        let mut bn = match size_sorted_axis.div_ceil(tn) {
870            v if v > 256 => 512,
871            v if v > 128 => 256,
872            v if v > 64 => 128,
873            v if v > 32 => 64,
874            _ => 32,
875        };
876        if bn == 512 && s1.dtype().size_in_bytes() > 4 {
877            bn = 256;
878        }
879        let n_per_block = bn * tn;
880        let n_blocks = size_sorted_axis.div_ceil(n_per_block);
881
882        // Borrow the buffers for this launch
883        let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
884
885        // ------------------------------------------------------------------
886        // Build the unified SortArgs payload
887        // ------------------------------------------------------------------
888        let sort_args = crate::metal_kernels::SortArgs {
889            axis: self.axis,
890            shape: l1.dims(),
891            strides: l1.stride(),
892            out_shape: l1.dims(), // same as input for argsort
893            out_strides: l1.stride(),
894            in_contiguous: l1.is_contiguous(),
895            in_ty: s1.dtype(),
896            out_ty: hanzo_ml::DType::U32,
897            src: s1.buffer(),
898            src_offset: l1.start_offset(), // element offset
899            dst: &output,
900            bn,
901            tn,
902            n_blocks,
903        };
904
905        // Launch the Metal kernel via the new API
906        crate::metal_kernels::call_argsort(
907            device.device(),
908            &encoder, // impl EncoderProvider
909            &crate::metal_kernels::Kernels::new(),
910            &sort_args,
911            &scratch,
912        )
913        .map_err(hanzo_ml::Error::wrap)?;
914
915        // Wrap and return as a new MetalStorage
916        let newstorage =
917            hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, hanzo_ml::DType::U32);
918        Ok((newstorage, out_shape))
919    }
920}
921
922impl CustomOp1 for Sort {
923    fn name(&self) -> &'static str {
924        "sort"
925    }
926
927    // -------- CPU ------------------------------------------------------------
928    fn cpu_fwd(&self, _s1: &CpuStorage, _l1: &Layout) -> Result<(CpuStorage, Shape)> {
929        hanzo_ml::bail!("Sort is not implemented for the CPU backend");
930    }
931
932    // -------- CUDA -----------------------------------------------------------
933    #[cfg(feature = "cuda")]
934    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
935        hanzo_ml::bail!("Sort is not implemented for the CUDA backend");
936    }
937
938    // -------- Metal ----------------------------------------------------------
939    #[cfg(feature = "metal")]
940    fn metal_fwd(
941        &self,
942        s1: &hanzo_ml::MetalStorage,
943        l1: &Layout,
944    ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
945        // Require contiguous input (same as other metal ops in this file)
946        if !l1.is_contiguous() {
947            hanzo_ml::bail!("Input tensor s1 must be contiguous");
948        }
949
950        // Create a command encoder and label it for easy debugging in Xcode’s GPU frame‑capture
951        let encoder = s1.device().command_encoder()?;
952        encoder.set_label("sort");
953
954        let device = s1.device();
955        let out_shape = l1.shape().clone();
956        let elem_count = out_shape.elem_count();
957
958        // Output buffer keeps the same dtype as the input (these are the reordered values)
959        let output = device.new_buffer(elem_count, s1.dtype(), "sort")?;
960
961        // ------------------------------------------------------------------
962        // Obtain a scratch‑buffer set from the global LRU cache (cap=4)
963        // ------------------------------------------------------------------
964        let cache = SORT_SCRATCH_CACHE.get_or_init(|| SortScratchCache::new(4));
965
966        let dims = l1.dims();
967        let size_sorted_axis = dims[self.axis];
968        let n_rows = l1.shape().elem_count() / size_sorted_axis;
969
970        // Replicate the kernel’s internal block sizing to derive `n_blocks`
971        let tn = 4usize;
972        let mut bn = match size_sorted_axis.div_ceil(tn) {
973            v if v > 256 => 512,
974            v if v > 128 => 256,
975            v if v > 64 => 128,
976            v if v > 32 => 64,
977            _ => 32,
978        };
979        if bn == 512 && s1.dtype().size_in_bytes() > 4 {
980            bn = 256;
981        }
982        let n_per_block = bn * tn;
983        let n_blocks = size_sorted_axis.div_ceil(n_per_block);
984
985        // Borrow the buffers for this launch
986        let scratch = cache.checkout(device, n_rows, size_sorted_axis, s1.dtype(), n_blocks);
987
988        // ------------------------------------------------------------------
989        // Build the unified SortArgs payload
990        // ------------------------------------------------------------------
991        let sort_args = crate::metal_kernels::SortArgs {
992            axis: self.axis,
993            shape: l1.dims(),
994            strides: l1.stride(),
995            out_shape: l1.dims(), // same shape for value sort
996            out_strides: l1.stride(),
997            in_contiguous: l1.is_contiguous(),
998            in_ty: s1.dtype(),
999            out_ty: s1.dtype(),
1000            src: s1.buffer(),
1001            src_offset: l1.start_offset(), // element offset
1002            dst: &output,
1003            bn,
1004            tn,
1005            n_blocks,
1006        };
1007
1008        // Launch the Metal kernel via the new API
1009        crate::metal_kernels::call_sort(
1010            device.device(),
1011            &encoder, // impl EncoderProvider
1012            &crate::metal_kernels::Kernels::new(),
1013            &sort_args,
1014            &scratch,
1015        )
1016        .map_err(hanzo_ml::Error::wrap)?;
1017
1018        // Wrap and return as a new MetalStorage
1019        let newstorage =
1020            hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
1021        Ok((newstorage, out_shape))
1022    }
1023}
1024
1025/// Extension trait adding `argsort` / `sort` convenience calls on `Tensor`.
1026pub trait SortOp {
1027    /// Returns the indices that would (ascending) sort the tensor along `axis`.
1028    fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1029    /// Returns the tensor's values (ascending) sorted along `axis`.
1030    fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor>;
1031}
1032
1033impl SortOp for Tensor {
1034    fn fast_argsort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1035        if self.device().is_cpu() || self.device().is_cuda() {
1036            return self.arg_sort_last_dim(true);
1037        }
1038        self.apply_op1_no_bwd(&ArgSort {
1039            axis: axis.to_index(self.shape(), "argsort")?,
1040        })
1041    }
1042
1043    fn fast_sort_asc<D: Dim>(&self, axis: D) -> Result<Tensor> {
1044        if self.device().is_cpu() || self.device().is_cuda() {
1045            return Ok(self.sort_last_dim(true)?.0);
1046        }
1047        self.apply_op1_no_bwd(&Sort {
1048            axis: axis.to_index(self.shape(), "sort")?,
1049        })
1050    }
1051}
1052
1053struct NonZero;
1054
1055impl NonZero {
1056    // Sequential version
1057    fn nonzero<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Vec<u32> {
1058        let n = layout.dims().len();
1059        let mut result = Vec::new();
1060        let mut indices = vec![0u32; n];
1061        for (i, v) in vs.iter().enumerate() {
1062            if !v.is_zero() {
1063                let mut idx = i;
1064                for (dim_index, dim) in layout.dims().iter().enumerate().rev() {
1065                    let d = idx % dim;
1066                    indices[dim_index] = u32::try_from(d).unwrap();
1067                    idx /= dim;
1068                }
1069                result.extend_from_slice(&indices);
1070            }
1071        }
1072        result
1073    }
1074}
1075
1076#[cfg(all(feature = "cuda", not(feature = "cuda-13000")))]
1077mod cuda_ops_cccl2 {
1078    use super::*;
1079
1080    pub(super) fn count_nonzero_cuda(
1081        dtype: hanzo_ml::DType,
1082        d_in: *const c_void,
1083        n: u32,
1084        stream: hanzo_ml::cuda::cudarc::driver::sys::CUstream,
1085    ) -> u32 {
1086        unsafe {
1087            match dtype {
1088                hanzo_ml::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1089                hanzo_ml::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1090                hanzo_ml::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1091                hanzo_ml::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1092                hanzo_ml::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1093                hanzo_ml::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1094                hanzo_ml::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1095                hanzo_ml::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1096                hanzo_ml::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1097                _ => unreachable!(),
1098            }
1099        }
1100    }
1101
1102    #[allow(clippy::too_many_arguments)]
1103    pub(super) fn nonzero_cuda(
1104        dtype: hanzo_ml::DType,
1105        d_in: *const c_void,
1106        n: u32,
1107        num_nonzero: u32,
1108        dims: *const c_void,
1109        num_dims: u32,
1110        d_out: *mut c_void,
1111        stream: hanzo_ml::cuda::cudarc::driver::sys::CUstream,
1112    ) {
1113        unsafe {
1114            match dtype {
1115                hanzo_ml::DType::U8 => {
1116                    ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1117                }
1118                hanzo_ml::DType::U32 => {
1119                    ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1120                }
1121                hanzo_ml::DType::I64 => {
1122                    ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1123                }
1124                hanzo_ml::DType::I32 => {
1125                    ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1126                }
1127                hanzo_ml::DType::I16 => {
1128                    ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1129                }
1130                hanzo_ml::DType::BF16 => {
1131                    ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1132                }
1133                hanzo_ml::DType::F16 => {
1134                    ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1135                }
1136                hanzo_ml::DType::F32 => {
1137                    ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1138                }
1139                hanzo_ml::DType::F64 => {
1140                    ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1141                }
1142                _ => unreachable!(),
1143            }
1144        }
1145    }
1146}
1147
1148#[cfg(all(feature = "cuda", feature = "cuda-13000"))]
1149mod cuda_ops_cccl3 {
1150    use super::*;
1151
1152    pub(super) fn count_nonzero_cuda(
1153        dtype: hanzo_ml::DType,
1154        d_in: *const c_void,
1155        n: u32,
1156        stream: hanzo_ml::cuda::cudarc::driver::sys::CUstream,
1157    ) -> u32 {
1158        unsafe {
1159            match dtype {
1160                hanzo_ml::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream),
1161                hanzo_ml::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream),
1162                hanzo_ml::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream),
1163                hanzo_ml::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream),
1164                hanzo_ml::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream),
1165                hanzo_ml::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream),
1166                hanzo_ml::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream),
1167                hanzo_ml::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream),
1168                hanzo_ml::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream),
1169                _ => unreachable!(),
1170            }
1171        }
1172    }
1173
1174    #[allow(clippy::too_many_arguments)]
1175    pub(super) fn nonzero_cuda(
1176        dtype: hanzo_ml::DType,
1177        d_in: *const c_void,
1178        n: u32,
1179        num_nonzero: u32,
1180        dims: *const c_void,
1181        num_dims: u32,
1182        d_out: *mut c_void,
1183        stream: hanzo_ml::cuda::cudarc::driver::sys::CUstream,
1184    ) {
1185        unsafe {
1186            match dtype {
1187                hanzo_ml::DType::U8 => {
1188                    ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1189                }
1190                hanzo_ml::DType::U32 => {
1191                    ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1192                }
1193                hanzo_ml::DType::I64 => {
1194                    ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1195                }
1196                hanzo_ml::DType::I32 => {
1197                    ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1198                }
1199                hanzo_ml::DType::I16 => {
1200                    ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1201                }
1202                hanzo_ml::DType::BF16 => {
1203                    ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1204                }
1205                hanzo_ml::DType::F16 => {
1206                    ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1207                }
1208                hanzo_ml::DType::F32 => {
1209                    ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1210                }
1211                hanzo_ml::DType::F64 => {
1212                    ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream)
1213                }
1214                _ => unreachable!(),
1215            }
1216        }
1217    }
1218}
1219
1220#[cfg(all(feature = "cuda", not(feature = "cuda-13000")))]
1221use cuda_ops_cccl2::{count_nonzero_cuda, nonzero_cuda};
1222#[cfg(all(feature = "cuda", feature = "cuda-13000"))]
1223use cuda_ops_cccl3::{count_nonzero_cuda, nonzero_cuda};
1224
1225impl CustomOp1 for NonZero {
1226    fn name(&self) -> &'static str {
1227        "nonzero"
1228    }
1229
1230    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
1231        if !layout.is_contiguous() {
1232            return Err(Error::RequiresContiguous { op: "nonzero" });
1233        }
1234        let result = match storage {
1235            hanzo_ml::CpuStorage::U8(vs) => self.nonzero(vs, layout),
1236            hanzo_ml::CpuStorage::U32(vs) => self.nonzero(vs, layout),
1237            hanzo_ml::CpuStorage::I16(vs) => self.nonzero(vs, layout),
1238            hanzo_ml::CpuStorage::I32(vs) => self.nonzero(vs, layout),
1239            hanzo_ml::CpuStorage::I64(vs) => self.nonzero(vs, layout),
1240            hanzo_ml::CpuStorage::BF16(vs) => self.nonzero(vs, layout),
1241            hanzo_ml::CpuStorage::F16(vs) => self.nonzero(vs, layout),
1242            hanzo_ml::CpuStorage::F32(vs) => self.nonzero(vs, layout),
1243            hanzo_ml::CpuStorage::F64(vs) => self.nonzero(vs, layout),
1244            _ => unreachable!(),
1245        };
1246        let index_len = layout.dims().len();
1247        let result_len = result.len() / index_len;
1248        let result = CpuStorage::U32(result);
1249        let shape = Shape::from_dims(&[result_len, index_len]);
1250        Ok((result, shape))
1251    }
1252
1253    #[cfg(feature = "cuda")]
1254    fn cuda_fwd(
1255        &self,
1256        storage: &hanzo_ml::CudaStorage,
1257        layout: &Layout,
1258    ) -> Result<(hanzo_ml::CudaStorage, Shape)> {
1259        if !layout.is_contiguous() {
1260            return Err(hanzo_ml::Error::RequiresContiguous { op: "nonzero" });
1261        }
1262        let dev = storage.device().clone();
1263        let (d_in, _d_in_guard) = match storage.dtype() {
1264            hanzo_ml::DType::U8 => {
1265                let slice = storage.as_cuda_slice::<u8>()?;
1266                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1267                (d_in as *const std::ffi::c_void, d_in_guard)
1268            }
1269            hanzo_ml::DType::U32 => {
1270                let slice = storage.as_cuda_slice::<u32>()?;
1271                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1272                (d_in as *const std::ffi::c_void, d_in_guard)
1273            }
1274            hanzo_ml::DType::I32 => {
1275                let slice = storage.as_cuda_slice::<i32>()?;
1276                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1277                (d_in as *const std::ffi::c_void, d_in_guard)
1278            }
1279            hanzo_ml::DType::I16 => {
1280                let slice = storage.as_cuda_slice::<i16>()?;
1281                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1282                (d_in as *const std::ffi::c_void, d_in_guard)
1283            }
1284            hanzo_ml::DType::I64 => {
1285                let slice = storage.as_cuda_slice::<i64>()?;
1286                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1287                (d_in as *const std::ffi::c_void, d_in_guard)
1288            }
1289            hanzo_ml::DType::BF16 => {
1290                let slice = storage.as_cuda_slice::<half::bf16>()?;
1291                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1292                (d_in as *const std::ffi::c_void, d_in_guard)
1293            }
1294            hanzo_ml::DType::F16 => {
1295                let slice = storage.as_cuda_slice::<half::f16>()?;
1296                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1297                (d_in as *const std::ffi::c_void, d_in_guard)
1298            }
1299            hanzo_ml::DType::F32 => {
1300                let slice = storage.as_cuda_slice::<f32>()?;
1301                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1302                (d_in as *const std::ffi::c_void, d_in_guard)
1303            }
1304            hanzo_ml::DType::F64 => {
1305                let slice = storage.as_cuda_slice::<f64>()?;
1306                let (d_in, d_in_guard) = slice_ptr(slice, 0);
1307                (d_in as *const std::ffi::c_void, d_in_guard)
1308            }
1309            _ => unreachable!(),
1310        };
1311        let n = layout.shape().elem_count();
1312
1313        let num_nonzero = count_nonzero_cuda(
1314            storage.dtype(),
1315            d_in,
1316            u32::try_from(n)?,
1317            dev.cuda_stream().cu_stream(),
1318        );
1319        let d_out = unsafe { dev.alloc::<u32>(num_nonzero as usize * layout.dims().len()) }
1320            .map_err(|_| Error::Msg("Failed to allocate memory for nonzero result".to_string()))?;
1321        if num_nonzero != 0 {
1322            let (d_out, _d_out_guard) = d_out.device_ptr(d_out.stream());
1323            let dims = layout
1324                .dims()
1325                .iter()
1326                .map(|&x| u32::try_from(x).unwrap())
1327                .collect::<Vec<u32>>();
1328            let mut d_dims = unsafe { dev.alloc::<u32>(dims.len()) }?;
1329            dev.memcpy_htod(&dims, &mut d_dims)?;
1330            let (d_dims_ptr, _d_dims_guard) = d_dims.device_ptr(d_dims.stream());
1331            nonzero_cuda(
1332                storage.dtype(),
1333                d_in,
1334                u32::try_from(n)?,
1335                num_nonzero,
1336                d_dims_ptr as *const c_void,
1337                u32::try_from(layout.dims().len())?,
1338                d_out as *mut c_void,
1339                dev.cuda_stream().cu_stream(),
1340            );
1341        }
1342        let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]);
1343        let dst = hanzo_ml::CudaStorage::wrap_cuda_slice(d_out, dev);
1344        Ok((dst, shape))
1345    }
1346}
1347
1348pub trait NonZeroOp {
1349    fn nonzero(&self) -> Result<Tensor>;
1350}
1351
1352impl NonZeroOp for Tensor {
1353    #[cfg(feature = "metal")]
1354    fn nonzero(&self) -> Result<Tensor> {
1355        if !self.is_contiguous() {
1356            return Err(hanzo_ml::Error::RequiresContiguous { op: "nonzero" });
1357        }
1358        let original_device = self.device();
1359        self.to_device(&hanzo_ml::Device::Cpu)?
1360            .apply_op1_no_bwd(&NonZero)?
1361            .to_device(original_device)
1362    }
1363
1364    #[cfg(not(feature = "metal"))]
1365    fn nonzero(&self) -> Result<Tensor> {
1366        if !self.is_contiguous() {
1367            return Err(hanzo_ml::Error::RequiresContiguous { op: "nonzero" });
1368        }
1369        self.apply_op1_no_bwd(&NonZero)
1370    }
1371}
1372
1373struct CumSum {
1374    inclusive: bool,
1375    reverse: bool,
1376    axis: usize,
1377}
1378
1379impl CustomOp1 for CumSum {
1380    fn name(&self) -> &'static str {
1381        "cumsum"
1382    }
1383
1384    fn cpu_fwd(&self, s1: &CpuStorage, l1: &Layout) -> Result<(CpuStorage, Shape)> {
1385        use std::ops::Add;
1386        if !l1.is_contiguous() {
1387            hanzo_ml::bail!("Input tensor s1 must be contiguous");
1388        }
1389        let dims = l1.dims();
1390        let axis = self.axis;
1391        let axis_len = dims[axis];
1392        let (start, end) = l1
1393            .contiguous_offsets()
1394            .ok_or(Error::RequiresContiguous { op: "cumsum" })?;
1395
1396        // helper to execute scan for a slice of T
1397        macro_rules! scan_block {
1398            ($vt:ident, $ty:ty, $add:ident, $init:expr) => {{
1399                let vs: &[$ty] = $vt;
1400                let input = &vs[start..end];
1401                let count = input.len() / axis_len;
1402                let mut result = Vec::<$ty>::with_capacity(input.len());
1403                if !self.reverse {
1404                    if self.inclusive {
1405                        for block in 0..count {
1406                            let base = block * axis_len;
1407                            let mut sum = input[base];
1408                            result.push(sum);
1409                            for j in 1..axis_len {
1410                                sum = sum.$add(input[base + j]);
1411                                result.push(sum);
1412                            }
1413                        }
1414                    } else {
1415                        let init: $ty = $init;
1416                        for block in 0..count {
1417                            let base = block * axis_len;
1418                            let mut sum = init;
1419                            for j in 0..axis_len {
1420                                result.push(sum);
1421                                sum = sum.$add(input[base + j]);
1422                            }
1423                        }
1424                    }
1425                } else {
1426                    if self.inclusive {
1427                        for block in 0..count {
1428                            let base = block * axis_len;
1429                            let mut temp = Vec::<$ty>::with_capacity(axis_len);
1430                            let mut sum = input[base + axis_len - 1];
1431                            temp.push(sum);
1432                            for k in 1..axis_len {
1433                                let idx = axis_len - 1 - k;
1434                                sum = sum.$add(input[base + idx]);
1435                                temp.push(sum);
1436                            }
1437                            temp.reverse();
1438                            result.extend(temp);
1439                        }
1440                    } else {
1441                        let init: $ty = $init;
1442                        for block in 0..count {
1443                            let base = block * axis_len;
1444                            let mut temp = Vec::<$ty>::with_capacity(axis_len);
1445                            let mut sum = init;
1446                            for k in 0..axis_len {
1447                                let idx = axis_len - 1 - k;
1448                                temp.push(sum);
1449                                sum = sum.$add(input[base + idx]);
1450                            }
1451                            temp.reverse();
1452                            result.extend(temp);
1453                        }
1454                    }
1455                }
1456                result
1457            }};
1458        }
1459        match s1 {
1460            CpuStorage::U8(vs) => {
1461                let result = scan_block!(vs, u8, wrapping_add, 0u8);
1462                Ok((CpuStorage::U8(result), l1.shape().clone()))
1463            }
1464            CpuStorage::I16(vs) => {
1465                let result = scan_block!(vs, i16, add, 0i16);
1466                Ok((CpuStorage::I16(result), l1.shape().clone()))
1467            }
1468            CpuStorage::U32(vs) => {
1469                let result = scan_block!(vs, u32, wrapping_add, 0u32);
1470                Ok((CpuStorage::U32(result), l1.shape().clone()))
1471            }
1472            CpuStorage::I32(vs) => {
1473                let result = scan_block!(vs, i32, add, 0i32);
1474                Ok((CpuStorage::I32(result), l1.shape().clone()))
1475            }
1476            CpuStorage::I64(vs) => {
1477                let result = scan_block!(vs, i64, add, 0i64);
1478                Ok((CpuStorage::I64(result), l1.shape().clone()))
1479            }
1480            CpuStorage::F32(vs) => {
1481                let result = scan_block!(vs, f32, add, 0.0f32);
1482                Ok((CpuStorage::F32(result), l1.shape().clone()))
1483            }
1484            CpuStorage::F64(vs) => {
1485                let result = scan_block!(vs, f64, add, 0.0f64);
1486                Ok((CpuStorage::F64(result), l1.shape().clone()))
1487            }
1488            _ => Err(Error::UnsupportedDTypeForOp(DType::F32, "cumsum")),
1489        }
1490    }
1491
1492    #[cfg(feature = "cuda")]
1493    fn cuda_fwd(&self, _s1: &CudaStorage, _l1: &Layout) -> Result<(CudaStorage, Shape)> {
1494        todo!()
1495    }
1496
1497    #[cfg(feature = "metal")]
1498    fn metal_fwd(
1499        &self,
1500        s1: &hanzo_ml::MetalStorage,
1501        l1: &Layout,
1502    ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
1503        use crate::metal_kernels::ScanType;
1504
1505        let encoder = s1.device().command_encoder()?;
1506        encoder.set_label("cumsum");
1507
1508        let device = s1.device();
1509
1510        let out_shape = l1.shape().clone();
1511
1512        let output = device.new_buffer(out_shape.elem_count(), s1.dtype(), "cumsum")?;
1513
1514        crate::metal_kernels::call_scan(
1515            device.device(),
1516            &encoder,
1517            &crate::metal_kernels::Kernels::new(),
1518            s1.dtype(),
1519            ScanType::Sum,
1520            s1.buffer(),
1521            l1.start_offset() * s1.dtype().size_in_bytes(),
1522            self.axis,
1523            l1.dims(),
1524            l1.stride(),
1525            self.reverse,
1526            self.inclusive,
1527            &output,
1528        )
1529        .map_err(hanzo_ml::Error::wrap)?;
1530
1531        let newstorage =
1532            hanzo_ml::MetalStorage::new(output, device.clone(), out_shape.elem_count(), s1.dtype());
1533        Ok((newstorage, out_shape))
1534    }
1535}
1536
1537#[allow(dead_code)]
1538pub trait CumSumOp {
1539    /// inclusive = false, reverse = false
1540    fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor>;
1541
1542    fn fast_cumsum_config<D: Dim>(&self, axis: D, inclusive: bool, reverse: bool)
1543        -> Result<Tensor>;
1544}
1545
1546impl CumSumOp for Tensor {
1547    fn fast_cumsum<D: Dim>(&self, axis: D) -> Result<Tensor> {
1548        self.fast_cumsum_config(axis, false, false)
1549    }
1550
1551    fn fast_cumsum_config<D: Dim>(
1552        &self,
1553        axis: D,
1554        inclusive: bool,
1555        reverse: bool,
1556    ) -> Result<Tensor> {
1557        self.apply_op1_no_bwd(&CumSum {
1558            inclusive,
1559            reverse,
1560            axis: axis.to_index(self.shape(), "cumsum")?,
1561        })
1562    }
1563}
1564
1565/// Fused GPT-OSS SwiGLU activation
1566/// Formula: output = (clamp(up, -limit, limit) + 1) * gate_clamped * sigmoid(gate_clamped * alpha)
1567/// where gate_clamped = min(gate, limit)
1568#[cfg(feature = "cuda")]
1569pub fn gptoss_swiglu_fused(gate: &Tensor, up: &Tensor, alpha: f32, limit: f32) -> Result<Tensor> {
1570    use half::{bf16, f16};
1571
1572    let gate = gate.contiguous()?;
1573    let up = up.contiguous()?;
1574
1575    if gate.shape() != up.shape() {
1576        hanzo_ml::bail!(
1577            "gptoss_swiglu: gate and up must have same shape, got {:?} vs {:?}",
1578            gate.shape(),
1579            up.shape()
1580        );
1581    }
1582
1583    let device = match gate.device() {
1584        hanzo_ml::Device::Cuda(dev) => dev,
1585        _ => hanzo_ml::bail!("gptoss_swiglu requires CUDA device"),
1586    };
1587
1588    let n_elements = gate.elem_count();
1589    let dtype = gate.dtype();
1590
1591    let gate_storage = gate.storage_and_layout().0;
1592    let up_storage = up.storage_and_layout().0;
1593
1594    let gate_cuda = match &*gate_storage {
1595        hanzo_ml::Storage::Cuda(s) => s,
1596        _ => hanzo_ml::bail!("Expected CUDA storage for gate"),
1597    };
1598    let up_cuda = match &*up_storage {
1599        hanzo_ml::Storage::Cuda(s) => s,
1600        _ => hanzo_ml::bail!("Expected CUDA storage for up"),
1601    };
1602
1603    let stream = device.cuda_stream().cu_stream();
1604
1605    match dtype {
1606        DType::F16 => {
1607            let output = device.alloc_zeros::<f16>(n_elements)?;
1608            let gate_slice = gate_cuda.as_cuda_slice::<f16>()?;
1609            let up_slice = up_cuda.as_cuda_slice::<f16>()?;
1610
1611            let (gate_ptr, _g_guard) = slice_ptr(gate_slice, 0);
1612            let (up_ptr, _u_guard) = slice_ptr(up_slice, 0);
1613            let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1614
1615            unsafe {
1616                ffi::gptoss_swiglu_f16(
1617                    gate_ptr as *const c_void,
1618                    up_ptr as *const c_void,
1619                    out_ptr as *mut c_void,
1620                    n_elements as u32,
1621                    alpha,
1622                    limit,
1623                    stream,
1624                );
1625            }
1626
1627            drop(_o_guard);
1628            let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1629            Ok(Tensor::from((
1630                hanzo_ml::Storage::Cuda(out_storage),
1631                gate.shape().clone(),
1632            )))
1633        }
1634        DType::BF16 => {
1635            let output = device.alloc_zeros::<bf16>(n_elements)?;
1636            let gate_slice = gate_cuda.as_cuda_slice::<bf16>()?;
1637            let up_slice = up_cuda.as_cuda_slice::<bf16>()?;
1638
1639            let (gate_ptr, _g_guard) = slice_ptr(gate_slice, 0);
1640            let (up_ptr, _u_guard) = slice_ptr(up_slice, 0);
1641            let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1642
1643            unsafe {
1644                ffi::gptoss_swiglu_bf16(
1645                    gate_ptr as *const c_void,
1646                    up_ptr as *const c_void,
1647                    out_ptr as *mut c_void,
1648                    n_elements as u32,
1649                    alpha,
1650                    limit,
1651                    stream,
1652                );
1653            }
1654
1655            drop(_o_guard);
1656            let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1657            Ok(Tensor::from((
1658                hanzo_ml::Storage::Cuda(out_storage),
1659                gate.shape().clone(),
1660            )))
1661        }
1662        DType::F32 => {
1663            let output = device.alloc_zeros::<f32>(n_elements)?;
1664            let gate_slice = gate_cuda.as_cuda_slice::<f32>()?;
1665            let up_slice = up_cuda.as_cuda_slice::<f32>()?;
1666
1667            let (gate_ptr, _g_guard) = slice_ptr(gate_slice, 0);
1668            let (up_ptr, _u_guard) = slice_ptr(up_slice, 0);
1669            let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1670
1671            unsafe {
1672                ffi::gptoss_swiglu_f32(
1673                    gate_ptr as *const c_void,
1674                    up_ptr as *const c_void,
1675                    out_ptr as *mut c_void,
1676                    n_elements as u32,
1677                    alpha,
1678                    limit,
1679                    stream,
1680                );
1681            }
1682
1683            drop(_o_guard);
1684            let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1685            Ok(Tensor::from((
1686                hanzo_ml::Storage::Cuda(out_storage),
1687                gate.shape().clone(),
1688            )))
1689        }
1690        _ => hanzo_ml::bail!("gptoss_swiglu: unsupported dtype {:?}", dtype),
1691    }
1692}
1693
1694/// Fused GPT-OSS SwiGLU for interleaved gate/up data.
1695///
1696/// This handles interleaved gate/up format directly, avoiding 2 tensor copies
1697/// from narrow().squeeze().contiguous().
1698///
1699/// Args:
1700///   gate_up: [N, intermediate_size, 2] - interleaved gate/up data
1701///   alpha: SwiGLU alpha parameter
1702///   limit: SwiGLU limit parameter
1703///
1704/// Returns: [N, intermediate_size] - activated output
1705#[cfg(feature = "cuda")]
1706pub fn gptoss_swiglu_interleaved(
1707    gate_up: &Tensor,
1708    intermediate_size: usize,
1709    alpha: f32,
1710    limit: f32,
1711) -> Result<Tensor> {
1712    use half::{bf16, f16};
1713    use std::ffi::c_void;
1714
1715    let gate_up = gate_up.contiguous()?;
1716
1717    let dims = gate_up.dims();
1718    if dims.len() != 3 || dims[2] != 2 {
1719        hanzo_ml::bail!(
1720            "gptoss_swiglu_interleaved: expected gate_up shape [N, intermediate_size, 2], got {:?}",
1721            dims
1722        );
1723    }
1724
1725    let n = dims[0]; // num_tokens * topk
1726    let device = match gate_up.device() {
1727        hanzo_ml::Device::Cuda(dev) => dev,
1728        _ => hanzo_ml::bail!("gptoss_swiglu_interleaved requires CUDA device"),
1729    };
1730
1731    let dtype = gate_up.dtype();
1732    let n_output_elements = n * intermediate_size;
1733
1734    let gate_up_storage = gate_up.storage_and_layout().0;
1735    let gate_up_cuda = match &*gate_up_storage {
1736        hanzo_ml::Storage::Cuda(s) => s,
1737        _ => hanzo_ml::bail!("Expected CUDA storage for gate_up"),
1738    };
1739
1740    let stream = device.cuda_stream().cu_stream();
1741
1742    match dtype {
1743        DType::F16 => {
1744            let output = device.alloc_zeros::<f16>(n_output_elements)?;
1745            let gate_up_slice = gate_up_cuda.as_cuda_slice::<f16>()?;
1746
1747            let (gate_up_ptr, _gu_guard) = slice_ptr(gate_up_slice, 0);
1748            let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1749
1750            unsafe {
1751                ffi::gptoss_swiglu_interleaved_f16(
1752                    gate_up_ptr as *const c_void,
1753                    out_ptr as *mut c_void,
1754                    n as u32,
1755                    intermediate_size as u32,
1756                    alpha,
1757                    limit,
1758                    stream,
1759                );
1760            }
1761
1762            drop(_o_guard);
1763            let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1764            Ok(Tensor::from((
1765                hanzo_ml::Storage::Cuda(out_storage),
1766                Shape::from(vec![n, intermediate_size]),
1767            )))
1768        }
1769        DType::BF16 => {
1770            let output = device.alloc_zeros::<bf16>(n_output_elements)?;
1771            let gate_up_slice = gate_up_cuda.as_cuda_slice::<bf16>()?;
1772
1773            let (gate_up_ptr, _gu_guard) = slice_ptr(gate_up_slice, 0);
1774            let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1775
1776            unsafe {
1777                ffi::gptoss_swiglu_interleaved_bf16(
1778                    gate_up_ptr as *const c_void,
1779                    out_ptr as *mut c_void,
1780                    n as u32,
1781                    intermediate_size as u32,
1782                    alpha,
1783                    limit,
1784                    stream,
1785                );
1786            }
1787
1788            drop(_o_guard);
1789            let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1790            Ok(Tensor::from((
1791                hanzo_ml::Storage::Cuda(out_storage),
1792                Shape::from(vec![n, intermediate_size]),
1793            )))
1794        }
1795        DType::F32 => {
1796            let output = device.alloc_zeros::<f32>(n_output_elements)?;
1797            let gate_up_slice = gate_up_cuda.as_cuda_slice::<f32>()?;
1798
1799            let (gate_up_ptr, _gu_guard) = slice_ptr(gate_up_slice, 0);
1800            let (out_ptr, _o_guard) = slice_ptr(&output, 0);
1801
1802            unsafe {
1803                ffi::gptoss_swiglu_interleaved_f32(
1804                    gate_up_ptr as *const c_void,
1805                    out_ptr as *mut c_void,
1806                    n as u32,
1807                    intermediate_size as u32,
1808                    alpha,
1809                    limit,
1810                    stream,
1811                );
1812            }
1813
1814            drop(_o_guard);
1815            let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
1816            Ok(Tensor::from((
1817                hanzo_ml::Storage::Cuda(out_storage),
1818                Shape::from(vec![n, intermediate_size]),
1819            )))
1820        }
1821        _ => hanzo_ml::bail!("gptoss_swiglu_interleaved: unsupported dtype {:?}", dtype),
1822    }
1823}
1824
1825/// Fused softmax with sinks for GPT-OSS attention.
1826///
1827/// This computes softmax over attention logits while including a per-head "sink" value
1828/// in the normalization, then drops the sink from the output.
1829///
1830/// Args:
1831///   logits: [batch, heads, q_len, k_len] - attention scores (q @ k.T * scale)
1832///   sinks: [heads] - per-head sink values
1833///   mask: Optional [batch, 1, q_len, k_len] - attention mask (0 = attend, -inf = mask)
1834///
1835/// Returns: [batch, heads, q_len, k_len] - softmax probabilities (sink dropped from normalization)
1836struct SoftmaxWithSinks {
1837    sinks: Tensor,
1838    num_heads: usize,
1839    q_len: usize,
1840    k_len: usize,
1841}
1842
1843impl CustomOp1 for SoftmaxWithSinks {
1844    fn name(&self) -> &'static str {
1845        "softmax-with-sinks"
1846    }
1847
1848    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
1849        use half::{bf16, f16};
1850
1851        let out_shape = layout.shape().clone();
1852        let total_rows = out_shape.elem_count() / self.k_len;
1853        let k_len = self.k_len;
1854        let num_heads = self.num_heads;
1855        let q_len = self.q_len;
1856        let offset = layout.start_offset();
1857
1858        let sinks_data = self.sinks.storage_and_layout();
1859        let sinks_cpu = match &*sinks_data.0 {
1860            hanzo_ml::Storage::Cpu(s) => s,
1861            _ => hanzo_ml::bail!("softmax_with_sinks cpu_fwd: sinks must be on CPU"),
1862        };
1863        let sinks_offset = sinks_data.1.start_offset();
1864
1865        match storage.dtype() {
1866            DType::F32 => {
1867                let logits = storage.as_slice::<f32>()?;
1868                let sinks_vals = sinks_cpu.as_slice::<f32>()?;
1869
1870                let mut result = vec![0f32; total_rows * k_len];
1871                result
1872                    .par_chunks_mut(k_len)
1873                    .enumerate()
1874                    .for_each(|(row, out_row)| {
1875                        let h = (row / q_len) % num_heads;
1876                        let sink_val = sinks_vals[sinks_offset + h];
1877                        let row_start = offset + row * k_len;
1878
1879                        let mut max_val = sink_val;
1880                        for k in 0..k_len {
1881                            let v = logits[row_start + k];
1882                            if v > max_val {
1883                                max_val = v;
1884                            }
1885                        }
1886
1887                        let mut sum = (sink_val - max_val).exp();
1888                        for k in 0..k_len {
1889                            let e = (logits[row_start + k] - max_val).exp();
1890                            out_row[k] = e;
1891                            sum += e;
1892                        }
1893
1894                        let inv_sum = 1.0 / sum;
1895                        for item in out_row.iter_mut().take(k_len) {
1896                            *item *= inv_sum;
1897                        }
1898                    });
1899
1900                Ok((CpuStorage::F32(result), out_shape))
1901            }
1902            DType::F16 => {
1903                let logits = storage.as_slice::<f16>()?;
1904                let sinks_vals = sinks_cpu.as_slice::<f16>()?;
1905
1906                let mut result = vec![f16::ZERO; total_rows * k_len];
1907                result
1908                    .par_chunks_mut(k_len)
1909                    .enumerate()
1910                    .for_each(|(row, out_row)| {
1911                        let h = (row / q_len) % num_heads;
1912                        let sink_val = sinks_vals[sinks_offset + h].to_f32();
1913                        let row_start = offset + row * k_len;
1914
1915                        let mut max_val = sink_val;
1916                        for k in 0..k_len {
1917                            let v = logits[row_start + k].to_f32();
1918                            if v > max_val {
1919                                max_val = v;
1920                            }
1921                        }
1922
1923                        let mut sum = (sink_val - max_val).exp();
1924                        for k in 0..k_len {
1925                            let e = (logits[row_start + k].to_f32() - max_val).exp();
1926                            out_row[k] = f16::from_f32(e);
1927                            sum += e;
1928                        }
1929
1930                        let inv_sum = 1.0f32 / sum;
1931                        for item in out_row.iter_mut().take(k_len) {
1932                            *item = f16::from_f32(item.to_f32() * inv_sum);
1933                        }
1934                    });
1935
1936                Ok((CpuStorage::F16(result), out_shape))
1937            }
1938            DType::BF16 => {
1939                let logits = storage.as_slice::<bf16>()?;
1940                let sinks_vals = sinks_cpu.as_slice::<bf16>()?;
1941
1942                let mut result = vec![bf16::ZERO; total_rows * k_len];
1943                result
1944                    .par_chunks_mut(k_len)
1945                    .enumerate()
1946                    .for_each(|(row, out_row)| {
1947                        let h = (row / q_len) % num_heads;
1948                        let sink_val = sinks_vals[sinks_offset + h].to_f32();
1949                        let row_start = offset + row * k_len;
1950
1951                        let mut max_val = sink_val;
1952                        for k in 0..k_len {
1953                            let v = logits[row_start + k].to_f32();
1954                            if v > max_val {
1955                                max_val = v;
1956                            }
1957                        }
1958
1959                        let mut sum = (sink_val - max_val).exp();
1960                        for k in 0..k_len {
1961                            let e = (logits[row_start + k].to_f32() - max_val).exp();
1962                            out_row[k] = bf16::from_f32(e);
1963                            sum += e;
1964                        }
1965
1966                        let inv_sum = 1.0f32 / sum;
1967                        for item in out_row.iter_mut().take(k_len) {
1968                            *item = bf16::from_f32(item.to_f32() * inv_sum);
1969                        }
1970                    });
1971
1972                Ok((CpuStorage::BF16(result), out_shape))
1973            }
1974            other => hanzo_ml::bail!("softmax_with_sinks: unsupported dtype {:?}", other),
1975        }
1976    }
1977
1978    #[cfg(feature = "cuda")]
1979    fn cuda_fwd(&self, storage: &CudaStorage, layout: &Layout) -> Result<(CudaStorage, Shape)> {
1980        use half::{bf16, f16};
1981
1982        let device = storage.device();
1983        let dtype = storage.dtype();
1984        let n_elements = layout.shape().elem_count();
1985        let out_shape = layout.shape().clone();
1986        let stream = device.cuda_stream().cu_stream();
1987        let logits_offset = layout.start_offset();
1988
1989        let batch_size = out_shape.dims()[0];
1990
1991        let sinks_data = self.sinks.storage_and_layout();
1992        let sinks_cuda = match &*sinks_data.0 {
1993            hanzo_ml::Storage::Cuda(s) => s,
1994            _ => hanzo_ml::bail!("softmax_with_sinks cuda_fwd: sinks must be on CUDA"),
1995        };
1996        let sinks_offset = sinks_data.1.start_offset();
1997
1998        match dtype {
1999            DType::F16 => {
2000                let output = device.alloc_zeros::<f16>(n_elements)?;
2001                let logits_slice = storage.as_cuda_slice::<f16>()?;
2002                let sinks_slice = sinks_cuda.as_cuda_slice::<f16>()?;
2003
2004                let (logits_ptr, _l_guard) = slice_ptr(logits_slice, logits_offset);
2005                let (sinks_ptr, _s_guard) = slice_ptr(sinks_slice, sinks_offset);
2006                let (out_ptr, _o_guard) = slice_ptr(&output, 0);
2007
2008                unsafe {
2009                    ffi::softmax_with_sinks_f16(
2010                        logits_ptr as *const c_void,
2011                        sinks_ptr as *const c_void,
2012                        std::ptr::null(), // mask pre-applied
2013                        out_ptr as *mut c_void,
2014                        batch_size as i32,
2015                        self.num_heads as i32,
2016                        self.q_len as i32,
2017                        self.k_len as i32,
2018                        1.0,
2019                        stream,
2020                    );
2021                }
2022
2023                drop(_o_guard);
2024                let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2025                Ok((out_storage, out_shape))
2026            }
2027            DType::BF16 => {
2028                let output = device.alloc_zeros::<bf16>(n_elements)?;
2029                let logits_slice = storage.as_cuda_slice::<bf16>()?;
2030                let sinks_slice = sinks_cuda.as_cuda_slice::<bf16>()?;
2031
2032                let (logits_ptr, _l_guard) = slice_ptr(logits_slice, logits_offset);
2033                let (sinks_ptr, _s_guard) = slice_ptr(sinks_slice, sinks_offset);
2034                let (out_ptr, _o_guard) = slice_ptr(&output, 0);
2035
2036                unsafe {
2037                    ffi::softmax_with_sinks_bf16(
2038                        logits_ptr as *const c_void,
2039                        sinks_ptr as *const c_void,
2040                        std::ptr::null(),
2041                        out_ptr as *mut c_void,
2042                        batch_size as i32,
2043                        self.num_heads as i32,
2044                        self.q_len as i32,
2045                        self.k_len as i32,
2046                        1.0,
2047                        stream,
2048                    );
2049                }
2050
2051                drop(_o_guard);
2052                let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2053                Ok((out_storage, out_shape))
2054            }
2055            DType::F32 => {
2056                let output = device.alloc_zeros::<f32>(n_elements)?;
2057                let logits_slice = storage.as_cuda_slice::<f32>()?;
2058                let sinks_slice = sinks_cuda.as_cuda_slice::<f32>()?;
2059
2060                let (logits_ptr, _l_guard) = slice_ptr(logits_slice, logits_offset);
2061                let (sinks_ptr, _s_guard) = slice_ptr(sinks_slice, sinks_offset);
2062                let (out_ptr, _o_guard) = slice_ptr(&output, 0);
2063
2064                unsafe {
2065                    ffi::softmax_with_sinks_f32(
2066                        logits_ptr as *const c_void,
2067                        sinks_ptr as *const c_void,
2068                        std::ptr::null(),
2069                        out_ptr as *mut c_void,
2070                        batch_size as i32,
2071                        self.num_heads as i32,
2072                        self.q_len as i32,
2073                        self.k_len as i32,
2074                        1.0,
2075                        stream,
2076                    );
2077                }
2078
2079                drop(_o_guard);
2080                let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2081                Ok((out_storage, out_shape))
2082            }
2083            _ => hanzo_ml::bail!("softmax_with_sinks: unsupported dtype {:?}", dtype),
2084        }
2085    }
2086
2087    #[cfg(feature = "metal")]
2088    fn metal_fwd(
2089        &self,
2090        storage: &hanzo_ml::MetalStorage,
2091        layout: &Layout,
2092    ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
2093        let dtype = storage.dtype();
2094        let n_elements = layout.shape().elem_count();
2095        let out_shape = layout.shape().clone();
2096        let total_rows = n_elements / self.k_len;
2097
2098        let device = storage.device();
2099        let encoder = device.command_encoder()?;
2100        encoder.set_label("softmax-with-sinks");
2101
2102        let output = device.new_buffer(n_elements, dtype, "softmax-with-sinks-output")?;
2103
2104        let sinks_data = self.sinks.storage_and_layout();
2105        let sinks_metal = match &*sinks_data.0 {
2106            hanzo_ml::Storage::Metal(s) => s,
2107            _ => hanzo_ml::bail!("softmax_with_sinks metal_fwd: sinks must be on Metal"),
2108        };
2109        let sinks_offset = sinks_data.1.start_offset() * self.sinks.dtype().size_in_bytes();
2110
2111        crate::metal_kernels::call_softmax_with_sinks(
2112            device.device(),
2113            &encoder,
2114            &crate::metal_kernels::Kernels::new(),
2115            dtype,
2116            storage.buffer(),
2117            layout.start_offset() * dtype.size_in_bytes(),
2118            sinks_metal.buffer(),
2119            sinks_offset,
2120            &output,
2121            self.num_heads as u32,
2122            self.q_len as u32,
2123            self.k_len as u32,
2124            total_rows,
2125        )
2126        .map_err(hanzo_ml::Error::wrap)?;
2127
2128        let newstorage = hanzo_ml::MetalStorage::new(output, device.clone(), n_elements, dtype);
2129        Ok((newstorage, out_shape))
2130    }
2131}
2132
2133pub fn softmax_with_sinks(
2134    logits: &Tensor,
2135    sinks: &Tensor,
2136    mask: Option<&Tensor>,
2137) -> Result<Tensor> {
2138    let logits = if let Some(mask) = mask {
2139        logits.broadcast_add(mask)?
2140    } else {
2141        logits.clone()
2142    };
2143    let logits = logits.contiguous()?;
2144    let sinks = sinks.contiguous()?;
2145
2146    let dims = logits.dims();
2147    if dims.len() != 4 {
2148        hanzo_ml::bail!(
2149            "softmax_with_sinks: expected logits to have 4 dims [b, h, q, k], got {:?}",
2150            dims
2151        );
2152    }
2153
2154    let num_heads = dims[1];
2155    let q_len = dims[2];
2156    let k_len = dims[3];
2157
2158    if sinks.dims() != [num_heads] {
2159        hanzo_ml::bail!(
2160            "softmax_with_sinks: expected sinks shape [{}], got {:?}",
2161            num_heads,
2162            sinks.dims()
2163        );
2164    }
2165
2166    logits.apply_op1_no_bwd(&SoftmaxWithSinks {
2167        sinks: sinks.clone(),
2168        num_heads,
2169        q_len,
2170        k_len,
2171    })
2172}
2173
2174// ============================================================================
2175// Fused flash attention with sinks (Metal)
2176// ============================================================================
2177
2178#[allow(dead_code)]
2179struct FlashAttnSinksMetal {
2180    key: Tensor,
2181    value: Tensor,
2182    sinks: Tensor, // [num_heads], always f32
2183    softmax_scale: f32,
2184    window_size: usize,
2185}
2186
2187impl CustomOp1 for FlashAttnSinksMetal {
2188    fn name(&self) -> &'static str {
2189        "flash-attn-sinks-metal"
2190    }
2191
2192    fn cpu_fwd(&self, _storage: &CpuStorage, _layout: &Layout) -> Result<(CpuStorage, Shape)> {
2193        hanzo_ml::bail!("flash_attn_sinks_metal: no CPU support, use softmax_with_sinks fallback")
2194    }
2195
2196    #[cfg(feature = "metal")]
2197    fn metal_fwd(
2198        &self,
2199        q_storage: &hanzo_ml::MetalStorage,
2200        q_layout: &Layout,
2201    ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
2202        let dtype = q_storage.dtype();
2203        let out_shape = q_layout.shape().clone();
2204        let (batch_size, num_heads, q_len, head_dim) = q_layout.shape().dims4()?;
2205
2206        // Extract K storage
2207        let (k_s, k_l) = self.key.storage_and_layout();
2208        let k_metal = match &*k_s {
2209            hanzo_ml::Storage::Metal(s) => s,
2210            _ => hanzo_ml::bail!("flash_attn_sinks_metal: key must be a Metal tensor"),
2211        };
2212        let (_, num_kv_heads, k_len, _) = k_l.shape().dims4()?;
2213
2214        // Extract V storage
2215        let (v_s, v_l) = self.value.storage_and_layout();
2216        let v_metal = match &*v_s {
2217            hanzo_ml::Storage::Metal(s) => s,
2218            _ => hanzo_ml::bail!("flash_attn_sinks_metal: value must be a Metal tensor"),
2219        };
2220
2221        // Extract sinks storage
2222        let (s_s, s_l) = self.sinks.storage_and_layout();
2223        let sinks_metal = match &*s_s {
2224            hanzo_ml::Storage::Metal(s) => s,
2225            _ => hanzo_ml::bail!("flash_attn_sinks_metal: sinks must be a Metal tensor"),
2226        };
2227        let sinks_offset = s_l.start_offset() * self.sinks.dtype().size_in_bytes();
2228
2229        let device = q_storage.device();
2230        let elem_count = out_shape.elem_count();
2231        let output = device.new_buffer(elem_count, dtype, "flash-attn-sinks-output")?;
2232
2233        let encoder = device.command_encoder()?;
2234        encoder.set_label("flash-attn-sinks");
2235
2236        let kernels = crate::metal_kernels::Kernels::new();
2237
2238        let q_offset = q_layout.start_offset() * dtype.size_in_bytes();
2239        let k_offset = k_l.start_offset() * dtype.size_in_bytes();
2240        let v_offset = v_l.start_offset() * dtype.size_in_bytes();
2241
2242        if q_len == 1 {
2243            // Decode path: use sdpa_vector_with_sinks
2244            let gqa_factor = (num_heads / num_kv_heads) as i32;
2245            let b = batch_size * num_heads;
2246
2247            // k_stride and v_stride: stride between consecutive KV positions in the head dimension
2248            // For contiguous [B, Hkv, S, D] layout: stride between KV heads = S * D
2249            let k_stride = k_l.stride()[1]; // stride for kv_head dim (= k_len * head_dim)
2250            let v_stride = v_l.stride()[1];
2251
2252            let two_pass_threshold = 1024;
2253            if k_len >= two_pass_threshold {
2254                // Two-pass for long contexts
2255                let blocks: usize = 32;
2256                let intermediate = device.new_buffer(
2257                    b * blocks * head_dim,
2258                    DType::F32,
2259                    "sdpa-sinks-intermediate",
2260                )?;
2261                let sums = device.new_buffer(b * blocks, DType::F32, "sdpa-sinks-sums")?;
2262                let maxs = device.new_buffer(b * blocks, DType::F32, "sdpa-sinks-maxs")?;
2263
2264                crate::metal_kernels::call_sdpa_vector_with_sinks_2pass(
2265                    device.device(),
2266                    &encoder,
2267                    &kernels,
2268                    dtype,
2269                    q_storage.buffer(),
2270                    q_offset,
2271                    k_metal.buffer(),
2272                    k_offset,
2273                    v_metal.buffer(),
2274                    v_offset,
2275                    sinks_metal.buffer(),
2276                    sinks_offset,
2277                    &output,
2278                    &intermediate,
2279                    &sums,
2280                    &maxs,
2281                    head_dim,
2282                    gqa_factor,
2283                    k_len as i32,
2284                    k_stride,
2285                    v_stride,
2286                    self.softmax_scale,
2287                    b,
2288                )
2289                .map_err(hanzo_ml::Error::wrap)?;
2290            } else {
2291                // Single-pass
2292                crate::metal_kernels::call_sdpa_vector_with_sinks(
2293                    device.device(),
2294                    &encoder,
2295                    &kernels,
2296                    dtype,
2297                    q_storage.buffer(),
2298                    q_offset,
2299                    k_metal.buffer(),
2300                    k_offset,
2301                    v_metal.buffer(),
2302                    v_offset,
2303                    sinks_metal.buffer(),
2304                    sinks_offset,
2305                    &output,
2306                    head_dim,
2307                    gqa_factor,
2308                    k_len as i32,
2309                    k_stride,
2310                    v_stride,
2311                    self.softmax_scale,
2312                    b,
2313                )
2314                .map_err(hanzo_ml::Error::wrap)?;
2315            }
2316        } else {
2317            // Prefill path: use flash_attn_sinks_kernel
2318            crate::metal_kernels::call_flash_attn_sinks_prefill(
2319                device.device(),
2320                &encoder,
2321                &kernels,
2322                dtype,
2323                q_storage.buffer(),
2324                q_offset,
2325                k_metal.buffer(),
2326                k_offset,
2327                v_metal.buffer(),
2328                v_offset,
2329                sinks_metal.buffer(),
2330                sinks_offset,
2331                &output,
2332                self.softmax_scale,
2333                batch_size,
2334                q_len,
2335                k_len,
2336                num_heads,
2337                num_kv_heads,
2338                head_dim,
2339                self.window_size,
2340            )
2341            .map_err(hanzo_ml::Error::wrap)?;
2342        }
2343
2344        let newstorage = hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, dtype);
2345        Ok((newstorage, out_shape))
2346    }
2347}
2348
2349/// Fused flash attention with per-head sinks for Metal devices.
2350///
2351/// Uses fused Metal kernels that compute Q·K^T -> softmax_with_sinks -> ×V
2352/// without materializing the N×N attention matrix. Per-head sinks contribute
2353/// to the softmax denominator without an associated value contribution.
2354///
2355/// Causal masking is applied for prefill (q_len > 1). For decode (q_len == 1),
2356/// all K/V positions are attended to.
2357///
2358/// # Arguments
2359///
2360/// * `q` - Query tensor `[batch_size, num_heads, q_len, head_dim]`
2361/// * `k` - Key tensor `[batch_size, num_kv_heads, k_len, head_dim]`
2362/// * `v` - Value tensor `[batch_size, num_kv_heads, k_len, head_dim]`
2363/// * `sinks` - Per-head sink values `[num_heads]` (will be cast to f32)
2364/// * `softmax_scale` - Scaling factor (typically `1 / sqrt(head_dim)`)
2365/// * `window_size` - Sliding window size (0 = full attention)
2366///
2367/// Returns `[batch_size, num_heads, q_len, head_dim]`
2368#[allow(clippy::too_many_arguments)]
2369pub fn flash_attn_sinks_metal(
2370    q: &Tensor,
2371    k: &Tensor,
2372    v: &Tensor,
2373    sinks: Option<&Tensor>,
2374    softmax_scale: f32,
2375    window_size: usize,
2376) -> Result<Tensor> {
2377    let q = q.contiguous()?;
2378    let k = k.contiguous()?;
2379    let v = v.contiguous()?;
2380
2381    let sinks = match sinks {
2382        Some(s) => s.to_dtype(DType::F32)?.contiguous()?,
2383        None => {
2384            // No sinks: create zeros (no effect on softmax)
2385            let num_heads = q.dim(1)?;
2386            Tensor::zeros(num_heads, DType::F32, q.device())?
2387        }
2388    };
2389
2390    let op = FlashAttnSinksMetal {
2391        key: k.clone(),
2392        value: v.clone(),
2393        sinks,
2394        softmax_scale,
2395        window_size,
2396    };
2397    q.apply_op1_no_bwd(&op)
2398}
2399
2400#[allow(dead_code)]
2401struct FlashAttnSinksVarlenMetal {
2402    key: Tensor,          // [total_kv, num_kv_heads, D]
2403    value: Tensor,        // [total_kv, num_kv_heads, D]
2404    sinks: Tensor,        // [num_heads], always f32
2405    cu_seqlens_q: Tensor, // [B+1] u32
2406    cu_seqlens_k: Tensor, // [B+1] u32
2407    softmax_scale: f32,
2408    window_size: usize,
2409}
2410
2411impl CustomOp1 for FlashAttnSinksVarlenMetal {
2412    fn name(&self) -> &'static str {
2413        "flash-attn-sinks-varlen-metal"
2414    }
2415
2416    fn cpu_fwd(&self, _storage: &CpuStorage, _layout: &Layout) -> Result<(CpuStorage, Shape)> {
2417        hanzo_ml::bail!("flash_attn_sinks_varlen_metal: no CPU support")
2418    }
2419
2420    #[cfg(feature = "metal")]
2421    fn metal_fwd(
2422        &self,
2423        q_storage: &hanzo_ml::MetalStorage,
2424        q_layout: &Layout,
2425    ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
2426        let dtype = q_storage.dtype();
2427        let out_shape = q_layout.shape().clone();
2428        let (batch_size, num_heads, max_q_len, head_dim) = q_layout.shape().dims4()?;
2429
2430        // Extract K storage [total_kv, num_kv_heads, D]
2431        let (k_s, k_l) = self.key.storage_and_layout();
2432        let k_metal = match &*k_s {
2433            hanzo_ml::Storage::Metal(s) => s,
2434            _ => hanzo_ml::bail!("flash_attn_sinks_varlen_metal: key must be a Metal tensor"),
2435        };
2436        let (_, num_kv_heads, _) = k_l.shape().dims3()?;
2437
2438        // Extract V storage
2439        let (v_s, v_l) = self.value.storage_and_layout();
2440        let v_metal = match &*v_s {
2441            hanzo_ml::Storage::Metal(s) => s,
2442            _ => hanzo_ml::bail!("flash_attn_sinks_varlen_metal: value must be a Metal tensor"),
2443        };
2444
2445        // Extract sinks storage
2446        let (s_s, s_l) = self.sinks.storage_and_layout();
2447        let sinks_metal = match &*s_s {
2448            hanzo_ml::Storage::Metal(s) => s,
2449            _ => hanzo_ml::bail!("flash_attn_sinks_varlen_metal: sinks must be a Metal tensor"),
2450        };
2451        let sinks_offset = s_l.start_offset() * self.sinks.dtype().size_in_bytes();
2452
2453        // Extract cu_seqlens_q storage
2454        let (csq_s, csq_l) = self.cu_seqlens_q.storage_and_layout();
2455        let csq_metal = match &*csq_s {
2456            hanzo_ml::Storage::Metal(s) => s,
2457            _ => hanzo_ml::bail!(
2458                "flash_attn_sinks_varlen_metal: cu_seqlens_q must be a Metal tensor"
2459            ),
2460        };
2461        let csq_offset = csq_l.start_offset() * DType::U32.size_in_bytes();
2462
2463        // Extract cu_seqlens_k storage
2464        let (csk_s, csk_l) = self.cu_seqlens_k.storage_and_layout();
2465        let csk_metal = match &*csk_s {
2466            hanzo_ml::Storage::Metal(s) => s,
2467            _ => hanzo_ml::bail!(
2468                "flash_attn_sinks_varlen_metal: cu_seqlens_k must be a Metal tensor"
2469            ),
2470        };
2471        let csk_offset = csk_l.start_offset() * DType::U32.size_in_bytes();
2472
2473        let device = q_storage.device();
2474        let elem_count = out_shape.elem_count();
2475        let output = device.new_buffer(elem_count, dtype, "flash-attn-sinks-varlen-output")?;
2476
2477        let encoder = device.command_encoder()?;
2478        encoder.set_label("flash-attn-sinks-varlen");
2479
2480        let kernels = crate::metal_kernels::Kernels::new();
2481
2482        let q_offset = q_layout.start_offset() * dtype.size_in_bytes();
2483        let k_offset = k_l.start_offset() * dtype.size_in_bytes();
2484        let v_offset = v_l.start_offset() * dtype.size_in_bytes();
2485
2486        crate::metal_kernels::call_flash_attn_sinks_varlen_prefill(
2487            device.device(),
2488            &encoder,
2489            &kernels,
2490            dtype,
2491            q_storage.buffer(),
2492            q_offset,
2493            k_metal.buffer(),
2494            k_offset,
2495            v_metal.buffer(),
2496            v_offset,
2497            sinks_metal.buffer(),
2498            sinks_offset,
2499            &output,
2500            csq_metal.buffer(),
2501            csq_offset,
2502            csk_metal.buffer(),
2503            csk_offset,
2504            self.softmax_scale,
2505            batch_size,
2506            max_q_len,
2507            num_heads,
2508            num_kv_heads,
2509            head_dim,
2510            self.window_size,
2511        )
2512        .map_err(hanzo_ml::Error::wrap)?;
2513
2514        let newstorage = hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, dtype);
2515        Ok((newstorage, out_shape))
2516    }
2517}
2518
2519/// Fused varlen flash attention with per-head sinks for Metal devices.
2520///
2521/// Handles variable-length sequences within a batch. Q is padded,
2522/// K/V are packed (concatenated across sequences).
2523///
2524/// # Arguments
2525///
2526/// * `q` - Query tensor `[batch_size, num_heads, max_q_len, head_dim]` (padded)
2527/// * `k` - Key tensor `[total_kv, num_kv_heads, head_dim]` (packed)
2528/// * `v` - Value tensor `[total_kv, num_kv_heads, head_dim]` (packed)
2529/// * `sinks` - Per-head sink values `[num_heads]` (will be cast to f32)
2530/// * `cu_seqlens_q` - Cumulative Q sequence lengths `[batch_size + 1]` (u32)
2531/// * `cu_seqlens_k` - Cumulative KV sequence lengths `[batch_size + 1]` (u32)
2532/// * `softmax_scale` - Scaling factor (typically `1 / sqrt(head_dim)`)
2533/// * `window_size` - Sliding window size (0 = full attention)
2534///
2535/// Returns `[batch_size, num_heads, max_q_len, head_dim]` (padding rows are zero)
2536#[allow(clippy::too_many_arguments)]
2537pub fn flash_attn_sinks_varlen_metal(
2538    q: &Tensor,
2539    k: &Tensor,
2540    v: &Tensor,
2541    sinks: Option<&Tensor>,
2542    cu_seqlens_q: &Tensor,
2543    cu_seqlens_k: &Tensor,
2544    softmax_scale: f32,
2545    window_size: usize,
2546) -> Result<Tensor> {
2547    let q = q.contiguous()?;
2548    let k = k.contiguous()?;
2549    let v = v.contiguous()?;
2550
2551    let sinks = match sinks {
2552        Some(s) => s.to_dtype(DType::F32)?.contiguous()?,
2553        None => {
2554            let num_heads = q.dim(1)?;
2555            Tensor::zeros(num_heads, DType::F32, q.device())?
2556        }
2557    };
2558
2559    let op = FlashAttnSinksVarlenMetal {
2560        key: k.clone(),
2561        value: v.clone(),
2562        sinks,
2563        cu_seqlens_q: cu_seqlens_q.clone(),
2564        cu_seqlens_k: cu_seqlens_k.clone(),
2565        softmax_scale,
2566        window_size,
2567    };
2568    q.apply_op1_no_bwd(&op)
2569}
2570
2571/// Activation enum for fused GLU kernel.
2572/// Must match the GluActivation enum in CUDA (ops.cu) and Metal (fused_glu.metal) kernels.
2573#[derive(Clone, Copy, Debug)]
2574#[repr(i32)]
2575pub enum GluActivationType {
2576    Silu = 0,
2577    Gelu = 1,
2578    Relu = 2,
2579    GeluErf = 3,
2580}
2581
2582// CPU activation functions for fused GLU
2583fn cpu_silu(x: f32) -> f32 {
2584    x / (1.0 + (-x).exp())
2585}
2586
2587fn cpu_gelu(x: f32) -> f32 {
2588    #[allow(clippy::excessive_precision)]
2589    const SQRT_2_OVER_PI: f32 = 0.7978845608;
2590    const COEFF: f32 = 0.044715;
2591    let x3 = x * x * x;
2592    let inner = SQRT_2_OVER_PI * (x + COEFF * x3);
2593    0.5 * x * (1.0 + inner.tanh())
2594}
2595
2596fn cpu_relu(x: f32) -> f32 {
2597    x.max(0.0)
2598}
2599
2600fn cpu_gelu_erf(x: f32) -> f32 {
2601    // gelu_erf: x * (1 + erf(x / sqrt(2))) / 2
2602    x * (1.0 + hanzo_ml::cpu::erf::erf_f32(x * std::f32::consts::FRAC_1_SQRT_2)) / 2.0
2603}
2604
2605fn apply_cpu_activation(x: f32, activation: GluActivationType) -> f32 {
2606    match activation {
2607        GluActivationType::Silu => cpu_silu(x),
2608        GluActivationType::Gelu => cpu_gelu(x),
2609        GluActivationType::Relu => cpu_relu(x),
2610        GluActivationType::GeluErf => cpu_gelu_erf(x),
2611    }
2612}
2613
2614struct FusedGlu(GluActivationType);
2615
2616impl CustomOp2 for FusedGlu {
2617    fn name(&self) -> &'static str {
2618        "fused-glu"
2619    }
2620
2621    fn cpu_fwd(
2622        &self,
2623        s1: &CpuStorage,
2624        l1: &Layout,
2625        s2: &CpuStorage,
2626        l2: &Layout,
2627    ) -> Result<(CpuStorage, Shape)> {
2628        use half::{bf16, f16};
2629
2630        let activation = self.0;
2631        let out_shape = l1.shape().clone();
2632        let len = out_shape.elem_count();
2633
2634        let result_storage = match s1.dtype() {
2635            DType::F32 => {
2636                let a_slice = s1.as_slice::<f32>()?;
2637                let b_slice = s2.as_slice::<f32>()?;
2638                let a_offset = l1.start_offset();
2639                let b_offset = l2.start_offset();
2640
2641                let result: Vec<f32> = (0..len)
2642                    .into_par_iter()
2643                    .map(|i| {
2644                        let a_val = a_slice[a_offset + i];
2645                        let b_val = b_slice[b_offset + i];
2646                        apply_cpu_activation(a_val, activation) * b_val
2647                    })
2648                    .collect();
2649                CpuStorage::F32(result)
2650            }
2651            DType::F16 => {
2652                let a_slice = s1.as_slice::<f16>()?;
2653                let b_slice = s2.as_slice::<f16>()?;
2654                let a_offset = l1.start_offset();
2655                let b_offset = l2.start_offset();
2656
2657                let result: Vec<f16> = (0..len)
2658                    .into_par_iter()
2659                    .map(|i| {
2660                        let a_val = a_slice[a_offset + i].to_f32();
2661                        // Cast activation back to f16 before multiplying, matching the reference
2662                        // two-step behavior: unary op in f32 -> cast to f16 -> binary mul
2663                        let activated = f16::from_f32(apply_cpu_activation(a_val, activation));
2664                        f16::from_f32(activated.to_f32() * b_slice[b_offset + i].to_f32())
2665                    })
2666                    .collect();
2667                CpuStorage::F16(result)
2668            }
2669            DType::BF16 => {
2670                let a_slice = s1.as_slice::<bf16>()?;
2671                let b_slice = s2.as_slice::<bf16>()?;
2672                let a_offset = l1.start_offset();
2673                let b_offset = l2.start_offset();
2674
2675                let result: Vec<bf16> = (0..len)
2676                    .into_par_iter()
2677                    .map(|i| {
2678                        let a_val = a_slice[a_offset + i].to_f32();
2679                        // Cast activation back to bf16 before multiplying, matching the reference
2680                        // two-step behavior: unary op in f32 -> cast to bf16 -> binary mul
2681                        let activated = bf16::from_f32(apply_cpu_activation(a_val, activation));
2682                        bf16::from_f32(activated.to_f32() * b_slice[b_offset + i].to_f32())
2683                    })
2684                    .collect();
2685                CpuStorage::BF16(result)
2686            }
2687            other => hanzo_ml::bail!("fused_glu: unsupported dtype {:?}", other),
2688        };
2689
2690        Ok((result_storage, out_shape))
2691    }
2692
2693    #[cfg(feature = "cuda")]
2694    fn cuda_fwd(
2695        &self,
2696        s1: &CudaStorage,
2697        l1: &Layout,
2698        s2: &CudaStorage,
2699        l2: &Layout,
2700    ) -> Result<(CudaStorage, Shape)> {
2701        use half::{bf16, f16};
2702
2703        let activation = self.0;
2704        let device = s1.device();
2705        let n_elements = l1.shape().elem_count();
2706        let dtype = s1.dtype();
2707        let out_shape = l1.shape().clone();
2708        let stream = device.cuda_stream().cu_stream();
2709        let a_offset = l1.start_offset();
2710        let b_offset = l2.start_offset();
2711
2712        match dtype {
2713            DType::F16 => {
2714                let output = device.alloc_zeros::<f16>(n_elements)?;
2715                let a_slice = s1.as_cuda_slice::<f16>()?;
2716                let b_slice = s2.as_cuda_slice::<f16>()?;
2717
2718                let (a_ptr, _a_guard) = slice_ptr(a_slice, a_offset);
2719                let (b_ptr, _b_guard) = slice_ptr(b_slice, b_offset);
2720                let (out_ptr, _o_guard) = slice_ptr(&output, 0);
2721
2722                unsafe {
2723                    ffi::fused_glu_f16(
2724                        a_ptr as *const c_void,
2725                        b_ptr as *const c_void,
2726                        out_ptr as *mut c_void,
2727                        n_elements as u32,
2728                        activation as i32,
2729                        stream,
2730                    );
2731                }
2732
2733                drop(_o_guard);
2734                let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2735                Ok((out_storage, out_shape))
2736            }
2737            DType::BF16 => {
2738                let output = device.alloc_zeros::<bf16>(n_elements)?;
2739                let a_slice = s1.as_cuda_slice::<bf16>()?;
2740                let b_slice = s2.as_cuda_slice::<bf16>()?;
2741
2742                let (a_ptr, _a_guard) = slice_ptr(a_slice, a_offset);
2743                let (b_ptr, _b_guard) = slice_ptr(b_slice, b_offset);
2744                let (out_ptr, _o_guard) = slice_ptr(&output, 0);
2745
2746                unsafe {
2747                    ffi::fused_glu_bf16(
2748                        a_ptr as *const c_void,
2749                        b_ptr as *const c_void,
2750                        out_ptr as *mut c_void,
2751                        n_elements as u32,
2752                        activation as i32,
2753                        stream,
2754                    );
2755                }
2756
2757                drop(_o_guard);
2758                let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2759                Ok((out_storage, out_shape))
2760            }
2761            DType::F32 => {
2762                let output = device.alloc_zeros::<f32>(n_elements)?;
2763                let a_slice = s1.as_cuda_slice::<f32>()?;
2764                let b_slice = s2.as_cuda_slice::<f32>()?;
2765
2766                let (a_ptr, _a_guard) = slice_ptr(a_slice, a_offset);
2767                let (b_ptr, _b_guard) = slice_ptr(b_slice, b_offset);
2768                let (out_ptr, _o_guard) = slice_ptr(&output, 0);
2769
2770                unsafe {
2771                    ffi::fused_glu_f32(
2772                        a_ptr as *const c_void,
2773                        b_ptr as *const c_void,
2774                        out_ptr as *mut c_void,
2775                        n_elements as u32,
2776                        activation as i32,
2777                        stream,
2778                    );
2779                }
2780
2781                drop(_o_guard);
2782                let out_storage = CudaStorage::wrap_cuda_slice(output, device.clone());
2783                Ok((out_storage, out_shape))
2784            }
2785            _ => hanzo_ml::bail!("fused_glu: unsupported dtype {:?}", dtype),
2786        }
2787    }
2788
2789    #[cfg(feature = "metal")]
2790    fn metal_fwd(
2791        &self,
2792        s1: &hanzo_ml::MetalStorage,
2793        l1: &Layout,
2794        s2: &hanzo_ml::MetalStorage,
2795        l2: &Layout,
2796    ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
2797        let activation = self.0;
2798        let n_elements = l1.shape().elem_count();
2799        let dtype = s1.dtype();
2800        let out_shape = l1.shape().clone();
2801
2802        let device = s1.device();
2803        let encoder = device.command_encoder()?;
2804        encoder.set_label("fused-glu");
2805
2806        let output = device.new_buffer(n_elements, dtype, "fused-glu-output")?;
2807
2808        crate::metal_kernels::call_fused_glu(
2809            device.device(),
2810            &encoder,
2811            &crate::metal_kernels::Kernels::new(),
2812            dtype,
2813            s1.buffer(),
2814            s2.buffer(),
2815            l1.start_offset() * dtype.size_in_bytes(),
2816            l2.start_offset() * dtype.size_in_bytes(),
2817            n_elements,
2818            activation as i32,
2819            &output,
2820        )
2821        .map_err(hanzo_ml::Error::wrap)?;
2822
2823        let newstorage = hanzo_ml::MetalStorage::new(output, device.clone(), n_elements, dtype);
2824        Ok((newstorage, out_shape))
2825    }
2826}
2827
2828/// Fused GLU activation: output = activation(a) * b
2829///
2830/// This fuses the activation function application and element-wise multiplication
2831/// into a single pass, reducing memory bandwidth and eliminating
2832/// intermediate tensor allocation.
2833pub fn fused_glu(a: &Tensor, b: &Tensor, activation: GluActivationType) -> Result<Tensor> {
2834    let a = a.contiguous()?;
2835    let b = b.contiguous()?;
2836
2837    if a.shape() != b.shape() {
2838        hanzo_ml::bail!(
2839            "fused_glu: a and b must have same shape, got {:?} vs {:?}",
2840            a.shape(),
2841            b.shape()
2842        );
2843    }
2844
2845    // ROCm and Vulkan have no fused-glu kernel; decompose to eager `activation(a) * b`
2846    // (uses the backend's native unary + multiply kernels).
2847    if a.device().is_rocm() || a.device().is_vulkan() {
2848        let act = match activation {
2849            GluActivationType::Silu => a.silu()?,
2850            GluActivationType::Gelu => a.gelu()?,
2851            GluActivationType::GeluErf => a.gelu_erf()?,
2852            GluActivationType::Relu => a.relu()?,
2853        };
2854        return act.mul(&b);
2855    }
2856
2857    a.apply_op2_no_bwd(&b, &FusedGlu(activation))
2858}
2859
2860mod tests {
2861    #[test]
2862    fn test_cumsum_exclusive_forward_cpu() {
2863        use crate::utils::ops::CumSumOp;
2864        use hanzo_ml::Tensor;
2865        let device = hanzo_ml::Device::Cpu;
2866        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2867        let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
2868        assert_eq!(b, [0, 1, 3, 6]);
2869    }
2870
2871    #[test]
2872    fn test_cumsum_inclusive_forward_cpu() {
2873        use crate::utils::ops::CumSumOp;
2874        use hanzo_ml::Tensor;
2875        let device = hanzo_ml::Device::Cpu;
2876        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2877        let b = a
2878            .fast_cumsum_config(0, true, false)
2879            .unwrap()
2880            .to_vec1::<i64>()
2881            .unwrap();
2882        assert_eq!(b, [1, 3, 6, 10]);
2883    }
2884
2885    #[test]
2886    fn test_cumsum_exclusive_reverse_cpu() {
2887        use crate::utils::ops::CumSumOp;
2888        use hanzo_ml::Tensor;
2889        let device = hanzo_ml::Device::Cpu;
2890        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2891        let b = a
2892            .fast_cumsum_config(0, false, true)
2893            .unwrap()
2894            .to_vec1::<i64>()
2895            .unwrap();
2896        assert_eq!(b, [9, 7, 4, 0]);
2897    }
2898
2899    #[test]
2900    fn test_cumsum_inclusive_reverse_cpu() {
2901        use crate::utils::ops::CumSumOp;
2902        use hanzo_ml::Tensor;
2903        let device = hanzo_ml::Device::Cpu;
2904        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2905        let b = a
2906            .fast_cumsum_config(0, true, true)
2907            .unwrap()
2908            .to_vec1::<i64>()
2909            .unwrap();
2910        assert_eq!(b, [10, 9, 7, 4]);
2911    }
2912
2913    #[cfg(feature = "metal")]
2914    #[test]
2915    fn test_cumsum_exclusive_forward_metal() {
2916        use crate::utils::ops::CumSumOp;
2917        use hanzo_ml::Tensor;
2918        let device = hanzo_ml::Device::new_metal(0).unwrap();
2919        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2920        let b = a.fast_cumsum(0).unwrap().to_vec1::<i64>().unwrap();
2921        assert_eq!(b, [0, 1, 3, 6]);
2922    }
2923
2924    #[cfg(feature = "metal")]
2925    #[test]
2926    fn test_cumsum_inclusive_forward_metal() {
2927        use crate::utils::ops::CumSumOp;
2928        use hanzo_ml::Tensor;
2929        let device = hanzo_ml::Device::new_metal(0).unwrap();
2930        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2931        let b = a
2932            .fast_cumsum_config(0, true, false)
2933            .unwrap()
2934            .to_vec1::<i64>()
2935            .unwrap();
2936        assert_eq!(b, [1, 3, 6, 10]);
2937    }
2938
2939    #[cfg(feature = "metal")]
2940    #[test]
2941    fn test_cumsum_exclusive_reverse_metal() {
2942        use crate::utils::ops::CumSumOp;
2943        use hanzo_ml::Tensor;
2944        let device = hanzo_ml::Device::new_metal(0).unwrap();
2945        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2946        let b = a
2947            .fast_cumsum_config(0, false, true)
2948            .unwrap()
2949            .to_vec1::<i64>()
2950            .unwrap();
2951        assert_eq!(b, [9, 7, 4, 0]);
2952    }
2953
2954    #[cfg(feature = "metal")]
2955    #[test]
2956    fn test_cumsum_inclusive_reverse_metal() {
2957        use crate::utils::ops::CumSumOp;
2958        use hanzo_ml::Tensor;
2959        let device = hanzo_ml::Device::new_metal(0).unwrap();
2960        let a = Tensor::from_vec(vec![1i64, 2, 3, 4], &[4], &device).unwrap();
2961        let b = a
2962            .fast_cumsum_config(0, true, true)
2963            .unwrap()
2964            .to_vec1::<i64>()
2965            .unwrap();
2966        assert_eq!(b, [10, 9, 7, 4]);
2967    }
2968
2969    #[test]
2970    fn test_nonzero_cpu() {
2971        use crate::utils::ops::NonZeroOp;
2972        use hanzo_ml::Tensor;
2973        let device = hanzo_ml::Device::Cpu;
2974        let a = Tensor::from_vec(
2975            vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
2976            &[2, 4],
2977            &device,
2978        )
2979        .unwrap();
2980        let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
2981        assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
2982    }
2983
2984    #[cfg(feature = "cuda")]
2985    #[test]
2986    fn test_nonzero_cuda() {
2987        use crate::utils::ops::NonZeroOp;
2988        use hanzo_ml::Tensor;
2989        let device = hanzo_ml::Device::new_cuda(0).unwrap();
2990        let a = Tensor::from_vec(
2991            vec![1f32, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0],
2992            &[2, 4],
2993            &device,
2994        )
2995        .unwrap();
2996        let b = a.nonzero().unwrap().to_vec2::<u32>().unwrap();
2997        assert_eq!(b, [[0, 0], [0, 2], [1, 0], [1, 2]]);
2998    }
2999
3000    #[test]
3001    fn test_bitwise_and_cpu() {
3002        use crate::utils::ops::BitWiseOp;
3003        use hanzo_ml::Tensor;
3004        let device = hanzo_ml::Device::Cpu;
3005        let a =
3006            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
3007        let b =
3008            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
3009        let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
3010        assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [5, 7]]);
3011    }
3012
3013    #[cfg(feature = "cuda")]
3014    #[test]
3015    fn test_bitwise_and_cuda() {
3016        use crate::utils::ops::BitWiseOp;
3017        use hanzo_ml::Tensor;
3018        let device = hanzo_ml::Device::new_cuda(0).unwrap();
3019        let a =
3020            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
3021        let b =
3022            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 0, 7], (5, 2), &device).unwrap();
3023        let c = a.bitwise_and(&b).unwrap().to_vec2::<i64>().unwrap();
3024        assert_eq!(c, [[1, 2], [3, -1], [1, -1], [-1, 4], [0, 7]]);
3025    }
3026
3027    #[test]
3028    fn test_bitwise_or_cpu() {
3029        use crate::utils::ops::BitWiseOp;
3030        use hanzo_ml::Tensor;
3031        let device = hanzo_ml::Device::Cpu;
3032        let a =
3033            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
3034        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
3035        let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
3036        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
3037    }
3038
3039    #[cfg(feature = "cuda")]
3040    #[test]
3041    fn test_bitwise_or_cuda() {
3042        use crate::utils::ops::BitWiseOp;
3043        use hanzo_ml::Tensor;
3044        let device = hanzo_ml::Device::new_cuda(0).unwrap();
3045        let a =
3046            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
3047        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
3048        let c = a.bitwise_or(&b).unwrap().to_vec2::<i64>().unwrap();
3049        assert_eq!(c, [[-1, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
3050    }
3051
3052    #[test]
3053    fn test_bitwise_xor_cpu() {
3054        use crate::utils::ops::BitWiseOp;
3055        use hanzo_ml::Tensor;
3056        let device = hanzo_ml::Device::Cpu;
3057        let a =
3058            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
3059        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
3060        let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
3061        assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
3062    }
3063
3064    #[cfg(feature = "cuda")]
3065    #[test]
3066    fn test_bitwise_xor_cuda() {
3067        use crate::utils::ops::BitWiseOp;
3068        use hanzo_ml::Tensor;
3069        let device = hanzo_ml::Device::new_cuda(0).unwrap();
3070        let a =
3071            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (5, 2), &device).unwrap();
3072        let b = Tensor::from_vec(vec![-1i64, 0, 0, 0, 0, 0, 0, 0, 0, 8], (5, 2), &device).unwrap();
3073        let c = a.bitwise_xor(&b).unwrap().to_vec2::<i64>().unwrap();
3074        assert_eq!(c, [[-2, 2], [3, -1], [-1, -1], [-1, 4], [5, 15]]);
3075    }
3076
3077    #[test]
3078    fn test_nonzero_and() {
3079        use crate::utils::ops::{BitWiseOp, NonZeroOp};
3080        use hanzo_ml::{Device, Tensor};
3081
3082        let input1 = Tensor::from_vec(
3083            vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7],
3084            (10,),
3085            &Device::Cpu,
3086        )
3087        .unwrap();
3088        let input2 = Tensor::from_vec(
3089            vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7],
3090            (10,),
3091            &Device::Cpu,
3092        )
3093        .unwrap();
3094        let input = Tensor::stack(&[input1, input2], 0).unwrap();
3095
3096        let lt = input.lt(0.0).unwrap();
3097        let gt = input.gt(-10.0).unwrap();
3098        let res = lt
3099            .bitwise_and(&gt)
3100            .unwrap()
3101            .nonzero()
3102            .unwrap()
3103            .to_vec2::<u32>()
3104            .unwrap();
3105
3106        assert_eq!(
3107            res,
3108            [
3109                [0, 3],
3110                [0, 4],
3111                [0, 5],
3112                [0, 6],
3113                [1, 0],
3114                [1, 3],
3115                [1, 5],
3116                [1, 6]
3117            ]
3118        );
3119    }
3120
3121    #[cfg(feature = "cuda")]
3122    #[test]
3123    fn nonzero_and_cuda() {
3124        use crate::utils::ops::{BitWiseOp, NonZeroOp};
3125        use hanzo_ml::{Device, Tensor};
3126
3127        let device = Device::new_cuda(0).unwrap();
3128        let input1 =
3129            Tensor::from_vec(vec![1i64, 2, 3, -1, -1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
3130        let input2 =
3131            Tensor::from_vec(vec![-1i64, 2, 3, -1, 1, -1, -1, 4, 5, 7], (10,), &device).unwrap();
3132        let input = Tensor::stack(&[input1, input2], 0).unwrap();
3133
3134        let lt = input.lt(0.0).unwrap();
3135        let gt = input.gt(-10.0).unwrap();
3136        let res = lt
3137            .bitwise_and(&gt)
3138            .unwrap()
3139            .nonzero()
3140            .unwrap()
3141            .to_vec2::<u32>()
3142            .unwrap();
3143
3144        assert_eq!(
3145            res,
3146            [
3147                [0, 3],
3148                [0, 4],
3149                [0, 5],
3150                [0, 6],
3151                [1, 0],
3152                [1, 3],
3153                [1, 5],
3154                [1, 6]
3155            ]
3156        );
3157    }
3158
3159    #[test]
3160    fn test_bitpack_8bit_cpu() {
3161        use crate::HqqBits;
3162        use hanzo_ml::{Device, Tensor};
3163        let bits = HqqBits::Eight;
3164        let device = Device::Cpu;
3165        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
3166        let c = bits.bitpack_type()(wq.clone())
3167            .unwrap()
3168            .to_vec2::<u8>()
3169            .unwrap();
3170        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
3171    }
3172
3173    #[cfg(feature = "cuda")]
3174    #[test]
3175    fn test_bitpack_8bit_cuda() {
3176        use crate::HqqBits;
3177        use hanzo_ml::{Device, Tensor};
3178        let bits = HqqBits::Eight;
3179        let device = Device::new_cuda(0).unwrap();
3180        // Use U8 tensor directly to avoid hanzo-ml's to_dtype which may not have
3181        // PTX compiled for newer GPU architectures (e.g., SM 120)
3182        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 255, 0], (3, 2), &device).unwrap();
3183        let c = bits.bitpack_type()(wq.clone())
3184            .unwrap()
3185            .to_vec2::<u8>()
3186            .unwrap();
3187        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
3188    }
3189
3190    #[cfg(feature = "metal")]
3191    #[test]
3192    fn test_bitpack_8bit_metal() {
3193        use crate::HqqBits;
3194        use hanzo_ml::{Device, Tensor};
3195        let bits = HqqBits::Eight;
3196        let device = Device::new_metal(0).unwrap();
3197        let wq = Tensor::from_vec(vec![257_i32, 258, 259, 260, 511, 512], (3, 2), &device).unwrap();
3198        let c = bits.bitpack_type()(wq.clone())
3199            .unwrap()
3200            .to_vec2::<u8>()
3201            .unwrap();
3202        assert_eq!(c, [[1, 2], [3, 4], [255, 0]]);
3203    }
3204
3205    #[test]
3206    fn test_bitpack_4bit() {
3207        use crate::HqqBits;
3208        use hanzo_ml::{Device, Tensor};
3209        let bits = HqqBits::Four;
3210        let device = Device::Cpu;
3211        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
3212        let c = bits.bitpack_type()(wq.clone())
3213            .unwrap()
3214            .to_vec2::<u8>()
3215            .unwrap();
3216        assert_eq!(c, [[19, 36]]);
3217    }
3218
3219    #[cfg(feature = "cuda")]
3220    #[test]
3221    fn test_bitpack_4bit_cuda() {
3222        use crate::HqqBits;
3223        use hanzo_ml::{Device, Tensor};
3224        let bits = HqqBits::Four;
3225        let device = Device::new_cuda(0).unwrap();
3226        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
3227        let c = bits.bitpack_type()(wq.clone())
3228            .unwrap()
3229            .to_vec2::<u8>()
3230            .unwrap();
3231        assert_eq!(c, [[19, 36]]);
3232    }
3233
3234    #[cfg(feature = "metal")]
3235    #[test]
3236    fn test_bitpack_4bit_metal() {
3237        use crate::HqqBits;
3238        use hanzo_ml::{Device, Tensor};
3239        let bits = HqqBits::Four;
3240        let device = Device::new_metal(0).unwrap();
3241        let wq = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], (3, 2), &device).unwrap();
3242        let c = bits.bitpack_type()(wq.clone())
3243            .unwrap()
3244            .to_vec2::<u8>()
3245            .unwrap();
3246        assert_eq!(c, [[19, 36]]);
3247    }
3248    // ─────────────────────────────── Sort / ArgSort ────────────────────────────────
3249    #[cfg(feature = "metal")]
3250    #[test]
3251    fn test_sort_and_argsort_vector_metal() {
3252        use crate::utils::ops::SortOp;
3253        use hanzo_ml::Tensor;
3254
3255        let device = hanzo_ml::Device::new_metal(0).unwrap();
3256        let a = Tensor::from_vec(vec![3i32, 1, 4, 2], &[4], &device).unwrap();
3257
3258        // sort (ascending)
3259        let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
3260        assert_eq!(sorted, [1, 2, 3, 4]);
3261
3262        // argsort (ascending indices)
3263        let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
3264        assert_eq!(idx, [1, 3, 0, 2]);
3265    }
3266
3267    #[cfg(feature = "metal")]
3268    #[test]
3269    fn test_sort_and_argsort_matrix_axis1_metal() {
3270        use crate::utils::ops::SortOp;
3271        use hanzo_ml::Tensor;
3272
3273        let device = hanzo_ml::Device::new_metal(0).unwrap();
3274        // 2 × 3 matrix:
3275        // [[3, 1, 2],
3276        //  [0, 4, 5]]
3277        let a = Tensor::from_vec(vec![3i32, 1, 2, 0, 4, 5], &[2, 3], &device).unwrap();
3278
3279        // Sort along axis=1 (second dimension)
3280        let sorted = a.fast_sort_asc(1).unwrap().to_vec2::<i32>().unwrap();
3281        assert_eq!(sorted, [[1, 2, 3], [0, 4, 5]]);
3282
3283        // ArgSort indices along axis=1
3284        let idx = a.fast_argsort_asc(1).unwrap().to_vec2::<u32>().unwrap();
3285        assert_eq!(idx, [[1, 2, 0], [0, 1, 2]]);
3286    }
3287
3288    // ─────────────────────────────── 2 048-element vector ────────────────────────────────
3289    #[cfg(feature = "metal")]
3290    #[test]
3291    fn test_sort_and_argsort_vector_2048_metal() {
3292        use crate::utils::ops::SortOp;
3293        use hanzo_ml::Tensor;
3294
3295        const N: usize = 4096;
3296
3297        let device = hanzo_ml::Device::new_metal(0).expect("Metal device");
3298
3299        // Create a descending vector [4095, 4094, …, 0]
3300        let vals: Vec<i32> = (0..N as i32).rev().collect();
3301        let a = Tensor::from_vec(vals.clone(), &[N], &device).unwrap();
3302
3303        // ---- sort (ascending) ---------------------------------------------------------
3304        let sorted = a.fast_sort_asc(0).unwrap().to_vec1::<i32>().unwrap();
3305        let expected: Vec<i32> = (0..N as i32).collect();
3306        assert_eq!(sorted, expected);
3307
3308        // ---- argsort (indices that would sort) ---------------------------------------
3309        let idx = a.fast_argsort_asc(0).unwrap().to_vec1::<u32>().unwrap();
3310        // Because the input is reversed, the correct indices are likewise reversed
3311        for (i, &v) in idx.iter().enumerate() {
3312            assert_eq!(v as usize, N - 1 - i);
3313        }
3314    }
3315
3316    #[cfg(feature = "metal")]
3317    #[test]
3318    fn test_fused_glu_metal_silu_f32() {
3319        use super::{fused_glu, GluActivationType};
3320        use hanzo_ml::Tensor;
3321
3322        let cpu = hanzo_ml::Device::Cpu;
3323        let metal = hanzo_ml::Device::new_metal(0).unwrap();
3324
3325        let a_data: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 64.0).collect();
3326        let b_data: Vec<f32> = (0..256).map(|i| (i as f32 * 0.7 - 90.0) / 50.0).collect();
3327
3328        let a_cpu = Tensor::from_vec(a_data.clone(), &[4, 64], &cpu).unwrap();
3329        let b_cpu = Tensor::from_vec(b_data.clone(), &[4, 64], &cpu).unwrap();
3330        let a_metal = Tensor::from_vec(a_data, &[4, 64], &metal).unwrap();
3331        let b_metal = Tensor::from_vec(b_data, &[4, 64], &metal).unwrap();
3332
3333        let cpu_result = fused_glu(&a_cpu, &b_cpu, GluActivationType::Silu)
3334            .unwrap()
3335            .to_vec2::<f32>()
3336            .unwrap();
3337        let metal_result = fused_glu(&a_metal, &b_metal, GluActivationType::Silu)
3338            .unwrap()
3339            .to_device(&cpu)
3340            .unwrap()
3341            .to_vec2::<f32>()
3342            .unwrap();
3343
3344        for (row_cpu, row_metal) in cpu_result.iter().zip(metal_result.iter()) {
3345            for (c, m) in row_cpu.iter().zip(row_metal.iter()) {
3346                let diff = (c - m).abs();
3347                assert!(
3348                    diff < 1e-4,
3349                    "SiLU F32 mismatch: cpu={c}, metal={m}, diff={diff}"
3350                );
3351            }
3352        }
3353    }
3354
3355    #[cfg(feature = "metal")]
3356    #[test]
3357    fn test_fused_glu_metal_silu_f16() {
3358        use super::{fused_glu, GluActivationType};
3359        use hanzo_ml::{DType, Tensor};
3360
3361        let cpu = hanzo_ml::Device::Cpu;
3362        let metal = hanzo_ml::Device::new_metal(0).unwrap();
3363
3364        let a_data: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 64.0).collect();
3365        let b_data: Vec<f32> = (0..256).map(|i| (i as f32 * 0.7 - 90.0) / 50.0).collect();
3366
3367        let a_cpu = Tensor::from_vec(a_data.clone(), &[256], &cpu)
3368            .unwrap()
3369            .to_dtype(DType::F16)
3370            .unwrap();
3371        let b_cpu = Tensor::from_vec(b_data.clone(), &[256], &cpu)
3372            .unwrap()
3373            .to_dtype(DType::F16)
3374            .unwrap();
3375        let a_metal = Tensor::from_vec(a_data, &[256], &metal)
3376            .unwrap()
3377            .to_dtype(DType::F16)
3378            .unwrap();
3379        let b_metal = Tensor::from_vec(b_data, &[256], &metal)
3380            .unwrap()
3381            .to_dtype(DType::F16)
3382            .unwrap();
3383
3384        let cpu_result = fused_glu(&a_cpu, &b_cpu, GluActivationType::Silu)
3385            .unwrap()
3386            .to_dtype(DType::F32)
3387            .unwrap()
3388            .to_vec1::<f32>()
3389            .unwrap();
3390        let metal_result = fused_glu(&a_metal, &b_metal, GluActivationType::Silu)
3391            .unwrap()
3392            .to_device(&cpu)
3393            .unwrap()
3394            .to_dtype(DType::F32)
3395            .unwrap()
3396            .to_vec1::<f32>()
3397            .unwrap();
3398
3399        for (i, (c, m)) in cpu_result.iter().zip(metal_result.iter()).enumerate() {
3400            let diff = (c - m).abs();
3401            assert!(
3402                diff < 1e-2,
3403                "SiLU F16 mismatch at {i}: cpu={c}, metal={m}, diff={diff}"
3404            );
3405        }
3406    }
3407
3408    #[cfg(feature = "metal")]
3409    #[test]
3410    fn test_fused_glu_metal_all_activations() {
3411        use super::{fused_glu, GluActivationType};
3412        use hanzo_ml::Tensor;
3413
3414        let cpu = hanzo_ml::Device::Cpu;
3415        let metal = hanzo_ml::Device::new_metal(0).unwrap();
3416
3417        let a_data: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) / 32.0).collect();
3418        let b_data: Vec<f32> = (0..128).map(|i| (i as f32 * 0.5 - 32.0) / 20.0).collect();
3419
3420        for act in [
3421            GluActivationType::Silu,
3422            GluActivationType::Gelu,
3423            GluActivationType::Relu,
3424            GluActivationType::GeluErf,
3425        ] {
3426            let a_cpu = Tensor::from_vec(a_data.clone(), &[128], &cpu).unwrap();
3427            let b_cpu = Tensor::from_vec(b_data.clone(), &[128], &cpu).unwrap();
3428            let a_metal = Tensor::from_vec(a_data.clone(), &[128], &metal).unwrap();
3429            let b_metal = Tensor::from_vec(b_data.clone(), &[128], &metal).unwrap();
3430
3431            let cpu_result = fused_glu(&a_cpu, &b_cpu, act)
3432                .unwrap()
3433                .to_vec1::<f32>()
3434                .unwrap();
3435            let metal_result = fused_glu(&a_metal, &b_metal, act)
3436                .unwrap()
3437                .to_device(&cpu)
3438                .unwrap()
3439                .to_vec1::<f32>()
3440                .unwrap();
3441
3442            for (i, (c, m)) in cpu_result.iter().zip(metal_result.iter()).enumerate() {
3443                let diff = (c - m).abs();
3444                assert!(
3445                    diff < 1e-4,
3446                    "{act:?} F32 mismatch at {i}: cpu={c}, metal={m}, diff={diff}"
3447                );
3448            }
3449        }
3450    }
3451
3452    /// Test that fused_glu matches hanzo-ml's fallback path (a.gelu() * b) for BF16.
3453    /// This was the exact scenario that caused model failure (Gemma 3 4B, BF16, GeluPytorchTanh).
3454    #[cfg(feature = "metal")]
3455    #[test]
3456    fn test_fused_glu_matches_fallback_bf16() {
3457        use super::{fused_glu, GluActivationType};
3458        use hanzo_ml::{DType, Tensor};
3459
3460        let metal = hanzo_ml::Device::new_metal(0).unwrap();
3461
3462        // Use realistic-sized data matching model dimensions
3463        let n = 10240;
3464        let a_data: Vec<f32> = (0..n).map(|i| (i as f32 - 5120.0) / 2560.0).collect();
3465        let b_data: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3 - 1500.0) / 1000.0).collect();
3466
3467        let a_metal = Tensor::from_vec(a_data.clone(), &[1, 2, n / 2], &metal)
3468            .unwrap()
3469            .to_dtype(DType::BF16)
3470            .unwrap();
3471        let b_metal = Tensor::from_vec(b_data.clone(), &[1, 2, n / 2], &metal)
3472            .unwrap()
3473            .to_dtype(DType::BF16)
3474            .unwrap();
3475
3476        // Fused path
3477        let fused = fused_glu(&a_metal, &b_metal, GluActivationType::Gelu).unwrap();
3478
3479        // Hanzo's fallback: a.gelu() * b (the tanh-approx GELU)
3480        let fallback = (a_metal.gelu().unwrap() * &b_metal).unwrap();
3481
3482        let fused_f32 = fused
3483            .to_dtype(DType::F32)
3484            .unwrap()
3485            .flatten_all()
3486            .unwrap()
3487            .to_vec1::<f32>()
3488            .unwrap();
3489        let fallback_f32 = fallback
3490            .to_dtype(DType::F32)
3491            .unwrap()
3492            .flatten_all()
3493            .unwrap()
3494            .to_vec1::<f32>()
3495            .unwrap();
3496
3497        let mut max_diff: f32 = 0.0;
3498        let mut num_mismatches = 0;
3499        for (f, fb) in fused_f32.iter().zip(fallback_f32.iter()) {
3500            let diff = (f - fb).abs();
3501            if diff > max_diff {
3502                max_diff = diff;
3503            }
3504            if diff > 0.0 {
3505                num_mismatches += 1;
3506            }
3507        }
3508        eprintln!(
3509            "BF16 Gelu fused vs fallback: max_diff={max_diff}, mismatches={num_mismatches}/{}",
3510            fused_f32.len()
3511        );
3512        // Allow up to 1 BF16 ULP difference (0.015625 at values around 1-2)
3513        // This is acceptable since Metal compiler may keep intermediate precision
3514        assert!(
3515            max_diff <= 0.015625,
3516            "BF16 Gelu fused vs reference fallback max_diff {max_diff} exceeds 1 BF16 ULP"
3517        );
3518    }
3519
3520    #[cfg(feature = "cuda")]
3521    #[test]
3522    fn test_fused_glu_cuda_silu_f32() {
3523        use super::{fused_glu, GluActivationType};
3524        use hanzo_ml::Tensor;
3525
3526        let cpu = hanzo_ml::Device::Cpu;
3527        let cuda = hanzo_ml::Device::new_cuda(0).unwrap();
3528
3529        let a_data: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 64.0).collect();
3530        let b_data: Vec<f32> = (0..256).map(|i| (i as f32 * 0.7 - 90.0) / 50.0).collect();
3531
3532        let a_cpu = Tensor::from_vec(a_data.clone(), &[4, 64], &cpu).unwrap();
3533        let b_cpu = Tensor::from_vec(b_data.clone(), &[4, 64], &cpu).unwrap();
3534        let a_cuda = Tensor::from_vec(a_data, &[4, 64], &cuda).unwrap();
3535        let b_cuda = Tensor::from_vec(b_data, &[4, 64], &cuda).unwrap();
3536
3537        let cpu_result = fused_glu(&a_cpu, &b_cpu, GluActivationType::Silu)
3538            .unwrap()
3539            .to_vec2::<f32>()
3540            .unwrap();
3541        let cuda_result = fused_glu(&a_cuda, &b_cuda, GluActivationType::Silu)
3542            .unwrap()
3543            .to_device(&cpu)
3544            .unwrap()
3545            .to_vec2::<f32>()
3546            .unwrap();
3547
3548        for (row_cpu, row_cuda) in cpu_result.iter().zip(cuda_result.iter()) {
3549            for (c, g) in row_cpu.iter().zip(row_cuda.iter()) {
3550                let diff = (c - g).abs();
3551                assert!(
3552                    diff < 1e-4,
3553                    "SiLU F32 mismatch: cpu={c}, cuda={g}, diff={diff}"
3554                );
3555            }
3556        }
3557    }
3558
3559    #[cfg(feature = "cuda")]
3560    #[test]
3561    fn test_fused_glu_cuda_silu_f16() {
3562        use super::{fused_glu, GluActivationType};
3563        use hanzo_ml::{DType, Tensor};
3564
3565        let cpu = hanzo_ml::Device::Cpu;
3566        let cuda = hanzo_ml::Device::new_cuda(0).unwrap();
3567
3568        let a_data: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 64.0).collect();
3569        let b_data: Vec<f32> = (0..256).map(|i| (i as f32 * 0.7 - 90.0) / 50.0).collect();
3570
3571        let a_cpu = Tensor::from_vec(a_data.clone(), &[256], &cpu)
3572            .unwrap()
3573            .to_dtype(DType::F16)
3574            .unwrap();
3575        let b_cpu = Tensor::from_vec(b_data.clone(), &[256], &cpu)
3576            .unwrap()
3577            .to_dtype(DType::F16)
3578            .unwrap();
3579        let a_cuda = Tensor::from_vec(a_data, &[256], &cuda)
3580            .unwrap()
3581            .to_dtype(DType::F16)
3582            .unwrap();
3583        let b_cuda = Tensor::from_vec(b_data, &[256], &cuda)
3584            .unwrap()
3585            .to_dtype(DType::F16)
3586            .unwrap();
3587
3588        let cpu_result = fused_glu(&a_cpu, &b_cpu, GluActivationType::Silu)
3589            .unwrap()
3590            .to_dtype(DType::F32)
3591            .unwrap()
3592            .to_vec1::<f32>()
3593            .unwrap();
3594        let cuda_result = fused_glu(&a_cuda, &b_cuda, GluActivationType::Silu)
3595            .unwrap()
3596            .to_device(&cpu)
3597            .unwrap()
3598            .to_dtype(DType::F32)
3599            .unwrap()
3600            .to_vec1::<f32>()
3601            .unwrap();
3602
3603        for (i, (c, g)) in cpu_result.iter().zip(cuda_result.iter()).enumerate() {
3604            let diff = (c - g).abs();
3605            assert!(
3606                diff < 1e-2,
3607                "SiLU F16 mismatch at {i}: cpu={c}, cuda={g}, diff={diff}"
3608            );
3609        }
3610    }
3611
3612    #[cfg(feature = "cuda")]
3613    #[test]
3614    fn test_fused_glu_cuda_all_activations() {
3615        use super::{fused_glu, GluActivationType};
3616        use hanzo_ml::Tensor;
3617
3618        let cpu = hanzo_ml::Device::Cpu;
3619        let cuda = hanzo_ml::Device::new_cuda(0).unwrap();
3620
3621        let a_data: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) / 32.0).collect();
3622        let b_data: Vec<f32> = (0..128).map(|i| (i as f32 * 0.5 - 32.0) / 20.0).collect();
3623
3624        for act in [
3625            GluActivationType::Silu,
3626            GluActivationType::Gelu,
3627            GluActivationType::Relu,
3628            GluActivationType::GeluErf,
3629        ] {
3630            let a_cpu = Tensor::from_vec(a_data.clone(), &[128], &cpu).unwrap();
3631            let b_cpu = Tensor::from_vec(b_data.clone(), &[128], &cpu).unwrap();
3632            let a_cuda = Tensor::from_vec(a_data.clone(), &[128], &cuda).unwrap();
3633            let b_cuda = Tensor::from_vec(b_data.clone(), &[128], &cuda).unwrap();
3634
3635            let cpu_result = fused_glu(&a_cpu, &b_cpu, act)
3636                .unwrap()
3637                .to_vec1::<f32>()
3638                .unwrap();
3639            let cuda_result = fused_glu(&a_cuda, &b_cuda, act)
3640                .unwrap()
3641                .to_device(&cpu)
3642                .unwrap()
3643                .to_vec1::<f32>()
3644                .unwrap();
3645
3646            for (i, (c, g)) in cpu_result.iter().zip(cuda_result.iter()).enumerate() {
3647                let diff = (c - g).abs();
3648                assert!(
3649                    diff < 1e-4,
3650                    "{act:?} F32 mismatch at {i}: cpu={c}, cuda={g}, diff={diff}"
3651                );
3652            }
3653        }
3654    }
3655
3656    #[cfg(feature = "cuda")]
3657    #[test]
3658    fn test_fused_glu_matches_fallback_bf16_cuda() {
3659        use super::{fused_glu, GluActivationType};
3660        use hanzo_ml::{DType, Tensor};
3661
3662        let cuda = hanzo_ml::Device::new_cuda(0).unwrap();
3663
3664        let n = 10240;
3665        let a_data: Vec<f32> = (0..n).map(|i| (i as f32 - 5120.0) / 2560.0).collect();
3666        let b_data: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3 - 1500.0) / 1000.0).collect();
3667
3668        let a_cuda = Tensor::from_vec(a_data.clone(), &[1, 2, n / 2], &cuda)
3669            .unwrap()
3670            .to_dtype(DType::BF16)
3671            .unwrap();
3672        let b_cuda = Tensor::from_vec(b_data.clone(), &[1, 2, n / 2], &cuda)
3673            .unwrap()
3674            .to_dtype(DType::BF16)
3675            .unwrap();
3676
3677        // Fused path
3678        let fused = fused_glu(&a_cuda, &b_cuda, GluActivationType::Gelu).unwrap();
3679
3680        // Hanzo's fallback: a.gelu() * b (the tanh-approx GELU)
3681        let fallback = (a_cuda.gelu().unwrap() * &b_cuda).unwrap();
3682
3683        let fused_f32 = fused
3684            .to_dtype(DType::F32)
3685            .unwrap()
3686            .flatten_all()
3687            .unwrap()
3688            .to_vec1::<f32>()
3689            .unwrap();
3690        let fallback_f32 = fallback
3691            .to_dtype(DType::F32)
3692            .unwrap()
3693            .flatten_all()
3694            .unwrap()
3695            .to_vec1::<f32>()
3696            .unwrap();
3697
3698        let mut max_diff: f32 = 0.0;
3699        let mut num_mismatches = 0;
3700        for (f, fb) in fused_f32.iter().zip(fallback_f32.iter()) {
3701            let diff = (f - fb).abs();
3702            if diff > max_diff {
3703                max_diff = diff;
3704            }
3705            if diff > 0.0 {
3706                num_mismatches += 1;
3707            }
3708        }
3709        eprintln!(
3710            "CUDA BF16 Gelu fused vs fallback: max_diff={max_diff}, mismatches={num_mismatches}/{}",
3711            fused_f32.len()
3712        );
3713        // Allow up to 1 BF16 ULP difference (0.015625 at values around 1-2)
3714        assert!(
3715            max_diff <= 0.015625,
3716            "CUDA BF16 Gelu fused vs reference fallback max_diff {max_diff} exceeds 1 BF16 ULP"
3717        );
3718    }
3719}