algos/ml/rl/
policy_gradients.rs

1use ndarray::{Array1, Array2};
2use rand::Rng;
3use std::collections::VecDeque;
4
5/// Policy Gradients (REINFORCE) implementation with baseline
6pub struct PolicyGradients {
7    policy_network: PolicyNetwork,
8    value_network: ValueNetwork,
9    memory: VecDeque<Experience>,
10    learning_rate: f64,
11    gamma: f64,
12}
13
14struct PolicyNetwork {
15    weights: Array2<f64>,
16    biases: Array1<f64>,
17}
18
19struct ValueNetwork {
20    weights: Array2<f64>,
21    biases: Array1<f64>,
22}
23
24pub struct Experience {
25    state: Array1<f64>,
26    action: usize,
27    reward: f64,
28}
29
30impl PolicyGradients {
31    pub fn new(state_dim: usize, action_dim: usize, learning_rate: f64, gamma: f64) -> Self {
32        PolicyGradients {
33            policy_network: PolicyNetwork::new(state_dim, action_dim),
34            value_network: ValueNetwork::new(state_dim),
35            memory: VecDeque::new(),
36            learning_rate,
37            gamma,
38        }
39    }
40
41    pub fn select_action(&self, state: &Array1<f64>) -> usize {
42        let action_probs = self.policy_network.forward(state);
43        self.sample_action(&action_probs)
44    }
45
46    pub fn train(&mut self, episode: Vec<Experience>) {
47        let mut returns = Vec::new();
48        let mut running_return = 0.0;
49
50        // Calculate returns for each step
51        for experience in episode.iter().rev() {
52            running_return = experience.reward + self.gamma * running_return;
53            returns.push(running_return);
54        }
55        returns.reverse();
56
57        // Convert returns to array
58        let returns = Array1::from(returns);
59
60        // Update networks
61        for (i, experience) in episode.iter().enumerate() {
62            // Calculate advantage
63            let value = self.value_network.forward(&experience.state);
64            let advantage = returns[i] - value;
65
66            // Update policy network
67            let action_probs = self.policy_network.forward(&experience.state);
68            let mut policy_gradient = action_probs.clone();
69            policy_gradient[experience.action] -= 1.0;
70
71            self.policy_network.backward(
72                &experience.state,
73                &policy_gradient,
74                advantage,
75                self.learning_rate,
76            );
77
78            // Update value network (baseline)
79            self.value_network
80                .backward(&experience.state, returns[i], self.learning_rate);
81        }
82    }
83
84    fn sample_action(&self, probs: &Array1<f64>) -> usize {
85        let mut rng = rand::thread_rng();
86        let sample = rng.gen::<f64>();
87        let mut cumsum = 0.0;
88
89        for (i, &prob) in probs.iter().enumerate() {
90            cumsum += prob;
91            if sample < cumsum {
92                return i;
93            }
94        }
95
96        probs.len() - 1
97    }
98
99    pub fn store_experience(&mut self, state: Array1<f64>, action: usize, reward: f64) {
100        self.memory.push_back(Experience {
101            state,
102            action,
103            reward,
104        });
105    }
106
107    pub fn get_episode(&mut self) -> Vec<Experience> {
108        self.memory.drain(..).collect()
109    }
110}
111
112impl PolicyNetwork {
113    fn new(input_dim: usize, output_dim: usize) -> Self {
114        PolicyNetwork {
115            weights: Array2::zeros((input_dim, output_dim)),
116            biases: Array1::zeros(output_dim),
117        }
118    }
119
120    fn forward(&self, state: &Array1<f64>) -> Array1<f64> {
121        let logits = state.dot(&self.weights) + &self.biases;
122        self.softmax(logits)
123    }
124
125    fn backward(
126        &mut self,
127        state: &Array1<f64>,
128        policy_gradient: &Array1<f64>,
129        advantage: f64,
130        learning_rate: f64,
131    ) {
132        // Policy gradient update
133        for i in 0..self.weights.nrows() {
134            for j in 0..self.weights.ncols() {
135                self.weights[[i, j]] -= learning_rate * advantage * state[i] * policy_gradient[j];
136            }
137        }
138
139        for j in 0..self.biases.len() {
140            self.biases[j] -= learning_rate * advantage * policy_gradient[j];
141        }
142    }
143
144    fn softmax(&self, x: Array1<f64>) -> Array1<f64> {
145        let max_val = x.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
146        let exp_x = x.mapv(|a| (a - max_val).exp());
147        let sum_exp = exp_x.sum();
148        exp_x / sum_exp
149    }
150}
151
152impl ValueNetwork {
153    fn new(input_dim: usize) -> Self {
154        ValueNetwork {
155            weights: Array2::zeros((input_dim, 1)),
156            biases: Array1::zeros(1),
157        }
158    }
159
160    fn forward(&self, state: &Array1<f64>) -> f64 {
161        (state.dot(&self.weights) + &self.biases)[0]
162    }
163
164    fn backward(&mut self, state: &Array1<f64>, target: f64, learning_rate: f64) {
165        let prediction = self.forward(state);
166        let error = target - prediction;
167
168        // Value network update
169        for i in 0..self.weights.nrows() {
170            self.weights[[i, 0]] += learning_rate * error * state[i];
171        }
172        self.biases[0] += learning_rate * error;
173    }
174}