nncombinator/cuda/kernel/
lossfunction.rs

1//! Implementation of various loss functions
2use std::marker::PhantomData;
3use libc::{c_int, c_void};
4use crate::cuda::{AsKernelPtr, CudaConstPtr, CudaTensor1dPtr, CudaTensor1dPtrView, CudaVec, CudaVecView, DataTypeInfo, Kernel, KernelArgs};
5use crate::ope::UnitValue;
6
7extern "C" {
8    fn loss_linear_batch_mse_derive_float(r: *const f32, t: *const f32, output: *mut f32, nlen: c_int, batch_size: c_int) -> c_void;
9    fn loss_linear_batch_mse_derive_double(r: *const f64, t: *const f64, output: *mut f64, nlen: c_int, batch_size: c_int) -> c_void;
10    fn loss_linear_batch_cross_entropy_derive_float(r: *const f32, t: *const f32, output: *mut f32, nlen: c_int, batch_size: c_int) -> c_void;
11    fn loss_linear_batch_cross_entropy_derive_double(r: *const f64, t: *const f64, output: *mut f64, nlen: c_int, batch_size: c_int) -> c_void;
12    fn loss_linear_batch_cross_entropy_multiclass_derive_float(r: *const f32, t: *const f32, output: *mut f32, nlen: c_int, batch_size: c_int) -> c_void;
13    fn loss_linear_batch_cross_entropy_multiclass_derive_double(r: *const f64, t: *const f64, output: *mut f64, nlen: c_int, batch_size: c_int) -> c_void;
14}
15/// Define a list to be passed to the cuda kernel function during mini-batch execution as the argument of mse.
16pub struct LinearBatchMseArgs<'a,T,const N:usize> where T: DataTypeInfo + UnitValue<T> {
17    /// expected value
18    expected: CudaConstPtr<'a,CudaVecView<'a,T,CudaTensor1dPtr<T,N>>>,
19    /// actual value
20    actual: CudaConstPtr<'a,CudaVecView<'a,T,CudaTensor1dPtr<T,N>>>,
21    pub output: CudaVec<T,CudaTensor1dPtr<T,N>>,
22    out_len: usize,
23    batch_len: usize,
24}
25/// Create an instance of an object representing the list of arguments to
26/// compute the loss function mse during mini-batch execution.
27impl<'a,T,const N:usize> LinearBatchMseArgs<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
28    /// Create a LinearBatchMseArgs instance
29    /// # Arguments
30    /// * `expected` - Expected Value
31    /// * `actual` - Actual Value
32    /// * `out_len` - Number of scalar values in output
33    /// * `batch_len` - batch count
34    pub fn new(t:&'a CudaVecView<'a,T,CudaTensor1dPtr<T,N>>,r:&'a CudaVecView<'a,T,CudaTensor1dPtr<T,N>>,
35               output: CudaVec<T,CudaTensor1dPtr<T,N>>,out_len:usize,batch_len:usize) -> LinearBatchMseArgs<'a,T,N> {
36        LinearBatchMseArgs {
37            expected: CudaConstPtr::new(t),
38            actual: CudaConstPtr::new(r),
39            output: output,
40            out_len: out_len,
41            batch_len: batch_len
42        }
43    }
44}
45impl<'a,T,const N:usize> KernelArgs for LinearBatchMseArgs<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
46    fn as_vec(&mut self) -> Vec<&mut dyn AsKernelPtr> {
47        vec![
48            &mut self.expected,
49            &mut self.actual,
50            &mut self.output,
51            &mut self.out_len,
52            &mut self.batch_len
53        ]
54    }
55}
56pub struct LinearBatchMse<'a,T,const N:usize> where T: DataTypeInfo + UnitValue<T> {
57    t:PhantomData<T>,
58    n:PhantomData<[();N]>,
59    l:PhantomData<&'a ()>
60}
61impl<'a,T,const N:usize> LinearBatchMse<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
62    /// Create a LinearBatchMse instance
63    pub fn new() -> LinearBatchMse<'a,T,N> {
64        LinearBatchMse {
65            t: PhantomData::<T>,
66            n: PhantomData::<[();N]>,
67            l: PhantomData::<&'a ()>
68        }
69    }
70}
71impl<'a,const N:usize> Kernel for LinearBatchMse<'a,f32,N> {
72    const FUNC_PTR: *const c_void = loss_linear_batch_mse_derive_float as *const c_void;
73    type Args = LinearBatchMseArgs<'a,f32,N>;
74}
75impl<'a,const N:usize> Kernel for LinearBatchMse<'a,f64,N> {
76    const FUNC_PTR: *const c_void = loss_linear_batch_mse_derive_double as *const c_void;
77    type Args = LinearBatchMseArgs<'a,f64,N>;
78}
79/// Defines the list passed to the cuda kernel function as the argument of mse.
80pub struct LinearMseArgs<'a,T,const N:usize> where T: DataTypeInfo + UnitValue<T> {
81    /// expected value
82    expected: CudaConstPtr<'a,CudaTensor1dPtrView<'a,T,N>>,
83    /// actual value
84    actual: CudaConstPtr<'a,CudaTensor1dPtrView<'a,T,N>>,
85    pub output: CudaTensor1dPtr<T,N>,
86    out_len: usize,
87    batch_len: usize,
88}
89/// Create an instance of an object representing the argument list for computing the loss function mse.
90impl<'a,T,const N:usize> LinearMseArgs<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
91    /// Create a LinearMseArgs instance
92    /// # Arguments
93    /// * `expected` - Expected Value
94    /// * `actual` - Actual Value
95    /// * `out_len` - Number of scalar values in output
96    pub fn new(t:&'a CudaTensor1dPtrView<'a,T,N>,
97               r:&'a CudaTensor1dPtrView<'a,T,N>,
98               output: CudaTensor1dPtr<T,N>,
99               out_len:usize) -> LinearMseArgs<'a,T,N> {
100        LinearMseArgs {
101            expected: CudaConstPtr::new(t),
102            actual: CudaConstPtr::new(r),
103            output: output,
104            out_len: out_len,
105            batch_len: 1
106        }
107    }
108}
109impl<'a,T,const N:usize> KernelArgs for LinearMseArgs<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
110    fn as_vec(&mut self) -> Vec<&mut dyn AsKernelPtr> {
111        vec![
112            &mut self.expected,
113            &mut self.actual,
114            &mut self.output,
115            &mut self.out_len,
116            &mut self.batch_len
117        ]
118    }
119}
120pub struct LinearMse<'a,T,const N:usize> where T: DataTypeInfo + UnitValue<T> {
121    t:PhantomData<T>,
122    n:PhantomData<[();N]>,
123    l:PhantomData<&'a ()>
124}
125impl<'a,T,const N:usize> LinearMse<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
126    /// Create a LinearMse instance
127    pub fn new() -> LinearMse<'a,T,N> {
128        LinearMse {
129            t: PhantomData::<T>,
130            n: PhantomData::<[();N]>,
131            l: PhantomData::<&'a ()>
132        }
133    }
134}
135impl<'a,const N:usize> Kernel for LinearMse<'a,f32,N> {
136    const FUNC_PTR: *const c_void = loss_linear_batch_mse_derive_float as *const c_void;
137    type Args = LinearMseArgs<'a,f32,N>;
138}
139impl<'a,const N:usize> Kernel for LinearMse<'a,f64,N> {
140    const FUNC_PTR: *const c_void = loss_linear_batch_mse_derive_double as *const c_void;
141    type Args = LinearMseArgs<'a,f64,N>;
142}
143/// Defines the list that is passed to the cuda kernel function as cross-entropy arguments during mini-batch execution.
144pub struct LinearBatchCrossEntropyArgs<'a,T,const N:usize> where T: DataTypeInfo + UnitValue<T> {
145    /// expected value
146    expected: CudaConstPtr<'a,CudaVecView<'a,T,CudaTensor1dPtr<T,N>>>,
147    /// actual value
148    actual: CudaConstPtr<'a,CudaVecView<'a,T,CudaTensor1dPtr<T,N>>>,
149    pub output: CudaVec<T,CudaTensor1dPtr<T,N>>,
150    out_len: usize,
151    batch_len: usize,
152}
153/// Create an instance of an object representing a list of arguments to calculate
154/// the result of passing a mini-batch to the loss function cross entropy.
155impl<'a,T,const N:usize> LinearBatchCrossEntropyArgs<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
156    /// Create a LinearBatchCrossEntropyArgs instance
157    /// # Arguments
158    /// * `expected` - Expected Value
159    /// * `actual` - Actual Value
160    /// * `out_len` - Number of scalar values in output
161    /// * `batch_len` - batch count
162    pub fn new(t:&'a CudaVecView<'a,T,CudaTensor1dPtr<T,N>>,
163               r:&'a CudaVecView<'a,T,CudaTensor1dPtr<T,N>>,
164               output: CudaVec<T,CudaTensor1dPtr<T,N>>,
165               out_len:usize,batch_len:usize) -> LinearBatchCrossEntropyArgs<'a,T,N> {
166        LinearBatchCrossEntropyArgs {
167            expected: CudaConstPtr::new(t),
168            actual: CudaConstPtr::new(r),
169            output: output,
170            out_len: out_len,
171            batch_len: batch_len
172        }
173    }
174}
175impl<'a,T,const N:usize> KernelArgs for LinearBatchCrossEntropyArgs<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
176    fn as_vec(&mut self) -> Vec<&mut dyn AsKernelPtr> {
177        vec![
178            &mut self.expected,
179            &mut self.actual,
180            &mut self.output,
181            &mut self.out_len,
182            &mut self.batch_len
183        ]
184    }
185}
186pub struct LinearBatchCrossEntropy<'a,T,const N:usize> where T: DataTypeInfo + UnitValue<T> {
187    t:PhantomData<T>,
188    n:PhantomData<[();N]>,
189    l:PhantomData<&'a ()>
190}
191impl<'a,T,const N:usize> LinearBatchCrossEntropy<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
192    /// Create a LinearBatchCrossEntropy instance
193    pub fn new() -> LinearBatchCrossEntropy<'a,T,N> {
194        LinearBatchCrossEntropy {
195            t: PhantomData::<T>,
196            n: PhantomData::<[();N]>,
197            l: PhantomData::<&'a ()>
198        }
199    }
200}
201impl<'a,const N:usize> Kernel for LinearBatchCrossEntropy<'a,f32,N> {
202    const FUNC_PTR: *const c_void = loss_linear_batch_cross_entropy_derive_float as *const c_void;
203    type Args = LinearBatchCrossEntropyArgs<'a,f32,N>;
204}
205impl<'a,const N:usize> Kernel for LinearBatchCrossEntropy<'a,f64,N> {
206    const FUNC_PTR: *const c_void = loss_linear_batch_cross_entropy_derive_double as *const c_void;
207    type Args = LinearBatchCrossEntropyArgs<'a,f64,N>;
208}
209/// Defines the list passed to the cuda kernel function as the argument of cross entropy.
210pub struct LinearCrossEntropyArgs<'a,T,const N:usize> where T: DataTypeInfo + UnitValue<T> {
211    /// expected value
212    expected: CudaConstPtr<'a,CudaTensor1dPtrView<'a,T,N>>,
213    /// actual value
214    actual: CudaConstPtr<'a,CudaTensor1dPtrView<'a,T,N>>,
215    pub output: CudaTensor1dPtr<T,N>,
216    out_len: usize,
217    batch_len: usize,
218}
219/// Create an instance of an object representing the argument list for computing the loss function cross entropy.
220impl<'a,T,const N:usize> LinearCrossEntropyArgs<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
221    /// Create a LinearCrossEntropyArgs instance
222    /// # Arguments
223    /// * `expected` - Expected Value
224    /// * `actual` - Actual Value
225    /// * `out_len` - Number of scalar values in output
226    pub fn new(t:&'a CudaTensor1dPtrView<'a,T,N>,
227               r:&'a CudaTensor1dPtrView<'a,T,N>,
228               output: CudaTensor1dPtr<T,N>,
229               out_len:usize) -> LinearCrossEntropyArgs<'a, T, N> {
230        LinearCrossEntropyArgs {
231            expected: CudaConstPtr::new(t),
232            actual: CudaConstPtr::new(r),
233            output: output,
234            out_len: out_len,
235            batch_len: 1
236        }
237    }
238}
239impl<'a,T,const N:usize> KernelArgs for LinearCrossEntropyArgs<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
240    fn as_vec(&mut self) -> Vec<&mut dyn AsKernelPtr> {
241        vec![
242            &mut self.expected,
243            &mut self.actual,
244            &mut self.output,
245            &mut self.out_len,
246            &mut self.batch_len
247        ]
248    }
249}
250pub struct LinearCrossEntropy<'a,T,const N:usize> where T: DataTypeInfo + UnitValue<T> {
251    t:PhantomData<T>,
252    n:PhantomData<[();N]>,
253    l:PhantomData<&'a ()>
254}
255impl<'a,T,const N:usize> LinearCrossEntropy<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
256    /// Create a LinearCrossEntropy instance
257    pub fn new() -> LinearCrossEntropy<'a,T,N> {
258        LinearCrossEntropy {
259            t: PhantomData::<T>,
260            n: PhantomData::<[();N]>,
261            l: PhantomData::<&'a ()>
262        }
263    }
264}
265impl<'a,const N:usize> Kernel for LinearCrossEntropy<'a,f32,N> {
266    const FUNC_PTR: *const c_void = loss_linear_batch_cross_entropy_derive_float as *const c_void;
267    type Args = LinearCrossEntropyArgs<'a,f32,N>;
268}
269impl<'a,const N:usize> Kernel for LinearCrossEntropy<'a,f64,N> {
270    const FUNC_PTR: *const c_void = loss_linear_batch_cross_entropy_derive_double as *const c_void;
271    type Args = LinearCrossEntropyArgs<'a,f64,N>;
272}
273/// Defines the list that is passed to the cuda kernel function as arguments
274/// to the croos entropy multiclass during mini-batch execution.
275pub struct LinearBatchCrossEntropyMulticlassArgs<'a,T,const N:usize> where T: DataTypeInfo + UnitValue<T> {
276    /// expected value
277    expected: CudaConstPtr<'a,CudaVecView<'a,T,CudaTensor1dPtr<T,N>>>,
278    /// actual value
279    actual: CudaConstPtr<'a,CudaVecView<'a,T,CudaTensor1dPtr<T,N>>>,
280    pub output: CudaVec<T,CudaTensor1dPtr<T,N>>,
281    out_len: usize,
282    batch_len: usize,
283}
284/// Create an instance of an object representing a list of arguments to compute the result of passing a mini-batch
285/// to the loss function cross entropy multiclass.
286impl<'a,T,const N:usize> LinearBatchCrossEntropyMulticlassArgs<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
287    /// Create a LinearBatchCrossEntropyMulticlassArgs instance
288    /// # Arguments
289    /// * `expected` - Expected Value
290    /// * `actual` - Actual Value
291    /// * `out_len` - Number of scalar values in output
292    /// * `batch_len` - batch count
293    pub fn new(t:&'a CudaVecView<'a,T,CudaTensor1dPtr<T,N>>,
294               r:&'a CudaVecView<'a,T,CudaTensor1dPtr<T,N>>,
295               output: CudaVec<T,CudaTensor1dPtr<T,N>>,
296               out_len:usize,batch_len:usize) -> LinearBatchCrossEntropyMulticlassArgs<'a,T,N> {
297        LinearBatchCrossEntropyMulticlassArgs {
298            expected: CudaConstPtr::new(t),
299            actual: CudaConstPtr::new(r),
300            output: output,
301            out_len: out_len,
302            batch_len: batch_len
303        }
304    }
305}
306impl<'a,T,const N:usize> KernelArgs for LinearBatchCrossEntropyMulticlassArgs<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
307    fn as_vec(&mut self) -> Vec<&mut dyn AsKernelPtr> {
308        vec![
309            &mut self.expected,
310            &mut self.actual,
311            &mut self.output,
312            &mut self.out_len,
313            &mut self.batch_len
314        ]
315    }
316}
317pub struct LinearBatchCrossEntropyMulticlass<'a,T,const N:usize> where T: DataTypeInfo + UnitValue<T> {
318    t:PhantomData<T>,
319    n:PhantomData<[();N]>,
320    l:PhantomData<&'a ()>
321}
322impl<'a,T,const N:usize> LinearBatchCrossEntropyMulticlass<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
323    /// Create a LinearBatchCrossEntropyMulticlass instance
324    pub fn new() -> LinearBatchCrossEntropyMulticlass<'a,T,N> {
325        LinearBatchCrossEntropyMulticlass {
326            t: PhantomData::<T>,
327            n: PhantomData::<[();N]>,
328            l: PhantomData::<&'a ()>
329        }
330    }
331}
332impl<'a,const N:usize> Kernel for LinearBatchCrossEntropyMulticlass<'a,f32,N> {
333    const FUNC_PTR: *const c_void = loss_linear_batch_cross_entropy_multiclass_derive_float as *const c_void;
334    type Args = LinearBatchCrossEntropyMulticlassArgs<'a,f32,N>;
335}
336impl<'a,const N:usize> Kernel for LinearBatchCrossEntropyMulticlass<'a,f64,N> {
337    const FUNC_PTR: *const c_void = loss_linear_batch_cross_entropy_multiclass_derive_double as *const c_void;
338    type Args = LinearBatchCrossEntropyMulticlassArgs<'a,f64,N>;
339}
340/// Defines the list passed to the cuda kernel function as the argument of croos entropy multiclass
341pub struct LinearCrossEntropyMulticlassArgs<'a,T,const N:usize> where T: DataTypeInfo + UnitValue<T> {
342    /// expected value
343    expected: CudaConstPtr<'a,CudaTensor1dPtrView<'a,T,N>>,
344    /// actual value
345    actual: CudaConstPtr<'a,CudaTensor1dPtrView<'a,T,N>>,
346    pub output: CudaTensor1dPtr<T,N>,
347    out_len: usize,
348    batch_len: usize,
349}
350/// Create an instance of an object representing the argument list for computing the loss function cross entropy multiclass.
351impl<'a,T,const N:usize> LinearCrossEntropyMulticlassArgs<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
352    /// Create a LinearCrossEntropyMulticlassArgs instance
353    /// # Arguments
354    /// * `expected` - Expected Value
355    /// * `actual` - Actual Value
356    /// * `out_len` - Number of scalar values in output
357    pub fn new(t:&'a CudaTensor1dPtrView<T,N>,
358               r:&'a CudaTensor1dPtrView<'a,T,N>,
359               output: CudaTensor1dPtr<T,N>,
360               out_len:usize) -> LinearCrossEntropyMulticlassArgs<'a,T,N> {
361        LinearCrossEntropyMulticlassArgs {
362            expected: CudaConstPtr::new(t),
363            actual: CudaConstPtr::new(r),
364            output: output,
365            out_len: out_len,
366            batch_len: 1
367        }
368    }
369}
370impl<'a,T,const N:usize> KernelArgs for LinearCrossEntropyMulticlassArgs<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
371    fn as_vec(&mut self) -> Vec<&mut dyn AsKernelPtr> {
372        vec![
373            &mut self.expected,
374            &mut self.actual,
375            &mut self.output,
376            &mut self.out_len,
377            &mut self.batch_len
378        ]
379    }
380}
381pub struct LinearCrossEntropyMulticlass<'a,T,const N:usize> where T: DataTypeInfo + UnitValue<T> {
382    t:PhantomData<T>,
383    n:PhantomData<[();N]>,
384    l:PhantomData<&'a ()>
385}
386impl<'a,T,const N:usize> LinearCrossEntropyMulticlass<'a,T,N> where T: DataTypeInfo + UnitValue<T> {
387    /// Create a LinearCrossEntropyMulticlass instance
388    pub fn new() -> LinearCrossEntropyMulticlass<'a,T,N> {
389        LinearCrossEntropyMulticlass {
390            t: PhantomData::<T>,
391            n: PhantomData::<[();N]>,
392            l: PhantomData::<&'a ()>
393        }
394    }
395}
396impl<'a,const N:usize> Kernel for LinearCrossEntropyMulticlass<'a,f32,N> {
397    const FUNC_PTR: *const c_void = loss_linear_batch_cross_entropy_multiclass_derive_float as *const c_void;
398    type Args = LinearCrossEntropyMulticlassArgs<'a, f32, N>;
399}
400impl<'a,const N:usize> Kernel for LinearCrossEntropyMulticlass<'a,f64,N> {
401    const FUNC_PTR: *const c_void = loss_linear_batch_cross_entropy_multiclass_derive_double as *const c_void;
402    type Args = LinearCrossEntropyMulticlassArgs<'a,f64,N>;
403}