reinforcex 0.0.1

Deep Reinforcement Learning Framework
use super::base_agent::BaseAgent;
use crate::explorers::BaseExplorer;
use crate::misc::batch_states::batch_states;
use crate::models::BaseQFunction;
use crate::replay_buffer::ReplayBuffer;
use tch::{nn, no_grad, Device, Tensor};

pub struct DQN {
    model: Box<dyn BaseQFunction>,
    optimizer: nn::Optimizer,
    replay_buffer: ReplayBuffer,
    explorer: Box<dyn BaseExplorer>,
    action_size: usize,
    batch_size: usize,
    update_interval: usize,
    target_model: Box<dyn BaseQFunction>,
    target_update_interval: usize,
    gamma: f64,
    n_steps: usize,
    t: usize,
}

impl DQN {
    pub fn new(
        model: Box<dyn BaseQFunction>,
        optimizer: nn::Optimizer,
        action_size: usize,
        batch_size: usize,
        replay_buffer_capacity: usize,
        update_interval: usize,
        target_update_interval: usize,
        explorer: Box<dyn BaseExplorer>,
        gamma: f64,
        n_steps: usize,
    ) -> Self {
        let target_model = model.clone();
        DQN {
            model,
            optimizer,
            replay_buffer: ReplayBuffer::new(replay_buffer_capacity, n_steps),
            explorer,
            action_size,
            batch_size,
            update_interval,
            target_model,
            target_update_interval,
            gamma,
            n_steps,
            t: 0,
        }
    }

    fn _update(&mut self) {
        if self.replay_buffer.len() < self.batch_size {
            return;
        }
        let experiences = self.replay_buffer.sample(self.batch_size);
        let mut states: Vec<Tensor> = vec![];
        let mut n_step_after_states: Vec<Tensor> = vec![];
        let mut actions: Vec<Tensor> = vec![];
        let mut n_step_discounted_rewards: Vec<f64> = vec![];
        for experience in experiences {
            let state = experience.state.shallow_clone();
            let n_step_after_state = experience
                .n_step_after_experience
                .borrow()
                .as_ref()
                .unwrap()
                .state
                .shallow_clone();
            let action = experience.action.as_ref().unwrap().shallow_clone();
            let n_step_discounted_reward = experience
                .n_step_discounted_reward
                .borrow()
                .unwrap_or(experience.reward_for_this_state);
            states.push(state);
            n_step_after_states.push(n_step_after_state);
            actions.push(action);
            n_step_discounted_rewards.push(n_step_discounted_reward);
        }
        let q_values = self._compute_q_values(&n_step_after_states, n_step_discounted_rewards);
        let pred_q_values = self._compute_pred_q_values(&states, actions);
        let loss = self._compute_loss(q_values, pred_q_values);
        self.optimizer.zero_grad();
        loss.backward();
        self.optimizer.step();
    }

    fn _sync_target_model(&mut self) {
        self.target_model = self.model.clone();
    }

    fn _compute_q_values(
        &self,
        n_step_after_states: &Vec<Tensor>,
        n_step_discounted_rewards: Vec<f64>,
    ) -> Tensor {
        assert_eq!(n_step_after_states.len(), n_step_discounted_rewards.len());
        let _states = batch_states(n_step_after_states, self.model.is_cuda());
        let pred_q_values = self.target_model.forward(&_states);
        let max_q_values = pred_q_values.max_dim(1, false).0;
        let gamma_n = self.gamma.powi(self.n_steps as i32);
        let n_step_discounted_rewards_tensor = Tensor::from_slice(&n_step_discounted_rewards);
        let updated_q_values = max_q_values * gamma_n + n_step_discounted_rewards_tensor;
        updated_q_values
    }

    fn _compute_pred_q_values(&self, states: &Vec<Tensor>, actions: Vec<Tensor>) -> Tensor {
        assert_eq!(states.len(), actions.len());
        let _states = batch_states(states, self.model.is_cuda());
        let pred_q_values = self.model.forward(&_states);
        let actions = Tensor::stack(&actions, 0).to_kind(tch::Kind::Int64);
        let pred_q_values_selected = pred_q_values.gather(1, &actions, false).squeeze();
        pred_q_values_selected
    }

    fn _compute_loss(&self, q_values: Tensor, pred_q_values: Tensor) -> Tensor {
        let loss = (q_values - pred_q_values).square().mean(tch::Kind::Float);
        loss
    }
}

impl BaseAgent for DQN {
    fn act(&self, obs: &Tensor) -> Tensor {
        no_grad(|| {
            let state = batch_states(&vec![obs.shallow_clone()], self.model.is_cuda());
            let q_values = self.model.forward(&state);
            q_values.argmax(1, false).to_device(Device::Cpu)
        })
    }

    fn act_and_train(&mut self, obs: &Tensor, reward: f64) -> Tensor {
        self.t += 1;
        let state = batch_states(&vec![obs.shallow_clone()], self.model.is_cuda());
        let q_values = self.model.forward(&state);

        let greedy_action_func = || q_values.argmax(1, false).int64_value(&[0]) as usize;
        let random_action_func = || rand::random::<usize>() % self.action_size;

        let action_idx =
            self.explorer
                .select_action(self.t, &random_action_func, &greedy_action_func);
        let action = Tensor::from_slice(&[action_idx as i64])
            .detach()
            .to_device(Device::Cpu);

        self.replay_buffer.append(
            state,
            Some(action.shallow_clone()),
            reward,
            false,
            self.gamma,
        );
        if self.t % self.update_interval == 0 {
            self._update();
        }
        if self.t % self.target_update_interval == 0 {
            self._sync_target_model();
        }
        action
    }

    fn stop_episode_and_train(&mut self, obs: &Tensor, reward: f64) {
        let state = batch_states(&vec![obs.shallow_clone()], self.model.is_cuda());
        self.replay_buffer
            .append(state, None, reward, true, self.gamma);
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::explorers::EpsilonGreedy;
    use crate::models::FCQNetwork;
    use tch::{nn, nn::OptimizerConfig, Device, Kind, Tensor};

    #[test]
    fn test_dqn_new() {
        let vs = nn::VarStore::new(Device::Cpu);
        let optimizer = nn::Adam::default().build(&vs, 1e-3).unwrap();
        let model = FCQNetwork::new(&vs, 4, 2, 2, Some(64));
        let explorer = EpsilonGreedy::new(1.0, 0.1, 1000);

        let dqn = DQN::new(
            Box::new(model),
            optimizer,
            2,
            32,
            1000,
            8,
            100,
            Box::new(explorer),
            0.99,
            3,
        );

        assert_eq!(dqn.action_size, 2);
        assert_eq!(dqn.batch_size, 32);
        assert_eq!(dqn.update_interval, 8);
        assert_eq!(dqn.target_update_interval, 100);
        assert_eq!(dqn.gamma, 0.99);
        assert_eq!(dqn.n_steps, 3);
        assert_eq!(dqn.t, 0);
    }

    #[test]
    fn test_dqn_act_and_train() {
        let vs = nn::VarStore::new(Device::Cpu);
        let optimizer = nn::Adam::default().build(&vs, 1e-2).unwrap();
        let model = FCQNetwork::new(&vs, 4, 4, 2, Some(128));
        let explorer = EpsilonGreedy::new(1.0, 0.0, 1000);
        let mut dqn = DQN::new(
            Box::new(model),
            optimizer,
            4,
            16,
            1000,
            50,
            100,
            Box::new(explorer),
            0.5,
            1,
        );

        let mut reward = 0.0;
        for i in 0..2000 {
            let obs = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0]).to_kind(Kind::Float);
            let action = dqn.act_and_train(&obs, reward);
            let action_value = i64::from(action.int64_value(&[]));
            if action_value == 2 {
                reward = 100.0;
            } else {
                reward = 0.0
            }
            assert!([0, 1, 2, 3].contains(&action_value));
            assert_eq!(dqn.t, i + 1);
            if dqn.t > 1000 {
                assert_eq!(action_value, 2);
            }
        }
        let obs = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0]).to_kind(Kind::Float);
        dqn.stop_episode_and_train(&obs, 1.0);

        for _ in 0..1000 {
            let action = dqn.act(&obs);
            let action_value = i64::from(action.int64_value(&[]));
            assert_eq!(action_value, 2);
        }
    }
}