rogue_net/
categorical_action_head.rs

1use ndarray::prelude::*;
2use rand::Rng;
3
4use crate::fun::softmax;
5use crate::linear::Linear;
6use crate::msgpack::TensorDict;
7#[derive(Debug, Clone)]
8pub struct CategoricalActionHead {
9    proj: Linear,
10}
11
12impl<'a> From<&'a TensorDict> for CategoricalActionHead {
13    fn from(state_dict: &TensorDict) -> Self {
14        let dict = state_dict.as_dict();
15        CategoricalActionHead {
16            proj: Linear::from(&dict["proj"]),
17        }
18    }
19}
20
21impl CategoricalActionHead {
22    pub fn forward(&self, x: ArrayView2<f32>, actors: Vec<usize>) -> (Array2<f32>, Vec<u64>) {
23        let actor_x = x.select(Axis(0), &actors);
24        let logits = self.proj.forward(actor_x.view());
25        let probs = softmax(&logits);
26        // TODO: efficient sampling
27        let mut rng = rand::thread_rng();
28        let mut acts = vec![0; actors.len()];
29        for i in 0..probs.dim().0 {
30            let mut r = rng.gen::<f32>();
31            for j in 0..probs.dim().1 {
32                r -= probs[[i, j]];
33                if r <= 0.0 {
34                    acts[i] = j as u64;
35                    break;
36                }
37            }
38        }
39        (probs, acts)
40    }
41}