use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::Array1;
#[derive(Debug, Clone)]
pub struct AlignmentBatch {
pub chosen_logprobs: Array1<f64>,
pub rejected_logprobs: Array1<f64>,
pub ref_chosen_logprobs: Option<Array1<f64>>,
pub ref_rejected_logprobs: Option<Array1<f64>>,
pub chosen_lengths: Option<Array1<f64>>,
pub rejected_lengths: Option<Array1<f64>>,
pub rewards: Option<Array1<f64>>,
pub labels: Option<Vec<bool>>,
}
impl AlignmentBatch {
pub fn new(chosen: Array1<f64>, rejected: Array1<f64>) -> Self {
Self {
chosen_logprobs: chosen,
rejected_logprobs: rejected,
ref_chosen_logprobs: None,
ref_rejected_logprobs: None,
chosen_lengths: None,
rejected_lengths: None,
rewards: None,
labels: None,
}
}
pub fn with_reference(mut self, ref_chosen: Array1<f64>, ref_rejected: Array1<f64>) -> Self {
self.ref_chosen_logprobs = Some(ref_chosen);
self.ref_rejected_logprobs = Some(ref_rejected);
self
}
pub fn with_rewards(mut self, rewards: Array1<f64>) -> Self {
self.rewards = Some(rewards);
self
}
pub fn with_labels(mut self, labels: Vec<bool>) -> Self {
self.labels = Some(labels);
self
}
pub fn with_lengths(mut self, chosen: Array1<f64>, rejected: Array1<f64>) -> Self {
self.chosen_lengths = Some(chosen);
self.rejected_lengths = Some(rejected);
self
}
pub fn batch_size(&self) -> usize {
self.chosen_logprobs.len()
}
fn validate_lengths(&self, ctx: &str) -> Result<usize> {
let n = self.chosen_logprobs.len();
if self.rejected_logprobs.len() != n {
return Err(NeuralError::DimensionMismatch(format!(
"{ctx}: chosen_logprobs length {n} != rejected_logprobs length {}",
self.rejected_logprobs.len()
)));
}
Ok(n)
}
}
pub trait AlignmentLoss {
fn compute_loss(&self, batch: &AlignmentBatch) -> Result<f64>;
}
#[derive(Debug, Clone)]
pub struct GrpoConfig {
pub beta: f64,
pub epsilon: f64,
pub group_size: usize,
pub use_token_level_kl: bool,
}
impl Default for GrpoConfig {
fn default() -> Self {
Self {
beta: 0.04,
epsilon: 0.2,
group_size: 8,
use_token_level_kl: false,
}
}
}
#[derive(Debug, Clone)]
pub struct GrpoLoss {
pub config: GrpoConfig,
}
impl GrpoLoss {
pub fn new(config: GrpoConfig) -> Self {
Self { config }
}
}
impl AlignmentLoss for GrpoLoss {
fn compute_loss(&self, batch: &AlignmentBatch) -> Result<f64> {
let rewards = batch.rewards.as_ref().ok_or_else(|| {
NeuralError::InvalidArgument("GRPO requires rewards".into())
})?;
let ref_lp = batch.ref_chosen_logprobs.as_ref().ok_or_else(|| {
NeuralError::InvalidArgument("GRPO requires reference logprobs".into())
})?;
let n = batch.validate_lengths("GrpoLoss")?;
if n == 0 {
return Ok(0.0);
}
if rewards.len() != n {
return Err(NeuralError::DimensionMismatch(format!(
"GrpoLoss: rewards length {} != batch size {n}",
rewards.len()
)));
}
if ref_lp.len() != n {
return Err(NeuralError::DimensionMismatch(format!(
"GrpoLoss: ref_chosen_logprobs length {} != batch size {n}",
ref_lp.len()
)));
}
let policy_lp = &batch.chosen_logprobs;
let mean_r: f64 = rewards.iter().sum::<f64>() / n as f64;
let var_r: f64 = rewards.iter().map(|&r| (r - mean_r).powi(2)).sum::<f64>() / n as f64;
let std_r = var_r.sqrt().max(1e-8);
let mut total_loss = 0.0_f64;
for i in 0..n {
let advantage = (rewards[i] - mean_r) / std_r;
let log_ratio = policy_lp[i] - ref_lp[i];
let ratio = log_ratio.exp();
let clipped_ratio = ratio.clamp(1.0 - self.config.epsilon, 1.0 + self.config.epsilon);
let surrogate = -f64::min(ratio * advantage, clipped_ratio * advantage);
let kl_approx = log_ratio;
total_loss += surrogate + self.config.beta * kl_approx;
}
Ok(total_loss / n as f64)
}
}
#[derive(Debug, Clone)]
pub struct SimpoConfig {
pub beta: f64,
pub gamma: f64,
pub length_normalize: bool,
}
impl Default for SimpoConfig {
fn default() -> Self {
Self {
beta: 2.5,
gamma: 1.4,
length_normalize: true,
}
}
}
#[derive(Debug, Clone)]
pub struct SimpoLoss {
pub config: SimpoConfig,
}
impl SimpoLoss {
pub fn new(config: SimpoConfig) -> Self {
Self { config }
}
}
impl AlignmentLoss for SimpoLoss {
fn compute_loss(&self, batch: &AlignmentBatch) -> Result<f64> {
let n = batch.validate_lengths("SimpoLoss")?;
if n == 0 {
return Ok(0.0);
}
let mut total = 0.0_f64;
for i in 0..n {
let mut r_chosen = batch.chosen_logprobs[i];
let mut r_rejected = batch.rejected_logprobs[i];
if self.config.length_normalize {
if let Some(ref lengths) = batch.chosen_lengths {
r_chosen /= lengths[i].max(1.0);
}
if let Some(ref lengths) = batch.rejected_lengths {
r_rejected /= lengths[i].max(1.0);
}
}
let margin = self.config.beta * (r_chosen - r_rejected) - self.config.gamma;
let loss_i = softplus_neg(margin);
total += loss_i;
}
Ok(total / n as f64)
}
}
#[derive(Debug, Clone)]
pub struct IpoConfig {
pub tau: f64,
}
impl Default for IpoConfig {
fn default() -> Self {
Self { tau: 0.1 }
}
}
#[derive(Debug, Clone)]
pub struct IpoLoss {
pub config: IpoConfig,
}
impl IpoLoss {
pub fn new(config: IpoConfig) -> Self {
Self { config }
}
}
impl AlignmentLoss for IpoLoss {
fn compute_loss(&self, batch: &AlignmentBatch) -> Result<f64> {
let ref_chosen = batch.ref_chosen_logprobs.as_ref().ok_or_else(|| {
NeuralError::InvalidArgument("IPO requires ref_chosen_logprobs".into())
})?;
let ref_rejected = batch.ref_rejected_logprobs.as_ref().ok_or_else(|| {
NeuralError::InvalidArgument("IPO requires ref_rejected_logprobs".into())
})?;
let n = batch.validate_lengths("IpoLoss")?;
if n == 0 {
return Ok(0.0);
}
if ref_chosen.len() != n {
return Err(NeuralError::DimensionMismatch(format!(
"IpoLoss: ref_chosen_logprobs length {} != batch size {n}",
ref_chosen.len()
)));
}
if ref_rejected.len() != n {
return Err(NeuralError::DimensionMismatch(format!(
"IpoLoss: ref_rejected_logprobs length {} != batch size {n}",
ref_rejected.len()
)));
}
let target = 1.0 / (2.0 * self.config.tau);
let mut total = 0.0_f64;
for i in 0..n {
let h = (batch.chosen_logprobs[i] - ref_chosen[i])
- (batch.rejected_logprobs[i] - ref_rejected[i]);
total += (h - target).powi(2);
}
Ok(total / n as f64)
}
}
#[derive(Debug, Clone)]
pub struct KtoConfig {
pub beta: f64,
pub desirable_weight: f64,
pub undesirable_weight: f64,
}
impl Default for KtoConfig {
fn default() -> Self {
Self {
beta: 0.1,
desirable_weight: 1.0,
undesirable_weight: 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct KtoLoss {
pub config: KtoConfig,
pub z_ref: f64,
}
impl KtoLoss {
pub fn new(config: KtoConfig) -> Self {
Self { config, z_ref: 0.0 }
}
pub fn update_z_ref(&mut self, kl_estimate: f64, momentum: f64) {
self.z_ref = momentum * self.z_ref + (1.0 - momentum) * kl_estimate;
}
}
impl AlignmentLoss for KtoLoss {
fn compute_loss(&self, batch: &AlignmentBatch) -> Result<f64> {
let labels = batch.labels.as_ref().ok_or_else(|| {
NeuralError::InvalidArgument("KTO requires labels".into())
})?;
let ref_lp = batch.ref_chosen_logprobs.as_ref().ok_or_else(|| {
NeuralError::InvalidArgument("KTO requires ref_chosen_logprobs".into())
})?;
let n = batch.chosen_logprobs.len();
if n == 0 {
return Ok(0.0);
}
if labels.len() != n {
return Err(NeuralError::DimensionMismatch(format!(
"KtoLoss: labels length {} != batch size {n}",
labels.len()
)));
}
if ref_lp.len() != n {
return Err(NeuralError::DimensionMismatch(format!(
"KtoLoss: ref_chosen_logprobs length {} != batch size {n}",
ref_lp.len()
)));
}
let mut total = 0.0_f64;
for i in 0..n {
let kl_term = batch.chosen_logprobs[i] - ref_lp[i];
let loss_i = if labels[i] {
let logit = self.config.beta * kl_term - self.z_ref;
let sigma = sigmoid(logit);
self.config.desirable_weight * (1.0 - sigma)
} else {
let logit = self.z_ref - self.config.beta * kl_term;
let sigma = sigmoid(logit);
self.config.undesirable_weight * (1.0 - sigma)
};
total += loss_i;
}
Ok(total / n as f64)
}
}
#[derive(Debug, Clone)]
pub struct OrpoConfig {
pub lambda: f64,
}
impl Default for OrpoConfig {
fn default() -> Self {
Self { lambda: 0.1 }
}
}
#[derive(Debug, Clone)]
pub struct OrpoLoss {
pub config: OrpoConfig,
}
impl OrpoLoss {
pub fn new(config: OrpoConfig) -> Self {
Self { config }
}
}
impl AlignmentLoss for OrpoLoss {
fn compute_loss(&self, batch: &AlignmentBatch) -> Result<f64> {
let n = batch.validate_lengths("OrpoLoss")?;
if n == 0 {
return Ok(0.0);
}
let mut total = 0.0_f64;
for i in 0..n {
let lp_c = batch.chosen_logprobs[i];
let lp_r = batch.rejected_logprobs[i];
let p_c = lp_c.exp().min(1.0 - 1e-7).max(1e-15);
let p_r = lp_r.exp().min(1.0 - 1e-7).max(1e-15);
let odds_c = p_c / (1.0 - p_c);
let odds_r = p_r / (1.0 - p_r);
let odds_c = odds_c.max(1e-15);
let odds_r = odds_r.max(1e-15);
let log_or = (odds_c / odds_r).ln();
let sft_loss = -lp_c;
let or_penalty = softplus_neg(log_or);
total += sft_loss + self.config.lambda * or_penalty;
}
Ok(total / n as f64)
}
}
#[inline]
fn softplus_neg(x: f64) -> f64 {
if x >= 0.0 {
(-x).exp().ln_1p()
} else {
-x + x.exp().ln_1p()
}
}
#[inline]
fn sigmoid(x: f64) -> f64 {
let s = if x >= 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let e = x.exp();
e / (1.0 + e)
};
s.clamp(1e-15, 1.0 - 1e-15)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_grpo_loss_basic() {
let rewards = Array1::from(vec![1.0_f64, 0.0, -1.0]);
let policy_lp = Array1::from(vec![-1.0_f64, -1.5, -2.0]);
let ref_lp = Array1::from(vec![-1.2_f64, -1.4, -1.9]);
let batch = AlignmentBatch::new(policy_lp, Array1::from(vec![-2.0_f64, -2.5, -3.0]))
.with_rewards(rewards)
.with_reference(ref_lp, Array1::from(vec![-2.1_f64, -2.6, -3.1]));
let loss = GrpoLoss::new(GrpoConfig::default())
.compute_loss(&batch)
.expect("grpo loss");
assert!(loss.is_finite(), "grpo loss={loss}");
}
#[test]
fn test_grpo_empty_batch() {
let batch = AlignmentBatch::new(Array1::from(vec![]), Array1::from(vec![]))
.with_rewards(Array1::from(vec![]))
.with_reference(Array1::from(vec![]), Array1::from(vec![]));
let loss = GrpoLoss::new(GrpoConfig::default())
.compute_loss(&batch)
.expect("empty batch");
assert_eq!(loss, 0.0);
}
#[test]
fn test_grpo_missing_rewards_error() {
let batch = AlignmentBatch::new(
Array1::from(vec![-1.0_f64]),
Array1::from(vec![-2.0_f64]),
)
.with_reference(
Array1::from(vec![-1.1_f64]),
Array1::from(vec![-2.1_f64]),
);
let result = GrpoLoss::new(GrpoConfig::default()).compute_loss(&batch);
assert!(result.is_err(), "should fail without rewards");
}
#[test]
fn test_grpo_group_normalization_equal_rewards() {
let n = 4;
let rewards = Array1::from(vec![1.0_f64; n]);
let policy_lp = Array1::from(vec![-1.0_f64; n]);
let ref_lp = Array1::from(vec![-1.0_f64; n]);
let batch = AlignmentBatch::new(policy_lp, Array1::from(vec![-2.0_f64; n]))
.with_rewards(rewards)
.with_reference(ref_lp, Array1::from(vec![-2.0_f64; n]));
let config = GrpoConfig { beta: 0.04, ..Default::default() };
let loss = GrpoLoss::new(config).compute_loss(&batch).expect("loss");
assert!(loss.abs() < 1e-10, "expected ~0 loss, got {loss}");
}
#[test]
fn test_simpo_loss_length_normalized() {
let chosen_lp = Array1::from(vec![-4.0_f64, -6.0]);
let rejected_lp = Array1::from(vec![-2.0_f64, -3.0]);
let chosen_len = Array1::from(vec![4.0_f64, 6.0]); let rejected_len = Array1::from(vec![2.0_f64, 3.0]);
let batch = AlignmentBatch::new(chosen_lp, rejected_lp)
.with_lengths(chosen_len, rejected_len);
let config = SimpoConfig { length_normalize: true, ..Default::default() };
let loss = SimpoLoss::new(config).compute_loss(&batch).expect("loss");
assert!(loss.is_finite() && loss > 0.0, "loss={loss}");
}
#[test]
fn test_simpo_no_length_norm() {
let chosen_lp = Array1::from(vec![-1.0_f64, -1.5]);
let rejected_lp = Array1::from(vec![-2.5_f64, -3.0]);
let batch = AlignmentBatch::new(chosen_lp, rejected_lp);
let config = SimpoConfig { length_normalize: false, ..Default::default() };
let loss = SimpoLoss::new(config).compute_loss(&batch).expect("loss");
assert!(loss.is_finite(), "loss={loss}");
}
#[test]
fn test_simpo_empty_batch() {
let batch = AlignmentBatch::new(Array1::from(vec![]), Array1::from(vec![]));
let loss = SimpoLoss::new(SimpoConfig::default())
.compute_loss(&batch)
.expect("empty");
assert_eq!(loss, 0.0);
}
#[test]
fn test_simpo_margin_effect() {
let chosen_lp = Array1::from(vec![-1.0_f64, -1.5]);
let rejected_lp = Array1::from(vec![-2.0_f64, -2.5]);
let batch = AlignmentBatch::new(chosen_lp, rejected_lp);
let loss_low_gamma = SimpoLoss::new(SimpoConfig { gamma: 0.5, ..Default::default() })
.compute_loss(&batch)
.expect("loss_low");
let loss_high_gamma = SimpoLoss::new(SimpoConfig { gamma: 3.0, ..Default::default() })
.compute_loss(&batch)
.expect("loss_high");
assert!(
loss_high_gamma > loss_low_gamma,
"high gamma={loss_high_gamma} should > low gamma={loss_low_gamma}"
);
}
#[test]
fn test_ipo_loss_formula() {
let chosen_lp = Array1::from(vec![-1.0_f64]);
let rejected_lp = Array1::from(vec![-2.0_f64]);
let ref_chosen = Array1::from(vec![-1.2_f64]);
let ref_rejected = Array1::from(vec![-1.9_f64]);
let batch = AlignmentBatch::new(chosen_lp, rejected_lp)
.with_reference(ref_chosen, ref_rejected);
let config = IpoConfig { tau: 0.1 };
let loss = IpoLoss::new(config).compute_loss(&batch).expect("loss");
let expected = (-4.7_f64).powi(2);
assert!((loss - expected).abs() < 1e-9, "loss={loss}, expected={expected}");
}
#[test]
fn test_ipo_missing_reference() {
let batch = AlignmentBatch::new(
Array1::from(vec![-1.0_f64]),
Array1::from(vec![-2.0_f64]),
);
let result = IpoLoss::new(IpoConfig::default()).compute_loss(&batch);
assert!(result.is_err(), "should fail without reference");
}
#[test]
fn test_ipo_empty_batch() {
let batch = AlignmentBatch::new(Array1::from(vec![]), Array1::from(vec![]))
.with_reference(Array1::from(vec![]), Array1::from(vec![]));
let loss = IpoLoss::new(IpoConfig::default())
.compute_loss(&batch)
.expect("empty");
assert_eq!(loss, 0.0);
}
#[test]
fn test_kto_desirable_all_positive() {
let n = 3;
let chosen_lp = Array1::from(vec![-0.5_f64, -1.0, -1.5]);
let ref_lp = Array1::from(vec![-1.0_f64, -1.5, -2.0]);
let labels = vec![true; n];
let batch = AlignmentBatch::new(chosen_lp, Array1::from(vec![-2.0_f64; n]))
.with_reference(ref_lp, Array1::from(vec![-2.5_f64; n]))
.with_labels(labels);
let loss = KtoLoss::new(KtoConfig::default())
.compute_loss(&batch)
.expect("loss");
assert!(loss.is_finite() && loss >= 0.0, "kto desirable loss={loss}");
}
#[test]
fn test_kto_undesirable_all_negative() {
let n = 3;
let chosen_lp = Array1::from(vec![-0.5_f64, -1.0, -1.5]);
let ref_lp = Array1::from(vec![-1.0_f64, -1.5, -2.0]);
let labels = vec![false; n];
let batch = AlignmentBatch::new(chosen_lp, Array1::from(vec![-2.0_f64; n]))
.with_reference(ref_lp, Array1::from(vec![-2.5_f64; n]))
.with_labels(labels);
let loss = KtoLoss::new(KtoConfig::default())
.compute_loss(&batch)
.expect("loss");
assert!(loss.is_finite() && loss >= 0.0, "kto undesirable loss={loss}");
}
#[test]
fn test_kto_mixed_labels() {
let chosen_lp = Array1::from(vec![-1.0_f64, -2.0, -1.5]);
let ref_lp = Array1::from(vec![-1.2_f64, -1.8, -1.6]);
let labels = vec![true, false, true];
let batch = AlignmentBatch::new(
chosen_lp,
Array1::from(vec![-2.5_f64, -3.0, -2.8]),
)
.with_reference(ref_lp, Array1::from(vec![-2.6_f64, -3.1, -2.9]))
.with_labels(labels);
let loss = KtoLoss::new(KtoConfig::default())
.compute_loss(&batch)
.expect("loss");
assert!(loss.is_finite(), "kto mixed loss={loss}");
}
#[test]
fn test_kto_missing_labels_error() {
let batch = AlignmentBatch::new(
Array1::from(vec![-1.0_f64]),
Array1::from(vec![-2.0_f64]),
)
.with_reference(
Array1::from(vec![-1.1_f64]),
Array1::from(vec![-2.1_f64]),
);
let result = KtoLoss::new(KtoConfig::default()).compute_loss(&batch);
assert!(result.is_err(), "should fail without labels");
}
#[test]
fn test_kto_empty_batch() {
let batch = AlignmentBatch::new(Array1::from(vec![]), Array1::from(vec![]))
.with_reference(Array1::from(vec![]), Array1::from(vec![]))
.with_labels(vec![]);
let loss = KtoLoss::new(KtoConfig::default())
.compute_loss(&batch)
.expect("empty");
assert_eq!(loss, 0.0);
}
#[test]
fn test_orpo_basic_positive_loss() {
let chosen_lp = Array1::from(vec![-1.0_f64, -1.5, -0.8]);
let rejected_lp = Array1::from(vec![-2.0_f64, -3.0, -2.5]);
let batch = AlignmentBatch::new(chosen_lp, rejected_lp);
let loss = OrpoLoss::new(OrpoConfig::default())
.compute_loss(&batch)
.expect("orpo loss");
assert!(loss.is_finite() && loss > 0.0, "orpo loss={loss}");
}
#[test]
fn test_orpo_chosen_better_lower_loss() {
let chosen_lp_good = Array1::from(vec![-0.1_f64, -0.1]);
let rejected_lp_bad = Array1::from(vec![-10.0_f64, -10.0]);
let batch_good = AlignmentBatch::new(chosen_lp_good, rejected_lp_bad);
let chosen_lp_poor = Array1::from(vec![-3.0_f64, -3.0]);
let rejected_lp_close = Array1::from(vec![-3.1_f64, -3.1]);
let batch_poor = AlignmentBatch::new(chosen_lp_poor, rejected_lp_close);
let config = OrpoConfig { lambda: 0.5 };
let loss_good = OrpoLoss::new(config.clone())
.compute_loss(&batch_good)
.expect("good");
let loss_poor = OrpoLoss::new(config)
.compute_loss(&batch_poor)
.expect("poor");
assert!(loss_good.is_finite() && loss_poor.is_finite());
}
#[test]
fn test_orpo_empty_batch() {
let batch = AlignmentBatch::new(Array1::from(vec![]), Array1::from(vec![]));
let loss = OrpoLoss::new(OrpoConfig::default())
.compute_loss(&batch)
.expect("empty");
assert_eq!(loss, 0.0);
}
#[test]
fn test_alignment_batch_builder_all_methods() {
let n = 3;
let batch = AlignmentBatch::new(
Array1::from(vec![-1.0_f64; n]),
Array1::from(vec![-2.0_f64; n]),
)
.with_reference(
Array1::from(vec![-1.1_f64; n]),
Array1::from(vec![-2.1_f64; n]),
)
.with_rewards(Array1::from(vec![1.0_f64; n]))
.with_labels(vec![true, false, true])
.with_lengths(
Array1::from(vec![10.0_f64; n]),
Array1::from(vec![8.0_f64; n]),
);
assert_eq!(batch.batch_size(), n);
assert!(batch.ref_chosen_logprobs.is_some());
assert!(batch.ref_rejected_logprobs.is_some());
assert!(batch.rewards.is_some());
assert!(batch.labels.is_some());
assert!(batch.chosen_lengths.is_some());
assert!(batch.rejected_lengths.is_some());
}
#[test]
fn test_all_losses_zero_batch() {
let batch_empty = AlignmentBatch::new(Array1::from(vec![]), Array1::from(vec![]));
assert_eq!(
SimpoLoss::new(SimpoConfig::default())
.compute_loss(&batch_empty)
.expect("simpo empty"),
0.0
);
assert_eq!(
OrpoLoss::new(OrpoConfig::default())
.compute_loss(&batch_empty)
.expect("orpo empty"),
0.0
);
let batch_ipo = AlignmentBatch::new(Array1::from(vec![]), Array1::from(vec![]))
.with_reference(Array1::from(vec![]), Array1::from(vec![]));
assert_eq!(
IpoLoss::new(IpoConfig::default())
.compute_loss(&batch_ipo)
.expect("ipo empty"),
0.0
);
let batch_grpo = AlignmentBatch::new(Array1::from(vec![]), Array1::from(vec![]))
.with_rewards(Array1::from(vec![]))
.with_reference(Array1::from(vec![]), Array1::from(vec![]));
assert_eq!(
GrpoLoss::new(GrpoConfig::default())
.compute_loss(&batch_grpo)
.expect("grpo empty"),
0.0
);
let batch_kto = AlignmentBatch::new(Array1::from(vec![]), Array1::from(vec![]))
.with_reference(Array1::from(vec![]), Array1::from(vec![]))
.with_labels(vec![]);
assert_eq!(
KtoLoss::new(KtoConfig::default())
.compute_loss(&batch_kto)
.expect("kto empty"),
0.0
);
}
}