#[derive(Debug, Clone)]
pub struct OrpoLoss {
pub lambda: f64,
}
impl Default for OrpoLoss {
fn default() -> Self {
Self { lambda: 0.5 }
}
}
impl OrpoLoss {
pub fn new(lambda: f64) -> Self {
Self { lambda }
}
fn odds(prob: f64) -> f64 {
let p = prob.clamp(1e-10, 1.0 - 1e-10);
p / (1.0 - p)
}
pub fn compute_alignment_loss(&self, chosen_prob: f64, rejected_prob: f64) -> f64 {
let chosen_odds = Self::odds(chosen_prob);
let rejected_odds = Self::odds(rejected_prob);
let log_odds_ratio = (chosen_odds / rejected_odds).ln();
-log_sigmoid(log_odds_ratio)
}
pub fn compute(&self, sft_loss: f64, chosen_prob: f64, rejected_prob: f64) -> f64 {
let alignment_loss = self.compute_alignment_loss(chosen_prob, rejected_prob);
sft_loss + self.lambda * alignment_loss
}
pub fn compute_batch(
&self,
sft_losses: &[f64],
chosen_probs: &[f64],
rejected_probs: &[f64],
) -> f64 {
if sft_losses.is_empty() {
return 0.0;
}
let sum: f64 = sft_losses
.iter()
.zip(chosen_probs.iter())
.zip(rejected_probs.iter())
.map(|((sft, cp), rp)| self.compute(*sft, *cp, *rp))
.sum();
sum / sft_losses.len() as f64
}
}
fn log_sigmoid(x: f64) -> f64 {
if x >= 0.0 {
-((1.0 + (-x).exp()).ln())
} else {
x - (1.0 + x.exp()).ln()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_orpo_loss() {
let loss = OrpoLoss::new(0.5);
let l1 = loss.compute_alignment_loss(0.8, 0.3);
let l2 = loss.compute_alignment_loss(0.5, 0.5);
assert!(l1 < l2); }
#[test]
fn test_orpo_full_loss() {
let loss = OrpoLoss::new(0.5);
let total = loss.compute(2.0, 0.7, 0.3);
assert!(total > 2.0); }
#[test]
fn test_orpo_batch() {
let loss = OrpoLoss::new(0.5);
let batch_loss = loss.compute_batch(&[2.0, 1.5], &[0.7, 0.8], &[0.3, 0.4]);
let individual_avg = (loss.compute(2.0, 0.7, 0.3) + loss.compute(1.5, 0.8, 0.4)) / 2.0;
assert!((batch_loss - individual_avg).abs() < 1e-10);
}
#[test]
fn test_odds() {
assert!((OrpoLoss::odds(0.5) - 1.0).abs() < 1e-10);
assert!(OrpoLoss::odds(0.8) > 1.0);
assert!(OrpoLoss::odds(0.2) < 1.0);
}
}