algos/ml/rl/
policy_gradients.rs1use ndarray::{Array1, Array2};
2use rand::Rng;
3use std::collections::VecDeque;
4
5pub struct PolicyGradients {
7 policy_network: PolicyNetwork,
8 value_network: ValueNetwork,
9 memory: VecDeque<Experience>,
10 learning_rate: f64,
11 gamma: f64,
12}
13
14struct PolicyNetwork {
15 weights: Array2<f64>,
16 biases: Array1<f64>,
17}
18
19struct ValueNetwork {
20 weights: Array2<f64>,
21 biases: Array1<f64>,
22}
23
24pub struct Experience {
25 state: Array1<f64>,
26 action: usize,
27 reward: f64,
28}
29
30impl PolicyGradients {
31 pub fn new(state_dim: usize, action_dim: usize, learning_rate: f64, gamma: f64) -> Self {
32 PolicyGradients {
33 policy_network: PolicyNetwork::new(state_dim, action_dim),
34 value_network: ValueNetwork::new(state_dim),
35 memory: VecDeque::new(),
36 learning_rate,
37 gamma,
38 }
39 }
40
41 pub fn select_action(&self, state: &Array1<f64>) -> usize {
42 let action_probs = self.policy_network.forward(state);
43 self.sample_action(&action_probs)
44 }
45
46 pub fn train(&mut self, episode: Vec<Experience>) {
47 let mut returns = Vec::new();
48 let mut running_return = 0.0;
49
50 for experience in episode.iter().rev() {
52 running_return = experience.reward + self.gamma * running_return;
53 returns.push(running_return);
54 }
55 returns.reverse();
56
57 let returns = Array1::from(returns);
59
60 for (i, experience) in episode.iter().enumerate() {
62 let value = self.value_network.forward(&experience.state);
64 let advantage = returns[i] - value;
65
66 let action_probs = self.policy_network.forward(&experience.state);
68 let mut policy_gradient = action_probs.clone();
69 policy_gradient[experience.action] -= 1.0;
70
71 self.policy_network.backward(
72 &experience.state,
73 &policy_gradient,
74 advantage,
75 self.learning_rate,
76 );
77
78 self.value_network
80 .backward(&experience.state, returns[i], self.learning_rate);
81 }
82 }
83
84 fn sample_action(&self, probs: &Array1<f64>) -> usize {
85 let mut rng = rand::thread_rng();
86 let sample = rng.gen::<f64>();
87 let mut cumsum = 0.0;
88
89 for (i, &prob) in probs.iter().enumerate() {
90 cumsum += prob;
91 if sample < cumsum {
92 return i;
93 }
94 }
95
96 probs.len() - 1
97 }
98
99 pub fn store_experience(&mut self, state: Array1<f64>, action: usize, reward: f64) {
100 self.memory.push_back(Experience {
101 state,
102 action,
103 reward,
104 });
105 }
106
107 pub fn get_episode(&mut self) -> Vec<Experience> {
108 self.memory.drain(..).collect()
109 }
110}
111
112impl PolicyNetwork {
113 fn new(input_dim: usize, output_dim: usize) -> Self {
114 PolicyNetwork {
115 weights: Array2::zeros((input_dim, output_dim)),
116 biases: Array1::zeros(output_dim),
117 }
118 }
119
120 fn forward(&self, state: &Array1<f64>) -> Array1<f64> {
121 let logits = state.dot(&self.weights) + &self.biases;
122 self.softmax(logits)
123 }
124
125 fn backward(
126 &mut self,
127 state: &Array1<f64>,
128 policy_gradient: &Array1<f64>,
129 advantage: f64,
130 learning_rate: f64,
131 ) {
132 for i in 0..self.weights.nrows() {
134 for j in 0..self.weights.ncols() {
135 self.weights[[i, j]] -= learning_rate * advantage * state[i] * policy_gradient[j];
136 }
137 }
138
139 for j in 0..self.biases.len() {
140 self.biases[j] -= learning_rate * advantage * policy_gradient[j];
141 }
142 }
143
144 fn softmax(&self, x: Array1<f64>) -> Array1<f64> {
145 let max_val = x.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
146 let exp_x = x.mapv(|a| (a - max_val).exp());
147 let sum_exp = exp_x.sum();
148 exp_x / sum_exp
149 }
150}
151
152impl ValueNetwork {
153 fn new(input_dim: usize) -> Self {
154 ValueNetwork {
155 weights: Array2::zeros((input_dim, 1)),
156 biases: Array1::zeros(1),
157 }
158 }
159
160 fn forward(&self, state: &Array1<f64>) -> f64 {
161 (state.dot(&self.weights) + &self.biases)[0]
162 }
163
164 fn backward(&mut self, state: &Array1<f64>, target: f64, learning_rate: f64) {
165 let prediction = self.forward(state);
166 let error = target - prediction;
167
168 for i in 0..self.weights.nrows() {
170 self.weights[[i, 0]] += learning_rate * error * state[i];
171 }
172 self.biases[0] += learning_rate * error;
173 }
174}