1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6fn validate_loss_inputs(
7 graph: &Graph,
8 prediction: NodeId,
9 target: NodeId,
10) -> Result<usize, ModelError> {
11 let prediction_shape = graph.value(prediction)?.shape().to_vec();
12 let target_shape = graph.value(target)?.shape().to_vec();
13 if prediction_shape != target_shape {
14 return Err(ModelError::PredictionTargetShapeMismatch {
15 prediction: prediction_shape,
16 target: target_shape,
17 });
18 }
19
20 let element_count = graph.value(prediction)?.len();
21 if element_count == 0 {
22 return Err(ModelError::EmptyLossTensor);
23 }
24 Ok(element_count)
25}
26
27fn abs_node(graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
28 let zero = graph.constant(Tensor::scalar(0.0));
29 let neg_input = graph.sub(zero, input)?;
30 let positive = graph.relu(input)?;
31 let negative = graph.relu(neg_input)?;
32 graph.add(positive, negative).map_err(Into::into)
33}
34
35pub fn mse_loss(
37 graph: &mut Graph,
38 prediction: NodeId,
39 target: NodeId,
40) -> Result<NodeId, ModelError> {
41 let element_count = validate_loss_inputs(graph, prediction, target)?;
42
43 let diff = graph.sub(prediction, target)?;
44 let sq = graph.mul(diff, diff)?;
45 let sum = graph.sum(sq)?;
46 let inv_count = graph.constant(Tensor::scalar(1.0 / element_count as f32));
47 graph.mul(sum, inv_count).map_err(Into::into)
48}
49
50pub fn mae_loss(
52 graph: &mut Graph,
53 prediction: NodeId,
54 target: NodeId,
55) -> Result<NodeId, ModelError> {
56 let element_count = validate_loss_inputs(graph, prediction, target)?;
57
58 let diff = graph.sub(prediction, target)?;
59 let abs = abs_node(graph, diff)?;
60 let sum = graph.sum(abs)?;
61 let inv_count = graph.constant(Tensor::scalar(1.0 / element_count as f32));
62 graph.mul(sum, inv_count).map_err(Into::into)
63}
64
65pub fn huber_loss(
68 graph: &mut Graph,
69 prediction: NodeId,
70 target: NodeId,
71 delta: f32,
72) -> Result<NodeId, ModelError> {
73 if !delta.is_finite() || delta <= 0.0 {
74 return Err(ModelError::InvalidHuberDelta { delta });
75 }
76 let element_count = validate_loss_inputs(graph, prediction, target)?;
77
78 let diff = graph.sub(prediction, target)?;
79 let abs = abs_node(graph, diff)?;
80 let delta_node = graph.constant(Tensor::scalar(delta));
81 let abs_minus_delta = graph.sub(abs, delta_node)?;
82 let excess = graph.relu(abs_minus_delta)?;
83 let clipped = graph.sub(abs, excess)?;
84
85 let clipped_sq = graph.mul(clipped, clipped)?;
86 let half = graph.constant(Tensor::scalar(0.5));
87 let quadratic = graph.mul(clipped_sq, half)?;
88 let linear = graph.mul(excess, delta_node)?;
89 let per_element = graph.add(quadratic, linear)?;
90 let sum = graph.sum(per_element)?;
91 let inv_count = graph.constant(Tensor::scalar(1.0 / element_count as f32));
92 graph.mul(sum, inv_count).map_err(Into::into)
93}
94
95pub fn hinge_loss(
98 graph: &mut Graph,
99 prediction: NodeId,
100 target: NodeId,
101 margin: f32,
102) -> Result<NodeId, ModelError> {
103 if !margin.is_finite() || margin <= 0.0 {
104 return Err(ModelError::InvalidHingeMargin { margin });
105 }
106 let element_count = validate_loss_inputs(graph, prediction, target)?;
107
108 let product = graph.mul(prediction, target)?;
109 let margin_node = graph.constant(Tensor::scalar(margin));
110 let raw = graph.sub(margin_node, product)?;
111 let positive = graph.relu(raw)?;
112 let sum = graph.sum(positive)?;
113 let inv_count = graph.constant(Tensor::scalar(1.0 / element_count as f32));
114 graph.mul(sum, inv_count).map_err(Into::into)
115}
116
117pub fn bce_loss(
122 graph: &mut Graph,
123 prediction: NodeId,
124 target: NodeId,
125) -> Result<NodeId, ModelError> {
126 let element_count = validate_loss_inputs(graph, prediction, target)?;
127
128 let eps = 1e-7_f32;
129 let eps_node = graph.constant(Tensor::scalar(eps));
130 let one_node = graph.constant(Tensor::scalar(1.0));
131
132 let shifted_low = graph.sub(prediction, eps_node)?;
134 let positive_part = graph.relu(shifted_low)?;
135 let pred_above_eps = graph.add(positive_part, eps_node)?;
136
137 let one_minus_eps_node = graph.constant(Tensor::scalar(1.0 - eps));
138 let over = graph.sub(pred_above_eps, one_minus_eps_node)?;
139 let excess = graph.relu(over)?;
140 let pred_safe = graph.sub(pred_above_eps, excess)?;
141
142 let log_pred = graph.log(pred_safe)?;
144
145 let one_minus_pred = graph.sub(one_node, pred_safe)?;
147 let one_minus_pred_safe = graph.add(one_minus_pred, eps_node)?;
148 let log_one_minus_pred = graph.log(one_minus_pred_safe)?;
149
150 let term1 = graph.mul(target, log_pred)?;
152 let one_minus_t = graph.sub(one_node, target)?;
153 let term2 = graph.mul(one_minus_t, log_one_minus_pred)?;
154 let combined = graph.add(term1, term2)?;
155 let sum = graph.sum(combined)?;
156 let neg_sum = graph.neg(sum)?;
157 let inv_count = graph.constant(Tensor::scalar(1.0 / element_count as f32));
158 graph.mul(neg_sum, inv_count).map_err(Into::into)
159}
160
161pub fn nll_loss(
167 graph: &mut Graph,
168 log_probs: NodeId,
169 targets: NodeId,
170) -> Result<NodeId, ModelError> {
171 let lp_shape = graph.value(log_probs)?.shape().to_vec();
172 let t_shape = graph.value(targets)?.shape().to_vec();
173
174 if lp_shape.len() != 2 {
175 return Err(ModelError::InvalidInputShape {
176 expected_features: 0,
177 got: lp_shape,
178 });
179 }
180 if t_shape.len() != 2 || t_shape[1] != 1 {
181 return Err(ModelError::PredictionTargetShapeMismatch {
182 prediction: lp_shape.clone(),
183 target: t_shape,
184 });
185 }
186 let batch_size = lp_shape[0];
187 let num_classes = lp_shape[1];
188 if batch_size == 0 {
189 return Err(ModelError::EmptyLossTensor);
190 }
191
192 let lp_data = graph.value(log_probs)?.data().to_vec();
193 let t_data = graph.value(targets)?.data().to_vec();
194
195 let mut selected = vec![0.0f32; batch_size];
196 for i in 0..batch_size {
197 let class_idx = t_data[i] as usize;
198 if class_idx >= num_classes {
199 return Err(ModelError::InvalidDatasetRecordValue {
200 line: i,
201 field: "nll_target",
202 index: 0,
203 reason: "class index out of range",
204 });
205 }
206 selected[i] = lp_data[i * num_classes + class_idx];
207 }
208
209 let selected_node = graph.constant(Tensor::from_vec(vec![batch_size], selected)?);
210 let sum = graph.sum(selected_node)?;
211 let neg_sum = graph.neg(sum)?;
212 let inv_batch = graph.constant(Tensor::scalar(1.0 / batch_size as f32));
213 graph.mul(neg_sum, inv_batch).map_err(Into::into)
214}
215
216pub fn cross_entropy_loss(
222 graph: &mut Graph,
223 logits: NodeId,
224 targets: NodeId,
225) -> Result<NodeId, ModelError> {
226 let shape = graph.value(logits)?.shape().to_vec();
227 if shape.len() != 2 {
228 return Err(ModelError::InvalidInputShape {
229 expected_features: 0,
230 got: shape,
231 });
232 }
233 let batch_size = shape[0];
234 let num_classes = shape[1];
235
236 if batch_size == 0 {
237 return Err(ModelError::EmptyLossTensor);
238 }
239
240 let logits_data = graph.value(logits)?.data().to_vec();
241 let t_data = graph.value(targets)?.data().to_vec();
242 let t_shape = graph.value(targets)?.shape().to_vec();
243
244 if t_shape.len() != 2 || t_shape[1] != 1 || t_shape[0] != batch_size {
245 return Err(ModelError::PredictionTargetShapeMismatch {
246 prediction: shape.clone(),
247 target: t_shape,
248 });
249 }
250
251 let mut log_probs = vec![0.0f32; batch_size * num_classes];
254 for b in 0..batch_size {
255 let row = &logits_data[b * num_classes..(b + 1) * num_classes];
256 let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
257 let sum_exp: f32 = row.iter().map(|&v| (v - max_val).exp()).sum();
258 let log_sum_exp = max_val + sum_exp.ln();
259 for c in 0..num_classes {
260 log_probs[b * num_classes + c] = row[c] - log_sum_exp;
261 }
262 }
263
264 let mut neg_sum = 0.0f32;
266 for i in 0..batch_size {
267 let class_idx = t_data[i] as usize;
268 if class_idx >= num_classes {
269 return Err(ModelError::InvalidDatasetRecordValue {
270 line: i,
271 field: "cross_entropy_target",
272 index: 0,
273 reason: "class index out of range",
274 });
275 }
276 neg_sum -= log_probs[i * num_classes + class_idx];
277 }
278
279 let loss_val = neg_sum / batch_size as f32;
280 let loss_node = graph.constant(Tensor::scalar(loss_val));
281 Ok(loss_node)
282}
283
284pub fn focal_loss(
289 graph: &mut Graph,
290 prediction: NodeId,
291 target: NodeId,
292 alpha: f32,
293 gamma: f32,
294) -> Result<NodeId, ModelError> {
295 let element_count = validate_loss_inputs(graph, prediction, target)?;
296 let eps = 1e-7_f32;
297
298 let pred_data = graph.value(prediction)?.data().to_vec();
299 let target_data = graph.value(target)?.data().to_vec();
300
301 let mut loss_sum = 0.0f32;
302 for i in 0..element_count {
303 let p = pred_data[i].clamp(eps, 1.0 - eps);
304 let t = target_data[i];
305 let pt = if t > 0.5 { p } else { 1.0 - p };
306 loss_sum += -alpha * (1.0 - pt).powf(gamma) * pt.ln();
307 }
308
309 let loss_val = loss_sum / element_count as f32;
310 Ok(graph.constant(Tensor::scalar(loss_val)))
311}
312
313pub fn dice_loss(
318 graph: &mut Graph,
319 prediction: NodeId,
320 target: NodeId,
321 smooth: f32,
322) -> Result<NodeId, ModelError> {
323 let element_count = validate_loss_inputs(graph, prediction, target)?;
324
325 let pred_data = graph.value(prediction)?.data().to_vec();
326 let target_data = graph.value(target)?.data().to_vec();
327
328 let mut intersection = 0.0f32;
329 let mut pred_sum = 0.0f32;
330 let mut target_sum = 0.0f32;
331 for i in 0..element_count {
332 intersection += pred_data[i] * target_data[i];
333 pred_sum += pred_data[i];
334 target_sum += target_data[i];
335 }
336
337 let dice = (2.0 * intersection + smooth) / (pred_sum + target_sum + smooth);
338 Ok(graph.constant(Tensor::scalar(1.0 - dice)))
339}
340
341pub fn triplet_loss(
346 graph: &mut Graph,
347 anchor: NodeId,
348 positive: NodeId,
349 negative: NodeId,
350 margin: f32,
351) -> Result<NodeId, ModelError> {
352 let a_shape = graph.value(anchor)?.shape().to_vec();
353 let p_shape = graph.value(positive)?.shape().to_vec();
354 let n_shape = graph.value(negative)?.shape().to_vec();
355 if a_shape != p_shape || a_shape != n_shape {
356 return Err(ModelError::PredictionTargetShapeMismatch {
357 prediction: a_shape,
358 target: p_shape,
359 });
360 }
361 if a_shape.len() != 2 || a_shape[0] == 0 {
362 return Err(ModelError::EmptyLossTensor);
363 }
364 let batch = a_shape[0];
365 let dim = a_shape[1];
366
367 let a_data = graph.value(anchor)?.data().to_vec();
368 let p_data = graph.value(positive)?.data().to_vec();
369 let n_data = graph.value(negative)?.data().to_vec();
370
371 let mut loss_sum = 0.0f32;
372 for b in 0..batch {
373 let mut dp = 0.0f32;
374 let mut dn = 0.0f32;
375 for d in 0..dim {
376 let idx = b * dim + d;
377 dp += (a_data[idx] - p_data[idx]).powi(2);
378 dn += (a_data[idx] - n_data[idx]).powi(2);
379 }
380 loss_sum += (dp.sqrt() - dn.sqrt() + margin).max(0.0);
381 }
382
383 Ok(graph.constant(Tensor::scalar(loss_sum / batch as f32)))
384}
385
386pub fn contrastive_loss(
391 graph: &mut Graph,
392 x1: NodeId,
393 x2: NodeId,
394 label: NodeId,
395 margin: f32,
396) -> Result<NodeId, ModelError> {
397 let s1 = graph.value(x1)?.shape().to_vec();
398 let s2 = graph.value(x2)?.shape().to_vec();
399 if s1 != s2 || s1.len() != 2 || s1[0] == 0 {
400 return Err(ModelError::PredictionTargetShapeMismatch {
401 prediction: s1,
402 target: s2,
403 });
404 }
405 let batch = s1[0];
406 let dim = s1[1];
407
408 let x1d = graph.value(x1)?.data().to_vec();
409 let x2d = graph.value(x2)?.data().to_vec();
410 let ld = graph.value(label)?.data().to_vec();
411
412 let mut loss_sum = 0.0f32;
413 for b in 0..batch {
414 let mut dist_sq = 0.0f32;
415 for d in 0..dim {
416 let idx = b * dim + d;
417 dist_sq += (x1d[idx] - x2d[idx]).powi(2);
418 }
419 let dist = dist_sq.sqrt();
420 let y = ld[b];
421 loss_sum += y * dist_sq + (1.0 - y) * (margin - dist).max(0.0).powi(2);
422 }
423
424 Ok(graph.constant(Tensor::scalar(loss_sum / batch as f32)))
425}
426
427pub fn cosine_embedding_loss(
431 graph: &mut Graph,
432 x1: NodeId,
433 x2: NodeId,
434 label: NodeId,
435 margin: f32,
436) -> Result<NodeId, ModelError> {
437 let s1 = graph.value(x1)?.shape().to_vec();
438 let s2 = graph.value(x2)?.shape().to_vec();
439 if s1 != s2 || s1.len() != 2 || s1[0] == 0 {
440 return Err(ModelError::PredictionTargetShapeMismatch {
441 prediction: s1,
442 target: s2,
443 });
444 }
445 let batch = s1[0];
446 let dim = s1[1];
447
448 let x1d = graph.value(x1)?.data().to_vec();
449 let x2d = graph.value(x2)?.data().to_vec();
450 let ld = graph.value(label)?.data().to_vec();
451
452 let mut loss_sum = 0.0f32;
453 for b in 0..batch {
454 let mut dot = 0.0f32;
455 let mut n1 = 0.0f32;
456 let mut n2 = 0.0f32;
457 for d in 0..dim {
458 let idx = b * dim + d;
459 dot += x1d[idx] * x2d[idx];
460 n1 += x1d[idx] * x1d[idx];
461 n2 += x2d[idx] * x2d[idx];
462 }
463 let cos = dot / (n1.sqrt() * n2.sqrt()).max(1e-8);
464 let y = ld[b];
465 if y > 0.0 {
466 loss_sum += 1.0 - cos;
467 } else {
468 loss_sum += (cos - margin).max(0.0);
469 }
470 }
471
472 Ok(graph.constant(Tensor::scalar(loss_sum / batch as f32)))
473}
474
475pub fn label_smoothing_cross_entropy(
480 graph: &mut Graph,
481 logits: NodeId,
482 targets: NodeId,
483 smoothing: f32,
484) -> Result<NodeId, ModelError> {
485 let shape = graph.value(logits)?.shape().to_vec();
486 if shape.len() != 2 || shape[0] == 0 {
487 return Err(ModelError::EmptyLossTensor);
488 }
489 let batch_size = shape[0];
490 let num_classes = shape[1];
491
492 let logits_data = graph.value(logits)?.data().to_vec();
493 let t_data = graph.value(targets)?.data().to_vec();
494
495 let smooth_val = smoothing / num_classes as f32;
496 let confidence = 1.0 - smoothing;
497
498 let mut total_loss = 0.0f32;
499 for b in 0..batch_size {
500 let row = &logits_data[b * num_classes..(b + 1) * num_classes];
501 let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
502 let sum_exp: f32 = row.iter().map(|&v| (v - max_val).exp()).sum();
503 let log_sum_exp = max_val + sum_exp.ln();
504
505 let class_idx = t_data[b] as usize;
506 for c in 0..num_classes {
507 let log_prob = row[c] - log_sum_exp;
508 let target_prob = if c == class_idx {
509 confidence + smooth_val
510 } else {
511 smooth_val
512 };
513 total_loss -= target_prob * log_prob;
514 }
515 }
516
517 Ok(graph.constant(Tensor::scalar(total_loss / batch_size as f32)))
518}
519
520pub fn ctc_loss(
525 graph: &mut Graph,
526 log_probs: NodeId,
527 targets: NodeId,
528 input_lengths: NodeId,
529 target_lengths: NodeId,
530 blank: usize,
531) -> Result<NodeId, ModelError> {
532 let lp_shape = graph.value(log_probs)?.shape().to_vec();
533 if lp_shape.len() != 3 {
534 return Err(ModelError::InvalidInputShape {
535 expected_features: 0,
536 got: lp_shape,
537 });
538 }
539 let _t_max = lp_shape[0];
540 let batch = lp_shape[1];
541 let num_classes = lp_shape[2];
542
543 let lp_data = graph.value(log_probs)?.data().to_vec();
544 let tgt_data = graph.value(targets)?.data().to_vec();
545 let il_data = graph.value(input_lengths)?.data().to_vec();
546 let tl_data = graph.value(target_lengths)?.data().to_vec();
547
548 let tgt_shape = graph.value(targets)?.shape().to_vec();
549 let s_max = if tgt_shape.len() >= 2 {
550 tgt_shape[1]
551 } else {
552 tgt_shape[0] / batch
553 };
554
555 let mut total_loss = 0.0f32;
556
557 for b in 0..batch {
558 let input_len = il_data[b] as usize;
559 let target_len = tl_data[b] as usize;
560
561 let label_len = 2 * target_len + 1;
563 let mut labels = vec![blank; label_len];
564 for s in 0..target_len {
565 labels[2 * s + 1] = tgt_data[b * s_max + s] as usize;
566 }
567
568 let mut alpha = vec![f32::NEG_INFINITY; label_len * input_len];
570 alpha[0] = lp_data[b * num_classes + labels[0]];
572 if label_len > 1 {
573 alpha[1] = lp_data[b * num_classes + labels[1]];
574 }
575
576 for t in 1..input_len {
577 for s in 0..label_len {
578 let lp_idx = t * batch * num_classes + b * num_classes + labels[s];
579 let log_p = lp_data[lp_idx];
580 let mut sum = alpha[(t - 1) * label_len + s];
581 if s > 0 {
582 sum = log_sum_exp_pair(sum, alpha[(t - 1) * label_len + s - 1]);
583 }
584 if s > 1 && labels[s] != blank && labels[s] != labels[s - 2] {
585 sum = log_sum_exp_pair(sum, alpha[(t - 1) * label_len + s - 2]);
586 }
587 alpha[t * label_len + s] = sum + log_p;
588 }
589 }
590
591 let last_t = input_len - 1;
592 let log_likelihood = log_sum_exp_pair(
593 alpha[last_t * label_len + label_len - 1],
594 if label_len >= 2 {
595 alpha[last_t * label_len + label_len - 2]
596 } else {
597 f32::NEG_INFINITY
598 },
599 );
600 total_loss -= log_likelihood;
601 }
602
603 Ok(graph.constant(Tensor::scalar(total_loss / batch as f32)))
604}
605
606pub fn smooth_l1_loss(
616 graph: &mut Graph,
617 prediction: NodeId,
618 target: NodeId,
619 beta: f32,
620) -> Result<NodeId, ModelError> {
621 if !beta.is_finite() || beta <= 0.0 {
622 return Err(ModelError::InvalidHuberDelta { delta: beta });
623 }
624 let element_count = validate_loss_inputs(graph, prediction, target)?;
625
626 let pred_data = graph.value(prediction)?.data().to_vec();
627 let target_data = graph.value(target)?.data().to_vec();
628
629 let mut loss_sum = 0.0f32;
630 for i in 0..element_count {
631 let x = (pred_data[i] - target_data[i]).abs();
632 if x < beta {
633 loss_sum += 0.5 * x * x / beta;
634 } else {
635 loss_sum += x - 0.5 * beta;
636 }
637 }
638
639 Ok(graph.constant(Tensor::scalar(loss_sum / element_count as f32)))
640}
641
642pub fn kl_div_loss(
652 graph: &mut Graph,
653 log_prediction: NodeId,
654 target: NodeId,
655) -> Result<NodeId, ModelError> {
656 let element_count = validate_loss_inputs(graph, log_prediction, target)?;
657
658 let log_pred_data = graph.value(log_prediction)?.data().to_vec();
659 let target_data = graph.value(target)?.data().to_vec();
660
661 let mut loss_sum = 0.0f32;
662 for i in 0..element_count {
663 let t = target_data[i];
664 if t > 0.0 {
665 loss_sum += t * (t.ln() - log_pred_data[i]);
666 }
667 }
668
669 Ok(graph.constant(Tensor::scalar(loss_sum / element_count as f32)))
670}
671
672fn log_sum_exp_pair(a: f32, b: f32) -> f32 {
673 if a == f32::NEG_INFINITY {
674 return b;
675 }
676 if b == f32::NEG_INFINITY {
677 return a;
678 }
679 let max = a.max(b);
680 max + ((a - max).exp() + (b - max).exp()).ln()
681}
682
683pub fn distillation_loss(
702 graph: &mut Graph,
703 student: NodeId,
704 teacher: NodeId,
705 labels: NodeId,
706 temperature: f32,
707 alpha: f32,
708) -> Result<NodeId, ModelError> {
709 let t_scalar = graph.constant(Tensor::scalar(temperature));
711 let t2_scalar = graph.constant(Tensor::scalar(temperature * temperature));
712
713 let s_scaled = graph.div(student, t_scalar)?;
714 let t_scaled = graph.div(teacher, t_scalar)?;
715
716 let s_log_softmax = graph.log_softmax(s_scaled)?;
717 let t_softmax = graph.softmax(t_scaled)?;
718
719 let t_log = graph.log(t_softmax)?;
721 let kl_pointwise = graph.sub(t_log, s_log_softmax)?;
722 let kl_weighted = graph.mul(t_softmax, kl_pointwise)?;
723 let kl_sum = graph.mean(kl_weighted)?;
724 let soft_loss = graph.mul(kl_sum, t2_scalar)?;
725
726 let hard_loss = cross_entropy_loss(graph, student, labels)?;
728
729 let alpha_node = graph.constant(Tensor::scalar(alpha));
731 let one_minus_alpha = graph.constant(Tensor::scalar(1.0 - alpha));
732
733 let weighted_soft = graph.mul(soft_loss, alpha_node)?;
734 let weighted_hard = graph.mul(hard_loss, one_minus_alpha)?;
735
736 graph.add(weighted_soft, weighted_hard).map_err(Into::into)
737}