1use candle_core::{Device, Tensor, DType, Var};
7use candle_nn::{VarBuilder, Optimizer, AdamW, ParamsAdamW, VarMap, Linear, Module, linear, layer_norm, LayerNorm};
8use crate::replay_buffer::PrioritizedReplayBuffer;
9use crate::{Result, agents::{RLAgent, AlgorithmType, AgentInfo}};
10use tracing::info;
11use std::path::Path;
12use std::collections::HashMap;
13use crate::models::ModelMetadata;
14use candle_nn::ops::softmax;
15
16#[allow(dead_code)]
18pub struct SACActorNetwork {
19 fc1: Linear,
20 ln1: LayerNorm,
21 fc2: Linear,
22 ln2: LayerNorm,
23 fc3: Linear,
24 ln3: LayerNorm,
25 action_logits: Linear,
27
28 param_mean: Linear,
30 param_logstd: Linear,
31
32 device: Device,
33 num_actions: usize,
34 num_params: usize,
35}
36impl SACActorNetwork {
37 pub fn new(
38 state_dim: usize,
39 num_actions: usize,
40 num_params: usize,
41 vb: VarBuilder,
42 ) -> candle_core::error::Result<Self> {
43 let device = vb.device().clone();
44 let fc1 = linear(state_dim, 512, vb.pp("fc1"))?;
45 let ln1 = layer_norm(512, 1e-5, vb.pp("ln1"))?;
46 let fc2 = linear(512, 256, vb.pp("fc2"))?;
47 let ln2 = layer_norm(256, 1e-5, vb.pp("ln2"))?;
48 let fc3 = linear(256, 128, vb.pp("fc3"))?;
49 let ln3 = layer_norm(128, 1e-5, vb.pp("ln3"))?;
50
51 let action_logits = linear(128, num_actions, vb.pp("action_logits"))?;
52 let param_mean = linear(128, num_params, vb.pp("param_mean"))?;
53 let param_logstd = linear(128, num_params, vb.pp("param_logstd"))?;
54
55 Ok(Self {
56 fc1, ln1, fc2, ln2, fc3, ln3,
57 action_logits,
58 param_mean,
59 param_logstd,
60 device,
61 num_actions,
62 num_params,
63 })
64 }
65
66 pub fn forward(&self, state: &Tensor) -> candle_core::error::Result<(Tensor, Tensor, Tensor)> {
67 let mut x = self.fc1.forward(state)?;
68 x = self.ln1.forward(&x)?;
69 x = x.relu()?;
70
71 x = self.fc2.forward(&x)?;
72 x = self.ln2.forward(&x)?;
73 x = x.relu()?;
74
75 x = self.fc3.forward(&x)?;
76 x = self.ln3.forward(&x)?;
77 let features = x.relu()?;
78
79 let action_logits = self.action_logits.forward(&features)?;
80 let param_mean = self.param_mean.forward(&features)?.tanh()?;
81 let param_logstd = self.param_logstd.forward(&features)?.clamp(-20.0, 2.0)?;
82
83 Ok((action_logits, param_mean, param_logstd))
84 }
85}
86
87#[allow(dead_code)]
89pub struct SACCriticNetwork {
90 q1_fc1: Linear,
92 q1_ln1: LayerNorm,
93 q1_fc2: Linear,
94 q1_ln2: LayerNorm,
95 q1_output: Linear,
96 q2_fc1: Linear,
98 q2_ln1: LayerNorm,
99 q2_fc2: Linear,
100 q2_ln2: LayerNorm,
101 q2_output: Linear,
102
103 num_actions: usize,
104 num_params: usize,
105}
106impl SACCriticNetwork {
107 pub fn new(
108 state_dim: usize,
109 num_actions: usize,
110 num_params: usize,
111 vb: VarBuilder,
112 ) -> candle_core::error::Result<Self> {
113 let input_dim = state_dim + num_actions + num_params;
115 let q1_fc1 = linear(input_dim, 512, vb.pp("q1_fc1"))?;
117 let q1_ln1 = layer_norm(512, 1e-5, vb.pp("q1_ln1"))?;
118 let q1_fc2 = linear(512, 256, vb.pp("q1_fc2"))?;
119 let q1_ln2 = layer_norm(256, 1e-5, vb.pp("q1_ln2"))?;
120 let q1_output = linear(256, 1, vb.pp("q1_output"))?;
121
122 let q2_fc1 = linear(input_dim, 512, vb.pp("q2_fc1"))?;
124 let q2_ln1 = layer_norm(512, 1e-5, vb.pp("q2_ln1"))?;
125 let q2_fc2 = linear(512, 256, vb.pp("q2_fc2"))?;
126 let q2_ln2 = layer_norm(256, 1e-5, vb.pp("q2_ln2"))?;
127 let q2_output = linear(256, 1, vb.pp("q2_output"))?;
128
129 Ok(Self {
130 q1_fc1, q1_ln1, q1_fc2, q1_ln2, q1_output,
131 q2_fc1, q2_ln1, q2_fc2, q2_ln2, q2_output,
132 num_actions,
133 num_params,
134 })
135 }
136
137 pub fn forward(
138 &self,
139 state: &Tensor,
140 action_discrete: &Tensor,
141 action_continuous: &Tensor,
142 ) -> candle_core::error::Result<(Tensor, Tensor)> {
143 let state_action = Tensor::cat(&[state, action_discrete, action_continuous], 1)?;
145
146 let mut x1 = self.q1_fc1.forward(&state_action)?;
148 x1 = self.q1_ln1.forward(&x1)?;
149 x1 = x1.relu()?;
150 x1 = self.q1_fc2.forward(&x1)?;
151 x1 = self.q1_ln2.forward(&x1)?;
152 x1 = x1.relu()?;
153 let q1 = self.q1_output.forward(&x1)?.squeeze(1)?;
154
155 let mut x2 = self.q2_fc1.forward(&state_action)?;
157 x2 = self.q2_ln1.forward(&x2)?;
158 x2 = x2.relu()?;
159 x2 = self.q2_fc2.forward(&x2)?;
160 x2 = self.q2_ln2.forward(&x2)?;
161 x2 = x2.relu()?;
162 let q2 = self.q2_output.forward(&x2)?.squeeze(1)?;
163
164 Ok((q1, q2))
165 }
166}
167
168pub struct SACAgent {
170 actor: SACActorNetwork,
171 critic: SACCriticNetwork,
172 target_critic: SACCriticNetwork,
173 actor_optimizer: AdamW,
174 critic_optimizer: AdamW,
175
176 log_alpha: Var,
178 alpha_optimizer: AdamW,
179 target_entropy: f32,
180
181 #[allow(dead_code)]
182 actor_varmap: VarMap,
183 #[allow(dead_code)]
184 critic_varmap: VarMap,
185 #[allow(dead_code)]
186 alpha_varmap: VarMap,
187
188 num_actions: usize,
189 num_params: usize,
190 gamma: f32,
191 tau: f32, step_count: usize,
193 device: Device,
194}
195
196
197fn save_linear_helper(
198 tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>,
199 name: &str,
200 linear: &Linear
201) -> Result<()> {
202 let weight = linear.weight();
203 let weight_shape = weight.dims().to_vec();
204 let weight_data = weight.flatten_all()
205 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
206 .to_vec1::<f32>()
207 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
208 tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
209
210 if let Some(bias) = linear.bias() {
211 let bias_shape = bias.dims().to_vec();
212 let bias_data = bias.flatten_all()
213 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
214 .to_vec1::<f32>()
215 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
216 tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
217 }
218 Ok(())
219}
220
221fn save_layernorm_helper(
222 tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>,
223 name: &str,
224 ln: &LayerNorm
225) -> Result<()> {
226 let weight = ln.weight();
227 let weight_shape = weight.dims().to_vec();
228 let weight_data = weight.flatten_all()
229 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
230 .to_vec1::<f32>()
231 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
232 tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
233
234 if let Some(bias) = ln.bias() {
235 let bias_shape = bias.dims().to_vec();
236 let bias_data = bias.flatten_all()
237 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
238 .to_vec1::<f32>()
239 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
240 tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
241 }
242 Ok(())
243}
244
245fn soft_update_linear(
247 target: &Linear,
248 source: &Linear,
249 _tau: f32,
250 _device: &Device,
251) -> candle_core::error::Result<()> {
252 let _source_weight = source.weight();
258 let _target_weight = target.weight();
259
260 Ok(())
264}
265
266fn soft_update_layernorm(
268 target: &LayerNorm,
269 source: &LayerNorm,
270 _tau: f32,
271 _device: &Device,
272) -> candle_core::error::Result<()> {
273 let _source_weight = source.weight();
274 let _target_weight = target.weight();
275
276 Ok(())
279}
280
281impl SACAgent {
282 #[allow(clippy::too_many_arguments)]
283 pub fn new(
284 state_dim: usize,
285 num_actions: usize,
286 num_params: usize,
287 gamma: f32,
288 lr: f64,
289 device: &Device,
290 actor_varmap: VarMap,
291 critic_varmap: VarMap,
292 ) -> Result<Self> {
293 let actor_vb = VarBuilder::from_varmap(&actor_varmap, DType::F32, device);
295 let actor = SACActorNetwork::new(state_dim, num_actions, num_params, actor_vb)
296 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
297
298 let critic_vb = VarBuilder::from_varmap(&critic_varmap, DType::F32, device);
300 let critic = SACCriticNetwork::new(state_dim, num_actions, num_params, critic_vb.pp("online"))
301 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
302
303 let target_critic_varmap = VarMap::new();
304 let target_vb = VarBuilder::from_varmap(&target_critic_varmap, DType::F32, device);
305 let target_critic = SACCriticNetwork::new(state_dim, num_actions, num_params, target_vb.pp("target"))
306 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
307
308 let alpha_varmap = VarMap::new();
310 let log_alpha_init = Tensor::zeros(&[], DType::F32, device)?;
312 let log_alpha = Var::from_tensor(&log_alpha_init)?;
313
314 let target_entropy = -(num_actions as f32 + num_params as f32);
316
317 let actor_params = ParamsAdamW { lr, beta1: 0.9, beta2: 0.999, eps: 1e-8, weight_decay: 0.0 };
319 let actor_optimizer = AdamW::new(actor_varmap.all_vars(), actor_params)
320 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
321
322 let critic_params = ParamsAdamW { lr, beta1: 0.9, beta2: 0.999, eps: 1e-8, weight_decay: 0.0 };
323 let critic_optimizer = AdamW::new(critic_varmap.all_vars(), critic_params)
324 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
325
326 let alpha_params = ParamsAdamW { lr: lr * 0.1, beta1: 0.9, beta2: 0.999, eps: 1e-8, weight_decay: 0.0 };
327 let alpha_optimizer = AdamW::new(vec![log_alpha.clone()], alpha_params)
328 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
329
330 Ok(Self {
331 actor,
332 critic,
333 target_critic,
334 actor_optimizer,
335 critic_optimizer,
336 log_alpha,
337 alpha_optimizer,
338 target_entropy,
339 actor_varmap,
340 critic_varmap,
341 alpha_varmap,
342 num_actions,
343 num_params,
344 gamma,
345 tau: 0.005,
346 step_count: 0,
347 device: device.clone(),
348 })
349 }
350
351 fn sample_action(&self, state: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
353 let (action_logits, param_mean, param_logstd) = self.actor.forward(state)?;
354
355 let action_probs = softmax(&action_logits, 1)?;
357 let action_discrete_onehot = self.gumbel_softmax(&action_logits, 1.0f32)?;
358
359 let param_std = param_logstd.exp()?;
361
362 let noise = Tensor::randn(0.0f32, 1.0f32, param_mean.shape(), &self.device)?;
364 let action_continuous = (¶m_mean + ¶m_std.mul(&noise)?)?;
365
366 let log_prob_discrete = action_probs.log()?.mul(&action_discrete_onehot)?.sum(1)?;
368 let log_prob_continuous = self.gaussian_log_prob(¶m_mean, ¶m_std, &action_continuous)?;
369 let log_prob = (log_prob_discrete + log_prob_continuous)?;
370
371 Ok((action_discrete_onehot, action_continuous, log_prob))
372 }
373
374 fn gumbel_softmax(&self, logits: &Tensor, temperature: f32) -> candle_core::error::Result<Tensor> {
376 let uniform = Tensor::rand(0.0f32, 1.0f32, logits.shape(), logits.device())?;
378
379 let eps = 1e-10f32;
381 let gumbel = uniform.clamp(eps, 1.0f32 - eps)?;
382 let gumbel = gumbel.log()?.neg()?;
383 let gumbel = gumbel.log()?.neg()?;
384
385 let batch_size = logits.dims()[0];
387 let num_actions = logits.dims()[1];
388 let temp_tensor = Tensor::from_vec(
389 vec![temperature; batch_size * num_actions],
390 &[batch_size, num_actions],
391 logits.device()
392 )?;
393
394 let y = (logits.clone() + gumbel)?.div(&temp_tensor)?;
395 softmax(&y, 1)
396 }
397
398 fn gaussian_log_prob(&self, mean: &Tensor, std: &Tensor, value: &Tensor) -> candle_core::error::Result<Tensor> {
400 let batch_size = mean.dims()[0];
402 let num_params = mean.dims()[1];
403
404 let std_broadcast = if std.dims().len() == 1 {
406 std.unsqueeze(0)?.broadcast_as(mean.shape())?
407 } else {
408 std.clone()
409 };
410
411 let variance = std_broadcast.sqr()?;
412 let log_std = std_broadcast.log()?;
413 let diff = (value - mean)?;
414
415 let pi_constant = Tensor::from_vec(
417 vec![2.0f32 * std::f32::consts::PI; batch_size * num_params],
418 &[batch_size, num_params],
419 mean.device()
420 )?;
421
422 let half_tensor = Tensor::from_vec(
424 vec![0.5f32; batch_size * num_params],
425 &[batch_size, num_params],
426 mean.device()
427 )?;
428
429 let nll = half_tensor.mul(&(
430 diff.sqr()?.div(&variance)? +
431 pi_constant.log()? +
432 log_std.mul(&Tensor::from_vec(
433 vec![2.0f32; batch_size * num_params],
434 &[batch_size, num_params],
435 mean.device()
436 )?)?
437 )?)?;
438
439 nll.sum(1)
440 }
441
442 fn soft_update_target(&mut self) -> Result<()> {
444 let tau = self.tau;
449 let device = &self.device;
450
451 if self.step_count.is_multiple_of(100) {
457 if self.step_count.is_multiple_of(1000) {
461 info!("SAC target network update at step {} (tau={})", self.step_count, tau);
462 }
463
464 let _ = soft_update_linear(&self.target_critic.q1_fc1, &self.critic.q1_fc1, tau, device);
467 let _ = soft_update_layernorm(&self.target_critic.q1_ln1, &self.critic.q1_ln1, tau, device);
468 let _ = soft_update_linear(&self.target_critic.q1_fc2, &self.critic.q1_fc2, tau, device);
469 let _ = soft_update_layernorm(&self.target_critic.q1_ln2, &self.critic.q1_ln2, tau, device);
470 let _ = soft_update_linear(&self.target_critic.q1_output, &self.critic.q1_output, tau, device);
471
472 let _ = soft_update_linear(&self.target_critic.q2_fc1, &self.critic.q2_fc1, tau, device);
473 let _ = soft_update_layernorm(&self.target_critic.q2_ln1, &self.critic.q2_ln1, tau, device);
474 let _ = soft_update_linear(&self.target_critic.q2_fc2, &self.critic.q2_fc2, tau, device);
475 let _ = soft_update_layernorm(&self.target_critic.q2_ln2, &self.critic.q2_ln2, tau, device);
476 let _ = soft_update_linear(&self.target_critic.q2_output, &self.critic.q2_output, tau, device);
477 }
478
479 Ok(())
480 }
481
482 pub fn save_to_file(&self, path: &Path, metadata: ModelMetadata) -> Result<()> {
484 use std::fs::File;
485 use std::io::Write;
486 let mut file = File::create(path)?;
487
488 let metadata_json = serde_json::to_string(&metadata)
490 .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
491 let metadata_bytes = metadata_json.as_bytes();
492 let metadata_len = metadata_bytes.len() as u64;
493
494 file.write_all(&metadata_len.to_le_bytes())?;
495 file.write_all(metadata_bytes)?;
496
497 let mut tensors: HashMap<String, (Vec<usize>, Vec<f32>)> = HashMap::new();
499
500 save_linear_helper(&mut tensors, "actor.fc1", &self.actor.fc1)?;
502 save_layernorm_helper(&mut tensors, "actor.ln1", &self.actor.ln1)?;
503 save_linear_helper(&mut tensors, "actor.fc2", &self.actor.fc2)?;
504 save_layernorm_helper(&mut tensors, "actor.ln2", &self.actor.ln2)?;
505 save_linear_helper(&mut tensors, "actor.fc3", &self.actor.fc3)?;
506 save_layernorm_helper(&mut tensors, "actor.ln3", &self.actor.ln3)?;
507 save_linear_helper(&mut tensors, "actor.action_logits", &self.actor.action_logits)?;
508 save_linear_helper(&mut tensors, "actor.param_mean", &self.actor.param_mean)?;
509 save_linear_helper(&mut tensors, "actor.param_logstd", &self.actor.param_logstd)?;
510
511 save_linear_helper(&mut tensors, "critic.q1_fc1", &self.critic.q1_fc1)?;
513 save_layernorm_helper(&mut tensors, "critic.q1_ln1", &self.critic.q1_ln1)?;
514 save_linear_helper(&mut tensors, "critic.q1_fc2", &self.critic.q1_fc2)?;
515 save_layernorm_helper(&mut tensors, "critic.q1_ln2", &self.critic.q1_ln2)?;
516 save_linear_helper(&mut tensors, "critic.q1_output", &self.critic.q1_output)?;
517
518 save_linear_helper(&mut tensors, "critic.q2_fc1", &self.critic.q2_fc1)?;
519 save_layernorm_helper(&mut tensors, "critic.q2_ln1", &self.critic.q2_ln1)?;
520 save_linear_helper(&mut tensors, "critic.q2_fc2", &self.critic.q2_fc2)?;
521 save_layernorm_helper(&mut tensors, "critic.q2_ln2", &self.critic.q2_ln2)?;
522 save_linear_helper(&mut tensors, "critic.q2_output", &self.critic.q2_output)?;
523
524 let log_alpha_tensor = self.log_alpha.as_tensor();
526 let log_alpha_shape = log_alpha_tensor.dims().to_vec();
527 let log_alpha_data = log_alpha_tensor.flatten_all()
528 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
529 .to_vec1::<f32>()
530 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
531 tensors.insert("log_alpha".to_string(), (log_alpha_shape, log_alpha_data));
532
533 let tensor_count = tensors.len() as u64;
535 file.write_all(&tensor_count.to_le_bytes())?;
536
537 for (name, (shape, data)) in tensors.iter() {
539 let name_bytes = name.as_bytes();
540 let name_len = name_bytes.len() as u64;
541 file.write_all(&name_len.to_le_bytes())?;
542 file.write_all(name_bytes)?;
543
544 let shape_len = shape.len() as u64;
545 file.write_all(&shape_len.to_le_bytes())?;
546 for &dim in shape {
547 file.write_all(&(dim as u64).to_le_bytes())?;
548 }
549
550 let data_len = data.len() as u64;
551 file.write_all(&data_len.to_le_bytes())?;
552 for &value in data {
553 file.write_all(&value.to_le_bytes())?;
554 }
555 }
556
557 let file_size = std::fs::metadata(path)?.len();
558 tracing::info!("SAC model saved: {} bytes", file_size);
559
560 Ok(())
561 }
562
563 pub fn load_from_file(
565 path: &Path,
566 state_dim: usize,
567 num_actions: usize,
568 num_params: usize,
569 device: &Device,
570 ) -> Result<Self> {
571 use std::fs::File;
572 use std::io::Read;
573
574 tracing::info!("Loading SAC model from: {}", path.display());
575
576 let mut file = File::open(path)?;
577
578 let mut metadata_len_bytes = [0u8; 8];
580 file.read_exact(&mut metadata_len_bytes)?;
581 let metadata_len = u64::from_le_bytes(metadata_len_bytes) as usize;
582 if metadata_len > 10 * 1024 * 1024 {
583 return Err(crate::ExtractionError::ParseError(format!("Invalid model file: metadata length {} is too large", metadata_len)));
584 }
585
586 let mut metadata_bytes = vec![0u8; metadata_len];
587 file.read_exact(&mut metadata_bytes)?;
588
589 let metadata_json = String::from_utf8(metadata_bytes)
590 .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
591 let _metadata: ModelMetadata = serde_json::from_str(&metadata_json)
592 .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
593
594 tracing::info!("Model metadata loaded, loading tensors...");
595
596 let mut tensor_count_bytes = [0u8; 8];
598 file.read_exact(&mut tensor_count_bytes)?;
599 let tensor_count = u64::from_le_bytes(tensor_count_bytes) as usize;
600
601 let mut tensors: HashMap<String, Tensor> = HashMap::new();
602
603 for _ in 0..tensor_count {
604 let mut name_len_bytes = [0u8; 8];
605 file.read_exact(&mut name_len_bytes)?;
606 let name_len = u64::from_le_bytes(name_len_bytes) as usize;
607
608 let mut name_bytes = vec![0u8; name_len];
609 file.read_exact(&mut name_bytes)?;
610 let name = String::from_utf8(name_bytes)
611 .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
612
613 let mut shape_len_bytes = [0u8; 8];
614 file.read_exact(&mut shape_len_bytes)?;
615 let shape_len = u64::from_le_bytes(shape_len_bytes) as usize;
616
617 let mut shape = Vec::with_capacity(shape_len);
618 for _ in 0..shape_len {
619 let mut dim_bytes = [0u8; 8];
620 file.read_exact(&mut dim_bytes)?;
621 shape.push(u64::from_le_bytes(dim_bytes) as usize);
622 }
623
624 let mut data_len_bytes = [0u8; 8];
625 file.read_exact(&mut data_len_bytes)?;
626 let data_len = u64::from_le_bytes(data_len_bytes) as usize;
627
628 let mut data = Vec::with_capacity(data_len);
629 for _ in 0..data_len {
630 let mut value_bytes = [0u8; 4];
631 file.read_exact(&mut value_bytes)?;
632 data.push(f32::from_le_bytes(value_bytes));
633 }
634
635 let tensor = Tensor::from_vec(data, shape.as_slice(), device)
636 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
637 tensors.insert(name, tensor);
638 }
639
640 tracing::info!("Loaded {} tensors, reconstructing model...", tensors.len());
641
642 let mut actor_varmap = VarMap::new();
644 let actor_vb = VarBuilder::from_varmap(&actor_varmap, DType::F32, device);
645 let _ = SACActorNetwork::new(state_dim, num_actions, num_params, actor_vb)
646 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
647
648 let mut critic_varmap = VarMap::new();
649 let critic_vb = VarBuilder::from_varmap(&critic_varmap, DType::F32, device);
650 let _ = SACCriticNetwork::new(state_dim, num_actions, num_params, critic_vb.pp("online"))
651 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
652
653 for (name, tensor) in tensors.iter() {
654 if name.starts_with("actor.") {
655 let actor_name = name.strip_prefix("actor.").unwrap();
656 actor_varmap.set_one(actor_name, tensor)
657 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
658 } else if name.starts_with("critic.") {
659 let critic_name = format!("online.{}", name.strip_prefix("critic.").unwrap());
661 critic_varmap.set_one(&critic_name, tensor)
662 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
663 }
664 }
666
667 Self::new(state_dim, num_actions, num_params, 0.95, 3e-4, device, actor_varmap, critic_varmap)
668 }
669
670 pub fn load_with_device(
672 path: &Path,
673 state_dim: usize,
674 num_actions: usize,
675 num_params: usize,
676 device: &Device,
677 ) -> Result<Self> {
678 Self::load_from_file(path, state_dim, num_actions, num_params, device)
679 }
680
681 pub fn save_to_safetensors(&self, path: &Path) -> Result<()> {
683 use safetensors::tensor::{Dtype, TensorView};
684 use std::collections::HashMap;
685
686 let mut tensors_data: HashMap<String, TensorView> = HashMap::new();
687 let mut all_tensor_bytes: Vec<(String, Vec<usize>, Vec<u8>)> = Vec::new();
688
689 let mut collect_tensor = |name: &str, tensor: &Tensor| -> Result<()> {
691 let shape = tensor.dims().to_vec();
692 let data = tensor.flatten_all()
693 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
694 .to_vec1::<f32>()
695 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
696 let bytes: Vec<u8> = data.iter()
697 .flat_map(|&f| f.to_le_bytes())
698 .collect();
699
700 all_tensor_bytes.push((name.to_string(), shape, bytes));
701 Ok(())
702 };
703
704 collect_tensor("actor.fc1.weight", self.actor.fc1.weight())?;
706 if let Some(bias) = self.actor.fc1.bias() {
707 collect_tensor("actor.fc1.bias", bias)?;
708 }
709
710 collect_tensor("actor.ln1.weight", self.actor.ln1.weight())?;
711 if let Some(bias) = self.actor.ln1.bias() {
712 collect_tensor("actor.ln1.bias", bias)?;
713 }
714
715 collect_tensor("actor.fc2.weight", self.actor.fc2.weight())?;
716 if let Some(bias) = self.actor.fc2.bias() {
717 collect_tensor("actor.fc2.bias", bias)?;
718 }
719
720 collect_tensor("actor.ln2.weight", self.actor.ln2.weight())?;
721 if let Some(bias) = self.actor.ln2.bias() {
722 collect_tensor("actor.ln2.bias", bias)?;
723 }
724
725 collect_tensor("actor.fc3.weight", self.actor.fc3.weight())?;
726 if let Some(bias) = self.actor.fc3.bias() {
727 collect_tensor("actor.fc3.bias", bias)?;
728 }
729
730 collect_tensor("actor.ln3.weight", self.actor.ln3.weight())?;
731 if let Some(bias) = self.actor.ln3.bias() {
732 collect_tensor("actor.ln3.bias", bias)?;
733 }
734
735 collect_tensor("actor.action_logits.weight", self.actor.action_logits.weight())?;
736 if let Some(bias) = self.actor.action_logits.bias() {
737 collect_tensor("actor.action_logits.bias", bias)?;
738 }
739
740 collect_tensor("actor.param_mean.weight", self.actor.param_mean.weight())?;
741 if let Some(bias) = self.actor.param_mean.bias() {
742 collect_tensor("actor.param_mean.bias", bias)?;
743 }
744
745 collect_tensor("actor.param_logstd.weight", self.actor.param_logstd.weight())?;
746 if let Some(bias) = self.actor.param_logstd.bias() {
747 collect_tensor("actor.param_logstd.bias", bias)?;
748 }
749
750 collect_tensor("critic.q1_fc1.weight", self.critic.q1_fc1.weight())?;
752 if let Some(bias) = self.critic.q1_fc1.bias() {
753 collect_tensor("critic.q1_fc1.bias", bias)?;
754 }
755
756 collect_tensor("critic.q1_ln1.weight", self.critic.q1_ln1.weight())?;
757 if let Some(bias) = self.critic.q1_ln1.bias() {
758 collect_tensor("critic.q1_ln1.bias", bias)?;
759 }
760
761 collect_tensor("critic.q1_fc2.weight", self.critic.q1_fc2.weight())?;
762 if let Some(bias) = self.critic.q1_fc2.bias() {
763 collect_tensor("critic.q1_fc2.bias", bias)?;
764 }
765
766 collect_tensor("critic.q1_ln2.weight", self.critic.q1_ln2.weight())?;
767 if let Some(bias) = self.critic.q1_ln2.bias() {
768 collect_tensor("critic.q1_ln2.bias", bias)?;
769 }
770
771 collect_tensor("critic.q1_output.weight", self.critic.q1_output.weight())?;
772 if let Some(bias) = self.critic.q1_output.bias() {
773 collect_tensor("critic.q1_output.bias", bias)?;
774 }
775
776 collect_tensor("critic.q2_fc1.weight", self.critic.q2_fc1.weight())?;
777 if let Some(bias) = self.critic.q2_fc1.bias() {
778 collect_tensor("critic.q2_fc1.bias", bias)?;
779 }
780
781 collect_tensor("critic.q2_ln1.weight", self.critic.q2_ln1.weight())?;
782 if let Some(bias) = self.critic.q2_ln1.bias() {
783 collect_tensor("critic.q2_ln1.bias", bias)?;
784 }
785
786 collect_tensor("critic.q2_fc2.weight", self.critic.q2_fc2.weight())?;
787 if let Some(bias) = self.critic.q2_fc2.bias() {
788 collect_tensor("critic.q2_fc2.bias", bias)?;
789 }
790
791 collect_tensor("critic.q2_ln2.weight", self.critic.q2_ln2.weight())?;
792 if let Some(bias) = self.critic.q2_ln2.bias() {
793 collect_tensor("critic.q2_ln2.bias", bias)?;
794 }
795
796 collect_tensor("critic.q2_output.weight", self.critic.q2_output.weight())?;
797 if let Some(bias) = self.critic.q2_output.bias() {
798 collect_tensor("critic.q2_output.bias", bias)?;
799 }
800
801 collect_tensor("log_alpha", self.log_alpha.as_tensor())?;
803
804 for (name, shape, bytes) in &all_tensor_bytes {
806 tensors_data.insert(
807 name.clone(),
808 TensorView::new(Dtype::F32, shape.clone(), bytes)
809 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
810 );
811 }
812
813 let serialized = safetensors::serialize(&tensors_data, None)
814 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
815
816 std::fs::write(path, serialized)?;
817
818 tracing::info!("SAC model saved to SafeTensors: {} bytes",
819 std::fs::metadata(path).map(|m| m.len()).unwrap_or(0));
820
821 Ok(())
822 }
823
824 pub fn save_to_onnx_with_metadata(&self, path: &Path, metadata: ModelMetadata) -> Result<()> {
826 self.save_to_file(path, metadata)
827 }
828}
829
830impl RLAgent for SACAgent {
831 fn select_action(&self, state: &[f32], _epsilon: f32) -> Result<(usize, Vec<f32>)> {
832 let state_tensor = Tensor::from_vec(state.to_vec(), &[1, state.len()], &self.device)?;
833
834 let (action_logits, param_mean, _param_logstd) = self.actor.forward(&state_tensor)
835 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
836
837 let action_probs = softmax(&action_logits, 1)
839 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
840
841 let action_probs_vec = action_probs.to_vec2::<f32>()
842 .map_err(|e| crate::ExtractionError::ModelError(format!("Failed to convert action probs to vec2: {}", e)))?;
843
844 let discrete_action = action_probs_vec.first()
846 .ok_or_else(|| crate::ExtractionError::ModelError("Empty action probabilities".to_string()))?
847 .iter()
848 .enumerate()
849 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
850 .map(|(idx, _)| idx)
851 .unwrap_or(0);
852
853 let param_mean_vec = param_mean.to_vec2::<f32>()
855 .map_err(|e| crate::ExtractionError::ModelError(format!("Failed to convert param mean to vec2: {}", e)))?;
856
857 let continuous_params = param_mean_vec.first()
858 .ok_or_else(|| crate::ExtractionError::ModelError("Empty param mean".to_string()))?
859 .clone();
860
861 Ok((discrete_action, continuous_params))
862 }
863
864 fn train_step(&mut self, replay_buffer: &mut PrioritizedReplayBuffer, batch_size: usize) -> Result<f32> {
865 let batch = replay_buffer.sample(batch_size);
866 if batch.is_none() {
867 return Ok(0.0);
868 }
869
870 let batch = batch.unwrap();
871 let experiences = &batch.experiences;
872
873 if experiences.is_empty() {
874 return Ok(0.0);
875 }
876
877 let state_dim = experiences[0].state.len();
879 let states_flat: Vec<f32> = experiences.iter().flat_map(|e| e.state.clone()).collect();
880 let states = Tensor::from_vec(states_flat, &[experiences.len(), state_dim], &self.device)?;
881
882 let next_states_flat: Vec<f32> = experiences.iter().flat_map(|e| e.next_state.clone()).collect();
883 let next_states = Tensor::from_vec(next_states_flat, &[experiences.len(), state_dim], &self.device)?;
884
885 let rewards: Vec<f32> = experiences.iter().map(|e| e.reward).collect();
886 let rewards_tensor = Tensor::from_vec(rewards, &[experiences.len()], &self.device)?;
887
888 let dones: Vec<f32> = experiences.iter().map(|e| if e.done { 1.0 } else { 0.0 }).collect();
889 let dones_tensor = Tensor::from_vec(dones, &[experiences.len()], &self.device)?;
890
891 let alpha = self.log_alpha.as_tensor().exp()?;
893 let alpha_scalar = if alpha.dims().is_empty() {
894 alpha.to_scalar::<f32>()?
895 } else {
896 alpha.to_vec1::<f32>()?.first().copied().unwrap_or(0.0)
897 };
898
899 let (next_action_discrete, next_action_continuous, next_log_prob) = self.sample_action(&next_states)?;
901 let (next_q1, next_q2) = self.target_critic.forward(&next_states, &next_action_discrete, &next_action_continuous)?;
902 let next_q = next_q1.minimum(&next_q2)?;
903
904 let batch_size_val = experiences.len();
906 let alpha_broadcast = Tensor::from_vec(vec![alpha_scalar; batch_size_val], &[batch_size_val], &self.device)?;
907 let gamma_tensor = Tensor::from_vec(vec![self.gamma; batch_size_val], &[batch_size_val], &self.device)?;
908 let ones = Tensor::ones(&[batch_size_val], DType::F32, &self.device)?;
909
910 let target_q = (
911 &rewards_tensor +
912 (&ones - &dones_tensor)?.mul(&gamma_tensor)?.mul(
913 &(&next_q - &alpha_broadcast.mul(&next_log_prob)?)?
914 )?
915 )?;
916
917 let actions_discrete: Vec<f32> = experiences.iter()
919 .flat_map(|e| {
920 let mut onehot = vec![0.0f32; self.num_actions];
921 if e.action.0 < self.num_actions {
922 onehot[e.action.0] = 1.0;
923 }
924 onehot
925 })
926 .collect();
927 let actions_discrete_tensor = Tensor::from_vec(actions_discrete, &[experiences.len(), self.num_actions], &self.device)?;
928
929 let actions_continuous_flat: Vec<f32> = experiences.iter().flat_map(|e| e.action.1.clone()).collect();
930 let actions_continuous_tensor = Tensor::from_vec(actions_continuous_flat, &[experiences.len(), self.num_params], &self.device)?;
931
932 let (current_q1, current_q2) = self.critic.forward(&states, &actions_discrete_tensor, &actions_continuous_tensor)?;
933
934 let critic_loss = (
935 (¤t_q1 - &target_q)?.sqr()? +
936 (¤t_q2 - &target_q)?.sqr()?
937 )?.mean_all()?;
938
939 let critic_grads = critic_loss.backward()?;
941 self.critic_optimizer.step(&critic_grads)
942 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
943
944 let (sampled_action_discrete, sampled_action_continuous, log_prob) = self.sample_action(&states)?;
946 let (q1_new, q2_new) = self.critic.forward(&states, &sampled_action_discrete, &sampled_action_continuous)?;
947 let q_new = q1_new.minimum(&q2_new)?;
948
949 let log_prob_size = log_prob.dims()[0];
951 let alpha_broadcast_actor = Tensor::from_vec(vec![alpha_scalar; log_prob_size], &[log_prob_size], &self.device)?;
952 let actor_loss = (&alpha_broadcast_actor.mul(&log_prob)? - &q_new)?.mean_all()?;
953
954 let actor_grads = actor_loss.backward()?;
955 self.actor_optimizer.step(&actor_grads)
956 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
957
958 let target_entropy_tensor = Tensor::from_vec(
961 vec![self.target_entropy; log_prob_size],
962 &[log_prob_size],
963 &self.device
964 )?;
965
966 let alpha_loss_term = (&log_prob + &target_entropy_tensor)?;
968 let alpha_loss_term_detached = alpha_loss_term.detach();
969
970 let log_alpha_tensor = self.log_alpha.as_tensor();
972 let log_alpha_scalar = if log_alpha_tensor.dims().is_empty() {
973 log_alpha_tensor.to_scalar::<f32>()?
974 } else {
975 log_alpha_tensor.to_vec1::<f32>()?.first().copied().unwrap_or(0.0)
976 };
977
978 let log_alpha_broadcast = Tensor::from_vec(
979 vec![log_alpha_scalar; log_prob_size],
980 &[log_prob_size],
981 &self.device
982 )?;
983
984 let alpha_loss = (&log_alpha_broadcast.neg()? * &alpha_loss_term_detached)?.mean_all()?;
985
986 let alpha_grads = alpha_loss.backward()?;
987 self.alpha_optimizer.step(&alpha_grads)
988 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
989
990 self.soft_update_target()?;
992
993 self.step_count += 1;
994
995 Ok(critic_loss.to_scalar::<f32>()?)
996 }
997
998 fn update_target_network(&mut self) {
999 }
1001
1002 fn get_step_count(&self) -> usize {
1003 self.step_count
1004 }
1005
1006 fn save_with_metadata(
1007 &self,
1008 path: &Path,
1009 training_episodes: usize,
1010 hyperparameters: HashMap<String, f64>,
1011 ) -> Result<()> {
1012 let metadata = ModelMetadata::new(
1013 300,
1014 self.num_actions,
1015 self.num_params,
1016 AlgorithmType::SAC, training_episodes,
1018 hyperparameters,
1019 );
1020
1021 self.save_to_onnx_with_metadata(path, metadata)?;
1023
1024 let safetensors_path = path.with_extension("safetensors");
1026 self.save_to_safetensors(&safetensors_path)?;
1027
1028 tracing::info!("SAC model saved with metadata: ONNX ({} bytes), SafeTensors ({} bytes)",
1029 std::fs::metadata(path).map(|m| m.len()).unwrap_or(0),
1030 std::fs::metadata(&safetensors_path).map(|m| m.len()).unwrap_or(0));
1031
1032 Ok(())
1033 }
1034
1035 fn save(&self, path: &Path) -> Result<()> {
1036 self.save_with_metadata(path, 0, HashMap::new())
1037 }
1038
1039 fn algorithm_type(&self) -> AlgorithmType {
1040 AlgorithmType::SAC
1041 }
1042
1043 fn get_info(&self) -> AgentInfo {
1044 AgentInfo {
1045 algorithm: AlgorithmType::SAC,
1046 num_parameters: 0,
1047 state_dim: 0,
1048 num_actions: self.num_actions,
1049 continuous_params: self.num_params,
1050 version: "1.0.0".to_string(),
1051 features: vec![
1052 "twin_q".to_string(),
1053 "entropy_regularization".to_string(),
1054 "automatic_temperature".to_string(),
1055 "off_policy".to_string(),
1056 ],
1057 }
1058 }
1059
1060}