1use axonml_autograd::Variable;
9use axonml_tensor::Tensor;
10
11use crate::module::Module;
12
13#[derive(Debug, Clone, Copy, PartialEq, Default)]
19pub enum Reduction {
20 None,
22 #[default]
24 Mean,
25 Sum,
27}
28
29#[derive(Debug, Clone, Copy)]
37pub struct MSELoss {
38 reduction: Reduction,
39}
40
41impl MSELoss {
42 pub fn new() -> Self {
44 Self {
45 reduction: Reduction::Mean,
46 }
47 }
48
49 pub fn with_reduction(reduction: Reduction) -> Self {
51 Self { reduction }
52 }
53
54 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
56 let diff = input.sub_var(target);
57 let squared = diff.pow(2.0);
58
59 match self.reduction {
60 Reduction::None => squared,
61 Reduction::Mean => squared.mean(),
62 Reduction::Sum => squared.sum(),
63 }
64 }
65}
66
67impl Default for MSELoss {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl Module for MSELoss {
74 fn forward(&self, input: &Variable) -> Variable {
75 input.clone()
78 }
79
80 fn name(&self) -> &'static str {
81 "MSELoss"
82 }
83}
84
85#[derive(Debug, Clone, Copy)]
93pub struct L1Loss {
94 reduction: Reduction,
95}
96
97impl L1Loss {
98 pub fn new() -> Self {
100 Self {
101 reduction: Reduction::Mean,
102 }
103 }
104
105 pub fn with_reduction(reduction: Reduction) -> Self {
107 Self { reduction }
108 }
109
110 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
112 let diff = input.sub_var(target);
113 let diff_data = diff.data();
114 let abs_data: Vec<f32> = diff_data.to_vec().iter().map(|x| x.abs()).collect();
115 let abs_tensor = Tensor::from_vec(abs_data, diff_data.shape()).unwrap();
116 let abs_var = Variable::new(abs_tensor, diff.requires_grad());
117
118 match self.reduction {
119 Reduction::None => abs_var,
120 Reduction::Mean => abs_var.mean(),
121 Reduction::Sum => abs_var.sum(),
122 }
123 }
124}
125
126impl Default for L1Loss {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132#[derive(Debug, Clone, Copy)]
144pub struct CrossEntropyLoss {
145 reduction: Reduction,
146}
147
148impl CrossEntropyLoss {
149 pub fn new() -> Self {
151 Self {
152 reduction: Reduction::Mean,
153 }
154 }
155
156 pub fn with_reduction(reduction: Reduction) -> Self {
158 Self { reduction }
159 }
160
161 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
167 let input_data = input.data();
168 let target_data = target.data();
169 let shape = input_data.shape().to_vec();
170 let batch_size = shape[0];
171 let num_classes = shape[1];
172
173 let input_vec = input_data.to_vec();
174 let target_vec = target_data.to_vec();
175
176 let mut losses = vec![0.0f32; batch_size];
177
178 for b in 0..batch_size {
179 let offset = b * num_classes;
181 let max_val = (0..num_classes)
182 .map(|c| input_vec[offset + c])
183 .fold(f32::NEG_INFINITY, f32::max);
184
185 let mut log_sum_exp = 0.0f32;
186 for c in 0..num_classes {
187 log_sum_exp += (input_vec[offset + c] - max_val).exp();
188 }
189 log_sum_exp = max_val + log_sum_exp.ln();
190
191 let target_class = target_vec[b] as usize;
193 losses[b] = log_sum_exp - input_vec[offset + target_class];
194 }
195
196 let loss_tensor = Tensor::from_vec(losses.clone(), &[batch_size]).unwrap();
197 let loss_var = Variable::new(loss_tensor, input.requires_grad());
198
199 match self.reduction {
200 Reduction::None => loss_var,
201 Reduction::Mean => loss_var.mean(),
202 Reduction::Sum => loss_var.sum(),
203 }
204 }
205}
206
207impl Default for CrossEntropyLoss {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213#[derive(Debug, Clone, Copy)]
221pub struct NLLLoss {
222 reduction: Reduction,
223}
224
225impl NLLLoss {
226 pub fn new() -> Self {
228 Self {
229 reduction: Reduction::Mean,
230 }
231 }
232
233 pub fn with_reduction(reduction: Reduction) -> Self {
235 Self { reduction }
236 }
237
238 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
240 let input_data = input.data();
241 let target_data = target.data();
242 let shape = input_data.shape().to_vec();
243 let batch_size = shape[0];
244 let num_classes = shape[1];
245
246 let input_vec = input_data.to_vec();
247 let target_vec = target_data.to_vec();
248
249 let mut losses = vec![0.0f32; batch_size];
250
251 for b in 0..batch_size {
252 let target_class = target_vec[b] as usize;
253 losses[b] = -input_vec[b * num_classes + target_class];
254 }
255
256 let loss_tensor = Tensor::from_vec(losses, &[batch_size]).unwrap();
257 let loss_var = Variable::new(loss_tensor, input.requires_grad());
258
259 match self.reduction {
260 Reduction::None => loss_var,
261 Reduction::Mean => loss_var.mean(),
262 Reduction::Sum => loss_var.sum(),
263 }
264 }
265}
266
267impl Default for NLLLoss {
268 fn default() -> Self {
269 Self::new()
270 }
271}
272
273#[derive(Debug, Clone, Copy)]
281pub struct BCELoss {
282 reduction: Reduction,
283}
284
285impl BCELoss {
286 pub fn new() -> Self {
288 Self {
289 reduction: Reduction::Mean,
290 }
291 }
292
293 pub fn with_reduction(reduction: Reduction) -> Self {
295 Self { reduction }
296 }
297
298 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
300 let eps = 1e-7f32;
301 let input_data = input.data();
302 let target_data = target.data();
303
304 let input_vec = input_data.to_vec();
305 let target_vec = target_data.to_vec();
306
307 let losses: Vec<f32> = input_vec
308 .iter()
309 .zip(target_vec.iter())
310 .map(|(&p, &t)| {
311 let p_clamped = p.max(eps).min(1.0 - eps);
312 -(t * p_clamped.ln() + (1.0 - t) * (1.0 - p_clamped).ln())
313 })
314 .collect();
315
316 let loss_tensor = Tensor::from_vec(losses, input_data.shape()).unwrap();
317 let loss_var = Variable::new(loss_tensor, input.requires_grad());
318
319 match self.reduction {
320 Reduction::None => loss_var,
321 Reduction::Mean => loss_var.mean(),
322 Reduction::Sum => loss_var.sum(),
323 }
324 }
325}
326
327impl Default for BCELoss {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333#[derive(Debug, Clone, Copy)]
341pub struct BCEWithLogitsLoss {
342 reduction: Reduction,
343}
344
345impl BCEWithLogitsLoss {
346 pub fn new() -> Self {
348 Self {
349 reduction: Reduction::Mean,
350 }
351 }
352
353 pub fn with_reduction(reduction: Reduction) -> Self {
355 Self { reduction }
356 }
357
358 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
360 let input_data = input.data();
361 let target_data = target.data();
362
363 let input_vec = input_data.to_vec();
364 let target_vec = target_data.to_vec();
365
366 let losses: Vec<f32> = input_vec
368 .iter()
369 .zip(target_vec.iter())
370 .map(|(&x, &t)| {
371 let max_val = x.max(0.0);
372 max_val - x * t + (1.0 + (-x.abs()).exp()).ln()
373 })
374 .collect();
375
376 let loss_tensor = Tensor::from_vec(losses, input_data.shape()).unwrap();
377 let loss_var = Variable::new(loss_tensor, input.requires_grad());
378
379 match self.reduction {
380 Reduction::None => loss_var,
381 Reduction::Mean => loss_var.mean(),
382 Reduction::Sum => loss_var.sum(),
383 }
384 }
385}
386
387impl Default for BCEWithLogitsLoss {
388 fn default() -> Self {
389 Self::new()
390 }
391}
392
393#[derive(Debug, Clone, Copy)]
401pub struct SmoothL1Loss {
402 reduction: Reduction,
403 beta: f32,
404}
405
406impl SmoothL1Loss {
407 pub fn new() -> Self {
409 Self {
410 reduction: Reduction::Mean,
411 beta: 1.0,
412 }
413 }
414
415 pub fn with_beta(beta: f32) -> Self {
417 Self {
418 reduction: Reduction::Mean,
419 beta,
420 }
421 }
422
423 pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
425 let diff = input.sub_var(target);
426 let diff_data = diff.data();
427 let diff_vec = diff_data.to_vec();
428
429 let losses: Vec<f32> = diff_vec
430 .iter()
431 .map(|&d| {
432 let abs_d = d.abs();
433 if abs_d < self.beta {
434 0.5 * d * d / self.beta
435 } else {
436 abs_d - 0.5 * self.beta
437 }
438 })
439 .collect();
440
441 let loss_tensor = Tensor::from_vec(losses, diff_data.shape()).unwrap();
442 let loss_var = Variable::new(loss_tensor, diff.requires_grad());
443
444 match self.reduction {
445 Reduction::None => loss_var,
446 Reduction::Mean => loss_var.mean(),
447 Reduction::Sum => loss_var.sum(),
448 }
449 }
450}
451
452impl Default for SmoothL1Loss {
453 fn default() -> Self {
454 Self::new()
455 }
456}
457
458#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[test]
467 fn test_mse_loss() {
468 let loss_fn = MSELoss::new();
469 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
470 let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
471 let loss = loss_fn.compute(&input, &target);
472 assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
473 }
474
475 #[test]
476 fn test_mse_loss_nonzero() {
477 let loss_fn = MSELoss::new();
478 let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(), false);
479 let target = Variable::new(Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap(), false);
480 let loss = loss_fn.compute(&input, &target);
481 assert!((loss.data().to_vec()[0] - 1.0).abs() < 1e-6);
483 }
484
485 #[test]
486 fn test_cross_entropy_loss() {
487 let loss_fn = CrossEntropyLoss::new();
488 let input = Variable::new(
489 Tensor::from_vec(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], &[2, 3]).unwrap(),
490 false,
491 );
492 let target = Variable::new(Tensor::from_vec(vec![2.0, 0.0], &[2]).unwrap(), false);
493 let loss = loss_fn.compute(&input, &target);
494 assert!(loss.data().to_vec()[0] > 0.0);
495 }
496
497 #[test]
498 fn test_bce_loss() {
499 let loss_fn = BCELoss::new();
500 let input = Variable::new(Tensor::from_vec(vec![0.5, 0.5], &[2]).unwrap(), false);
501 let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
502 let loss = loss_fn.compute(&input, &target);
503 assert!((loss.data().to_vec()[0] - 0.693).abs() < 0.01);
505 }
506}