1use candle_core::{Device, Tensor, DType, Var};
6use candle_nn::{VarBuilder, Optimizer, AdamW, ParamsAdamW, VarMap, Linear, Module, linear, layer_norm, LayerNorm};
7use crate::replay_buffer::{PrioritizedReplayBuffer};
8use crate::{Result, agents::{RLAgent, AlgorithmType, AgentInfo}};
9use rand::RngExt;
10use rand_distr::{Normal, Distribution};
11use std::path::{Path, PathBuf};
12use crate::models::ModelMetadata;
13use std::collections::HashMap;
14
15
16fn sample_categorical(probs: &[f32]) -> usize {
18 let mut rng = rand::rng();
19 let random_val: f32 = rng.random();
20 let mut cumsum = 0.0;
21 for (i, &prob) in probs.iter().enumerate() {
22 cumsum += prob;
23 if random_val < cumsum {
24 return i;
25 }
26 }
27 probs.len() - 1
28}
29fn sample_gaussian(means: &[f32], stds: &[f32]) -> Vec<f32> {
30 let mut rng = rand::rng();
31 means.iter().zip(stds.iter())
32 .map(|(&mean, &std)| {
33 let normal = Normal::new(mean, std).unwrap_or_else(|_| Normal::new(0.0, 1.0).unwrap());
34 normal.sample(&mut rng)
35 })
36 .collect()
37}
38
39fn save_linear_helper(
41 tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>,
42 name: &str,
43 linear: &Linear
44) -> Result<()> {
45 let weight = linear.weight();
46 let weight_shape = weight.dims().to_vec();
47 let weight_data = weight.flatten_all()
48 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
49 .to_vec1::<f32>()
50 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
51 tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
52
53 if let Some(bias) = linear.bias() {
54 let bias_shape = bias.dims().to_vec();
55 let bias_data = bias.flatten_all()
56 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
57 .to_vec1::<f32>()
58 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
59 tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
60 }
61 Ok(())
62}
63
64fn save_layernorm_helper(
65 tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>,
66 name: &str,
67 ln: &LayerNorm
68) -> Result<()> {
69 let weight = ln.weight();
70 let weight_shape = weight.dims().to_vec();
71 let weight_data = weight.flatten_all()
72 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
73 .to_vec1::<f32>()
74 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
75 tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
76
77 if let Some(bias) = ln.bias() {
78 let bias_shape = bias.dims().to_vec();
79 let bias_data = bias.flatten_all()
80 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
81 .to_vec1::<f32>()
82 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
83 tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
84 }
85 Ok(())
86}
87
88#[allow(dead_code)]
90pub struct ActorCriticNetwork {
91 fc1: Linear,
93 ln1: LayerNorm,
94 fc2: Linear,
95 ln2: LayerNorm,
96 fc3: Linear,
97 ln3: LayerNorm,
98 actor_discrete: Linear,
100 actor_param_mean: Linear,
101 actor_param_logstd: Var, critic_fc1: Linear,
105 critic_fc2: Linear,
106
107 device: Device,
108 num_actions: usize,
109 num_params: usize,
110}
111
112
113impl ActorCriticNetwork {
114
115 pub fn new(
116 state_dim: usize,
117 num_actions: usize,
118 num_params: usize,
119 vb: VarBuilder,
120 ) -> candle_core::error::Result<Self> {
121 let device = vb.device().clone();
122 let fc1 = linear(state_dim, 512, vb.pp("fc1"))?;
124 let ln1 = layer_norm(512, 1e-5, vb.pp("ln1"))?;
125 let fc2 = linear(512, 256, vb.pp("fc2"))?;
126 let ln2 = layer_norm(256, 1e-5, vb.pp("ln2"))?;
127 let fc3 = linear(256, 128, vb.pp("fc3"))?;
128 let ln3 = layer_norm(128, 1e-5, vb.pp("ln3"))?;
129
130 let actor_discrete = linear(128, num_actions, vb.pp("actor_discrete"))?;
132 let actor_param_mean = linear(128, num_params, vb.pp("actor_param_mean"))?;
133
134 let logstd_init = Tensor::from_vec(
136 vec![-1.0f32; num_params],
137 &[num_params],
138 &device
139 )?;
140 let actor_param_logstd = Var::from_tensor(&logstd_init)?;
141
142 let critic_fc1 = linear(128, 64, vb.pp("critic_fc1"))?;
144 let critic_fc2 = linear(64, 1, vb.pp("critic_fc2"))?;
145
146 Ok(Self {
147 fc1, ln1, fc2, ln2, fc3, ln3,
148 actor_discrete,
149 actor_param_mean,
150 actor_param_logstd,
151 critic_fc1,
152 critic_fc2,
153 device,
154 num_actions,
155 num_params,
156 })
157 }
158
159 pub fn forward(
160 &self,
161 state: &Tensor,
162 _training: bool,
163 ) -> candle_core::error::Result<(Tensor, Tensor, Tensor, Tensor)> {
164 let mut x = self.fc1.forward(state)?;
166 x = self.ln1.forward(&x)?;
167 x = x.relu()?;
168
169 x = self.fc2.forward(&x)?;
170 x = self.ln2.forward(&x)?;
171 x = x.relu()?;
172
173 x = self.fc3.forward(&x)?;
174 x = self.ln3.forward(&x)?;
175 let features = x.relu()?;
176
177 let action_logits = self.actor_discrete.forward(&features)?;
179 let param_mean = self.actor_param_mean.forward(&features)?.tanh()?;
180 let param_std = self.actor_param_logstd.as_tensor().exp()?;
181
182 let mut value = self.critic_fc1.forward(&features)?;
184 value = value.relu()?;
185 let value = self.critic_fc2.forward(&value)?.squeeze(1)?;
186
187 Ok((action_logits, param_mean, param_std, value))
188 }
189
190 pub fn save_to_file(&self, path: &Path, metadata: ModelMetadata) -> Result<()> {
192 use std::fs::File;
195 use std::io::Write;
196 let mut file = File::create(path)?;
197
198 let metadata_json = serde_json::to_string(&metadata)
200 .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
201 let metadata_bytes = metadata_json.as_bytes();
202 let metadata_len = metadata_bytes.len() as u64;
203
204 file.write_all(&metadata_len.to_le_bytes())?;
205 file.write_all(metadata_bytes)?;
206
207 let mut tensors: HashMap<String, (Vec<usize>, Vec<f32>)> = HashMap::new();
209
210 save_linear_helper(&mut tensors, "fc1", &self.fc1)?;
212 save_layernorm_helper(&mut tensors, "ln1", &self.ln1)?;
213 save_linear_helper(&mut tensors, "fc2", &self.fc2)?;
214 save_layernorm_helper(&mut tensors, "ln2", &self.ln2)?;
215 save_linear_helper(&mut tensors, "fc3", &self.fc3)?;
216 save_layernorm_helper(&mut tensors, "ln3", &self.ln3)?;
217
218 save_linear_helper(&mut tensors, "actor_discrete", &self.actor_discrete)?;
219 save_linear_helper(&mut tensors, "actor_param_mean", &self.actor_param_mean)?;
220
221 let logstd_tensor = self.actor_param_logstd.as_tensor();
223 let logstd_shape = logstd_tensor.dims().to_vec();
224 let logstd_data = logstd_tensor.flatten_all()
225 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
226 .to_vec1::<f32>()
227 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
228 tensors.insert("actor_param_logstd".to_string(), (logstd_shape, logstd_data));
229
230 save_linear_helper(&mut tensors, "critic_fc1", &self.critic_fc1)?;
231 save_linear_helper(&mut tensors, "critic_fc2", &self.critic_fc2)?;
232
233 let tensor_count = tensors.len() as u64;
235 file.write_all(&tensor_count.to_le_bytes())?;
236
237 for (name, (shape, data)) in tensors.iter() {
239 let name_bytes = name.as_bytes();
240 let name_len = name_bytes.len() as u64;
241 file.write_all(&name_len.to_le_bytes())?;
242 file.write_all(name_bytes)?;
243
244 let shape_len = shape.len() as u64;
245 file.write_all(&shape_len.to_le_bytes())?;
246 for &dim in shape {
247 file.write_all(&(dim as u64).to_le_bytes())?;
248 }
249
250 let data_len = data.len() as u64;
251 file.write_all(&data_len.to_le_bytes())?;
252 for &value in data {
253 file.write_all(&value.to_le_bytes())?;
254 }
255 }
256
257 let file_size = std::fs::metadata(path)?.len();
258 tracing::info!("PPO model saved: {} bytes", file_size);
259
260 Ok(())
261 }
262
263 pub fn load_from_file(
265 path: &Path,
266 state_dim: usize,
267 num_actions: usize,
268 num_params: usize,
269 device: &Device,
270 ) -> Result<(Self, VarMap)> { use std::fs::File;
272 use std::io::Read;
273
274 tracing::info!("Loading PPO model from: {}", path.display());
275
276 let mut file = File::open(path)?;
277
278 let mut metadata_len_bytes = [0u8; 8];
280 file.read_exact(&mut metadata_len_bytes)?;
281 let metadata_len = u64::from_le_bytes(metadata_len_bytes) as usize;
282 if metadata_len > 10 * 1024 * 1024 {
283 return Err(crate::ExtractionError::ParseError(format!("Invalid model file: metadata length {} is too large", metadata_len)));
284 }
285
286 let mut metadata_bytes = vec![0u8; metadata_len];
287 file.read_exact(&mut metadata_bytes)?;
288
289 let metadata_json = String::from_utf8(metadata_bytes)
290 .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
291 let _metadata: ModelMetadata = serde_json::from_str(&metadata_json)
292 .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
293
294 tracing::info!("Model metadata loaded, loading tensors...");
295
296 let mut tensor_count_bytes = [0u8; 8];
298 file.read_exact(&mut tensor_count_bytes)?;
299 let tensor_count = u64::from_le_bytes(tensor_count_bytes) as usize;
300
301 let mut tensors: HashMap<String, Tensor> = HashMap::new();
302
303 for _ in 0..tensor_count {
304 let mut name_len_bytes = [0u8; 8];
305 file.read_exact(&mut name_len_bytes)?;
306 let name_len = u64::from_le_bytes(name_len_bytes) as usize;
307
308 let mut name_bytes = vec![0u8; name_len];
309 file.read_exact(&mut name_bytes)?;
310 let name = String::from_utf8(name_bytes)
311 .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
312
313 let mut shape_len_bytes = [0u8; 8];
314 file.read_exact(&mut shape_len_bytes)?;
315 let shape_len = u64::from_le_bytes(shape_len_bytes) as usize;
316
317 let mut shape = Vec::with_capacity(shape_len);
318 for _ in 0..shape_len {
319 let mut dim_bytes = [0u8; 8];
320 file.read_exact(&mut dim_bytes)?;
321 shape.push(u64::from_le_bytes(dim_bytes) as usize);
322 }
323
324 let mut data_len_bytes = [0u8; 8];
325 file.read_exact(&mut data_len_bytes)?;
326 let data_len = u64::from_le_bytes(data_len_bytes) as usize;
327
328 let mut data = Vec::with_capacity(data_len);
329 for _ in 0..data_len {
330 let mut value_bytes = [0u8; 4];
331 file.read_exact(&mut value_bytes)?;
332 data.push(f32::from_le_bytes(value_bytes));
333 }
334
335 let tensor = Tensor::from_vec(data, shape.as_slice(), device)
336 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
337 tensors.insert(name, tensor);
338 }
339
340 tracing::info!("Loaded {} tensors, reconstructing model...", tensors.len());
341
342 let mut varmap = VarMap::new();
344 let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
345 let mut network = ActorCriticNetwork::new(state_dim, num_actions, num_params, vb)
346 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
347
348 for (name, tensor) in tensors.iter() {
349 if name == "actor_param_logstd" {
350 network.actor_param_logstd = Var::from_tensor(tensor)
351 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
352 } else {
353 varmap.set_one(name, tensor)
354 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
355 }
356 }
357
358 Ok((network, varmap))
359 }
360
361 pub(crate) fn logstd_var(&self) -> &Var {
363 &self.actor_param_logstd
364 }
365
366 pub fn load_with_device(
368 path: &Path,
369 state_dim: usize,
370 num_actions: usize,
371 num_params: usize,
372 device: &Device,
373 ) -> Result<(Self, VarMap)> {
374 Self::load_from_file(path, state_dim, num_actions, num_params, device)
375 }
376
377 #[allow(dead_code)]
379 pub(crate) fn save_to_safetensors(&self, path: &PathBuf) -> Result<()> {
380 use safetensors::tensor::{Dtype, TensorView};
381 use std::collections::HashMap;
382
383 let mut tensors_data: HashMap<String, TensorView> = HashMap::new();
384 let mut all_tensor_bytes: Vec<(String, Vec<usize>, Vec<u8>)> = Vec::new();
385
386 let mut collect_tensor = |name: &str, tensor: &Tensor| -> Result<()> {
388 let shape = tensor.dims().to_vec();
389 let data = tensor.flatten_all()
390 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
391 .to_vec1::<f32>()
392 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
393 let bytes: Vec<u8> = data.iter()
394 .flat_map(|&f| f.to_le_bytes())
395 .collect();
396
397 all_tensor_bytes.push((name.to_string(), shape, bytes));
398 Ok(())
399 };
400
401 collect_tensor("fc1.weight", self.fc1.weight())?;
403 if let Some(bias) = self.fc1.bias() {
404 collect_tensor("fc1.bias", bias)?;
405 }
406
407 collect_tensor("ln1.weight", self.ln1.weight())?;
408 if let Some(bias) = self.ln1.bias() {
409 collect_tensor("ln1.bias", bias)?;
410 }
411
412 collect_tensor("fc2.weight", self.fc2.weight())?;
413 if let Some(bias) = self.fc2.bias() {
414 collect_tensor("fc2.bias", bias)?;
415 }
416
417 collect_tensor("ln2.weight", self.ln2.weight())?;
418 if let Some(bias) = self.ln2.bias() {
419 collect_tensor("ln2.bias", bias)?;
420 }
421
422 collect_tensor("fc3.weight", self.fc3.weight())?;
423 if let Some(bias) = self.fc3.bias() {
424 collect_tensor("fc3.bias", bias)?;
425 }
426
427 collect_tensor("ln3.weight", self.ln3.weight())?;
428 if let Some(bias) = self.ln3.bias() {
429 collect_tensor("ln3.bias", bias)?;
430 }
431
432 collect_tensor("actor_discrete.weight", self.actor_discrete.weight())?;
433 if let Some(bias) = self.actor_discrete.bias() {
434 collect_tensor("actor_discrete.bias", bias)?;
435 }
436
437 collect_tensor("actor_param_mean.weight", self.actor_param_mean.weight())?;
438 if let Some(bias) = self.actor_param_mean.bias() {
439 collect_tensor("actor_param_mean.bias", bias)?;
440 }
441
442 collect_tensor("actor_param_logstd", self.actor_param_logstd.as_tensor())?;
443
444 collect_tensor("critic_fc1.weight", self.critic_fc1.weight())?;
445 if let Some(bias) = self.critic_fc1.bias() {
446 collect_tensor("critic_fc1.bias", bias)?;
447 }
448
449 collect_tensor("critic_fc2.weight", self.critic_fc2.weight())?;
450 if let Some(bias) = self.critic_fc2.bias() {
451 collect_tensor("critic_fc2.bias", bias)?;
452 }
453
454 for (name, shape, bytes) in &all_tensor_bytes {
456 tensors_data.insert(
457 name.clone(),
458 TensorView::new(Dtype::F32, shape.clone(), bytes)
459 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
460 );
461 }
462
463 let serialized = safetensors::serialize(&tensors_data, None)
464 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
465
466 std::fs::write(path, serialized)?;
467
468 tracing::info!("PPO model saved to SafeTensors: {} bytes",
469 std::fs::metadata(path).map(|m| m.len()).unwrap_or(0));
470
471 Ok(())
472 }
473
474 #[allow(dead_code)]
476 pub(crate) fn save_to_onnx_with_metadata(&self, path: &Path, metadata: ModelMetadata) -> Result<()> {
477 self.save_to_file(path, metadata)
478 }
479
480}
481
482pub struct PPOAgent {
484 network: ActorCriticNetwork,
485 optimizer: AdamW,
486 #[allow(dead_code)]
487 varmap: VarMap,
488 clip_epsilon: f32,
490 gae_lambda: f32,
491 value_loss_coef: f32,
492 entropy_coef: f32,
493 ppo_epochs: usize,
494
495 state_dim: usize,
496 num_actions: usize,
497 num_params: usize,
498 gamma: f32,
499 step_count: usize,
500 device: Device,
501}
502
503impl PPOAgent {
504 pub fn new(
505 state_dim: usize,
506 num_actions: usize,
507 num_params: usize,
508 gamma: f32,
509 lr: f64,
510 device: &Device,
511 varmap: VarMap,
512 ) -> Result<Self> {
513 let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
514 let network = ActorCriticNetwork::new(state_dim, num_actions, num_params, vb)
515 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
516
517 let mut trainable_vars = varmap.all_vars();
519 trainable_vars.push(network.logstd_var().clone());
520
521 let params = ParamsAdamW {
522 lr,
523 beta1: 0.9,
524 beta2: 0.999,
525 eps: 1e-8,
526 weight_decay: 0.0,
527 };
528
529 let optimizer = AdamW::new(trainable_vars, params)
530 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
531
532 Ok(Self {
533 network,
534 optimizer,
535 varmap,
536 clip_epsilon: 0.2,
537 gae_lambda: 0.95,
538 value_loss_coef: 0.5,
539 entropy_coef: 0.01,
540 ppo_epochs: 4,
541 state_dim,
542 num_actions,
543 num_params,
544 gamma,
545 step_count: 0,
546 device: device.clone(),
547 })
548 }
549
550 fn calculate_gae(
552 &self,
553 rewards: &[f32],
554 values: &[f32],
555 next_value: f32,
556 dones: &[bool],
557 ) -> (Vec<f32>, Vec<f32>) {
558 let mut advantages = vec![0.0; rewards.len()];
559 let mut returns = vec![0.0; rewards.len()];
560
561 let mut gae = 0.0;
562 let mut next_val = next_value;
563
564 for t in (0..rewards.len()).rev() {
565 let done_mask = if dones[t] { 0.0 } else { 1.0 };
566 let delta = rewards[t] + self.gamma * next_val * done_mask - values[t];
567 gae = delta + self.gamma * self.gae_lambda * done_mask * gae;
568 advantages[t] = gae;
569 returns[t] = gae + values[t];
570 next_val = values[t];
571 }
572
573 (advantages, returns)
574 }
575
576 fn discrete_log_prob(
578 logits: &Tensor,
579 actions: &Tensor,
580 ) -> candle_core::error::Result<Tensor> {
581 let log_probs = candle_nn::ops::log_softmax(logits, 1)?;
582 log_probs.gather(&actions.unsqueeze(1)?, 1)?.squeeze(1)
583 }
584
585 fn continuous_log_prob(
587 mean: &Tensor,
588 std: &Tensor,
589 actions: &Tensor,
590 ) -> candle_core::error::Result<Tensor> {
591 let batch_size = mean.dims()[0];
593 let num_params = mean.dims()[1];
594
595 let std_broadcast = std.unsqueeze(0)?.broadcast_as(mean.shape())?;
598 let variance = std_broadcast.sqr()?;
599 let diff = (actions - mean)?;
600
601 let pi_constant = Tensor::new(
603 vec![2.0 * std::f32::consts::PI; batch_size * num_params],
604 mean.device()
605 )?.reshape(&[batch_size, num_params])?;
606
607 let log_prob = -0.5 * (
608 diff.sqr()?.div(&variance)? +
609 variance.log()? +
610 pi_constant.log()?
611 )?;
612
613 log_prob?.sum(1)
614 }
615
616 fn calculate_entropy(
618 logits: &Tensor,
619 std: &Tensor,
620 ) -> candle_core::error::Result<Tensor> {
621 let probs = candle_nn::ops::softmax(logits, 1)?;
623 let log_probs = candle_nn::ops::log_softmax(logits, 1)?;
624 let discrete_entropy = -1.0 * (probs * log_probs)?.sum(1)?.mean_all()?;
625
626 let num_params = std.dims()[0];
629 let constant = Tensor::new(
630 vec![0.5 * (1.0 + 2.0 * std::f32::consts::PI).ln(); num_params],
631 std.device()
632 )?;
633
634 let continuous_entropy = (std.log()? + constant)?.mean_all()?;
635
636 discrete_entropy + continuous_entropy
637 }
638
639 fn ppo_update(
641 &mut self,
642 states: &Tensor,
643 actions_discrete: &Tensor,
644 actions_continuous: &Tensor,
645 old_log_probs: &Tensor,
646 advantages: &Tensor,
647 returns: &Tensor,
648 ) -> Result<(f32, f32, f32)> {
649 let (action_logits, param_mean, param_std, values) =
651 self.network.forward(states, true)
652 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
653
654 let log_probs_discrete = Self::discrete_log_prob(&action_logits, actions_discrete)
656 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
657 let log_probs_continuous = Self::continuous_log_prob(¶m_mean, ¶m_std, actions_continuous)
658 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
659 let log_probs = (log_probs_discrete + log_probs_continuous)
660 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
661
662 let ratio = (log_probs.clone() - old_log_probs)?.exp()?;
664
665 let batch_size = advantages.dims()[0];
667
668 let adv_mean_scalar = advantages.mean_all()?.to_scalar::<f32>()?;
670 let adv_variance = advantages.sub(&Tensor::new(&[adv_mean_scalar], advantages.device())?.broadcast_as(advantages.shape())?)?.sqr()?.mean_all()?;
671 let adv_std_scalar = (adv_variance.to_scalar::<f32>()? + 1e-8).sqrt();
672
673 let adv_mean_broadcast = Tensor::new(vec![adv_mean_scalar; batch_size], advantages.device())?;
675 let adv_std_broadcast = Tensor::new(vec![adv_std_scalar; batch_size], advantages.device())?;
676
677 let advantages_norm = ((advantages - &adv_mean_broadcast)? / &adv_std_broadcast)?;
679
680 let surr1 = (ratio.clone() * &advantages_norm)?;
681
682 let ratio_clipped = ratio.clamp(1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon)?;
683 let surr2 = (ratio_clipped * advantages_norm)?;
684
685 let policy_loss = (-1.0 * surr1.minimum(&surr2)?.mean_all()?)?;
686
687 let value_loss = (values - returns)?.sqr()?.mean_all()?;
689
690 let entropy = Self::calculate_entropy(&action_logits, ¶m_std)?;
692
693 let value_loss_coef = self.value_loss_coef as f64;
695 let entropy_coef = self.entropy_coef as f64;
696 let total_loss = ((&policy_loss + (value_loss.clone() * value_loss_coef)?)? - (entropy.clone() * entropy_coef)?)?;
697
698 let total_loss_scalar = total_loss.to_scalar::<f32>()?;
700 let policy_loss_scalar = policy_loss.to_scalar::<f32>()?;
701
702 if total_loss_scalar.is_nan() || total_loss_scalar.is_infinite() {
703 return Err(crate::ExtractionError::ModelError(
704 format!("Invalid PPO loss: {}", total_loss_scalar)
705 ));
706 }
707
708 let grads = total_loss.backward()
710 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
711
712 self.optimizer.step(&grads)
713 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
714
715 Ok((
716 policy_loss_scalar,
717 value_loss.to_scalar::<f32>()?,
718 entropy.to_scalar::<f32>()?,
719 ))
720 }
721
722 pub fn load_with_device(
723 path: &Path,
724 state_dim: usize,
725 num_actions: usize,
726 num_params: usize,
727 device: &Device,
728 ) -> Result<Self> {
729 let (network, varmap) = ActorCriticNetwork::load_from_file(
730 path, state_dim, num_actions, num_params, device
731 )?;
732
733 let trainable_vars = varmap.all_vars();
735 let params = ParamsAdamW {
736 lr: 3e-4,
737 beta1: 0.9,
738 beta2: 0.999,
739 eps: 1e-8,
740 weight_decay: 0.0,
741 };
742
743 let optimizer = AdamW::new(trainable_vars, params)
744 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
745
746 Ok(Self {
747 network,
748 optimizer,
749 varmap,
750 clip_epsilon: 0.2,
751 gae_lambda: 0.95,
752 value_loss_coef: 0.5,
753 entropy_coef: 0.01,
754 ppo_epochs: 4,
755 state_dim,
756 num_actions,
757 num_params,
758 gamma: 0.95,
759 step_count: 0,
760 device: device.clone(),
761 })
762 }
763}
764impl RLAgent for PPOAgent {
765 fn select_action(&self, state: &[f32], _epsilon: f32) -> Result<(usize, Vec<f32>)> {
766 let state_tensor = Tensor::from_vec(
768 state.to_vec(),
769 &[1, state.len()],
770 &self.device
771 )?;
772 let (action_logits, param_mean, param_std, _value) =
773 self.network.forward(&state_tensor, false)?;
774
775 let probs = candle_nn::ops::softmax(&action_logits, 1)?.to_vec2::<f32>()?;
777 let discrete_action = sample_categorical(&probs[0]);
778
779 let mean_vec = param_mean.to_vec2::<f32>()?;
781 let std_vec = param_std.to_vec1::<f32>()?;
782 let continuous_params = sample_gaussian(&mean_vec[0], &std_vec);
783
784 Ok((discrete_action, continuous_params))
785 }
786
787 fn save_with_metadata(
788 &self,
789 path: &Path,
790 training_episodes: usize,
791 hyperparameters: HashMap<String, f64>,
792 ) -> Result<()> {
793 let metadata = ModelMetadata::new(
794 self.state_dim,
795 self.num_actions,
796 self.num_params,
797 AlgorithmType::PPO,
798 training_episodes,
799 hyperparameters,
800 );
801
802 self.network.save_to_file(path, metadata)?;
803
804 let safetensors_path = path.with_extension("safetensors");
805 self.network.save_to_safetensors(&safetensors_path)?;
806
807 tracing::info!("PPO model saved: ONNX ({} bytes), SafeTensors ({} bytes)",
808 std::fs::metadata(path).map(|m| m.len()).unwrap_or(0),
809 std::fs::metadata(&safetensors_path).map(|m| m.len()).unwrap_or(0));
810
811 Ok(())
812 }
813
814 fn save(&self, path: &Path) -> Result<()> {
815 self.save_with_metadata(path, 0, std::collections::HashMap::new())
816 }
817
818 fn train_step(
819 &mut self,
820 replay_buffer: &mut PrioritizedReplayBuffer,
821 batch_size: usize,
822 ) -> Result<f32> {
823 let batch = replay_buffer.sample(batch_size);
824 if batch.is_none() {
825 return Ok(0.0);
826 }
827
828 let batch = batch.unwrap();
829 let experiences = &batch.experiences;
830
831 if experiences.is_empty() {
832 return Ok(0.0);
833 }
834
835 let state_dim = experiences[0].state.len();
837 let states_flat: Vec<f32> = experiences.iter()
838 .flat_map(|e| e.state.clone())
839 .collect();
840 let states_tensor = Tensor::from_vec(
841 states_flat,
842 &[experiences.len(), state_dim],
843 &self.device
844 )?;
845
846 let (old_logits, old_means, old_stds, old_values) =
848 self.network.forward(&states_tensor, false)?;
849
850 let actions_discrete: Vec<i64> = experiences.iter()
852 .map(|e| e.action.0 as i64)
853 .collect();
854 let actions_discrete_tensor = Tensor::from_vec(
855 actions_discrete,
856 &[experiences.len()],
857 &self.device
858 )?;
859
860 let actions_continuous_flat: Vec<f32> = experiences.iter()
861 .flat_map(|e| e.action.1.clone())
862 .collect();
863 let actions_continuous_tensor = Tensor::from_vec(
864 actions_continuous_flat,
865 &[experiences.len(), self.num_params],
866 &self.device
867 )?;
868
869 let old_log_probs_discrete = Self::discrete_log_prob(&old_logits, &actions_discrete_tensor)?;
871 let old_log_probs_continuous = Self::continuous_log_prob(&old_means, &old_stds, &actions_continuous_tensor)?;
872 let old_log_probs = (old_log_probs_discrete + old_log_probs_continuous)?;
873
874 let rewards: Vec<f32> = experiences.iter().map(|e| e.reward).collect();
876 let values_vec: Vec<f32> = old_values.to_vec1()?;
877 let dones: Vec<bool> = experiences.iter().map(|e| e.done).collect();
878
879 let (advantages, returns) = self.calculate_gae(
880 &rewards,
881 &values_vec,
882 0.0,
883 &dones,
884 );
885
886 let advantages_tensor = Tensor::from_vec(advantages, &[experiences.len()], &self.device)?;
887 let returns_tensor = Tensor::from_vec(returns, &[experiences.len()], &self.device)?;
888
889 let mut total_policy_loss = 0.0;
891 let mut total_value_loss = 0.0;
892 let mut _total_entropy = 0.0;
893
894 for _ in 0..self.ppo_epochs {
895 let (policy_loss, value_loss, entropy) = self.ppo_update(
896 &states_tensor,
897 &actions_discrete_tensor,
898 &actions_continuous_tensor,
899 &old_log_probs,
900 &advantages_tensor,
901 &returns_tensor,
902 )?;
903
904 total_policy_loss += policy_loss;
905 total_value_loss += value_loss;
906 _total_entropy += entropy;
907 }
908
909 self.step_count += 1;
910
911 let avg_loss = (total_policy_loss + total_value_loss) / self.ppo_epochs as f32;
912 Ok(avg_loss)
913 }
914
915 fn update_target_network(&mut self) {
916 }
918
919 fn get_step_count(&self) -> usize {
920 self.step_count
921 }
922
923 fn algorithm_type(&self) -> AlgorithmType {
924 AlgorithmType::PPO
925 }
926
927 fn get_info(&self) -> AgentInfo {
928 let sd = self.state_dim;
930 let na = self.num_actions;
931 let np = self.num_params;
932 let num_parameters =
933 (sd * 512 + 512) + (512 * 2) + (512 * 256 + 256) + (256 * 2) + (256 * 128 + 128) + (128 * 2) + (128 * na + na) + (128 * np + np) + np + (128 * 64 + 64) + (64 * 1 + 1); AgentInfo {
946 algorithm: AlgorithmType::PPO,
947 num_parameters,
948 state_dim: self.state_dim,
949 num_actions: self.num_actions,
950 continuous_params: self.num_params,
951 version: "1.0.0".to_string(),
952 features: vec![
953 "actor_critic".to_string(),
954 "clipped_objective".to_string(),
955 "gae".to_string(),
956 "entropy_bonus".to_string(),
957 ],
958 }
959 }
960}
961
962#[cfg(debug_assertions)]
969#[allow(dead_code)]
970fn debug_tensor_shape(name: &str, tensor: &Tensor) {
971 eprintln!("DEBUG: {} shape: {:?}", name, tensor.dims());
972}
973
974#[cfg(not(debug_assertions))]
975fn debug_tensor_shape(_name: &str, _tensor: &Tensor) {
976 }