use constants::{TICKS_PER_ROUND, TOTAL_WOOD};
use domain::LearnDomain;
use estimator::NNStateValueEstimator;
use npc_engine_core::{
graphviz, ActiveTask, ActiveTasks, AgentId, IdleTask, MCTSConfiguration, StateValueEstimator,
MCTS,
};
use npc_engine_utils::{run_simple_executor, ExecutorState, ExecutorStateLocal};
use rand::{thread_rng, Rng};
use state::State;
mod behavior;
mod constants;
mod domain;
mod estimator;
mod state;
mod task;
#[derive(Default)]
struct LearnExecutorState {
estimator: NNStateValueEstimator,
planned_values: Vec<([f32; 5], f32)>,
}
impl LearnExecutorState {
pub fn wood_collected(&self) -> f32 {
TOTAL_WOOD as f32 - self.planned_values.last().unwrap().0.iter().sum::<f32>()
}
pub fn train_and_clear_data(&mut self) {
self.estimator.0.train(self.planned_values.iter(), 0.001);
self.planned_values.clear();
}
}
impl ExecutorStateLocal<LearnDomain> for LearnExecutorState {
fn create_initial_state(&self) -> State {
let mut rng = thread_rng();
let mut map = [0; 14];
for _tree in 0..TOTAL_WOOD {
let mut pos = rng.gen_range(0..14);
while map[pos] >= 3 {
pos = rng.gen_range(0..14);
}
map[pos] += 1;
}
State {
map,
wood_count: 0,
agent_pos: rng.gen_range(0..14),
}
}
fn init_task_queue(&self, _: &State) -> ActiveTasks<LearnDomain> {
vec![ActiveTask::new_with_end(0, AgentId(0), Box::new(IdleTask))]
.into_iter()
.collect()
}
fn keep_agent(&self, tick: u64, _state: &State, _agent: AgentId) -> bool {
tick < TICKS_PER_ROUND
}
}
impl ExecutorState<LearnDomain> for LearnExecutorState {
fn create_state_value_estimator(&self) -> Box<dyn StateValueEstimator<LearnDomain> + Send> {
Box::new(self.estimator.clone())
}
fn post_mcts_run_hook(
&mut self,
mcts: &MCTS<LearnDomain>,
_last_active_task: &ActiveTask<LearnDomain>,
) {
self.planned_values.push((
mcts.initial_state().local_view(),
mcts.q_value_at_root(AgentId(0)),
));
}
}
#[allow(dead_code)]
fn enable_map_display() {
use std::io::Write;
env_logger::builder()
.format(|buf, record| writeln!(buf, "{}", record.args()))
.filter(None, log::LevelFilter::Info)
.init();
}
fn main() {
const CONFIG: MCTSConfiguration = MCTSConfiguration {
allow_invalid_tasks: true,
visits: 20,
depth: TICKS_PER_ROUND as u32,
exploration: 1.414,
discount_hl: TICKS_PER_ROUND as f32 / 3.,
seed: None,
planning_task_duration: None,
};
graphviz::set_graph_output_depth(4);
let mut executor_state = LearnExecutorState::default();
for _epoch in 0..600 {
run_simple_executor::<LearnDomain, LearnExecutorState>(&CONFIG, &mut executor_state);
let wood_collected = executor_state.wood_collected();
println!("{wood_collected}");
executor_state.train_and_clear_data();
}
}