1use ghostflow_core::Tensor;
11use std::collections::VecDeque;
12use rand::Rng;
13
14#[derive(Debug, Clone)]
16pub struct ReplayBuffer {
17 capacity: usize,
18 buffer: VecDeque<Experience>,
19}
20
21#[derive(Debug, Clone)]
22pub struct Experience {
23 pub state: Tensor,
24 pub action: usize,
25 pub reward: f32,
26 pub next_state: Tensor,
27 pub done: bool,
28}
29
30impl ReplayBuffer {
31 pub fn new(capacity: usize) -> Self {
33 ReplayBuffer {
34 capacity,
35 buffer: VecDeque::with_capacity(capacity),
36 }
37 }
38
39 pub fn push(&mut self, experience: Experience) {
41 if self.buffer.len() >= self.capacity {
42 self.buffer.pop_front();
43 }
44 self.buffer.push_back(experience);
45 }
46
47 pub fn sample(&self, batch_size: usize) -> Vec<Experience> {
49 let mut rng = rand::thread_rng();
50 let mut samples = Vec::with_capacity(batch_size);
51
52 for _ in 0..batch_size {
53 let idx = rng.gen_range(0..self.buffer.len());
54 samples.push(self.buffer[idx].clone());
55 }
56
57 samples
58 }
59
60 pub fn len(&self) -> usize {
62 self.buffer.len()
63 }
64
65 pub fn is_empty(&self) -> bool {
67 self.buffer.is_empty()
68 }
69}
70
71pub struct DQNAgent {
73 q_network: QNetwork,
74 target_network: QNetwork,
75 replay_buffer: ReplayBuffer,
76 gamma: f32,
77 epsilon: f32,
78 epsilon_decay: f32,
79 epsilon_min: f32,
80 learning_rate: f32,
81 batch_size: usize,
82 target_update_freq: usize,
83 steps: usize,
84}
85
86#[derive(Debug, Clone)]
88pub struct QNetwork {
89 fc1: Tensor,
90 fc2: Tensor,
91 fc3: Tensor,
92 state_dim: usize,
93 action_dim: usize,
94}
95
96impl QNetwork {
97 pub fn new(state_dim: usize, action_dim: usize, hidden_dim: usize) -> Self {
99 let fc1 = Tensor::randn(&[state_dim, hidden_dim]).mul_scalar(0.01);
100 let fc2 = Tensor::randn(&[hidden_dim, hidden_dim]).mul_scalar(0.01);
101 let fc3 = Tensor::randn(&[hidden_dim, action_dim]).mul_scalar(0.01);
102
103 QNetwork {
104 fc1,
105 fc2,
106 fc3,
107 state_dim,
108 action_dim,
109 }
110 }
111
112 pub fn forward(&self, state: &Tensor) -> Tensor {
114 let h1 = state.matmul(&self.fc1).unwrap().relu();
115 let h2 = h1.matmul(&self.fc2).unwrap().relu();
116 h2.matmul(&self.fc3).unwrap()
117 }
118
119 pub fn q_value(&self, state: &Tensor, action: usize) -> f32 {
121 let q_values = self.forward(state);
122 q_values.data_f32()[action]
123 }
124}
125
126impl DQNAgent {
127 pub fn new(
129 state_dim: usize,
130 action_dim: usize,
131 hidden_dim: usize,
132 buffer_capacity: usize,
133 gamma: f32,
134 epsilon: f32,
135 learning_rate: f32,
136 batch_size: usize,
137 ) -> Self {
138 let q_network = QNetwork::new(state_dim, action_dim, hidden_dim);
139 let target_network = q_network.clone();
140 let replay_buffer = ReplayBuffer::new(buffer_capacity);
141
142 DQNAgent {
143 q_network,
144 target_network,
145 replay_buffer,
146 gamma,
147 epsilon,
148 epsilon_decay: 0.995,
149 epsilon_min: 0.01,
150 learning_rate,
151 batch_size,
152 target_update_freq: 100,
153 steps: 0,
154 }
155 }
156
157 pub fn select_action(&self, state: &Tensor) -> usize {
159 let mut rng = rand::thread_rng();
160
161 if rng.gen::<f32>() < self.epsilon {
162 rng.gen_range(0..self.q_network.action_dim)
164 } else {
165 let q_values = self.q_network.forward(state);
167 let data = q_values.data_f32();
168 data.iter()
169 .enumerate()
170 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
171 .map(|(idx, _)| idx)
172 .unwrap()
173 }
174 }
175
176 pub fn store_experience(&mut self, experience: Experience) {
178 self.replay_buffer.push(experience);
179 }
180
181 pub fn train(&mut self) -> f32 {
183 if self.replay_buffer.len() < self.batch_size {
184 return 0.0;
185 }
186
187 let batch = self.replay_buffer.sample(self.batch_size);
188 let mut total_loss = 0.0;
189
190 for exp in batch {
191 let target_q = if exp.done {
193 exp.reward
194 } else {
195 let next_q_values = self.target_network.forward(&exp.next_state);
196 let max_next_q = next_q_values.data_f32().iter()
197 .cloned()
198 .fold(f32::NEG_INFINITY, f32::max);
199 exp.reward + self.gamma * max_next_q
200 };
201
202 let current_q = self.q_network.q_value(&exp.state, exp.action);
204
205 let loss = (current_q - target_q).powi(2);
207 total_loss += loss;
208 }
209
210 self.steps += 1;
212 if self.steps % self.target_update_freq == 0 {
213 self.target_network = self.q_network.clone();
214 }
215
216 self.epsilon = (self.epsilon * self.epsilon_decay).max(self.epsilon_min);
218
219 total_loss / self.batch_size as f32
220 }
221}
222
223#[derive(Debug, Clone)]
225pub struct PolicyNetwork {
226 fc1: Tensor,
227 fc2: Tensor,
228 fc3: Tensor,
229}
230
231impl PolicyNetwork {
232 pub fn new(state_dim: usize, action_dim: usize, hidden_dim: usize) -> Self {
234 let fc1 = Tensor::randn(&[state_dim, hidden_dim]).mul_scalar(0.01);
235 let fc2 = Tensor::randn(&[hidden_dim, hidden_dim]).mul_scalar(0.01);
236 let fc3 = Tensor::randn(&[hidden_dim, action_dim]).mul_scalar(0.01);
237
238 PolicyNetwork { fc1, fc2, fc3 }
239 }
240
241 pub fn forward(&self, state: &Tensor) -> Tensor {
243 let h1 = state.matmul(&self.fc1).unwrap().relu();
244 let h2 = h1.matmul(&self.fc2).unwrap().relu();
245 let logits = h2.matmul(&self.fc3).unwrap();
246 logits.softmax(-1)
247 }
248
249 pub fn sample_action(&self, state: &Tensor) -> usize {
251 let probs = self.forward(state);
252 let prob_data = probs.data_f32();
253
254 let mut rng = rand::thread_rng();
256 let sample: f32 = rng.gen();
257 let mut cumsum = 0.0;
258
259 for (i, &p) in prob_data.iter().enumerate() {
260 cumsum += p;
261 if sample < cumsum {
262 return i;
263 }
264 }
265
266 prob_data.len() - 1
267 }
268}
269
270pub struct REINFORCEAgent {
272 policy: PolicyNetwork,
273 gamma: f32,
274 learning_rate: f32,
275 episode_rewards: Vec<f32>,
276 episode_actions: Vec<usize>,
277 episode_states: Vec<Tensor>,
278}
279
280impl REINFORCEAgent {
281 pub fn new(state_dim: usize, action_dim: usize, hidden_dim: usize, gamma: f32, learning_rate: f32) -> Self {
283 let policy = PolicyNetwork::new(state_dim, action_dim, hidden_dim);
284
285 REINFORCEAgent {
286 policy,
287 gamma,
288 learning_rate,
289 episode_rewards: Vec::new(),
290 episode_actions: Vec::new(),
291 episode_states: Vec::new(),
292 }
293 }
294
295 pub fn select_action(&self, state: &Tensor) -> usize {
297 self.policy.sample_action(state)
298 }
299
300 pub fn store_step(&mut self, state: Tensor, action: usize, reward: f32) {
302 self.episode_states.push(state);
303 self.episode_actions.push(action);
304 self.episode_rewards.push(reward);
305 }
306
307 pub fn train_episode(&mut self) -> f32 {
309 let episode_len = self.episode_rewards.len();
310 if episode_len == 0 {
311 return 0.0;
312 }
313
314 let mut returns = vec![0.0; episode_len];
316 let mut g = 0.0;
317 for t in (0..episode_len).rev() {
318 g = self.episode_rewards[t] + self.gamma * g;
319 returns[t] = g;
320 }
321
322 let mean = returns.iter().sum::<f32>() / episode_len as f32;
324 let std = (returns.iter().map(|r| (r - mean).powi(2)).sum::<f32>() / episode_len as f32).sqrt();
325 for r in &mut returns {
326 *r = (*r - mean) / (std + 1e-8);
327 }
328
329 let total_return = returns[0];
330
331 self.episode_rewards.clear();
333 self.episode_actions.clear();
334 self.episode_states.clear();
335
336 total_return
337 }
338}
339
340pub struct ActorCriticAgent {
342 actor: PolicyNetwork,
343 critic: ValueNetwork,
344 gamma: f32,
345 actor_lr: f32,
346 critic_lr: f32,
347}
348
349#[derive(Debug, Clone)]
351pub struct ValueNetwork {
352 fc1: Tensor,
353 fc2: Tensor,
354 fc3: Tensor,
355}
356
357impl ValueNetwork {
358 pub fn new(state_dim: usize, hidden_dim: usize) -> Self {
360 let fc1 = Tensor::randn(&[state_dim, hidden_dim]).mul_scalar(0.01);
361 let fc2 = Tensor::randn(&[hidden_dim, hidden_dim]).mul_scalar(0.01);
362 let fc3 = Tensor::randn(&[hidden_dim, 1]).mul_scalar(0.01);
363
364 ValueNetwork { fc1, fc2, fc3 }
365 }
366
367 pub fn forward(&self, state: &Tensor) -> f32 {
369 let h1 = state.matmul(&self.fc1).unwrap().relu();
370 let h2 = h1.matmul(&self.fc2).unwrap().relu();
371 let value = h2.matmul(&self.fc3).unwrap();
372 value.data_f32()[0]
373 }
374}
375
376impl ActorCriticAgent {
377 pub fn new(
379 state_dim: usize,
380 action_dim: usize,
381 hidden_dim: usize,
382 gamma: f32,
383 actor_lr: f32,
384 critic_lr: f32,
385 ) -> Self {
386 let actor = PolicyNetwork::new(state_dim, action_dim, hidden_dim);
387 let critic = ValueNetwork::new(state_dim, hidden_dim);
388
389 ActorCriticAgent {
390 actor,
391 critic,
392 gamma,
393 actor_lr,
394 critic_lr,
395 }
396 }
397
398 pub fn select_action(&self, state: &Tensor) -> usize {
400 self.actor.sample_action(state)
401 }
402
403 pub fn train_step(&mut self, state: &Tensor, _action: usize, reward: f32, next_state: &Tensor, done: bool) -> (f32, f32) {
405 let value = self.critic.forward(state);
407 let next_value = if done { 0.0 } else { self.critic.forward(next_state) };
408 let td_error = reward + self.gamma * next_value - value;
409
410 let actor_loss = -td_error; let critic_loss = td_error.powi(2);
415
416 (actor_loss, critic_loss)
417 }
418}
419
420pub struct PPOAgent {
422 actor: PolicyNetwork,
423 critic: ValueNetwork,
424 gamma: f32,
425 lambda: f32, epsilon_clip: f32,
427 actor_lr: f32,
428 critic_lr: f32,
429}
430
431impl PPOAgent {
432 pub fn new(
434 state_dim: usize,
435 action_dim: usize,
436 hidden_dim: usize,
437 gamma: f32,
438 lambda: f32,
439 epsilon_clip: f32,
440 ) -> Self {
441 let actor = PolicyNetwork::new(state_dim, action_dim, hidden_dim);
442 let critic = ValueNetwork::new(state_dim, hidden_dim);
443
444 PPOAgent {
445 actor,
446 critic,
447 gamma,
448 lambda,
449 epsilon_clip,
450 actor_lr: 3e-4,
451 critic_lr: 1e-3,
452 }
453 }
454
455 pub fn select_action(&self, state: &Tensor) -> usize {
457 self.actor.sample_action(state)
458 }
459
460 pub fn compute_gae(&self, rewards: &[f32], values: &[f32], next_value: f32) -> Vec<f32> {
462 let mut advantages = vec![0.0; rewards.len()];
463 let mut gae = 0.0;
464
465 for t in (0..rewards.len()).rev() {
466 let next_val = if t == rewards.len() - 1 { next_value } else { values[t + 1] };
467 let delta = rewards[t] + self.gamma * next_val - values[t];
468 gae = delta + self.gamma * self.lambda * gae;
469 advantages[t] = gae;
470 }
471
472 advantages
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn test_replay_buffer() {
482 let mut buffer = ReplayBuffer::new(10);
483 let state = Tensor::zeros(&[4]);
484 let next_state = Tensor::zeros(&[4]);
485
486 let exp = Experience {
487 state: state.clone(),
488 action: 0,
489 reward: 1.0,
490 next_state: next_state.clone(),
491 done: false,
492 };
493
494 buffer.push(exp);
495 assert_eq!(buffer.len(), 1);
496 }
497
498 #[test]
499 fn test_dqn_agent() {
500 let agent = DQNAgent::new(4, 2, 64, 1000, 0.99, 1.0, 0.001, 32);
501 let state = Tensor::randn(&[1, 4]);
502 let action = agent.select_action(&state);
503 assert!(action < 2);
504 }
505
506 #[test]
507 fn test_policy_network() {
508 let policy = PolicyNetwork::new(4, 2, 64);
509 let state = Tensor::randn(&[1, 4]);
510 let probs = policy.forward(&state);
511
512 let sum: f32 = probs.data_f32().iter().sum();
514 assert!((sum - 1.0).abs() < 0.01);
515 }
516
517 #[test]
518 fn test_reinforce_agent() {
519 let mut agent = REINFORCEAgent::new(4, 2, 64, 0.99, 0.001);
520 let state = Tensor::randn(&[1, 4]);
521 let action = agent.select_action(&state);
522
523 agent.store_step(state, action, 1.0);
524 assert_eq!(agent.episode_rewards.len(), 1);
525 }
526
527 #[test]
528 fn test_actor_critic() {
529 let agent = ActorCriticAgent::new(4, 2, 64, 0.99, 0.001, 0.001);
530 let state = Tensor::randn(&[1, 4]);
531 let action = agent.select_action(&state);
532 assert!(action < 2);
533 }
534}