1use crate::op::{BackpropOp, Op};
2use crate::tensor::from_storage;
3use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
4use std::sync::Arc;
5
6pub trait CustomOp1 {
8 fn name(&self) -> &'static str;
10
11 fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
14
15 fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
18 Err(crate::Error::Cuda(
19 format!("no cuda implementation for {}", self.name()).into(),
20 ))
21 }
22
23 fn metal_fwd(
26 &self,
27 _storage: &MetalStorage,
28 _layout: &Layout,
29 ) -> Result<(MetalStorage, Shape)> {
30 Err(crate::Error::Metal(
31 format!("no metal implementation for {}", self.name()).into(),
32 ))
33 }
34
35 fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
39 Err(crate::Error::BackwardNotSupported { op: self.name() })
40 }
41}
42
43pub trait CustomOp2 {
44 fn name(&self) -> &'static str;
45
46 fn cpu_fwd(
49 &self,
50 s1: &CpuStorage,
51 l1: &Layout,
52 s2: &CpuStorage,
53 l2: &Layout,
54 ) -> Result<(CpuStorage, Shape)>;
55
56 fn cuda_fwd(
59 &self,
60 _: &CudaStorage,
61 _: &Layout,
62 _: &CudaStorage,
63 _: &Layout,
64 ) -> Result<(CudaStorage, Shape)> {
65 Err(crate::Error::Cuda(
66 format!("no cuda implementation for {}", self.name()).into(),
67 ))
68 }
69
70 fn metal_fwd(
73 &self,
74 _: &MetalStorage,
75 _: &Layout,
76 _: &MetalStorage,
77 _: &Layout,
78 ) -> Result<(MetalStorage, Shape)> {
79 Err(crate::Error::Metal(
80 format!("no metal implementation for {}", self.name()).into(),
81 ))
82 }
83
84 fn bwd(
85 &self,
86 _arg1: &Tensor,
87 _arg2: &Tensor,
88 _res: &Tensor,
89 _grad_res: &Tensor,
90 ) -> Result<(Option<Tensor>, Option<Tensor>)> {
91 Err(crate::Error::BackwardNotSupported { op: self.name() })
92 }
93}
94
95pub trait CustomOp3 {
96 fn name(&self) -> &'static str;
97
98 fn cpu_fwd(
101 &self,
102 s1: &CpuStorage,
103 l1: &Layout,
104 s2: &CpuStorage,
105 l2: &Layout,
106 s3: &CpuStorage,
107 l3: &Layout,
108 ) -> Result<(CpuStorage, Shape)>;
109
110 fn cuda_fwd(
113 &self,
114 _: &CudaStorage,
115 _: &Layout,
116 _: &CudaStorage,
117 _: &Layout,
118 _: &CudaStorage,
119 _: &Layout,
120 ) -> Result<(CudaStorage, Shape)> {
121 Err(crate::Error::Cuda(
122 format!("no cuda implementation for {}", self.name()).into(),
123 ))
124 }
125
126 fn metal_fwd(
129 &self,
130 _: &MetalStorage,
131 _: &Layout,
132 _: &MetalStorage,
133 _: &Layout,
134 _: &MetalStorage,
135 _: &Layout,
136 ) -> Result<(MetalStorage, Shape)> {
137 Err(crate::Error::Metal(
138 format!("no metal implementation for {}", self.name()).into(),
139 ))
140 }
141
142 fn bwd(
143 &self,
144 _arg1: &Tensor,
145 _arg2: &Tensor,
146 _arg3: &Tensor,
147 _res: &Tensor,
148 _grad_res: &Tensor,
149 ) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
150 Err(crate::Error::BackwardNotSupported { op: self.name() })
151 }
152}
153
154impl Tensor {
155 pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
157 let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
158 Ok(from_storage(storage, shape, BackpropOp::none(), false))
159 }
160
161 pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
163 let (storage, shape) =
164 self.storage()
165 .apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
166 Ok(from_storage(storage, shape, BackpropOp::none(), false))
167 }
168
169 pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
171 let (storage, shape) = self.storage().apply_op3(
172 self.layout(),
173 &t2.storage(),
174 t2.layout(),
175 &t3.storage(),
176 t3.layout(),
177 c,
178 )?;
179 Ok(from_storage(storage, shape, BackpropOp::none(), false))
180 }
181
182 pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
184 let (storage, shape) = self
185 .storage()
186 .apply_op1(self.layout(), c.as_ref().as_ref())?;
187 let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
188 Ok(from_storage(storage, shape, op, false))
189 }
190
191 pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
192 self.apply_op1_arc(Arc::new(Box::new(c)))
193 }
194
195 pub fn apply_op2_arc(
197 &self,
198 rhs: &Self,
199 c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
200 ) -> Result<Self> {
201 let (storage, shape) = self.storage().apply_op2(
202 self.layout(),
203 &rhs.storage(),
204 rhs.layout(),
205 c.as_ref().as_ref(),
206 )?;
207 let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
208 Ok(from_storage(storage, shape, op, false))
209 }
210
211 pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
212 self.apply_op2_arc(r, Arc::new(Box::new(c)))
213 }
214
215 pub fn apply_op3_arc(
217 &self,
218 t2: &Self,
219 t3: &Self,
220 c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
221 ) -> Result<Self> {
222 let (storage, shape) = self.storage().apply_op3(
223 self.layout(),
224 &t2.storage(),
225 t2.layout(),
226 &t3.storage(),
227 t3.layout(),
228 c.as_ref().as_ref(),
229 )?;
230 let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
231 Op::CustomOp3(t1, t2, t3, c.clone())
232 });
233 Ok(from_storage(storage, shape, op, false))
234 }
235
236 pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
237 &self,
238 t2: &Self,
239 t3: &Self,
240 c: C,
241 ) -> Result<Self> {
242 self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
243 }
244}
245
246pub trait InplaceOp1 {
251 fn name(&self) -> &'static str;
253
254 fn cpu_fwd(&self, storage: &mut CpuStorage, layout: &Layout) -> Result<()>;
257
258 fn cuda_fwd(&self, _storage: &mut CudaStorage, _layout: &Layout) -> Result<()> {
261 Err(crate::Error::Cuda(
262 format!("no cuda implementation for {}", self.name()).into(),
263 ))
264 }
265
266 fn metal_fwd(&self, _storage: &mut MetalStorage, _layout: &Layout) -> Result<()> {
269 Err(crate::Error::Metal(
270 format!("no metal implementation for {}", self.name()).into(),
271 ))
272 }
273}
274
275pub trait InplaceOp2 {
276 fn name(&self) -> &'static str;
277
278 fn cpu_fwd(&self, s1: &mut CpuStorage, l1: &Layout, s2: &CpuStorage, l2: &Layout)
281 -> Result<()>;
282
283 fn cuda_fwd(&self, _: &mut CudaStorage, _: &Layout, _: &CudaStorage, _: &Layout) -> Result<()> {
286 Err(crate::Error::Cuda(
287 format!("no cuda implementation for {}", self.name()).into(),
288 ))
289 }
290
291 fn metal_fwd(
294 &self,
295 _: &mut MetalStorage,
296 _: &Layout,
297 _: &MetalStorage,
298 _: &Layout,
299 ) -> Result<()> {
300 Err(crate::Error::Metal(
301 format!("no metal implementation for {}", self.name()).into(),
302 ))
303 }
304}
305
306pub trait InplaceOp3 {
307 fn name(&self) -> &'static str;
308
309 fn cpu_fwd(
312 &self,
313 s1: &mut CpuStorage,
314 l1: &Layout,
315 s2: &CpuStorage,
316 l2: &Layout,
317 s3: &CpuStorage,
318 l3: &Layout,
319 ) -> Result<()>;
320
321 fn cuda_fwd(
324 &self,
325 _: &mut CudaStorage,
326 _: &Layout,
327 _: &CudaStorage,
328 _: &Layout,
329 _: &CudaStorage,
330 _: &Layout,
331 ) -> Result<()> {
332 Err(crate::Error::Cuda(
333 format!("no cuda implementation for {}", self.name()).into(),
334 ))
335 }
336
337 fn metal_fwd(
340 &self,
341 _: &mut MetalStorage,
342 _: &Layout,
343 _: &MetalStorage,
344 _: &Layout,
345 _: &MetalStorage,
346 _: &Layout,
347 ) -> Result<()> {
348 Err(crate::Error::Metal(
349 format!("no metal implementation for {}", self.name()).into(),
350 ))
351 }
352}
353
354impl Tensor {
355 pub fn inplace_op1<C: InplaceOp1>(&self, c: &C) -> Result<()> {
357 self.storage_mut().inplace_op1(self.layout(), c)
358 }
359
360 pub fn inplace_op2<C: InplaceOp2>(&self, rhs: &Self, c: &C) -> Result<()> {
362 self.storage_mut()
363 .inplace_op2(self.layout(), &rhs.storage(), rhs.layout(), c)
364 }
365
366 pub fn inplace_op3<C: InplaceOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<()> {
368 self.storage_mut().inplace_op3(
369 self.layout(),
370 &t2.storage(),
371 t2.layout(),
372 &t3.storage(),
373 t3.layout(),
374 c,
375 )
376 }
377}
378
379pub struct UgIOp1 {
380 name: &'static str,
381 #[cfg(feature = "cuda")]
382 func: cudarc::driver::CudaFunction,
383 #[cfg(feature = "metal")]
384 func: metal::ComputePipelineState,
385}
386
387impl UgIOp1 {
388 #[allow(unused)]
389 #[cfg(not(target_arch = "wasm32"))]
390 pub fn new(
391 name: &'static str,
392 kernel: ug::lang::ssa::Kernel,
393 device: &crate::Device,
394 ) -> Result<Self> {
395 #[cfg(feature = "cuda")]
396 {
397 let device = device.as_cuda_device()?;
398 let func = device.compile(name, kernel)?;
399 Ok(Self { name, func })
400 }
401 #[cfg(feature = "metal")]
402 {
403 let device = device.as_metal_device()?;
404 let func = device.compile(name, kernel)?;
405 Ok(Self { name, func })
406 }
407 #[cfg(not(any(feature = "cuda", feature = "metal")))]
408 {
409 Ok(Self { name })
410 }
411 }
412}
413
414impl InplaceOp1 for UgIOp1 {
415 fn name(&self) -> &'static str {
416 self.name
417 }
418
419 fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
420 crate::bail!("ug ops are only supported on metal/cuda at the moment")
421 }
422
423 #[cfg(feature = "metal")]
424 fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
425 use crate::backend::BackendStorage;
426 use candle_metal_kernels::utils::EncoderProvider;
427
428 let elem_count = layout.shape().elem_count();
429 if sto.dtype() != crate::DType::F32 {
430 crate::bail!("input is not a f32 tensor")
432 }
433 let device = sto.device();
434 println!("here");
435 let command_buffer = device.command_buffer()?;
436 let command_buffer = &command_buffer;
437 let encoder = command_buffer.encoder();
438 let encoder = encoder.as_ref();
439 encoder.set_compute_pipeline_state(&self.func);
440 let (g, b) = if elem_count % 32 == 0 {
441 (elem_count / 32, 32)
442 } else {
443 (elem_count, 1)
444 };
445 let grid_dims = metal::MTLSize {
446 width: g as u64,
447 height: 1,
448 depth: 1,
449 };
450 let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1);
451 candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize));
452
453 encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write);
454 encoder.dispatch_threads(grid_dims, group_dims);
455
456 Ok(())
457 }
458
459 #[cfg(feature = "cuda")]
460 fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
461 use crate::cuda_backend::WrapErr;
462 use cudarc::driver::LaunchAsync;
463
464 let elem_count = layout.shape().elem_count();
465 let sto = sto.as_cuda_slice::<f32>()?;
467 let sto = match layout.contiguous_offsets() {
468 None => crate::bail!("input has to be contiguous"),
469 Some((o1, o2)) => sto.slice(o1..o2),
470 };
471 let params = (&sto,);
472 let (g, b) = if elem_count % 32 == 0 {
473 (elem_count / 32, 32)
474 } else {
475 (elem_count, 1)
476 };
477 let cfg = cudarc::driver::LaunchConfig {
478 grid_dim: (g as u32, 1, 1),
479 block_dim: (b as u32, 1, 1),
480 shared_mem_bytes: 0,
481 };
482 unsafe { self.func.clone().launch(cfg, params) }.w()?;
483 Ok(())
484 }
485}