#![allow(dead_code)]
pub mod dto;
pub use dto::*;
pub type Chips = i16;
pub type Position = usize;
pub type Epoch = i16;
pub type Energy = f32;
pub type Entropy = f32;
pub type Utility = f32;
pub type Probability = f32;
pub trait Arbitrary {
fn random() -> Self;
}
pub trait Unique<T = Self> {
fn id(&self) -> ID<T>;
}
use std::cmp::Ordering;
use std::fmt::Debug;
use std::fmt::Display;
use std::fmt::Formatter;
use std::hash::Hash;
use std::hash::Hasher;
use std::marker::PhantomData;
pub struct ID<T> {
inner: uuid::Uuid,
marker: PhantomData<T>,
}
impl<T> ID<T> {
pub fn inner(&self) -> uuid::Uuid {
self.inner
}
pub fn cast<U>(self) -> ID<U> {
ID {
inner: self.inner,
marker: PhantomData,
}
}
}
impl<T> From<ID<T>> for uuid::Uuid {
fn from(id: ID<T>) -> Self {
id.inner()
}
}
impl<T> From<uuid::Uuid> for ID<T> {
fn from(inner: uuid::Uuid) -> Self {
Self {
inner,
marker: PhantomData,
}
}
}
impl<T> Default for ID<T> {
fn default() -> Self {
Self {
inner: uuid::Uuid::now_v7(),
marker: PhantomData,
}
}
}
impl<T> Copy for ID<T> {}
impl<T> Clone for ID<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T> Eq for ID<T> {}
impl<T> PartialEq for ID<T> {
fn eq(&self, other: &Self) -> bool {
self.inner == other.inner
}
}
impl<T> Ord for ID<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.inner.cmp(&other.inner)
}
}
impl<T> PartialOrd for ID<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T> Hash for ID<T> {
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
self.inner.hash(state);
}
}
impl<T> Debug for ID<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("ID").field(&self.inner).finish()
}
}
impl<T> Display for ID<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.inner, f)
}
}
pub const N: usize = 2;
pub const STACK: Chips = 100;
pub const B_BLIND: Chips = 2;
pub const S_BLIND: Chips = 1;
pub const MAX_RAISE_REPEATS: usize = 3;
pub const MAX_DEPTH_SUBGAME: usize = 16;
pub const MAX_DEPTH_ALLGAME: usize = 32;
pub const SHOWDOWN_TIMEOUT: u64 = 5;
pub const SINKHORN_TEMPERATURE: Entropy = 0.025;
pub const SINKHORN_ITERATIONS: usize = 128;
pub const SINKHORN_TOLERANCE: Energy = 0.001;
pub const KMEANS_FLOP_TRAINING_ITERATIONS: usize = 20;
pub const KMEANS_TURN_TRAINING_ITERATIONS: usize = 24;
pub const KMEANS_FLOP_CLUSTER_COUNT: usize = 128;
pub const KMEANS_TURN_CLUSTER_COUNT: usize = 144;
pub const KMEANS_EQTY_CLUSTER_COUNT: usize = 101;
pub const ASYMMETRIC_UTILITY: f32 = 2.0;
pub const CFR_BATCH_SIZE_RPS: usize = 1;
pub const CFR_TREE_COUNT_RPS: usize = 8192;
pub const CFR_BATCH_SIZE_NLHE: usize = 128;
pub const CFR_TREE_COUNT_NLHE: usize = 0x10000000;
pub const CFR_BATCH_SIZE_RIVER: usize = 16;
pub const CFR_TREE_COUNT_RIVER: usize = 0x10000;
pub const SAMPLING_TEMPERATURE: Entropy = 2.0;
pub const SAMPLING_SMOOTHING: Energy = 0.5;
pub const SAMPLING_CURIOSITY: Probability = 0.01;
pub const POLICY_MIN: Probability = Probability::MIN_POSITIVE;
pub const REGRET_MIN: Utility = -4e6;
pub const PRUNING_THRESHOLD: Utility = -3e5;
pub const PRUNING_EXPLORE: Probability = 0.05;
pub const PRUNING_WARMUP: usize = 524288;
pub const SUBGAME_ALTS: usize = 4;
pub const SUBGAME_ITERATIONS: usize = 1024;
pub const TRAINING_LOG_INTERVAL: std::time::Duration = std::time::Duration::from_secs(60);
pub const BIAS_FOLDS: Utility = 3.0;
pub const BIAS_RAISE: Utility = 0.5;
pub const BIAS_OTHER: Utility = 1.0;
#[cfg(feature = "server")]
pub fn log() {
std::fs::create_dir_all("logs").expect("create logs directory");
let config = simplelog::ConfigBuilder::new()
.set_location_level(log::LevelFilter::Off)
.set_target_level(log::LevelFilter::Off)
.set_thread_level(log::LevelFilter::Off)
.build();
let time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time moves slow")
.as_secs();
let file = simplelog::WriteLogger::new(
log::LevelFilter::Debug,
config.clone(),
std::fs::File::create(format!("logs/{}.log", time)).expect("create log file"),
);
let term = simplelog::TermLogger::new(
log::LevelFilter::Info,
config.clone(),
simplelog::TerminalMode::Mixed,
simplelog::ColorChoice::Auto,
);
simplelog::CombinedLogger::init(vec![term, file]).expect("initialize logger");
}
#[cfg(feature = "server")]
pub fn kys() {
tokio::spawn(async move {
tokio::signal::ctrl_c().await.unwrap();
println!();
log::warn!("violent interrupt received, exiting immediately");
std::process::exit(0);
});
}
#[cfg(feature = "server")]
static INTERRUPTED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
#[cfg(feature = "server")]
static DEADLINE: std::sync::OnceLock<std::time::Instant> = std::sync::OnceLock::new();
#[cfg(feature = "server")]
pub fn interrupted() -> bool {
INTERRUPTED.load(std::sync::atomic::Ordering::Relaxed)
|| DEADLINE
.get()
.map_or(false, |d| std::time::Instant::now() >= *d)
}
#[cfg(not(feature = "server"))]
pub fn interrupted() -> bool {
false
}
#[cfg(feature = "server")]
pub fn brb() {
if let Ok(duration) = std::env::var("TRAIN_DURATION") {
if let Some(deadline) = parse_duration(&duration) {
let _ = DEADLINE.set(std::time::Instant::now() + deadline);
log::info!("training will stop after {}", duration);
}
}
std::thread::spawn(|| {
loop {
let ref mut buffer = String::new();
if let Ok(_) = std::io::stdin().read_line(buffer) {
if buffer.trim().to_uppercase() == "Q" {
log::warn!("graceful interrupt requested, finishing current batch...");
INTERRUPTED.store(true, std::sync::atomic::Ordering::Relaxed);
break;
}
}
}
});
}
#[cfg(feature = "server")]
fn parse_duration(s: &str) -> Option<std::time::Duration> {
let s = s.trim();
let (num, unit) = s.split_at(s.len().saturating_sub(1));
let value: u64 = num.parse().ok()?;
match unit {
"s" => Some(std::time::Duration::from_secs(value)),
"m" => Some(std::time::Duration::from_secs(value * 60)),
"h" => Some(std::time::Duration::from_secs(value * 3600)),
"d" => Some(std::time::Duration::from_secs(value * 86400)),
_ => None,
}
}