use crate::aixi::common::{
Action, ObservationKeyMode, PerceptVal, Reward, observation_repr_from_stream,
};
use rayon::prelude::*;
use std::collections::HashMap;
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
struct PerceptOutcome {
observations: Box<[PerceptVal]>,
reward: Reward,
}
impl PerceptOutcome {
fn new(observations: Vec<PerceptVal>, reward: Reward) -> Self {
Self {
observations: observations.into_boxed_slice(),
reward,
}
}
}
pub trait AgentSimulator: Send {
fn get_num_actions(&self) -> usize;
fn get_num_observation_bits(&self) -> usize;
fn observation_stream_len(&self) -> usize {
1
}
fn observation_key_mode(&self) -> ObservationKeyMode {
ObservationKeyMode::FullStream
}
fn observation_repr_from_stream(&self, observations: &[PerceptVal]) -> Vec<PerceptVal> {
observation_repr_from_stream(
self.observation_key_mode(),
observations,
self.get_num_observation_bits(),
)
}
fn get_num_reward_bits(&self) -> usize;
fn horizon(&self) -> usize;
fn max_reward(&self) -> Reward;
fn min_reward(&self) -> Reward;
fn reward_offset(&self) -> i64 {
0
}
fn get_explore_exploit_ratio(&self) -> f64 {
1.0
}
fn discount_gamma(&self) -> f64 {
1.0
}
fn model_update_action(&mut self, action: Action);
fn gen_percept_and_update(&mut self, bits: usize) -> u64;
fn begin_simulation(&mut self) {}
fn model_revert(&mut self, steps: usize);
fn gen_range(&mut self, end: usize) -> usize;
fn gen_f64(&mut self) -> f64;
fn boxed_clone(&self) -> Box<dyn AgentSimulator> {
self.boxed_clone_with_seed(0)
}
fn boxed_clone_with_seed(&self, seed: u64) -> Box<dyn AgentSimulator>;
fn norm_reward(&self, reward: f64) -> f64 {
let min = self.min_reward() as f64;
let max = self.max_reward() as f64;
let h = self.horizon() as f64;
let gamma = self.discount_gamma().clamp(0.0, 1.0);
let discount_sum = if (gamma - 1.0).abs() < 1e-9 {
h
} else {
(1.0 - gamma.powi(h as i32)) / (1.0 - gamma)
};
let range = (max - min) * discount_sum;
let min_cumulative = min * discount_sum;
if range.abs() < 1e-9 {
0.5
} else {
(reward - min_cumulative) / range
}
}
fn gen_percepts_and_update(&mut self) -> (Vec<PerceptVal>, Reward) {
let obs_bits = self.get_num_observation_bits();
let obs_len = self.observation_stream_len().max(1);
let mut observations = Vec::with_capacity(obs_len);
for _ in 0..obs_len {
observations.push(self.gen_percept_and_update(obs_bits));
}
let obs_key = self.observation_repr_from_stream(&observations);
let rew_bits = self.get_num_reward_bits();
let rew_u = self.gen_percept_and_update(rew_bits);
let rew = (rew_u as i64) - self.reward_offset();
(obs_key, rew)
}
}
#[derive(Clone)]
pub struct SearchNode {
visits: u32,
mean: f64,
is_chance_node: bool,
action_children: Vec<Option<SearchNode>>,
percept_children: HashMap<PerceptOutcome, SearchNode>,
}
impl SearchNode {
pub fn new(is_chance_node: bool) -> Self {
Self {
visits: 0,
mean: 0.0,
is_chance_node,
action_children: Vec::new(),
percept_children: HashMap::new(),
}
}
pub fn best_action(&self, agent: &mut dyn AgentSimulator) -> Action {
let mut best_actions = Vec::new();
let mut best_mean = -f64::INFINITY;
for (action, child) in self.action_children.iter().enumerate() {
let Some(child) = child.as_ref() else {
continue;
};
let mean = child.mean;
if mean > best_mean {
best_mean = mean;
best_actions.clear();
best_actions.push(action as u64);
} else if (mean - best_mean).abs() < 1e-9 {
best_actions.push(action as u64);
}
}
if best_actions.is_empty() {
return 0;
}
let idx = agent.gen_range(best_actions.len());
best_actions[idx] as Action
}
fn expectation(&self) -> f64 {
self.mean
}
fn apply_delta(&mut self, base: &SearchNode, updated: &SearchNode) {
if self.is_chance_node != base.is_chance_node
|| self.is_chance_node != updated.is_chance_node
{
return;
}
let base_visits = base.visits as f64;
let updated_visits = updated.visits as f64;
if updated_visits < base_visits {
return;
}
let delta_visits = updated.visits - base.visits;
if delta_visits > 0 {
let base_sum = base.mean * base_visits;
let updated_sum = updated.mean * updated_visits;
let delta_sum = updated_sum - base_sum;
let total_visits = self.visits + delta_visits;
let total_sum = self.mean * (self.visits as f64) + delta_sum;
self.visits = total_visits;
self.mean = if total_visits > 0 {
total_sum / (total_visits as f64)
} else {
0.0
};
}
if self.is_chance_node {
for (key, updated_child) in &updated.percept_children {
if let Some(base_child) = base.percept_children.get(key) {
if let Some(self_child) = self.percept_children.get_mut(key) {
self_child.apply_delta(base_child, updated_child);
} else {
let mut child = SearchNode::new(updated_child.is_chance_node);
child.apply_delta(
&SearchNode::new(updated_child.is_chance_node),
updated_child,
);
self.percept_children.insert(key.clone(), child);
}
} else if let Some(self_child) = self.percept_children.get_mut(key) {
let empty = SearchNode::new(updated_child.is_chance_node);
self_child.apply_delta(&empty, updated_child);
} else {
let mut child = SearchNode::new(updated_child.is_chance_node);
child.apply_delta(
&SearchNode::new(updated_child.is_chance_node),
updated_child,
);
self.percept_children.insert(key.clone(), child);
}
}
} else {
let max_len = base
.action_children
.len()
.max(updated.action_children.len());
if self.action_children.len() < max_len {
self.action_children.resize_with(max_len, || None);
}
for idx in 0..max_len {
let base_child = base.action_children.get(idx).and_then(|c| c.as_ref());
let updated_child = updated.action_children.get(idx).and_then(|c| c.as_ref());
let Some(updated_child) = updated_child else {
continue;
};
match (base_child, self.action_children.get_mut(idx)) {
(Some(base_child), Some(Some(self_child))) => {
self_child.apply_delta(base_child, updated_child);
}
(Some(base_child), Some(slot @ None)) => {
let mut child = SearchNode::new(updated_child.is_chance_node);
child.apply_delta(base_child, updated_child);
*slot = Some(child);
}
(None, Some(Some(self_child))) => {
let empty = SearchNode::new(updated_child.is_chance_node);
self_child.apply_delta(&empty, updated_child);
}
(None, Some(slot @ None)) => {
let mut child = SearchNode::new(updated_child.is_chance_node);
child.apply_delta(
&SearchNode::new(updated_child.is_chance_node),
updated_child,
);
*slot = Some(child);
}
_ => {}
}
}
}
}
fn select_action(&mut self, agent: &mut dyn AgentSimulator) -> (&mut SearchNode, Action) {
let num_actions = agent.get_num_actions();
if self.action_children.len() < num_actions {
self.action_children.resize_with(num_actions, || None);
}
let mut unvisited = Vec::new();
for a in 0..num_actions {
if self.action_children[a].is_none() {
unvisited.push(a as u64);
}
}
let action;
if !unvisited.is_empty() {
let idx = agent.gen_range(unvisited.len());
action = unvisited[idx];
self.action_children[action as usize] = Some(SearchNode::new(true));
} else {
let c = agent.get_explore_exploit_ratio().max(0.0);
let explore_bias = (agent.horizon() as f64) * (agent.max_reward() as f64).max(0.0);
let mut best_val = -f64::INFINITY;
let mut best_action = None;
let mut num_maximal_actions = 0usize;
let log_visits = (self.visits as f64).ln().max(0.0);
for (a, child) in self.action_children.iter().enumerate() {
let Some(child) = child.as_ref() else {
continue;
};
let nvisits = child.visits as f64;
let val = child.expectation() + explore_bias * ((c * log_visits) / nvisits).sqrt();
debug_assert!(
val.is_finite(),
"UCB score must be finite for visited MC-AIXI action children"
);
match val.total_cmp(&best_val) {
std::cmp::Ordering::Greater => {
best_val = val;
best_action = Some(a as u64);
num_maximal_actions = 1;
}
std::cmp::Ordering::Equal => {
num_maximal_actions += 1;
if agent.gen_range(num_maximal_actions) == 0 {
best_action = Some(a as u64);
}
}
std::cmp::Ordering::Less => {}
}
}
action = best_action.expect("visited MC-AIXI node must have a maximal action");
}
agent.model_update_action(action as Action);
(
self.action_children[action as usize]
.as_mut()
.expect("missing action child"),
action as Action,
)
}
pub fn sample(
&mut self,
agent: &mut dyn AgentSimulator,
horizon: usize,
total_horizon: usize,
) -> f64 {
if horizon == 0 {
agent.model_revert(total_horizon);
return 0.0;
}
let reward;
if self.is_chance_node {
let (obs, rew) = agent.gen_percepts_and_update();
let key = PerceptOutcome::new(obs, rew);
let child = self
.percept_children
.entry(key)
.or_insert_with(|| SearchNode::new(false));
reward = (rew as f64)
+ agent.discount_gamma() * child.sample(agent, horizon - 1, total_horizon);
} else if self.visits == 0 {
reward = Self::playout(agent, horizon, total_horizon);
} else {
let (child, _act) = self.select_action(agent);
reward = child.sample(agent, horizon, total_horizon);
}
self.mean = (reward + (self.visits as f64) * self.mean) / ((self.visits + 1) as f64);
self.visits += 1;
reward
}
fn playout(agent: &mut dyn AgentSimulator, horizon: usize, total_horizon: usize) -> f64 {
let mut total_rew = 0.0;
let num_actions = agent.get_num_actions();
let gamma = agent.discount_gamma().clamp(0.0, 1.0);
let mut discount = 1.0;
for _ in 0..horizon {
let act = agent.gen_range(num_actions);
agent.model_update_action(act as Action);
let (_key, rew) = agent.gen_percepts_and_update();
total_rew += discount * (rew as f64);
discount *= gamma;
}
agent.model_revert(total_horizon);
total_rew
}
}
pub struct SearchTree {
root: Option<SearchNode>,
}
impl SearchTree {
pub fn new() -> Self {
Self {
root: Some(SearchNode::new(false)),
}
}
pub fn search(
&mut self,
agent: &mut dyn AgentSimulator,
prev_obs_stream: &[PerceptVal],
prev_rew: Reward,
prev_act: u64,
samples: usize,
) -> Action {
self.prune_tree(agent, prev_obs_stream, prev_rew, prev_act);
let root = self.root.as_mut().unwrap();
let h = agent.horizon();
let threads = rayon::current_num_threads().max(1);
if samples < 2 || threads < 2 {
for _ in 0..samples {
agent.begin_simulation();
root.sample(agent, h, h);
}
return root.best_action(agent);
}
let workers = threads.min(samples);
let base = samples / workers;
let extra = samples % workers;
let snapshot = root.clone();
let mut agents = Vec::with_capacity(workers);
for i in 0..workers {
let seed = agent.gen_f64().to_bits() ^ (i as u64);
agents.push(agent.boxed_clone_with_seed(seed));
}
let results: Vec<SearchNode> = agents
.into_par_iter()
.enumerate()
.map(|(i, mut local_agent)| {
let mut local_root = snapshot.clone();
let iterations = base + usize::from(i < extra);
for _ in 0..iterations {
local_agent.begin_simulation();
local_root.sample(local_agent.as_mut(), h, h);
}
local_root
})
.collect();
for local in &results {
root.apply_delta(&snapshot, local);
}
root.best_action(agent)
}
fn prune_tree(
&mut self,
agent: &mut dyn AgentSimulator,
prev_obs_stream: &[PerceptVal],
prev_rew: Reward,
prev_act: u64,
) {
if self.root.is_none() {
self.root = Some(SearchNode::new(false));
return;
}
let mut old_root = self.root.take().unwrap();
let action_child_opt = if old_root.action_children.len() > prev_act as usize {
old_root.action_children[prev_act as usize].take()
} else {
None
};
if let Some(mut chance_child) = action_child_opt {
let obs_repr = agent.observation_repr_from_stream(prev_obs_stream);
let key = PerceptOutcome::new(obs_repr, prev_rew);
if let Some(action_child) = chance_child.percept_children.remove(&key) {
self.root = Some(action_child);
} else {
self.root = Some(SearchNode::new(false));
}
} else {
self.root = Some(SearchNode::new(false));
}
}
}
impl Default for SearchTree {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::aixi::common::ObservationKeyMode;
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
#[derive(Clone)]
struct DummyAgent {
obs_bits: usize,
rew_bits: usize,
horizon: usize,
min_reward: Reward,
max_reward: Reward,
key_mode: ObservationKeyMode,
}
impl DummyAgent {
fn new(obs_bits: usize, key_mode: ObservationKeyMode) -> Self {
Self {
obs_bits,
rew_bits: 8,
horizon: 5,
min_reward: -1,
max_reward: 1,
key_mode,
}
}
}
impl AgentSimulator for DummyAgent {
fn get_num_actions(&self) -> usize {
4
}
fn get_num_observation_bits(&self) -> usize {
self.obs_bits
}
fn observation_key_mode(&self) -> ObservationKeyMode {
self.key_mode
}
fn get_num_reward_bits(&self) -> usize {
self.rew_bits
}
fn horizon(&self) -> usize {
self.horizon
}
fn max_reward(&self) -> Reward {
self.max_reward
}
fn min_reward(&self) -> Reward {
self.min_reward
}
fn model_update_action(&mut self, _action: Action) {}
fn gen_percept_and_update(&mut self, _bits: usize) -> u64 {
0
}
fn model_revert(&mut self, _steps: usize) {}
fn gen_range(&mut self, _end: usize) -> usize {
0
}
fn gen_f64(&mut self) -> f64 {
0.0
}
fn boxed_clone_with_seed(&self, _seed: u64) -> Box<dyn AgentSimulator> {
Box::new(self.clone())
}
}
fn build_tree_with_key(
agent: &DummyAgent,
prev_act: u64,
prev_obs_stream: &[PerceptVal],
prev_rew: Reward,
kept_mean: f64,
kept_visits: u32,
) -> SearchTree {
let mut old_root = SearchNode::new(false);
old_root.action_children.resize(prev_act as usize + 1, None);
let mut chance_child = SearchNode::new(true);
let mut kept = SearchNode::new(false);
kept.mean = kept_mean;
kept.visits = kept_visits;
let obs_repr = agent.observation_repr_from_stream(prev_obs_stream);
let key = PerceptOutcome::new(obs_repr, prev_rew);
chance_child.percept_children.insert(key, kept);
old_root.action_children[prev_act as usize] = Some(chance_child);
SearchTree {
root: Some(old_root),
}
}
#[test]
fn prune_tree_keeps_matching_subtree() {
let prev_act = 2u64;
let prev_obs_stream = vec![9u64, 2u64, 7u64];
let prev_rew: Reward = 3;
let mut agent = DummyAgent::new(3, ObservationKeyMode::FullStream);
let mut tree = build_tree_with_key(&agent, prev_act, &prev_obs_stream, prev_rew, 123.0, 7);
tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
let root = tree.root.as_ref().expect("root should exist");
assert!(!root.is_chance_node);
assert_eq!(root.mean, 123.0);
assert_eq!(root.visits, 7);
}
#[test]
fn prune_tree_resets_when_action_missing() {
let prev_act = 10u64;
let prev_obs_stream = vec![1u64];
let prev_rew: Reward = 0;
let mut agent = DummyAgent::new(1, ObservationKeyMode::FullStream);
let mut tree = SearchTree::new();
tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
let root = tree.root.as_ref().unwrap();
assert!(!root.is_chance_node);
assert_eq!(root.visits, 0);
assert_eq!(root.mean, 0.0);
}
#[test]
fn prune_tree_resets_when_percept_key_missing() {
let prev_act = 0u64;
let prev_obs_stream = vec![1u64, 2u64];
let prev_rew: Reward = 1;
let mut agent = DummyAgent::new(4, ObservationKeyMode::Last);
let mut tree = build_tree_with_key(&agent, prev_act, &[9u64], prev_rew, 9.0, 2);
tree.prune_tree(&mut agent, &prev_obs_stream, prev_rew, prev_act);
let root = tree.root.as_ref().unwrap();
assert!(!root.is_chance_node);
assert_eq!(root.visits, 0);
assert_eq!(root.mean, 0.0);
}
#[test]
fn prune_tree_resets_when_reward_mismatch_shares_observation_key() {
let prev_act = 1u64;
let prev_obs_stream = vec![4u64, 5u64];
let kept_rew: Reward = -2;
let requested_rew: Reward = 2;
let mut agent = DummyAgent::new(6, ObservationKeyMode::FullStream);
let mut tree = build_tree_with_key(&agent, prev_act, &prev_obs_stream, kept_rew, 77.0, 11);
tree.prune_tree(&mut agent, &prev_obs_stream, requested_rew, prev_act);
let root = tree.root.as_ref().unwrap();
assert!(!root.is_chance_node);
assert_eq!(root.visits, 0);
assert_eq!(root.mean, 0.0);
}
#[derive(Clone)]
struct BeginCountingAgent {
begins: Arc<AtomicUsize>,
}
impl AgentSimulator for BeginCountingAgent {
fn get_num_actions(&self) -> usize {
2
}
fn get_num_observation_bits(&self) -> usize {
1
}
fn get_num_reward_bits(&self) -> usize {
1
}
fn horizon(&self) -> usize {
1
}
fn max_reward(&self) -> Reward {
1
}
fn min_reward(&self) -> Reward {
0
}
fn model_update_action(&mut self, _action: Action) {}
fn gen_percept_and_update(&mut self, _bits: usize) -> u64 {
0
}
fn begin_simulation(&mut self) {
self.begins.fetch_add(1, Ordering::Relaxed);
}
fn model_revert(&mut self, _steps: usize) {}
fn gen_range(&mut self, _end: usize) -> usize {
0
}
fn gen_f64(&mut self) -> f64 {
0.0
}
fn boxed_clone_with_seed(&self, _seed: u64) -> Box<dyn AgentSimulator> {
Box::new(self.clone())
}
}
#[test]
fn search_calls_begin_simulation_for_each_rollout() {
let begins = Arc::new(AtomicUsize::new(0));
let mut agent = BeginCountingAgent {
begins: begins.clone(),
};
let mut tree = SearchTree::new();
let _ = tree.search(&mut agent, &[0], 0, 0, 5);
assert_eq!(begins.load(Ordering::Relaxed), 5);
}
#[derive(Clone)]
struct TieBreakAgent {
next_range: Arc<AtomicUsize>,
}
impl AgentSimulator for TieBreakAgent {
fn get_num_actions(&self) -> usize {
4
}
fn get_num_observation_bits(&self) -> usize {
1
}
fn get_num_reward_bits(&self) -> usize {
1
}
fn horizon(&self) -> usize {
1
}
fn max_reward(&self) -> Reward {
1
}
fn min_reward(&self) -> Reward {
0
}
fn get_explore_exploit_ratio(&self) -> f64 {
0.0
}
fn model_update_action(&mut self, _action: Action) {}
fn gen_percept_and_update(&mut self, _bits: usize) -> u64 {
0
}
fn model_revert(&mut self, _steps: usize) {}
fn gen_range(&mut self, end: usize) -> usize {
self.next_range
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |value| {
Some(value.saturating_sub(1))
})
.expect("range source should be initialized")
% end
}
fn gen_f64(&mut self) -> f64 {
0.0
}
fn boxed_clone_with_seed(&self, _seed: u64) -> Box<dyn AgentSimulator> {
Box::new(self.clone())
}
}
#[test]
fn select_action_uses_uniform_tie_break_for_maximal_ucb_actions() {
let mut node = SearchNode::new(false);
node.visits = 16;
node.action_children = vec![
Some(SearchNode {
visits: 5,
mean: 0.1,
is_chance_node: true,
action_children: Vec::new(),
percept_children: HashMap::new(),
}),
Some(SearchNode {
visits: 5,
mean: 0.9,
is_chance_node: true,
action_children: Vec::new(),
percept_children: HashMap::new(),
}),
Some(SearchNode {
visits: 5,
mean: 0.2,
is_chance_node: true,
action_children: Vec::new(),
percept_children: HashMap::new(),
}),
Some(SearchNode {
visits: 5,
mean: 0.9,
is_chance_node: true,
action_children: Vec::new(),
percept_children: HashMap::new(),
}),
];
let mut agent = TieBreakAgent {
next_range: Arc::new(AtomicUsize::new(0)),
};
let (_child, action) = node.select_action(&mut agent);
assert_eq!(
action, 3,
"exactly tied maximal UCB actions should be chosen uniformly; scripted RNG selected the later maximal action"
);
}
}