aprender/loss/loss.rs
1
2/// `InfoNCE` (Noise Contrastive Estimation) loss for contrastive learning.
3///
4/// Also known as NT-Xent (Normalized Temperature-scaled Cross Entropy).
5///
6/// ```text
7/// L = -log(exp(sim(z_i, z_j)/τ) / Σ_k exp(sim(z_i, z_k)/τ))
8/// ```
9///
10/// # Arguments
11///
12/// * `anchor` - Anchor embedding
13/// * `positive` - Positive example embedding
14/// * `negatives` - Slice of negative example embeddings
15/// * `temperature` - Temperature scaling parameter (typically 0.07-0.5)
16///
17/// # Returns
18///
19/// The `InfoNCE` loss value
20///
21/// # Example
22///
23/// ```
24/// use aprender::loss::info_nce_loss;
25/// use aprender::primitives::Vector;
26///
27/// let anchor = Vector::from_slice(&[1.0, 0.0, 0.0]);
28/// let positive = Vector::from_slice(&[0.9, 0.1, 0.0]);
29/// let negatives = vec![
30/// Vector::from_slice(&[0.0, 1.0, 0.0]),
31/// Vector::from_slice(&[0.0, 0.0, 1.0]),
32/// ];
33///
34/// let loss = info_nce_loss(&anchor, &positive, &negatives, 0.1);
35/// assert!(loss >= 0.0);
36/// ```
37#[must_use]
38pub fn info_nce_loss(
39 anchor: &Vector<f32>,
40 positive: &Vector<f32>,
41 negatives: &[Vector<f32>],
42 temperature: f32,
43) -> f32 {
44 assert_eq!(
45 anchor.len(),
46 positive.len(),
47 "Anchor and positive must have same dimension"
48 );
49 for neg in negatives {
50 assert_eq!(
51 anchor.len(),
52 neg.len(),
53 "All embeddings must have same dimension"
54 );
55 }
56 assert!(temperature > 0.0, "Temperature must be positive");
57
58 // Compute similarity with positive
59 let sim_pos = cosine_similarity(anchor, positive) / temperature;
60
61 // Compute log-sum-exp over all (positive + negatives)
62 let mut max_sim = sim_pos;
63 for neg in negatives {
64 let sim_neg = cosine_similarity(anchor, neg) / temperature;
65 max_sim = max_sim.max(sim_neg);
66 }
67
68 // Numerically stable log-sum-exp
69 let mut sum_exp = (sim_pos - max_sim).exp();
70 for neg in negatives {
71 let sim_neg = cosine_similarity(anchor, neg) / temperature;
72 sum_exp += (sim_neg - max_sim).exp();
73 }
74
75 // -log(exp(sim_pos) / sum_exp) = -sim_pos + log(sum_exp)
76 -sim_pos + max_sim + sum_exp.ln()
77}
78
79/// Focal loss for class imbalance (spec: more-learning-specs.md §18).
80///
81/// Down-weights easy examples, focuses on hard examples:
82///
83/// ```text
84/// FL(p) = -α * (1 - p)^γ * log(p)
85/// ```
86///
87/// # Arguments
88///
89/// * `predictions` - Predicted probabilities (after sigmoid/softmax)
90/// * `targets` - Binary targets (0 or 1)
91/// * `alpha` - Balancing factor (typically 0.25 for rare class)
92/// * `gamma` - Focusing parameter (typically 2.0)
93///
94/// # Returns
95///
96/// The focal loss value
97///
98/// # Example
99///
100/// ```
101/// use aprender::loss::focal_loss;
102/// use aprender::primitives::Vector;
103///
104/// let predictions = Vector::from_slice(&[0.9, 0.1, 0.8]);
105/// let targets = Vector::from_slice(&[1.0, 0.0, 1.0]);
106///
107/// let loss = focal_loss(&predictions, &targets, 0.25, 2.0);
108/// assert!(loss >= 0.0);
109/// ```
110#[must_use]
111pub fn focal_loss(predictions: &Vector<f32>, targets: &Vector<f32>, alpha: f32, gamma: f32) -> f32 {
112 assert_eq!(
113 predictions.len(),
114 targets.len(),
115 "Predictions and targets must have same length"
116 );
117
118 let n = predictions.len() as f32;
119 let mut sum = 0.0;
120
121 for i in 0..predictions.len() {
122 let p = predictions[i].clamp(1e-7, 1.0 - 1e-7);
123 let t = targets[i];
124
125 // For positive class (t=1): -α * (1-p)^γ * log(p)
126 // For negative class (t=0): -(1-α) * p^γ * log(1-p)
127 let loss = if t > 0.5 {
128 -alpha * (1.0 - p).powf(gamma) * p.ln()
129 } else {
130 -(1.0 - alpha) * p.powf(gamma) * (1.0 - p).ln()
131 };
132
133 sum += loss;
134 }
135
136 sum / n
137}
138
139/// KL Divergence loss between two probability distributions.
140///
141/// ```text
142/// KL(P || Q) = Σ P(x) * log(P(x) / Q(x))
143/// ```
144///
145/// # Arguments
146///
147/// * `p` - True distribution (targets)
148/// * `q` - Predicted distribution
149///
150/// # Returns
151///
152/// The KL divergence (always >= 0)
153///
154/// # Example
155///
156/// ```
157/// use aprender::loss::kl_divergence;
158/// use aprender::primitives::Vector;
159///
160/// let p = Vector::from_slice(&[0.5, 0.3, 0.2]);
161/// let q = Vector::from_slice(&[0.4, 0.4, 0.2]);
162///
163/// let kl = kl_divergence(&p, &q);
164/// assert!(kl >= 0.0);
165/// ```
166#[must_use]
167pub fn kl_divergence(p: &Vector<f32>, q: &Vector<f32>) -> f32 {
168 assert_eq!(p.len(), q.len(), "Distributions must have same length");
169
170 let mut sum = 0.0;
171 for i in 0..p.len() {
172 if p[i] > 1e-10 {
173 let q_safe = q[i].max(1e-10);
174 sum += p[i] * (p[i] / q_safe).ln();
175 }
176 }
177
178 sum
179}
180
181/// Triplet loss function (struct wrapper).
182#[derive(Debug, Clone, Copy)]
183pub struct TripletLoss {
184 margin: f32,
185}
186
187impl TripletLoss {
188 /// Creates a new Triplet loss with the given margin.
189 #[must_use]
190 pub fn new(margin: f32) -> Self {
191 Self { margin }
192 }
193
194 /// Returns the margin parameter.
195 #[must_use]
196 pub fn margin(&self) -> f32 {
197 self.margin
198 }
199
200 /// Compute triplet loss for given embeddings.
201 #[must_use]
202 pub fn compute_triplet(
203 &self,
204 anchor: &Vector<f32>,
205 positive: &Vector<f32>,
206 negative: &Vector<f32>,
207 ) -> f32 {
208 triplet_loss(anchor, positive, negative, self.margin)
209 }
210}
211
212/// Focal loss function (struct wrapper).
213#[derive(Debug, Clone, Copy)]
214pub struct FocalLoss {
215 alpha: f32,
216 gamma: f32,
217}
218
219impl FocalLoss {
220 /// Creates a new Focal loss with given parameters.
221 ///
222 /// # Arguments
223 ///
224 /// * `alpha` - Balancing factor for rare class (typically 0.25)
225 /// * `gamma` - Focusing parameter (typically 2.0)
226 #[must_use]
227 pub fn new(alpha: f32, gamma: f32) -> Self {
228 Self { alpha, gamma }
229 }
230
231 /// Returns the alpha parameter.
232 #[must_use]
233 pub fn alpha(&self) -> f32 {
234 self.alpha
235 }
236
237 /// Returns the gamma parameter.
238 #[must_use]
239 pub fn gamma(&self) -> f32 {
240 self.gamma
241 }
242}
243
244impl Loss for FocalLoss {
245 fn compute(&self, y_pred: &Vector<f32>, y_true: &Vector<f32>) -> f32 {
246 focal_loss(y_pred, y_true, self.alpha, self.gamma)
247 }
248
249 fn name(&self) -> &'static str {
250 "Focal"
251 }
252}
253
254/// `InfoNCE` / NT-Xent loss function (struct wrapper).
255#[derive(Debug, Clone, Copy)]
256pub struct InfoNCELoss {
257 temperature: f32,
258}
259
260impl InfoNCELoss {
261 /// Creates a new `InfoNCE` loss with given temperature.
262 ///
263 /// # Arguments
264 ///
265 /// * `temperature` - Temperature scaling (typically 0.07-0.5)
266 #[must_use]
267 pub fn new(temperature: f32) -> Self {
268 Self { temperature }
269 }
270
271 /// Returns the temperature parameter.
272 #[must_use]
273 pub fn temperature(&self) -> f32 {
274 self.temperature
275 }
276
277 /// Compute `InfoNCE` loss for contrastive learning.
278 #[must_use]
279 pub fn compute_contrastive(
280 &self,
281 anchor: &Vector<f32>,
282 positive: &Vector<f32>,
283 negatives: &[Vector<f32>],
284 ) -> f32 {
285 info_nce_loss(anchor, positive, negatives, self.temperature)
286 }
287}
288
289/// Dice loss for segmentation tasks.
290///
291/// Measures overlap between predicted and ground truth masks:
292/// ```text
293/// Dice = 2 * |X ∩ Y| / (|X| + |Y|)
294/// Loss = 1 - Dice
295/// ```
296///
297/// # Arguments
298/// * `y_pred` - Predicted probabilities (0-1)
299/// * `y_true` - Ground truth binary mask (0 or 1)
300/// * `smooth` - Smoothing factor to avoid division by zero
301#[must_use]
302pub fn dice_loss(y_pred: &Vector<f32>, y_true: &Vector<f32>, smooth: f32) -> f32 {
303 assert_eq!(y_pred.len(), y_true.len());
304
305 let mut intersection = 0.0;
306 let mut pred_sum = 0.0;
307 let mut true_sum = 0.0;
308
309 for i in 0..y_pred.len() {
310 intersection += y_pred[i] * y_true[i];
311 pred_sum += y_pred[i];
312 true_sum += y_true[i];
313 }
314
315 let dice = (2.0 * intersection + smooth) / (pred_sum + true_sum + smooth);
316 1.0 - dice
317}
318
319/// Hinge loss for SVM-style margin classification.
320///
321/// ```text
322/// L = max(0, margin - y_true * y_pred)
323/// ```
324///
325/// # Arguments
326/// * `y_pred` - Predicted scores (raw, not probabilities)
327/// * `y_true` - True labels (-1 or 1)
328/// * `margin` - Margin threshold (typically 1.0)
329#[must_use]
330pub fn hinge_loss(y_pred: &Vector<f32>, y_true: &Vector<f32>, margin: f32) -> f32 {
331 assert_eq!(y_pred.len(), y_true.len());
332
333 let mut sum = 0.0;
334 for i in 0..y_pred.len() {
335 let loss = (margin - y_true[i] * y_pred[i]).max(0.0);
336 sum += loss;
337 }
338 sum / y_pred.len() as f32
339}
340
341/// Squared hinge loss (smoother gradient).
342#[must_use]
343pub fn squared_hinge_loss(y_pred: &Vector<f32>, y_true: &Vector<f32>, margin: f32) -> f32 {
344 assert_eq!(y_pred.len(), y_true.len());
345
346 let mut sum = 0.0;
347 for i in 0..y_pred.len() {
348 let loss = (margin - y_true[i] * y_pred[i]).max(0.0);
349 sum += loss * loss;
350 }
351 sum / y_pred.len() as f32
352}
353
354/// Dice loss struct wrapper.
355#[derive(Debug, Clone, Copy)]
356pub struct DiceLoss {
357 smooth: f32,
358}
359
360impl DiceLoss {
361 #[must_use]
362 pub fn new(smooth: f32) -> Self {
363 Self { smooth }
364 }
365
366 #[must_use]
367 pub fn smooth(&self) -> f32 {
368 self.smooth
369 }
370}
371
372impl Loss for DiceLoss {
373 fn compute(&self, y_pred: &Vector<f32>, y_true: &Vector<f32>) -> f32 {
374 dice_loss(y_pred, y_true, self.smooth)
375 }
376
377 fn name(&self) -> &'static str {
378 "Dice"
379 }
380}
381
382/// Hinge loss struct wrapper.
383#[derive(Debug, Clone, Copy)]
384pub struct HingeLoss {
385 margin: f32,
386}
387
388impl HingeLoss {
389 #[must_use]
390 pub fn new(margin: f32) -> Self {
391 Self { margin }
392 }
393
394 #[must_use]
395 pub fn margin(&self) -> f32 {
396 self.margin
397 }
398}
399
400impl Loss for HingeLoss {
401 fn compute(&self, y_pred: &Vector<f32>, y_true: &Vector<f32>) -> f32 {
402 hinge_loss(y_pred, y_true, self.margin)
403 }
404
405 fn name(&self) -> &'static str {
406 "Hinge"
407 }
408}
409
410/// Connectionist Temporal Classification (CTC) Loss.
411///
412/// Used for sequence-to-sequence tasks where alignment is unknown.
413/// Common in speech recognition and OCR.
414///
415/// Reference: Graves et al., "Connectionist Temporal Classification" (2006)
416#[derive(Debug, Clone)]
417pub struct CTCLoss {
418 blank_idx: usize,
419}