rogue_net/
categorical_action_head.rs1use ndarray::prelude::*;
2use rand::Rng;
3
4use crate::fun::softmax;
5use crate::linear::Linear;
6use crate::msgpack::TensorDict;
7#[derive(Debug, Clone)]
8pub struct CategoricalActionHead {
9 proj: Linear,
10}
11
12impl<'a> From<&'a TensorDict> for CategoricalActionHead {
13 fn from(state_dict: &TensorDict) -> Self {
14 let dict = state_dict.as_dict();
15 CategoricalActionHead {
16 proj: Linear::from(&dict["proj"]),
17 }
18 }
19}
20
21impl CategoricalActionHead {
22 pub fn forward(&self, x: ArrayView2<f32>, actors: Vec<usize>) -> (Array2<f32>, Vec<u64>) {
23 let actor_x = x.select(Axis(0), &actors);
24 let logits = self.proj.forward(actor_x.view());
25 let probs = softmax(&logits);
26 let mut rng = rand::thread_rng();
28 let mut acts = vec![0; actors.len()];
29 for i in 0..probs.dim().0 {
30 let mut r = rng.gen::<f32>();
31 for j in 0..probs.dim().1 {
32 r -= probs[[i, j]];
33 if r <= 0.0 {
34 acts[i] = j as u64;
35 break;
36 }
37 }
38 }
39 (probs, acts)
40 }
41}