use crate::error::{AprenderError, Result};
#[derive(Debug, Clone)]
pub struct CptConfig {
pub learning_rate: f64,
pub warmup_steps: usize,
pub total_steps: usize,
pub seq_length: usize,
pub domain_mix_ratio: f64,
pub weight_decay: f64,
pub seed: u64,
pub max_grad_norm: f64,
pub replay_buffer_size: usize,
}
impl Default for CptConfig {
fn default() -> Self {
Self {
learning_rate: 2e-5,
warmup_steps: 100,
total_steps: 1000,
seq_length: 512,
domain_mix_ratio: 0.7,
weight_decay: 0.01,
seed: 42,
max_grad_norm: 1.0,
replay_buffer_size: 0,
}
}
}
impl CptConfig {
pub fn validate(&self) -> Result<()> {
if self.learning_rate <= 0.0 || !self.learning_rate.is_finite() {
return Err(AprenderError::FormatError {
message: format!(
"learning_rate must be positive finite, got {}",
self.learning_rate
),
});
}
if self.total_steps == 0 {
return Err(AprenderError::FormatError {
message: "total_steps must be > 0".to_string(),
});
}
if self.domain_mix_ratio < 0.0 || self.domain_mix_ratio > 1.0 {
return Err(AprenderError::FormatError {
message: format!(
"domain_mix_ratio must be in [0.0, 1.0], got {}",
self.domain_mix_ratio
),
});
}
if self.seq_length == 0 {
return Err(AprenderError::FormatError {
message: "seq_length must be > 0".to_string(),
});
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct CptSchedule {
base_lr: f64,
warmup_steps: usize,
total_steps: usize,
}
impl CptSchedule {
#[must_use]
pub fn new(config: &CptConfig) -> Self {
Self {
base_lr: config.learning_rate,
warmup_steps: config.warmup_steps,
total_steps: config.total_steps,
}
}
#[must_use]
pub fn lr_at_step(&self, step: usize) -> f64 {
if step < self.warmup_steps {
self.base_lr * (step as f64 / self.warmup_steps.max(1) as f64)
} else {
let progress = (step - self.warmup_steps) as f64
/ (self.total_steps - self.warmup_steps).max(1) as f64;
let progress = progress.min(1.0);
self.base_lr * 0.5 * (1.0 + (std::f64::consts::PI * progress).cos())
}
}
}
#[derive(Debug, Clone)]
pub struct DataMixer {
domain_mix_ratio: f64,
seed: u64,
step: usize,
}
impl DataMixer {
#[must_use]
pub fn new(domain_mix_ratio: f64, seed: u64) -> Self {
Self {
domain_mix_ratio,
seed,
step: 0,
}
}
pub fn next_is_domain(&mut self) -> bool {
self.step += 1;
let state = self
.seed
.wrapping_add(self.step as u64)
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let r = (state >> 33) as f64 / (1u64 << 31) as f64;
r < self.domain_mix_ratio
}
pub fn mix_batches(
&mut self,
domain_tokens: &[Vec<u32>],
general_tokens: &[Vec<u32>],
batch_size: usize,
) -> Vec<Vec<u32>> {
let mut batch = Vec::with_capacity(batch_size);
let mut domain_idx = 0;
let mut general_idx = 0;
for _ in 0..batch_size {
if self.next_is_domain() && domain_idx < domain_tokens.len() {
batch.push(domain_tokens[domain_idx].clone());
domain_idx += 1;
} else if general_idx < general_tokens.len() {
batch.push(general_tokens[general_idx].clone());
general_idx += 1;
} else if domain_idx < domain_tokens.len() {
batch.push(domain_tokens[domain_idx].clone());
domain_idx += 1;
}
}
batch
}
#[must_use]
pub fn ratio(&self) -> f64 {
self.domain_mix_ratio
}
}
#[derive(Debug, Clone)]
pub struct ReplayBuffer {
buffer: Vec<Vec<u32>>,
capacity: usize,
insert_idx: usize,
}
impl ReplayBuffer {
#[must_use]
pub fn new(capacity: usize) -> Self {
Self {
buffer: Vec::with_capacity(capacity),
capacity,
insert_idx: 0,
}
}
pub fn add(&mut self, tokens: Vec<u32>) {
if self.buffer.len() < self.capacity {
self.buffer.push(tokens);
} else {
self.buffer[self.insert_idx % self.capacity] = tokens;
}
self.insert_idx += 1;
}
#[must_use]
pub fn sample(&self, n: usize, seed: u64) -> Vec<Vec<u32>> {
if self.buffer.is_empty() {
return vec![];
}
let mut samples = Vec::with_capacity(n);
let mut state = seed;
for _ in 0..n {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let idx = (state >> 33) as usize % self.buffer.len();
samples.push(self.buffer[idx].clone());
}
samples
}
#[must_use]
pub fn len(&self) -> usize {
self.buffer.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
#[must_use]
pub fn capacity(&self) -> usize {
self.capacity
}
}
#[derive(Debug, Clone)]
pub struct CptProgress {
pub step: usize,
pub total_steps: usize,
pub current_lr: f64,
pub avg_loss: f64,
pub domain_samples: usize,
pub general_samples: usize,
pub replay_samples: usize,
}
impl CptProgress {
#[must_use]
pub fn new(total_steps: usize) -> Self {
Self {
step: 0,
total_steps,
current_lr: 0.0,
avg_loss: 0.0,
domain_samples: 0,
general_samples: 0,
replay_samples: 0,
}
}
pub fn update(&mut self, lr: f64, loss: f64, domain: bool) {
self.step += 1;
self.current_lr = lr;
let alpha = 0.99_f64.min(1.0 - 1.0 / (self.step as f64 + 1.0));
self.avg_loss = alpha * self.avg_loss + (1.0 - alpha) * loss;
if domain {
self.domain_samples += 1;
} else {
self.general_samples += 1;
}
}
#[must_use]
pub fn fraction(&self) -> f64 {
self.step as f64 / self.total_steps.max(1) as f64
}
#[must_use]
pub fn is_done(&self) -> bool {
self.step >= self.total_steps
}
}
#[cfg(test)]
#[path = "cpt_tests.rs"]
mod tests;