1use crate::op::{BackpropOp, Op};
2use crate::tensor::from_storage;
3#[cfg(feature = "rocm")]
4use crate::RocmStorage;
5#[cfg(feature = "vulkan")]
6use crate::VulkanStorage;
7use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
8use std::sync::Arc;
9
10#[cfg(feature = "vulkan")]
17fn log_vulkan_custom_op_bail(name: &str, l: &Layout) {
18 if std::env::var("HANZO_VK_PROFILE").map(|v| v != "0").unwrap_or(false) {
19 eprintln!(
20 "[HANZO_VK_PROFILE] custom-op bail op={name} shape={:?} (no vulkan_fwd; would round-trip/err)",
21 l.shape().dims()
22 );
23 }
24}
25
26pub trait CustomOp1 {
28 fn name(&self) -> &'static str;
30
31 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
34
35 fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
38 Err(crate::Error::Cuda(
39 format!("no cuda implementation for {}", self.name()).into(),
40 ))
41 }
42
43 #[cfg(feature = "rocm")]
44 fn rocm_fwd(&self, _storage: &RocmStorage, _layout: &Layout) -> Result<(RocmStorage, Shape)> {
45 Err(crate::Error::Msg(format!(
46 "no rocm implementation for {}",
47 self.name()
48 )))
49 }
50 #[cfg(feature = "vulkan")]
51 fn vulkan_fwd(
52 &self,
53 _storage: &VulkanStorage,
54 _layout: &Layout,
55 ) -> Result<(VulkanStorage, Shape)> {
56 log_vulkan_custom_op_bail(self.name(), _layout);
57 Err(crate::Error::Msg(format!(
58 "no vulkan implementation for {}",
59 self.name()
60 )))
61 }
62
63 fn metal_fwd(
66 &self,
67 _storage: &MetalStorage,
68 _layout: &Layout,
69 ) -> Result<(MetalStorage, Shape)> {
70 Err(crate::Error::Metal(
71 format!("no metal implementation for {}", self.name()).into(),
72 ))
73 }
74
75 fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
79 Err(crate::Error::BackwardNotSupported { op: self.name() })
80 }
81}
82
83pub trait CustomOp2 {
84 fn name(&self) -> &'static str;
85
86 fn cpu_fwd(
89 &self,
90 s1: &CpuStorage,
91 l1: &Layout,
92 s2: &CpuStorage,
93 l2: &Layout,
94 ) -> Result<(CpuStorage, Shape)>;
95
96 fn cuda_fwd(
99 &self,
100 _: &CudaStorage,
101 _: &Layout,
102 _: &CudaStorage,
103 _: &Layout,
104 ) -> Result<(CudaStorage, Shape)> {
105 Err(crate::Error::Cuda(
106 format!("no cuda implementation for {}", self.name()).into(),
107 ))
108 }
109
110 #[cfg(feature = "rocm")]
111 fn rocm_fwd(
112 &self,
113 _: &RocmStorage,
114 _: &Layout,
115 _: &RocmStorage,
116 _: &Layout,
117 ) -> Result<(RocmStorage, Shape)> {
118 Err(crate::Error::Msg(format!(
119 "no rocm implementation for {}",
120 self.name()
121 )))
122 }
123 #[cfg(feature = "vulkan")]
124 fn vulkan_fwd(
125 &self,
126 _: &VulkanStorage,
127 l1: &Layout,
128 _: &VulkanStorage,
129 _: &Layout,
130 ) -> Result<(VulkanStorage, Shape)> {
131 log_vulkan_custom_op_bail(self.name(), l1);
132 Err(crate::Error::Msg(format!(
133 "no vulkan implementation for {}",
134 self.name()
135 )))
136 }
137
138 fn metal_fwd(
141 &self,
142 _: &MetalStorage,
143 _: &Layout,
144 _: &MetalStorage,
145 _: &Layout,
146 ) -> Result<(MetalStorage, Shape)> {
147 Err(crate::Error::Metal(
148 format!("no metal implementation for {}", self.name()).into(),
149 ))
150 }
151
152 fn bwd(
153 &self,
154 _arg1: &Tensor,
155 _arg2: &Tensor,
156 _res: &Tensor,
157 _grad_res: &Tensor,
158 ) -> Result<(Option<Tensor>, Option<Tensor>)> {
159 Err(crate::Error::BackwardNotSupported { op: self.name() })
160 }
161}
162
163pub trait CustomOp3 {
164 fn name(&self) -> &'static str;
165
166 fn cpu_fwd(
169 &self,
170 s1: &CpuStorage,
171 l1: &Layout,
172 s2: &CpuStorage,
173 l2: &Layout,
174 s3: &CpuStorage,
175 l3: &Layout,
176 ) -> Result<(CpuStorage, Shape)>;
177
178 fn cuda_fwd(
181 &self,
182 _: &CudaStorage,
183 _: &Layout,
184 _: &CudaStorage,
185 _: &Layout,
186 _: &CudaStorage,
187 _: &Layout,
188 ) -> Result<(CudaStorage, Shape)> {
189 Err(crate::Error::Cuda(
190 format!("no cuda implementation for {}", self.name()).into(),
191 ))
192 }
193
194 #[cfg(feature = "rocm")]
195 fn rocm_fwd(
196 &self,
197 _: &RocmStorage,
198 _: &Layout,
199 _: &RocmStorage,
200 _: &Layout,
201 _: &RocmStorage,
202 _: &Layout,
203 ) -> Result<(RocmStorage, Shape)> {
204 Err(crate::Error::Msg(format!(
205 "no rocm implementation for {}",
206 self.name()
207 )))
208 }
209 #[cfg(feature = "vulkan")]
210 fn vulkan_fwd(
211 &self,
212 _: &VulkanStorage,
213 l1: &Layout,
214 _: &VulkanStorage,
215 _: &Layout,
216 _: &VulkanStorage,
217 _: &Layout,
218 ) -> Result<(VulkanStorage, Shape)> {
219 log_vulkan_custom_op_bail(self.name(), l1);
220 Err(crate::Error::Msg(format!(
221 "no vulkan implementation for {}",
222 self.name()
223 )))
224 }
225
226 fn metal_fwd(
229 &self,
230 _: &MetalStorage,
231 _: &Layout,
232 _: &MetalStorage,
233 _: &Layout,
234 _: &MetalStorage,
235 _: &Layout,
236 ) -> Result<(MetalStorage, Shape)> {
237 Err(crate::Error::Metal(
238 format!("no metal implementation for {}", self.name()).into(),
239 ))
240 }
241
242 fn bwd(
243 &self,
244 _arg1: &Tensor,
245 _arg2: &Tensor,
246 _arg3: &Tensor,
247 _res: &Tensor,
248 _grad_res: &Tensor,
249 ) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
250 Err(crate::Error::BackwardNotSupported { op: self.name() })
251 }
252}
253
254impl Tensor {
255 pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
257 let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
258 Ok(from_storage(storage, shape, BackpropOp::none(), false))
259 }
260
261 pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
263 let (storage, shape) =
264 self.storage()
265 .apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
266 Ok(from_storage(storage, shape, BackpropOp::none(), false))
267 }
268
269 pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
271 let (storage, shape) = self.storage().apply_op3(
272 self.layout(),
273 &t2.storage(),
274 t2.layout(),
275 &t3.storage(),
276 t3.layout(),
277 c,
278 )?;
279 Ok(from_storage(storage, shape, BackpropOp::none(), false))
280 }
281
282 pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
284 let (storage, shape) = self
285 .storage()
286 .apply_op1(self.layout(), c.as_ref().as_ref())?;
287 let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
288 Ok(from_storage(storage, shape, op, false))
289 }
290
291 pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
292 self.apply_op1_arc(Arc::new(Box::new(c)))
293 }
294
295 pub fn apply_op2_arc(
297 &self,
298 rhs: &Self,
299 c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
300 ) -> Result<Self> {
301 let (storage, shape) = self.storage().apply_op2(
302 self.layout(),
303 &rhs.storage(),
304 rhs.layout(),
305 c.as_ref().as_ref(),
306 )?;
307 let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
308 Ok(from_storage(storage, shape, op, false))
309 }
310
311 pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
312 self.apply_op2_arc(r, Arc::new(Box::new(c)))
313 }
314
315 pub fn apply_op3_arc(
317 &self,
318 t2: &Self,
319 t3: &Self,
320 c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
321 ) -> Result<Self> {
322 let (storage, shape) = self.storage().apply_op3(
323 self.layout(),
324 &t2.storage(),
325 t2.layout(),
326 &t3.storage(),
327 t3.layout(),
328 c.as_ref().as_ref(),
329 )?;
330 let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
331 Op::CustomOp3(t1, t2, t3, c.clone())
332 });
333 Ok(from_storage(storage, shape, op, false))
334 }
335
336 pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
337 &self,
338 t2: &Self,
339 t3: &Self,
340 c: C,
341 ) -> Result<Self> {
342 self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
343 }
344}
345
346pub trait InplaceOp1 {
351 fn name(&self) -> &'static str;
353
354 fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
357
358 fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
361 Err(crate::Error::Cuda(
362 format!("no cuda implementation for {}", self.name()).into(),
363 ))
364 }
365
366 #[cfg(feature = "rocm")]
367 fn rocm_fwd(&self, _storage: &mut RocmStorage, _layout: &Layout) -> Result<()> {
368 Err(crate::Error::Msg(format!(
369 "no rocm implementation for {}",
370 self.name()
371 )))
372 }
373 #[cfg(feature = "vulkan")]
374 fn vulkan_fwd(&self, _storage: &mut VulkanStorage, _layout: &Layout) -> Result<()> {
375 log_vulkan_custom_op_bail(self.name(), _layout);
376 Err(crate::Error::Msg(format!(
377 "no vulkan implementation for {}",
378 self.name()
379 )))
380 }
381
382 fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
385 Err(crate::Error::Metal(
386 format!("no metal implementation for {}", self.name()).into(),
387 ))
388 }
389}
390
391pub trait InplaceOp2 {
392 fn name(&self) -> &'static str;
393
394 fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
397 -> Result<()>;
398
399 fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
402 Err(crate::Error::Cuda(
403 format!("no cuda implementation for {}", self.name()).into(),
404 ))
405 }
406
407 #[cfg(feature = "rocm")]
408 fn rocm_fwd(&self, _: &mut RocmStorage, _: &Layout, _: &RocmStorage, _: &Layout) -> Result<()> {
409 Err(crate::Error::Msg(format!(
410 "no rocm implementation for {}",
411 self.name()
412 )))
413 }
414 #[cfg(feature = "vulkan")]
415 fn vulkan_fwd(
416 &self,
417 _: &mut VulkanStorage,
418 l1: &Layout,
419 _: &VulkanStorage,
420 _: &Layout,
421 ) -> Result<()> {
422 log_vulkan_custom_op_bail(self.name(), l1);
423 Err(crate::Error::Msg(format!(
424 "no vulkan implementation for {}",
425 self.name()
426 )))
427 }
428
429 fn metal_fwd(
432 &self,
433 _: &mut MetalStorage,
434 _: &Layout,
435 _: &MetalStorage,
436 _: &Layout,
437 ) -> Result<()> {
438 Err(crate::Error::Metal(
439 format!("no metal implementation for {}", self.name()).into(),
440 ))
441 }
442}
443
444pub trait InplaceOp3 {
445 fn name(&self) -> &'static str;
446
447 fn cpu_fwd(
450 &self,
451 s1: &mut CpuStorage,
452 l1: &Layout,
453 s2: &CpuStorage,
454 l2: &Layout,
455 s3: &CpuStorage,
456 l3: &Layout,
457 ) -> Result<()>;
458
459 fn cuda_fwd(
462 &self,
463 _: &mut CudaStorage,
464 _: &Layout,
465 _: &CudaStorage,
466 _: &Layout,
467 _: &CudaStorage,
468 _: &Layout,
469 ) -> Result<()> {
470 Err(crate::Error::Cuda(
471 format!("no cuda implementation for {}", self.name()).into(),
472 ))
473 }
474
475 #[cfg(feature = "rocm")]
476 fn rocm_fwd(
477 &self,
478 _: &mut RocmStorage,
479 _: &Layout,
480 _: &RocmStorage,
481 _: &Layout,
482 _: &RocmStorage,
483 _: &Layout,
484 ) -> Result<()> {
485 Err(crate::Error::Msg(format!(
486 "no rocm implementation for {}",
487 self.name()
488 )))
489 }
490 #[cfg(feature = "vulkan")]
491 fn vulkan_fwd(
492 &self,
493 _: &mut VulkanStorage,
494 l1: &Layout,
495 _: &VulkanStorage,
496 _: &Layout,
497 _: &VulkanStorage,
498 _: &Layout,
499 ) -> Result<()> {
500 log_vulkan_custom_op_bail(self.name(), l1);
501 Err(crate::Error::Msg(format!(
502 "no vulkan implementation for {}",
503 self.name()
504 )))
505 }
506
507 fn metal_fwd(
510 &self,
511 _: &mut MetalStorage,
512 _: &Layout,
513 _: &MetalStorage,
514 _: &Layout,
515 _: &MetalStorage,
516 _: &Layout,
517 ) -> Result<()> {
518 Err(crate::Error::Metal(
519 format!("no metal implementation for {}", self.name()).into(),
520 ))
521 }
522}
523
524impl Tensor {
525 pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
527 self.storage_mut().inplace_op1(self.layout(), c)
528 }
529
530 pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
532 self.storage_mut()
533 .inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
534 }
535
536 pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
538 self.storage_mut().inplace_op3(
539 self.layout(),
540 &t2.storage(),
541 t2.layout(),
542 &t3.storage(),
543 t3.layout(),
544 c,
545 )
546 }
547}
548
549#[cfg(feature = "ug")]
550pub struct UgIOp1 {
551 name: &'static str,
552 #[cfg(feature = "cuda")]
553 func: cudarc::driver::CudaFunction,
554 #[cfg(feature = "metal")]
555 func: hanzo_metal_kernels::metal::ComputePipeline,
556}
557
558#[cfg(feature = "ug")]
559impl UgIOp1 {
560 #[allow(unused)]
561 #[cfg(all(not(target_arch = "wasm32"), not(target_os = "ios")))]
562 pub fn new(
563 name: &'static str,
564 kernel: hanzo_ug::lang::ssa::Kernel,
565 device: &crate::Device,
566 ) -> Result<Self> {
567 #[cfg(feature = "cuda")]
568 {
569 let device = device.as_cuda_device()?;
570 let func = device.compile(name, kernel)?;
571 Ok(Self {
572 name,
573 func: func.into_cuda_function(),
574 })
575 }
576 #[cfg(feature = "metal")]
577 {
578 let device = device.as_metal_device()?;
579 let func = device.compile(name, kernel)?;
580 Ok(Self { name, func })
581 }
582 #[cfg(not(any(feature = "cuda", feature = "metal")))]
583 {
584 Ok(Self { name })
585 }
586 }
587}
588
589#[cfg(feature = "ug")]
590impl InplaceOp1 for UgIOp1 {
591 fn name(&self) -> &'static str {
592 self.name
593 }
594
595 fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
596 crate::bail!("ug ops are only supported on metal/cuda at the moment")
597 }
598
599 #[cfg(feature = "metal")]
600 fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
601 use crate::backend::BackendStorage;
602 use objc2_metal;
603
604 let elem_count = layout.shape().elem_count();
605 if sto.dtype() != crate::DType::F32 {
606 crate::bail!("input is not a f32 tensor")
608 }
609 let device = sto.device();
610 let encoder = device.command_encoder()?;
611 encoder.set_compute_pipeline_state(&self.func);
612 let (g, b) = if elem_count.is_multiple_of(32) {
613 (elem_count / 32, 32)
614 } else {
615 (elem_count, 1)
616 };
617 let grid_dims = objc2_metal::MTLSize {
618 width: g,
619 height: 1,
620 depth: 1,
621 };
622 let group_dims = hanzo_metal_kernels::utils::get_block_dims(b, 1, 1);
623 encoder.set_output_buffer(0, Some(sto.buffer()), 0);
624 encoder.dispatch_threads(grid_dims, group_dims);
625
626 Ok(())
627 }
628
629 #[cfg(feature = "cuda")]
630 fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
631 use crate::cuda_backend::WrapErr;
632 use cudarc::driver::PushKernelArg;
633
634 let elem_count = layout.shape().elem_count();
635 let stream = sto.device.cuda_stream();
636 let sto = sto.as_cuda_slice::<f32>()?;
638 let sto = match layout.contiguous_offsets() {
639 None => crate::bail!("input has to be contiguous"),
640 Some((o1, o2)) => sto.slice(o1..o2),
641 };
642 let (g, b) = if elem_count % 32 == 0 {
643 (elem_count / 32, 32)
644 } else {
645 (elem_count, 1)
646 };
647 let cfg = cudarc::driver::LaunchConfig {
648 grid_dim: (g as u32, 1, 1),
649 block_dim: (b as u32, 1, 1),
650 shared_mem_bytes: 0,
651 };
652 let mut builder = stream.launch_builder(&self.func);
653 builder.arg(&sto);
654 unsafe { builder.launch(cfg) }.w()?;
655 Ok(())
656 }
657}