Skip to main content

content_extractor_rl/agents/
ppo_agent.rs

1// ============================================================================
2// FILE: crates/content-extractor-rl/src/agents/ppo_agent.rs
3// ============================================================================
4
5use candle_core::{Device, Tensor, DType, Var};
6use candle_nn::{VarBuilder, Optimizer, AdamW, ParamsAdamW, VarMap, Linear, Module, linear, layer_norm, LayerNorm};
7use crate::replay_buffer::{PrioritizedReplayBuffer};
8use crate::{Result, agents::{RLAgent, AlgorithmType, AgentInfo}};
9use rand::RngExt;
10use rand_distr::{Normal, Distribution};
11use std::path::{Path, PathBuf};
12use crate::models::ModelMetadata;
13use std::collections::HashMap;
14
15
16// Helper functions
17fn sample_categorical(probs: &[f32]) -> usize {
18    let mut rng = rand::rng();
19    let random_val: f32 = rng.random();
20    let mut cumsum = 0.0;
21    for (i, &prob) in probs.iter().enumerate() {
22        cumsum += prob;
23        if random_val < cumsum {
24            return i;
25        }
26    }
27    probs.len() - 1
28}
29fn sample_gaussian(means: &[f32], stds: &[f32]) -> Vec<f32> {
30    let mut rng = rand::rng();
31    means.iter().zip(stds.iter())
32        .map(|(&mean, &std)| {
33            let normal = Normal::new(mean, std).unwrap_or_else(|_| Normal::new(0.0, 1.0).unwrap());
34            normal.sample(&mut rng)
35        })
36        .collect()
37}
38
39// Use helper functions that take tensors as parameter to avoid borrow conflicts
40fn save_linear_helper(
41    tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>,
42    name: &str,
43    linear: &Linear
44) -> Result<()> {
45    let weight = linear.weight();
46    let weight_shape = weight.dims().to_vec();
47    let weight_data = weight.flatten_all()
48        .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
49        .to_vec1::<f32>()
50        .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
51    tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
52
53    if let Some(bias) = linear.bias() {
54        let bias_shape = bias.dims().to_vec();
55        let bias_data = bias.flatten_all()
56            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
57            .to_vec1::<f32>()
58            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
59        tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
60    }
61    Ok(())
62}
63
64fn save_layernorm_helper(
65    tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>,
66    name: &str,
67    ln: &LayerNorm
68) -> Result<()> {
69    let weight = ln.weight();
70    let weight_shape = weight.dims().to_vec();
71    let weight_data = weight.flatten_all()
72        .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
73        .to_vec1::<f32>()
74        .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
75    tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
76
77    if let Some(bias) = ln.bias() {
78        let bias_shape = bias.dims().to_vec();
79        let bias_data = bias.flatten_all()
80            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
81            .to_vec1::<f32>()
82            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
83        tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
84    }
85    Ok(())
86}
87
88/// Actor-Critic network for PPO
89#[allow(dead_code)]
90pub struct ActorCriticNetwork {
91    // Shared feature encoder
92    fc1: Linear,
93    ln1: LayerNorm,
94    fc2: Linear,
95    ln2: LayerNorm,
96    fc3: Linear,
97    ln3: LayerNorm,
98    // Actor head (policy)
99    actor_discrete: Linear,
100    actor_param_mean: Linear,
101    actor_param_logstd: Var,  // Learnable log std
102
103    // Critic head (value function)
104    critic_fc1: Linear,
105    critic_fc2: Linear,
106
107    device: Device,
108    num_actions: usize,
109    num_params: usize,
110}
111
112
113impl ActorCriticNetwork {
114
115    pub fn new(
116        state_dim: usize,
117        num_actions: usize,
118        num_params: usize,
119        vb: VarBuilder,
120    ) -> candle_core::error::Result<Self> {
121        let device = vb.device().clone();
122        // Shared encoder
123        let fc1 = linear(state_dim, 512, vb.pp("fc1"))?;
124        let ln1 = layer_norm(512, 1e-5, vb.pp("ln1"))?;
125        let fc2 = linear(512, 256, vb.pp("fc2"))?;
126        let ln2 = layer_norm(256, 1e-5, vb.pp("ln2"))?;
127        let fc3 = linear(256, 128, vb.pp("fc3"))?;
128        let ln3 = layer_norm(128, 1e-5, vb.pp("ln3"))?;
129
130        // Actor
131        let actor_discrete = linear(128, num_actions, vb.pp("actor_discrete"))?;
132        let actor_param_mean = linear(128, num_params, vb.pp("actor_param_mean"))?;
133
134        // Initialize learnable log std
135        let logstd_init = Tensor::from_vec(
136            vec![-1.0f32; num_params],
137            &[num_params],
138            &device
139        )?;
140        let actor_param_logstd = Var::from_tensor(&logstd_init)?;
141
142        // Critic
143        let critic_fc1 = linear(128, 64, vb.pp("critic_fc1"))?;
144        let critic_fc2 = linear(64, 1, vb.pp("critic_fc2"))?;
145
146        Ok(Self {
147            fc1, ln1, fc2, ln2, fc3, ln3,
148            actor_discrete,
149            actor_param_mean,
150            actor_param_logstd,
151            critic_fc1,
152            critic_fc2,
153            device,
154            num_actions,
155            num_params,
156        })
157    }
158
159    pub fn forward(
160        &self,
161        state: &Tensor,
162        _training: bool,
163    ) -> candle_core::error::Result<(Tensor, Tensor, Tensor, Tensor)> {
164        // Shared features
165        let mut x = self.fc1.forward(state)?;
166        x = self.ln1.forward(&x)?;
167        x = x.relu()?;
168
169        x = self.fc2.forward(&x)?;
170        x = self.ln2.forward(&x)?;
171        x = x.relu()?;
172
173        x = self.fc3.forward(&x)?;
174        x = self.ln3.forward(&x)?;
175        let features = x.relu()?;
176
177        // Actor outputs
178        let action_logits = self.actor_discrete.forward(&features)?;
179        let param_mean = self.actor_param_mean.forward(&features)?.tanh()?;
180        let param_std = self.actor_param_logstd.as_tensor().exp()?;
181
182        // Critic output
183        let mut value = self.critic_fc1.forward(&features)?;
184        value = value.relu()?;
185        let value = self.critic_fc2.forward(&value)?.squeeze(1)?;
186
187        Ok((action_logits, param_mean, param_std, value))
188    }
189
190    /// Save PPO model to file with metadata
191    pub fn save_to_file(&self, path: &Path, metadata: ModelMetadata) -> Result<()> {
192        // This method already exists in the file, just ensure it's properly visible
193        // The implementation around line 130-220 is already correct
194        use std::fs::File;
195        use std::io::Write;
196        let mut file = File::create(path)?;
197
198        // Write metadata
199        let metadata_json = serde_json::to_string(&metadata)
200            .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
201        let metadata_bytes = metadata_json.as_bytes();
202        let metadata_len = metadata_bytes.len() as u64;
203
204        file.write_all(&metadata_len.to_le_bytes())?;
205        file.write_all(metadata_bytes)?;
206
207        // Collect all tensors
208        let mut tensors: HashMap<String, (Vec<usize>, Vec<f32>)> = HashMap::new();
209
210        // Save all network components using helper functions
211        save_linear_helper(&mut tensors, "fc1", &self.fc1)?;
212        save_layernorm_helper(&mut tensors, "ln1", &self.ln1)?;
213        save_linear_helper(&mut tensors, "fc2", &self.fc2)?;
214        save_layernorm_helper(&mut tensors, "ln2", &self.ln2)?;
215        save_linear_helper(&mut tensors, "fc3", &self.fc3)?;
216        save_layernorm_helper(&mut tensors, "ln3", &self.ln3)?;
217
218        save_linear_helper(&mut tensors, "actor_discrete", &self.actor_discrete)?;
219        save_linear_helper(&mut tensors, "actor_param_mean", &self.actor_param_mean)?;
220
221        // Save learnable log std
222        let logstd_tensor = self.actor_param_logstd.as_tensor();
223        let logstd_shape = logstd_tensor.dims().to_vec();
224        let logstd_data = logstd_tensor.flatten_all()
225            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
226            .to_vec1::<f32>()
227            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
228        tensors.insert("actor_param_logstd".to_string(), (logstd_shape, logstd_data));
229
230        save_linear_helper(&mut tensors, "critic_fc1", &self.critic_fc1)?;
231        save_linear_helper(&mut tensors, "critic_fc2", &self.critic_fc2)?;
232
233        // Write tensor count
234        let tensor_count = tensors.len() as u64;
235        file.write_all(&tensor_count.to_le_bytes())?;
236
237        // Write each tensor
238        for (name, (shape, data)) in tensors.iter() {
239            let name_bytes = name.as_bytes();
240            let name_len = name_bytes.len() as u64;
241            file.write_all(&name_len.to_le_bytes())?;
242            file.write_all(name_bytes)?;
243
244            let shape_len = shape.len() as u64;
245            file.write_all(&shape_len.to_le_bytes())?;
246            for &dim in shape {
247                file.write_all(&(dim as u64).to_le_bytes())?;
248            }
249
250            let data_len = data.len() as u64;
251            file.write_all(&data_len.to_le_bytes())?;
252            for &value in data {
253                file.write_all(&value.to_le_bytes())?;
254            }
255        }
256
257        let file_size = std::fs::metadata(path)?.len();
258        tracing::info!("PPO model saved: {} bytes", file_size);
259
260        Ok(())
261    }
262
263    /// Load PPO model from file - returns network and varmap
264    pub fn load_from_file(
265        path: &Path,
266        state_dim: usize,
267        num_actions: usize,
268        num_params: usize,
269        device: &Device,
270    ) -> Result<(Self, VarMap)> {  // FIXED: Return tuple
271        use std::fs::File;
272        use std::io::Read;
273
274        tracing::info!("Loading PPO model from: {}", path.display());
275
276        let mut file = File::open(path)?;
277
278        // Read metadata
279        let mut metadata_len_bytes = [0u8; 8];
280        file.read_exact(&mut metadata_len_bytes)?;
281        let metadata_len = u64::from_le_bytes(metadata_len_bytes) as usize;
282        if metadata_len > 10 * 1024 * 1024 {
283            return Err(crate::ExtractionError::ParseError(format!("Invalid model file: metadata length {} is too large", metadata_len)));
284        }
285
286        let mut metadata_bytes = vec![0u8; metadata_len];
287        file.read_exact(&mut metadata_bytes)?;
288
289        let metadata_json = String::from_utf8(metadata_bytes)
290            .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
291        let _metadata: ModelMetadata = serde_json::from_str(&metadata_json)
292            .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
293
294        tracing::info!("Model metadata loaded, loading tensors...");
295
296        // Read tensor count
297        let mut tensor_count_bytes = [0u8; 8];
298        file.read_exact(&mut tensor_count_bytes)?;
299        let tensor_count = u64::from_le_bytes(tensor_count_bytes) as usize;
300
301        let mut tensors: HashMap<String, Tensor> = HashMap::new();
302
303        for _ in 0..tensor_count {
304            let mut name_len_bytes = [0u8; 8];
305            file.read_exact(&mut name_len_bytes)?;
306            let name_len = u64::from_le_bytes(name_len_bytes) as usize;
307
308            let mut name_bytes = vec![0u8; name_len];
309            file.read_exact(&mut name_bytes)?;
310            let name = String::from_utf8(name_bytes)
311                .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
312
313            let mut shape_len_bytes = [0u8; 8];
314            file.read_exact(&mut shape_len_bytes)?;
315            let shape_len = u64::from_le_bytes(shape_len_bytes) as usize;
316
317            let mut shape = Vec::with_capacity(shape_len);
318            for _ in 0..shape_len {
319                let mut dim_bytes = [0u8; 8];
320                file.read_exact(&mut dim_bytes)?;
321                shape.push(u64::from_le_bytes(dim_bytes) as usize);
322            }
323
324            let mut data_len_bytes = [0u8; 8];
325            file.read_exact(&mut data_len_bytes)?;
326            let data_len = u64::from_le_bytes(data_len_bytes) as usize;
327
328            let mut data = Vec::with_capacity(data_len);
329            for _ in 0..data_len {
330                let mut value_bytes = [0u8; 4];
331                file.read_exact(&mut value_bytes)?;
332                data.push(f32::from_le_bytes(value_bytes));
333            }
334
335            let tensor = Tensor::from_vec(data, shape.as_slice(), device)
336                .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
337            tensors.insert(name, tensor);
338        }
339
340        tracing::info!("Loaded {} tensors, reconstructing model...", tensors.len());
341
342        // Create network first to populate varmap with correct keys, then overwrite with loaded values
343        let mut varmap = VarMap::new();
344        let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
345        let mut network = ActorCriticNetwork::new(state_dim, num_actions, num_params, vb)
346            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
347
348        for (name, tensor) in tensors.iter() {
349            if name == "actor_param_logstd" {
350                network.actor_param_logstd = Var::from_tensor(tensor)
351                    .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
352            } else {
353                varmap.set_one(name, tensor)
354                    .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
355            }
356        }
357
358        Ok((network, varmap))
359    }
360
361    /// Update load_with_device to use load_from_file
362    pub fn load_with_device(
363        path: &Path,
364        state_dim: usize,
365        num_actions: usize,
366        num_params: usize,
367        device: &Device,
368    ) -> Result<(Self, VarMap)> {
369        Self::load_from_file(path, state_dim, num_actions, num_params, device)
370    }
371
372    /// Save to SafeTensors format
373    #[allow(dead_code)]
374    pub(crate) fn save_to_safetensors(&self, path: &PathBuf) -> Result<()> {
375        use safetensors::tensor::{Dtype, TensorView};
376        use std::collections::HashMap;
377
378        let mut tensors_data: HashMap<String, TensorView> = HashMap::new();
379        let mut all_tensor_bytes: Vec<(String, Vec<usize>, Vec<u8>)> = Vec::new();
380
381        // Collect all tensors
382        let mut collect_tensor = |name: &str, tensor: &Tensor| -> Result<()> {
383            let shape = tensor.dims().to_vec();
384            let data = tensor.flatten_all()
385                .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
386                .to_vec1::<f32>()
387                .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
388            let bytes: Vec<u8> = data.iter()
389                .flat_map(|&f| f.to_le_bytes())
390                .collect();
391
392            all_tensor_bytes.push((name.to_string(), shape, bytes));
393            Ok(())
394        };
395
396        // Save all network components
397        collect_tensor("fc1.weight", self.fc1.weight())?;
398        if let Some(bias) = self.fc1.bias() {
399            collect_tensor("fc1.bias", bias)?;
400        }
401
402        collect_tensor("ln1.weight", self.ln1.weight())?;
403        if let Some(bias) = self.ln1.bias() {
404            collect_tensor("ln1.bias", bias)?;
405        }
406
407        collect_tensor("fc2.weight", self.fc2.weight())?;
408        if let Some(bias) = self.fc2.bias() {
409            collect_tensor("fc2.bias", bias)?;
410        }
411
412        collect_tensor("ln2.weight", self.ln2.weight())?;
413        if let Some(bias) = self.ln2.bias() {
414            collect_tensor("ln2.bias", bias)?;
415        }
416
417        collect_tensor("fc3.weight", self.fc3.weight())?;
418        if let Some(bias) = self.fc3.bias() {
419            collect_tensor("fc3.bias", bias)?;
420        }
421
422        collect_tensor("ln3.weight", self.ln3.weight())?;
423        if let Some(bias) = self.ln3.bias() {
424            collect_tensor("ln3.bias", bias)?;
425        }
426
427        collect_tensor("actor_discrete.weight", self.actor_discrete.weight())?;
428        if let Some(bias) = self.actor_discrete.bias() {
429            collect_tensor("actor_discrete.bias", bias)?;
430        }
431
432        collect_tensor("actor_param_mean.weight", self.actor_param_mean.weight())?;
433        if let Some(bias) = self.actor_param_mean.bias() {
434            collect_tensor("actor_param_mean.bias", bias)?;
435        }
436
437        collect_tensor("actor_param_logstd", self.actor_param_logstd.as_tensor())?;
438
439        collect_tensor("critic_fc1.weight", self.critic_fc1.weight())?;
440        if let Some(bias) = self.critic_fc1.bias() {
441            collect_tensor("critic_fc1.bias", bias)?;
442        }
443
444        collect_tensor("critic_fc2.weight", self.critic_fc2.weight())?;
445        if let Some(bias) = self.critic_fc2.bias() {
446            collect_tensor("critic_fc2.bias", bias)?;
447        }
448
449        // Convert to SafeTensors format
450        for (name, shape, bytes) in &all_tensor_bytes {
451            tensors_data.insert(
452                name.clone(),
453                TensorView::new(Dtype::F32, shape.clone(), bytes)
454                    .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
455            );
456        }
457
458        let serialized = safetensors::serialize(&tensors_data, None)
459            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
460
461        std::fs::write(path, serialized)?;
462
463        tracing::info!("PPO model saved to SafeTensors: {} bytes",
464                   std::fs::metadata(path).map(|m| m.len()).unwrap_or(0));
465
466        Ok(())
467    }
468
469    /// Save to ONNX format with metadata
470    #[allow(dead_code)]
471    pub(crate) fn save_to_onnx_with_metadata(&self, path: &Path, metadata: ModelMetadata) -> Result<()> {
472        self.save_to_file(path, metadata)
473    }
474
475}
476
477/// PPO Agent
478pub struct PPOAgent {
479    network: ActorCriticNetwork,
480    optimizer: AdamW,
481    #[allow(dead_code)]
482    varmap: VarMap,
483    // PPO hyperparameters
484    clip_epsilon: f32,
485    gae_lambda: f32,
486    value_loss_coef: f32,
487    entropy_coef: f32,
488    ppo_epochs: usize,
489
490    num_actions: usize,
491    num_params: usize,
492    gamma: f32,
493    step_count: usize,
494    device: Device,
495}
496
497impl PPOAgent {
498    pub fn new(
499        state_dim: usize,
500        num_actions: usize,
501        num_params: usize,
502        gamma: f32,
503        lr: f64,
504        device: &Device,
505        varmap: VarMap,
506    ) -> Result<Self> {
507        let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
508        let network = ActorCriticNetwork::new(state_dim, num_actions, num_params, vb)
509            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
510        let trainable_vars = varmap.all_vars();
511        let params = ParamsAdamW {
512            lr,
513            beta1: 0.9,
514            beta2: 0.999,
515            eps: 1e-8,
516            weight_decay: 0.0,
517        };
518
519        let optimizer = AdamW::new(trainable_vars, params)
520            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
521
522        Ok(Self {
523            network,
524            optimizer,
525            varmap,
526            clip_epsilon: 0.2,
527            gae_lambda: 0.95,
528            value_loss_coef: 0.5,
529            entropy_coef: 0.01,
530            ppo_epochs: 4,
531            num_actions,
532            num_params,
533            gamma,
534            step_count: 0,
535            device: device.clone(),
536        })
537    }
538
539    /// Calculate Generalized Advantage Estimation (GAE)
540    fn calculate_gae(
541        &self,
542        rewards: &[f32],
543        values: &[f32],
544        next_value: f32,
545        dones: &[bool],
546    ) -> (Vec<f32>, Vec<f32>) {
547        let mut advantages = vec![0.0; rewards.len()];
548        let mut returns = vec![0.0; rewards.len()];
549
550        let mut gae = 0.0;
551        let mut next_val = next_value;
552
553        for t in (0..rewards.len()).rev() {
554            let done_mask = if dones[t] { 0.0 } else { 1.0 };
555            let delta = rewards[t] + self.gamma * next_val * done_mask - values[t];
556            gae = delta + self.gamma * self.gae_lambda * done_mask * gae;
557            advantages[t] = gae;
558            returns[t] = gae + values[t];
559            next_val = values[t];
560        }
561
562        (advantages, returns)
563    }
564
565    /// Calculate log probability for discrete action
566    fn discrete_log_prob(
567        logits: &Tensor,
568        actions: &Tensor,
569    ) -> candle_core::error::Result<Tensor> {
570        let log_probs = candle_nn::ops::log_softmax(logits, 1)?;
571        log_probs.gather(&actions.unsqueeze(1)?, 1)?.squeeze(1)
572    }
573
574    /// Calculate log probability for continuous actions (Gaussian)
575    fn continuous_log_prob(
576        mean: &Tensor,
577        std: &Tensor,
578        actions: &Tensor,
579    ) -> candle_core::error::Result<Tensor> {
580        // Get dimensions
581        let batch_size = mean.dims()[0];
582        let num_params = mean.dims()[1];
583
584        // Broadcast std to match mean shape
585        // std is [num_params], need [batch_size, num_params]
586        let std_broadcast = std.unsqueeze(0)?.broadcast_as(mean.shape())?;
587        let variance = std_broadcast.sqr()?;
588        let diff = (actions - mean)?;
589
590        // Create pi constant with proper shape [batch_size, num_params]
591        let pi_constant = Tensor::new(
592            vec![2.0 * std::f32::consts::PI; batch_size * num_params],
593            mean.device()
594        )?.reshape(&[batch_size, num_params])?;
595
596        let log_prob = -0.5 * (
597            diff.sqr()?.div(&variance)? +
598                variance.log()? +
599                pi_constant.log()?
600        )?;
601
602        log_prob?.sum(1)
603    }
604
605    /// Calculate entropy for exploration bonus
606    fn calculate_entropy(
607        logits: &Tensor,
608        std: &Tensor,
609    ) -> candle_core::error::Result<Tensor> {
610        // Discrete entropy
611        let probs = candle_nn::ops::softmax(logits, 1)?;
612        let log_probs = candle_nn::ops::log_softmax(logits, 1)?;
613        let discrete_entropy = -1.0 * (probs * log_probs)?.sum(1)?.mean_all()?;
614
615        // Continuous entropy (Gaussian)
616        // std is [num_params], create constant with same shape
617        let num_params = std.dims()[0];
618        let constant = Tensor::new(
619            vec![0.5 * (1.0 + 2.0 * std::f32::consts::PI).ln(); num_params],
620            std.device()
621        )?;
622
623        let continuous_entropy = (std.log()? + constant)?.mean_all()?;
624
625        discrete_entropy + continuous_entropy
626    }
627
628    /// PPO update step
629    fn ppo_update(
630        &mut self,
631        states: &Tensor,
632        actions_discrete: &Tensor,
633        actions_continuous: &Tensor,
634        old_log_probs: &Tensor,
635        advantages: &Tensor,
636        returns: &Tensor,
637    ) -> Result<(f32, f32, f32)> {
638        // Forward pass
639        let (action_logits, param_mean, param_std, values) =
640            self.network.forward(states, true)
641                .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
642
643        // Calculate current log probabilities
644        let log_probs_discrete = Self::discrete_log_prob(&action_logits, actions_discrete)
645            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
646        let log_probs_continuous = Self::continuous_log_prob(&param_mean, &param_std, actions_continuous)
647            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
648        let log_probs = (log_probs_discrete + log_probs_continuous)
649            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
650
651        // PPO clipped objective
652        let ratio = (log_probs.clone() - old_log_probs)?.exp()?;
653
654        // FIXED: Normalize advantages with proper shape handling
655        let batch_size = advantages.dims()[0];
656
657        // Calculate mean and std as scalars
658        let adv_mean_scalar = advantages.mean_all()?.to_scalar::<f32>()?;
659        let adv_variance = advantages.sub(&Tensor::new(&[adv_mean_scalar], advantages.device())?.broadcast_as(advantages.shape())?)?.sqr()?.mean_all()?;
660        let adv_std_scalar = (adv_variance.to_scalar::<f32>()? + 1e-8).sqrt();
661
662        // Create broadcast-able tensors
663        let adv_mean_broadcast = Tensor::new(vec![adv_mean_scalar; batch_size], advantages.device())?;
664        let adv_std_broadcast = Tensor::new(vec![adv_std_scalar; batch_size], advantages.device())?;
665
666        // Now shapes match: [batch_size] - [batch_size] / [batch_size]
667        let advantages_norm = ((advantages - &adv_mean_broadcast)? / &adv_std_broadcast)?;
668
669        let surr1 = (ratio.clone() * &advantages_norm)?;
670
671        let ratio_clipped = ratio.clamp(1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon)?;
672        let surr2 = (ratio_clipped * advantages_norm)?;
673
674        let policy_loss = (-1.0 * surr1.minimum(&surr2)?.mean_all()?)?;
675
676        // Value loss with clipping
677        let value_loss = (values - returns)?.sqr()?.mean_all()?;
678
679        // Entropy bonus
680        let entropy = Self::calculate_entropy(&action_logits, &param_std)?;
681
682        // Total loss - combine as scalars
683        let value_loss_weighted = value_loss.to_scalar::<f32>()? * self.value_loss_coef;
684        let entropy_weighted = entropy.to_scalar::<f32>()? * self.entropy_coef;
685        let policy_loss_scalar = policy_loss.to_scalar::<f32>()?;
686
687        let total_loss_scalar = policy_loss_scalar + value_loss_weighted - entropy_weighted;
688
689        // Create tensor from combined scalar for backward pass
690        let total_loss = Tensor::new(&[total_loss_scalar], policy_loss.device())?;
691
692        // Backward and optimize
693        let grads = total_loss.backward()
694            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
695
696        self.optimizer.step(&grads)
697            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
698
699        Ok((
700            policy_loss_scalar,
701            value_loss.to_scalar::<f32>()?,
702            entropy.to_scalar::<f32>()?,
703        ))
704    }
705
706    pub fn load_with_device(
707        path: &Path,
708        state_dim: usize,
709        num_actions: usize,
710        num_params: usize,
711        device: &Device,
712    ) -> Result<Self> {
713        let (network, varmap) = ActorCriticNetwork::load_from_file(
714            path, state_dim, num_actions, num_params, device
715        )?;
716
717        // Create optimizers
718        let trainable_vars = varmap.all_vars();
719        let params = ParamsAdamW {
720            lr: 3e-4,
721            beta1: 0.9,
722            beta2: 0.999,
723            eps: 1e-8,
724            weight_decay: 0.0,
725        };
726
727        let optimizer = AdamW::new(trainable_vars, params)
728            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
729
730        Ok(Self {
731            network,
732            optimizer,
733            varmap,
734            clip_epsilon: 0.2,
735            gae_lambda: 0.95,
736            value_loss_coef: 0.5,
737            entropy_coef: 0.01,
738            ppo_epochs: 4,
739            num_actions,
740            num_params,
741            gamma: 0.95,
742            step_count: 0,
743            device: device.clone(),
744        })
745    }
746}
747impl RLAgent for PPOAgent {
748    fn select_action(&self, state: &[f32], _epsilon: f32) -> Result<(usize, Vec<f32>)> {
749        // PPO uses stochastic policy, not epsilon-greedy
750        let state_tensor = Tensor::from_vec(
751            state.to_vec(),
752            &[1, state.len()],
753            &self.device
754        )?;
755        let (action_logits, param_mean, param_std, _value) =
756            self.network.forward(&state_tensor, false)?;
757
758        // Sample discrete action from categorical distribution
759        let probs = candle_nn::ops::softmax(&action_logits, 1)?.to_vec2::<f32>()?;
760        let discrete_action = sample_categorical(&probs[0]);
761
762        // Sample continuous parameters from Gaussian
763        let mean_vec = param_mean.to_vec2::<f32>()?;
764        let std_vec = param_std.to_vec1::<f32>()?;
765        let continuous_params = sample_gaussian(&mean_vec[0], &std_vec);
766
767        Ok((discrete_action, continuous_params))
768    }
769
770    fn save_with_metadata(
771        &self,
772        path: &Path,
773        training_episodes: usize,
774        hyperparameters: HashMap<String, f64>,
775    ) -> Result<()> {
776        let metadata = ModelMetadata::new(
777            300,
778            self.num_actions,
779            self.num_params,
780            AlgorithmType::PPO,
781            training_episodes,
782            hyperparameters,
783        );
784
785        self.network.save_to_file(path, metadata)
786    }
787
788    fn save(&self, path: &Path) -> Result<()> {
789        self.save_with_metadata(path, 0, std::collections::HashMap::new())
790    }
791
792    fn train_step(
793        &mut self,
794        replay_buffer: &mut PrioritizedReplayBuffer,
795        batch_size: usize,
796    ) -> Result<f32> {
797        let batch = replay_buffer.sample(batch_size);
798        if batch.is_none() {
799            return Ok(0.0);
800        }
801
802        let batch = batch.unwrap();
803        let experiences = &batch.experiences;
804
805        if experiences.is_empty() {
806            return Ok(0.0);
807        }
808
809        // Convert experiences to tensors
810        let state_dim = experiences[0].state.len();
811        let states_flat: Vec<f32> = experiences.iter()
812            .flat_map(|e| e.state.clone())
813            .collect();
814        let states_tensor = Tensor::from_vec(
815            states_flat,
816            &[experiences.len(), state_dim],
817            &self.device
818        )?;
819
820        // Get old policy values
821        let (old_logits, old_means, old_stds, old_values) =
822            self.network.forward(&states_tensor, false)?;
823
824        // Extract actions
825        let actions_discrete: Vec<i64> = experiences.iter()
826            .map(|e| e.action.0 as i64)
827            .collect();
828        let actions_discrete_tensor = Tensor::from_vec(
829            actions_discrete,
830            &[experiences.len()],
831            &self.device
832        )?;
833
834        let actions_continuous_flat: Vec<f32> = experiences.iter()
835            .flat_map(|e| e.action.1.clone())
836            .collect();
837        let actions_continuous_tensor = Tensor::from_vec(
838            actions_continuous_flat,
839            &[experiences.len(), self.num_params],
840            &self.device
841        )?;
842
843        // Calculate old log probabilities
844        let old_log_probs_discrete = Self::discrete_log_prob(&old_logits, &actions_discrete_tensor)?;
845        let old_log_probs_continuous = Self::continuous_log_prob(&old_means, &old_stds, &actions_continuous_tensor)?;
846        let old_log_probs = (old_log_probs_discrete + old_log_probs_continuous)?;
847
848        // Calculate GAE
849        let rewards: Vec<f32> = experiences.iter().map(|e| e.reward).collect();
850        let values_vec: Vec<f32> = old_values.to_vec1()?;
851        let dones: Vec<bool> = experiences.iter().map(|e| e.done).collect();
852
853        let (advantages, returns) = self.calculate_gae(
854            &rewards,
855            &values_vec,
856            0.0,
857            &dones,
858        );
859
860        let advantages_tensor = Tensor::from_vec(advantages, &[experiences.len()], &self.device)?;
861        let returns_tensor = Tensor::from_vec(returns, &[experiences.len()], &self.device)?;
862
863        // PPO update for multiple epochs
864        let mut total_policy_loss = 0.0;
865        let mut total_value_loss = 0.0;
866        let mut _total_entropy = 0.0;
867
868        for _ in 0..self.ppo_epochs {
869            let (policy_loss, value_loss, entropy) = self.ppo_update(
870                &states_tensor,
871                &actions_discrete_tensor,
872                &actions_continuous_tensor,
873                &old_log_probs,
874                &advantages_tensor,
875                &returns_tensor,
876            )?;
877
878            total_policy_loss += policy_loss;
879            total_value_loss += value_loss;
880            _total_entropy += entropy;
881        }
882
883        self.step_count += 1;
884
885        let avg_loss = (total_policy_loss + total_value_loss) / self.ppo_epochs as f32;
886        Ok(avg_loss)
887    }
888
889    fn update_target_network(&mut self) {
890        // PPO doesn't use target networks
891    }
892
893    fn get_step_count(&self) -> usize {
894        self.step_count
895    }
896
897    fn algorithm_type(&self) -> AlgorithmType {
898        AlgorithmType::PPO
899    }
900
901    fn get_info(&self) -> AgentInfo {
902        AgentInfo {
903            algorithm: AlgorithmType::PPO,
904            num_parameters: 0, // TODO: calculate
905            state_dim: 0,
906            num_actions: self.num_actions,
907            continuous_params: self.num_params,
908            version: "1.0.0".to_string(),
909            features: vec![
910                "actor_critic".to_string(),
911                "clipped_objective".to_string(),
912                "gae".to_string(),
913                "entropy_bonus".to_string(),
914            ],
915        }
916    }
917}
918
919// ADDITIONAL HELPER: Debug tensor shapes (for development)
920// Usage in ppo_update for debugging:
921// debug_tensor_shape("advantages", advantages);
922// debug_tensor_shape("adv_mean", &adv_mean);
923// debug_tensor_shape("adv_std", &adv_std);
924
925#[cfg(debug_assertions)]
926#[allow(dead_code)]
927fn debug_tensor_shape(name: &str, tensor: &Tensor) {
928    eprintln!("DEBUG: {} shape: {:?}", name, tensor.dims());
929}
930
931#[cfg(not(debug_assertions))]
932fn debug_tensor_shape(_name: &str, _tensor: &Tensor) {
933    // No-op in release builds
934}