1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
//! Direct Preference Optimization (DPO) (GH-449)
//!
//! Implements the DPO algorithm from Rafailov et al. 2023:
//! "Direct Preference Optimization: Your Language Model is Secretly a Reward Model"
//!
//! DPO directly optimizes the policy using preference pairs (chosen/rejected)
//! without training a separate reward model (unlike RLHF/PPO).
//!
//! Loss: L_DPO = -E[log σ(β · (log π(y_w|x)/π_ref(y_w|x) - log π(y_l|x)/π_ref(y_l|x)))]
//!
//! where y_w = chosen response, y_l = rejected response, β = temperature
use crate::error::{AprenderError, Result};
/// Configuration for DPO training.
#[derive(Debug, Clone)]
pub struct DpoConfig {
/// Temperature parameter (β). Higher β = more conservative updates.
/// Typical range: 0.1 to 0.5
pub beta: f64,
/// Learning rate
pub learning_rate: f64,
/// Label smoothing (0.0 = none). Reduces overfitting to preference pairs.
pub label_smoothing: f64,
/// Whether to use reference model log-probs (standard DPO)
/// If false, uses SimPO variant (reference-free)
pub use_reference: bool,
/// Length normalization for SimPO variant
pub length_normalize: bool,
}
impl Default for DpoConfig {
fn default() -> Self {
Self {
beta: 0.1,
learning_rate: 5e-7,
label_smoothing: 0.0,
use_reference: true,
length_normalize: false,
}
}
}
impl DpoConfig {
/// Validate configuration constraints.
pub fn validate(&self) -> Result<()> {
if self.beta <= 0.0 || !self.beta.is_finite() {
return Err(AprenderError::FormatError {
message: format!("beta must be positive finite, got {}", self.beta),
});
}
if self.learning_rate <= 0.0 || !self.learning_rate.is_finite() {
return Err(AprenderError::FormatError {
message: format!(
"learning_rate must be positive finite, got {}",
self.learning_rate
),
});
}
if self.label_smoothing < 0.0 || self.label_smoothing >= 1.0 {
return Err(AprenderError::FormatError {
message: format!(
"label_smoothing must be in [0.0, 1.0), got {}",
self.label_smoothing
),
});
}
Ok(())
}
}
/// A preference pair for DPO training.
#[derive(Debug, Clone)]
pub struct PreferencePair {
/// Log-probability of chosen response under current policy
pub chosen_logprob: f64,
/// Log-probability of rejected response under current policy
pub rejected_logprob: f64,
/// Log-probability of chosen response under reference policy
pub ref_chosen_logprob: f64,
/// Log-probability of rejected response under reference policy
pub ref_rejected_logprob: f64,
}
/// DPO loss calculator.
#[derive(Debug, Clone)]
pub struct DpoLoss {
config: DpoConfig,
}
impl DpoLoss {
/// Create a new DPO loss calculator.
#[must_use]
pub fn new(config: DpoConfig) -> Self {
Self { config }
}
/// Compute DPO loss for a single preference pair.
///
/// L = -log σ(β · (log_ratio_chosen - log_ratio_rejected))
///
/// where log_ratio = log π(y|x) - log π_ref(y|x)
#[must_use]
pub fn compute(&self, pair: &PreferencePair) -> f64 {
let log_ratio_chosen = if self.config.use_reference {
pair.chosen_logprob - pair.ref_chosen_logprob
} else {
pair.chosen_logprob
};
let log_ratio_rejected = if self.config.use_reference {
pair.rejected_logprob - pair.ref_rejected_logprob
} else {
pair.rejected_logprob
};
let logit = self.config.beta * (log_ratio_chosen - log_ratio_rejected);
// -log σ(logit) with label smoothing
if self.config.label_smoothing > 0.0 {
let eps = self.config.label_smoothing;
// Smoothed: -(1-eps) * log σ(logit) - eps * log σ(-logit)
-(1.0 - eps) * log_sigmoid(logit) - eps * log_sigmoid(-logit)
} else {
-log_sigmoid(logit)
}
}
/// Compute DPO loss for a batch of preference pairs.
#[must_use]
pub fn compute_batch(&self, pairs: &[PreferencePair]) -> f64 {
if pairs.is_empty() {
return 0.0;
}
pairs.iter().map(|p| self.compute(p)).sum::<f64>() / pairs.len() as f64
}
/// Compute gradient of DPO loss w.r.t. policy log-probs.
///
/// Returns (grad_chosen, grad_rejected) — gradients for the policy's
/// log-probabilities of chosen and rejected responses.
#[must_use]
pub fn gradient(&self, pair: &PreferencePair) -> (f64, f64) {
let log_ratio_chosen = if self.config.use_reference {
pair.chosen_logprob - pair.ref_chosen_logprob
} else {
pair.chosen_logprob
};
let log_ratio_rejected = if self.config.use_reference {
pair.rejected_logprob - pair.ref_rejected_logprob
} else {
pair.rejected_logprob
};
let logit = self.config.beta * (log_ratio_chosen - log_ratio_rejected);
// σ(-logit) = 1 - σ(logit)
let s = sigmoid(-logit);
// dL/d(log_ratio_chosen) = -β * σ(-logit) = -β * (1 - σ(logit))
// dL/d(log_ratio_rejected) = β * σ(-logit)
let grad_chosen = -self.config.beta * s;
let grad_rejected = self.config.beta * s;
(grad_chosen, grad_rejected)
}
/// Compute implicit reward for a response.
///
/// r(x, y) = β * (log π(y|x) - log π_ref(y|x))
#[must_use]
pub fn implicit_reward(&self, policy_logprob: f64, ref_logprob: f64) -> f64 {
self.config.beta * (policy_logprob - ref_logprob)
}
/// Compute the accuracy of the policy's preference ranking.
///
/// Returns the fraction of pairs where the policy assigns higher
/// probability to the chosen response.
#[must_use]
pub fn accuracy(&self, pairs: &[PreferencePair]) -> f64 {
if pairs.is_empty() {
return 0.0;
}
let correct = pairs
.iter()
.filter(|p| {
let chosen_ratio = p.chosen_logprob - p.ref_chosen_logprob;
let rejected_ratio = p.rejected_logprob - p.ref_rejected_logprob;
chosen_ratio > rejected_ratio
})
.count();
correct as f64 / pairs.len() as f64
}
/// Get configuration.
#[must_use]
pub fn config(&self) -> &DpoConfig {
&self.config
}
}
/// Numerically stable log-sigmoid: log(σ(x)) = -log(1 + exp(-x))
fn log_sigmoid(x: f64) -> f64 {
if x > 20.0 {
// For large x: log(σ(x)) ≈ 0
-(-x).exp()
} else if x < -20.0 {
// For very negative x: log(σ(x)) ≈ x
x
} else {
-(1.0 + (-x).exp()).ln()
}
}
/// Standard sigmoid function.
fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
/// DPO training metrics.
#[derive(Debug, Clone, Default)]
pub struct DpoMetrics {
/// Average loss over epoch
pub avg_loss: f64,
/// Preference accuracy (chosen > rejected)
pub accuracy: f64,
/// Average chosen reward
pub avg_chosen_reward: f64,
/// Average rejected reward
pub avg_rejected_reward: f64,
/// Reward margin (chosen - rejected)
pub reward_margin: f64,
/// Number of pairs processed
pub num_pairs: usize,
}
impl DpoMetrics {
/// Compute metrics from a batch of preference pairs.
#[must_use]
pub fn from_batch(loss: &DpoLoss, pairs: &[PreferencePair]) -> Self {
if pairs.is_empty() {
return Self::default();
}
let avg_loss = loss.compute_batch(pairs);
let accuracy = loss.accuracy(pairs);
let (total_chosen, total_rejected) = pairs.iter().fold((0.0, 0.0), |(tc, tr), p| {
let rc = loss.implicit_reward(p.chosen_logprob, p.ref_chosen_logprob);
let rr = loss.implicit_reward(p.rejected_logprob, p.ref_rejected_logprob);
(tc + rc, tr + rr)
});
let n = pairs.len() as f64;
let avg_chosen = total_chosen / n;
let avg_rejected = total_rejected / n;
Self {
avg_loss,
accuracy,
avg_chosen_reward: avg_chosen,
avg_rejected_reward: avg_rejected,
reward_margin: avg_chosen - avg_rejected,
num_pairs: pairs.len(),
}
}
}
#[cfg(test)]
#[path = "dpo_tests.rs"]
mod tests;