Skip to main content

aprender/loss/
loss.rs

1
2/// `InfoNCE` (Noise Contrastive Estimation) loss for contrastive learning.
3///
4/// Also known as NT-Xent (Normalized Temperature-scaled Cross Entropy).
5///
6/// ```text
7/// L = -log(exp(sim(z_i, z_j)/τ) / Σ_k exp(sim(z_i, z_k)/τ))
8/// ```
9///
10/// # Arguments
11///
12/// * `anchor` - Anchor embedding
13/// * `positive` - Positive example embedding
14/// * `negatives` - Slice of negative example embeddings
15/// * `temperature` - Temperature scaling parameter (typically 0.07-0.5)
16///
17/// # Returns
18///
19/// The `InfoNCE` loss value
20///
21/// # Example
22///
23/// ```
24/// use aprender::loss::info_nce_loss;
25/// use aprender::primitives::Vector;
26///
27/// let anchor = Vector::from_slice(&[1.0, 0.0, 0.0]);
28/// let positive = Vector::from_slice(&[0.9, 0.1, 0.0]);
29/// let negatives = vec![
30///     Vector::from_slice(&[0.0, 1.0, 0.0]),
31///     Vector::from_slice(&[0.0, 0.0, 1.0]),
32/// ];
33///
34/// let loss = info_nce_loss(&anchor, &positive, &negatives, 0.1);
35/// assert!(loss >= 0.0);
36/// ```
37#[must_use]
38pub fn info_nce_loss(
39    anchor: &Vector<f32>,
40    positive: &Vector<f32>,
41    negatives: &[Vector<f32>],
42    temperature: f32,
43) -> f32 {
44    assert_eq!(
45        anchor.len(),
46        positive.len(),
47        "Anchor and positive must have same dimension"
48    );
49    for neg in negatives {
50        assert_eq!(
51            anchor.len(),
52            neg.len(),
53            "All embeddings must have same dimension"
54        );
55    }
56    assert!(temperature > 0.0, "Temperature must be positive");
57
58    // Compute similarity with positive
59    let sim_pos = cosine_similarity(anchor, positive) / temperature;
60
61    // Compute log-sum-exp over all (positive + negatives)
62    let mut max_sim = sim_pos;
63    for neg in negatives {
64        let sim_neg = cosine_similarity(anchor, neg) / temperature;
65        max_sim = max_sim.max(sim_neg);
66    }
67
68    // Numerically stable log-sum-exp
69    let mut sum_exp = (sim_pos - max_sim).exp();
70    for neg in negatives {
71        let sim_neg = cosine_similarity(anchor, neg) / temperature;
72        sum_exp += (sim_neg - max_sim).exp();
73    }
74
75    // -log(exp(sim_pos) / sum_exp) = -sim_pos + log(sum_exp)
76    -sim_pos + max_sim + sum_exp.ln()
77}
78
79/// Focal loss for class imbalance (spec: more-learning-specs.md §18).
80///
81/// Down-weights easy examples, focuses on hard examples:
82///
83/// ```text
84/// FL(p) = -α * (1 - p)^γ * log(p)
85/// ```
86///
87/// # Arguments
88///
89/// * `predictions` - Predicted probabilities (after sigmoid/softmax)
90/// * `targets` - Binary targets (0 or 1)
91/// * `alpha` - Balancing factor (typically 0.25 for rare class)
92/// * `gamma` - Focusing parameter (typically 2.0)
93///
94/// # Returns
95///
96/// The focal loss value
97///
98/// # Example
99///
100/// ```
101/// use aprender::loss::focal_loss;
102/// use aprender::primitives::Vector;
103///
104/// let predictions = Vector::from_slice(&[0.9, 0.1, 0.8]);
105/// let targets = Vector::from_slice(&[1.0, 0.0, 1.0]);
106///
107/// let loss = focal_loss(&predictions, &targets, 0.25, 2.0);
108/// assert!(loss >= 0.0);
109/// ```
110#[must_use]
111pub fn focal_loss(predictions: &Vector<f32>, targets: &Vector<f32>, alpha: f32, gamma: f32) -> f32 {
112    assert_eq!(
113        predictions.len(),
114        targets.len(),
115        "Predictions and targets must have same length"
116    );
117
118    let n = predictions.len() as f32;
119    let mut sum = 0.0;
120
121    for i in 0..predictions.len() {
122        let p = predictions[i].clamp(1e-7, 1.0 - 1e-7);
123        let t = targets[i];
124
125        // For positive class (t=1): -α * (1-p)^γ * log(p)
126        // For negative class (t=0): -(1-α) * p^γ * log(1-p)
127        let loss = if t > 0.5 {
128            -alpha * (1.0 - p).powf(gamma) * p.ln()
129        } else {
130            -(1.0 - alpha) * p.powf(gamma) * (1.0 - p).ln()
131        };
132
133        sum += loss;
134    }
135
136    sum / n
137}
138
139/// KL Divergence loss between two probability distributions.
140///
141/// ```text
142/// KL(P || Q) = Σ P(x) * log(P(x) / Q(x))
143/// ```
144///
145/// # Arguments
146///
147/// * `p` - True distribution (targets)
148/// * `q` - Predicted distribution
149///
150/// # Returns
151///
152/// The KL divergence (always >= 0)
153///
154/// # Example
155///
156/// ```
157/// use aprender::loss::kl_divergence;
158/// use aprender::primitives::Vector;
159///
160/// let p = Vector::from_slice(&[0.5, 0.3, 0.2]);
161/// let q = Vector::from_slice(&[0.4, 0.4, 0.2]);
162///
163/// let kl = kl_divergence(&p, &q);
164/// assert!(kl >= 0.0);
165/// ```
166#[must_use]
167pub fn kl_divergence(p: &Vector<f32>, q: &Vector<f32>) -> f32 {
168    assert_eq!(p.len(), q.len(), "Distributions must have same length");
169
170    let mut sum = 0.0;
171    for i in 0..p.len() {
172        if p[i] > 1e-10 {
173            let q_safe = q[i].max(1e-10);
174            sum += p[i] * (p[i] / q_safe).ln();
175        }
176    }
177
178    sum
179}
180
181/// Triplet loss function (struct wrapper).
182#[derive(Debug, Clone, Copy)]
183pub struct TripletLoss {
184    margin: f32,
185}
186
187impl TripletLoss {
188    /// Creates a new Triplet loss with the given margin.
189    #[must_use]
190    pub fn new(margin: f32) -> Self {
191        Self { margin }
192    }
193
194    /// Returns the margin parameter.
195    #[must_use]
196    pub fn margin(&self) -> f32 {
197        self.margin
198    }
199
200    /// Compute triplet loss for given embeddings.
201    #[must_use]
202    pub fn compute_triplet(
203        &self,
204        anchor: &Vector<f32>,
205        positive: &Vector<f32>,
206        negative: &Vector<f32>,
207    ) -> f32 {
208        triplet_loss(anchor, positive, negative, self.margin)
209    }
210}
211
212/// Focal loss function (struct wrapper).
213#[derive(Debug, Clone, Copy)]
214pub struct FocalLoss {
215    alpha: f32,
216    gamma: f32,
217}
218
219impl FocalLoss {
220    /// Creates a new Focal loss with given parameters.
221    ///
222    /// # Arguments
223    ///
224    /// * `alpha` - Balancing factor for rare class (typically 0.25)
225    /// * `gamma` - Focusing parameter (typically 2.0)
226    #[must_use]
227    pub fn new(alpha: f32, gamma: f32) -> Self {
228        Self { alpha, gamma }
229    }
230
231    /// Returns the alpha parameter.
232    #[must_use]
233    pub fn alpha(&self) -> f32 {
234        self.alpha
235    }
236
237    /// Returns the gamma parameter.
238    #[must_use]
239    pub fn gamma(&self) -> f32 {
240        self.gamma
241    }
242}
243
244impl Loss for FocalLoss {
245    fn compute(&self, y_pred: &Vector<f32>, y_true: &Vector<f32>) -> f32 {
246        focal_loss(y_pred, y_true, self.alpha, self.gamma)
247    }
248
249    fn name(&self) -> &'static str {
250        "Focal"
251    }
252}
253
254/// `InfoNCE` / NT-Xent loss function (struct wrapper).
255#[derive(Debug, Clone, Copy)]
256pub struct InfoNCELoss {
257    temperature: f32,
258}
259
260impl InfoNCELoss {
261    /// Creates a new `InfoNCE` loss with given temperature.
262    ///
263    /// # Arguments
264    ///
265    /// * `temperature` - Temperature scaling (typically 0.07-0.5)
266    #[must_use]
267    pub fn new(temperature: f32) -> Self {
268        Self { temperature }
269    }
270
271    /// Returns the temperature parameter.
272    #[must_use]
273    pub fn temperature(&self) -> f32 {
274        self.temperature
275    }
276
277    /// Compute `InfoNCE` loss for contrastive learning.
278    #[must_use]
279    pub fn compute_contrastive(
280        &self,
281        anchor: &Vector<f32>,
282        positive: &Vector<f32>,
283        negatives: &[Vector<f32>],
284    ) -> f32 {
285        info_nce_loss(anchor, positive, negatives, self.temperature)
286    }
287}
288
289/// Dice loss for segmentation tasks.
290///
291/// Measures overlap between predicted and ground truth masks:
292/// ```text
293/// Dice = 2 * |X ∩ Y| / (|X| + |Y|)
294/// Loss = 1 - Dice
295/// ```
296///
297/// # Arguments
298/// * `y_pred` - Predicted probabilities (0-1)
299/// * `y_true` - Ground truth binary mask (0 or 1)
300/// * `smooth` - Smoothing factor to avoid division by zero
301#[must_use]
302pub fn dice_loss(y_pred: &Vector<f32>, y_true: &Vector<f32>, smooth: f32) -> f32 {
303    assert_eq!(y_pred.len(), y_true.len());
304
305    let mut intersection = 0.0;
306    let mut pred_sum = 0.0;
307    let mut true_sum = 0.0;
308
309    for i in 0..y_pred.len() {
310        intersection += y_pred[i] * y_true[i];
311        pred_sum += y_pred[i];
312        true_sum += y_true[i];
313    }
314
315    let dice = (2.0 * intersection + smooth) / (pred_sum + true_sum + smooth);
316    1.0 - dice
317}
318
319/// Hinge loss for SVM-style margin classification.
320///
321/// ```text
322/// L = max(0, margin - y_true * y_pred)
323/// ```
324///
325/// # Arguments
326/// * `y_pred` - Predicted scores (raw, not probabilities)
327/// * `y_true` - True labels (-1 or 1)
328/// * `margin` - Margin threshold (typically 1.0)
329#[must_use]
330pub fn hinge_loss(y_pred: &Vector<f32>, y_true: &Vector<f32>, margin: f32) -> f32 {
331    assert_eq!(y_pred.len(), y_true.len());
332
333    let mut sum = 0.0;
334    for i in 0..y_pred.len() {
335        let loss = (margin - y_true[i] * y_pred[i]).max(0.0);
336        sum += loss;
337    }
338    sum / y_pred.len() as f32
339}
340
341/// Squared hinge loss (smoother gradient).
342#[must_use]
343pub fn squared_hinge_loss(y_pred: &Vector<f32>, y_true: &Vector<f32>, margin: f32) -> f32 {
344    assert_eq!(y_pred.len(), y_true.len());
345
346    let mut sum = 0.0;
347    for i in 0..y_pred.len() {
348        let loss = (margin - y_true[i] * y_pred[i]).max(0.0);
349        sum += loss * loss;
350    }
351    sum / y_pred.len() as f32
352}
353
354/// Dice loss struct wrapper.
355#[derive(Debug, Clone, Copy)]
356pub struct DiceLoss {
357    smooth: f32,
358}
359
360impl DiceLoss {
361    #[must_use]
362    pub fn new(smooth: f32) -> Self {
363        Self { smooth }
364    }
365
366    #[must_use]
367    pub fn smooth(&self) -> f32 {
368        self.smooth
369    }
370}
371
372impl Loss for DiceLoss {
373    fn compute(&self, y_pred: &Vector<f32>, y_true: &Vector<f32>) -> f32 {
374        dice_loss(y_pred, y_true, self.smooth)
375    }
376
377    fn name(&self) -> &'static str {
378        "Dice"
379    }
380}
381
382/// Hinge loss struct wrapper.
383#[derive(Debug, Clone, Copy)]
384pub struct HingeLoss {
385    margin: f32,
386}
387
388impl HingeLoss {
389    #[must_use]
390    pub fn new(margin: f32) -> Self {
391        Self { margin }
392    }
393
394    #[must_use]
395    pub fn margin(&self) -> f32 {
396        self.margin
397    }
398}
399
400impl Loss for HingeLoss {
401    fn compute(&self, y_pred: &Vector<f32>, y_true: &Vector<f32>) -> f32 {
402        hinge_loss(y_pred, y_true, self.margin)
403    }
404
405    fn name(&self) -> &'static str {
406        "Hinge"
407    }
408}
409
410/// Connectionist Temporal Classification (CTC) Loss.
411///
412/// Used for sequence-to-sequence tasks where alignment is unknown.
413/// Common in speech recognition and OCR.
414///
415/// Reference: Graves et al., "Connectionist Temporal Classification" (2006)
416#[derive(Debug, Clone)]
417pub struct CTCLoss {
418    blank_idx: usize,
419}