use crate::RateBackend;
use crate::aixi::common::{
Action, ObservationKeyMode, PerceptVal, RandomGenerator, Reward, decode, encode,
observation_repr_from_stream,
};
use crate::aixi::mcts::{AgentSimulator, SearchTree};
#[cfg(feature = "backend-mamba")]
use crate::aixi::model::MambaPredictor;
#[cfg(feature = "backend-rwkv")]
use crate::aixi::model::RwkvPredictor;
use crate::aixi::model::{
CtwPredictor, FacCtwPredictor, Predictor, RateBackendBitPredictor, RosaPredictor, ZpaqPredictor,
};
use crate::aixi::rate_backend::{adapt_rate_backend_for_bit_tokens, rate_backend_contains_zpaq};
#[cfg(feature = "backend-mamba")]
use crate::load_mamba_model_from_path;
#[cfg(feature = "backend-rwkv")]
use crate::load_rwkv7_model_from_path;
use crate::{validate_rate_backend, validate_zpaq_rate_method};
#[derive(Clone)]
pub struct AgentConfig {
pub algorithm: String,
pub ct_depth: usize,
pub agent_horizon: usize,
pub observation_bits: usize,
pub observation_stream_len: usize,
pub observation_key_mode: ObservationKeyMode,
pub reward_bits: usize,
pub agent_actions: usize,
pub num_simulations: usize,
pub exploration_exploitation_ratio: f64,
pub discount_gamma: f64,
pub min_reward: Reward,
pub max_reward: Reward,
pub reward_offset: Reward,
pub random_seed: Option<u64>,
pub rate_backend: Option<RateBackend>,
pub rate_backend_max_order: i64,
pub rwkv_model_path: Option<String>,
pub rwkv_method: Option<String>,
pub mamba_model_path: Option<String>,
pub mamba_method: Option<String>,
pub rosa_max_order: Option<i64>,
pub zpaq_method: Option<String>,
}
impl AgentConfig {
pub fn validate(&self) -> Result<(), String> {
if self.agent_actions == 0 {
return Err("agent_actions must be >= 1".to_string());
}
if self.agent_horizon == 0 {
return Err("agent_horizon must be >= 1".to_string());
}
if self.num_simulations == 0 {
return Err("num_simulations must be >= 1".to_string());
}
if self.exploration_exploitation_ratio <= 0.0 {
return Err("exploration_exploitation_ratio must be > 0".to_string());
}
if !(0.0..=1.0).contains(&self.discount_gamma) {
return Err(format!(
"discount_gamma must be in [0, 1] for MC-AIXI, got {}",
self.discount_gamma
));
}
if self.max_reward < self.min_reward {
return Err(format!(
"max_reward must be >= min_reward (got {} < {})",
self.max_reward, self.min_reward
));
}
let min_shifted = (self.min_reward as i128) + (self.reward_offset as i128);
let max_shifted = (self.max_reward as i128) + (self.reward_offset as i128);
if min_shifted < 0 {
return Err(format!(
"reward_offset too small: min_reward + reward_offset must be >= 0 (got {})",
min_shifted
));
}
if self.reward_bits < 64 {
let max_enc = (1u128 << self.reward_bits) - 1;
if (max_shifted as u128) > max_enc {
return Err(format!(
"reward_bits too small for configured reward range: max shifted reward {} exceeds {}",
max_shifted, max_enc
));
}
}
if let Some(rate_backend) = &self.rate_backend {
validate_rate_backend(rate_backend)
.map_err(|err| format!("invalid rate_backend: {err}"))?;
if rate_backend_contains_zpaq(rate_backend) {
return Err(
"MC-AIXI strict generic rate_backend support requires reversible action conditioning; configured rate_backend contains zpaq which does not provide the reversible action conditioning required by \"A Monte-Carlo AIXI Approximation\""
.to_string(),
);
}
return Ok(());
}
match self.algorithm.as_str() {
"ctw" | "fac-ctw" | "ac-ctw" | "ctw-context-tree" | "rosa" => {}
#[cfg(feature = "backend-rwkv")]
"rwkv" => {
let has_method = self
.rwkv_method
.as_deref()
.map(str::trim)
.is_some_and(|v| !v.is_empty());
let has_path = self
.rwkv_model_path
.as_deref()
.map(str::trim)
.is_some_and(|v| !v.is_empty());
if !(has_method || has_path) {
return Err(
"algorithm=rwkv requires rwkv_model_path or rwkv_method when no rate_backend override is configured"
.to_string(),
);
}
}
#[cfg(not(feature = "backend-rwkv"))]
"rwkv" => return Err("algorithm=rwkv requires backend-rwkv feature".to_string()),
#[cfg(feature = "backend-mamba")]
"mamba" => {
let has_method = self
.mamba_method
.as_deref()
.map(str::trim)
.is_some_and(|v| !v.is_empty());
let has_path = self
.mamba_model_path
.as_deref()
.map(str::trim)
.is_some_and(|v| !v.is_empty());
if !(has_method || has_path) {
return Err(
"algorithm=mamba requires mamba_model_path or mamba_method when no rate_backend override is configured"
.to_string(),
);
}
}
#[cfg(not(feature = "backend-mamba"))]
"mamba" => return Err("algorithm=mamba requires backend-mamba feature".to_string()),
"zpaq" => {
let method = self.zpaq_method.as_deref().unwrap_or("1");
if let Err(err) = validate_zpaq_rate_method(method) {
return Err(format!("Invalid zpaq method for AIXI: {err}"));
}
}
other => return Err(format!("Unknown algorithm: {other}")),
}
Ok(())
}
}
pub struct Agent {
model: Box<dyn Predictor>,
planner: Option<SearchTree>,
config: AgentConfig,
age: u64,
total_reward: f64,
action_bits: usize,
rng: RandomGenerator,
obs_buffer: Vec<u64>,
sym_buffer: Vec<bool>,
}
impl Agent {
pub fn new(config: AgentConfig) -> Self {
Self::try_new(config).unwrap_or_else(|err| panic!("Invalid MC-AIXI config: {err}"))
}
pub fn try_new(config: AgentConfig) -> Result<Self, String> {
config.validate()?;
let mut action_bits = 0;
let mut c = 1;
let mut i = 1;
while i < config.agent_actions {
i *= 2;
action_bits = c;
c += 1;
}
if config.agent_actions == 1 {
action_bits = 1;
}
let model = build_model(&config)?;
let rng = if let Some(seed) = config.random_seed {
RandomGenerator::from_seed(seed)
} else {
RandomGenerator::new()
};
Ok(Self {
model,
planner: Some(SearchTree::new()),
config,
age: 0,
total_reward: 0.0,
action_bits,
rng,
obs_buffer: Vec::with_capacity(128),
sym_buffer: Vec::with_capacity(64),
})
}
fn clone_for_simulation(&self, seed: u64) -> Self {
Self {
model: self.model.boxed_clone(),
planner: None,
config: self.config.clone(),
age: self.age,
total_reward: self.total_reward,
action_bits: self.action_bits,
rng: self.rng.fork_with(seed),
obs_buffer: Vec::with_capacity(128),
sym_buffer: Vec::with_capacity(64),
}
}
pub fn reset(&mut self) {
self.age = 0;
self.total_reward = 0.0;
}
pub fn get_planned_action(
&mut self,
prev_obs_stream: &[PerceptVal],
prev_rew: Reward,
prev_act: Action,
) -> Action {
let mut planner = self.planner.take().expect("Planner missing");
let num_sim = self.config.num_simulations;
let action = planner.search(self, prev_obs_stream, prev_rew, prev_act, num_sim);
self.planner = Some(planner);
action
}
pub fn model_update_percept(&mut self, observation: PerceptVal, reward: Reward) {
self.model_update_percept_stream(&[observation], reward);
}
pub fn model_update_percept_stream(&mut self, observations: &[PerceptVal], reward: Reward) {
debug_assert!(
!observations.is_empty() || self.config.observation_bits == 0,
"percept update missing observation stream"
);
let mut percept_syms = Vec::new();
for &obs in observations {
encode(&mut percept_syms, obs, self.config.observation_bits);
}
crate::aixi::common::encode_reward_offset(
&mut percept_syms,
reward,
self.config.reward_bits,
self.config.reward_offset,
);
for &sym in &percept_syms {
self.model.commit_update(sym);
}
self.total_reward += reward as f64;
}
pub fn observation_repr_from_stream(&self, observations: &[PerceptVal]) -> Vec<PerceptVal> {
observation_repr_from_stream(
self.config.observation_key_mode,
observations,
self.config.observation_bits,
)
}
pub fn model_update_action_external(&mut self, action: Action) {
self.sym_buffer.clear();
encode(&mut self.sym_buffer, action, self.action_bits);
for &sym in &self.sym_buffer {
self.model.commit_update_history(sym);
}
}
}
fn build_model(config: &AgentConfig) -> Result<Box<dyn Predictor>, String> {
if let Some(rate_backend) = config.rate_backend.clone() {
let bit_backend = adapt_rate_backend_for_bit_tokens(rate_backend);
let predictor = RateBackendBitPredictor::new(bit_backend, config.rate_backend_max_order)?;
return Ok(Box::new(predictor));
}
match config.algorithm.as_str() {
"ctw" | "fac-ctw" => {
let obs_len = config.observation_stream_len.max(1);
let percept_bits = (config.observation_bits * obs_len) + config.reward_bits;
Ok(Box::new(FacCtwPredictor::new(
config.ct_depth,
percept_bits,
)))
}
"ac-ctw" | "ctw-context-tree" => Ok(Box::new(CtwPredictor::new(config.ct_depth))),
"rosa" => {
let max_order = config.rosa_max_order.unwrap_or(20);
Ok(Box::new(RosaPredictor::new(max_order)))
}
#[cfg(feature = "backend-rwkv")]
"rwkv" => {
if let Some(method) = config
.rwkv_method
.as_deref()
.map(str::trim)
.filter(|v| !v.is_empty())
{
let predictor = RwkvPredictor::from_method(method)
.map_err(|err| format!("Invalid RWKV method for AIXI: {err}"))?;
Ok(Box::new(predictor))
} else {
let path = config.rwkv_model_path.as_ref().ok_or_else(|| {
"RWKV model path required when rwkv_method is not configured".to_string()
})?;
let model_arc = load_rwkv7_model_from_path(path);
Ok(Box::new(RwkvPredictor::new(model_arc)))
}
}
#[cfg(not(feature = "backend-rwkv"))]
"rwkv" => Err("RWKV backend disabled at compile time".to_string()),
#[cfg(feature = "backend-mamba")]
"mamba" => {
if let Some(method) = config
.mamba_method
.as_deref()
.map(str::trim)
.filter(|v| !v.is_empty())
{
let predictor = MambaPredictor::from_method(method)
.map_err(|err| format!("Invalid Mamba method for AIXI: {err}"))?;
Ok(Box::new(predictor))
} else {
let path = config.mamba_model_path.as_ref().ok_or_else(|| {
"Mamba model path required when mamba_method is not configured".to_string()
})?;
let model_arc = load_mamba_model_from_path(path);
Ok(Box::new(MambaPredictor::new(model_arc)))
}
}
#[cfg(not(feature = "backend-mamba"))]
"mamba" => Err("Mamba backend disabled at compile time".to_string()),
"zpaq" => {
let method = config
.zpaq_method
.clone()
.unwrap_or_else(|| "1".to_string());
if let Err(err) = validate_zpaq_rate_method(&method) {
return Err(format!("Invalid zpaq method for AIXI: {err}"));
}
Ok(Box::new(ZpaqPredictor::new(method, 2f64.powi(-24))))
}
_ => Err(format!("Unknown algorithm: {}", config.algorithm)),
}
}
impl AgentSimulator for Agent {
fn get_num_actions(&self) -> usize {
self.config.agent_actions
}
fn get_num_observation_bits(&self) -> usize {
self.config.observation_bits
}
fn observation_stream_len(&self) -> usize {
self.config.observation_stream_len.max(1)
}
fn observation_key_mode(&self) -> ObservationKeyMode {
self.config.observation_key_mode
}
fn get_num_reward_bits(&self) -> usize {
self.config.reward_bits
}
fn horizon(&self) -> usize {
self.config.agent_horizon
}
fn max_reward(&self) -> Reward {
self.config.max_reward
}
fn min_reward(&self) -> Reward {
self.config.min_reward
}
fn reward_offset(&self) -> i64 {
self.config.reward_offset
}
fn get_explore_exploit_ratio(&self) -> f64 {
self.config.exploration_exploitation_ratio
}
fn discount_gamma(&self) -> f64 {
self.config.discount_gamma
}
fn model_update_action(&mut self, action: Action) {
self.sym_buffer.clear();
encode(&mut self.sym_buffer, action, self.action_bits);
for &sym in &self.sym_buffer {
self.model.update_history(sym);
}
}
fn gen_percept_and_update(&mut self, bits: usize) -> u64 {
self.sym_buffer.clear();
for _ in 0..bits {
let prob_1 = self.model.predict_one();
let sym = self.rng.gen_bool(prob_1);
self.model.update(sym);
self.sym_buffer.push(sym);
}
decode(&self.sym_buffer, bits)
}
fn begin_simulation(&mut self) {
self.model.begin_rollback_scope();
}
fn gen_percepts_and_update(&mut self) -> (Vec<PerceptVal>, Reward) {
let obs_bits = self.config.observation_bits;
let obs_len = self.config.observation_stream_len.max(1);
self.obs_buffer.clear();
for _ in 0..obs_len {
let p = self.gen_percept_and_update(obs_bits);
self.obs_buffer.push(p);
}
let obs_repr = observation_repr_from_stream(
self.config.observation_key_mode,
&self.obs_buffer,
obs_bits,
);
let rew_bits = self.config.reward_bits;
let rew_u = self.gen_percept_and_update(rew_bits);
let rew = (rew_u as i64) - self.config.reward_offset;
(obs_repr, rew)
}
fn gen_range(&mut self, end: usize) -> usize {
self.rng.gen_range(end)
}
fn gen_f64(&mut self) -> f64 {
self.rng.gen_f64()
}
fn model_revert(&mut self, steps: usize) {
if self.model.rollback_scope() {
return;
}
let obs_bits = self.config.observation_bits * self.config.observation_stream_len.max(1);
let percept_bits = obs_bits + self.config.reward_bits;
for _ in 0..steps {
for _ in 0..percept_bits {
self.model.revert();
}
for _ in 0..self.action_bits {
self.model.pop_history();
}
}
}
fn boxed_clone_with_seed(&self, seed: u64) -> Box<dyn AgentSimulator> {
Box::new(self.clone_for_simulation(seed))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
#[derive(Clone, Default)]
struct CallCounts {
update: usize,
commit_update: usize,
update_history: usize,
commit_update_history: usize,
begin_scope: usize,
rollback_scope: usize,
revert: usize,
pop_history: usize,
}
#[derive(Clone)]
struct InstrumentedPredictor {
counts: Arc<Mutex<CallCounts>>,
}
impl InstrumentedPredictor {
fn new(counts: Arc<Mutex<CallCounts>>) -> Self {
Self { counts }
}
}
impl Predictor for InstrumentedPredictor {
fn update(&mut self, _sym: bool) {
self.counts.lock().unwrap().update += 1;
}
fn commit_update(&mut self, _sym: bool) {
self.counts.lock().unwrap().commit_update += 1;
}
fn update_history(&mut self, _sym: bool) {
self.counts.lock().unwrap().update_history += 1;
}
fn commit_update_history(&mut self, _sym: bool) {
self.counts.lock().unwrap().commit_update_history += 1;
}
fn revert(&mut self) {
self.counts.lock().unwrap().revert += 1;
}
fn pop_history(&mut self) {
self.counts.lock().unwrap().pop_history += 1;
}
fn begin_rollback_scope(&mut self) {
self.counts.lock().unwrap().begin_scope += 1;
}
fn rollback_scope(&mut self) -> bool {
self.counts.lock().unwrap().rollback_scope += 1;
true
}
fn predict_prob(&mut self, sym: bool) -> f64 {
if sym { 0.75 } else { 0.25 }
}
fn model_name(&self) -> String {
"InstrumentedPredictor".to_string()
}
fn boxed_clone(&self) -> Box<dyn Predictor> {
Box::new(self.clone())
}
}
fn basic_config() -> AgentConfig {
AgentConfig {
algorithm: "ac-ctw".to_string(),
ct_depth: 8,
agent_horizon: 2,
observation_bits: 2,
observation_stream_len: 2,
observation_key_mode: ObservationKeyMode::FullStream,
reward_bits: 3,
agent_actions: 4,
num_simulations: 2,
exploration_exploitation_ratio: 1.0,
discount_gamma: 0.95,
min_reward: -2,
max_reward: 3,
reward_offset: 2,
random_seed: Some(7),
rate_backend: None,
rate_backend_max_order: 8,
rwkv_model_path: None,
rwkv_method: None,
mamba_model_path: None,
mamba_method: None,
rosa_max_order: None,
zpaq_method: None,
}
}
#[test]
fn external_history_updates_use_committed_predictor_paths() {
let mut agent = Agent::try_new(basic_config()).expect("valid agent config");
let counts = Arc::new(Mutex::new(CallCounts::default()));
agent.model = Box::new(InstrumentedPredictor::new(counts.clone()));
agent.model_update_percept_stream(&[1, 2], 1);
agent.model_update_action_external(3);
let snapshot = counts.lock().unwrap().clone();
assert_eq!(snapshot.commit_update, 7);
assert_eq!(snapshot.commit_update_history, 2);
assert_eq!(snapshot.update, 0);
assert_eq!(snapshot.update_history, 0);
}
#[test]
fn simulation_revert_prefers_predictor_scope_when_available() {
let mut agent = Agent::try_new(basic_config()).expect("valid agent config");
let counts = Arc::new(Mutex::new(CallCounts::default()));
agent.model = Box::new(InstrumentedPredictor::new(counts.clone()));
AgentSimulator::begin_simulation(&mut agent);
agent.model_revert(3);
let snapshot = counts.lock().unwrap().clone();
assert_eq!(snapshot.begin_scope, 1);
assert_eq!(snapshot.rollback_scope, 1);
assert_eq!(snapshot.revert, 0);
assert_eq!(snapshot.pop_history, 0);
}
}