use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use rand::SeedableRng;
use rand::rngs::StdRng;
use rand_distr::{Distribution, Normal};
use rlevo_core::action::DiscreteAction;
use rlevo_core::base::{
Action, Observation, Reward, State, TensorConversionError, TensorConvertible,
};
use rlevo_core::environment::{ConstructableEnv, Environment, EnvironmentError, SnapshotBase};
use rlevo_core::reward::ScalarReward;
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter};
use std::str::FromStr;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct KArmedBanditState;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
pub struct KArmedBanditObservation;
impl Observation<1> for KArmedBanditObservation {
fn shape() -> [usize; 1] {
[1]
}
}
impl State<1> for KArmedBanditState {
type Observation = KArmedBanditObservation;
fn shape() -> [usize; 1] {
[1]
}
fn observe(&self) -> Self::Observation {
KArmedBanditObservation
}
fn is_valid(&self) -> bool {
true
}
fn numel(&self) -> usize {
1
}
}
impl Display for KArmedBanditState {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "KArmedBanditState")
}
}
impl<B: Backend> TensorConvertible<1, B> for KArmedBanditState {
fn to_tensor(&self, device: &<B as burn::tensor::backend::BackendTypes>::Device) -> Tensor<B, 1> {
Tensor::from_floats([0.0_f32; 1], device)
}
fn from_tensor(tensor: Tensor<B, 1>) -> Result<Self, TensorConversionError> {
let dims = tensor.dims();
if dims.as_slice() != [1] {
return Err(TensorConversionError {
message: format!("expected shape [1], got {dims:?}"),
});
}
Ok(Self)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct KArmedBanditAction<const K: usize> {
selected_arm: usize,
}
impl<const K: usize> KArmedBanditAction<K> {
pub fn new(arm: usize) -> Result<Self, EnvironmentError> {
if arm < K {
Ok(Self { selected_arm: arm })
} else {
Err(EnvironmentError::InvalidAction(format!(
"arm index {arm} out of range [0, {K})"
)))
}
}
#[must_use]
pub fn arm(&self) -> usize {
self.selected_arm
}
}
impl<const K: usize> Display for KArmedBanditAction<K> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "KArmedBanditAction<{K}>(arm={})", self.selected_arm)
}
}
impl<const K: usize> Action<1> for KArmedBanditAction<K> {
fn shape() -> [usize; 1] {
[K]
}
fn is_valid(&self) -> bool {
self.selected_arm < K
}
}
impl<const K: usize> DiscreteAction<1> for KArmedBanditAction<K> {
const ACTION_COUNT: usize = K;
fn from_index(index: usize) -> Self {
assert!(
index < K,
"KArmedBanditAction index {index} out of range [0, {K})",
);
Self {
selected_arm: index,
}
}
fn to_index(&self) -> usize {
self.selected_arm
}
}
impl<const K: usize, B: Backend> TensorConvertible<1, B> for KArmedBanditAction<K> {
fn to_tensor(&self, device: &<B as burn::tensor::backend::BackendTypes>::Device) -> Tensor<B, 1> {
let mut one_hot = [0.0_f32; K];
one_hot[self.selected_arm] = 1.0;
Tensor::from_floats(one_hot, device)
}
fn from_tensor(tensor: Tensor<B, 1>) -> Result<Self, TensorConversionError> {
let dims = tensor.dims();
if dims.as_slice() != [K] {
return Err(TensorConversionError {
message: format!("expected shape [{K}], got {dims:?}"),
});
}
let data = tensor.into_data();
let values: Vec<f32> = data.to_vec().map_err(|e| TensorConversionError {
message: format!("failed to extract tensor data: {e:?}"),
})?;
let (argmax, _) = values.iter().enumerate().fold(
(0_usize, f32::NEG_INFINITY),
|(i_best, v_best), (i, &v)| {
if v > v_best { (i, v) } else { (i_best, v_best) }
},
);
KArmedBanditAction::<K>::new(argmax).map_err(|e| TensorConversionError {
message: format!("invalid one-hot argmax: {e}"),
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KArmedBanditConfig {
pub max_steps: usize,
pub seed: u64,
}
impl Default for KArmedBanditConfig {
fn default() -> Self {
Self {
max_steps: 500,
seed: 42,
}
}
}
impl FromStr for KArmedBanditConfig {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let trimmed = s.trim();
if let Ok(max_steps) = trimmed.parse::<usize>() {
return Ok(Self {
max_steps,
..Self::default()
});
}
let mut cfg = Self::default();
let mut saw_key = false;
for pair in trimmed.split(',') {
let pair = pair.trim();
if pair.is_empty() {
continue;
}
let Some(eq_pos) = pair.find('=') else {
return Err(format!(
"Invalid KArmedBanditConfig format. Expected either a number or 'key=value' pairs, got: {s}"
));
};
let key = pair[..eq_pos].trim();
let value_str = pair[eq_pos + 1..].trim();
match key {
"max_steps" => {
cfg.max_steps = value_str
.parse::<usize>()
.map_err(|e| format!("Failed to parse max_steps value: {e}"))?;
}
"seed" => {
cfg.seed = value_str
.parse::<u64>()
.map_err(|e| format!("Failed to parse seed value: {e}"))?;
}
other => {
return Err(format!(
"Unknown KArmedBanditConfig key {other:?} (expected max_steps or seed)"
));
}
}
saw_key = true;
}
if saw_key {
Ok(cfg)
} else {
Err(format!(
"Invalid KArmedBanditConfig format. Expected either a number or 'key=value' pairs, got: {s}"
))
}
}
}
#[derive(Debug)]
pub struct KArmedBandit<const K: usize> {
state: KArmedBanditState,
steps: usize,
done: bool,
config: KArmedBanditConfig,
rng: StdRng,
arm_means: [f32; K],
}
impl<const K: usize> Display for KArmedBandit<K> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"KArmedBandit<{K}>(step={}/{}, done={})",
self.steps, self.config.max_steps, self.done
)
}
}
impl<const K: usize> KArmedBandit<K> {
pub fn with_seed(seed: u64) -> Self {
let config = KArmedBanditConfig {
seed,
..KArmedBanditConfig::default()
};
Self::with_config(config)
}
pub fn with_config(config: KArmedBanditConfig) -> Self {
let mut rng = StdRng::seed_from_u64(config.seed);
let arm_means = sample_arm_means::<K>(&mut rng);
Self {
state: KArmedBanditState,
steps: 0,
done: false,
config,
rng,
arm_means,
}
}
pub fn reset(&mut self) {
self.rng = StdRng::seed_from_u64(self.config.seed);
self.arm_means = sample_arm_means::<K>(&mut self.rng);
self.state = KArmedBanditState;
self.steps = 0;
self.done = false;
}
pub fn pull(&mut self, arm: usize) -> f32 {
let action = KArmedBanditAction::<K>::new(arm).expect("arm index in range");
let reward = self.sample_reward(action.arm());
self.steps += 1;
if self.steps >= self.config.max_steps {
self.done = true;
}
reward
}
#[must_use]
pub fn is_done(&self) -> bool {
self.done
}
#[must_use]
pub fn arm_means(&self) -> &[f32; K] {
&self.arm_means
}
fn sample_reward(&mut self, arm: usize) -> f32 {
let mean = self.arm_means[arm];
Normal::new(mean, 1.0)
.expect("N(mean, 1) is always valid")
.sample(&mut self.rng)
}
}
pub(super) fn sample_arm_means<const K: usize>(rng: &mut StdRng) -> [f32; K] {
let normal = Normal::new(0.0_f32, 1.0).expect("N(0, 1) is always valid");
let mut arm_means = [0.0_f32; K];
for mean in &mut arm_means {
*mean = normal.sample(rng);
}
arm_means
}
impl<const K: usize> ConstructableEnv for KArmedBandit<K> {
fn new(render: bool) -> Self {
let _ = render;
Self::with_config(KArmedBanditConfig::default())
}
}
impl<const K: usize> Environment<1, 1, 1> for KArmedBandit<K> {
type StateType = KArmedBanditState;
type ObservationType = KArmedBanditObservation;
type ActionType = KArmedBanditAction<K>;
type RewardType = ScalarReward;
type SnapshotType = SnapshotBase<1, KArmedBanditObservation, ScalarReward>;
fn reset(&mut self) -> Result<Self::SnapshotType, EnvironmentError> {
KArmedBandit::reset(self);
Ok(SnapshotBase::running(
self.state.observe(),
ScalarReward::zero(),
))
}
fn step(&mut self, action: Self::ActionType) -> Result<Self::SnapshotType, EnvironmentError> {
if !action.is_valid() {
return Err(EnvironmentError::InvalidAction(format!(
"arm index {} out of range [0, {K})",
action.arm(),
)));
}
let reward = ScalarReward(self.sample_reward(action.arm()));
self.steps += 1;
let obs = self.state.observe();
let snap = if self.steps >= self.config.max_steps {
self.done = true;
SnapshotBase::terminated(obs, reward)
} else {
SnapshotBase::running(obs, reward)
};
Ok(snap)
}
}
impl<const K: usize> crate::render::AsciiRenderable for KArmedBandit<K> {
fn render_ascii(&self) -> String {
let (best_arm, best_mean) = argmax(&self.arm_means);
format!(
"K-armed (K={K}) best_arm={best_arm} (q*={best_mean:.2}) step={}/{}",
self.steps, self.config.max_steps
)
}
fn render_styled(&self) -> crate::render::StyledFrame {
let line = self.render_ascii();
crate::render::StyledFrame {
lines: vec![style_bandit_line(&line)],
}
}
}
pub(super) fn argmax(values: &[f32]) -> (usize, f32) {
values
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map_or((0, 0.0), |(i, v)| (i, *v))
}
pub(super) fn style_bandit_line(line: &str) -> crate::render::StyledLine {
use crate::render::palette::{AGENT_FG, AGENT_MODIFIER};
use crate::render::{SpanStyle, StyledLine, StyledSpan};
let agent_style = SpanStyle::default()
.fg(AGENT_FG)
.with_modifier(AGENT_MODIFIER);
if let Some(sep) = line.find(" ") {
let label = &line[..sep];
let rest = &line[sep..];
StyledLine::from_spans(vec![
StyledSpan::new(label, agent_style),
StyledSpan::raw(rest.to_string()),
])
} else {
StyledLine::unstyled(line)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlevo_core::environment::Snapshot;
type TestBackend = burn::backend::Flex;
const K: usize = 10;
#[test]
fn state_round_trips_through_tensor() {
let device = Default::default();
let state = KArmedBanditState;
let tensor =
<KArmedBanditState as TensorConvertible<1, TestBackend>>::to_tensor(&state, &device);
let back = <KArmedBanditState as TensorConvertible<1, TestBackend>>::from_tensor(tensor)
.expect("round-trip should succeed for valid shape");
assert_eq!(back, state);
}
#[test]
fn state_from_tensor_rejects_wrong_shape() {
use burn::tensor::{Tensor, TensorData as TD};
let device = Default::default();
let data = TD::new(vec![0.0_f32, 0.0_f32], [2]);
let tensor = Tensor::<TestBackend, 1>::from_data(data, &device);
let err = <KArmedBanditState as TensorConvertible<1, TestBackend>>::from_tensor(tensor)
.expect_err("shape [2] should be rejected");
assert!(err.message.contains("expected shape [1]"));
}
#[test]
fn action_from_index_round_trips() {
for i in 0..K {
let action = KArmedBanditAction::<K>::from_index(i);
assert_eq!(action.to_index(), i);
assert!(action.is_valid());
}
}
#[test]
fn action_new_rejects_out_of_range() {
let err = KArmedBanditAction::<K>::new(K).expect_err("expected InvalidAction");
matches!(err, EnvironmentError::InvalidAction(_));
}
#[test]
#[should_panic(expected = "out of range")]
fn action_from_index_panics_out_of_range() {
let _ = KArmedBanditAction::<K>::from_index(K);
}
#[test]
fn action_enumerate_covers_all_arms() {
let all = KArmedBanditAction::<K>::enumerate();
assert_eq!(all.len(), K);
for (i, a) in all.iter().enumerate() {
assert_eq!(a.to_index(), i);
}
}
#[test]
fn action_one_hot_round_trips_through_tensor() {
let device = Default::default();
for i in 0..K {
let a = KArmedBanditAction::<K>::from_index(i);
let t = <KArmedBanditAction<K> as TensorConvertible<1, TestBackend>>::to_tensor(
&a, &device,
);
let back = <KArmedBanditAction<K> as TensorConvertible<1, TestBackend>>::from_tensor(t)
.expect("valid one-hot");
assert_eq!(back, a);
}
}
#[test]
fn action_from_tensor_rejects_wrong_shape() {
use burn::tensor::{Tensor, TensorData as TD};
let device = Default::default();
let data = TD::new(vec![0.0_f32, 1.0_f32], [2]);
let tensor = Tensor::<TestBackend, 1>::from_data(data, &device);
let err = <KArmedBanditAction<K> as TensorConvertible<1, TestBackend>>::from_tensor(tensor)
.expect_err("shape [2] should be rejected");
assert!(err.message.contains("expected shape"));
}
#[test]
fn environment_new_constructs() {
let env = <KArmedBandit<K> as ConstructableEnv>::new(false);
assert_eq!(env.steps, 0);
assert!(!env.done);
}
#[test]
fn environment_reset_yields_running_snapshot_with_zero_reward() {
let mut env = KArmedBandit::<K>::with_config(KArmedBanditConfig::default());
let snap = <KArmedBandit<K> as Environment<1, 1, 1>>::reset(&mut env).expect("reset");
assert!(!snap.is_done());
assert_eq!(f32::from(*snap.reward()), 0.0);
}
#[test]
fn environment_step_terminates_at_max_steps() {
let mut env = KArmedBandit::<K>::with_config(KArmedBanditConfig {
max_steps: 3,
seed: 1,
});
let action = KArmedBanditAction::<K>::from_index(0);
let s1 = <KArmedBandit<K> as Environment<1, 1, 1>>::step(&mut env, action).unwrap();
assert!(!s1.is_done());
let s2 = <KArmedBandit<K> as Environment<1, 1, 1>>::step(&mut env, action).unwrap();
assert!(!s2.is_done());
let s3 = <KArmedBandit<K> as Environment<1, 1, 1>>::step(&mut env, action).unwrap();
assert!(s3.is_terminated());
}
#[test]
fn same_seed_produces_identical_trajectories() {
let cfg = KArmedBanditConfig {
max_steps: 64,
seed: 7,
};
let mut a = KArmedBandit::<K>::with_config(cfg.clone());
let mut b = KArmedBandit::<K>::with_config(cfg);
<KArmedBandit<K> as Environment<1, 1, 1>>::reset(&mut a).unwrap();
<KArmedBandit<K> as Environment<1, 1, 1>>::reset(&mut b).unwrap();
assert_eq!(a.arm_means(), b.arm_means());
for step in 0..64 {
let action = KArmedBanditAction::<K>::from_index(step % K);
let snap_a = <KArmedBandit<K> as Environment<1, 1, 1>>::step(&mut a, action).unwrap();
let snap_b = <KArmedBandit<K> as Environment<1, 1, 1>>::step(&mut b, action).unwrap();
assert_eq!(f32::from(*snap_a.reward()), f32::from(*snap_b.reward()));
assert_eq!(snap_a.status(), snap_b.status());
}
}
#[test]
fn reset_redraws_arm_means_from_config_seed() {
let cfg = KArmedBanditConfig {
max_steps: 10,
seed: 99,
};
let mut env = KArmedBandit::<K>::with_config(cfg);
let means_before = *env.arm_means();
for _ in 0..5 {
let _ = env.pull(0);
}
<KArmedBandit<K> as Environment<1, 1, 1>>::reset(&mut env).unwrap();
let means_after = *env.arm_means();
assert_eq!(means_before, means_after);
assert_eq!(env.steps, 0);
}
#[test]
fn alias_ten_armed_bandit_resolves_to_k_equals_10() {
use crate::classic::{TenArmedBandit, TenArmedBanditAction};
let mut env = TenArmedBandit::with_config(KArmedBanditConfig::default());
<TenArmedBandit as Environment<1, 1, 1>>::reset(&mut env).unwrap();
let action = TenArmedBanditAction::from_index(0);
let snap = <TenArmedBandit as Environment<1, 1, 1>>::step(&mut env, action).unwrap();
assert!(!snap.is_done());
assert_eq!(env.arm_means().len(), 10);
}
#[test]
fn k_other_than_10_constructs_and_steps() {
let mut env = KArmedBandit::<4>::with_config(KArmedBanditConfig::default());
<KArmedBandit<4> as Environment<1, 1, 1>>::reset(&mut env).unwrap();
assert_eq!(env.arm_means().len(), 4);
let action = KArmedBanditAction::<4>::from_index(3);
let _ = <KArmedBandit<4> as Environment<1, 1, 1>>::step(&mut env, action).unwrap();
}
#[test]
fn fromstr_simple_number_sets_max_steps() {
let c: KArmedBanditConfig = "500".parse().unwrap();
assert_eq!(c.max_steps, 500);
assert_eq!(c.seed, 42);
}
#[test]
fn fromstr_with_whitespace() {
let c: KArmedBanditConfig = " 750 ".parse().unwrap();
assert_eq!(c.max_steps, 750);
}
#[test]
fn fromstr_key_value_max_steps() {
let c: KArmedBanditConfig = "max_steps=1000".parse().unwrap();
assert_eq!(c.max_steps, 1000);
}
#[test]
fn fromstr_key_value_seed() {
let c: KArmedBanditConfig = "seed=17".parse().unwrap();
assert_eq!(c.seed, 17);
assert_eq!(c.max_steps, 500);
}
#[test]
fn fromstr_two_keys() {
let c: KArmedBanditConfig = "max_steps=50,seed=3".parse().unwrap();
assert_eq!(c.max_steps, 50);
assert_eq!(c.seed, 3);
}
#[test]
fn fromstr_key_value_with_whitespace() {
let c: KArmedBanditConfig = "max_steps = 2000".parse().unwrap();
assert_eq!(c.max_steps, 2000);
}
#[test]
fn fromstr_zero_steps() {
let c: KArmedBanditConfig = "0".parse().unwrap();
assert_eq!(c.max_steps, 0);
}
#[test]
fn fromstr_large_number() {
let c: KArmedBanditConfig = "999999999".parse().unwrap();
assert_eq!(c.max_steps, 999_999_999);
}
#[test]
fn fromstr_invalid_format_errors() {
let err: String = "invalid".parse::<KArmedBanditConfig>().unwrap_err();
assert!(err.contains("Invalid KArmedBanditConfig format"));
}
#[test]
fn fromstr_non_numeric_errors() {
let err = "not_a_number".parse::<KArmedBanditConfig>();
assert!(err.is_err());
}
#[test]
fn fromstr_invalid_kv_number_errors() {
let err: String = "max_steps=invalid"
.parse::<KArmedBanditConfig>()
.unwrap_err();
assert!(err.contains("Failed to parse max_steps value"));
}
#[test]
fn fromstr_unknown_key_errors() {
let err: String = "wrong_key=500".parse::<KArmedBanditConfig>().unwrap_err();
assert!(err.contains("Unknown KArmedBanditConfig key"));
}
#[test]
fn config_default_has_expected_values() {
let c = KArmedBanditConfig::default();
assert_eq!(c.max_steps, 500);
assert_eq!(c.seed, 42);
}
#[test]
fn render_styled_matches_ascii() {
use crate::render::AsciiRenderable;
let env: KArmedBandit<K> = KArmedBandit::with_seed(7);
let plain = env.render_ascii();
let styled = env.render_styled();
assert_eq!(styled.lines.len(), 1);
assert_eq!(styled.plain_text(), plain);
}
#[test]
fn render_styled_uses_palette_consts() {
use crate::render::AsciiRenderable;
use crate::render::palette::{AGENT_FG, AGENT_MODIFIER};
let env: KArmedBandit<K> = KArmedBandit::with_seed(7);
let styled = env.render_styled();
let label = styled.lines[0]
.spans
.iter()
.find(|s| s.text.starts_with("K-armed"))
.expect("K-armed label span present");
assert_eq!(label.style.fg, Some(AGENT_FG));
assert!(label.style.modifier.contains(AGENT_MODIFIER));
}
#[test]
fn render_ascii_within_width_budget() {
use crate::render::AsciiRenderable;
let env: KArmedBandit<K> = KArmedBandit::with_seed(7);
for line in env.render_ascii().lines() {
assert!(
line.chars().count() <= 80,
"line exceeds 80 cols: {line:?} ({} chars)",
line.chars().count()
);
}
}
}