use rand::RngExt;
use rand::SeedableRng;
use rand::rngs::StdRng;
use rlevo_core::action::DiscreteAction;
use rlevo_core::base::{Action, Observation, State};
use rlevo_core::environment::{ConstructableEnv, Environment, EnvironmentError, SnapshotBase};
use rlevo_core::reward::ScalarReward;
use serde::{Deserialize, Serialize};
fn draw_card(rng: &mut StdRng) -> u8 {
rng.random_range(1u8..=13).min(10)
}
fn hand_value(hand: &[u8]) -> (u8, bool) {
let sum: u8 = hand.iter().sum();
let has_ace = hand.contains(&1);
if has_ace && sum.saturating_add(10) <= 21 {
(sum + 10, true)
} else {
(sum, false)
}
}
fn is_natural(hand: &[u8]) -> bool {
hand.len() == 2 && hand_value(hand).0 == 21
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BlackjackVariant {
Standard { natural_pays_bonus: bool },
SuttonBarto,
}
impl Default for BlackjackVariant {
fn default() -> Self {
BlackjackVariant::Standard {
natural_pays_bonus: false,
}
}
}
#[derive(Debug, Clone)]
pub struct BlackjackConfig {
pub variant: BlackjackVariant,
pub seed: u64,
}
impl Default for BlackjackConfig {
fn default() -> Self {
Self {
variant: BlackjackVariant::Standard {
natural_pays_bonus: false,
},
seed: 0,
}
}
}
impl BlackjackConfig {
pub fn builder() -> BlackjackConfigBuilder {
BlackjackConfigBuilder::default()
}
}
#[derive(Default)]
pub struct BlackjackConfigBuilder {
variant: BlackjackVariant,
seed: u64,
}
impl BlackjackConfigBuilder {
pub fn variant(mut self, v: BlackjackVariant) -> Self {
self.variant = v;
self
}
pub fn seed(mut self, s: u64) -> Self {
self.seed = s;
self
}
pub fn build(self) -> BlackjackConfig {
BlackjackConfig {
variant: self.variant,
seed: self.seed,
}
}
}
#[derive(Debug, Clone)]
pub struct BlackjackState {
pub player_sum: u8,
pub dealer_showing: u8,
pub usable_ace: bool,
player_hand: Vec<u8>,
dealer_hand: Vec<u8>,
}
impl State<1> for BlackjackState {
type Observation = BlackjackObservation;
fn shape() -> [usize; 1] {
[3]
}
fn observe(&self) -> BlackjackObservation {
BlackjackObservation {
player_sum: self.player_sum,
dealer_showing: self.dealer_showing,
usable_ace: u8::from(self.usable_ace),
}
}
fn is_valid(&self) -> bool {
(4..=32).contains(&self.player_sum) && (1..=10).contains(&self.dealer_showing)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BlackjackObservation {
pub player_sum: u8,
pub dealer_showing: u8,
pub usable_ace: u8,
}
impl Observation<1> for BlackjackObservation {
fn shape() -> [usize; 1] {
[3]
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BlackjackAction {
Stick = 0,
Hit = 1,
}
impl Action<1> for BlackjackAction {
fn shape() -> [usize; 1] {
[1]
}
fn is_valid(&self) -> bool {
true
}
}
impl DiscreteAction<1> for BlackjackAction {
const ACTION_COUNT: usize = 2;
fn from_index(index: usize) -> Self {
match index {
0 => BlackjackAction::Stick,
1 => BlackjackAction::Hit,
_ => panic!("BlackjackAction index {index} out of range [0, 2)"),
}
}
fn to_index(&self) -> usize {
*self as usize
}
}
#[derive(Debug)]
pub struct Blackjack {
state: BlackjackState,
config: BlackjackConfig,
rng: StdRng,
}
impl Blackjack {
pub fn with_config(config: BlackjackConfig) -> Self {
let rng = StdRng::seed_from_u64(config.seed);
Self {
state: BlackjackState {
player_sum: 0,
dealer_showing: 1,
usable_ace: false,
player_hand: Vec::new(),
dealer_hand: Vec::new(),
},
config,
rng,
}
}
fn deal_initial(&mut self) {
let p1 = draw_card(&mut self.rng);
let p2 = draw_card(&mut self.rng);
let d1 = draw_card(&mut self.rng);
let d2 = draw_card(&mut self.rng);
let player_hand = vec![p1, p2];
let (player_sum, usable_ace) = hand_value(&player_hand);
self.state = BlackjackState {
player_sum,
dealer_showing: d1,
usable_ace,
player_hand,
dealer_hand: vec![d1, d2],
};
}
fn apply_sab_override(&self, reward: &mut f32, done: &mut bool) {
if let BlackjackVariant::SuttonBarto = self.config.variant {
let p_nat = is_natural(&self.state.player_hand);
let d_nat = is_natural(&self.state.dealer_hand);
if p_nat && !d_nat {
*reward = 1.0;
*done = true;
} else if d_nat && !p_nat {
*reward = -1.0;
*done = true;
}
}
}
}
impl ConstructableEnv for Blackjack {
fn new(_render: bool) -> Self {
Self::with_config(BlackjackConfig::default())
}
}
impl Environment<1, 1, 1> for Blackjack {
type StateType = BlackjackState;
type ObservationType = BlackjackObservation;
type ActionType = BlackjackAction;
type RewardType = ScalarReward;
type SnapshotType = SnapshotBase<1, BlackjackObservation, ScalarReward>;
fn reset(&mut self) -> Result<Self::SnapshotType, EnvironmentError> {
self.deal_initial();
Ok(SnapshotBase::running(
self.state.observe(),
ScalarReward(0.0),
))
}
fn step(&mut self, action: BlackjackAction) -> Result<Self::SnapshotType, EnvironmentError> {
let mut reward;
let mut done;
match action {
BlackjackAction::Hit => {
let card = draw_card(&mut self.rng);
self.state.player_hand.push(card);
let (sum, ace) = hand_value(&self.state.player_hand);
self.state.player_sum = sum;
self.state.usable_ace = ace;
if sum > 21 {
reward = -1.0_f32;
done = true;
} else {
reward = 0.0_f32;
done = false;
}
}
BlackjackAction::Stick => {
while hand_value(&self.state.dealer_hand).0 < 17 {
let card = draw_card(&mut self.rng);
self.state.dealer_hand.push(card);
}
let dealer_sum = hand_value(&self.state.dealer_hand).0;
let player_sum = self.state.player_sum;
reward = if dealer_sum > 21 || player_sum > dealer_sum {
1.0
} else if player_sum < dealer_sum {
-1.0
} else {
0.0
};
let pays_natural = matches!(
self.config.variant,
BlackjackVariant::Standard {
natural_pays_bonus: true
}
);
if pays_natural
&& (reward - 1.0).abs() < 1e-6
&& is_natural(&self.state.player_hand)
&& !is_natural(&self.state.dealer_hand)
{
reward = 1.5;
}
done = true;
}
}
self.apply_sab_override(&mut reward, &mut done);
let obs = self.state.observe();
if done {
Ok(SnapshotBase::terminated(obs, ScalarReward(reward)))
} else {
Ok(SnapshotBase::running(obs, ScalarReward(reward)))
}
}
}
impl crate::render::AsciiRenderable for Blackjack {
fn render_ascii(&self) -> String {
let ace = if self.state.usable_ace { "A" } else { "" };
format!(
"Blackjack player={}{ace} dealer_showing={}",
self.state.player_sum, self.state.dealer_showing
)
}
fn render_styled(&self) -> crate::render::StyledFrame {
use crate::render::palette::{AGENT_FG, AGENT_MODIFIER};
use crate::render::{SpanStyle, StyledFrame, StyledLine, StyledSpan};
const LABEL: &str = "Blackjack";
let line = self.render_ascii();
let label_style = SpanStyle::default()
.fg(AGENT_FG)
.with_modifier(AGENT_MODIFIER);
let styled_line = if let Some(rest) = line.strip_prefix(LABEL) {
StyledLine::from_spans(vec![
StyledSpan::new(LABEL, label_style),
StyledSpan::raw(rest.to_string()),
])
} else {
StyledLine::unstyled(line)
};
StyledFrame {
lines: vec![styled_line],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlevo_core::action::DiscreteAction;
use rlevo_core::base::Observation;
use rlevo_core::environment::Snapshot;
fn make_env() -> Blackjack {
Blackjack::with_config(BlackjackConfig::default())
}
impl Blackjack {
fn set_hands(&mut self, player_hand: Vec<u8>, dealer_hand: Vec<u8>) {
let (player_sum, usable_ace) = hand_value(&player_hand);
let dealer_showing = dealer_hand[0];
self.state = BlackjackState {
player_sum,
dealer_showing,
usable_ace,
player_hand,
dealer_hand,
};
}
}
#[test]
fn action_count() {
assert_eq!(BlackjackAction::ACTION_COUNT, 2);
}
#[test]
fn action_roundtrip() {
for i in 0..BlackjackAction::ACTION_COUNT {
assert_eq!(BlackjackAction::from_index(i).to_index(), i);
}
}
#[test]
fn obs_shape() {
assert_eq!(BlackjackObservation::shape(), [3]);
}
#[test]
fn obs_encoding() {
let obs = BlackjackObservation {
player_sum: 18,
dealer_showing: 10,
usable_ace: 0,
};
assert_eq!(obs.player_sum, 18);
assert_eq!(obs.dealer_showing, 10);
assert_eq!(obs.usable_ace, 0);
}
#[test]
fn bust_on_hit_returns_negative_one() {
for seed in 0u64..200 {
let mut env = make_env();
env.reset().unwrap();
env.set_hands(vec![10, 10], vec![6, 5]);
env.rng = StdRng::seed_from_u64(seed);
let snap = env.step(BlackjackAction::Hit).unwrap();
let r: f32 = (*snap.reward()).into();
if r == -1.0 {
assert!(snap.is_done());
return;
}
}
panic!("could not find a busting seed in 200 tries");
}
#[test]
fn dealer_bust_returns_positive_one() {
for seed in 0u64..200 {
let mut env = make_env();
env.reset().unwrap();
env.set_hands(vec![9, 9], vec![10, 6]);
env.rng = StdRng::seed_from_u64(seed);
let snap = env.step(BlackjackAction::Stick).unwrap();
let r: f32 = (*snap.reward()).into();
if r == 1.0 {
assert!(snap.is_done());
return;
}
}
panic!("could not find a dealer-bust seed in 200 tries");
}
#[test]
fn push_on_equal_sums() {
let mut env = make_env();
env.reset().unwrap();
env.set_hands(vec![9, 9], vec![10, 8]);
let snap = env.step(BlackjackAction::Stick).unwrap();
let r: f32 = (*snap.reward()).into();
assert_eq!(r, 0.0, "equal sums must push (reward 0), got {r}");
assert!(snap.is_done());
}
#[test]
fn natural_pays_1_5_when_flag_set() {
let cfg = BlackjackConfig::builder()
.variant(BlackjackVariant::Standard {
natural_pays_bonus: true,
})
.seed(0)
.build();
let mut env = Blackjack::with_config(cfg);
env.reset().unwrap();
env.set_hands(vec![1, 10], vec![8, 9]);
let snap = env.step(BlackjackAction::Stick).unwrap();
let r: f32 = (*snap.reward()).into();
assert!((r - 1.5).abs() < 1e-6, "natural pays 1.5, got {r}");
assert!(snap.is_done());
}
#[test]
fn sab_player_natural_wins_one() {
let cfg = BlackjackConfig::builder()
.variant(BlackjackVariant::SuttonBarto)
.seed(0)
.build();
let mut env = Blackjack::with_config(cfg);
env.reset().unwrap();
env.set_hands(vec![1, 10], vec![8, 9]);
let snap = env.step(BlackjackAction::Stick).unwrap();
let r: f32 = (*snap.reward()).into();
assert!(
(r - 1.0).abs() < 1e-6,
"SAB player natural must pay 1.0, got {r}"
);
assert!(snap.is_done());
}
#[test]
fn sab_dealer_natural_costs_one() {
let cfg = BlackjackConfig::builder()
.variant(BlackjackVariant::SuttonBarto)
.seed(0)
.build();
let mut env = Blackjack::with_config(cfg);
env.reset().unwrap();
env.set_hands(vec![9, 7], vec![1, 10]);
let snap = env.step(BlackjackAction::Stick).unwrap();
let r: f32 = (*snap.reward()).into();
assert!(
(r - (-1.0)).abs() < 1e-6,
"SAB dealer natural costs -1.0, got {r}"
);
assert!(snap.is_done());
}
#[test]
fn determinism() {
let cfg = BlackjackConfig {
variant: BlackjackVariant::default(),
seed: 99,
};
let run = || {
let mut env = Blackjack::with_config(cfg.clone());
let mut total = 0.0_f32;
for _ in 0..10 {
env.reset().unwrap();
loop {
let a = if env.state.player_sum < 18 {
BlackjackAction::Hit
} else {
BlackjackAction::Stick
};
let snap = env.step(a).unwrap();
let r: f32 = (*snap.reward()).into();
total += r;
if snap.is_done() {
break;
}
}
}
total
};
let a = run();
let b = run();
assert!(
(a - b).abs() < 1e-5,
"same seed must give same rewards; got {a} vs {b}"
);
}
#[test]
fn render_styled_matches_ascii() {
use crate::render::AsciiRenderable;
let mut env = Blackjack::with_config(BlackjackConfig::default());
env.reset().unwrap();
let plain = env.render_ascii();
let styled = env.render_styled();
assert_eq!(styled.lines.len(), 1);
assert_eq!(styled.plain_text(), plain);
}
#[test]
fn render_styled_uses_palette_consts() {
use crate::render::AsciiRenderable;
use crate::render::palette::{AGENT_FG, AGENT_MODIFIER};
let mut env = Blackjack::with_config(BlackjackConfig::default());
env.reset().unwrap();
let styled = env.render_styled();
let label = styled.lines[0]
.spans
.iter()
.find(|s| s.text == "Blackjack")
.expect("Blackjack label span present");
assert_eq!(label.style.fg, Some(AGENT_FG));
assert!(label.style.modifier.contains(AGENT_MODIFIER));
}
#[test]
fn render_ascii_within_width_budget() {
use crate::render::AsciiRenderable;
let mut env = Blackjack::with_config(BlackjackConfig::default());
env.reset().unwrap();
for line in env.render_ascii().lines() {
assert!(
line.chars().count() <= 80,
"line exceeds 80 cols: {line:?} ({} chars)",
line.chars().count()
);
}
}
}
impl rlevo_core::render::payload::TabularPayloadSource for Blackjack {
fn tabular_snapshot(&self) -> rlevo_core::render::payload::TabularSnapshot {
use rlevo_core::render::payload::{CardTable, TabularLayout, TabularSnapshot};
TabularSnapshot {
layout: TabularLayout::Cards(CardTable {
player_cards: self.state.player_hand.clone(),
player_total: self.state.player_sum,
usable_ace: self.state.usable_ace,
dealer_cards: self.state.dealer_hand.clone(),
dealer_showing: self.state.dealer_showing,
}),
}
}
}