use std::collections::HashMap;
use super::super::episode::{Episode, EpisodeContext, Outcome};
use super::super::record::Record;
use super::super::training::TrainingData;
use super::{LearnError, LearnModel};
use crate::types::GroupId;
#[derive(Debug, Clone)]
pub struct DpoPair {
pub chosen: Episode,
pub rejected: Episode,
pub group_id: GroupId,
pub quality_gap: f64,
}
impl DpoPair {
pub fn new(chosen: Episode, rejected: Episode, group_id: GroupId) -> Self {
let chosen_score = chosen.outcome.score();
let rejected_score = rejected.outcome.score();
let quality_gap = chosen_score - rejected_score;
Self {
chosen,
rejected,
group_id,
quality_gap,
}
}
}
#[derive(Debug, Clone)]
pub struct DpoConfig {
pub min_quality_gap: f64,
pub max_pairs: Option<usize>,
pub allow_reuse: bool,
}
impl Default for DpoConfig {
fn default() -> Self {
Self {
min_quality_gap: 0.1, max_pairs: None,
allow_reuse: true,
}
}
}
pub struct DpoLearnModel<F>
where
F: Fn(&Episode) -> Option<(String, String)> + Send + Sync,
{
system_prompt: String,
config: DpoConfig,
extractor: F,
}
impl<F> DpoLearnModel<F>
where
F: Fn(&Episode) -> Option<(String, String)> + Send + Sync,
{
pub fn new(extractor: F) -> Self {
Self {
system_prompt: String::new(),
config: DpoConfig::default(),
extractor,
}
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = prompt.into();
self
}
pub fn with_config(mut self, config: DpoConfig) -> Self {
self.config = config;
self
}
pub fn with_min_quality_gap(mut self, gap: f64) -> Self {
self.config.min_quality_gap = gap;
self
}
pub fn with_max_pairs(mut self, max: usize) -> Self {
self.config.max_pairs = Some(max);
self
}
pub fn build_pairs(&self, episodes: &[Episode]) -> Vec<DpoPair> {
let mut by_group: HashMap<GroupId, Vec<&Episode>> = HashMap::new();
for ep in episodes {
if let Some(gid) = ep.group_id {
by_group.entry(gid).or_default().push(ep);
}
}
let mut pairs = Vec::new();
for (group_id, group_episodes) in by_group {
let (successes, failures): (Vec<_>, Vec<_>) = group_episodes
.into_iter()
.partition(|ep| ep.outcome.is_success());
if successes.is_empty() || failures.is_empty() {
continue;
}
let mut sorted_successes: Vec<_> = successes;
sorted_successes.sort_by(|a, b| {
let a_score = a.outcome.score();
let b_score = b.outcome.score();
b_score
.partial_cmp(&a_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut sorted_failures: Vec<_> = failures;
sorted_failures.sort_by(|a, b| {
let a_score = a.outcome.score();
let b_score = b.outcome.score();
a_score
.partial_cmp(&b_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
for success_ep in &sorted_successes {
for failure_ep in &sorted_failures {
let chosen_score = success_ep.outcome.score();
let rejected_score = failure_ep.outcome.score();
let gap = chosen_score - rejected_score;
if gap < self.config.min_quality_gap {
continue;
}
let pair = DpoPair::new((*success_ep).clone(), (*failure_ep).clone(), group_id);
pairs.push(pair);
if !self.config.allow_reuse {
break;
}
}
if !self.config.allow_reuse {
break;
}
}
}
pairs.sort_by(|a, b| {
b.quality_gap
.partial_cmp(&a.quality_gap)
.unwrap_or(std::cmp::Ordering::Equal)
});
if let Some(max) = self.config.max_pairs {
pairs.truncate(max);
}
pairs
}
pub fn convert_pair(&self, pair: &DpoPair) -> Result<TrainingData, LearnError> {
let (chosen_prompt, chosen_response) = (self.extractor)(&pair.chosen)
.ok_or_else(|| LearnError::MissingData("chosen prompt/response".into()))?;
let (rejected_prompt, rejected_response) = (self.extractor)(&pair.rejected)
.ok_or_else(|| LearnError::MissingData("rejected prompt/response".into()))?;
if chosen_prompt != rejected_prompt {
return Err(LearnError::InvalidEpisode(format!(
"Prompt mismatch: '{}' vs '{}'",
chosen_prompt, rejected_prompt
)));
}
let training = if self.system_prompt.is_empty() {
TrainingData::dpo(&chosen_prompt, &chosen_response, &rejected_response)
} else {
TrainingData::dpo_with_system(
&self.system_prompt,
&chosen_prompt,
&chosen_response,
&rejected_response,
)
};
Ok(training
.with_episode_id(pair.chosen.id.to_string())
.with_custom("rejected_episode_id", pair.rejected.id.to_string())
.with_custom("quality_gap", pair.quality_gap.to_string())
.with_custom("group_id", pair.group_id.0.to_string()))
}
pub fn convert_pairs(&self, pairs: &[DpoPair]) -> Vec<TrainingData> {
pairs
.iter()
.filter_map(|pair| self.convert_pair(pair).ok())
.collect()
}
}
impl<F> LearnModel for DpoLearnModel<F>
where
F: Fn(&Episode) -> Option<(String, String)> + Send + Sync,
{
fn name(&self) -> &str {
"dpo"
}
fn objective(&self) -> &str {
"Learn preferences from success/failure Episode pairs within the same group"
}
fn build_episodes(&self, _records: &[Record]) -> Vec<Episode> {
vec![]
}
fn evaluate(&self, _context: &EpisodeContext) -> Outcome {
panic!(
"DpoLearnModel::evaluate() should not be called.\n\
DPO learning compares multiple Episodes by group_id, not individual Episode evaluation.\n\
Use build_pairs() to generate training pairs from Episodes."
);
}
fn convert(&self, _episode: &Episode) -> Result<TrainingData, LearnError> {
Err(LearnError::InvalidEpisode(
"DPO requires pairs, use convert_pair instead".into(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::learn::episode::EpisodeBuilder;
use crate::learn::record::ActionRecord;
use crate::types::TaskId;
fn create_test_episode(
task_id: TaskId,
group_id: GroupId,
success: bool,
score: f64,
) -> Episode {
let outcome = if success {
Outcome::success(score)
} else {
Outcome::failure("test failure")
};
EpisodeBuilder::default()
.learn_model("test")
.task_id(task_id)
.group_id(group_id)
.record(ActionRecord::new(1, 0, "TestAction").success(success))
.outcome(outcome)
.build()
}
fn test_extractor(ep: &Episode) -> Option<(String, String)> {
Some((
"test prompt".to_string(),
format!("response for {:?}", ep.id),
))
}
#[test]
fn test_build_pairs_basic() {
let group_id = GroupId::new();
let task1 = TaskId::new();
let task2 = TaskId::new();
let episodes = vec![
create_test_episode(task1, group_id, true, 0.9),
create_test_episode(task2, group_id, false, 0.0),
];
let dpo = DpoLearnModel::new(test_extractor);
let pairs = dpo.build_pairs(&episodes);
assert_eq!(pairs.len(), 1);
assert!(pairs[0].quality_gap > 0.0);
}
#[test]
fn test_build_pairs_different_groups() {
let group1 = GroupId::new();
let group2 = GroupId::new();
let episodes = vec![
create_test_episode(TaskId::new(), group1, true, 0.9),
create_test_episode(TaskId::new(), group2, false, 0.0),
];
let dpo = DpoLearnModel::new(test_extractor);
let pairs = dpo.build_pairs(&episodes);
assert!(pairs.is_empty());
}
#[test]
fn test_min_quality_gap() {
let group_id = GroupId::new();
let episodes = vec![
create_test_episode(TaskId::new(), group_id, true, 0.6),
create_test_episode(TaskId::new(), group_id, false, 0.0),
];
let dpo = DpoLearnModel::new(test_extractor).with_min_quality_gap(0.5);
let pairs = dpo.build_pairs(&episodes);
assert_eq!(pairs.len(), 1);
let dpo = DpoLearnModel::new(test_extractor).with_min_quality_gap(0.7);
let pairs = dpo.build_pairs(&episodes);
assert!(pairs.is_empty());
}
}