use mctrust::*;
use std::sync::{Arc, Barrier, Mutex};
use std::thread;
use std::time::Duration;
#[test]
fn bandit_concurrent_2_threads_10_arms() {
let search = Arc::new(Mutex::new(BanditSearch::new_seeded(
BanditConfig::default(),
42,
)));
{
let mut s = search.lock().unwrap();
for i in 0..10u64 {
s.add_arm(i, i as u32 % 2);
}
}
let mut handles = vec![];
for _ in 0..2 {
let s = Arc::clone(&search);
handles.push(thread::spawn(move || {
for _ in 0..20 {
let mut locked = s.lock().unwrap();
if let Some(arm) = locked.next_arm() {
locked.observe(arm, 0.5);
}
}
}));
}
for h in handles {
h.join().unwrap();
}
let locked = search.lock().unwrap();
assert_eq!(locked.total_pulls(), 10);
}
#[test]
fn bandit_concurrent_8_threads_100_arms() {
let search = Arc::new(Mutex::new(BanditSearch::new_seeded(
BanditConfig::default(),
42,
)));
{
let mut s = search.lock().unwrap();
for i in 0..100u64 {
s.add_arm(i, i as u32 % 4);
}
}
let barrier = Arc::new(Barrier::new(8));
let mut handles = vec![];
for t in 0..8 {
let b = Arc::clone(&barrier);
let s = Arc::clone(&search);
handles.push(thread::spawn(move || {
b.wait();
for _ in 0..50 {
let mut locked = s.lock().unwrap();
if let Some(arm) = locked.next_arm() {
locked.observe(arm, (t % 2) as f64);
}
}
}));
}
for h in handles {
h.join().unwrap();
}
let locked = search.lock().unwrap();
assert_eq!(locked.total_pulls(), 100);
let stats = locked.group_stats();
let total_visits: u32 = stats.iter().map(|g| g.visits).sum();
assert_eq!(total_visits, 100);
}
#[test]
fn bandit_concurrent_32_threads_1000_arms() {
let search = Arc::new(Mutex::new(BanditSearch::new_seeded(
BanditConfig::default(),
42,
)));
{
let mut s = search.lock().unwrap();
for i in 0..1000u64 {
s.add_arm(i, i as u32 % 10);
}
}
let barrier = Arc::new(Barrier::new(32));
let mut handles = vec![];
for t in 0..32 {
let b = Arc::clone(&barrier);
let s = Arc::clone(&search);
handles.push(thread::spawn(move || {
b.wait();
for _ in 0..100 {
let mut locked = s.lock().unwrap();
if let Some(arm) = locked.next_arm() {
locked.observe(arm, (t % 2) as f64);
}
}
}));
}
for h in handles {
h.join().unwrap();
}
let locked = search.lock().unwrap();
assert_eq!(locked.total_pulls(), 1000);
let stats = locked.group_stats();
assert!(!stats.is_empty());
let total_visits: u32 = stats.iter().map(|g| g.visits).sum();
assert_eq!(total_visits, 1000);
}
#[test]
fn bandit_concurrent_mixed_rewards_no_panic() {
let search = Arc::new(Mutex::new(BanditSearch::new_seeded(
BanditConfig::default(),
42,
)));
{
let mut s = search.lock().unwrap();
for i in 0..200u64 {
s.add_arm(i, i as u32 % 5);
}
}
let mut handles = vec![];
for t in 0..10 {
let s = Arc::clone(&search);
handles.push(thread::spawn(move || {
for i in 0..30 {
let mut locked = s.lock().unwrap();
if let Some(arm) = locked.next_arm() {
let reward = match (t + i) % 4 {
0 => 1.0,
1 => -1.0,
2 => 0.0,
_ => 0.5,
};
locked.observe(arm, reward);
}
}
}));
}
for h in handles {
h.join().unwrap();
}
let locked = search.lock().unwrap();
assert_eq!(locked.total_pulls(), 200);
}
#[test]
fn bandit_concurrent_checkpoint_under_lock() {
let search = Arc::new(Mutex::new(BanditSearch::new_seeded(
BanditConfig::default(),
42,
)));
{
let mut s = search.lock().unwrap();
for i in 0..50u64 {
s.add_arm(i, 0);
}
for _ in 0..10 {
if let Some(arm) = s.next_arm() {
s.observe(arm, 0.5);
}
}
}
let s = Arc::clone(&search);
let handle = thread::spawn(move || {
let locked = s.lock().unwrap();
let _cp = locked.checkpoint();
});
handle.join().unwrap();
let locked = search.lock().unwrap();
assert_eq!(locked.total_pulls(), 10);
}
#[derive(Clone)]
struct ConcurrentGame(i32);
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum ConcurrentMove {
Up,
Down,
}
impl Environment for ConcurrentGame {
type Action = ConcurrentMove;
fn legal_actions(&self) -> Vec<ConcurrentMove> {
if self.0.abs() >= 10 {
vec![]
} else {
vec![ConcurrentMove::Up, ConcurrentMove::Down]
}
}
fn apply(&mut self, action: &ConcurrentMove) {
match action {
ConcurrentMove::Up => self.0 += 1,
ConcurrentMove::Down => self.0 -= 1,
}
}
fn evaluate(&self) -> Outcome {
if self.0 >= 5 {
Outcome::Success(Reward::WIN)
} else if self.0 <= -5 {
Outcome::Failure
} else {
Outcome::Ongoing
}
}
}
#[cfg(feature = "parallel")]
#[test]
fn treesearch_parallel_4_threads() {
let mut search = TreeSearch::with_seed(
ConcurrentGame(0),
SearchConfig::builder()
.iterations(2_000)
.max_depth(10)
.build(),
42,
);
let best = search.run_parallel(4);
assert!(best.is_some());
}
#[cfg(feature = "parallel")]
#[test]
fn treesearch_parallel_8_threads() {
let mut search = TreeSearch::with_seed(
ConcurrentGame(0),
SearchConfig::builder()
.iterations(4_000)
.max_depth(10)
.build(),
42,
);
let best = search.run_parallel(8);
assert!(best.is_some());
}
#[cfg(feature = "parallel")]
#[test]
fn treesearch_parallel_single_thread_matches_serial() {
let config = SearchConfig::builder()
.iterations(1_000)
.max_depth(10)
.build();
let mut serial = TreeSearch::with_seed(ConcurrentGame(0), config.clone(), 42);
let serial_result = serial.run();
let mut parallel = TreeSearch::with_seed(ConcurrentGame(0), config, 42);
let parallel_result = parallel.run_parallel(1);
assert_eq!(serial_result, parallel_result);
}
#[cfg(feature = "parallel")]
#[test]
fn treesearch_parallel_with_time_budget() {
let mut config = SearchConfig::builder().iterations(1_000_000).build();
config.time_budget = Some(Duration::from_millis(50));
let mut search = TreeSearch::new(ConcurrentGame(0), config);
let best = search.run_parallel(4);
assert!(best.is_some());
assert_eq!(
search.total_simulations(),
0,
"run_parallel uses ephemeral workers; the receiver's tree and visit counts are unchanged"
);
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct CheckpointEnv(i32);
#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
struct CheckpointMove(i32);
impl Environment for CheckpointEnv {
type Action = CheckpointMove;
fn legal_actions(&self) -> Vec<CheckpointMove> {
vec![CheckpointMove(1), CheckpointMove(-1)]
}
fn apply(&mut self, action: &CheckpointMove) {
self.0 += action.0;
}
fn evaluate(&self) -> Outcome {
if self.0 >= 3 {
Outcome::Success(Reward::WIN)
} else if self.0 <= -3 {
Outcome::Failure
} else {
Outcome::Ongoing
}
}
}
#[test]
fn treesearch_checkpoint_from_multiple_threads() {
let search = Arc::new(Mutex::new(TreeSearch::with_seed(
CheckpointEnv(0),
SearchConfig::builder().iterations(100).build(),
42,
)));
let mut handles = vec![];
for _ in 0..4 {
let s = Arc::clone(&search);
handles.push(thread::spawn(move || {
let locked = s.lock().unwrap();
let _cp = locked.checkpoint();
}));
}
for h in handles {
h.join().unwrap();
}
}
#[test]
fn treesearch_run_step_concurrent_inspection() {
let search = Arc::new(Mutex::new(TreeSearch::with_seed(
CheckpointEnv(0),
SearchConfig::builder().iterations(100).build(),
42,
)));
{
let mut locked = search.lock().unwrap();
for _ in 0..50 {
locked.run_step();
}
}
let mut handles = vec![];
for _ in 0..4 {
let s = Arc::clone(&search);
handles.push(thread::spawn(move || {
let locked = s.lock().unwrap();
let _size = locked.tree_size();
let _sims = locked.total_simulations();
let _stats = locked.root_stats();
let _pv = locked.principal_variation();
}));
}
for h in handles {
h.join().unwrap();
}
}