candle_core/
custom_op.rs

1use crate::op::{BackpropOp, Op};
2use crate::tensor::from_storage;
3use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
4use std::sync::Arc;
5
6/// Unary ops that can be defined in user-land.
7pub trait CustomOp1 {
8    // Box<dyn> does not support const yet, so use a function to get the name.
9    fn name(&self) -> &'static str;
10
11    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
12    /// offsets etc so the associated layout should be used to access it.
13    fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
14
15    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
16    /// offsets etc so the associated layout should be used to access it.
17    fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
18        Err(crate::Error::Cuda(
19            format!("no cuda implementation for {}", self.name()).into(),
20        ))
21    }
22
23    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
24    /// offsets etc so the associated layout should be used to access it.
25    fn metal_fwd(
26        &self,
27        _storage: &MetalStorage,
28        _layout: &Layout,
29    ) -> Result<(MetalStorage, Shape)> {
30        Err(crate::Error::Metal(
31            format!("no metal implementation for {}", self.name()).into(),
32        ))
33    }
34
35    /// This function takes as argument the argument `arg` used in the forward pass, the result
36    /// produced by the forward operation `res` and the gradient of the result `grad_res`.
37    /// The function should return the gradient of the argument.
38    fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
39        Err(crate::Error::BackwardNotSupported { op: self.name() })
40    }
41}
42
43pub trait CustomOp2 {
44    fn name(&self) -> &'static str;
45
46    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
47    /// offsets etc so the associated layout should be used to access it.
48    fn cpu_fwd(
49        &self,
50        s1: &CpuStorage,
51        l1: &Layout,
52        s2: &CpuStorage,
53        l2: &Layout,
54    ) -> Result<(CpuStorage, Shape)>;
55
56    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
57    /// offsets etc so the associated layout should be used to access it.
58    fn cuda_fwd(
59        &self,
60        _: &CudaStorage,
61        _: &Layout,
62        _: &CudaStorage,
63        _: &Layout,
64    ) -> Result<(CudaStorage, Shape)> {
65        Err(crate::Error::Cuda(
66            format!("no cuda implementation for {}", self.name()).into(),
67        ))
68    }
69
70    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
71    /// offsets etc so the associated layout should be used to access it.
72    fn metal_fwd(
73        &self,
74        _: &MetalStorage,
75        _: &Layout,
76        _: &MetalStorage,
77        _: &Layout,
78    ) -> Result<(MetalStorage, Shape)> {
79        Err(crate::Error::Metal(
80            format!("no metal implementation for {}", self.name()).into(),
81        ))
82    }
83
84    fn bwd(
85        &self,
86        _arg1: &Tensor,
87        _arg2: &Tensor,
88        _res: &Tensor,
89        _grad_res: &Tensor,
90    ) -> Result<(Option<Tensor>, Option<Tensor>)> {
91        Err(crate::Error::BackwardNotSupported { op: self.name() })
92    }
93}
94
95pub trait CustomOp3 {
96    fn name(&self) -> &'static str;
97
98    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
99    /// offsets etc so the associated layout should be used to access it.
100    fn cpu_fwd(
101        &self,
102        s1: &CpuStorage,
103        l1: &Layout,
104        s2: &CpuStorage,
105        l2: &Layout,
106        s3: &CpuStorage,
107        l3: &Layout,
108    ) -> Result<(CpuStorage, Shape)>;
109
110    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
111    /// offsets etc so the associated layout should be used to access it.
112    fn cuda_fwd(
113        &self,
114        _: &CudaStorage,
115        _: &Layout,
116        _: &CudaStorage,
117        _: &Layout,
118        _: &CudaStorage,
119        _: &Layout,
120    ) -> Result<(CudaStorage, Shape)> {
121        Err(crate::Error::Cuda(
122            format!("no cuda implementation for {}", self.name()).into(),
123        ))
124    }
125
126    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
127    /// offsets etc so the associated layout should be used to access it.
128    fn metal_fwd(
129        &self,
130        _: &MetalStorage,
131        _: &Layout,
132        _: &MetalStorage,
133        _: &Layout,
134        _: &MetalStorage,
135        _: &Layout,
136    ) -> Result<(MetalStorage, Shape)> {
137        Err(crate::Error::Metal(
138            format!("no metal implementation for {}", self.name()).into(),
139        ))
140    }
141
142    fn bwd(
143        &self,
144        _arg1: &Tensor,
145        _arg2: &Tensor,
146        _arg3: &Tensor,
147        _res: &Tensor,
148        _grad_res: &Tensor,
149    ) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
150        Err(crate::Error::BackwardNotSupported { op: self.name() })
151    }
152}
153
154impl Tensor {
155    /// Applies a unary custom op without backward support
156    pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
157        let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
158        Ok(from_storage(storage, shape, BackpropOp::none(), false))
159    }
160
161    /// Applies a binary custom op without backward support
162    pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
163        let (storage, shape) =
164            self.storage()
165                .apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
166        Ok(from_storage(storage, shape, BackpropOp::none(), false))
167    }
168
169    /// Applies a ternary custom op without backward support
170    pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
171        let (storage, shape) = self.storage().apply_op3(
172            self.layout(),
173            &t2.storage(),
174            t2.layout(),
175            &t3.storage(),
176            t3.layout(),
177            c,
178        )?;
179        Ok(from_storage(storage, shape, BackpropOp::none(), false))
180    }
181
182    /// Applies a unary custom op.
183    pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
184        let (storage, shape) = self
185            .storage()
186            .apply_op1(self.layout(), c.as_ref().as_ref())?;
187        let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
188        Ok(from_storage(storage, shape, op, false))
189    }
190
191    pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
192        self.apply_op1_arc(Arc::new(Box::new(c)))
193    }
194
195    /// Applies a binary custom op.
196    pub fn apply_op2_arc(
197        &self,
198        rhs: &Self,
199        c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
200    ) -> Result<Self> {
201        let (storage, shape) = self.storage().apply_op2(
202            self.layout(),
203            &rhs.storage(),
204            rhs.layout(),
205            c.as_ref().as_ref(),
206        )?;
207        let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
208        Ok(from_storage(storage, shape, op, false))
209    }
210
211    pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
212        self.apply_op2_arc(r, Arc::new(Box::new(c)))
213    }
214
215    /// Applies a ternary custom op.
216    pub fn apply_op3_arc(
217        &self,
218        t2: &Self,
219        t3: &Self,
220        c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
221    ) -> Result<Self> {
222        let (storage, shape) = self.storage().apply_op3(
223            self.layout(),
224            &t2.storage(),
225            t2.layout(),
226            &t3.storage(),
227            t3.layout(),
228            c.as_ref().as_ref(),
229        )?;
230        let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
231            Op::CustomOp3(t1, t2, t3, c.clone())
232        });
233        Ok(from_storage(storage, shape, op, false))
234    }
235
236    pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
237        &self,
238        t2: &Self,
239        t3: &Self,
240        c: C,
241    ) -> Result<Self> {
242        self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
243    }
244}
245
246// In place ops.
247
248/// Unary ops that can be defined in user-land.
249/// These ops work in place and as such back-prop is unsupported.
250pub trait InplaceOp1 {
251    // Box<dyn> does not support const yet, so use a function to get the name.
252    fn name(&self) -> &'static str;
253
254    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
255    /// offsets etc so the associated layout should be used to access it.
256    fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
257
258    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
259    /// offsets etc so the associated layout should be used to access it.
260    fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
261        Err(crate::Error::Cuda(
262            format!("no cuda implementation for {}", self.name()).into(),
263        ))
264    }
265
266    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
267    /// offsets etc so the associated layout should be used to access it.
268    fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
269        Err(crate::Error::Metal(
270            format!("no metal implementation for {}", self.name()).into(),
271        ))
272    }
273}
274
275pub trait InplaceOp2 {
276    fn name(&self) -> &'static str;
277
278    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
279    /// offsets etc so the associated layout should be used to access it.
280    fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
281        -> Result<()>;
282
283    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
284    /// offsets etc so the associated layout should be used to access it.
285    fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
286        Err(crate::Error::Cuda(
287            format!("no cuda implementation for {}", self.name()).into(),
288        ))
289    }
290
291    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
292    /// offsets etc so the associated layout should be used to access it.
293    fn metal_fwd(
294        &self,
295        _: &mut MetalStorage,
296        _: &Layout,
297        _: &MetalStorage,
298        _: &Layout,
299    ) -> Result<()> {
300        Err(crate::Error::Metal(
301            format!("no metal implementation for {}", self.name()).into(),
302        ))
303    }
304}
305
306pub trait InplaceOp3 {
307    fn name(&self) -> &'static str;
308
309    /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
310    /// offsets etc so the associated layout should be used to access it.
311    fn cpu_fwd(
312        &self,
313        s1: &mut CpuStorage,
314        l1: &Layout,
315        s2: &CpuStorage,
316        l2: &Layout,
317        s3: &CpuStorage,
318        l3: &Layout,
319    ) -> Result<()>;
320
321    /// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
322    /// offsets etc so the associated layout should be used to access it.
323    fn cuda_fwd(
324        &self,
325        _: &mut CudaStorage,
326        _: &Layout,
327        _: &CudaStorage,
328        _: &Layout,
329        _: &CudaStorage,
330        _: &Layout,
331    ) -> Result<()> {
332        Err(crate::Error::Cuda(
333            format!("no cuda implementation for {}", self.name()).into(),
334        ))
335    }
336
337    /// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
338    /// offsets etc so the associated layout should be used to access it.
339    fn metal_fwd(
340        &self,
341        _: &mut MetalStorage,
342        _: &Layout,
343        _: &MetalStorage,
344        _: &Layout,
345        _: &MetalStorage,
346        _: &Layout,
347    ) -> Result<()> {
348        Err(crate::Error::Metal(
349            format!("no metal implementation for {}", self.name()).into(),
350        ))
351    }
352}
353
354impl Tensor {
355    /// Applies a unary custom op in place.
356    pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
357        self.storage_mut().inplace_op1(self.layout(), c)
358    }
359
360    /// Applies a unary custom op in place (for the first tensor).
361    pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
362        self.storage_mut()
363            .inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
364    }
365
366    /// Applies a ternary custom op in place (for the first tensor).
367    pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
368        self.storage_mut().inplace_op3(
369            self.layout(),
370            &t2.storage(),
371            t2.layout(),
372            &t3.storage(),
373            t3.layout(),
374            c,
375        )
376    }
377}
378
379pub struct UgIOp1 {
380    name: &'static str,
381    #[cfg(feature = "cuda")]
382    func: cudarc::driver::CudaFunction,
383    #[cfg(feature = "metal")]
384    func: metal::ComputePipelineState,
385}
386
387impl UgIOp1 {
388    #[allow(unused)]
389    #[cfg(not(target_arch = "wasm32"))]
390    pub fn new(
391        name: &'static str,
392        kernel: ug::lang::ssa::Kernel,
393        device: &crate::Device,
394    ) -> Result<Self> {
395        #[cfg(feature = "cuda")]
396        {
397            let device = device.as_cuda_device()?;
398            let func = device.compile(name, kernel)?;
399            Ok(Self { name, func })
400        }
401        #[cfg(feature = "metal")]
402        {
403            let device = device.as_metal_device()?;
404            let func = device.compile(name, kernel)?;
405            Ok(Self { name, func })
406        }
407        #[cfg(not(any(feature = "cuda", feature = "metal")))]
408        {
409            Ok(Self { name })
410        }
411    }
412}
413
414impl InplaceOp1 for UgIOp1 {
415    fn name(&self) -> &'static str {
416        self.name
417    }
418
419    fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
420        crate::bail!("ug ops are only supported on metal/cuda at the moment")
421    }
422
423    #[cfg(feature = "metal")]
424    fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
425        use crate::backend::BackendStorage;
426        use candle_metal_kernels::utils::EncoderProvider;
427
428        let elem_count = layout.shape().elem_count();
429        if sto.dtype() != crate::DType::F32 {
430            // TODO: support more dtypes.
431            crate::bail!("input is not a f32 tensor")
432        }
433        let device = sto.device();
434        println!("here");
435        let command_buffer = device.command_buffer()?;
436        let command_buffer = &command_buffer;
437        let encoder = command_buffer.encoder();
438        let encoder = encoder.as_ref();
439        encoder.set_compute_pipeline_state(&self.func);
440        let (g, b) = if elem_count % 32 == 0 {
441            (elem_count / 32, 32)
442        } else {
443            (elem_count, 1)
444        };
445        let grid_dims = metal::MTLSize {
446            width: g as u64,
447            height: 1,
448            depth: 1,
449        };
450        let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1);
451        candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize));
452
453        encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write);
454        encoder.dispatch_threads(grid_dims, group_dims);
455
456        Ok(())
457    }
458
459    #[cfg(feature = "cuda")]
460    fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
461        use crate::cuda_backend::WrapErr;
462        use cudarc::driver::LaunchAsync;
463
464        let elem_count = layout.shape().elem_count();
465        // TODO: support more dtypes.
466        let sto = sto.as_cuda_slice::<f32>()?;
467        let sto = match layout.contiguous_offsets() {
468            None => crate::bail!("input has to be contiguous"),
469            Some((o1, o2)) => sto.slice(o1..o2),
470        };
471        let params = (&sto,);
472        let (g, b) = if elem_count % 32 == 0 {
473            (elem_count / 32, 32)
474        } else {
475            (elem_count, 1)
476        };
477        let cfg = cudarc::driver::LaunchConfig {
478            grid_dim: (g as u32, 1, 1),
479            block_dim: (b as u32, 1, 1),
480            shared_mem_bytes: 0,
481        };
482        unsafe { self.func.clone().launch(cfg, params) }.w()?;
483        Ok(())
484    }
485}