#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LoggedSample {
pub reward: f32,
pub logging_prop: f32,
pub target_prop: f32,
pub reward_hat: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OffPolicyError {
EmptyLog,
NonPositivePropensity,
InvalidClip,
}
impl std::fmt::Display for OffPolicyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::EmptyLog => write!(f, "off-policy log is empty"),
Self::NonPositivePropensity => {
write!(f, "logging propensity must be > 0")
}
Self::InvalidClip => write!(f, "clip threshold must be > 0"),
}
}
}
impl std::error::Error for OffPolicyError {}
pub type OffPolicyResult<T> = Result<T, OffPolicyError>;
#[inline]
fn weight(s: &LoggedSample) -> OffPolicyResult<f32> {
if s.logging_prop <= 0.0 {
return Err(OffPolicyError::NonPositivePropensity);
}
Ok(s.target_prop / s.logging_prop)
}
pub fn ips(samples: &[LoggedSample]) -> OffPolicyResult<f32> {
if samples.is_empty() {
return Err(OffPolicyError::EmptyLog);
}
let mut acc = 0.0_f32;
for s in samples {
acc += weight(s)? * s.reward;
}
Ok(acc / samples.len() as f32)
}
pub fn ips_clipped(samples: &[LoggedSample], clip: f32) -> OffPolicyResult<f32> {
if samples.is_empty() {
return Err(OffPolicyError::EmptyLog);
}
if clip <= 0.0 {
return Err(OffPolicyError::InvalidClip);
}
let mut acc = 0.0_f32;
for s in samples {
let w = weight(s)?.min(clip);
acc += w * s.reward;
}
Ok(acc / samples.len() as f32)
}
pub fn snips(samples: &[LoggedSample]) -> OffPolicyResult<f32> {
if samples.is_empty() {
return Err(OffPolicyError::EmptyLog);
}
let mut num = 0.0_f32;
let mut den = 0.0_f32;
for s in samples {
let w = weight(s)?;
num += w * s.reward;
den += w;
}
if den.abs() < 1e-12 {
return Ok(0.0);
}
Ok(num / den)
}
pub fn direct_method(samples: &[LoggedSample]) -> OffPolicyResult<f32> {
if samples.is_empty() {
return Err(OffPolicyError::EmptyLog);
}
let sum: f32 = samples.iter().map(|s| s.reward_hat).sum();
Ok(sum / samples.len() as f32)
}
pub fn doubly_robust(samples: &[LoggedSample]) -> OffPolicyResult<f32> {
if samples.is_empty() {
return Err(OffPolicyError::EmptyLog);
}
let mut acc = 0.0_f32;
for s in samples {
let w = weight(s)?;
acc += s.reward_hat + w * (s.reward - s.reward_hat);
}
Ok(acc / samples.len() as f32)
}
pub fn doubly_robust_clipped(samples: &[LoggedSample], clip: f32) -> OffPolicyResult<f32> {
if samples.is_empty() {
return Err(OffPolicyError::EmptyLog);
}
if clip <= 0.0 {
return Err(OffPolicyError::InvalidClip);
}
let mut acc = 0.0_f32;
for s in samples {
let w = weight(s)?.min(clip);
acc += s.reward_hat + w * (s.reward - s.reward_hat);
}
Ok(acc / samples.len() as f32)
}
pub fn effective_sample_size(samples: &[LoggedSample]) -> OffPolicyResult<f32> {
if samples.is_empty() {
return Err(OffPolicyError::EmptyLog);
}
let mut sum = 0.0_f32;
let mut sum_sq = 0.0_f32;
for s in samples {
let w = weight(s)?;
sum += w;
sum_sq += w * w;
}
if sum_sq.abs() < 1e-12 {
return Ok(0.0);
}
Ok(sum * sum / sum_sq)
}
#[cfg(test)]
mod tests {
use super::*;
fn sample(reward: f32, log_p: f32, tgt_p: f32, rhat: f32) -> LoggedSample {
LoggedSample {
reward,
logging_prop: log_p,
target_prop: tgt_p,
reward_hat: rhat,
}
}
#[test]
fn ips_on_policy_recovers_mean_reward() {
let data = vec![
sample(1.0, 0.5, 0.5, 0.0),
sample(0.0, 0.5, 0.5, 0.0),
sample(1.0, 0.5, 0.5, 0.0),
];
let v = ips(&data).expect("ips");
assert!((v - 2.0 / 3.0).abs() < 1e-5, "got {v}");
}
#[test]
fn ips_reweights_correctly() {
let data = vec![sample(1.0, 0.4, 0.8, 0.0)];
let v = ips(&data).expect("ips");
assert!((v - 2.0).abs() < 1e-5, "got {v}");
}
#[test]
fn ips_empty_errors() {
assert_eq!(ips(&[]), Err(OffPolicyError::EmptyLog));
}
#[test]
fn ips_zero_propensity_errors() {
let data = vec![sample(1.0, 0.0, 0.5, 0.0)];
assert_eq!(ips(&data), Err(OffPolicyError::NonPositivePropensity));
}
#[test]
fn clipped_ips_caps_weight() {
let data = vec![sample(1.0, 0.1, 0.9, 0.0)];
let v = ips_clipped(&data, 3.0).expect("clipped");
assert!((v - 3.0).abs() < 1e-5, "got {v}");
}
#[test]
fn clipped_ips_invalid_clip_errors() {
let data = vec![sample(1.0, 0.5, 0.5, 0.0)];
assert_eq!(ips_clipped(&data, 0.0), Err(OffPolicyError::InvalidClip));
assert_eq!(ips_clipped(&[], 1.0), Err(OffPolicyError::EmptyLog));
}
#[test]
fn snips_normalises() {
let data = vec![sample(1.0, 0.4, 0.8, 0.0), sample(0.0, 0.8, 0.4, 0.0)];
let v = snips(&data).expect("snips");
assert!((v - 0.8).abs() < 1e-5, "got {v}");
}
#[test]
fn snips_zero_weight_returns_zero() {
let data = vec![sample(1.0, 0.5, 0.0, 0.0), sample(0.0, 0.5, 0.0, 0.0)];
let v = snips(&data).expect("snips");
assert!(v.abs() < 1e-7, "got {v}");
}
#[test]
fn snips_empty_errors() {
assert_eq!(snips(&[]), Err(OffPolicyError::EmptyLog));
}
#[test]
fn direct_method_averages_rhat() {
let data = vec![
sample(0.0, 0.5, 0.5, 0.2),
sample(0.0, 0.5, 0.5, 0.4),
sample(0.0, 0.5, 0.5, 0.6),
];
let v = direct_method(&data).expect("dm");
assert!((v - 0.4).abs() < 1e-5, "got {v}");
}
#[test]
fn direct_method_empty_errors() {
assert_eq!(direct_method(&[]), Err(OffPolicyError::EmptyLog));
}
#[test]
fn dr_perfect_reward_model_equals_dm() {
let data = vec![sample(1.0, 0.3, 0.9, 1.0), sample(0.0, 0.7, 0.2, 0.0)];
let dr = doubly_robust(&data).expect("dr");
let dm = direct_method(&data).expect("dm");
assert!((dr - dm).abs() < 1e-5, "dr {dr} should equal dm {dm}");
}
#[test]
fn dr_zero_reward_model_equals_ips() {
let data = vec![
sample(1.0, 0.4, 0.8, 0.0),
sample(0.0, 0.5, 0.5, 0.0),
sample(1.0, 0.25, 0.5, 0.0),
];
let dr = doubly_robust(&data).expect("dr");
let ips_v = ips(&data).expect("ips");
assert!(
(dr - ips_v).abs() < 1e-5,
"dr {dr} should equal ips {ips_v}"
);
}
#[test]
fn dr_empty_and_bad_propensity_errors() {
assert_eq!(doubly_robust(&[]), Err(OffPolicyError::EmptyLog));
let data = vec![sample(1.0, -0.1, 0.5, 0.0)];
assert_eq!(
doubly_robust(&data),
Err(OffPolicyError::NonPositivePropensity)
);
}
#[test]
fn dr_clipped_matches_dr_when_below_threshold() {
let data = vec![sample(1.0, 0.5, 0.5, 0.3), sample(0.0, 0.5, 0.5, 0.1)];
let dr = doubly_robust(&data).expect("dr");
let drc = doubly_robust_clipped(&data, 10.0).expect("drc");
assert!((dr - drc).abs() < 1e-5, "dr {dr} vs clipped {drc}");
assert_eq!(
doubly_robust_clipped(&data, 0.0),
Err(OffPolicyError::InvalidClip)
);
}
#[test]
fn ess_uniform_weights_equals_n() {
let data = vec![
sample(1.0, 0.5, 0.5, 0.0),
sample(0.0, 0.5, 0.5, 0.0),
sample(1.0, 0.5, 0.5, 0.0),
sample(0.0, 0.5, 0.5, 0.0),
];
let ess = effective_sample_size(&data).expect("ess");
assert!((ess - 4.0).abs() < 1e-4, "got {ess}");
}
#[test]
fn ess_concentrated_weights_low() {
let data = vec![
sample(1.0, 0.01, 1.0, 0.0), sample(0.0, 1.0, 0.0, 0.0), sample(0.0, 1.0, 0.0, 0.0), ];
let ess = effective_sample_size(&data).expect("ess");
assert!(ess < 1.5, "expected low ESS, got {ess}");
}
#[test]
fn error_display_messages() {
assert!(OffPolicyError::EmptyLog.to_string().contains("empty"));
assert!(
OffPolicyError::NonPositivePropensity
.to_string()
.contains("propensity")
);
assert!(OffPolicyError::InvalidClip.to_string().contains("clip"));
}
}