Skip to main content

hanzo_ml/
custom_op.rs

1use crate::op::{BackpropOp, Op};
2use crate::tensor::from_storage;
3#[cfg(feature = "rocm")]
4use crate::RocmStorage;
5#[cfg(feature = "vulkan")]
6use crate::VulkanStorage;
7use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
8use std::sync::Arc;
9
10/// Name a custom/inplace op that has no native Vulkan path when `HANZO_VK_PROFILE` is set, before
11/// it bails. The size-only readback profiler can't attribute a missing op to a name; this surfaces
12/// the exact op (+ its shape) so a GPU re-run knows which `vulkan_fwd` override to add next. The
13/// env read is on the cold bail path only (the op errors out right after), so it's effectively
14/// zero-cost for ops that DO have a native path. Vulkan-only (the default impls it guards are
15/// `#[cfg(feature = "vulkan")]`), so it never touches other backends.
16#[cfg(feature = "vulkan")]
17fn log_vulkan_custom_op_bail(name: &str, l: &Layout) {
18    if std::env::var("HANZO_VK_PROFILE").map(|v| v != "0").unwrap_or(false) {
19        eprintln!(
20            "[HANZO_VK_PROFILE] custom-op bail op={name} shape={:?} (no vulkan_fwd; would round-trip/err)",
21            l.shape().dims()
22        );
23    }
24}
25
26/// Unary ops that can be defined in user-land.
27pub trait CustomOp1 {
28    // Box<dyn> does not support const yet, so use a function to get the name.
29    fn name(&self) -> &'static str;
30
31    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
32    /// offsets etc so the associated layout should be used to access it.
33    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
34
35    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
36    /// offsets etc so the associated layout should be used to access it.
37    fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
38        Err(crate::Error::Cuda(
39            format!("no cuda implementation for {}", self.name()).into(),
40        ))
41    }
42
43    #[cfg(feature = "rocm")]
44    fn rocm_fwd(&self, _storage: &RocmStorage, _layout: &Layout) -> Result<(RocmStorage, Shape)> {
45        Err(crate::Error::Msg(format!(
46            "no rocm implementation for {}",
47            self.name()
48        )))
49    }
50    #[cfg(feature = "vulkan")]
51    fn vulkan_fwd(
52        &self,
53        _storage: &VulkanStorage,
54        _layout: &Layout,
55    ) -> Result<(VulkanStorage, Shape)> {
56        log_vulkan_custom_op_bail(self.name(), _layout);
57        Err(crate::Error::Msg(format!(
58            "no vulkan implementation for {}",
59            self.name()
60        )))
61    }
62
63    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
64    /// offsets etc so the associated layout should be used to access it.
65    fn metal_fwd(
66        &self,
67        _storage: &MetalStorage,
68        _layout: &Layout,
69    ) -> Result<(MetalStorage, Shape)> {
70        Err(crate::Error::Metal(
71            format!("no metal implementation for {}", self.name()).into(),
72        ))
73    }
74
75    /// This function takes as argument the argument `arg` used in the forward pass, the result
76    /// produced by the forward operation `res` and the gradient of the result `grad_res`.
77    /// The function should return the gradient of the argument.
78    fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
79        Err(crate::Error::BackwardNotSupported { op: self.name() })
80    }
81}
82
83pub trait CustomOp2 {
84    fn name(&self) -> &'static str;
85
86    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
87    /// offsets etc so the associated layout should be used to access it.
88    fn cpu_fwd(
89        &self,
90        s1: &CpuStorage,
91        l1: &Layout,
92        s2: &CpuStorage,
93        l2: &Layout,
94    ) -> Result<(CpuStorage, Shape)>;
95
96    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
97    /// offsets etc so the associated layout should be used to access it.
98    fn cuda_fwd(
99        &self,
100        _: &CudaStorage,
101        _: &Layout,
102        _: &CudaStorage,
103        _: &Layout,
104    ) -> Result<(CudaStorage, Shape)> {
105        Err(crate::Error::Cuda(
106            format!("no cuda implementation for {}", self.name()).into(),
107        ))
108    }
109
110    #[cfg(feature = "rocm")]
111    fn rocm_fwd(
112        &self,
113        _: &RocmStorage,
114        _: &Layout,
115        _: &RocmStorage,
116        _: &Layout,
117    ) -> Result<(RocmStorage, Shape)> {
118        Err(crate::Error::Msg(format!(
119            "no rocm implementation for {}",
120            self.name()
121        )))
122    }
123    #[cfg(feature = "vulkan")]
124    fn vulkan_fwd(
125        &self,
126        _: &VulkanStorage,
127        l1: &Layout,
128        _: &VulkanStorage,
129        _: &Layout,
130    ) -> Result<(VulkanStorage, Shape)> {
131        log_vulkan_custom_op_bail(self.name(), l1);
132        Err(crate::Error::Msg(format!(
133            "no vulkan implementation for {}",
134            self.name()
135        )))
136    }
137
138    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
139    /// offsets etc so the associated layout should be used to access it.
140    fn metal_fwd(
141        &self,
142        _: &MetalStorage,
143        _: &Layout,
144        _: &MetalStorage,
145        _: &Layout,
146    ) -> Result<(MetalStorage, Shape)> {
147        Err(crate::Error::Metal(
148            format!("no metal implementation for {}", self.name()).into(),
149        ))
150    }
151
152    fn bwd(
153        &self,
154        _arg1: &Tensor,
155        _arg2: &Tensor,
156        _res: &Tensor,
157        _grad_res: &Tensor,
158    ) -> Result<(Option<Tensor>, Option<Tensor>)> {
159        Err(crate::Error::BackwardNotSupported { op: self.name() })
160    }
161}
162
163pub trait CustomOp3 {
164    fn name(&self) -> &'static str;
165
166    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
167    /// offsets etc so the associated layout should be used to access it.
168    fn cpu_fwd(
169        &self,
170        s1: &CpuStorage,
171        l1: &Layout,
172        s2: &CpuStorage,
173        l2: &Layout,
174        s3: &CpuStorage,
175        l3: &Layout,
176    ) -> Result<(CpuStorage, Shape)>;
177
178    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
179    /// offsets etc so the associated layout should be used to access it.
180    fn cuda_fwd(
181        &self,
182        _: &CudaStorage,
183        _: &Layout,
184        _: &CudaStorage,
185        _: &Layout,
186        _: &CudaStorage,
187        _: &Layout,
188    ) -> Result<(CudaStorage, Shape)> {
189        Err(crate::Error::Cuda(
190            format!("no cuda implementation for {}", self.name()).into(),
191        ))
192    }
193
194    #[cfg(feature = "rocm")]
195    fn rocm_fwd(
196        &self,
197        _: &RocmStorage,
198        _: &Layout,
199        _: &RocmStorage,
200        _: &Layout,
201        _: &RocmStorage,
202        _: &Layout,
203    ) -> Result<(RocmStorage, Shape)> {
204        Err(crate::Error::Msg(format!(
205            "no rocm implementation for {}",
206            self.name()
207        )))
208    }
209    #[cfg(feature = "vulkan")]
210    fn vulkan_fwd(
211        &self,
212        _: &VulkanStorage,
213        l1: &Layout,
214        _: &VulkanStorage,
215        _: &Layout,
216        _: &VulkanStorage,
217        _: &Layout,
218    ) -> Result<(VulkanStorage, Shape)> {
219        log_vulkan_custom_op_bail(self.name(), l1);
220        Err(crate::Error::Msg(format!(
221            "no vulkan implementation for {}",
222            self.name()
223        )))
224    }
225
226    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
227    /// offsets etc so the associated layout should be used to access it.
228    fn metal_fwd(
229        &self,
230        _: &MetalStorage,
231        _: &Layout,
232        _: &MetalStorage,
233        _: &Layout,
234        _: &MetalStorage,
235        _: &Layout,
236    ) -> Result<(MetalStorage, Shape)> {
237        Err(crate::Error::Metal(
238            format!("no metal implementation for {}", self.name()).into(),
239        ))
240    }
241
242    fn bwd(
243        &self,
244        _arg1: &Tensor,
245        _arg2: &Tensor,
246        _arg3: &Tensor,
247        _res: &Tensor,
248        _grad_res: &Tensor,
249    ) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
250        Err(crate::Error::BackwardNotSupported { op: self.name() })
251    }
252}
253
254impl Tensor {
255    /// Applies a unary custom op without backward support
256    pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
257        let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
258        Ok(from_storage(storage, shape, BackpropOp::none(), false))
259    }
260
261    /// Applies a binary custom op without backward support
262    pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
263        let (storage, shape) =
264            self.storage()
265                .apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
266        Ok(from_storage(storage, shape, BackpropOp::none(), false))
267    }
268
269    /// Applies a ternary custom op without backward support
270    pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
271        let (storage, shape) = self.storage().apply_op3(
272            self.layout(),
273            &t2.storage(),
274            t2.layout(),
275            &t3.storage(),
276            t3.layout(),
277            c,
278        )?;
279        Ok(from_storage(storage, shape, BackpropOp::none(), false))
280    }
281
282    /// Applies a unary custom op.
283    pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
284        let (storage, shape) = self
285            .storage()
286            .apply_op1(self.layout(), c.as_ref().as_ref())?;
287        let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
288        Ok(from_storage(storage, shape, op, false))
289    }
290
291    pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
292        self.apply_op1_arc(Arc::new(Box::new(c)))
293    }
294
295    /// Applies a binary custom op.
296    pub fn apply_op2_arc(
297        &self,
298        rhs: &Self,
299        c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
300    ) -> Result<Self> {
301        let (storage, shape) = self.storage().apply_op2(
302            self.layout(),
303            &rhs.storage(),
304            rhs.layout(),
305            c.as_ref().as_ref(),
306        )?;
307        let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
308        Ok(from_storage(storage, shape, op, false))
309    }
310
311    pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
312        self.apply_op2_arc(r, Arc::new(Box::new(c)))
313    }
314
315    /// Applies a ternary custom op.
316    pub fn apply_op3_arc(
317        &self,
318        t2: &Self,
319        t3: &Self,
320        c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
321    ) -> Result<Self> {
322        let (storage, shape) = self.storage().apply_op3(
323            self.layout(),
324            &t2.storage(),
325            t2.layout(),
326            &t3.storage(),
327            t3.layout(),
328            c.as_ref().as_ref(),
329        )?;
330        let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
331            Op::CustomOp3(t1, t2, t3, c.clone())
332        });
333        Ok(from_storage(storage, shape, op, false))
334    }
335
336    pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
337        &self,
338        t2: &Self,
339        t3: &Self,
340        c: C,
341    ) -> Result<Self> {
342        self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
343    }
344}
345
346// In place ops.
347
348/// Unary ops that can be defined in user-land.
349/// These ops work in place and as such back-prop is unsupported.
350pub trait InplaceOp1 {
351    // Box<dyn> does not support const yet, so use a function to get the name.
352    fn name(&self) -> &'static str;
353
354    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
355    /// offsets etc so the associated layout should be used to access it.
356    fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
357
358    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
359    /// offsets etc so the associated layout should be used to access it.
360    fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
361        Err(crate::Error::Cuda(
362            format!("no cuda implementation for {}", self.name()).into(),
363        ))
364    }
365
366    #[cfg(feature = "rocm")]
367    fn rocm_fwd(&self, _storage: &mut RocmStorage, _layout: &Layout) -> Result<()> {
368        Err(crate::Error::Msg(format!(
369            "no rocm implementation for {}",
370            self.name()
371        )))
372    }
373    #[cfg(feature = "vulkan")]
374    fn vulkan_fwd(&self, _storage: &mut VulkanStorage, _layout: &Layout) -> Result<()> {
375        log_vulkan_custom_op_bail(self.name(), _layout);
376        Err(crate::Error::Msg(format!(
377            "no vulkan implementation for {}",
378            self.name()
379        )))
380    }
381
382    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
383    /// offsets etc so the associated layout should be used to access it.
384    fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
385        Err(crate::Error::Metal(
386            format!("no metal implementation for {}", self.name()).into(),
387        ))
388    }
389}
390
391pub trait InplaceOp2 {
392    fn name(&self) -> &'static str;
393
394    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
395    /// offsets etc so the associated layout should be used to access it.
396    fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
397        -> Result<()>;
398
399    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
400    /// offsets etc so the associated layout should be used to access it.
401    fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
402        Err(crate::Error::Cuda(
403            format!("no cuda implementation for {}", self.name()).into(),
404        ))
405    }
406
407    #[cfg(feature = "rocm")]
408    fn rocm_fwd(&self, _: &mut RocmStorage, _: &Layout, _: &RocmStorage, _: &Layout) -> Result<()> {
409        Err(crate::Error::Msg(format!(
410            "no rocm implementation for {}",
411            self.name()
412        )))
413    }
414    #[cfg(feature = "vulkan")]
415    fn vulkan_fwd(
416        &self,
417        _: &mut VulkanStorage,
418        l1: &Layout,
419        _: &VulkanStorage,
420        _: &Layout,
421    ) -> Result<()> {
422        log_vulkan_custom_op_bail(self.name(), l1);
423        Err(crate::Error::Msg(format!(
424            "no vulkan implementation for {}",
425            self.name()
426        )))
427    }
428
429    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
430    /// offsets etc so the associated layout should be used to access it.
431    fn metal_fwd(
432        &self,
433        _: &mut MetalStorage,
434        _: &Layout,
435        _: &MetalStorage,
436        _: &Layout,
437    ) -> Result<()> {
438        Err(crate::Error::Metal(
439            format!("no metal implementation for {}", self.name()).into(),
440        ))
441    }
442}
443
444pub trait InplaceOp3 {
445    fn name(&self) -> &'static str;
446
447    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
448    /// offsets etc so the associated layout should be used to access it.
449    fn cpu_fwd(
450        &self,
451        s1: &mut CpuStorage,
452        l1: &Layout,
453        s2: &CpuStorage,
454        l2: &Layout,
455        s3: &CpuStorage,
456        l3: &Layout,
457    ) -> Result<()>;
458
459    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
460    /// offsets etc so the associated layout should be used to access it.
461    fn cuda_fwd(
462        &self,
463        _: &mut CudaStorage,
464        _: &Layout,
465        _: &CudaStorage,
466        _: &Layout,
467        _: &CudaStorage,
468        _: &Layout,
469    ) -> Result<()> {
470        Err(crate::Error::Cuda(
471            format!("no cuda implementation for {}", self.name()).into(),
472        ))
473    }
474
475    #[cfg(feature = "rocm")]
476    fn rocm_fwd(
477        &self,
478        _: &mut RocmStorage,
479        _: &Layout,
480        _: &RocmStorage,
481        _: &Layout,
482        _: &RocmStorage,
483        _: &Layout,
484    ) -> Result<()> {
485        Err(crate::Error::Msg(format!(
486            "no rocm implementation for {}",
487            self.name()
488        )))
489    }
490    #[cfg(feature = "vulkan")]
491    fn vulkan_fwd(
492        &self,
493        _: &mut VulkanStorage,
494        l1: &Layout,
495        _: &VulkanStorage,
496        _: &Layout,
497        _: &VulkanStorage,
498        _: &Layout,
499    ) -> Result<()> {
500        log_vulkan_custom_op_bail(self.name(), l1);
501        Err(crate::Error::Msg(format!(
502            "no vulkan implementation for {}",
503            self.name()
504        )))
505    }
506
507    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
508    /// offsets etc so the associated layout should be used to access it.
509    fn metal_fwd(
510        &self,
511        _: &mut MetalStorage,
512        _: &Layout,
513        _: &MetalStorage,
514        _: &Layout,
515        _: &MetalStorage,
516        _: &Layout,
517    ) -> Result<()> {
518        Err(crate::Error::Metal(
519            format!("no metal implementation for {}", self.name()).into(),
520        ))
521    }
522}
523
524impl Tensor {
525    /// Applies a unary custom op in place.
526    pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
527        self.storage_mut().inplace_op1(self.layout(), c)
528    }
529
530    /// Applies a unary custom op in place (for the first tensor).
531    pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
532        self.storage_mut()
533            .inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
534    }
535
536    /// Applies a ternary custom op in place (for the first tensor).
537    pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
538        self.storage_mut().inplace_op3(
539            self.layout(),
540            &t2.storage(),
541            t2.layout(),
542            &t3.storage(),
543            t3.layout(),
544            c,
545        )
546    }
547}
548
549#[cfg(feature = "ug")]
550pub struct UgIOp1 {
551    name: &'static str,
552    #[cfg(feature = "cuda")]
553    func: cudarc::driver::CudaFunction,
554    #[cfg(feature = "metal")]
555    func: hanzo_metal_kernels::metal::ComputePipeline,
556}
557
558#[cfg(feature = "ug")]
559impl UgIOp1 {
560    #[allow(unused)]
561    #[cfg(all(not(target_arch = "wasm32"), not(target_os = "ios")))]
562    pub fn new(
563        name: &'static str,
564        kernel: hanzo_ug::lang::ssa::Kernel,
565        device: &crate::Device,
566    ) -> Result<Self> {
567        #[cfg(feature = "cuda")]
568        {
569            let device = device.as_cuda_device()?;
570            let func = device.compile(name, kernel)?;
571            Ok(Self {
572                name,
573                func: func.into_cuda_function(),
574            })
575        }
576        #[cfg(feature = "metal")]
577        {
578            let device = device.as_metal_device()?;
579            let func = device.compile(name, kernel)?;
580            Ok(Self { name, func })
581        }
582        #[cfg(not(any(feature = "cuda", feature = "metal")))]
583        {
584            Ok(Self { name })
585        }
586    }
587}
588
589#[cfg(feature = "ug")]
590impl InplaceOp1 for UgIOp1 {
591    fn name(&self) -> &'static str {
592        self.name
593    }
594
595    fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
596        crate::bail!("ug ops are only supported on metal/cuda at the moment")
597    }
598
599    #[cfg(feature = "metal")]
600    fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
601        use crate::backend::BackendStorage;
602        use objc2_metal;
603
604        let elem_count = layout.shape().elem_count();
605        if sto.dtype() != crate::DType::F32 {
606            // TODO: support more dtypes.
607            crate::bail!("input is not a f32 tensor")
608        }
609        let device = sto.device();
610        let encoder = device.command_encoder()?;
611        encoder.set_compute_pipeline_state(&self.func);
612        let (g, b) = if elem_count.is_multiple_of(32) {
613            (elem_count / 32, 32)
614        } else {
615            (elem_count, 1)
616        };
617        let grid_dims = objc2_metal::MTLSize {
618            width: g,
619            height: 1,
620            depth: 1,
621        };
622        let group_dims = hanzo_metal_kernels::utils::get_block_dims(b, 1, 1);
623        encoder.set_output_buffer(0, Some(sto.buffer()), 0);
624        encoder.dispatch_threads(grid_dims, group_dims);
625
626        Ok(())
627    }
628
629    #[cfg(feature = "cuda")]
630    fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
631        use crate::cuda_backend::WrapErr;
632        use cudarc::driver::PushKernelArg;
633
634        let elem_count = layout.shape().elem_count();
635        let stream = sto.device.cuda_stream();
636        // TODO: support more dtypes.
637        let sto = sto.as_cuda_slice::<f32>()?;
638        let sto = match layout.contiguous_offsets() {
639            None => crate::bail!("input has to be contiguous"),
640            Some((o1, o2)) => sto.slice(o1..o2),
641        };
642        let (g, b) = if elem_count % 32 == 0 {
643            (elem_count / 32, 32)
644        } else {
645            (elem_count, 1)
646        };
647        let cfg = cudarc::driver::LaunchConfig {
648            grid_dim: (g as u32, 1, 1),
649            block_dim: (b as u32, 1, 1),
650            shared_mem_bytes: 0,
651        };
652        let mut builder = stream.launch_builder(&self.func);
653        builder.arg(&sto);
654        unsafe { builder.launch(cfg) }.w()?;
655        Ok(())
656    }
657}