1use crate::op::{BackpropOp, Op};
2use crate::tensor::from_storage;
3use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
4use std::sync::Arc;
5
6pub trait CustomOp1 {
8 fn name(&self) -> &'static str;
10
11 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
14
15 fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
18 Err(crate::Error::Cuda(
19 format!("no cuda implementation for {}", self.name()).into(),
20 ))
21 }
22
23 fn metal_fwd(
26 &self,
27 _storage: &MetalStorage,
28 _layout: &Layout,
29 ) -> Result<(MetalStorage, Shape)> {
30 Err(crate::Error::Metal(
31 format!("no metal implementation for {}", self.name()).into(),
32 ))
33 }
34
35 fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
39 Err(crate::Error::BackwardNotSupported { op: self.name() })
40 }
41}
42
43pub trait CustomOp2 {
44 fn name(&self) -> &'static str;
45
46 fn cpu_fwd(
49 &self,
50 s1: &CpuStorage,
51 l1: &Layout,
52 s2: &CpuStorage,
53 l2: &Layout,
54 ) -> Result<(CpuStorage, Shape)>;
55
56 fn cuda_fwd(
59 &self,
60 _: &CudaStorage,
61 _: &Layout,
62 _: &CudaStorage,
63 _: &Layout,
64 ) -> Result<(CudaStorage, Shape)> {
65 Err(crate::Error::Cuda(
66 format!("no cuda implementation for {}", self.name()).into(),
67 ))
68 }
69
70 fn metal_fwd(
73 &self,
74 _: &MetalStorage,
75 _: &Layout,
76 _: &MetalStorage,
77 _: &Layout,
78 ) -> Result<(MetalStorage, Shape)> {
79 Err(crate::Error::Metal(
80 format!("no metal implementation for {}", self.name()).into(),
81 ))
82 }
83
84 fn bwd(
85 &self,
86 _arg1: &Tensor,
87 _arg2: &Tensor,
88 _res: &Tensor,
89 _grad_res: &Tensor,
90 ) -> Result<(Option<Tensor>, Option<Tensor>)> {
91 Err(crate::Error::BackwardNotSupported { op: self.name() })
92 }
93}
94
95pub trait CustomOp3 {
96 fn name(&self) -> &'static str;
97
98 fn cpu_fwd(
101 &self,
102 s1: &CpuStorage,
103 l1: &Layout,
104 s2: &CpuStorage,
105 l2: &Layout,
106 s3: &CpuStorage,
107 l3: &Layout,
108 ) -> Result<(CpuStorage, Shape)>;
109
110 fn cuda_fwd(
113 &self,
114 _: &CudaStorage,
115 _: &Layout,
116 _: &CudaStorage,
117 _: &Layout,
118 _: &CudaStorage,
119 _: &Layout,
120 ) -> Result<(CudaStorage, Shape)> {
121 Err(crate::Error::Cuda(
122 format!("no cuda implementation for {}", self.name()).into(),
123 ))
124 }
125
126 fn metal_fwd(
129 &self,
130 _: &MetalStorage,
131 _: &Layout,
132 _: &MetalStorage,
133 _: &Layout,
134 _: &MetalStorage,
135 _: &Layout,
136 ) -> Result<(MetalStorage, Shape)> {
137 Err(crate::Error::Metal(
138 format!("no metal implementation for {}", self.name()).into(),
139 ))
140 }
141
142 fn bwd(
143 &self,
144 _arg1: &Tensor,
145 _arg2: &Tensor,
146 _arg3: &Tensor,
147 _res: &Tensor,
148 _grad_res: &Tensor,
149 ) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
150 Err(crate::Error::BackwardNotSupported { op: self.name() })
151 }
152}
153
154impl Tensor {
155 pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
157 let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
158 Ok(from_storage(storage, shape, BackpropOp::none(), false))
159 }
160
161 pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
163 let (storage, shape) =
164 self.storage()
165 .apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
166 Ok(from_storage(storage, shape, BackpropOp::none(), false))
167 }
168
169 pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
171 let (storage, shape) = self.storage().apply_op3(
172 self.layout(),
173 &t2.storage(),
174 t2.layout(),
175 &t3.storage(),
176 t3.layout(),
177 c,
178 )?;
179 Ok(from_storage(storage, shape, BackpropOp::none(), false))
180 }
181
182 pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
184 let (storage, shape) = self
185 .storage()
186 .apply_op1(self.layout(), c.as_ref().as_ref())?;
187 let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
188 Ok(from_storage(storage, shape, op, false))
189 }
190
191 pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
192 self.apply_op1_arc(Arc::new(Box::new(c)))
193 }
194
195 pub fn apply_op2_arc(
197 &self,
198 rhs: &Self,
199 c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
200 ) -> Result<Self> {
201 let (storage, shape) = self.storage().apply_op2(
202 self.layout(),
203 &rhs.storage(),
204 rhs.layout(),
205 c.as_ref().as_ref(),
206 )?;
207 let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
208 Ok(from_storage(storage, shape, op, false))
209 }
210
211 pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
212 self.apply_op2_arc(r, Arc::new(Box::new(c)))
213 }
214
215 pub fn apply_op3_arc(
217 &self,
218 t2: &Self,
219 t3: &Self,
220 c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
221 ) -> Result<Self> {
222 let (storage, shape) = self.storage().apply_op3(
223 self.layout(),
224 &t2.storage(),
225 t2.layout(),
226 &t3.storage(),
227 t3.layout(),
228 c.as_ref().as_ref(),
229 )?;
230 let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
231 Op::CustomOp3(t1, t2, t3, c.clone())
232 });
233 Ok(from_storage(storage, shape, op, false))
234 }
235
236 pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
237 &self,
238 t2: &Self,
239 t3: &Self,
240 c: C,
241 ) -> Result<Self> {
242 self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
243 }
244}
245
246pub trait InplaceOp1 {
251 fn name(&self) -> &'static str;
253
254 fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
257
258 fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
261 Err(crate::Error::Cuda(
262 format!("no cuda implementation for {}", self.name()).into(),
263 ))
264 }
265
266 fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
269 Err(crate::Error::Metal(
270 format!("no metal implementation for {}", self.name()).into(),
271 ))
272 }
273}
274
275pub trait InplaceOp2 {
276 fn name(&self) -> &'static str;
277
278 fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
281 -> Result<()>;
282
283 fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
286 Err(crate::Error::Cuda(
287 format!("no cuda implementation for {}", self.name()).into(),
288 ))
289 }
290
291 fn metal_fwd(
294 &self,
295 _: &mut MetalStorage,
296 _: &Layout,
297 _: &MetalStorage,
298 _: &Layout,
299 ) -> Result<()> {
300 Err(crate::Error::Metal(
301 format!("no metal implementation for {}", self.name()).into(),
302 ))
303 }
304}
305
306pub trait InplaceOp3 {
307 fn name(&self) -> &'static str;
308
309 fn cpu_fwd(
312 &self,
313 s1: &mut CpuStorage,
314 l1: &Layout,
315 s2: &CpuStorage,
316 l2: &Layout,
317 s3: &CpuStorage,
318 l3: &Layout,
319 ) -> Result<()>;
320
321 fn cuda_fwd(
324 &self,
325 _: &mut CudaStorage,
326 _: &Layout,
327 _: &CudaStorage,
328 _: &Layout,
329 _: &CudaStorage,
330 _: &Layout,
331 ) -> Result<()> {
332 Err(crate::Error::Cuda(
333 format!("no cuda implementation for {}", self.name()).into(),
334 ))
335 }
336
337 fn metal_fwd(
340 &self,
341 _: &mut MetalStorage,
342 _: &Layout,
343 _: &MetalStorage,
344 _: &Layout,
345 _: &MetalStorage,
346 _: &Layout,
347 ) -> Result<()> {
348 Err(crate::Error::Metal(
349 format!("no metal implementation for {}", self.name()).into(),
350 ))
351 }
352}
353
354impl Tensor {
355 pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
357 self.storage_mut().inplace_op1(self.layout(), c)
358 }
359
360 pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
362 self.storage_mut()
363 .inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
364 }
365
366 pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
368 self.storage_mut().inplace_op3(
369 self.layout(),
370 &t2.storage(),
371 t2.layout(),
372 &t3.storage(),
373 t3.layout(),
374 c,
375 )
376 }
377}
378
379#[cfg(feature = "ug")]
380pub struct UgIOp1 {
381 name: &'static str,
382 #[cfg(feature = "cuda")]
383 func: cudarc::driver::CudaFunction,
384 #[cfg(feature = "metal")]
385 func: candle_metal_kernels::metal::ComputePipeline,
386}
387
388#[cfg(feature = "ug")]
389impl UgIOp1 {
390 #[allow(unused)]
391 #[cfg(all(not(target_arch = "wasm32"), not(target_os = "ios")))]
392 pub fn new(
393 name: &'static str,
394 kernel: candle_ug::lang::ssa::Kernel,
395 device: &crate::Device,
396 ) -> Result<Self> {
397 #[cfg(feature = "cuda")]
398 {
399 let device = device.as_cuda_device()?;
400 let func = device.compile(name, kernel)?;
401 Ok(Self {
402 name,
403 func: func.into_cuda_function(),
404 })
405 }
406 #[cfg(feature = "metal")]
407 {
408 let device = device.as_metal_device()?;
409 let func = device.compile(name, kernel)?;
410 Ok(Self { name, func })
411 }
412 #[cfg(not(any(feature = "cuda", feature = "metal")))]
413 {
414 Ok(Self { name })
415 }
416 }
417}
418
419#[cfg(feature = "ug")]
420impl InplaceOp1 for UgIOp1 {
421 fn name(&self) -> &'static str {
422 self.name
423 }
424
425 fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
426 crate::bail!("ug ops are only supported on metal/cuda at the moment")
427 }
428
429 #[cfg(feature = "metal")]
430 fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
431 use crate::backend::BackendStorage;
432 use objc2_metal;
433
434 let elem_count = layout.shape().elem_count();
435 if sto.dtype() != crate::DType::F32 {
436 crate::bail!("input is not a f32 tensor")
438 }
439 let device = sto.device();
440 let encoder = device.command_encoder()?;
441 encoder.set_compute_pipeline_state(&self.func);
442 let (g, b) = if elem_count.is_multiple_of(32) {
443 (elem_count / 32, 32)
444 } else {
445 (elem_count, 1)
446 };
447 let grid_dims = objc2_metal::MTLSize {
448 width: g,
449 height: 1,
450 depth: 1,
451 };
452 let group_dims = candle_metal_kernels::utils::get_block_dims(b, 1, 1);
453 candle_metal_kernels::utils::set_param(&encoder, 0, (sto.buffer(), 0usize));
454
455 encoder.use_resource(sto.buffer(), objc2_metal::MTLResourceUsage::Write);
456 encoder.dispatch_threads(grid_dims, group_dims);
457
458 Ok(())
459 }
460
461 #[cfg(feature = "cuda")]
462 fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
463 use crate::cuda_backend::WrapErr;
464 use cudarc::driver::PushKernelArg;
465
466 let elem_count = layout.shape().elem_count();
467 let stream = sto.device.cuda_stream();
468 let sto = sto.as_cuda_slice::<f32>()?;
470 let sto = match layout.contiguous_offsets() {
471 None => crate::bail!("input has to be contiguous"),
472 Some((o1, o2)) => sto.slice(o1..o2),
473 };
474 let (g, b) = if elem_count % 32 == 0 {
475 (elem_count / 32, 32)
476 } else {
477 (elem_count, 1)
478 };
479 let cfg = cudarc::driver::LaunchConfig {
480 grid_dim: (g as u32, 1, 1),
481 block_dim: (b as u32, 1, 1),
482 shared_mem_bytes: 0,
483 };
484 let mut builder = stream.launch_builder(&self.func);
485 builder.arg(&sto);
486 unsafe { builder.launch(cfg) }.w()?;
487 Ok(())
488 }
489}