1use crate::op::{BackpropOp, Op};
2use crate::tensor::from_storage;
3use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
4use std::sync::Arc;
5
6pub trait CustomOp1 {
8 fn name(&self) -> &'static str;
10
11 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
14
15 fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
18 Err(crate::Error::Cuda(
19 format!("no cuda implementation for {}", self.name()).into(),
20 ))
21 }
22
23 fn metal_fwd(
26 &self,
27 _storage: &MetalStorage,
28 _layout: &Layout,
29 ) -> Result<(MetalStorage, Shape)> {
30 Err(crate::Error::Metal(
31 format!("no metal implementation for {}", self.name()).into(),
32 ))
33 }
34
35 fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
39 Err(crate::Error::BackwardNotSupported { op: self.name() })
40 }
41}
42
43pub trait CustomOp2 {
44 fn name(&self) -> &'static str;
45
46 fn cpu_fwd(
49 &self,
50 s1: &CpuStorage,
51 l1: &Layout,
52 s2: &CpuStorage,
53 l2: &Layout,
54 ) -> Result<(CpuStorage, Shape)>;
55
56 fn cuda_fwd(
59 &self,
60 _: &CudaStorage,
61 _: &Layout,
62 _: &CudaStorage,
63 _: &Layout,
64 ) -> Result<(CudaStorage, Shape)> {
65 Err(crate::Error::Cuda(
66 format!("no cuda implementation for {}", self.name()).into(),
67 ))
68 }
69
70 fn metal_fwd(
73 &self,
74 _: &MetalStorage,
75 _: &Layout,
76 _: &MetalStorage,
77 _: &Layout,
78 ) -> Result<(MetalStorage, Shape)> {
79 Err(crate::Error::Metal(
80 format!("no metal implementation for {}", self.name()).into(),
81 ))
82 }
83
84 fn bwd(
85 &self,
86 _arg1: &Tensor,
87 _arg2: &Tensor,
88 _res: &Tensor,
89 _grad_res: &Tensor,
90 ) -> Result<(Option<Tensor>, Option<Tensor>)> {
91 Err(crate::Error::BackwardNotSupported { op: self.name() })
92 }
93}
94
95pub trait CustomOp3 {
96 fn name(&self) -> &'static str;
97
98 fn cpu_fwd(
101 &self,
102 s1: &CpuStorage,
103 l1: &Layout,
104 s2: &CpuStorage,
105 l2: &Layout,
106 s3: &CpuStorage,
107 l3: &Layout,
108 ) -> Result<(CpuStorage, Shape)>;
109
110 fn cuda_fwd(
113 &self,
114 _: &CudaStorage,
115 _: &Layout,
116 _: &CudaStorage,
117 _: &Layout,
118 _: &CudaStorage,
119 _: &Layout,
120 ) -> Result<(CudaStorage, Shape)> {
121 Err(crate::Error::Cuda(
122 format!("no cuda implementation for {}", self.name()).into(),
123 ))
124 }
125
126 fn metal_fwd(
129 &self,
130 _: &MetalStorage,
131 _: &Layout,
132 _: &MetalStorage,
133 _: &Layout,
134 _: &MetalStorage,
135 _: &Layout,
136 ) -> Result<(MetalStorage, Shape)> {
137 Err(crate::Error::Metal(
138 format!("no metal implementation for {}", self.name()).into(),
139 ))
140 }
141
142 fn bwd(
143 &self,
144 _arg1: &Tensor,
145 _arg2: &Tensor,
146 _arg3: &Tensor,
147 _res: &Tensor,
148 _grad_res: &Tensor,
149 ) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
150 Err(crate::Error::BackwardNotSupported { op: self.name() })
151 }
152}
153
154impl Tensor {
155 pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
157 let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
158 Ok(from_storage(storage, shape, BackpropOp::none(), false))
159 }
160
161 pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
163 let (storage, shape) =
164 self.storage()
165 .apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
166 Ok(from_storage(storage, shape, BackpropOp::none(), false))
167 }
168
169 pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
171 let (storage, shape) = self.storage().apply_op3(
172 self.layout(),
173 &t2.storage(),
174 t2.layout(),
175 &t3.storage(),
176 t3.layout(),
177 c,
178 )?;
179 Ok(from_storage(storage, shape, BackpropOp::none(), false))
180 }
181
182 pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
184 let (storage, shape) = self
185 .storage()
186 .apply_op1(self.layout(), c.as_ref().as_ref())?;
187 let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
188 Ok(from_storage(storage, shape, op, false))
189 }
190
191 pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
192 self.apply_op1_arc(Arc::new(Box::new(c)))
193 }
194
195 pub fn apply_op2_arc(
197 &self,
198 rhs: &Self,
199 c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
200 ) -> Result<Self> {
201 let (storage, shape) = self.storage().apply_op2(
202 self.layout(),
203 &rhs.storage(),
204 rhs.layout(),
205 c.as_ref().as_ref(),
206 )?;
207 let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
208 Ok(from_storage(storage, shape, op, false))
209 }
210
211 pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
212 self.apply_op2_arc(r, Arc::new(Box::new(c)))
213 }
214
215 pub fn apply_op3_arc(
217 &self,
218 t2: &Self,
219 t3: &Self,
220 c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
221 ) -> Result<Self> {
222 let (storage, shape) = self.storage().apply_op3(
223 self.layout(),
224 &t2.storage(),
225 t2.layout(),
226 &t3.storage(),
227 t3.layout(),
228 c.as_ref().as_ref(),
229 )?;
230 let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
231 Op::CustomOp3(t1, t2, t3, c.clone())
232 });
233 Ok(from_storage(storage, shape, op, false))
234 }
235
236 pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
237 &self,
238 t2: &Self,
239 t3: &Self,
240 c: C,
241 ) -> Result<Self> {
242 self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
243 }
244}
245
246pub trait InplaceOp1 {
251 fn name(&self) -> &'static str;
253
254 fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
257
258 fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
261 Err(crate::Error::Cuda(
262 format!("no cuda implementation for {}", self.name()).into(),
263 ))
264 }
265
266 fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
269 Err(crate::Error::Metal(
270 format!("no metal implementation for {}", self.name()).into(),
271 ))
272 }
273}
274
275pub trait InplaceOp2 {
276 fn name(&self) -> &'static str;
277
278 fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
281 -> Result<()>;
282
283 fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
286 Err(crate::Error::Cuda(
287 format!("no cuda implementation for {}", self.name()).into(),
288 ))
289 }
290
291 fn metal_fwd(
294 &self,
295 _: &mut MetalStorage,
296 _: &Layout,
297 _: &MetalStorage,
298 _: &Layout,
299 ) -> Result<()> {
300 Err(crate::Error::Metal(
301 format!("no metal implementation for {}", self.name()).into(),
302 ))
303 }
304}
305
306pub trait InplaceOp3 {
307 fn name(&self) -> &'static str;
308
309 fn cpu_fwd(
312 &self,
313 s1: &mut CpuStorage,
314 l1: &Layout,
315 s2: &CpuStorage,
316 l2: &Layout,
317 s3: &CpuStorage,
318 l3: &Layout,
319 ) -> Result<()>;
320
321 fn cuda_fwd(
324 &self,
325 _: &mut CudaStorage,
326 _: &Layout,
327 _: &CudaStorage,
328 _: &Layout,
329 _: &CudaStorage,
330 _: &Layout,
331 ) -> Result<()> {
332 Err(crate::Error::Cuda(
333 format!("no cuda implementation for {}", self.name()).into(),
334 ))
335 }
336
337 fn metal_fwd(
340 &self,
341 _: &mut MetalStorage,
342 _: &Layout,
343 _: &MetalStorage,
344 _: &Layout,
345 _: &MetalStorage,
346 _: &Layout,
347 ) -> Result<()> {
348 Err(crate::Error::Metal(
349 format!("no metal implementation for {}", self.name()).into(),
350 ))
351 }
352}
353
354impl Tensor {
355 pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
357 self.storage_mut().inplace_op1(self.layout(), c)
358 }
359
360 pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
362 self.storage_mut()
363 .inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
364 }
365
366 pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
368 self.storage_mut().inplace_op3(
369 self.layout(),
370 &t2.storage(),
371 t2.layout(),
372 &t3.storage(),
373 t3.layout(),
374 c,
375 )
376 }
377}
378
379pub struct UgIOp1 {
380 name: &'static str,
381 #[cfg(feature = "cuda")]
382 func: cudarc::driver::CudaFunction,
383 #[cfg(feature = "metal")]
384 func: metal::ComputePipelineState,
385}
386
387impl UgIOp1 {
388 #[allow(unused)]
389 #[cfg(not(target_arch = "wasm32"))]
390 pub fn new(
391 name: &'static str,
392 kernel: ug::lang::ssa::Kernel,
393 device: &crate::Device,
394 ) -> Result<Self> {
395 #[cfg(feature = "cuda")]
396 {
397 let device = device.as_cuda_device()?;
398 let func = device.compile(name, kernel)?;
399 Ok(Self {
400 name,
401 func: func.into_cuda_function(),
402 })
403 }
404 #[cfg(feature = "metal")]
405 {
406 let device = device.as_metal_device()?;
407 let func = device.compile(name, kernel)?;
408 Ok(Self { name, func })
409 }
410 #[cfg(not(any(feature = "cuda", feature = "metal")))]
411 {
412 Ok(Self { name })
413 }
414 }
415}
416
417impl InplaceOp1 for UgIOp1 {
418 fn name(&self) -> &'static str {
419 self.name
420 }
421
422 fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
423 crate::bail!("ug ops are only supported on metal/cuda at the moment")
424 }
425
426 #[cfg(feature = "metal")]
427 fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
428 use crate::backend::BackendStorage;
429 use candle_metal_kernels::utils::EncoderProvider;
430
431 let elem_count = layout.shape().elem_count();
432 if sto.dtype() != crate::DType::F32 {
433 crate::bail!("input is not a f32 tensor")
435 }
436 let device = sto.device();
437 println!("here");
438 let command_buffer = device.command_buffer()?;
439 let command_buffer = &command_buffer;
440 let encoder = command_buffer.encoder();
441 let encoder = encoder.as_ref();
442 encoder.set_compute_pipeline_state(&self.func);
443 let (g, b) = if elem_count % 32 == 0 {
444 (elem_count / 32, 32)
445 } else {
446 (elem_count, 1)
447 };
448 let grid_dims = metal::MTLSize {
449 width: g as u64,
450 height: 1,
451 depth: 1,
452 };
453 let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1);
454 candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize));
455
456 encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write);
457 encoder.dispatch_threads(grid_dims, group_dims);
458
459 Ok(())
460 }
461
462 #[cfg(feature = "cuda")]
463 fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
464 use crate::cuda_backend::WrapErr;
465 use cudarc::driver::PushKernelArg;
466
467 let elem_count = layout.shape().elem_count();
468 let stream = sto.device.cuda_stream();
469 let sto = sto.as_cuda_slice::<f32>()?;
471 let sto = match layout.contiguous_offsets() {
472 None => crate::bail!("input has to be contiguous"),
473 Some((o1, o2)) => sto.slice(o1..o2),
474 };
475 let (g, b) = if elem_count % 32 == 0 {
476 (elem_count / 32, 32)
477 } else {
478 (elem_count, 1)
479 };
480 let cfg = cudarc::driver::LaunchConfig {
481 grid_dim: (g as u32, 1, 1),
482 block_dim: (b as u32, 1, 1),
483 shared_mem_bytes: 0,
484 };
485 let mut builder = stream.launch_builder(&self.func);
486 builder.arg(&sto);
487 unsafe { builder.launch(cfg) }.w()?;
488 Ok(())
489 }
490}