candle_core/
custom_op.rs

1use crate::op::{BackpropOp, Op};
2use crate::tensor::from_storage;
3use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
4use std::sync::Arc;
5
6/// Unary ops that can be defined in user-land.
7pub trait CustomOp1 {
8    // Box<dyn> does not support const yet, so use a function to get the name.
9    fn name(&self) -> &'static str;
10
11    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
12    /// offsets etc so the associated layout should be used to access it.
13    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
14
15    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
16    /// offsets etc so the associated layout should be used to access it.
17    fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
18        Err(crate::Error::Cuda(
19            format!("no cuda implementation for {}", self.name()).into(),
20        ))
21    }
22
23    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
24    /// offsets etc so the associated layout should be used to access it.
25    fn metal_fwd(
26        &self,
27        _storage: &MetalStorage,
28        _layout: &Layout,
29    ) -> Result<(MetalStorage, Shape)> {
30        Err(crate::Error::Metal(
31            format!("no metal implementation for {}", self.name()).into(),
32        ))
33    }
34
35    /// This function takes as argument the argument `arg` used in the forward pass, the result
36    /// produced by the forward operation `res` and the gradient of the result `grad_res`.
37    /// The function should return the gradient of the argument.
38    fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
39        Err(crate::Error::BackwardNotSupported { op: self.name() })
40    }
41}
42
43pub trait CustomOp2 {
44    fn name(&self) -> &'static str;
45
46    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
47    /// offsets etc so the associated layout should be used to access it.
48    fn cpu_fwd(
49        &self,
50        s1: &CpuStorage,
51        l1: &Layout,
52        s2: &CpuStorage,
53        l2: &Layout,
54    ) -> Result<(CpuStorage, Shape)>;
55
56    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
57    /// offsets etc so the associated layout should be used to access it.
58    fn cuda_fwd(
59        &self,
60        _: &CudaStorage,
61        _: &Layout,
62        _: &CudaStorage,
63        _: &Layout,
64    ) -> Result<(CudaStorage, Shape)> {
65        Err(crate::Error::Cuda(
66            format!("no cuda implementation for {}", self.name()).into(),
67        ))
68    }
69
70    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
71    /// offsets etc so the associated layout should be used to access it.
72    fn metal_fwd(
73        &self,
74        _: &MetalStorage,
75        _: &Layout,
76        _: &MetalStorage,
77        _: &Layout,
78    ) -> Result<(MetalStorage, Shape)> {
79        Err(crate::Error::Metal(
80            format!("no metal implementation for {}", self.name()).into(),
81        ))
82    }
83
84    fn bwd(
85        &self,
86        _arg1: &Tensor,
87        _arg2: &Tensor,
88        _res: &Tensor,
89        _grad_res: &Tensor,
90    ) -> Result<(Option<Tensor>, Option<Tensor>)> {
91        Err(crate::Error::BackwardNotSupported { op: self.name() })
92    }
93}
94
95pub trait CustomOp3 {
96    fn name(&self) -> &'static str;
97
98    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
99    /// offsets etc so the associated layout should be used to access it.
100    fn cpu_fwd(
101        &self,
102        s1: &CpuStorage,
103        l1: &Layout,
104        s2: &CpuStorage,
105        l2: &Layout,
106        s3: &CpuStorage,
107        l3: &Layout,
108    ) -> Result<(CpuStorage, Shape)>;
109
110    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
111    /// offsets etc so the associated layout should be used to access it.
112    fn cuda_fwd(
113        &self,
114        _: &CudaStorage,
115        _: &Layout,
116        _: &CudaStorage,
117        _: &Layout,
118        _: &CudaStorage,
119        _: &Layout,
120    ) -> Result<(CudaStorage, Shape)> {
121        Err(crate::Error::Cuda(
122            format!("no cuda implementation for {}", self.name()).into(),
123        ))
124    }
125
126    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
127    /// offsets etc so the associated layout should be used to access it.
128    fn metal_fwd(
129        &self,
130        _: &MetalStorage,
131        _: &Layout,
132        _: &MetalStorage,
133        _: &Layout,
134        _: &MetalStorage,
135        _: &Layout,
136    ) -> Result<(MetalStorage, Shape)> {
137        Err(crate::Error::Metal(
138            format!("no metal implementation for {}", self.name()).into(),
139        ))
140    }
141
142    fn bwd(
143        &self,
144        _arg1: &Tensor,
145        _arg2: &Tensor,
146        _arg3: &Tensor,
147        _res: &Tensor,
148        _grad_res: &Tensor,
149    ) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
150        Err(crate::Error::BackwardNotSupported { op: self.name() })
151    }
152}
153
154impl Tensor {
155    /// Applies a unary custom op without backward support
156    pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
157        let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
158        Ok(from_storage(storage, shape, BackpropOp::none(), false))
159    }
160
161    /// Applies a binary custom op without backward support
162    pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
163        let (storage, shape) =
164            self.storage()
165                .apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
166        Ok(from_storage(storage, shape, BackpropOp::none(), false))
167    }
168
169    /// Applies a ternary custom op without backward support
170    pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
171        let (storage, shape) = self.storage().apply_op3(
172            self.layout(),
173            &t2.storage(),
174            t2.layout(),
175            &t3.storage(),
176            t3.layout(),
177            c,
178        )?;
179        Ok(from_storage(storage, shape, BackpropOp::none(), false))
180    }
181
182    /// Applies a unary custom op.
183    pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
184        let (storage, shape) = self
185            .storage()
186            .apply_op1(self.layout(), c.as_ref().as_ref())?;
187        let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
188        Ok(from_storage(storage, shape, op, false))
189    }
190
191    pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
192        self.apply_op1_arc(Arc::new(Box::new(c)))
193    }
194
195    /// Applies a binary custom op.
196    pub fn apply_op2_arc(
197        &self,
198        rhs: &Self,
199        c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
200    ) -> Result<Self> {
201        let (storage, shape) = self.storage().apply_op2(
202            self.layout(),
203            &rhs.storage(),
204            rhs.layout(),
205            c.as_ref().as_ref(),
206        )?;
207        let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
208        Ok(from_storage(storage, shape, op, false))
209    }
210
211    pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
212        self.apply_op2_arc(r, Arc::new(Box::new(c)))
213    }
214
215    /// Applies a ternary custom op.
216    pub fn apply_op3_arc(
217        &self,
218        t2: &Self,
219        t3: &Self,
220        c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
221    ) -> Result<Self> {
222        let (storage, shape) = self.storage().apply_op3(
223            self.layout(),
224            &t2.storage(),
225            t2.layout(),
226            &t3.storage(),
227            t3.layout(),
228            c.as_ref().as_ref(),
229        )?;
230        let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
231            Op::CustomOp3(t1, t2, t3, c.clone())
232        });
233        Ok(from_storage(storage, shape, op, false))
234    }
235
236    pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
237        &self,
238        t2: &Self,
239        t3: &Self,
240        c: C,
241    ) -> Result<Self> {
242        self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
243    }
244}
245
246// In place ops.
247
248/// Unary ops that can be defined in user-land.
249/// These ops work in place and as such back-prop is unsupported.
250pub trait InplaceOp1 {
251    // Box<dyn> does not support const yet, so use a function to get the name.
252    fn name(&self) -> &'static str;
253
254    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
255    /// offsets etc so the associated layout should be used to access it.
256    fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
257
258    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
259    /// offsets etc so the associated layout should be used to access it.
260    fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
261        Err(crate::Error::Cuda(
262            format!("no cuda implementation for {}", self.name()).into(),
263        ))
264    }
265
266    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
267    /// offsets etc so the associated layout should be used to access it.
268    fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
269        Err(crate::Error::Metal(
270            format!("no metal implementation for {}", self.name()).into(),
271        ))
272    }
273}
274
275pub trait InplaceOp2 {
276    fn name(&self) -> &'static str;
277
278    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
279    /// offsets etc so the associated layout should be used to access it.
280    fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
281        -> Result<()>;
282
283    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
284    /// offsets etc so the associated layout should be used to access it.
285    fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
286        Err(crate::Error::Cuda(
287            format!("no cuda implementation for {}", self.name()).into(),
288        ))
289    }
290
291    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
292    /// offsets etc so the associated layout should be used to access it.
293    fn metal_fwd(
294        &self,
295        _: &mut MetalStorage,
296        _: &Layout,
297        _: &MetalStorage,
298        _: &Layout,
299    ) -> Result<()> {
300        Err(crate::Error::Metal(
301            format!("no metal implementation for {}", self.name()).into(),
302        ))
303    }
304}
305
306pub trait InplaceOp3 {
307    fn name(&self) -> &'static str;
308
309    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
310    /// offsets etc so the associated layout should be used to access it.
311    fn cpu_fwd(
312        &self,
313        s1: &mut CpuStorage,
314        l1: &Layout,
315        s2: &CpuStorage,
316        l2: &Layout,
317        s3: &CpuStorage,
318        l3: &Layout,
319    ) -> Result<()>;
320
321    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
322    /// offsets etc so the associated layout should be used to access it.
323    fn cuda_fwd(
324        &self,
325        _: &mut CudaStorage,
326        _: &Layout,
327        _: &CudaStorage,
328        _: &Layout,
329        _: &CudaStorage,
330        _: &Layout,
331    ) -> Result<()> {
332        Err(crate::Error::Cuda(
333            format!("no cuda implementation for {}", self.name()).into(),
334        ))
335    }
336
337    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
338    /// offsets etc so the associated layout should be used to access it.
339    fn metal_fwd(
340        &self,
341        _: &mut MetalStorage,
342        _: &Layout,
343        _: &MetalStorage,
344        _: &Layout,
345        _: &MetalStorage,
346        _: &Layout,
347    ) -> Result<()> {
348        Err(crate::Error::Metal(
349            format!("no metal implementation for {}", self.name()).into(),
350        ))
351    }
352}
353
354impl Tensor {
355    /// Applies a unary custom op in place.
356    pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
357        self.storage_mut().inplace_op1(self.layout(), c)
358    }
359
360    /// Applies a unary custom op in place (for the first tensor).
361    pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
362        self.storage_mut()
363            .inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
364    }
365
366    /// Applies a ternary custom op in place (for the first tensor).
367    pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
368        self.storage_mut().inplace_op3(
369            self.layout(),
370            &t2.storage(),
371            t2.layout(),
372            &t3.storage(),
373            t3.layout(),
374            c,
375        )
376    }
377}
378
379#[cfg(feature = "ug")]
380pub struct UgIOp1 {
381    name: &'static str,
382    #[cfg(feature = "cuda")]
383    func: cudarc::driver::CudaFunction,
384    #[cfg(feature = "metal")]
385    func: candle_metal_kernels::metal::ComputePipeline,
386}
387
388#[cfg(feature = "ug")]
389impl UgIOp1 {
390    #[allow(unused)]
391    #[cfg(all(not(target_arch = "wasm32"), not(target_os = "ios")))]
392    pub fn new(
393        name: &'static str,
394        kernel: candle_ug::lang::ssa::Kernel,
395        device: &crate::Device,
396    ) -> Result<Self> {
397        #[cfg(feature = "cuda")]
398        {
399            let device = device.as_cuda_device()?;
400            let func = device.compile(name, kernel)?;
401            Ok(Self {
402                name,
403                func: func.into_cuda_function(),
404            })
405        }
406        #[cfg(feature = "metal")]
407        {
408            let device = device.as_metal_device()?;
409            let func = device.compile(name, kernel)?;
410            Ok(Self { name, func })
411        }
412        #[cfg(not(any(feature = "cuda", feature = "metal")))]
413        {
414            Ok(Self { name })
415        }
416    }
417}
418
419#[cfg(feature = "ug")]
420impl InplaceOp1 for UgIOp1 {
421    fn name(&self) -> &'static str {
422        self.name
423    }
424
425    fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
426        crate::bail!("ug ops are only supported on metal/cuda at the moment")
427    }
428
429    #[cfg(feature = "metal")]
430    fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
431        use crate::backend::BackendStorage;
432        use objc2_metal;
433
434        let elem_count = layout.shape().elem_count();
435        if sto.dtype() != crate::DType::F32 {
436            // TODO: support more dtypes.
437            crate::bail!("input is not a f32 tensor")
438        }
439        let device = sto.device();
440        let encoder = device.command_encoder()?;
441        encoder.set_compute_pipeline_state(&self.func);
442        let (g, b) = if elem_count.is_multiple_of(32) {
443            (elem_count / 32, 32)
444        } else {
445            (elem_count, 1)
446        };
447        let grid_dims = objc2_metal::MTLSize {
448            width: g,
449            height: 1,
450            depth: 1,
451        };
452        let group_dims = candle_metal_kernels::utils::get_block_dims(b, 1, 1);
453        candle_metal_kernels::utils::set_param(&encoder, 0, (sto.buffer(), 0usize));
454
455        encoder.use_resource(sto.buffer(), objc2_metal::MTLResourceUsage::Write);
456        encoder.dispatch_threads(grid_dims, group_dims);
457
458        Ok(())
459    }
460
461    #[cfg(feature = "cuda")]
462    fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
463        use crate::cuda_backend::WrapErr;
464        use cudarc::driver::PushKernelArg;
465
466        let elem_count = layout.shape().elem_count();
467        let stream = sto.device.cuda_stream();
468        // TODO: support more dtypes.
469        let sto = sto.as_cuda_slice::<f32>()?;
470        let sto = match layout.contiguous_offsets() {
471            None => crate::bail!("input has to be contiguous"),
472            Some((o1, o2)) => sto.slice(o1..o2),
473        };
474        let (g, b) = if elem_count % 32 == 0 {
475            (elem_count / 32, 32)
476        } else {
477            (elem_count, 1)
478        };
479        let cfg = cudarc::driver::LaunchConfig {
480            grid_dim: (g as u32, 1, 1),
481            block_dim: (b as u32, 1, 1),
482            shared_mem_bytes: 0,
483        };
484        let mut builder = stream.launch_builder(&self.func);
485        builder.arg(&sto);
486        unsafe { builder.launch(cfg) }.w()?;
487        Ok(())
488    }
489}