use crate::error::{DatasetError, DatasetResult};
use crate::types::{PreferencePair, TrainingExample};
pub trait Dataset: Send + Sync {
type Item: Clone;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn get(&self, index: usize) -> Option<&Self::Item>;
fn iter(&self) -> Box<dyn Iterator<Item = &Self::Item> + '_>;
fn shuffle(&mut self, seed: u64);
fn split(&self, ratio: f32) -> (Vec<Self::Item>, Vec<Self::Item>);
}
#[derive(Debug, Clone)]
pub struct InstructDataset {
examples: Vec<TrainingExample>,
}
impl InstructDataset {
pub fn new(examples: Vec<TrainingExample>) -> Self {
Self { examples }
}
pub fn from_examples(examples: impl IntoIterator<Item = TrainingExample>) -> Self {
Self {
examples: examples.into_iter().collect(),
}
}
pub fn push(&mut self, example: TrainingExample) {
self.examples.push(example);
}
pub fn extend(&mut self, examples: impl IntoIterator<Item = TrainingExample>) {
self.examples.extend(examples);
}
pub fn remove(&mut self, index: usize) -> DatasetResult<TrainingExample> {
if index >= self.examples.len() {
return Err(DatasetError::IndexOutOfBounds {
index,
len: self.examples.len(),
});
}
Ok(self.examples.remove(index))
}
pub fn total_estimated_tokens(&self) -> usize {
self.examples.iter().map(|e| e.estimated_tokens()).sum()
}
pub fn filter<F>(&self, predicate: F) -> Self
where
F: Fn(&TrainingExample) -> bool,
{
Self {
examples: self
.examples
.iter()
.filter(|e| predicate(e))
.cloned()
.collect(),
}
}
pub fn as_slice(&self) -> &[TrainingExample] {
&self.examples
}
pub fn into_inner(self) -> Vec<TrainingExample> {
self.examples
}
}
impl Dataset for InstructDataset {
type Item = TrainingExample;
fn len(&self) -> usize {
self.examples.len()
}
fn get(&self, index: usize) -> Option<&TrainingExample> {
self.examples.get(index)
}
fn iter(&self) -> Box<dyn Iterator<Item = &TrainingExample> + '_> {
Box::new(self.examples.iter())
}
fn shuffle(&mut self, seed: u64) {
let len = self.examples.len();
if len <= 1 {
return;
}
let mut state = seed;
for i in (1..len).rev() {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let j = (state >> 33) as usize % (i + 1);
self.examples.swap(i, j);
}
}
fn split(&self, ratio: f32) -> (Vec<TrainingExample>, Vec<TrainingExample>) {
let ratio = ratio.clamp(0.0, 1.0);
let split_idx = (self.examples.len() as f32 * ratio) as usize;
let train = self.examples[..split_idx].to_vec();
let eval = self.examples[split_idx..].to_vec();
(train, eval)
}
}
#[derive(Debug, Clone)]
pub struct PreferenceDataset {
pairs: Vec<PreferencePair>,
}
impl PreferenceDataset {
pub fn new(pairs: Vec<PreferencePair>) -> Self {
Self { pairs }
}
pub fn push(&mut self, pair: PreferencePair) {
self.pairs.push(pair);
}
pub fn from_pairs(pairs: impl IntoIterator<Item = PreferencePair>) -> Self {
Self {
pairs: pairs.into_iter().collect(),
}
}
pub fn extend(&mut self, pairs: impl IntoIterator<Item = PreferencePair>) {
self.pairs.extend(pairs);
}
pub fn remove(&mut self, index: usize) -> DatasetResult<PreferencePair> {
if index >= self.pairs.len() {
return Err(DatasetError::IndexOutOfBounds {
index,
len: self.pairs.len(),
});
}
Ok(self.pairs.remove(index))
}
pub fn filter<F>(&self, predicate: F) -> Self
where
F: Fn(&PreferencePair) -> bool,
{
Self {
pairs: self
.pairs
.iter()
.filter(|p| predicate(p))
.cloned()
.collect(),
}
}
pub fn total_estimated_tokens(&self) -> usize {
self.pairs.iter().map(|p| p.estimated_tokens()).sum()
}
pub fn as_slice(&self) -> &[PreferencePair] {
&self.pairs
}
pub fn into_inner(self) -> Vec<PreferencePair> {
self.pairs
}
}
impl Dataset for PreferenceDataset {
type Item = PreferencePair;
fn len(&self) -> usize {
self.pairs.len()
}
fn get(&self, index: usize) -> Option<&PreferencePair> {
self.pairs.get(index)
}
fn iter(&self) -> Box<dyn Iterator<Item = &PreferencePair> + '_> {
Box::new(self.pairs.iter())
}
fn shuffle(&mut self, seed: u64) {
let len = self.pairs.len();
if len <= 1 {
return;
}
let mut state = seed;
for i in (1..len).rev() {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let j = (state >> 33) as usize % (i + 1);
self.pairs.swap(i, j);
}
}
fn split(&self, ratio: f32) -> (Vec<PreferencePair>, Vec<PreferencePair>) {
let ratio = ratio.clamp(0.0, 1.0);
let split_idx = (self.pairs.len() as f32 * ratio) as usize;
let train = self.pairs[..split_idx].to_vec();
let eval = self.pairs[split_idx..].to_vec();
(train, eval)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::TrainingMessage;
fn sample_examples(n: usize) -> Vec<TrainingExample> {
(0..n)
.map(|i| {
TrainingExample::with_id(
format!("ex-{i}"),
vec![
TrainingMessage::user(format!("Question {i}")),
TrainingMessage::assistant(format!("Answer {i}")),
],
)
})
.collect()
}
#[test]
fn test_instruct_dataset_basics() {
let ds = InstructDataset::new(sample_examples(10));
assert_eq!(ds.len(), 10);
assert!(!ds.is_empty());
assert!(ds.get(0).is_some());
assert!(ds.get(10).is_none());
}
#[test]
fn test_instruct_dataset_split() {
let ds = InstructDataset::new(sample_examples(10));
let (train, eval) = ds.split(0.8);
assert_eq!(train.len(), 8);
assert_eq!(eval.len(), 2);
}
#[test]
fn test_instruct_dataset_shuffle_deterministic() {
let mut ds1 = InstructDataset::new(sample_examples(20));
let mut ds2 = InstructDataset::new(sample_examples(20));
ds1.shuffle(42);
ds2.shuffle(42);
for (a, b) in ds1.iter().zip(ds2.iter()) {
assert_eq!(a.id, b.id);
}
}
#[test]
fn test_instruct_dataset_filter() {
let ds = InstructDataset::new(sample_examples(10));
let filtered = ds.filter(|e| e.id.ends_with('5') || e.id.ends_with('7'));
assert_eq!(filtered.len(), 2);
}
#[test]
fn test_preference_dataset() {
let pairs = vec![PreferencePair::new(
vec![TrainingMessage::user("Q")],
vec![TrainingMessage::assistant("Good")],
vec![TrainingMessage::assistant("Bad")],
)];
let ds = PreferenceDataset::new(pairs);
assert_eq!(ds.len(), 1);
assert!(ds.get(0).is_some());
}
#[test]
fn test_preference_dataset_from_pairs() {
let pairs = vec![
PreferencePair::new(
vec![TrainingMessage::user("Q1")],
vec![TrainingMessage::assistant("Good1")],
vec![TrainingMessage::assistant("Bad1")],
),
PreferencePair::new(
vec![TrainingMessage::user("Q2")],
vec![TrainingMessage::assistant("Good2")],
vec![TrainingMessage::assistant("Bad2")],
),
];
let ds = PreferenceDataset::from_pairs(pairs);
assert_eq!(ds.len(), 2);
}
#[test]
fn test_preference_dataset_extend() {
let mut ds = PreferenceDataset::new(vec![]);
ds.extend(vec![PreferencePair::new(
vec![TrainingMessage::user("Q")],
vec![TrainingMessage::assistant("Good")],
vec![TrainingMessage::assistant("Bad")],
)]);
assert_eq!(ds.len(), 1);
}
#[test]
fn test_preference_dataset_remove() {
let mut ds = PreferenceDataset::new(vec![PreferencePair::new(
vec![TrainingMessage::user("Q")],
vec![TrainingMessage::assistant("Good")],
vec![TrainingMessage::assistant("Bad")],
)]);
let removed = ds.remove(0).unwrap();
assert_eq!(removed.prompt.len(), 1);
assert!(ds.is_empty());
assert!(ds.remove(0).is_err());
}
#[test]
fn test_preference_dataset_filter() {
let ds = PreferenceDataset::new(vec![
PreferencePair::new(
vec![TrainingMessage::user("short")],
vec![TrainingMessage::assistant("a")],
vec![TrainingMessage::assistant("b")],
),
PreferencePair::new(
vec![TrainingMessage::user("this is a longer prompt message")],
vec![TrainingMessage::assistant("good")],
vec![TrainingMessage::assistant("bad")],
),
]);
let filtered = ds.filter(|p| p.estimated_tokens() > 5);
assert_eq!(filtered.len(), 1);
}
}