reinforcex 0.0.2

Deep Reinforcement Learning Framework
use super::base_policy_network::BasePolicy;

use crate::misc::weight_initializer::he_init;
use crate::prob_distributions::BaseDistribution;
use crate::prob_distributions::SoftmaxDistribution;
use tch::nn::{Init, Linear, LinearConfig, Module};
use tch::{nn, Tensor};

pub struct FCSoftmaxPolicy {
    layers: Vec<Linear>,
    logits_layer: Linear,
    n_input_channels: i64,
    n_actions: i64,
    min_prob: f64,
}

pub struct FCSoftmaxPolicyWithValue {
    base_policy: FCSoftmaxPolicy,
    value_layer: Linear,
}

impl FCSoftmaxPolicy {
    pub fn new(
        vs: &nn::VarStore,
        n_input_channels: i64,
        n_actions: i64,
        n_hidden_layers: usize,
        n_hidden_channels: Option<i64>,
        min_prob: f64,
    ) -> Self {
        let root = vs.root();
        let mut layers = Vec::new();
        let n_hidden_channels = n_hidden_channels.unwrap_or(256);

        layers.push(nn::linear(
            &root,
            n_input_channels,
            n_hidden_channels,
            LinearConfig {
                ws_init: he_init(n_input_channels),
                bs_init: Some(Init::Const(0.0)),
                bias: true,
            },
        ));
        for _ in 0..n_hidden_layers {
            layers.push(nn::linear(
                &root,
                n_hidden_channels,
                n_hidden_channels,
                LinearConfig {
                    ws_init: he_init(n_hidden_channels),
                    bs_init: Some(Init::Const(0.0)),
                    bias: true,
                },
            ));
        }

        let logits_layer: Linear = nn::linear(
            &root,
            n_hidden_channels,
            n_actions,
            LinearConfig {
                ws_init: he_init(n_hidden_channels),
                bs_init: Some(Init::Const(0.0)),
                bias: true,
            },
        );

        FCSoftmaxPolicy {
            layers,
            logits_layer,
            n_input_channels,
            n_actions,
            min_prob,
        }
    }

    fn compute_medium_layer(&self, x: &Tensor) -> Tensor {
        let mut h = x.view([-1, self.n_input_channels]);

        for layer in &self.layers {
            h = (layer.forward(&h)).relu();
        }

        h
    }
}

impl BasePolicy for FCSoftmaxPolicy {
    fn forward(&self, x: &Tensor) -> (Box<dyn BaseDistribution>, Option<Tensor>) {
        let h = self.compute_medium_layer(x);
        let logits = self.logits_layer.forward(&h).relu();
        (
            Box::new(SoftmaxDistribution::new(logits, 1.0, self.min_prob)),
            None,
        )
    }

    fn is_cuda(&self) -> bool {
        false
    }
}

impl FCSoftmaxPolicyWithValue {
    pub fn new(
        vs: &nn::VarStore,
        n_input_channels: i64,
        n_actions: i64,
        n_hidden_layers: usize,
        n_hidden_channels: Option<i64>,
        min_prob: f64,
    ) -> Self {
        let base_policy: FCSoftmaxPolicy = FCSoftmaxPolicy::new(
            vs,
            n_input_channels,
            n_actions,
            n_hidden_layers,
            n_hidden_channels,
            min_prob,
        );

        let root = vs.root();
        let n_hidden_channels = n_hidden_channels.unwrap_or(256);
        let value_layer = nn::linear(
            &root,
            n_hidden_channels,
            1,
            LinearConfig {
                ws_init: he_init(n_hidden_channels),
                bs_init: Some(Init::Const(0.0)),
                bias: true,
            },
        );

        FCSoftmaxPolicyWithValue {
            base_policy,
            value_layer,
        }
    }
}

impl BasePolicy for FCSoftmaxPolicyWithValue {
    fn forward(&self, x: &Tensor) -> (Box<dyn BaseDistribution>, Option<Tensor>) {
        let h = self.base_policy.compute_medium_layer(x);
        let logits = self.base_policy.logits_layer.forward(&h).relu();
        let value = self.value_layer.forward(&h);
        (
            Box::new(SoftmaxDistribution::new(
                logits,
                1.0,
                self.base_policy.min_prob,
            )),
            Some(value),
        )
    }

    fn is_cuda(&self) -> bool {
        self.base_policy.is_cuda()
    }
}