use super::*;
#[test]
fn test_rlvr_config_default() {
let cfg = RlvrConfig::default();
assert!((cfg.learning_rate - 1e-4).abs() < 1e-10);
assert!((cfg.kl_coeff - 0.1).abs() < 1e-10);
assert!((cfg.reward_scale - 1.0).abs() < 1e-10);
assert_eq!(cfg.max_response_len, 512);
assert_eq!(cfg.num_samples, 4);
}
#[test]
fn test_rlvr_config_validate_ok() {
assert!(RlvrConfig::default().validate().is_ok());
}
#[test]
fn test_rlvr_config_validate_bad_lr() {
let cfg = RlvrConfig {
learning_rate: -0.01,
..Default::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_rlvr_config_validate_bad_kl() {
let cfg = RlvrConfig {
kl_coeff: -1.0,
..Default::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_policy_gradient_positive_reward() {
let loss = RlvrLoss::new(RlvrConfig::default());
let log_probs = vec![-1.0, -2.0, -1.5];
let rewards = vec![1.0, 1.0, 1.0];
let pg = loss.compute_policy_gradient(&log_probs, &rewards);
assert!((pg - 1.5).abs() < 1e-10);
}
#[test]
fn test_policy_gradient_zero_reward() {
let loss = RlvrLoss::new(RlvrConfig::default());
let log_probs = vec![-1.0, -2.0];
let rewards = vec![0.0, 0.0];
let pg = loss.compute_policy_gradient(&log_probs, &rewards);
assert!((pg - 0.0).abs() < 1e-10);
}
#[test]
fn test_policy_gradient_empty() {
let loss = RlvrLoss::new(RlvrConfig::default());
let pg = loss.compute_policy_gradient(&[], &[]);
assert!((pg - 0.0).abs() < 1e-10);
}
#[test]
fn test_kl_penalty_identical() {
let loss = RlvrLoss::new(RlvrConfig::default());
let probs = vec![-1.0, -2.0, -3.0];
let kl = loss.compute_kl_penalty(&probs, &probs);
assert!((kl - 0.0).abs() < 1e-10);
}
#[test]
fn test_kl_penalty_different() {
let loss = RlvrLoss::new(RlvrConfig::default());
let policy = vec![-1.0, -2.0];
let reference = vec![-2.0, -3.0];
let kl = loss.compute_kl_penalty(&policy, &reference);
assert!((kl - 1.0).abs() < 1e-10);
}
#[test]
fn test_total_loss() {
let loss = RlvrLoss::new(RlvrConfig {
kl_coeff: 0.1,
..Default::default()
});
let log_probs = vec![-1.0, -2.0];
let rewards = vec![1.0, 1.0];
let ref_probs = vec![-2.0, -3.0];
let total = loss.compute_total_loss(&log_probs, &rewards, &ref_probs);
assert!((total - 1.6).abs() < 1e-10);
}
#[test]
fn test_math_reward_correct() {
let reward = MathReward;
let result = reward.verify("What is 2+2? expected: 4", "The answer is 4.");
assert!(result.correct);
assert!(result.score > 0.0);
}
#[test]
fn test_math_reward_incorrect() {
let reward = MathReward;
let result = reward.verify("What is 2+2? expected: 4", "The answer is 5.");
assert!(!result.correct);
}
#[test]
fn test_code_reward_correct() {
let reward = CodeReward;
let result = reward.verify(
"Write hello world. must contain: hello world",
"fn main() { println!(\"hello world\"); }",
);
assert!(result.correct);
}
#[test]
fn test_rlvr_metrics_from_batch() {
let loss = RlvrLoss::new(RlvrConfig::default());
let log_probs = vec![-1.0, -2.0];
let rewards = vec![1.0, 0.0];
let ref_probs = vec![-1.0, -2.0];
let metrics = RlvrMetrics::from_batch(&loss, &log_probs, &rewards, &ref_probs);
assert_eq!(metrics.num_samples, 2);
assert!(metrics.avg_reward >= 0.0);
assert!(metrics.avg_loss.is_finite());
}
#[test]
fn test_rlvr_metrics_empty() {
let loss = RlvrLoss::new(RlvrConfig::default());
let metrics = RlvrMetrics::from_batch(&loss, &[], &[], &[]);
assert_eq!(metrics.num_samples, 0);
}
#[test]
fn falsify_rlvr_001_policy_gradient_finite() {
let loss = RlvrLoss::new(RlvrConfig::default());
for r in [-10.0, -1.0, 0.0, 1.0, 10.0] {
let pg = loss.compute_policy_gradient(&[-1.0, -2.0], &[r, r]);
assert!(
pg.is_finite(),
"Policy gradient must be finite for reward={}",
r
);
}
}
#[test]
fn falsify_rlvr_002_kl_nonnegative() {
let loss = RlvrLoss::new(RlvrConfig::default());
for offset in [-5.0, -1.0, 0.0, 1.0, 5.0] {
let policy = vec![-1.0 + offset, -2.0 + offset];
let reference = vec![-1.0, -2.0];
let kl = loss.compute_kl_penalty(&policy, &reference);
assert!(kl.is_finite(), "KL must be finite for offset={}", offset);
}
}
#[test]
fn falsify_rlvr_003_reward_bounded() {
let math = MathReward;
let code = CodeReward;
for prompt in ["2+2=?", "x=1", "hello"] {
for response in ["4", "wrong", "fn main() {}"] {
let r1 = math.verify(prompt, response);
let r2 = code.verify(prompt, response);
assert!(
r1.score >= 0.0 && r1.score <= 1.0,
"Math score out of range"
);
assert!(
r2.score >= 0.0 && r2.score <= 1.0,
"Code score out of range"
);
}
}
}