mctrust 0.4.0

Universal search & planning toolkit — MCTS, bandit search, pluggable evaluators, tree reuse, DAG transpositions, root parallelism. Define an Environment, search handles the rest.
Documentation
#![allow(clippy::cast_possible_truncation)]

use std::collections::HashMap;

use super::*;

#[test]
fn bandit_basic_flow() {
    let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);

    for i in 0..20u64 {
        search.add_arm(i, (i / 5) as u32);
    }

    let mut pulled = Vec::new();
    for _ in 0..20 {
        if let Some(arm) = search.next_arm() {
            pulled.push(arm);
            let reward = if arm < 5 { 1.0 } else { 0.0 };
            search.observe(arm, reward);
        } else {
            break;
        }
    }

    assert_eq!(pulled.len(), 20);
}

#[test]
fn bandit_next_arms_collects_multiple() {
    let mut search = BanditSearch::new_seeded(BanditConfig::default(), 7);

    for i in 0..5u64 {
        search.add_arm(i, 0);
    }

    let mut pulled = search.next_arms(3);
    assert_eq!(pulled.len(), 3);
    while let Some(arm) = pulled.pop() {
        search.observe(arm, 0.5);
    }

    let pulled = search.next_arms(10);
    assert_eq!(pulled.len(), 2);
    for arm in pulled {
        search.observe(arm, 0.5);
    }

    assert_eq!(search.total_pulls(), 5);
}

#[test]
fn bandit_observe_with_signals_scalarizes() {
    let mut signal_weights = HashMap::new();
    signal_weights.insert("quality".to_string(), 1.0);
    signal_weights.insert("latency".to_string(), 2.0);
    let config = BanditConfig::builder()
        .scalarizer(Scalarizer {
            signal_weights,
            default_weight: 0.0,
        })
        .build();
    let mut search = BanditSearch::new_seeded(config, 11);

    search.add_arm(0, 0);
    let arm = search.next_arm().unwrap();
    search.observe_with_signals(arm, &[("quality", 0.5), ("latency", 1.0)]);
    let stats = search.group_stats();

    assert_eq!(stats.len(), 1);
    assert!((stats[0].average_reward - 2.5).abs() < f64::EPSILON);
}

#[test]
fn bandit_reweight_signals_updates_online() {
    let mut signal_weights = HashMap::new();
    signal_weights.insert("quality".to_string(), 1.0);
    signal_weights.insert("latency".to_string(), 0.0);
    let config = BanditConfig::builder()
        .scalarizer(Scalarizer {
            signal_weights,
            default_weight: 0.0,
        })
        .build();
    let mut search = BanditSearch::new_seeded(config, 11);

    search.add_arm(0, 0);
    search.add_arm(1, 0);

    let first = search.next_arm().unwrap();
    search.observe_with_signals(first, &[("quality", 1.0)]);

    let second = search.next_arm().unwrap();
    search.reweight_signals(&[("quality", 0.0), ("latency", 2.0)]);
    search.observe_with_signals(second, &[("latency", 1.0)]);

    let stats = search.group_stats();
    assert_eq!(search.total_pulls(), 2);
    assert_eq!(stats[0].visits, 2);
    assert!((stats[0].average_reward - 1.5).abs() < f64::EPSILON);
}

#[test]
fn bandit_respects_budget() {
    let config = BanditConfig::builder().max_pulls(5).build();
    let mut search = BanditSearch::new_seeded(config, 42);

    for i in 0..50u64 {
        search.add_arm(i, 0);
    }

    let mut count = 0;
    while let Some(arm) = search.next_arm() {
        count += 1;
        search.observe(arm, 1.0);
        assert!(count <= 100, "budget not respected");
    }

    assert_eq!(count, 5);
}

#[test]
fn bandit_empty_returns_none() {
    let mut search = BanditSearch::new(BanditConfig::default());
    assert!(search.next_arm().is_none());
}

#[test]
fn bandit_rave_propagates() {
    let mut search = BanditSearch::new_seeded(BanditConfig::default(), 99);

    // Two groups, 5 arms each.
    for i in 0..10u64 {
        search.add_arm(i, (i / 5) as u32);
    }

    // Pull two arms with positive reward so RAVE definitely fires.
    let arm1 = search.next_arm().unwrap();
    search.observe(arm1, 1.0);
    let arm2 = search.next_arm().unwrap();
    search.observe(arm2, 1.0);

    // At least one group's node should have RAVE visits from the other.
    let stats = search.group_stats();
    let total_rave: u32 = stats.iter().map(|s| s.rave_visits).sum();
    assert_eq!(total_rave, 2);
}

#[test]
fn bandit_group_bias() {
    let mut search = BanditSearch::new_seeded(BanditConfig::default(), 7);

    for i in 0..10u64 {
        search.add_arm(i, (i / 5) as u32);
    }

    // Bias group 1 heavily.
    search.set_group_bias(1, 100.0);

    // First pull should come from group 1 due to the massive bias.
    let arm = search.next_arm().unwrap();
    assert!(arm >= 5, "biased group should be selected first");
}

#[test]
fn bandit_total_pulls() {
    let mut search = BanditSearch::new(BanditConfig::default());
    for i in 0..5u64 {
        search.add_arm(i, 0);
    }

    for _ in 0..3 {
        if let Some(arm) = search.next_arm() {
            search.observe(arm, 0.5);
        }
    }

    assert_eq!(search.total_pulls(), 3);
}

#[test]
fn bandit_group_stats_correct() {
    let mut search = BanditSearch::new_seeded(BanditConfig::default(), 1);

    for i in 0..6u64 {
        search.add_arm(i, (i / 3) as u32);
    }

    // Pull 3 arms, all from whatever UCT selects.
    for _ in 0..3 {
        if let Some(arm) = search.next_arm() {
            search.observe(arm, 0.5);
        }
    }

    let stats = search.group_stats();
    assert_eq!(stats.len(), 2);

    let total_visits: u32 = stats.iter().map(|s| s.visits).sum();
    assert_eq!(total_visits, 3);
}

#[test]
fn bandit_seeded_determinism() {
    let config = BanditConfig::default();

    let mut s1 = BanditSearch::new_seeded(config.clone(), 42);
    let mut s2 = BanditSearch::new_seeded(config, 42);

    for i in 0..10u64 {
        s1.add_arm(i, (i / 5) as u32);
        s2.add_arm(i, (i / 5) as u32);
    }

    for _ in 0..10 {
        let a1 = s1.next_arm();
        let a2 = s2.next_arm();
        assert_eq!(a1, a2, "same seed should produce same sequence");
        if let (Some(arm1), Some(arm2)) = (a1, a2) {
            s1.observe(arm1, 0.5);
            s2.observe(arm2, 0.5);
        }
    }
}

#[test]
fn bandit_many_groups() {
    let mut search = BanditSearch::new(BanditConfig::default());

    // 100 groups, 10 arms each.
    for group in 0..100u32 {
        for arm in 0..10u32 {
            search.add_arm(u64::from(group) * 10 + u64::from(arm), group);
        }
    }

    // Should be able to pull at least 100 arms.
    let mut count = 0;
    for _ in 0..200 {
        if let Some(arm) = search.next_arm() {
            search.observe(arm, 0.1);
            count += 1;
        } else {
            break;
        }
    }

    assert!(count >= 100);
}

#[test]
fn bandit_checkpoint_roundtrip() {
    let mut search = BanditSearch::new_seeded(BanditConfig::default(), 11);
    for i in 0..7u64 {
        search.add_arm(i, 0);
    }
    for _ in 0..3 {
        if let Some(arm) = search.next_arm() {
            search.observe(arm, 0.75);
        }
    }

    let checkpoint = search.checkpoint();
    let restored = BanditSearch::restore(checkpoint);
    assert_eq!(restored.total_pulls(), 3);
}

#[test]
fn bandit_prevents_duplicate_arms() {
    let mut search = BanditSearch::new(BanditConfig::default());
    search.add_arm(1, 0);
    search.add_arm(1, 0);
    let stats = search.group_stats();

    assert_eq!(stats[0].total_arms, 1);
}

#[test]
fn bandit_zero_max_pulls_disables_limit() {
    let config = BanditConfig::builder().max_pulls(0).build();
    let mut search = BanditSearch::new_seeded(config, 7);

    for i in 0..15u64 {
        search.add_arm(i, 0);
    }

    for _ in 0..15 {
        if search.next_arm().is_none() {
            break;
        }
    }

    assert_eq!(search.total_pulls(), 15);
}

#[test]
fn test_next_arms_returns_correct_count() {
    let mut search = BanditSearch::new(BanditConfig::default());
    for i in 0..10u64 {
        search.add_arm(i, 0);
    }
    let arms = search.next_arms(5);
    assert_eq!(arms.len(), 5);
}

#[test]
fn test_observe_multi_uses_max() {
    let mut search = BanditSearch::new_seeded(BanditConfig::default(), 42);
    search.add_arm(0, 0);
    search.add_arm(1, 0);
    let _ = search.next_arm();
    search.observe_multi(0, &[("sink", 1.0), ("crash", 0.5), ("coverage", 0.3)]);
    let stats = search.group_stat(0);
    assert!(stats.is_some());
}