mctrust 0.4.0

Universal search & planning toolkit — MCTS, bandit search, pluggable evaluators, tree reuse, DAG transpositions, root parallelism. Define an Environment, search handles the rest.
Documentation
//! Concurrent stress tests for mctrust.
//!
//! These tests exercise multi-threaded access patterns, race conditions,
//! and parallel search scaling.

use mctrust::*;
use std::sync::{Arc, Barrier, Mutex};
use std::thread;
use std::time::Duration;

// =============================================================================
// BanditSearch Concurrent Hammer Tests
// =============================================================================

#[test]
fn bandit_concurrent_2_threads_10_arms() {
    let search = Arc::new(Mutex::new(BanditSearch::new_seeded(
        BanditConfig::default(),
        42,
    )));
    {
        let mut s = search.lock().unwrap();
        for i in 0..10u64 {
            s.add_arm(i, i as u32 % 2);
        }
    }

    let mut handles = vec![];
    for _ in 0..2 {
        let s = Arc::clone(&search);
        handles.push(thread::spawn(move || {
            for _ in 0..20 {
                let mut locked = s.lock().unwrap();
                if let Some(arm) = locked.next_arm() {
                    locked.observe(arm, 0.5);
                }
            }
        }));
    }

    for h in handles {
        h.join().unwrap();
    }

    let locked = search.lock().unwrap();
    assert_eq!(locked.total_pulls(), 10);
}

#[test]
fn bandit_concurrent_8_threads_100_arms() {
    let search = Arc::new(Mutex::new(BanditSearch::new_seeded(
        BanditConfig::default(),
        42,
    )));
    {
        let mut s = search.lock().unwrap();
        for i in 0..100u64 {
            s.add_arm(i, i as u32 % 4);
        }
    }

    let barrier = Arc::new(Barrier::new(8));
    let mut handles = vec![];
    for t in 0..8 {
        let b = Arc::clone(&barrier);
        let s = Arc::clone(&search);
        handles.push(thread::spawn(move || {
            b.wait();
            for _ in 0..50 {
                let mut locked = s.lock().unwrap();
                if let Some(arm) = locked.next_arm() {
                    locked.observe(arm, (t % 2) as f64);
                }
            }
        }));
    }

    for h in handles {
        h.join().unwrap();
    }

    let locked = search.lock().unwrap();
    assert_eq!(locked.total_pulls(), 100);
    let stats = locked.group_stats();
    let total_visits: u32 = stats.iter().map(|g| g.visits).sum();
    assert_eq!(total_visits, 100);
}

#[test]
fn bandit_concurrent_32_threads_1000_arms() {
    let search = Arc::new(Mutex::new(BanditSearch::new_seeded(
        BanditConfig::default(),
        42,
    )));
    {
        let mut s = search.lock().unwrap();
        for i in 0..1000u64 {
            s.add_arm(i, i as u32 % 10);
        }
    }

    let barrier = Arc::new(Barrier::new(32));
    let mut handles = vec![];
    for t in 0..32 {
        let b = Arc::clone(&barrier);
        let s = Arc::clone(&search);
        handles.push(thread::spawn(move || {
            b.wait();
            for _ in 0..100 {
                let mut locked = s.lock().unwrap();
                if let Some(arm) = locked.next_arm() {
                    locked.observe(arm, (t % 2) as f64);
                }
            }
        }));
    }

    for h in handles {
        h.join().unwrap();
    }

    let locked = search.lock().unwrap();
    assert_eq!(locked.total_pulls(), 1000);
    let stats = locked.group_stats();
    assert!(!stats.is_empty());
    let total_visits: u32 = stats.iter().map(|g| g.visits).sum();
    assert_eq!(total_visits, 1000);
}

#[test]
fn bandit_concurrent_mixed_rewards_no_panic() {
    let search = Arc::new(Mutex::new(BanditSearch::new_seeded(
        BanditConfig::default(),
        42,
    )));
    {
        let mut s = search.lock().unwrap();
        for i in 0..200u64 {
            s.add_arm(i, i as u32 % 5);
        }
    }

    let mut handles = vec![];
    for t in 0..10 {
        let s = Arc::clone(&search);
        handles.push(thread::spawn(move || {
            for i in 0..30 {
                let mut locked = s.lock().unwrap();
                if let Some(arm) = locked.next_arm() {
                    let reward = match (t + i) % 4 {
                        0 => 1.0,
                        1 => -1.0,
                        2 => 0.0,
                        _ => 0.5,
                    };
                    locked.observe(arm, reward);
                }
            }
        }));
    }

    for h in handles {
        h.join().unwrap();
    }

    let locked = search.lock().unwrap();
    assert_eq!(locked.total_pulls(), 200);
}

#[test]
fn bandit_concurrent_checkpoint_under_lock() {
    let search = Arc::new(Mutex::new(BanditSearch::new_seeded(
        BanditConfig::default(),
        42,
    )));
    {
        let mut s = search.lock().unwrap();
        for i in 0..50u64 {
            s.add_arm(i, 0);
        }
        for _ in 0..10 {
            if let Some(arm) = s.next_arm() {
                s.observe(arm, 0.5);
            }
        }
    }

    let s = Arc::clone(&search);
    let handle = thread::spawn(move || {
        let locked = s.lock().unwrap();
        let _cp = locked.checkpoint();
    });

    handle.join().unwrap();

    let locked = search.lock().unwrap();
    assert_eq!(locked.total_pulls(), 10);
}

// =============================================================================
// TreeSearch Parallel Tests
// =============================================================================

#[derive(Clone)]
struct ConcurrentGame(i32);

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum ConcurrentMove {
    Up,
    Down,
}

impl Environment for ConcurrentGame {
    type Action = ConcurrentMove;

    fn legal_actions(&self) -> Vec<ConcurrentMove> {
        if self.0.abs() >= 10 {
            vec![]
        } else {
            vec![ConcurrentMove::Up, ConcurrentMove::Down]
        }
    }

    fn apply(&mut self, action: &ConcurrentMove) {
        match action {
            ConcurrentMove::Up => self.0 += 1,
            ConcurrentMove::Down => self.0 -= 1,
        }
    }

    fn evaluate(&self) -> Outcome {
        if self.0 >= 5 {
            Outcome::Success(Reward::WIN)
        } else if self.0 <= -5 {
            Outcome::Failure
        } else {
            Outcome::Ongoing
        }
    }
}

#[cfg(feature = "parallel")]
#[test]
fn treesearch_parallel_4_threads() {
    let mut search = TreeSearch::with_seed(
        ConcurrentGame(0),
        SearchConfig::builder()
            .iterations(2_000)
            .max_depth(10)
            .build(),
        42,
    );
    let best = search.run_parallel(4);
    assert!(best.is_some());
}

#[cfg(feature = "parallel")]
#[test]
fn treesearch_parallel_8_threads() {
    let mut search = TreeSearch::with_seed(
        ConcurrentGame(0),
        SearchConfig::builder()
            .iterations(4_000)
            .max_depth(10)
            .build(),
        42,
    );
    let best = search.run_parallel(8);
    assert!(best.is_some());
}

#[cfg(feature = "parallel")]
#[test]
fn treesearch_parallel_single_thread_matches_serial() {
    let config = SearchConfig::builder()
        .iterations(1_000)
        .max_depth(10)
        .build();
    let mut serial = TreeSearch::with_seed(ConcurrentGame(0), config.clone(), 42);
    let serial_result = serial.run();

    let mut parallel = TreeSearch::with_seed(ConcurrentGame(0), config, 42);
    let parallel_result = parallel.run_parallel(1);

    assert_eq!(serial_result, parallel_result);
}

#[cfg(feature = "parallel")]
#[test]
fn treesearch_parallel_with_time_budget() {
    let mut config = SearchConfig::builder().iterations(1_000_000).build();
    config.time_budget = Some(Duration::from_millis(50));
    let mut search = TreeSearch::new(ConcurrentGame(0), config);
    let best = search.run_parallel(4);
    assert!(best.is_some());
    assert_eq!(
        search.total_simulations(),
        0,
        "run_parallel uses ephemeral workers; the receiver's tree and visit counts are unchanged"
    );
}

// =============================================================================
// TreeSearch + Checkpoint Concurrent Access
// =============================================================================

#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct CheckpointEnv(i32);

#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
struct CheckpointMove(i32);

impl Environment for CheckpointEnv {
    type Action = CheckpointMove;
    fn legal_actions(&self) -> Vec<CheckpointMove> {
        vec![CheckpointMove(1), CheckpointMove(-1)]
    }
    fn apply(&mut self, action: &CheckpointMove) {
        self.0 += action.0;
    }
    fn evaluate(&self) -> Outcome {
        if self.0 >= 3 {
            Outcome::Success(Reward::WIN)
        } else if self.0 <= -3 {
            Outcome::Failure
        } else {
            Outcome::Ongoing
        }
    }
}

#[test]
fn treesearch_checkpoint_from_multiple_threads() {
    let search = Arc::new(Mutex::new(TreeSearch::with_seed(
        CheckpointEnv(0),
        SearchConfig::builder().iterations(100).build(),
        42,
    )));

    let mut handles = vec![];
    for _ in 0..4 {
        let s = Arc::clone(&search);
        handles.push(thread::spawn(move || {
            let locked = s.lock().unwrap();
            let _cp = locked.checkpoint();
        }));
    }

    for h in handles {
        h.join().unwrap();
    }
}

#[test]
fn treesearch_run_step_concurrent_inspection() {
    let search = Arc::new(Mutex::new(TreeSearch::with_seed(
        CheckpointEnv(0),
        SearchConfig::builder().iterations(100).build(),
        42,
    )));

    {
        let mut locked = search.lock().unwrap();
        for _ in 0..50 {
            locked.run_step();
        }
    }

    let mut handles = vec![];
    for _ in 0..4 {
        let s = Arc::clone(&search);
        handles.push(thread::spawn(move || {
            let locked = s.lock().unwrap();
            let _size = locked.tree_size();
            let _sims = locked.total_simulations();
            let _stats = locked.root_stats();
            let _pv = locked.principal_variation();
        }));
    }

    for h in handles {
        h.join().unwrap();
    }
}