1use tensor_rs::tensor::Tensor;
2use super::{OpTrait, OpHandle};
3
4#[cfg(feature = "use-serde")]
5use serde::{Serialize, Deserialize};
6#[cfg(feature = "use-serde")]
7use std::any::Any;
8
9pub enum Reduction{
13 None,
14 Mean,
15 Sum,
16}
17
18#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
23pub struct MSELoss {
24 #[cfg_attr(feature = "use-serde", serde(skip))]
25 handle: OpHandle,
26 }
27impl MSELoss {
28 pub fn new() -> MSELoss {
29 MSELoss {
30 handle: OpHandle::new(),
31 }
32 }
33 handle_method!();
34}
35impl OpTrait for MSELoss {
36
37
38 fn get_name(&self) -> &'static str {
39 "MSE"
40 }
41 fn get_input_size(&self) -> usize {
42 2
43 }
44 fn get_output_size(&self) -> usize {
45 1
46 }
47 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
48 let tmp = input[0].sub(&input[1]);
49 let tmp2 = tmp.mul(&tmp);
50 let tmp3 = tmp2.sum(None, false);
51 let ret = tmp3.div(&input[0].get_n().mul(&input[0].get_c()));
52 output[0].swap(&ret);
53 }
54 fn grad(&self, input: &[Tensor], output_grad: &[Tensor], input_grad: &[Tensor]) {
55
56 if input.len() < 2 {
57 panic!("MSELoss expect two input, get {}", input.len());
58 }
59 if input_grad.len() < 2 {
60 panic!("MSELoss expect two input gradient tensor, get {}", input_grad.len());
61 }
62 if output_grad.is_empty() {
63 panic!("MSELoss expect one output gradient, get {}", output_grad.len());
64 }
65 if ! input[0].same_shape(&input[1]) {
66 panic!("MSELoss expect two input have the same shape, get {:?}, {:?}", input[0].size(), input[1].size());
67 }
68
69
70 let tmp1 = input[0].sub(&input[1]);
71 let tmp2 = tmp1.div(&input[0].numel_tensor());
72 let tmp3 = tmp2.mul(&output_grad[0]);
73 input_grad[0].swap(&tmp3);
74
75 let tmp1 = input[1].sub(&input[0]);
76 let tmp2 = tmp1.div(&input[0].numel_tensor());
77 let tmp3 = tmp2.mul(&output_grad[0]);
78 input_grad[1].swap(&tmp3);
79 }
80
81 fn get_values(&self) -> Vec<Tensor> {
82 Vec::new()
83 }
84 fn set_values(&self, _v: &[Tensor]) {
85 }
86
87 fn get_grads(&self) -> Vec<Tensor> {
88 Vec::new()
89 }
90 #[cfg(feature = "use-serde")]
91 fn as_any(&self) -> &dyn Any {
92 self
93 }
94}
95impl Default for MSELoss {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101
102#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
104pub struct CrossEntropyLoss {
105 #[cfg_attr(feature = "use-serde", serde(skip))]
106 handle: OpHandle,
107}
108impl CrossEntropyLoss {
109 pub fn new() -> CrossEntropyLoss {
110 CrossEntropyLoss {
111 handle: OpHandle::new(),
112 }
113 }
114 handle_method!();
115}
116impl OpTrait for CrossEntropyLoss {
117 fn get_name(&self) -> &'static str {
118 "CrossEntropyLoss"
119 }
120 fn get_input_size(&self) -> usize {
121 2
122 }
123 fn get_output_size(&self) -> usize {
124 1
125 }
126 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
129 if input.len() != 2 {
130 panic!("{} expect two input, get {}", self.get_name(), input.len());
131 }
132 if input[0].size().len() != (input[1].size().len()+1) {
133 panic!("{} expect dim+1 and dim, get {}, {}, for now, no one-hot encoding support",
134 self.get_name(), input[0].size().len(), input[1].size().len());
135 }
136
137 let class_index = input[1].unsqueeze(1);
138 let class_score = input[0].gather(1, &class_index);
139 let val = class_score.neg().add(&input[0].logsumexp(Some(&[1]), true)).mean(None, false);
140 output[0].swap(&val);
141 }
142
143 fn grad(&self, input: &[Tensor],
147 output_grad: &[Tensor], input_grad: &[Tensor]) {
148 let n = input[0].size()[0];
149 let d = input[0].size()[1];
150 let common = input[0].sub(&input[0].logsumexp(Some(&[1]), true).repeat(&[1, d])).exp();
151
152 let class_index = input[1].unsqueeze(1);
153 let zeros = Tensor::zeros(&[n, d]);
154 let subone = Tensor::ones(&[n, 1]).neg();
155 let class_score = zeros.spread(1, &class_index, &subone);
156 input_grad[0].swap(&(class_score.add(&common)).mul(&output_grad[0])
157 .div(&Tensor::int_n(&[1], n.try_into().expect(""))))
158 }
159
160 fn get_values(&self) -> Vec<Tensor> {
162 Vec::new()
163 }
164 fn set_values(&self, _v: &[Tensor]) {
165 }
166 fn get_grads(&self) -> Vec<Tensor> {
168 Vec::new()
169 }
170 #[cfg(feature = "use-serde")]
171 fn as_any(&self) -> &dyn Any {
172 self
173 }
174}
175impl Default for CrossEntropyLoss {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180#[cfg_attr(feature = "use-serde", derive(Serialize, Deserialize))]
195pub struct BCEWithLogitsLoss {
196 #[cfg_attr(feature = "use-serde", serde(skip))]
197 handle: OpHandle,
198}
199impl BCEWithLogitsLoss {
200 pub fn new() -> BCEWithLogitsLoss {
201 BCEWithLogitsLoss {
202 handle: OpHandle::new(),
203 }
204 }
205 handle_method!();
206}
207impl OpTrait for BCEWithLogitsLoss {
208
209 fn get_name(&self) -> &'static str {
210 "BCEWithLogitsLoss"
211 }
212 fn get_input_size(&self) -> usize {
213 2
214 }
215 fn get_output_size(&self) -> usize {
216 1
217 }
218 fn apply(&self, input: &[Tensor], output: &[Tensor]) {
221 if input.len() != self.get_input_size() {
222 panic!("{} expect two input, get {}", self.get_name(), input.len());
223 }
224 let ret_all = input[1].mul(&input[0].neg().log1pexp())
225 .add(&(input[1].neg().add(&input[1].ones_like())).mul(&input[0].log1pexp()));
226 let tmp3 = ret_all.sum(None, false);
227 let ret = tmp3.div(&input[0].get_n().mul(&input[0].get_c()));
228 output[0].swap(&ret);
229 }
230
231 fn grad(&self, input: &[Tensor],
235 output_grad: &[Tensor],
236 input_grad: &[Tensor]) {
237 let ones = Tensor::ones_like(&input[0]);
240 let tmp1 = input[1].neg().div(&input[0].exp().add(&ones));
241 let tmp2 = input[1].neg().add(&ones).div(&input[0].neg().exp().add(&ones));
242 let tmp3 = tmp1.add(&tmp2);
243 let tmp4 = tmp3.mul(&output_grad[0]);
244
245 let zeros = Tensor::zeros_like(&input[0]);
246 input_grad[0].swap(&tmp4);
247 input_grad[1].swap(&zeros);
248 }
249
250 fn get_values(&self) -> Vec<Tensor> {
252 Vec::new()
253 }
254
255 fn set_values(&self, _v: &[Tensor]) {
256
257 }
258
259 fn get_grads(&self) -> Vec<Tensor> {
261 Vec::new()
262 }
263 #[cfg(feature = "use-serde")]
264 fn as_any(&self) -> &dyn Any {
265 self
266 }
267}
268impl Default for BCEWithLogitsLoss {
269 fn default() -> Self {
270 Self::new()
271 }
272}
273
274#[cfg(test)]
286mod tests {
287 use super::*;
288 use crate::op::_gradient_checker;
289
290 #[test]
291 fn test_cross_entropy_loss() {
292 let a = Tensor::from_vec_f64(&vec![1., 2., 3., 4., 5., 6., ], &vec![3, 2]);
293 let b = Tensor::from_vec_f64(&vec![0., 0., 1., ], &vec![3]);
294 let c = CrossEntropyLoss::new();
295 let d = Tensor::new();
296 c.apply(&[a.ref_copy(), b.ref_copy()], &[d.ref_copy()]);
297 assert!((d.get_scale_f64() - 0.97992826).abs() < 0.001);
298
299 let a = Tensor::from_vec_f64(&vec![0.1, 0.1, 10., 10., 0.1, 0.1], &[2, 3]);
300 let b = Tensor::from_vec_f64(&vec![2., 0., ], &vec![2]);
301 let c = CrossEntropyLoss::new();
302 let d = Tensor::new();
303 c.apply(&[a.ref_copy(), b.ref_copy()], &[d.ref_copy()]);
304 println!("{:?}", d);
305
306 let a = Tensor::from_vec_f64(&vec![0.1, 0.1, 10., 10., 0.1, 0.1], &[2, 3]);
307 let b = Tensor::from_vec_f64(&vec![0., 2., ], &vec![2]);
308 let mut c = CrossEntropyLoss::new();
309 let d = Tensor::new();
310 c.apply(&[a.ref_copy(), b.ref_copy()], &[d.ref_copy()]);
311 println!("{:?}", d);
312
313 assert!(_gradient_checker(&mut c, &[a, b], Some(&[true, false]), None, None));
314 }
315}