1use candle_core::{Device, Tensor, DType, Var};
6use candle_nn::{VarBuilder, Optimizer, AdamW, ParamsAdamW, VarMap, Linear, Module, linear, layer_norm, LayerNorm};
7use crate::replay_buffer::{PrioritizedReplayBuffer};
8use crate::{Result, agents::{RLAgent, AlgorithmType, AgentInfo}};
9use rand::RngExt;
10use rand_distr::{Normal, Distribution};
11use std::path::{Path, PathBuf};
12use crate::models::ModelMetadata;
13use std::collections::HashMap;
14
15
16fn sample_categorical(probs: &[f32]) -> usize {
18 let mut rng = rand::rng();
19 let random_val: f32 = rng.random();
20 let mut cumsum = 0.0;
21 for (i, &prob) in probs.iter().enumerate() {
22 cumsum += prob;
23 if random_val < cumsum {
24 return i;
25 }
26 }
27 probs.len() - 1
28}
29fn sample_gaussian(means: &[f32], stds: &[f32]) -> Vec<f32> {
30 let mut rng = rand::rng();
31 means.iter().zip(stds.iter())
32 .map(|(&mean, &std)| {
33 let normal = Normal::new(mean, std).unwrap_or_else(|_| Normal::new(0.0, 1.0).unwrap());
34 normal.sample(&mut rng)
35 })
36 .collect()
37}
38
39fn save_linear_helper(
41 tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>,
42 name: &str,
43 linear: &Linear
44) -> Result<()> {
45 let weight = linear.weight();
46 let weight_shape = weight.dims().to_vec();
47 let weight_data = weight.flatten_all()
48 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
49 .to_vec1::<f32>()
50 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
51 tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
52
53 if let Some(bias) = linear.bias() {
54 let bias_shape = bias.dims().to_vec();
55 let bias_data = bias.flatten_all()
56 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
57 .to_vec1::<f32>()
58 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
59 tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
60 }
61 Ok(())
62}
63
64fn save_layernorm_helper(
65 tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>,
66 name: &str,
67 ln: &LayerNorm
68) -> Result<()> {
69 let weight = ln.weight();
70 let weight_shape = weight.dims().to_vec();
71 let weight_data = weight.flatten_all()
72 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
73 .to_vec1::<f32>()
74 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
75 tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
76
77 if let Some(bias) = ln.bias() {
78 let bias_shape = bias.dims().to_vec();
79 let bias_data = bias.flatten_all()
80 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
81 .to_vec1::<f32>()
82 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
83 tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
84 }
85 Ok(())
86}
87
88#[allow(dead_code)]
90pub struct ActorCriticNetwork {
91 fc1: Linear,
93 ln1: LayerNorm,
94 fc2: Linear,
95 ln2: LayerNorm,
96 fc3: Linear,
97 ln3: LayerNorm,
98 actor_discrete: Linear,
100 actor_param_mean: Linear,
101 actor_param_logstd: Var, critic_fc1: Linear,
105 critic_fc2: Linear,
106
107 device: Device,
108 num_actions: usize,
109 num_params: usize,
110}
111
112
113impl ActorCriticNetwork {
114
115 pub fn new(
116 state_dim: usize,
117 num_actions: usize,
118 num_params: usize,
119 vb: VarBuilder,
120 ) -> candle_core::error::Result<Self> {
121 let device = vb.device().clone();
122 let fc1 = linear(state_dim, 512, vb.pp("fc1"))?;
124 let ln1 = layer_norm(512, 1e-5, vb.pp("ln1"))?;
125 let fc2 = linear(512, 256, vb.pp("fc2"))?;
126 let ln2 = layer_norm(256, 1e-5, vb.pp("ln2"))?;
127 let fc3 = linear(256, 128, vb.pp("fc3"))?;
128 let ln3 = layer_norm(128, 1e-5, vb.pp("ln3"))?;
129
130 let actor_discrete = linear(128, num_actions, vb.pp("actor_discrete"))?;
132 let actor_param_mean = linear(128, num_params, vb.pp("actor_param_mean"))?;
133
134 let logstd_init = Tensor::from_vec(
136 vec![-1.0f32; num_params],
137 &[num_params],
138 &device
139 )?;
140 let actor_param_logstd = Var::from_tensor(&logstd_init)?;
141
142 let critic_fc1 = linear(128, 64, vb.pp("critic_fc1"))?;
144 let critic_fc2 = linear(64, 1, vb.pp("critic_fc2"))?;
145
146 Ok(Self {
147 fc1, ln1, fc2, ln2, fc3, ln3,
148 actor_discrete,
149 actor_param_mean,
150 actor_param_logstd,
151 critic_fc1,
152 critic_fc2,
153 device,
154 num_actions,
155 num_params,
156 })
157 }
158
159 pub fn forward(
160 &self,
161 state: &Tensor,
162 _training: bool,
163 ) -> candle_core::error::Result<(Tensor, Tensor, Tensor, Tensor)> {
164 let mut x = self.fc1.forward(state)?;
166 x = self.ln1.forward(&x)?;
167 x = x.relu()?;
168
169 x = self.fc2.forward(&x)?;
170 x = self.ln2.forward(&x)?;
171 x = x.relu()?;
172
173 x = self.fc3.forward(&x)?;
174 x = self.ln3.forward(&x)?;
175 let features = x.relu()?;
176
177 let action_logits = self.actor_discrete.forward(&features)?;
179 let param_mean = self.actor_param_mean.forward(&features)?.tanh()?;
180 let param_std = self.actor_param_logstd.as_tensor().exp()?;
181
182 let mut value = self.critic_fc1.forward(&features)?;
184 value = value.relu()?;
185 let value = self.critic_fc2.forward(&value)?.squeeze(1)?;
186
187 Ok((action_logits, param_mean, param_std, value))
188 }
189
190 pub fn save_to_file(&self, path: &Path, metadata: ModelMetadata) -> Result<()> {
192 use std::fs::File;
195 use std::io::Write;
196 let mut file = File::create(path)?;
197
198 let metadata_json = serde_json::to_string(&metadata)
200 .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
201 let metadata_bytes = metadata_json.as_bytes();
202 let metadata_len = metadata_bytes.len() as u64;
203
204 file.write_all(&metadata_len.to_le_bytes())?;
205 file.write_all(metadata_bytes)?;
206
207 let mut tensors: HashMap<String, (Vec<usize>, Vec<f32>)> = HashMap::new();
209
210 save_linear_helper(&mut tensors, "fc1", &self.fc1)?;
212 save_layernorm_helper(&mut tensors, "ln1", &self.ln1)?;
213 save_linear_helper(&mut tensors, "fc2", &self.fc2)?;
214 save_layernorm_helper(&mut tensors, "ln2", &self.ln2)?;
215 save_linear_helper(&mut tensors, "fc3", &self.fc3)?;
216 save_layernorm_helper(&mut tensors, "ln3", &self.ln3)?;
217
218 save_linear_helper(&mut tensors, "actor_discrete", &self.actor_discrete)?;
219 save_linear_helper(&mut tensors, "actor_param_mean", &self.actor_param_mean)?;
220
221 let logstd_tensor = self.actor_param_logstd.as_tensor();
223 let logstd_shape = logstd_tensor.dims().to_vec();
224 let logstd_data = logstd_tensor.flatten_all()
225 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
226 .to_vec1::<f32>()
227 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
228 tensors.insert("actor_param_logstd".to_string(), (logstd_shape, logstd_data));
229
230 save_linear_helper(&mut tensors, "critic_fc1", &self.critic_fc1)?;
231 save_linear_helper(&mut tensors, "critic_fc2", &self.critic_fc2)?;
232
233 let tensor_count = tensors.len() as u64;
235 file.write_all(&tensor_count.to_le_bytes())?;
236
237 for (name, (shape, data)) in tensors.iter() {
239 let name_bytes = name.as_bytes();
240 let name_len = name_bytes.len() as u64;
241 file.write_all(&name_len.to_le_bytes())?;
242 file.write_all(name_bytes)?;
243
244 let shape_len = shape.len() as u64;
245 file.write_all(&shape_len.to_le_bytes())?;
246 for &dim in shape {
247 file.write_all(&(dim as u64).to_le_bytes())?;
248 }
249
250 let data_len = data.len() as u64;
251 file.write_all(&data_len.to_le_bytes())?;
252 for &value in data {
253 file.write_all(&value.to_le_bytes())?;
254 }
255 }
256
257 let file_size = std::fs::metadata(path)?.len();
258 tracing::info!("PPO model saved: {} bytes", file_size);
259
260 Ok(())
261 }
262
263 pub fn load_from_file(
265 path: &Path,
266 state_dim: usize,
267 num_actions: usize,
268 num_params: usize,
269 device: &Device,
270 ) -> Result<(Self, VarMap)> { use std::fs::File;
272 use std::io::Read;
273
274 tracing::info!("Loading PPO model from: {}", path.display());
275
276 let mut file = File::open(path)?;
277
278 let mut metadata_len_bytes = [0u8; 8];
280 file.read_exact(&mut metadata_len_bytes)?;
281 let metadata_len = u64::from_le_bytes(metadata_len_bytes) as usize;
282 if metadata_len > 10 * 1024 * 1024 {
283 return Err(crate::ExtractionError::ParseError(format!("Invalid model file: metadata length {} is too large", metadata_len)));
284 }
285
286 let mut metadata_bytes = vec![0u8; metadata_len];
287 file.read_exact(&mut metadata_bytes)?;
288
289 let metadata_json = String::from_utf8(metadata_bytes)
290 .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
291 let _metadata: ModelMetadata = serde_json::from_str(&metadata_json)
292 .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
293
294 tracing::info!("Model metadata loaded, loading tensors...");
295
296 let mut tensor_count_bytes = [0u8; 8];
298 file.read_exact(&mut tensor_count_bytes)?;
299 let tensor_count = u64::from_le_bytes(tensor_count_bytes) as usize;
300
301 let mut tensors: HashMap<String, Tensor> = HashMap::new();
302
303 for _ in 0..tensor_count {
304 let mut name_len_bytes = [0u8; 8];
305 file.read_exact(&mut name_len_bytes)?;
306 let name_len = u64::from_le_bytes(name_len_bytes) as usize;
307
308 let mut name_bytes = vec![0u8; name_len];
309 file.read_exact(&mut name_bytes)?;
310 let name = String::from_utf8(name_bytes)
311 .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
312
313 let mut shape_len_bytes = [0u8; 8];
314 file.read_exact(&mut shape_len_bytes)?;
315 let shape_len = u64::from_le_bytes(shape_len_bytes) as usize;
316
317 let mut shape = Vec::with_capacity(shape_len);
318 for _ in 0..shape_len {
319 let mut dim_bytes = [0u8; 8];
320 file.read_exact(&mut dim_bytes)?;
321 shape.push(u64::from_le_bytes(dim_bytes) as usize);
322 }
323
324 let mut data_len_bytes = [0u8; 8];
325 file.read_exact(&mut data_len_bytes)?;
326 let data_len = u64::from_le_bytes(data_len_bytes) as usize;
327
328 let mut data = Vec::with_capacity(data_len);
329 for _ in 0..data_len {
330 let mut value_bytes = [0u8; 4];
331 file.read_exact(&mut value_bytes)?;
332 data.push(f32::from_le_bytes(value_bytes));
333 }
334
335 let tensor = Tensor::from_vec(data, shape.as_slice(), device)
336 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
337 tensors.insert(name, tensor);
338 }
339
340 tracing::info!("Loaded {} tensors, reconstructing model...", tensors.len());
341
342 let mut varmap = VarMap::new();
344 let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
345 let mut network = ActorCriticNetwork::new(state_dim, num_actions, num_params, vb)
346 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
347
348 for (name, tensor) in tensors.iter() {
349 if name == "actor_param_logstd" {
350 network.actor_param_logstd = Var::from_tensor(tensor)
351 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
352 } else {
353 varmap.set_one(name, tensor)
354 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
355 }
356 }
357
358 Ok((network, varmap))
359 }
360
361 pub fn load_with_device(
363 path: &Path,
364 state_dim: usize,
365 num_actions: usize,
366 num_params: usize,
367 device: &Device,
368 ) -> Result<(Self, VarMap)> {
369 Self::load_from_file(path, state_dim, num_actions, num_params, device)
370 }
371
372 #[allow(dead_code)]
374 pub(crate) fn save_to_safetensors(&self, path: &PathBuf) -> Result<()> {
375 use safetensors::tensor::{Dtype, TensorView};
376 use std::collections::HashMap;
377
378 let mut tensors_data: HashMap<String, TensorView> = HashMap::new();
379 let mut all_tensor_bytes: Vec<(String, Vec<usize>, Vec<u8>)> = Vec::new();
380
381 let mut collect_tensor = |name: &str, tensor: &Tensor| -> Result<()> {
383 let shape = tensor.dims().to_vec();
384 let data = tensor.flatten_all()
385 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
386 .to_vec1::<f32>()
387 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
388 let bytes: Vec<u8> = data.iter()
389 .flat_map(|&f| f.to_le_bytes())
390 .collect();
391
392 all_tensor_bytes.push((name.to_string(), shape, bytes));
393 Ok(())
394 };
395
396 collect_tensor("fc1.weight", self.fc1.weight())?;
398 if let Some(bias) = self.fc1.bias() {
399 collect_tensor("fc1.bias", bias)?;
400 }
401
402 collect_tensor("ln1.weight", self.ln1.weight())?;
403 if let Some(bias) = self.ln1.bias() {
404 collect_tensor("ln1.bias", bias)?;
405 }
406
407 collect_tensor("fc2.weight", self.fc2.weight())?;
408 if let Some(bias) = self.fc2.bias() {
409 collect_tensor("fc2.bias", bias)?;
410 }
411
412 collect_tensor("ln2.weight", self.ln2.weight())?;
413 if let Some(bias) = self.ln2.bias() {
414 collect_tensor("ln2.bias", bias)?;
415 }
416
417 collect_tensor("fc3.weight", self.fc3.weight())?;
418 if let Some(bias) = self.fc3.bias() {
419 collect_tensor("fc3.bias", bias)?;
420 }
421
422 collect_tensor("ln3.weight", self.ln3.weight())?;
423 if let Some(bias) = self.ln3.bias() {
424 collect_tensor("ln3.bias", bias)?;
425 }
426
427 collect_tensor("actor_discrete.weight", self.actor_discrete.weight())?;
428 if let Some(bias) = self.actor_discrete.bias() {
429 collect_tensor("actor_discrete.bias", bias)?;
430 }
431
432 collect_tensor("actor_param_mean.weight", self.actor_param_mean.weight())?;
433 if let Some(bias) = self.actor_param_mean.bias() {
434 collect_tensor("actor_param_mean.bias", bias)?;
435 }
436
437 collect_tensor("actor_param_logstd", self.actor_param_logstd.as_tensor())?;
438
439 collect_tensor("critic_fc1.weight", self.critic_fc1.weight())?;
440 if let Some(bias) = self.critic_fc1.bias() {
441 collect_tensor("critic_fc1.bias", bias)?;
442 }
443
444 collect_tensor("critic_fc2.weight", self.critic_fc2.weight())?;
445 if let Some(bias) = self.critic_fc2.bias() {
446 collect_tensor("critic_fc2.bias", bias)?;
447 }
448
449 for (name, shape, bytes) in &all_tensor_bytes {
451 tensors_data.insert(
452 name.clone(),
453 TensorView::new(Dtype::F32, shape.clone(), bytes)
454 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
455 );
456 }
457
458 let serialized = safetensors::serialize(&tensors_data, None)
459 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
460
461 std::fs::write(path, serialized)?;
462
463 tracing::info!("PPO model saved to SafeTensors: {} bytes",
464 std::fs::metadata(path).map(|m| m.len()).unwrap_or(0));
465
466 Ok(())
467 }
468
469 #[allow(dead_code)]
471 pub(crate) fn save_to_onnx_with_metadata(&self, path: &Path, metadata: ModelMetadata) -> Result<()> {
472 self.save_to_file(path, metadata)
473 }
474
475}
476
477pub struct PPOAgent {
479 network: ActorCriticNetwork,
480 optimizer: AdamW,
481 #[allow(dead_code)]
482 varmap: VarMap,
483 clip_epsilon: f32,
485 gae_lambda: f32,
486 value_loss_coef: f32,
487 entropy_coef: f32,
488 ppo_epochs: usize,
489
490 num_actions: usize,
491 num_params: usize,
492 gamma: f32,
493 step_count: usize,
494 device: Device,
495}
496
497impl PPOAgent {
498 pub fn new(
499 state_dim: usize,
500 num_actions: usize,
501 num_params: usize,
502 gamma: f32,
503 lr: f64,
504 device: &Device,
505 varmap: VarMap,
506 ) -> Result<Self> {
507 let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
508 let network = ActorCriticNetwork::new(state_dim, num_actions, num_params, vb)
509 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
510 let trainable_vars = varmap.all_vars();
511 let params = ParamsAdamW {
512 lr,
513 beta1: 0.9,
514 beta2: 0.999,
515 eps: 1e-8,
516 weight_decay: 0.0,
517 };
518
519 let optimizer = AdamW::new(trainable_vars, params)
520 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
521
522 Ok(Self {
523 network,
524 optimizer,
525 varmap,
526 clip_epsilon: 0.2,
527 gae_lambda: 0.95,
528 value_loss_coef: 0.5,
529 entropy_coef: 0.01,
530 ppo_epochs: 4,
531 num_actions,
532 num_params,
533 gamma,
534 step_count: 0,
535 device: device.clone(),
536 })
537 }
538
539 fn calculate_gae(
541 &self,
542 rewards: &[f32],
543 values: &[f32],
544 next_value: f32,
545 dones: &[bool],
546 ) -> (Vec<f32>, Vec<f32>) {
547 let mut advantages = vec![0.0; rewards.len()];
548 let mut returns = vec![0.0; rewards.len()];
549
550 let mut gae = 0.0;
551 let mut next_val = next_value;
552
553 for t in (0..rewards.len()).rev() {
554 let done_mask = if dones[t] { 0.0 } else { 1.0 };
555 let delta = rewards[t] + self.gamma * next_val * done_mask - values[t];
556 gae = delta + self.gamma * self.gae_lambda * done_mask * gae;
557 advantages[t] = gae;
558 returns[t] = gae + values[t];
559 next_val = values[t];
560 }
561
562 (advantages, returns)
563 }
564
565 fn discrete_log_prob(
567 logits: &Tensor,
568 actions: &Tensor,
569 ) -> candle_core::error::Result<Tensor> {
570 let log_probs = candle_nn::ops::log_softmax(logits, 1)?;
571 log_probs.gather(&actions.unsqueeze(1)?, 1)?.squeeze(1)
572 }
573
574 fn continuous_log_prob(
576 mean: &Tensor,
577 std: &Tensor,
578 actions: &Tensor,
579 ) -> candle_core::error::Result<Tensor> {
580 let batch_size = mean.dims()[0];
582 let num_params = mean.dims()[1];
583
584 let std_broadcast = std.unsqueeze(0)?.broadcast_as(mean.shape())?;
587 let variance = std_broadcast.sqr()?;
588 let diff = (actions - mean)?;
589
590 let pi_constant = Tensor::new(
592 vec![2.0 * std::f32::consts::PI; batch_size * num_params],
593 mean.device()
594 )?.reshape(&[batch_size, num_params])?;
595
596 let log_prob = -0.5 * (
597 diff.sqr()?.div(&variance)? +
598 variance.log()? +
599 pi_constant.log()?
600 )?;
601
602 log_prob?.sum(1)
603 }
604
605 fn calculate_entropy(
607 logits: &Tensor,
608 std: &Tensor,
609 ) -> candle_core::error::Result<Tensor> {
610 let probs = candle_nn::ops::softmax(logits, 1)?;
612 let log_probs = candle_nn::ops::log_softmax(logits, 1)?;
613 let discrete_entropy = -1.0 * (probs * log_probs)?.sum(1)?.mean_all()?;
614
615 let num_params = std.dims()[0];
618 let constant = Tensor::new(
619 vec![0.5 * (1.0 + 2.0 * std::f32::consts::PI).ln(); num_params],
620 std.device()
621 )?;
622
623 let continuous_entropy = (std.log()? + constant)?.mean_all()?;
624
625 discrete_entropy + continuous_entropy
626 }
627
628 fn ppo_update(
630 &mut self,
631 states: &Tensor,
632 actions_discrete: &Tensor,
633 actions_continuous: &Tensor,
634 old_log_probs: &Tensor,
635 advantages: &Tensor,
636 returns: &Tensor,
637 ) -> Result<(f32, f32, f32)> {
638 let (action_logits, param_mean, param_std, values) =
640 self.network.forward(states, true)
641 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
642
643 let log_probs_discrete = Self::discrete_log_prob(&action_logits, actions_discrete)
645 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
646 let log_probs_continuous = Self::continuous_log_prob(¶m_mean, ¶m_std, actions_continuous)
647 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
648 let log_probs = (log_probs_discrete + log_probs_continuous)
649 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
650
651 let ratio = (log_probs.clone() - old_log_probs)?.exp()?;
653
654 let batch_size = advantages.dims()[0];
656
657 let adv_mean_scalar = advantages.mean_all()?.to_scalar::<f32>()?;
659 let adv_variance = advantages.sub(&Tensor::new(&[adv_mean_scalar], advantages.device())?.broadcast_as(advantages.shape())?)?.sqr()?.mean_all()?;
660 let adv_std_scalar = (adv_variance.to_scalar::<f32>()? + 1e-8).sqrt();
661
662 let adv_mean_broadcast = Tensor::new(vec![adv_mean_scalar; batch_size], advantages.device())?;
664 let adv_std_broadcast = Tensor::new(vec![adv_std_scalar; batch_size], advantages.device())?;
665
666 let advantages_norm = ((advantages - &adv_mean_broadcast)? / &adv_std_broadcast)?;
668
669 let surr1 = (ratio.clone() * &advantages_norm)?;
670
671 let ratio_clipped = ratio.clamp(1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon)?;
672 let surr2 = (ratio_clipped * advantages_norm)?;
673
674 let policy_loss = (-1.0 * surr1.minimum(&surr2)?.mean_all()?)?;
675
676 let value_loss = (values - returns)?.sqr()?.mean_all()?;
678
679 let entropy = Self::calculate_entropy(&action_logits, ¶m_std)?;
681
682 let value_loss_weighted = value_loss.to_scalar::<f32>()? * self.value_loss_coef;
684 let entropy_weighted = entropy.to_scalar::<f32>()? * self.entropy_coef;
685 let policy_loss_scalar = policy_loss.to_scalar::<f32>()?;
686
687 let total_loss_scalar = policy_loss_scalar + value_loss_weighted - entropy_weighted;
688
689 let total_loss = Tensor::new(&[total_loss_scalar], policy_loss.device())?;
691
692 let grads = total_loss.backward()
694 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
695
696 self.optimizer.step(&grads)
697 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
698
699 Ok((
700 policy_loss_scalar,
701 value_loss.to_scalar::<f32>()?,
702 entropy.to_scalar::<f32>()?,
703 ))
704 }
705
706 pub fn load_with_device(
707 path: &Path,
708 state_dim: usize,
709 num_actions: usize,
710 num_params: usize,
711 device: &Device,
712 ) -> Result<Self> {
713 let (network, varmap) = ActorCriticNetwork::load_from_file(
714 path, state_dim, num_actions, num_params, device
715 )?;
716
717 let trainable_vars = varmap.all_vars();
719 let params = ParamsAdamW {
720 lr: 3e-4,
721 beta1: 0.9,
722 beta2: 0.999,
723 eps: 1e-8,
724 weight_decay: 0.0,
725 };
726
727 let optimizer = AdamW::new(trainable_vars, params)
728 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
729
730 Ok(Self {
731 network,
732 optimizer,
733 varmap,
734 clip_epsilon: 0.2,
735 gae_lambda: 0.95,
736 value_loss_coef: 0.5,
737 entropy_coef: 0.01,
738 ppo_epochs: 4,
739 num_actions,
740 num_params,
741 gamma: 0.95,
742 step_count: 0,
743 device: device.clone(),
744 })
745 }
746}
747impl RLAgent for PPOAgent {
748 fn select_action(&self, state: &[f32], _epsilon: f32) -> Result<(usize, Vec<f32>)> {
749 let state_tensor = Tensor::from_vec(
751 state.to_vec(),
752 &[1, state.len()],
753 &self.device
754 )?;
755 let (action_logits, param_mean, param_std, _value) =
756 self.network.forward(&state_tensor, false)?;
757
758 let probs = candle_nn::ops::softmax(&action_logits, 1)?.to_vec2::<f32>()?;
760 let discrete_action = sample_categorical(&probs[0]);
761
762 let mean_vec = param_mean.to_vec2::<f32>()?;
764 let std_vec = param_std.to_vec1::<f32>()?;
765 let continuous_params = sample_gaussian(&mean_vec[0], &std_vec);
766
767 Ok((discrete_action, continuous_params))
768 }
769
770 fn save_with_metadata(
771 &self,
772 path: &Path,
773 training_episodes: usize,
774 hyperparameters: HashMap<String, f64>,
775 ) -> Result<()> {
776 let metadata = ModelMetadata::new(
777 300,
778 self.num_actions,
779 self.num_params,
780 AlgorithmType::PPO,
781 training_episodes,
782 hyperparameters,
783 );
784
785 self.network.save_to_file(path, metadata)
786 }
787
788 fn save(&self, path: &Path) -> Result<()> {
789 self.save_with_metadata(path, 0, std::collections::HashMap::new())
790 }
791
792 fn train_step(
793 &mut self,
794 replay_buffer: &mut PrioritizedReplayBuffer,
795 batch_size: usize,
796 ) -> Result<f32> {
797 let batch = replay_buffer.sample(batch_size);
798 if batch.is_none() {
799 return Ok(0.0);
800 }
801
802 let batch = batch.unwrap();
803 let experiences = &batch.experiences;
804
805 if experiences.is_empty() {
806 return Ok(0.0);
807 }
808
809 let state_dim = experiences[0].state.len();
811 let states_flat: Vec<f32> = experiences.iter()
812 .flat_map(|e| e.state.clone())
813 .collect();
814 let states_tensor = Tensor::from_vec(
815 states_flat,
816 &[experiences.len(), state_dim],
817 &self.device
818 )?;
819
820 let (old_logits, old_means, old_stds, old_values) =
822 self.network.forward(&states_tensor, false)?;
823
824 let actions_discrete: Vec<i64> = experiences.iter()
826 .map(|e| e.action.0 as i64)
827 .collect();
828 let actions_discrete_tensor = Tensor::from_vec(
829 actions_discrete,
830 &[experiences.len()],
831 &self.device
832 )?;
833
834 let actions_continuous_flat: Vec<f32> = experiences.iter()
835 .flat_map(|e| e.action.1.clone())
836 .collect();
837 let actions_continuous_tensor = Tensor::from_vec(
838 actions_continuous_flat,
839 &[experiences.len(), self.num_params],
840 &self.device
841 )?;
842
843 let old_log_probs_discrete = Self::discrete_log_prob(&old_logits, &actions_discrete_tensor)?;
845 let old_log_probs_continuous = Self::continuous_log_prob(&old_means, &old_stds, &actions_continuous_tensor)?;
846 let old_log_probs = (old_log_probs_discrete + old_log_probs_continuous)?;
847
848 let rewards: Vec<f32> = experiences.iter().map(|e| e.reward).collect();
850 let values_vec: Vec<f32> = old_values.to_vec1()?;
851 let dones: Vec<bool> = experiences.iter().map(|e| e.done).collect();
852
853 let (advantages, returns) = self.calculate_gae(
854 &rewards,
855 &values_vec,
856 0.0,
857 &dones,
858 );
859
860 let advantages_tensor = Tensor::from_vec(advantages, &[experiences.len()], &self.device)?;
861 let returns_tensor = Tensor::from_vec(returns, &[experiences.len()], &self.device)?;
862
863 let mut total_policy_loss = 0.0;
865 let mut total_value_loss = 0.0;
866 let mut _total_entropy = 0.0;
867
868 for _ in 0..self.ppo_epochs {
869 let (policy_loss, value_loss, entropy) = self.ppo_update(
870 &states_tensor,
871 &actions_discrete_tensor,
872 &actions_continuous_tensor,
873 &old_log_probs,
874 &advantages_tensor,
875 &returns_tensor,
876 )?;
877
878 total_policy_loss += policy_loss;
879 total_value_loss += value_loss;
880 _total_entropy += entropy;
881 }
882
883 self.step_count += 1;
884
885 let avg_loss = (total_policy_loss + total_value_loss) / self.ppo_epochs as f32;
886 Ok(avg_loss)
887 }
888
889 fn update_target_network(&mut self) {
890 }
892
893 fn get_step_count(&self) -> usize {
894 self.step_count
895 }
896
897 fn algorithm_type(&self) -> AlgorithmType {
898 AlgorithmType::PPO
899 }
900
901 fn get_info(&self) -> AgentInfo {
902 AgentInfo {
903 algorithm: AlgorithmType::PPO,
904 num_parameters: 0, state_dim: 0,
906 num_actions: self.num_actions,
907 continuous_params: self.num_params,
908 version: "1.0.0".to_string(),
909 features: vec![
910 "actor_critic".to_string(),
911 "clipped_objective".to_string(),
912 "gae".to_string(),
913 "entropy_bonus".to_string(),
914 ],
915 }
916 }
917}
918
919#[cfg(debug_assertions)]
926#[allow(dead_code)]
927fn debug_tensor_shape(name: &str, tensor: &Tensor) {
928 eprintln!("DEBUG: {} shape: {:?}", name, tensor.dims());
929}
930
931#[cfg(not(debug_assertions))]
932fn debug_tensor_shape(_name: &str, _tensor: &Tensor) {
933 }