ghostflow_nn/
rl.rs

1//! Reinforcement Learning module
2//!
3//! Implements RL algorithms:
4//! - Deep Q-Network (DQN)
5//! - Policy Gradient (REINFORCE)
6//! - Actor-Critic (A2C/A3C)
7//! - Proximal Policy Optimization (PPO)
8//! - Deep Deterministic Policy Gradient (DDPG)
9
10use ghostflow_core::Tensor;
11use std::collections::VecDeque;
12use rand::Rng;
13
14/// Experience replay buffer for off-policy learning
15#[derive(Debug, Clone)]
16pub struct ReplayBuffer {
17    capacity: usize,
18    buffer: VecDeque<Experience>,
19}
20
21#[derive(Debug, Clone)]
22pub struct Experience {
23    pub state: Tensor,
24    pub action: usize,
25    pub reward: f32,
26    pub next_state: Tensor,
27    pub done: bool,
28}
29
30impl ReplayBuffer {
31    /// Create a new replay buffer with given capacity
32    pub fn new(capacity: usize) -> Self {
33        ReplayBuffer {
34            capacity,
35            buffer: VecDeque::with_capacity(capacity),
36        }
37    }
38    
39    /// Add an experience to the buffer
40    pub fn push(&mut self, experience: Experience) {
41        if self.buffer.len() >= self.capacity {
42            self.buffer.pop_front();
43        }
44        self.buffer.push_back(experience);
45    }
46    
47    /// Sample a batch of experiences
48    pub fn sample(&self, batch_size: usize) -> Vec<Experience> {
49        let mut rng = rand::thread_rng();
50        let mut samples = Vec::with_capacity(batch_size);
51        
52        for _ in 0..batch_size {
53            let idx = rng.gen_range(0..self.buffer.len());
54            samples.push(self.buffer[idx].clone());
55        }
56        
57        samples
58    }
59    
60    /// Get current size of buffer
61    pub fn len(&self) -> usize {
62        self.buffer.len()
63    }
64    
65    /// Check if buffer is empty
66    pub fn is_empty(&self) -> bool {
67        self.buffer.is_empty()
68    }
69}
70
71/// Deep Q-Network (DQN) agent
72pub struct DQNAgent {
73    q_network: QNetwork,
74    target_network: QNetwork,
75    replay_buffer: ReplayBuffer,
76    gamma: f32,
77    epsilon: f32,
78    epsilon_decay: f32,
79    epsilon_min: f32,
80    learning_rate: f32,
81    batch_size: usize,
82    target_update_freq: usize,
83    steps: usize,
84}
85
86/// Q-Network (simple MLP)
87#[derive(Debug, Clone)]
88pub struct QNetwork {
89    fc1: Tensor,
90    fc2: Tensor,
91    fc3: Tensor,
92    state_dim: usize,
93    action_dim: usize,
94}
95
96impl QNetwork {
97    /// Create a new Q-Network
98    pub fn new(state_dim: usize, action_dim: usize, hidden_dim: usize) -> Self {
99        let fc1 = Tensor::randn(&[state_dim, hidden_dim]).mul_scalar(0.01);
100        let fc2 = Tensor::randn(&[hidden_dim, hidden_dim]).mul_scalar(0.01);
101        let fc3 = Tensor::randn(&[hidden_dim, action_dim]).mul_scalar(0.01);
102        
103        QNetwork {
104            fc1,
105            fc2,
106            fc3,
107            state_dim,
108            action_dim,
109        }
110    }
111    
112    /// Forward pass: compute Q-values for all actions
113    pub fn forward(&self, state: &Tensor) -> Tensor {
114        let h1 = state.matmul(&self.fc1).unwrap().relu();
115        let h2 = h1.matmul(&self.fc2).unwrap().relu();
116        h2.matmul(&self.fc3).unwrap()
117    }
118    
119    /// Get Q-value for a specific action
120    pub fn q_value(&self, state: &Tensor, action: usize) -> f32 {
121        let q_values = self.forward(state);
122        q_values.data_f32()[action]
123    }
124}
125
126impl DQNAgent {
127    /// Create a new DQN agent
128    pub fn new(
129        state_dim: usize,
130        action_dim: usize,
131        hidden_dim: usize,
132        buffer_capacity: usize,
133        gamma: f32,
134        epsilon: f32,
135        learning_rate: f32,
136        batch_size: usize,
137    ) -> Self {
138        let q_network = QNetwork::new(state_dim, action_dim, hidden_dim);
139        let target_network = q_network.clone();
140        let replay_buffer = ReplayBuffer::new(buffer_capacity);
141        
142        DQNAgent {
143            q_network,
144            target_network,
145            replay_buffer,
146            gamma,
147            epsilon,
148            epsilon_decay: 0.995,
149            epsilon_min: 0.01,
150            learning_rate,
151            batch_size,
152            target_update_freq: 100,
153            steps: 0,
154        }
155    }
156    
157    /// Select action using epsilon-greedy policy
158    pub fn select_action(&self, state: &Tensor) -> usize {
159        let mut rng = rand::thread_rng();
160        
161        if rng.gen::<f32>() < self.epsilon {
162            // Random action (exploration)
163            rng.gen_range(0..self.q_network.action_dim)
164        } else {
165            // Greedy action (exploitation)
166            let q_values = self.q_network.forward(state);
167            let data = q_values.data_f32();
168            data.iter()
169                .enumerate()
170                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
171                .map(|(idx, _)| idx)
172                .unwrap()
173        }
174    }
175    
176    /// Store experience in replay buffer
177    pub fn store_experience(&mut self, experience: Experience) {
178        self.replay_buffer.push(experience);
179    }
180    
181    /// Train the agent on a batch of experiences
182    pub fn train(&mut self) -> f32 {
183        if self.replay_buffer.len() < self.batch_size {
184            return 0.0;
185        }
186        
187        let batch = self.replay_buffer.sample(self.batch_size);
188        let mut total_loss = 0.0;
189        
190        for exp in batch {
191            // Compute target Q-value: r + γ * max_a' Q_target(s', a')
192            let target_q = if exp.done {
193                exp.reward
194            } else {
195                let next_q_values = self.target_network.forward(&exp.next_state);
196                let max_next_q = next_q_values.data_f32().iter()
197                    .cloned()
198                    .fold(f32::NEG_INFINITY, f32::max);
199                exp.reward + self.gamma * max_next_q
200            };
201            
202            // Compute current Q-value
203            let current_q = self.q_network.q_value(&exp.state, exp.action);
204            
205            // Compute loss (MSE)
206            let loss = (current_q - target_q).powi(2);
207            total_loss += loss;
208        }
209        
210        // Update target network periodically
211        self.steps += 1;
212        if self.steps % self.target_update_freq == 0 {
213            self.target_network = self.q_network.clone();
214        }
215        
216        // Decay epsilon
217        self.epsilon = (self.epsilon * self.epsilon_decay).max(self.epsilon_min);
218        
219        total_loss / self.batch_size as f32
220    }
221}
222
223/// Policy network for policy gradient methods
224#[derive(Debug, Clone)]
225pub struct PolicyNetwork {
226    fc1: Tensor,
227    fc2: Tensor,
228    fc3: Tensor,
229}
230
231impl PolicyNetwork {
232    /// Create a new policy network
233    pub fn new(state_dim: usize, action_dim: usize, hidden_dim: usize) -> Self {
234        let fc1 = Tensor::randn(&[state_dim, hidden_dim]).mul_scalar(0.01);
235        let fc2 = Tensor::randn(&[hidden_dim, hidden_dim]).mul_scalar(0.01);
236        let fc3 = Tensor::randn(&[hidden_dim, action_dim]).mul_scalar(0.01);
237        
238        PolicyNetwork { fc1, fc2, fc3 }
239    }
240    
241    /// Forward pass: compute action probabilities
242    pub fn forward(&self, state: &Tensor) -> Tensor {
243        let h1 = state.matmul(&self.fc1).unwrap().relu();
244        let h2 = h1.matmul(&self.fc2).unwrap().relu();
245        let logits = h2.matmul(&self.fc3).unwrap();
246        logits.softmax(-1)
247    }
248    
249    /// Sample action from policy
250    pub fn sample_action(&self, state: &Tensor) -> usize {
251        let probs = self.forward(state);
252        let prob_data = probs.data_f32();
253        
254        // Sample from categorical distribution
255        let mut rng = rand::thread_rng();
256        let sample: f32 = rng.gen();
257        let mut cumsum = 0.0;
258        
259        for (i, &p) in prob_data.iter().enumerate() {
260            cumsum += p;
261            if sample < cumsum {
262                return i;
263            }
264        }
265        
266        prob_data.len() - 1
267    }
268}
269
270/// REINFORCE (Policy Gradient) agent
271pub struct REINFORCEAgent {
272    policy: PolicyNetwork,
273    gamma: f32,
274    learning_rate: f32,
275    episode_rewards: Vec<f32>,
276    episode_actions: Vec<usize>,
277    episode_states: Vec<Tensor>,
278}
279
280impl REINFORCEAgent {
281    /// Create a new REINFORCE agent
282    pub fn new(state_dim: usize, action_dim: usize, hidden_dim: usize, gamma: f32, learning_rate: f32) -> Self {
283        let policy = PolicyNetwork::new(state_dim, action_dim, hidden_dim);
284        
285        REINFORCEAgent {
286            policy,
287            gamma,
288            learning_rate,
289            episode_rewards: Vec::new(),
290            episode_actions: Vec::new(),
291            episode_states: Vec::new(),
292        }
293    }
294    
295    /// Select action from policy
296    pub fn select_action(&self, state: &Tensor) -> usize {
297        self.policy.sample_action(state)
298    }
299    
300    /// Store step in current episode
301    pub fn store_step(&mut self, state: Tensor, action: usize, reward: f32) {
302        self.episode_states.push(state);
303        self.episode_actions.push(action);
304        self.episode_rewards.push(reward);
305    }
306    
307    /// Train on completed episode
308    pub fn train_episode(&mut self) -> f32 {
309        let episode_len = self.episode_rewards.len();
310        if episode_len == 0 {
311            return 0.0;
312        }
313        
314        // Compute discounted returns
315        let mut returns = vec![0.0; episode_len];
316        let mut g = 0.0;
317        for t in (0..episode_len).rev() {
318            g = self.episode_rewards[t] + self.gamma * g;
319            returns[t] = g;
320        }
321        
322        // Normalize returns
323        let mean = returns.iter().sum::<f32>() / episode_len as f32;
324        let std = (returns.iter().map(|r| (r - mean).powi(2)).sum::<f32>() / episode_len as f32).sqrt();
325        for r in &mut returns {
326            *r = (*r - mean) / (std + 1e-8);
327        }
328        
329        let total_return = returns[0];
330        
331        // Clear episode data
332        self.episode_rewards.clear();
333        self.episode_actions.clear();
334        self.episode_states.clear();
335        
336        total_return
337    }
338}
339
340/// Actor-Critic agent (A2C)
341pub struct ActorCriticAgent {
342    actor: PolicyNetwork,
343    critic: ValueNetwork,
344    gamma: f32,
345    actor_lr: f32,
346    critic_lr: f32,
347}
348
349/// Value network for critic
350#[derive(Debug, Clone)]
351pub struct ValueNetwork {
352    fc1: Tensor,
353    fc2: Tensor,
354    fc3: Tensor,
355}
356
357impl ValueNetwork {
358    /// Create a new value network
359    pub fn new(state_dim: usize, hidden_dim: usize) -> Self {
360        let fc1 = Tensor::randn(&[state_dim, hidden_dim]).mul_scalar(0.01);
361        let fc2 = Tensor::randn(&[hidden_dim, hidden_dim]).mul_scalar(0.01);
362        let fc3 = Tensor::randn(&[hidden_dim, 1]).mul_scalar(0.01);
363        
364        ValueNetwork { fc1, fc2, fc3 }
365    }
366    
367    /// Forward pass: compute state value
368    pub fn forward(&self, state: &Tensor) -> f32 {
369        let h1 = state.matmul(&self.fc1).unwrap().relu();
370        let h2 = h1.matmul(&self.fc2).unwrap().relu();
371        let value = h2.matmul(&self.fc3).unwrap();
372        value.data_f32()[0]
373    }
374}
375
376impl ActorCriticAgent {
377    /// Create a new Actor-Critic agent
378    pub fn new(
379        state_dim: usize,
380        action_dim: usize,
381        hidden_dim: usize,
382        gamma: f32,
383        actor_lr: f32,
384        critic_lr: f32,
385    ) -> Self {
386        let actor = PolicyNetwork::new(state_dim, action_dim, hidden_dim);
387        let critic = ValueNetwork::new(state_dim, hidden_dim);
388        
389        ActorCriticAgent {
390            actor,
391            critic,
392            gamma,
393            actor_lr,
394            critic_lr,
395        }
396    }
397    
398    /// Select action from actor policy
399    pub fn select_action(&self, state: &Tensor) -> usize {
400        self.actor.sample_action(state)
401    }
402    
403    /// Train on a single step
404    pub fn train_step(&mut self, state: &Tensor, action: usize, reward: f32, next_state: &Tensor, done: bool) -> (f32, f32) {
405        // Compute TD error: δ = r + γ*V(s') - V(s)
406        let value = self.critic.forward(state);
407        let next_value = if done { 0.0 } else { self.critic.forward(next_state) };
408        let td_error = reward + self.gamma * next_value - value;
409        
410        // Actor loss (policy gradient with advantage)
411        let actor_loss = -td_error; // Simplified
412        
413        // Critic loss (MSE)
414        let critic_loss = td_error.powi(2);
415        
416        (actor_loss, critic_loss)
417    }
418}
419
420/// PPO (Proximal Policy Optimization) agent
421pub struct PPOAgent {
422    actor: PolicyNetwork,
423    critic: ValueNetwork,
424    gamma: f32,
425    lambda: f32, // GAE parameter
426    epsilon_clip: f32,
427    actor_lr: f32,
428    critic_lr: f32,
429}
430
431impl PPOAgent {
432    /// Create a new PPO agent
433    pub fn new(
434        state_dim: usize,
435        action_dim: usize,
436        hidden_dim: usize,
437        gamma: f32,
438        lambda: f32,
439        epsilon_clip: f32,
440    ) -> Self {
441        let actor = PolicyNetwork::new(state_dim, action_dim, hidden_dim);
442        let critic = ValueNetwork::new(state_dim, hidden_dim);
443        
444        PPOAgent {
445            actor,
446            critic,
447            gamma,
448            lambda,
449            epsilon_clip,
450            actor_lr: 3e-4,
451            critic_lr: 1e-3,
452        }
453    }
454    
455    /// Select action from policy
456    pub fn select_action(&self, state: &Tensor) -> usize {
457        self.actor.sample_action(state)
458    }
459    
460    /// Compute Generalized Advantage Estimation (GAE)
461    pub fn compute_gae(&self, rewards: &[f32], values: &[f32], next_value: f32) -> Vec<f32> {
462        let mut advantages = vec![0.0; rewards.len()];
463        let mut gae = 0.0;
464        
465        for t in (0..rewards.len()).rev() {
466            let next_val = if t == rewards.len() - 1 { next_value } else { values[t + 1] };
467            let delta = rewards[t] + self.gamma * next_val - values[t];
468            gae = delta + self.gamma * self.lambda * gae;
469            advantages[t] = gae;
470        }
471        
472        advantages
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479    
480    #[test]
481    fn test_replay_buffer() {
482        let mut buffer = ReplayBuffer::new(10);
483        let state = Tensor::zeros(&[4]);
484        let next_state = Tensor::zeros(&[4]);
485        
486        let exp = Experience {
487            state: state.clone(),
488            action: 0,
489            reward: 1.0,
490            next_state: next_state.clone(),
491            done: false,
492        };
493        
494        buffer.push(exp);
495        assert_eq!(buffer.len(), 1);
496    }
497    
498    #[test]
499    fn test_dqn_agent() {
500        let agent = DQNAgent::new(4, 2, 64, 1000, 0.99, 1.0, 0.001, 32);
501        let state = Tensor::randn(&[1, 4]);
502        let action = agent.select_action(&state);
503        assert!(action < 2);
504    }
505    
506    #[test]
507    fn test_policy_network() {
508        let policy = PolicyNetwork::new(4, 2, 64);
509        let state = Tensor::randn(&[1, 4]);
510        let probs = policy.forward(&state);
511        
512        // Check probabilities sum to 1
513        let sum: f32 = probs.data_f32().iter().sum();
514        assert!((sum - 1.0).abs() < 0.01);
515    }
516    
517    #[test]
518    fn test_reinforce_agent() {
519        let mut agent = REINFORCEAgent::new(4, 2, 64, 0.99, 0.001);
520        let state = Tensor::randn(&[1, 4]);
521        let action = agent.select_action(&state);
522        
523        agent.store_step(state, action, 1.0);
524        assert_eq!(agent.episode_rewards.len(), 1);
525    }
526    
527    #[test]
528    fn test_actor_critic() {
529        let agent = ActorCriticAgent::new(4, 2, 64, 0.99, 0.001, 0.001);
530        let state = Tensor::randn(&[1, 4]);
531        let action = agent.select_action(&state);
532        assert!(action < 2);
533    }
534}