use std::collections::HashMap;
use std::fmt;
use crate::core::{Hand, RSPokerError, Value};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(try_from = "String", into = "String"))]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub struct PreflopHand {
high: Value,
low: Value,
suited: bool,
}
impl PreflopHand {
pub fn new(v1: Value, v2: Value, suited: bool) -> Self {
let (high, low) = if v1 >= v2 { (v1, v2) } else { (v2, v1) };
let suited = if high == low { false } else { suited };
Self { high, low, suited }
}
pub fn is_pair(&self) -> bool {
self.high == self.low
}
pub fn suited(&self) -> bool {
self.suited
}
pub fn high(&self) -> Value {
self.high
}
pub fn low(&self) -> Value {
self.low
}
pub fn to_notation(&self) -> String {
let high_char = self.high.to_char();
let low_char = self.low.to_char();
if self.is_pair() {
format!("{}{}", high_char, low_char)
} else if self.suited {
format!("{}{}s", high_char, low_char)
} else {
format!("{}{}o", high_char, low_char)
}
}
pub fn from_notation(s: &str) -> Result<Self, RSPokerError> {
let chars: Vec<char> = s.chars().collect();
if chars.len() < 2 || chars.len() > 3 {
return Err(RSPokerError::InvalidPreflopNotation(s.to_string()));
}
let v1 = Value::from_char(chars[0])
.ok_or_else(|| RSPokerError::InvalidPreflopNotation(s.to_string()))?;
let v2 = Value::from_char(chars[1])
.ok_or_else(|| RSPokerError::InvalidPreflopNotation(s.to_string()))?;
let suited = if chars.len() == 2 {
if v1 != v2 {
return Err(RSPokerError::InvalidPreflopNotation(s.to_string()));
}
false
} else {
match chars[2].to_ascii_lowercase() {
's' => {
if v1 == v2 {
return Err(RSPokerError::InvalidPreflopNotation(s.to_string()));
}
true
}
'o' => false,
_ => return Err(RSPokerError::InvalidPreflopNotation(s.to_string())),
}
};
Ok(Self::new(v1, v2, suited))
}
pub fn all() -> Vec<Self> {
let mut hands = Vec::with_capacity(169);
let values = Value::values();
for (i, &high) in values.iter().enumerate() {
for &low in &values[..=i] {
hands.push(Self::new(high, low, false));
if high != low {
hands.push(Self::new(high, low, true));
}
}
}
hands
}
}
impl fmt::Display for PreflopHand {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_notation())
}
}
impl TryFrom<String> for PreflopHand {
type Error = RSPokerError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::from_notation(&value)
}
}
impl From<PreflopHand> for String {
fn from(hand: PreflopHand) -> Self {
hand.to_notation()
}
}
impl TryFrom<&Hand> for PreflopHand {
type Error = RSPokerError;
fn try_from(hand: &Hand) -> Result<Self, Self::Error> {
let count = hand.count();
if count != 2 {
return Err(RSPokerError::InvalidPreflopHandSize(count));
}
let mut iter = hand.iter();
let c1 = iter
.next()
.ok_or(RSPokerError::InvalidPreflopHandSize(count))?;
let c2 = iter
.next()
.ok_or(RSPokerError::InvalidPreflopHandSize(count))?;
let suited = c1.suit == c2.suit;
Ok(Self::new(c1.value, c2.value, suited))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub enum PreflopActionType {
Fold,
Call,
Raise,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub enum PreflopScenario {
Rfi,
VsOpen,
Vs3Bet,
Vs4Bet,
}
impl PreflopScenario {
pub fn from_raise_count(raises: u8) -> Self {
match raises {
0 => Self::Rfi,
1 => Self::VsOpen,
2 => Self::Vs3Bet,
_ => Self::Vs4Bet,
}
}
pub const fn all() -> [Self; 4] {
[Self::Rfi, Self::VsOpen, Self::Vs3Bet, Self::Vs4Bet]
}
pub fn label(self) -> &'static str {
match self {
Self::Rfi => "RFI",
Self::VsOpen => "vs Open",
Self::Vs3Bet => "vs 3-Bet",
Self::Vs4Bet => "vs 4-Bet",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub struct PreflopStrategy {
#[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "is_zero"))]
raise: f32,
#[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "is_zero"))]
call: f32,
}
#[cfg(feature = "serde")]
fn is_zero(f: &f32) -> bool {
*f == 0.0
}
impl PreflopStrategy {
pub fn new(raise: f32, call: f32) -> Result<Self, RSPokerError> {
if !(0.0..=1.0 + 0.001).contains(&raise) {
return Err(RSPokerError::InvalidStrategyFrequencies(format!(
"raise = {raise}"
)));
}
if !(0.0..=1.0 + 0.001).contains(&call) {
return Err(RSPokerError::InvalidStrategyFrequencies(format!(
"call = {call}"
)));
}
if raise + call > 1.0 + 0.001 {
return Err(RSPokerError::InvalidStrategyFrequencies(format!(
"raise + call = {:.3}",
raise + call
)));
}
Ok(Self { raise, call })
}
pub const fn fold() -> Self {
Self {
raise: 0.0,
call: 0.0,
}
}
pub const fn pure_raise() -> Self {
Self {
raise: 1.0,
call: 0.0,
}
}
pub const fn pure_call() -> Self {
Self {
raise: 0.0,
call: 1.0,
}
}
pub fn raise(&self) -> f32 {
self.raise
}
pub fn call(&self) -> f32 {
self.call
}
pub fn fold_freq(&self) -> f32 {
(1.0 - self.raise - self.call).max(0.0)
}
pub fn is_pure_fold(&self) -> bool {
self.raise == 0.0 && self.call == 0.0
}
pub fn frequency(&self, action: PreflopActionType) -> f32 {
match action {
PreflopActionType::Raise => self.raise,
PreflopActionType::Call => self.call,
PreflopActionType::Fold => self.fold_freq(),
}
}
pub fn sample(&self, random_value: f32) -> PreflopActionType {
if random_value < self.raise {
PreflopActionType::Raise
} else if random_value < self.raise + self.call {
PreflopActionType::Call
} else {
PreflopActionType::Fold
}
}
}
impl Default for PreflopStrategy {
fn default() -> Self {
Self::fold()
}
}
#[derive(Debug, Clone, PartialEq, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub struct PreflopChart {
strategies: HashMap<PreflopHand, PreflopStrategy>,
}
impl PreflopChart {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self, hand: &PreflopHand) -> Option<&PreflopStrategy> {
self.strategies.get(hand)
}
pub fn get_or_fold(&self, hand: &PreflopHand) -> PreflopStrategy {
self.strategies.get(hand).copied().unwrap_or_default()
}
pub fn set(&mut self, hand: PreflopHand, strategy: PreflopStrategy) {
self.strategies.insert(hand, strategy);
}
pub fn remove(&mut self, hand: &PreflopHand) -> Option<PreflopStrategy> {
self.strategies.remove(hand)
}
pub fn iter(&self) -> impl Iterator<Item = (&PreflopHand, &PreflopStrategy)> {
self.strategies.iter()
}
pub fn len(&self) -> usize {
self.strategies.len()
}
pub fn is_empty(&self) -> bool {
self.strategies.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{Card, Suit};
#[test]
fn test_preflop_hand_ordering() {
let hand1 = PreflopHand::new(Value::King, Value::Ace, true);
let hand2 = PreflopHand::new(Value::Ace, Value::King, true);
assert_eq!(hand1, hand2);
assert_eq!(hand1.high(), Value::Ace);
assert_eq!(hand1.low(), Value::King);
}
#[test]
fn test_preflop_hand_try_from_hand() {
let mut cards = Hand::new();
cards.insert(Card::new(Value::Ace, Suit::Spade));
cards.insert(Card::new(Value::King, Suit::Spade));
let preflop = PreflopHand::try_from(&cards).expect("valid 2-card hand");
assert!(preflop.suited());
assert_eq!(preflop.high(), Value::Ace);
assert_eq!(preflop.low(), Value::King);
let mut cards = Hand::new();
cards.insert(Card::new(Value::Ace, Suit::Spade));
cards.insert(Card::new(Value::King, Suit::Heart));
let preflop = PreflopHand::try_from(&cards).expect("valid 2-card hand");
assert!(!preflop.suited());
let mut cards = Hand::new();
cards.insert(Card::new(Value::Queen, Suit::Spade));
cards.insert(Card::new(Value::Queen, Suit::Heart));
let preflop = PreflopHand::try_from(&cards).expect("valid 2-card hand");
assert!(preflop.is_pair());
assert!(!preflop.suited());
}
#[test]
fn test_preflop_hand_try_from_invalid_size() {
let cards = Hand::new();
assert!(PreflopHand::try_from(&cards).is_err());
let mut cards = Hand::new();
cards.insert(Card::new(Value::Ace, Suit::Spade));
assert!(PreflopHand::try_from(&cards).is_err());
let mut cards = Hand::new();
cards.insert(Card::new(Value::Ace, Suit::Spade));
cards.insert(Card::new(Value::King, Suit::Spade));
cards.insert(Card::new(Value::Queen, Suit::Spade));
assert!(PreflopHand::try_from(&cards).is_err());
}
#[test]
fn test_preflop_hand_all_count() {
let all = PreflopHand::all();
assert_eq!(all.len(), 169);
let pairs = all.iter().filter(|h| h.is_pair()).count();
let suited = all.iter().filter(|h| h.suited()).count();
let offsuit = all.iter().filter(|h| !h.is_pair() && !h.suited()).count();
assert_eq!(pairs, 13);
assert_eq!(suited, 78);
assert_eq!(offsuit, 78);
}
#[test]
fn test_notation_roundtrip() {
for hand in PreflopHand::all() {
let notation = hand.to_notation();
let parsed = PreflopHand::from_notation(¬ation).unwrap();
assert_eq!(hand, parsed, "Failed roundtrip for {}", notation);
}
}
#[test]
fn test_notation_case_insensitive() {
let parsed_lower = PreflopHand::from_notation("aks").unwrap();
let parsed_upper = PreflopHand::from_notation("AKS").unwrap();
let parsed_mixed = PreflopHand::from_notation("AkS").unwrap();
assert_eq!(parsed_lower, parsed_upper);
assert_eq!(parsed_lower, parsed_mixed);
}
#[test]
fn test_notation_invalid() {
assert!(PreflopHand::from_notation("A").is_err());
assert!(PreflopHand::from_notation("AKso").is_err());
assert!(PreflopHand::from_notation("XKs").is_err());
assert!(PreflopHand::from_notation("AK").is_err());
assert!(PreflopHand::from_notation("AAs").is_err());
assert!(PreflopHand::from_notation("AKx").is_err());
}
#[test]
fn test_strategy_pure_raise() {
let s = PreflopStrategy::pure_raise();
assert_eq!(s.raise(), 1.0);
assert_eq!(s.call(), 0.0);
assert_eq!(s.fold_freq(), 0.0);
assert_eq!(s.frequency(PreflopActionType::Raise), 1.0);
assert_eq!(s.frequency(PreflopActionType::Fold), 0.0);
}
#[test]
fn test_strategy_pure_call() {
let s = PreflopStrategy::pure_call();
assert_eq!(s.raise(), 0.0);
assert_eq!(s.call(), 1.0);
assert_eq!(s.fold_freq(), 0.0);
}
#[test]
fn test_strategy_fold() {
let s = PreflopStrategy::fold();
assert_eq!(s.fold_freq(), 1.0);
assert_eq!(s.raise(), 0.0);
assert_eq!(s.call(), 0.0);
assert!(s.is_pure_fold());
}
#[test]
fn test_strategy_mixed() {
let s = PreflopStrategy::new(0.6, 0.3).unwrap();
assert_eq!(s.raise(), 0.6);
assert_eq!(s.call(), 0.3);
assert!((s.fold_freq() - 0.1).abs() < 1e-5);
}
#[test]
fn test_strategy_rejects_out_of_range() {
assert!(PreflopStrategy::new(1.5, 0.0).is_err());
assert!(PreflopStrategy::new(-0.1, 0.0).is_err());
assert!(PreflopStrategy::new(0.0, 1.5).is_err());
assert!(PreflopStrategy::new(0.0, -0.1).is_err());
}
#[test]
fn test_strategy_rejects_sum_over_one() {
let result = PreflopStrategy::new(0.6, 0.5);
assert!(result.is_err());
}
#[test]
fn test_strategy_allows_partial_sum() {
let s = PreflopStrategy::new(0.5, 0.0).unwrap();
assert!((s.fold_freq() - 0.5).abs() < 1e-5);
}
#[test]
fn test_strategy_sample() {
let s = PreflopStrategy::new(0.5, 0.3).unwrap();
assert_eq!(s.sample(0.0), PreflopActionType::Raise);
assert_eq!(s.sample(0.49), PreflopActionType::Raise);
assert_eq!(s.sample(0.5), PreflopActionType::Call);
assert_eq!(s.sample(0.79), PreflopActionType::Call);
assert_eq!(s.sample(0.85), PreflopActionType::Fold);
}
#[test]
fn test_strategy_default_is_fold() {
let s = PreflopStrategy::default();
assert!(s.is_pure_fold());
}
#[test]
fn test_chart_get_set() {
let mut chart = PreflopChart::new();
let aa = PreflopHand::new(Value::Ace, Value::Ace, false);
assert!(chart.get(&aa).is_none());
chart.set(aa, PreflopStrategy::pure_raise());
assert_eq!(chart.get(&aa).unwrap().raise(), 1.0);
}
#[test]
fn test_chart_get_or_fold() {
let chart = PreflopChart::new();
let unknown = PreflopHand::new(Value::Seven, Value::Two, false);
let strategy = chart.get_or_fold(&unknown);
assert!(strategy.is_pure_fold());
}
#[test]
fn test_chart_remove() {
let mut chart = PreflopChart::new();
let aa = PreflopHand::new(Value::Ace, Value::Ace, false);
chart.set(aa, PreflopStrategy::pure_raise());
assert_eq!(chart.len(), 1);
assert!(chart.remove(&aa).is_some());
assert_eq!(chart.len(), 0);
assert!(chart.remove(&aa).is_none());
}
#[test]
fn test_chart_len_is_empty() {
let mut chart = PreflopChart::new();
assert!(chart.is_empty());
chart.set(
PreflopHand::new(Value::Ace, Value::Ace, false),
PreflopStrategy::pure_raise(),
);
assert!(!chart.is_empty());
assert_eq!(chart.len(), 1);
}
#[test]
fn test_scenario_from_raise_count() {
assert_eq!(PreflopScenario::from_raise_count(0), PreflopScenario::Rfi);
assert_eq!(
PreflopScenario::from_raise_count(1),
PreflopScenario::VsOpen
);
assert_eq!(
PreflopScenario::from_raise_count(2),
PreflopScenario::Vs3Bet
);
assert_eq!(
PreflopScenario::from_raise_count(3),
PreflopScenario::Vs4Bet
);
assert_eq!(
PreflopScenario::from_raise_count(10),
PreflopScenario::Vs4Bet
);
}
#[cfg(feature = "serde")]
mod serde_tests {
use super::*;
#[test]
fn test_serde_hand_roundtrip() {
let hand = PreflopHand::new(Value::Ace, Value::King, true);
let json = serde_json::to_string(&hand).unwrap();
assert_eq!(json, "\"AKs\"");
let parsed: PreflopHand = serde_json::from_str(&json).unwrap();
assert_eq!(hand, parsed);
}
#[test]
fn test_serde_action_type() {
let action = PreflopActionType::Raise;
let json = serde_json::to_string(&action).unwrap();
assert_eq!(json, "\"Raise\"");
let parsed: PreflopActionType = serde_json::from_str(&json).unwrap();
assert_eq!(action, parsed);
}
#[test]
fn test_serde_strategy_minimal() {
let s: PreflopStrategy = serde_json::from_str("{}").unwrap();
assert!(s.is_pure_fold());
}
#[test]
fn test_serde_strategy_raise_only() {
let s: PreflopStrategy = serde_json::from_str(r#"{"raise": 1.0}"#).unwrap();
assert_eq!(s.raise(), 1.0);
assert_eq!(s.call(), 0.0);
}
#[test]
fn test_serde_strategy_call_only() {
let s: PreflopStrategy = serde_json::from_str(r#"{"call": 0.5}"#).unwrap();
assert_eq!(s.raise(), 0.0);
assert_eq!(s.call(), 0.5);
}
#[test]
fn test_serde_strategy_both() {
let s: PreflopStrategy =
serde_json::from_str(r#"{"raise": 0.5, "call": 0.3}"#).unwrap();
assert_eq!(s.raise(), 0.5);
assert_eq!(s.call(), 0.3);
}
#[test]
fn test_serde_strategy_skip_zero() {
let s = PreflopStrategy::pure_raise();
let json = serde_json::to_string(&s).unwrap();
assert_eq!(json, r#"{"raise":1.0}"#);
let s = PreflopStrategy::fold();
let json = serde_json::to_string(&s).unwrap();
assert_eq!(json, "{}");
}
#[test]
fn test_serde_chart_json() {
let mut chart = PreflopChart::new();
chart.set(
PreflopHand::new(Value::Ace, Value::Ace, false),
PreflopStrategy::new(0.85, 0.15).unwrap(),
);
chart.set(
PreflopHand::new(Value::Ace, Value::King, true),
PreflopStrategy::pure_raise(),
);
let json = serde_json::to_string(&chart).unwrap();
let parsed: PreflopChart = serde_json::from_str(&json).unwrap();
assert_eq!(chart, parsed);
}
#[test]
fn test_serde_chart_transparent() {
let chart = PreflopChart::new();
let json = serde_json::to_string(&chart).unwrap();
assert_eq!(json, "{}");
}
}
}