use rand::RngExt;
use rand::SeedableRng;
use rand::rngs::StdRng;
use rlevo_core::base::{Action, Reward, State};
use rlevo_core::environment::{ConstructableEnv, Environment, EnvironmentError, SnapshotBase};
use rlevo_core::reward::ScalarReward;
use serde::{Deserialize, Serialize};
use std::f32::consts::TAU;
use std::fmt::{Display, Formatter};
use std::str::FromStr;
use super::k_armed::{KArmedBanditAction, KArmedBanditObservation, KArmedBanditState};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdversarialBanditConfig {
pub max_steps: usize,
pub seed: u64,
pub period: usize,
pub amplitude: f32,
}
impl Default for AdversarialBanditConfig {
fn default() -> Self {
Self {
max_steps: 500,
seed: 42,
period: 10,
amplitude: 1.0,
}
}
}
impl FromStr for AdversarialBanditConfig {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let trimmed = s.trim();
if let Ok(max_steps) = trimmed.parse::<usize>() {
return Ok(Self {
max_steps,
..Self::default()
});
}
let mut cfg = Self::default();
let mut saw_key = false;
for pair in trimmed.split(',') {
let pair = pair.trim();
if pair.is_empty() {
continue;
}
let Some(eq_pos) = pair.find('=') else {
return Err(format!(
"Invalid AdversarialBanditConfig format. Expected either a number or 'key=value' pairs, got: {s}"
));
};
let key = pair[..eq_pos].trim();
let value_str = pair[eq_pos + 1..].trim();
match key {
"max_steps" => {
cfg.max_steps = value_str
.parse::<usize>()
.map_err(|e| format!("Failed to parse max_steps value: {e}"))?;
}
"seed" => {
cfg.seed = value_str
.parse::<u64>()
.map_err(|e| format!("Failed to parse seed value: {e}"))?;
}
"period" => {
cfg.period = value_str
.parse::<usize>()
.map_err(|e| format!("Failed to parse period value: {e}"))?;
}
"amplitude" => {
cfg.amplitude = value_str
.parse::<f32>()
.map_err(|e| format!("Failed to parse amplitude value: {e}"))?;
}
other => {
return Err(format!(
"Unknown AdversarialBanditConfig key {other:?} (expected max_steps, seed, period, or amplitude)"
));
}
}
saw_key = true;
}
if saw_key {
Ok(cfg)
} else {
Err(format!(
"Invalid AdversarialBanditConfig format. Expected either a number or 'key=value' pairs, got: {s}"
))
}
}
}
#[derive(Debug)]
pub struct AdversarialBandit<const K: usize> {
state: KArmedBanditState,
steps: usize,
done: bool,
config: AdversarialBanditConfig,
rng: StdRng,
phases: [usize; K],
}
impl<const K: usize> Display for AdversarialBandit<K> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"AdversarialBandit<{K}>(step={}/{}, period={}, done={})",
self.steps, self.config.max_steps, self.config.period, self.done
)
}
}
impl<const K: usize> AdversarialBandit<K> {
pub fn with_seed(seed: u64) -> Self {
let config = AdversarialBanditConfig {
seed,
..AdversarialBanditConfig::default()
};
Self::with_config(config)
}
pub fn with_config(config: AdversarialBanditConfig) -> Self {
let mut rng = StdRng::seed_from_u64(config.seed);
let phases = sample_phases::<K>(&mut rng, config.period.max(1));
Self {
state: KArmedBanditState,
steps: 0,
done: false,
config,
rng,
phases,
}
}
#[must_use]
pub fn phases(&self) -> &[usize; K] {
&self.phases
}
fn reward_at(&self, arm: usize, t: usize) -> f32 {
let period = self.config.period.max(1) as f32;
let theta = TAU * ((t + self.phases[arm]) as f32) / period;
self.config.amplitude * 0.5 * (1.0 + theta.cos())
}
}
fn sample_phases<const K: usize>(rng: &mut StdRng, period: usize) -> [usize; K] {
let mut phases = [0_usize; K];
for phase in &mut phases {
*phase = rng.random_range(0..period);
}
phases
}
impl<const K: usize> ConstructableEnv for AdversarialBandit<K> {
fn new(render: bool) -> Self {
let _ = render;
Self::with_config(AdversarialBanditConfig::default())
}
}
impl<const K: usize> Environment<1, 1, 1> for AdversarialBandit<K> {
type StateType = KArmedBanditState;
type ObservationType = KArmedBanditObservation;
type ActionType = KArmedBanditAction<K>;
type RewardType = ScalarReward;
type SnapshotType = SnapshotBase<1, KArmedBanditObservation, ScalarReward>;
fn reset(&mut self) -> Result<Self::SnapshotType, EnvironmentError> {
self.rng = StdRng::seed_from_u64(self.config.seed);
self.phases = sample_phases::<K>(&mut self.rng, self.config.period.max(1));
self.state = KArmedBanditState;
self.steps = 0;
self.done = false;
Ok(SnapshotBase::running(
self.state.observe(),
ScalarReward::zero(),
))
}
fn step(&mut self, action: Self::ActionType) -> Result<Self::SnapshotType, EnvironmentError> {
if !action.is_valid() {
return Err(EnvironmentError::InvalidAction(format!(
"arm index {} out of range [0, {K})",
action.arm(),
)));
}
let reward = ScalarReward(self.reward_at(action.arm(), self.steps));
self.steps += 1;
let obs = self.state.observe();
let snap = if self.steps >= self.config.max_steps {
self.done = true;
SnapshotBase::terminated(obs, reward)
} else {
SnapshotBase::running(obs, reward)
};
Ok(snap)
}
}
impl<const K: usize> crate::render::AsciiRenderable for AdversarialBandit<K> {
fn render_ascii(&self) -> String {
let rewards: Vec<f32> = (0..K).map(|a| self.reward_at(a, self.steps)).collect();
let (best_arm, best_reward) = super::k_armed::argmax(&rewards);
format!(
"Adversarial (K={K}) best_arm={best_arm} (r={best_reward:.2}) period={} amp={:.2} step={}/{}",
self.config.period, self.config.amplitude, self.steps, self.config.max_steps
)
}
fn render_styled(&self) -> crate::render::StyledFrame {
let line = self.render_ascii();
crate::render::StyledFrame {
lines: vec![super::k_armed::style_bandit_line(&line)],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlevo_core::action::DiscreteAction;
use rlevo_core::environment::Snapshot;
const K: usize = 10;
#[test]
fn environment_reset_yields_running_snapshot_with_zero_reward() {
let mut env = AdversarialBandit::<K>::with_config(AdversarialBanditConfig::default());
let snap = <AdversarialBandit<K> as Environment<1, 1, 1>>::reset(&mut env).expect("reset");
assert!(!snap.is_done());
assert_eq!(f32::from(*snap.reward()), 0.0);
}
#[test]
fn rewards_are_bounded_by_amplitude() {
let cfg = AdversarialBanditConfig {
max_steps: 200,
seed: 5,
period: 7,
amplitude: 1.0,
};
let mut env = AdversarialBandit::<K>::with_config(cfg);
<AdversarialBandit<K> as Environment<1, 1, 1>>::reset(&mut env).unwrap();
for step in 0..50 {
let action = KArmedBanditAction::<K>::from_index(step % K);
let snap =
<AdversarialBandit<K> as Environment<1, 1, 1>>::step(&mut env, action).unwrap();
let r = f32::from(*snap.reward());
assert!(
(0.0..=1.0).contains(&r),
"reward {r} outside [0, 1] at step {step}"
);
}
}
#[test]
fn same_seed_produces_identical_reward_sequence() {
let cfg = AdversarialBanditConfig {
max_steps: 64,
seed: 21,
period: 8,
amplitude: 1.0,
};
let mut a = AdversarialBandit::<K>::with_config(cfg.clone());
let mut b = AdversarialBandit::<K>::with_config(cfg);
<AdversarialBandit<K> as Environment<1, 1, 1>>::reset(&mut a).unwrap();
<AdversarialBandit<K> as Environment<1, 1, 1>>::reset(&mut b).unwrap();
assert_eq!(a.phases(), b.phases());
for step in 0..32 {
let action = KArmedBanditAction::<K>::from_index(step % K);
let snap_a =
<AdversarialBandit<K> as Environment<1, 1, 1>>::step(&mut a, action).unwrap();
let snap_b =
<AdversarialBandit<K> as Environment<1, 1, 1>>::step(&mut b, action).unwrap();
assert_eq!(f32::from(*snap_a.reward()), f32::from(*snap_b.reward()));
}
}
#[test]
fn reward_schedule_is_periodic() {
let cfg = AdversarialBanditConfig {
max_steps: 1000,
seed: 1,
period: 5,
amplitude: 1.0,
};
let env = AdversarialBandit::<K>::with_config(cfg);
let r0 = env.reward_at(2, 0);
let r5 = env.reward_at(2, 5);
let r10 = env.reward_at(2, 10);
assert!((r0 - r5).abs() < 1e-5);
assert!((r0 - r10).abs() < 1e-5);
}
#[test]
fn environment_step_terminates_at_max_steps() {
let mut env = AdversarialBandit::<K>::with_config(AdversarialBanditConfig {
max_steps: 3,
seed: 1,
period: 4,
amplitude: 1.0,
});
let action = KArmedBanditAction::<K>::from_index(0);
let s1 = <AdversarialBandit<K> as Environment<1, 1, 1>>::step(&mut env, action).unwrap();
assert!(!s1.is_done());
let _ = <AdversarialBandit<K> as Environment<1, 1, 1>>::step(&mut env, action).unwrap();
let s3 = <AdversarialBandit<K> as Environment<1, 1, 1>>::step(&mut env, action).unwrap();
assert!(s3.is_terminated());
}
#[test]
fn fromstr_kv_with_period_and_amplitude() {
let c: AdversarialBanditConfig = "max_steps=200,seed=9,period=12,amplitude=0.5"
.parse()
.unwrap();
assert_eq!(c.max_steps, 200);
assert_eq!(c.seed, 9);
assert_eq!(c.period, 12);
assert!((c.amplitude - 0.5).abs() < 1e-6);
}
#[test]
fn fromstr_unknown_key_errors() {
let err: String = "wrong=1".parse::<AdversarialBanditConfig>().unwrap_err();
assert!(err.contains("Unknown AdversarialBanditConfig key"));
}
#[test]
fn render_styled_matches_ascii() {
use crate::render::AsciiRenderable;
let env: AdversarialBandit<10> = AdversarialBandit::with_seed(7);
let plain = env.render_ascii();
let styled = env.render_styled();
assert_eq!(styled.lines.len(), 1);
assert_eq!(styled.plain_text(), plain);
}
#[test]
fn render_styled_uses_palette_consts() {
use crate::render::AsciiRenderable;
use crate::render::palette::{AGENT_FG, AGENT_MODIFIER};
let env: AdversarialBandit<10> = AdversarialBandit::with_seed(7);
let styled = env.render_styled();
let label = styled.lines[0]
.spans
.iter()
.find(|s| s.text.starts_with("Adversarial"))
.expect("Adversarial label span present");
assert_eq!(label.style.fg, Some(AGENT_FG));
assert!(label.style.modifier.contains(AGENT_MODIFIER));
}
#[test]
fn render_ascii_within_width_budget() {
use crate::render::AsciiRenderable;
let env: AdversarialBandit<10> = AdversarialBandit::with_seed(7);
for line in env.render_ascii().lines() {
assert!(
line.chars().count() <= 80,
"line exceeds 80 cols: {line:?} ({} chars)",
line.chars().count()
);
}
}
}