1use hanzo_ml::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
5use rayon::prelude::*;
6
7pub fn softmax<D: hanzo_ml::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
23 let dim = dim.to_index(xs.shape(), "softmax")?;
24 let max = xs.max_keepdim(dim)?;
25 let diff = xs.broadcast_sub(&max)?;
26 let num = diff.exp()?;
27 let den = num.sum_keepdim(dim)?;
28 num.broadcast_div(&den)
29}
30
31pub fn log_softmax<D: hanzo_ml::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
32 let d = d.to_index(xs.shape(), "log-softmax")?;
33 let max = xs.max_keepdim(d)?;
34 let diff = xs.broadcast_sub(&max)?;
35 let sum_exp = diff.exp()?.sum_keepdim(d)?;
36 let log_sm = diff.broadcast_sub(&sum_exp.log()?)?;
37 Ok(log_sm)
38}
39
40pub fn silu(xs: &Tensor) -> Result<Tensor> {
41 xs.silu()
42}
43
44pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
45 let xs = xs.chunk(2, D::Minus1)?;
46 &xs[0].silu()? * &xs[1]
47}
48
49struct Sigmoid;
50
51impl hanzo_ml::CustomOp1 for Sigmoid {
52 fn name(&self) -> &'static str {
53 "sigmoid"
54 }
55
56 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
57 use hanzo_ml::backend::BackendStorage;
58
59 fn fwd<T: num_traits::Float>(v: T) -> T {
60 (v.neg().exp() + T::one()).recip()
61 }
62
63 let storage = match storage {
65 CpuStorage::BF16(slice) => {
66 CpuStorage::BF16(hanzo_ml::cpu_backend::unary_map(slice, layout, fwd))
67 }
68 CpuStorage::F16(slice) => {
69 CpuStorage::F16(hanzo_ml::cpu_backend::unary_map(slice, layout, fwd))
70 }
71 CpuStorage::F32(slice) => {
72 CpuStorage::F32(hanzo_ml::cpu_backend::unary_map(slice, layout, fwd))
73 }
74 CpuStorage::F64(slice) => {
75 CpuStorage::F64(hanzo_ml::cpu_backend::unary_map(slice, layout, fwd))
76 }
77 _ => Err(hanzo_ml::Error::UnsupportedDTypeForOp(
78 storage.dtype(),
79 self.name(),
80 ))?,
81 };
82 Ok((storage, layout.shape().clone()))
83 }
84
85 #[cfg(feature = "cuda")]
86 fn cuda_fwd(
87 &self,
88 storage: &hanzo_ml::CudaStorage,
89 layout: &Layout,
90 ) -> Result<(hanzo_ml::CudaStorage, Shape)> {
91 use hanzo_ml::backend::BackendStorage;
92 use hanzo_ml::cuda_backend::cudarc::driver::{
93 CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
94 };
95 use hanzo_ml::cuda_backend::SlicePtrOrNull;
96 use hanzo_ml::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
97 use hanzo_ml::{CudaDevice, WithDType};
98
99 struct S;
100 impl Map1 for S {
101 fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
102 &self,
103 src: &CudaSlice<T>,
104 dev: &CudaDevice,
105 layout: &Layout,
106 ) -> Result<CudaSlice<T>> {
107 let shape = layout.shape();
108 let dims = shape.dims();
109 let el_count = shape.elem_count();
110 let cfg = LaunchConfig::for_num_elems(el_count as u32);
111 let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
112 let src = &src.slice(layout.start_offset()..);
113 let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
114 let out = unsafe { dev.alloc::<T>(el_count)? };
116
117 let mut builder = func.builder();
118 hanzo_ml::builder_arg!(builder, el_count, dims.len());
119 ds.builder_arg(&mut builder);
120 builder.arg(src);
121 builder.arg(&out);
122 unsafe { builder.launch(cfg) }.w()?;
124 Ok(out)
125 }
126 }
127
128 let dev = storage.device();
129 let slice = S.map(&storage.slice, dev, layout)?;
130 let dst = hanzo_ml::CudaStorage {
131 slice,
132 device: dev.clone(),
133 };
134 Ok((dst, layout.shape().clone()))
135 }
136
137 #[cfg(feature = "metal")]
138 fn metal_fwd(
139 &self,
140 storage: &hanzo_ml::MetalStorage,
141 layout: &Layout,
142 ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
143 use hanzo_ml::backend::BackendStorage;
144 use hanzo_ml::MetalError;
145 let device = storage.device();
146 let dtype = storage.dtype();
147 let shape = layout.shape();
148 let el_count = shape.elem_count();
149 let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
150 let encoder = device.command_encoder()?;
151 encoder.set_label("sigmoid");
152 let src = hanzo_metal_kernels::BufferOffset {
153 buffer: storage.buffer(),
154 offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
155 };
156
157 if layout.is_contiguous() {
158 use hanzo_metal_kernels::unary::contiguous;
159 let kernel_name = match dtype {
160 DType::F16 => contiguous::sigmoid::HALF,
161 DType::F32 => contiguous::sigmoid::FLOAT,
162 DType::BF16 => contiguous::sigmoid::BFLOAT,
163 dtype => {
164 hanzo_ml::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
165 }
166 };
167 hanzo_metal_kernels::call_unary_contiguous(
168 device.metal_device(),
169 &encoder,
170 device.kernels(),
171 kernel_name,
172 dtype.size_in_bytes(),
173 el_count,
174 src,
175 &buffer,
176 )
177 .map_err(MetalError::from)?;
178 } else {
179 use hanzo_metal_kernels::unary::strided;
180 let kernel_name = match dtype {
181 DType::F16 => strided::sigmoid::HALF,
182 DType::F32 => strided::sigmoid::FLOAT,
183 DType::BF16 => strided::sigmoid::BFLOAT,
184 dtype => {
185 hanzo_ml::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
186 }
187 };
188 let dst = hanzo_metal_kernels::BufferOffset::zero_offset(&buffer);
189 hanzo_metal_kernels::call_unary_strided(
190 device.metal_device(),
191 &encoder,
192 device.kernels(),
193 kernel_name,
194 layout.dims(),
195 src,
196 layout.stride(),
197 dst,
198 )
199 .map_err(MetalError::from)?;
200 }
201
202 let new_storage = hanzo_ml::MetalStorage::new(buffer, device.clone(), el_count, dtype);
203 Ok((new_storage, layout.shape().clone()))
204 }
205
206 fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
207 let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?;
209 Ok(Some(grad_res.mul(&d_dx_sigmoid)?))
210 }
211}
212
213pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
214 xs.apply_op1(Sigmoid)
215}
216
217pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
218 ((xs + 3.0)? / 6.0)?.clamp(0f32, 1f32)
220}
221
222pub fn mish(xs: &Tensor) -> Result<Tensor> {
223 xs * (1.0 + xs.exp()?)?.log()?.tanh()
224}
225
226pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result<Tensor> {
227 let zeros = xs.zeros_like()?;
228 xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope
229}
230
231pub fn selu(xs: &Tensor, alpha: f32, gamma: f32) -> Result<Tensor> {
232 let is_pos = xs.gt(0f32)?;
233 let alpha_t = Tensor::full(alpha, xs.dims(), xs.device())?;
234 let neg = xs.exp()?.mul(&alpha_t)?.sub(&alpha_t)?;
235 let selu = is_pos.where_cond(xs, &neg)?;
236 let gamma_t = Tensor::full(gamma, xs.dims(), xs.device())?;
237 selu.broadcast_mul(&gamma_t)
238}
239
240pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
241 if !(0. ..1.).contains(&drop_p) {
247 hanzo_ml::bail!("dropout probability has to be in [0, 1), got {drop_p}")
248 }
249 let rand = Tensor::rand(0f32, 1f32, xs.shape(), xs.device())?;
250 let scale = 1.0 / (1.0 - drop_p as f64);
251 let drop_p = Tensor::new(drop_p, xs.device())?.broadcast_as(xs.shape())?;
252 let mask = (rand.ge(&drop_p)?.to_dtype(xs.dtype())? * scale)?;
253 xs * mask
254}
255
256#[derive(Clone, Debug)]
257pub struct Dropout {
258 drop_p: f32,
259}
260
261impl Dropout {
262 pub fn new(drop_p: f32) -> Dropout {
263 Self { drop_p }
264 }
265
266 pub fn forward(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
267 if train {
268 dropout(xs, self.drop_p)
269 } else {
270 Ok(xs.clone())
271 }
272 }
273}
274
275impl hanzo_ml::ModuleT for Dropout {
276 fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
277 self.forward(xs, train)
278 }
279}
280
281struct SoftmaxLastDim;
282
283impl hanzo_ml::CustomOp1 for SoftmaxLastDim {
284 fn name(&self) -> &'static str {
285 "softmax-last-dim"
286 }
287
288 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
289 fn softmax<T: hanzo_ml::WithDType + num_traits::Float>(
290 src: &[T],
291 layout: &Layout,
292 ) -> Result<(CpuStorage, Shape)> {
293 let src = match layout.contiguous_offsets() {
294 None => hanzo_ml::bail!("input has to be contiguous"),
295 Some((o1, o2)) => &src[o1..o2],
296 };
297 let el_count = layout.shape().elem_count();
298 let dims = layout.shape().dims();
299 let dim_m1 = dims[dims.len() - 1];
300 let mut dst = vec![T::zero(); el_count];
301 src.par_chunks(dim_m1)
302 .zip(dst.par_chunks_mut(dim_m1))
303 .for_each(|(src, dst)| {
304 let mut max = T::neg_infinity();
305 unsafe { T::vec_reduce_max(src.as_ptr(), &mut max, dim_m1) };
306 for (s, d) in src.iter().zip(dst.iter_mut()) {
307 *d = (*s - max).exp();
308 }
309 let mut sum_exp = T::zero();
310 unsafe { T::vec_reduce_sum(dst.as_ptr(), &mut sum_exp, dim_m1) };
311 for d in dst.iter_mut() {
312 *d /= sum_exp
313 }
314 });
315 let storage = hanzo_ml::WithDType::to_cpu_storage_owned(dst);
316 Ok((storage, Shape::from_dims(dims)))
317 }
318
319 match storage {
320 CpuStorage::BF16(slice) => softmax::<half::bf16>(slice, layout),
321 CpuStorage::F16(slice) => softmax::<half::f16>(slice, layout),
322 CpuStorage::F32(slice) => softmax::<f32>(slice, layout),
323 CpuStorage::F64(slice) => softmax::<f64>(slice, layout),
324 _ => hanzo_ml::bail!("unsupported dtype for softmax {:?}", storage),
325 }
326 }
327
328 #[cfg(feature = "vulkan")]
329 fn vulkan_fwd(
330 &self,
331 storage: &hanzo_ml::VulkanStorage,
332 layout: &Layout,
333 ) -> Result<(hanzo_ml::VulkanStorage, Shape)> {
334 let out = storage.softmax_last_dim(layout)?;
335 Ok((out, layout.shape().clone()))
336 }
337
338 #[cfg(feature = "cuda")]
339 fn cuda_fwd(
340 &self,
341 storage: &hanzo_ml::CudaStorage,
342 layout: &Layout,
343 ) -> Result<(hanzo_ml::CudaStorage, Shape)> {
344 use hanzo_ml::cuda_backend::cudarc::driver::{
345 CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
346 };
347 use hanzo_ml::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
348 use hanzo_ml::{CudaDevice, WithDType};
349
350 struct S;
351 impl Map1 for S {
352 fn f<T: DeviceRepr + WithDType>(
353 &self,
354 src: &CudaSlice<T>,
355 dev: &CudaDevice,
356 layout: &Layout,
357 ) -> Result<CudaSlice<T>> {
358 let src = match layout.contiguous_offsets() {
359 None => hanzo_ml::bail!("input has to be contiguous"),
360 Some((o1, o2)) => src.slice(o1..o2),
361 };
362 let el = layout.shape().elem_count();
363 let dims = layout.shape().dims();
364 let dim_m1 = dims[dims.len() - 1];
365 let (n_rows, n_cols) = (el / dim_m1, dim_m1);
366
367 let cfg = LaunchConfig {
368 grid_dim: (n_rows as u32, 1, 1),
369 block_dim: (1, 32, 1),
370 shared_mem_bytes: 0,
371 };
372 let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
373 let dst = unsafe { dev.alloc::<T>(el)? };
375 let mut builder = func.builder();
376 builder.arg(&src);
377 builder.arg(&dst);
378 hanzo_ml::builder_arg!(builder, n_cols as i32);
379 unsafe { builder.launch(cfg) }.w()?;
381 Ok(dst)
382 }
383 }
384
385 use hanzo_ml::backend::BackendStorage;
386 let dev = storage.device();
387 let slice = S.map(&storage.slice, dev, layout)?;
388 let dst = hanzo_ml::cuda_backend::CudaStorage {
389 slice,
390 device: dev.clone(),
391 };
392 Ok((dst, layout.shape().clone()))
393 }
394
395 #[cfg(feature = "metal")]
396 fn metal_fwd(
397 &self,
398 storage: &hanzo_ml::MetalStorage,
399 layout: &Layout,
400 ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
401 use hanzo_ml::backend::BackendStorage;
402 let device = storage.device();
403 let encoder = device.command_encoder()?;
404 encoder.set_label("softmax");
405 let kernels = device.kernels();
406 let name = match storage.dtype() {
407 DType::F32 => "softmax_f32",
408 DType::F16 => "softmax_f16",
409 DType::BF16 => "softmax_bf16",
410 dtype => hanzo_ml::bail!("softmax-last-dim is not implemented for {dtype:?}"),
411 };
412
413 let n = layout.stride().len();
414 if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
415 hanzo_ml::bail!("Non contiguous softmax-last-dim is not implemented");
416 }
417
418 let last_dim = layout.dims()[layout.shape().rank() - 1];
419 let elem_count = layout.shape().elem_count();
420 let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
421 hanzo_metal_kernels::call_last_softmax(
422 device.metal_device(),
423 &encoder,
424 kernels,
425 name,
426 elem_count,
427 last_dim,
428 storage.buffer(),
429 layout.start_offset() * storage.dtype().size_in_bytes(),
430 &output,
431 )
432 .map_err(hanzo_ml::Error::wrap)?;
433 let newstorage =
434 hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
435 Ok((newstorage, layout.shape().clone()))
436 }
437}
438
439pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
440 if xs.device().is_rocm() {
441 return softmax(xs, D::Minus1);
442 }
443 xs.apply_op1_no_bwd(&SoftmaxLastDim)
444}
445
446#[derive(Debug, Clone)]
447struct RmsNorm {
448 eps: f32,
449}
450
451impl hanzo_ml::CustomOp2 for RmsNorm {
452 fn name(&self) -> &'static str {
453 "rms-norm"
454 }
455
456 fn cpu_fwd(
457 &self,
458 s1: &CpuStorage,
459 l1: &Layout,
460 s2: &CpuStorage,
461 l2: &Layout,
462 ) -> Result<(CpuStorage, Shape)> {
463 use hanzo_ml::backend::BackendStorage;
464
465 let eps = self.eps;
466 fn inner<
467 T: hanzo_ml::WithDType
468 + num_traits::Float
469 + num_traits::AsPrimitive<f32>
470 + num_traits::FromPrimitive,
471 >(
472 src: &[T],
473 layout: &Layout,
474 alpha: &[T],
475 alpha_layout: &Layout,
476 eps: f32,
477 ) -> Result<(CpuStorage, Shape)> {
478 let src = match layout.contiguous_offsets() {
479 None => hanzo_ml::bail!("input has to be contiguous"),
480 Some((o1, o2)) => &src[o1..o2],
481 };
482 let alpha = match alpha_layout.contiguous_offsets() {
483 None => hanzo_ml::bail!("alpha has to be contiguous"),
484 Some((o1, o2)) => &alpha[o1..o2],
485 };
486 let el_count = layout.shape().elem_count();
487 let dims = layout.shape().dims();
488 let dim_m1 = dims[dims.len() - 1];
489 let mut dst = vec![T::zero(); el_count];
490 src.par_chunks(dim_m1)
491 .zip(dst.par_chunks_mut(dim_m1))
492 .for_each(|(src, dst)| {
493 let sum2 = src
494 .iter()
495 .map(|&v| {
496 let v = v.as_();
497 v * v
498 })
499 .sum::<f32>();
500 let m = (sum2 / dim_m1 as f32 + eps).sqrt();
501 let m = T::from_f32(m).unwrap_or_else(T::nan);
502 for ((d, s), alpha) in dst.iter_mut().zip(src.iter()).zip(alpha) {
503 *d = *s / m * *alpha
504 }
505 });
506 let storage = hanzo_ml::WithDType::to_cpu_storage_owned(dst);
507 Ok((storage, Shape::from_dims(dims)))
508 }
509
510 use CpuStorage as C;
511 match (s1, s2) {
512 (C::BF16(s1), C::BF16(s2)) => inner::<half::bf16>(s1, l1, s2, l2, eps),
513 (C::F16(s1), C::F16(s2)) => inner::<half::f16>(s1, l1, s2, l2, eps),
514 (C::F32(s1), C::F32(s2)) => inner::<f32>(s1, l1, s2, l2, eps),
515 _ => hanzo_ml::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
516 }
517 }
518
519 #[cfg(feature = "vulkan")]
520 fn vulkan_fwd(
521 &self,
522 s1: &hanzo_ml::VulkanStorage,
523 l1: &Layout,
524 s2: &hanzo_ml::VulkanStorage,
525 l2: &Layout,
526 ) -> Result<(hanzo_ml::VulkanStorage, Shape)> {
527 let out = s1.rms_norm(l1, s2, l2, self.eps)?;
528 Ok((out, l1.shape().clone()))
529 }
530
531 #[cfg(feature = "cuda")]
532 fn cuda_fwd(
533 &self,
534 s1: &hanzo_ml::CudaStorage,
535 l1: &Layout,
536 s2: &hanzo_ml::CudaStorage,
537 l2: &Layout,
538 ) -> Result<(hanzo_ml::CudaStorage, Shape)> {
539 use hanzo_ml::cuda_backend::cudarc::driver::{
540 CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
541 };
542 use hanzo_ml::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
543 use hanzo_ml::{CudaDevice, WithDType};
544
545 struct S {
546 eps: f32,
547 }
548 impl Map2 for S {
549 fn f<T: DeviceRepr + WithDType>(
550 &self,
551 src: &CudaSlice<T>,
552 layout: &Layout,
553 alpha: &CudaSlice<T>,
554 alpha_layout: &Layout,
555 dev: &CudaDevice,
556 ) -> Result<CudaSlice<T>> {
557 let src = match layout.contiguous_offsets() {
558 None => hanzo_ml::bail!("input has to be contiguous"),
559 Some((o1, o2)) => src.slice(o1..o2),
560 };
561 let alpha = match alpha_layout.contiguous_offsets() {
562 None => hanzo_ml::bail!("alpha has to be contiguous"),
563 Some((o1, o2)) => alpha.slice(o1..o2),
564 };
565 let el = layout.shape().elem_count();
566 let dims = layout.shape().dims();
567 let dim_m1 = dims[dims.len() - 1];
568 let (n_rows, n_cols) = (el / dim_m1, dim_m1);
569
570 let block_size = if n_cols < 1024 { 32 } else { 1024 };
571 let cfg = LaunchConfig {
572 grid_dim: (n_rows as u32, 1, 1),
573 block_dim: (block_size, 1, 1),
574 shared_mem_bytes: 0,
575 };
576 let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
577 let dst = unsafe { dev.alloc::<T>(el)? };
579 let mut builder = func.builder();
580 builder.arg(&src);
581 builder.arg(&dst);
582 builder.arg(&alpha);
583 hanzo_ml::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
584 unsafe { builder.launch(cfg) }.w()?;
586 Ok(dst)
587 }
588 }
589
590 use hanzo_ml::backend::BackendStorage;
591 let dev = s1.device();
592 let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, dev)?;
593 let dst = hanzo_ml::cuda_backend::CudaStorage {
594 slice,
595 device: dev.clone(),
596 };
597 Ok((dst, l1.shape().clone()))
598 }
599
600 #[cfg(feature = "metal")]
601 fn metal_fwd(
602 &self,
603 s1: &hanzo_ml::MetalStorage,
604 l1: &Layout,
605 s2: &hanzo_ml::MetalStorage,
606 l2: &Layout,
607 ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
608 use hanzo_ml::backend::BackendStorage;
609 let device = s1.device();
610 let encoder = device.command_encoder()?;
611 encoder.set_label("rmsnorm");
612 let kernels = device.kernels();
613 let name = match (s1.dtype(), s2.dtype()) {
614 (DType::F32, DType::F32) => "rmsnorm_f32",
615 (DType::F16, DType::F16) => "rmsnorm_f16",
616 (DType::BF16, DType::BF16) => "rmsnorm_bf16",
617 (dt1, dt2) => hanzo_ml::bail!("rmsnorm is not implemented for {dt1:?} {dt2:?}"),
618 };
619
620 if !(l1.is_contiguous() && l2.is_contiguous()) {
621 hanzo_ml::bail!("Non contiguous rmsnorm is not implemented");
622 }
623
624 let last_dim = l1.dims()[l1.shape().rank() - 1];
625 let elem_count = l1.shape().elem_count();
626 let output = device.new_buffer(elem_count, s1.dtype(), "rmsnorm")?;
627 hanzo_metal_kernels::call_rms_norm(
628 device.metal_device(),
629 &encoder,
630 kernels,
631 name,
632 elem_count,
633 last_dim,
634 self.eps,
635 s1.buffer(),
636 l1.start_offset() * s1.dtype().size_in_bytes(),
637 s2.buffer(),
638 l2.start_offset() * s2.dtype().size_in_bytes(),
639 &output,
640 )
641 .map_err(hanzo_ml::Error::wrap)?;
642 let newstorage =
643 hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
644 Ok((newstorage, l1.shape().clone()))
645 }
646}
647
648pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
649 let x_dtype = x.dtype();
650 let internal_dtype = match x_dtype {
651 DType::F16 | DType::BF16 => DType::F32,
652 d => d,
653 };
654 let hidden_size = x.dim(D::Minus1)?;
655 let x = x.to_dtype(internal_dtype)?;
656 let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
657 let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
658 x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
659}
660
661pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
662 let hidden_size_xs = xs.dim(D::Minus1)?;
663 let hidden_size_alpha = alpha.dims1()?;
664 if hidden_size_xs != hidden_size_alpha {
665 hanzo_ml::bail!(
666 "shape mismatch in rms-norm {:?} {:?}",
667 xs.shape(),
668 alpha.shape()
669 )
670 }
671 if xs.device().is_rocm() {
674 return rms_norm_slow(xs, alpha, eps);
675 }
676 xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
677}
678
679struct SiluMul;
681
682impl hanzo_ml::CustomOp2 for SiluMul {
683 fn name(&self) -> &'static str {
684 "silu-mul"
685 }
686
687 fn cpu_fwd(
688 &self,
689 s1: &CpuStorage,
690 l1: &Layout,
691 s2: &CpuStorage,
692 l2: &Layout,
693 ) -> Result<(CpuStorage, Shape)> {
694 fn inner<
695 T: hanzo_ml::WithDType
696 + num_traits::Float
697 + num_traits::AsPrimitive<f32>
698 + num_traits::FromPrimitive,
699 >(
700 a: &[T],
701 la: &Layout,
702 b: &[T],
703 lb: &Layout,
704 ) -> Result<(CpuStorage, Shape)> {
705 let a = match la.contiguous_offsets() {
706 Some((o1, o2)) => &a[o1..o2],
707 None => hanzo_ml::bail!("silu-mul: a must be contiguous"),
708 };
709 let b = match lb.contiguous_offsets() {
710 Some((o1, o2)) => &b[o1..o2],
711 None => hanzo_ml::bail!("silu-mul: b must be contiguous"),
712 };
713 let dst: Vec<T> = a
714 .iter()
715 .zip(b.iter())
716 .map(|(&x, &y)| {
717 let xf = x.as_();
718 T::from_f32(xf / (1.0 + (-xf).exp()) * y.as_()).unwrap_or_else(T::nan)
719 })
720 .collect();
721 Ok((
722 hanzo_ml::WithDType::to_cpu_storage_owned(dst),
723 Shape::from_dims(la.shape().dims()),
724 ))
725 }
726 use hanzo_ml::backend::BackendStorage;
727 use CpuStorage as C;
728 match (s1, s2) {
729 (C::BF16(a), C::BF16(b)) => inner::<half::bf16>(a, l1, b, l2),
730 (C::F16(a), C::F16(b)) => inner::<half::f16>(a, l1, b, l2),
731 (C::F32(a), C::F32(b)) => inner::<f32>(a, l1, b, l2),
732 _ => hanzo_ml::bail!("silu-mul: unsupported dtype {:?}", s1.dtype()),
733 }
734 }
735
736 #[cfg(feature = "vulkan")]
737 fn vulkan_fwd(
738 &self,
739 s1: &hanzo_ml::VulkanStorage,
740 l1: &Layout,
741 s2: &hanzo_ml::VulkanStorage,
742 l2: &Layout,
743 ) -> Result<(hanzo_ml::VulkanStorage, Shape)> {
744 let out = s1.silu_mul(l1, s2, l2)?;
745 Ok((out, l1.shape().clone()))
746 }
747}
748
749pub fn silu_mul(gate: &Tensor, up: &Tensor) -> Result<Tensor> {
751 if gate.device().is_cuda() || gate.device().is_metal() {
752 return silu(gate)?.mul(up);
754 }
755 gate.apply_op2_no_bwd(up, &SiluMul)
756}
757
758#[derive(Debug, Clone)]
759struct LayerNorm {
760 eps: f32,
761}
762
763impl hanzo_ml::CustomOp3 for LayerNorm {
764 fn name(&self) -> &'static str {
765 "layer-norm"
766 }
767
768 fn cpu_fwd(
769 &self,
770 s1: &CpuStorage,
771 l1: &Layout,
772 s2: &CpuStorage,
773 l2: &Layout,
774 s3: &CpuStorage,
775 l3: &Layout,
776 ) -> Result<(CpuStorage, Shape)> {
777 use hanzo_ml::backend::BackendStorage;
778
779 let eps = self.eps;
780 fn inner<
781 T: hanzo_ml::WithDType
782 + num_traits::Float
783 + num_traits::AsPrimitive<f32>
784 + num_traits::FromPrimitive,
785 >(
786 src: &[T],
787 layout: &Layout,
788 alpha: &[T],
789 alpha_layout: &Layout,
790 beta: &[T],
791 beta_layout: &Layout,
792 eps: f32,
793 ) -> Result<(CpuStorage, Shape)> {
794 let src = match layout.contiguous_offsets() {
795 None => hanzo_ml::bail!("input has to be contiguous"),
796 Some((o1, o2)) => &src[o1..o2],
797 };
798 let alpha = match alpha_layout.contiguous_offsets() {
799 None => hanzo_ml::bail!("alpha has to be contiguous"),
800 Some((o1, o2)) => &alpha[o1..o2],
801 };
802 let beta = match beta_layout.contiguous_offsets() {
803 None => hanzo_ml::bail!("beta has to be contiguous"),
804 Some((o1, o2)) => &beta[o1..o2],
805 };
806 let el_count = layout.shape().elem_count();
807 let dims = layout.shape().dims();
808 let dim_m1 = dims[dims.len() - 1];
809 let mut dst = vec![T::zero(); el_count];
810 src.par_chunks(dim_m1)
811 .zip(dst.par_chunks_mut(dim_m1))
812 .for_each(|(src, dst)| {
813 let mut sum = 0f32;
814 let mut sum2 = 0f32;
815 for v in src {
816 let v = v.as_();
817 sum += v;
818 sum2 += v * v;
819 }
820 let mean = sum / dim_m1 as f32;
821 let var = sum2 / dim_m1 as f32 - mean * mean;
822 let inv_std = (var + eps).sqrt().recip();
823 for ((d, s), (alpha, beta)) in
824 dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta))
825 {
826 let alpha = alpha.as_();
827 let beta = beta.as_();
828 let d_ = (s.as_() - mean) * inv_std * alpha + beta;
829 *d = T::from_f32(d_).unwrap_or_else(T::nan);
830 }
831 });
832 let storage = hanzo_ml::WithDType::to_cpu_storage_owned(dst);
833 Ok((storage, Shape::from_dims(dims)))
834 }
835
836 use CpuStorage as C;
837 match (s1, s2, s3) {
838 (C::BF16(s1), C::BF16(s2), C::BF16(s3)) => {
839 inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps)
840 }
841 (C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps),
842 (C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps),
843 _ => hanzo_ml::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
844 }
845 }
846
847 #[cfg(feature = "cuda")]
848 fn cuda_fwd(
849 &self,
850 s1: &hanzo_ml::CudaStorage,
851 l1: &Layout,
852 s2: &hanzo_ml::CudaStorage,
853 l2: &Layout,
854 s3: &hanzo_ml::CudaStorage,
855 l3: &Layout,
856 ) -> Result<(hanzo_ml::CudaStorage, Shape)> {
857 use hanzo_ml::cuda_backend::cudarc::driver::{
858 CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
859 };
860 use hanzo_ml::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
861 use hanzo_ml::{CudaDevice, WithDType};
862
863 struct S {
864 eps: f32,
865 }
866 impl Map3 for S {
867 fn f<T: DeviceRepr + WithDType>(
868 &self,
869 src: &CudaSlice<T>,
870 layout: &Layout,
871 alpha: &CudaSlice<T>,
872 alpha_layout: &Layout,
873 beta: &CudaSlice<T>,
874 beta_layout: &Layout,
875 dev: &CudaDevice,
876 ) -> Result<CudaSlice<T>> {
877 let src = match layout.contiguous_offsets() {
878 None => hanzo_ml::bail!("input has to be contiguous"),
879 Some((o1, o2)) => src.slice(o1..o2),
880 };
881 let alpha = match alpha_layout.contiguous_offsets() {
882 None => hanzo_ml::bail!("alpha has to be contiguous"),
883 Some((o1, o2)) => alpha.slice(o1..o2),
884 };
885 let beta = match beta_layout.contiguous_offsets() {
886 None => hanzo_ml::bail!("beta has to be contiguous"),
887 Some((o1, o2)) => beta.slice(o1..o2),
888 };
889 let el = layout.shape().elem_count();
890 let dims = layout.shape().dims();
891 let dim_m1 = dims[dims.len() - 1];
892 let (n_rows, n_cols) = (el / dim_m1, dim_m1);
893
894 let block_size = if n_cols < 1024 { 32 } else { 1024 };
895 let cfg = LaunchConfig {
896 grid_dim: (n_rows as u32, 1, 1),
897 block_dim: (block_size, 1, 1),
898 shared_mem_bytes: 0,
899 };
900 let func =
901 dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
902 let dst = unsafe { dev.alloc::<T>(el)? };
904 let mut builder = func.builder();
905 builder.arg(&src);
906 builder.arg(&dst);
907 builder.arg(&alpha);
908 builder.arg(&beta);
909 hanzo_ml::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
910 unsafe { builder.launch(cfg) }.w()?;
912 Ok(dst)
913 }
914 }
915
916 use hanzo_ml::backend::BackendStorage;
917 let dev = s1.device();
918 let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?;
919 let dst = hanzo_ml::cuda_backend::CudaStorage {
920 slice,
921 device: dev.clone(),
922 };
923 Ok((dst, l1.shape().clone()))
924 }
925
926 #[cfg(feature = "metal")]
927 fn metal_fwd(
928 &self,
929 s1: &hanzo_ml::MetalStorage,
930 l1: &Layout,
931 s2: &hanzo_ml::MetalStorage,
932 l2: &Layout,
933 s3: &hanzo_ml::MetalStorage,
934 l3: &Layout,
935 ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
936 use hanzo_ml::backend::BackendStorage;
937 let device = s1.device();
938 let encoder = device.command_encoder()?;
939 encoder.set_label("layernorm");
940 let kernels = device.kernels();
941 let name = match (s1.dtype(), s2.dtype(), s3.dtype()) {
942 (DType::F32, DType::F32, DType::F32) => "layernorm_f32",
943 (DType::F16, DType::F16, DType::F16) => "layernorm_f16",
944 (DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16",
945 (dt1, dt2, dt3) => {
946 hanzo_ml::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}")
947 }
948 };
949
950 if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) {
951 hanzo_ml::bail!("Non contiguous layernorm is not implemented");
952 }
953
954 let last_dim = l1.dims()[l1.shape().rank() - 1];
955 let elem_count = l1.shape().elem_count();
956 let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?;
957 hanzo_metal_kernels::call_layer_norm(
958 device.metal_device(),
959 &encoder,
960 kernels,
961 name,
962 elem_count,
963 last_dim,
964 self.eps,
965 s1.buffer(),
966 l1.start_offset() * s1.dtype().size_in_bytes(),
967 s2.buffer(),
968 l2.start_offset() * s2.dtype().size_in_bytes(),
969 s3.buffer(),
970 l3.start_offset() * s3.dtype().size_in_bytes(),
971 &output,
972 )
973 .map_err(hanzo_ml::Error::wrap)?;
974 let newstorage =
975 hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
976 Ok((newstorage, l1.shape().clone()))
977 }
978}
979
980pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
981 let x_dtype = x.dtype();
982 let internal_dtype = match x_dtype {
983 DType::F16 | DType::BF16 => DType::F32,
984 d => d,
985 };
986 let hidden_size = x.dim(D::Minus1)?;
987 let x = x.to_dtype(internal_dtype)?;
988 let x = {
989 let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
990 x.broadcast_sub(&mean_x)?
991 };
992 let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
993 let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
994 x_normed
995 .to_dtype(x_dtype)?
996 .broadcast_mul(alpha)?
997 .broadcast_add(beta)
998}
999
1000pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
1001 let hidden_size_xs = xs.dim(D::Minus1)?;
1002 let hidden_size_alpha = alpha.dims1()?;
1003 let hidden_size_beta = beta.dims1()?;
1004 if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta {
1005 hanzo_ml::bail!(
1006 "shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}",
1007 xs.shape(),
1008 alpha.shape(),
1009 beta.shape()
1010 )
1011 }
1012 if xs.device().is_rocm() {
1013 return layer_norm_slow(xs, alpha, beta, eps);
1014 }
1015 xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps })
1016}
1017
1018pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
1020 let (b_size, c, h, w) = xs.dims4()?;
1021 let out_c = c / upscale_factor / upscale_factor;
1022 xs.reshape((b_size, out_c, upscale_factor, upscale_factor, h, w))?
1023 .permute((0, 1, 4, 2, 5, 3))?
1024 .reshape((b_size, out_c, h * upscale_factor, w * upscale_factor))
1025}
1026
1027pub fn pixel_unshuffle(xs: &Tensor, downscale_factor: usize) -> Result<Tensor> {
1028 let (b_size, c, h, w) = xs.dims4()?;
1029 let out_c = c * downscale_factor * downscale_factor;
1030 xs.reshape((
1031 b_size,
1032 c,
1033 h / downscale_factor,
1034 downscale_factor,
1035 w / downscale_factor,
1036 downscale_factor,
1037 ))?
1038 .permute((0, 1, 3, 5, 2, 4))?
1039 .reshape((b_size, out_c, h / downscale_factor, w / downscale_factor))
1040}
1041
1042pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
1044 match pad {
1045 0 => Ok(xs.clone()),
1046 1 => {
1047 let (_b_size, _c, h, w) = xs.dims4()?;
1048 let (first, last) = (xs.narrow(3, 0, 1)?, xs.narrow(3, w - 1, 1)?);
1049 let xs = Tensor::cat(&[&first, xs, &last], 3)?;
1050 let (first, last) = (xs.narrow(2, 0, 1)?, xs.narrow(2, h - 1, 1)?);
1051 Tensor::cat(&[&first, &xs, &last], 2)
1052 }
1053 n => hanzo_ml::bail!("replication-pad with a size of {n} is not supported"),
1054 }
1055}
1056
1057#[derive(Clone, Debug)]
1058pub struct Identity;
1059
1060impl Identity {
1061 pub fn new() -> Identity {
1062 Self
1063 }
1064}
1065
1066impl Default for Identity {
1067 fn default() -> Self {
1068 Self
1069 }
1070}
1071
1072impl Module for Identity {
1073 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1074 Ok(xs.clone())
1075 }
1076}
1077
1078#[allow(dead_code)]
1079struct Sdpa {
1080 scale: f32,
1081 softcapping: f32,
1082 mask: Option<Tensor>,
1083 do_causal: bool,
1084}
1085
1086impl hanzo_ml::CustomOp3 for Sdpa {
1087 fn name(&self) -> &'static str {
1088 "metal-sdpa"
1089 }
1090
1091 fn cpu_fwd(
1092 &self,
1093 _s1: &CpuStorage,
1094 _l1: &Layout,
1095 _s2: &CpuStorage,
1096 _l2: &Layout,
1097 _s3: &CpuStorage,
1098 _l3: &Layout,
1099 ) -> Result<(CpuStorage, Shape)> {
1100 hanzo_ml::bail!("SDPA has no cpu impl")
1101 }
1102
1103 #[cfg(feature = "metal")]
1104 fn metal_fwd(
1105 &self,
1106 q: &hanzo_ml::MetalStorage,
1107 q_l: &Layout,
1108 k: &hanzo_ml::MetalStorage,
1109 k_l: &Layout,
1110 v: &hanzo_ml::MetalStorage,
1111 v_l: &Layout,
1112 ) -> Result<(hanzo_ml::MetalStorage, Shape)> {
1113 use hanzo_metal_kernels::SdpaDType;
1114 use hanzo_ml::backend::BackendStorage;
1115
1116 let device = q.device();
1117
1118 let out_dims = vec![q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, v_l.dim(3)?];
1119 let elem_count: usize = out_dims.iter().product();
1120 let out_shape = Shape::from_dims(&out_dims);
1121 let out_layout = Layout::contiguous(out_shape.clone());
1122
1123 let output = device.new_buffer(elem_count, q.dtype(), "sdpa_o")?;
1124
1125 if q_l.dim(D::Minus1)? != k_l.dim(D::Minus1)? {
1127 hanzo_ml::bail!("`q` and `k` last dims must match");
1128 }
1129
1130 if v_l.dim(D::Minus(3))? != k_l.dim(D::Minus(3))? {
1132 hanzo_ml::bail!("`k` and `v` head dims must match");
1133 }
1134
1135 if q_l.dim(D::Minus(3))? % k_l.dim(D::Minus(3))? != 0 {
1137 hanzo_ml::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
1138 }
1139
1140 let k_head = k_l.dim(D::Minus1)?;
1141 let q_head = q_l.dim(D::Minus1)?;
1142 let q_seq = q_l.dim(2)?;
1143 let k_seq = k_l.dim(2)?;
1144
1145 let mut implementation_supports_use_case = q_head == k_head;
1146 let supported_head_dim = q_head == 32
1147 || q_head == 64
1148 || q_head == 72
1149 || q_head == 80
1150 || q_head == 96
1151 || q_head == 128
1152 || q_head == 256
1153 || q_head == 512;
1154
1155 let supports_sdpa_full_mask = self.mask.is_none() || q_seq <= k_seq;
1156 let supports_sdpa_full_dtype = !(q_head == 512 && q.dtype() == DType::F32);
1158 let supports_sdpa_full =
1159 q_seq > 8 && supported_head_dim && supports_sdpa_full_mask && supports_sdpa_full_dtype;
1160 let supports_sdpa_vector = q_seq <= 8 && supported_head_dim && q_seq <= k_seq;
1161
1162 implementation_supports_use_case &= supports_sdpa_full || supports_sdpa_vector;
1163
1164 if !supported_head_dim {
1165 hanzo_ml::bail!(
1166 "Meta SDPA does not support q head dim {q_head}: q dims {:?}, k dims {:?}, v dims {:?}.",
1167 q_l.dims(),
1168 k_l.dims(),
1169 v_l.dims()
1170 );
1171 }
1172 if !implementation_supports_use_case {
1173 hanzo_ml::bail!(
1174 "Meta SDPA does not support q dims {:?}, k dims {:?}, v dims {:?}.",
1175 q_l.dims(),
1176 k_l.dims(),
1177 v_l.dims()
1178 );
1179 }
1180
1181 for t in [k.dtype(), v.dtype()] {
1182 if q.dtype() != t {
1183 hanzo_ml::bail!("all q, k, v dtypes must match.");
1184 }
1185 }
1186
1187 let itype = match q.dtype() {
1188 DType::BF16 => SdpaDType::BF16,
1189 DType::F16 => SdpaDType::F16,
1190 DType::F32 => SdpaDType::F32,
1191 other => hanzo_ml::bail!("unsupported sdpa type {other:?}"),
1192 };
1193
1194 let encoder = q.device().command_encoder()?;
1195 if supports_sdpa_vector {
1196 const TWO_PASS_K_THRESHOLD: usize = 1024;
1199 if k_seq >= TWO_PASS_K_THRESHOLD {
1200 let mut intermediate_shape = [
1201 &out_dims[0..out_dims.len() - 2],
1202 &[hanzo_metal_kernels::SDPA_2PASS_BLOCKS],
1203 &[out_dims[out_dims.len() - 1]],
1204 ]
1205 .concat();
1206 let intermediate = device.new_buffer(
1207 intermediate_shape.iter().product::<usize>(),
1208 DType::F32,
1209 "sdpa_2pass_intermediate",
1210 )?;
1211 let _ = intermediate_shape.pop().unwrap();
1212 let sums = device.new_buffer(
1213 intermediate_shape.iter().product::<usize>(),
1214 DType::F32,
1215 "sdpa_2pass_sums",
1216 )?;
1217 let maxs = device.new_buffer(
1218 intermediate_shape.iter().product::<usize>(),
1219 DType::F32,
1220 "sdpa_2pass_maxs",
1221 )?;
1222
1223 encoder.set_label("vector_attention");
1224 hanzo_metal_kernels::call_sdpa_vector_2pass(
1225 q.device().device(),
1226 &encoder,
1227 q.device().kernels(),
1228 q_l.start_offset(),
1229 q_l.dims(),
1230 q.buffer(),
1231 k_l.start_offset(),
1232 k_l.dims(),
1233 k_l.stride(),
1234 k.buffer(),
1235 v_l.start_offset(),
1236 v_l.stride(),
1237 v.buffer(),
1238 &output,
1239 &intermediate,
1240 &sums,
1241 &maxs,
1242 self.scale,
1243 self.softcapping,
1244 itype,
1245 )
1246 .map_err(hanzo_ml::Error::wrap)?;
1247 } else {
1248 encoder.set_label("vector_attention");
1249 hanzo_metal_kernels::call_sdpa_vector(
1250 q.device().device(),
1251 &encoder,
1252 q.device().kernels(),
1253 q_l.start_offset(),
1254 q_l.dims(),
1255 q.buffer(),
1256 k_l.start_offset(),
1257 k_l.dims(),
1258 k_l.stride(),
1259 k.buffer(),
1260 v_l.start_offset(),
1261 v_l.stride(),
1262 v.buffer(),
1263 &output,
1264 self.scale,
1265 self.softcapping,
1266 itype,
1267 )
1268 .map_err(hanzo_ml::Error::wrap)?;
1269 }
1270 } else if supports_sdpa_full {
1271 encoder.set_label("full_attention");
1272 if self.softcapping != 1. {
1273 hanzo_ml::bail!("SDPA full requires softcapping to be disabled (1.0)");
1274 }
1275
1276 let mask_s_l = self.mask.as_ref().map(|m| m.storage_and_layout());
1277
1278 let (mask_type, mask_buffer, mask_strides) = if let Some(mask) = &self.mask {
1279 let (mask_s, mask_l) = mask_s_l.as_ref().unwrap();
1280
1281 let mask_buffer = match &**mask_s {
1282 hanzo_ml::Storage::Metal(m) => m.buffer(),
1283 _ => hanzo_ml::bail!("Expected metal device for mask"),
1284 };
1285
1286 let mask_type = match mask.dtype() {
1287 DType::BF16 => SdpaDType::BF16,
1288 DType::F16 => SdpaDType::F16,
1289 DType::F32 => SdpaDType::F32,
1290 other => hanzo_ml::bail!("unsupported sdpa type {other:?}"),
1291 };
1292 if mask_type != itype {
1293 hanzo_ml::bail!("Mask type {mask_type:?} must match q type {itype:?}");
1294 }
1295
1296 if mask_l.dims() != [q_l.dim(0)?, q_l.dim(1)?, q_l.dim(2)?, k_seq] {
1297 hanzo_ml::bail!(
1298 "Mask shape must be {:?} (bs, qheads, qseq, kseq), got {:?}",
1299 [q_l.dim(0)?, q_head, q_l.dim(2)?, k_seq],
1300 mask_l.dims()
1301 );
1302 }
1303
1304 (
1305 Some(mask_type),
1306 Some(mask_buffer),
1307 Some(mask_l.stride().to_vec()),
1308 )
1309 } else {
1310 (None, None, None)
1311 };
1312
1313 hanzo_metal_kernels::call_sdpa_full(
1314 q.device().device(),
1315 &encoder,
1316 q.device().kernels(),
1317 q_l.start_offset(),
1318 q_l.dims(),
1319 q_l.stride(),
1320 q.buffer(),
1321 k_l.start_offset(),
1322 k_l.dims(),
1323 k_l.stride(),
1324 k.buffer(),
1325 v_l.start_offset(),
1326 v.buffer(),
1327 v_l.stride(),
1328 mask_type,
1329 mask_buffer,
1330 mask_strides.as_deref(),
1331 &output,
1332 out_layout.stride(),
1333 self.scale,
1334 self.do_causal,
1335 itype,
1336 )
1337 .map_err(hanzo_ml::Error::wrap)?;
1338 } else {
1339 hanzo_ml::bail!("must be vector or full sdpa kernel");
1340 }
1341
1342 let newstorage = hanzo_ml::MetalStorage::new(output, device.clone(), elem_count, q.dtype());
1343 Ok((newstorage, out_shape))
1344 }
1345}
1346
1347pub fn sdpa(
1376 q: &Tensor,
1377 k: &Tensor,
1378 v: &Tensor,
1379 mask: Option<&Tensor>,
1380 do_causal: bool,
1381 scale: f32,
1382 softcapping: f32,
1383) -> Result<Tensor> {
1384 q.apply_op3_no_bwd(
1385 k,
1386 v,
1387 &Sdpa {
1388 scale,
1389 softcapping,
1390 mask: mask.cloned(),
1391 do_causal,
1392 },
1393 )
1394}