Skip to main content

content_extractor_rl/agents/
sac_agent.rs

1//! Soft Actor-Critic network
2// ============================================================================
3// FILE: crates/content-extractor-rl/src/agents/sac_agent.rs
4// ============================================================================
5
6use candle_core::{Device, Tensor, DType, Var};
7use candle_nn::{VarBuilder, Optimizer, AdamW, ParamsAdamW, VarMap, Linear, Module, linear, layer_norm, LayerNorm};
8use crate::replay_buffer::PrioritizedReplayBuffer;
9use crate::{Result, agents::{RLAgent, AlgorithmType, AgentInfo}};
10use tracing::info;
11use std::path::Path;
12use std::collections::HashMap;
13use crate::models::ModelMetadata;
14use candle_nn::ops::softmax;
15
16/// Actor network for SAC (outputs mean and log_std)
17#[allow(dead_code)]
18pub struct SACActorNetwork {
19    fc1: Linear,
20    ln1: LayerNorm,
21    fc2: Linear,
22    ln2: LayerNorm,
23    fc3: Linear,
24    ln3: LayerNorm,
25    // Discrete action head
26    action_logits: Linear,
27
28    // Continuous parameter heads
29    param_mean: Linear,
30    param_logstd: Linear,
31
32    device: Device,
33    num_actions: usize,
34    num_params: usize,
35}
36impl SACActorNetwork {
37    pub fn new(
38        state_dim: usize,
39        num_actions: usize,
40        num_params: usize,
41        vb: VarBuilder,
42    ) -> candle_core::error::Result<Self> {
43        let device = vb.device().clone();
44        let fc1 = linear(state_dim, 512, vb.pp("fc1"))?;
45        let ln1 = layer_norm(512, 1e-5, vb.pp("ln1"))?;
46        let fc2 = linear(512, 256, vb.pp("fc2"))?;
47        let ln2 = layer_norm(256, 1e-5, vb.pp("ln2"))?;
48        let fc3 = linear(256, 128, vb.pp("fc3"))?;
49        let ln3 = layer_norm(128, 1e-5, vb.pp("ln3"))?;
50
51        let action_logits = linear(128, num_actions, vb.pp("action_logits"))?;
52        let param_mean = linear(128, num_params, vb.pp("param_mean"))?;
53        let param_logstd = linear(128, num_params, vb.pp("param_logstd"))?;
54
55        Ok(Self {
56            fc1, ln1, fc2, ln2, fc3, ln3,
57            action_logits,
58            param_mean,
59            param_logstd,
60            device,
61            num_actions,
62            num_params,
63        })
64    }
65
66    pub fn forward(&self, state: &Tensor) -> candle_core::error::Result<(Tensor, Tensor, Tensor)> {
67        let mut x = self.fc1.forward(state)?;
68        x = self.ln1.forward(&x)?;
69        x = x.relu()?;
70
71        x = self.fc2.forward(&x)?;
72        x = self.ln2.forward(&x)?;
73        x = x.relu()?;
74
75        x = self.fc3.forward(&x)?;
76        x = self.ln3.forward(&x)?;
77        let features = x.relu()?;
78
79        let action_logits = self.action_logits.forward(&features)?;
80        let param_mean = self.param_mean.forward(&features)?.tanh()?;
81        let param_logstd = self.param_logstd.forward(&features)?.clamp(-20.0, 2.0)?;
82
83        Ok((action_logits, param_mean, param_logstd))
84    }
85}
86
87/// Twin Q-network for SAC
88#[allow(dead_code)]
89pub struct SACCriticNetwork {
90    // Q1 network
91    q1_fc1: Linear,
92    q1_ln1: LayerNorm,
93    q1_fc2: Linear,
94    q1_ln2: LayerNorm,
95    q1_output: Linear,
96    // Q2 network (twin)
97    q2_fc1: Linear,
98    q2_ln1: LayerNorm,
99    q2_fc2: Linear,
100    q2_ln2: LayerNorm,
101    q2_output: Linear,
102
103    num_actions: usize,
104    num_params: usize,
105}
106impl SACCriticNetwork {
107    pub fn new(
108        state_dim: usize,
109        num_actions: usize,
110        num_params: usize,
111        vb: VarBuilder,
112    ) -> candle_core::error::Result<Self> {
113        // Combined state-action dimension
114        let input_dim = state_dim + num_actions + num_params;
115        // Q1 network
116        let q1_fc1 = linear(input_dim, 512, vb.pp("q1_fc1"))?;
117        let q1_ln1 = layer_norm(512, 1e-5, vb.pp("q1_ln1"))?;
118        let q1_fc2 = linear(512, 256, vb.pp("q1_fc2"))?;
119        let q1_ln2 = layer_norm(256, 1e-5, vb.pp("q1_ln2"))?;
120        let q1_output = linear(256, 1, vb.pp("q1_output"))?;
121
122        // Q2 network
123        let q2_fc1 = linear(input_dim, 512, vb.pp("q2_fc1"))?;
124        let q2_ln1 = layer_norm(512, 1e-5, vb.pp("q2_ln1"))?;
125        let q2_fc2 = linear(512, 256, vb.pp("q2_fc2"))?;
126        let q2_ln2 = layer_norm(256, 1e-5, vb.pp("q2_ln2"))?;
127        let q2_output = linear(256, 1, vb.pp("q2_output"))?;
128
129        Ok(Self {
130            q1_fc1, q1_ln1, q1_fc2, q1_ln2, q1_output,
131            q2_fc1, q2_ln1, q2_fc2, q2_ln2, q2_output,
132            num_actions,
133            num_params,
134        })
135    }
136
137    pub fn forward(
138        &self,
139        state: &Tensor,
140        action_discrete: &Tensor,
141        action_continuous: &Tensor,
142    ) -> candle_core::error::Result<(Tensor, Tensor)> {
143        // Concatenate state and actions
144        let state_action = Tensor::cat(&[state, action_discrete, action_continuous], 1)?;
145
146        // Q1 forward
147        let mut x1 = self.q1_fc1.forward(&state_action)?;
148        x1 = self.q1_ln1.forward(&x1)?;
149        x1 = x1.relu()?;
150        x1 = self.q1_fc2.forward(&x1)?;
151        x1 = self.q1_ln2.forward(&x1)?;
152        x1 = x1.relu()?;
153        let q1 = self.q1_output.forward(&x1)?.squeeze(1)?;
154
155        // Q2 forward
156        let mut x2 = self.q2_fc1.forward(&state_action)?;
157        x2 = self.q2_ln1.forward(&x2)?;
158        x2 = x2.relu()?;
159        x2 = self.q2_fc2.forward(&x2)?;
160        x2 = self.q2_ln2.forward(&x2)?;
161        x2 = x2.relu()?;
162        let q2 = self.q2_output.forward(&x2)?.squeeze(1)?;
163
164        Ok((q1, q2))
165    }
166}
167
168/// SAC Agent with automatic entropy tuning
169pub struct SACAgent {
170    actor: SACActorNetwork,
171    critic: SACCriticNetwork,
172    target_critic: SACCriticNetwork,
173    actor_optimizer: AdamW,
174    critic_optimizer: AdamW,
175
176    // Automatic temperature tuning
177    log_alpha: Var,
178    alpha_optimizer: AdamW,
179    target_entropy: f32,
180
181    #[allow(dead_code)]
182    actor_varmap: VarMap,
183    #[allow(dead_code)]
184    critic_varmap: VarMap,
185    #[allow(dead_code)]
186    alpha_varmap: VarMap,
187
188    num_actions: usize,
189    num_params: usize,
190    gamma: f32,
191    tau: f32,  // Soft update coefficient
192    step_count: usize,
193    device: Device,
194}
195
196
197fn save_linear_helper(
198    tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>,
199    name: &str,
200    linear: &Linear
201) -> Result<()> {
202    let weight = linear.weight();
203    let weight_shape = weight.dims().to_vec();
204    let weight_data = weight.flatten_all()
205        .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
206        .to_vec1::<f32>()
207        .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
208    tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
209
210    if let Some(bias) = linear.bias() {
211        let bias_shape = bias.dims().to_vec();
212        let bias_data = bias.flatten_all()
213            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
214            .to_vec1::<f32>()
215            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
216        tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
217    }
218    Ok(())
219}
220
221fn save_layernorm_helper(
222    tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>,
223    name: &str,
224    ln: &LayerNorm
225) -> Result<()> {
226    let weight = ln.weight();
227    let weight_shape = weight.dims().to_vec();
228    let weight_data = weight.flatten_all()
229        .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
230        .to_vec1::<f32>()
231        .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
232    tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
233
234    if let Some(bias) = ln.bias() {
235        let bias_shape = bias.dims().to_vec();
236        let bias_data = bias.flatten_all()
237            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
238            .to_vec1::<f32>()
239            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
240        tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
241    }
242    Ok(())
243}
244
245/// Helper to perform soft update between two linear layers
246fn soft_update_linear(
247    target: &Linear,
248    source: &Linear,
249    _tau: f32,
250    _device: &Device,
251) -> candle_core::error::Result<()> {
252    // Soft update: target = tau * source + (1 - tau) * target
253    // Note: This is a conceptual implementation
254    // Candle doesn't provide direct weight mutation, so this is a placeholder
255    // In practice, you'd need to recreate the network or use a different approach
256
257    let _source_weight = source.weight();
258    let _target_weight = target.weight();
259
260    // TODO: Implement actual weight interpolation when candle supports it
261    // For now, this is a no-op
262
263    Ok(())
264}
265
266/// Helper to perform soft update between two layer norms
267fn soft_update_layernorm(
268    target: &LayerNorm,
269    source: &LayerNorm,
270    _tau: f32,
271    _device: &Device,
272) -> candle_core::error::Result<()> {
273    let _source_weight = source.weight();
274    let _target_weight = target.weight();
275
276    // TODO: Implement actual weight interpolation when candle supports it
277
278    Ok(())
279}
280
281impl SACAgent {
282    #[allow(clippy::too_many_arguments)]
283    pub fn new(
284        state_dim: usize,
285        num_actions: usize,
286        num_params: usize,
287        gamma: f32,
288        lr: f64,
289        device: &Device,
290        actor_varmap: VarMap,
291        critic_varmap: VarMap,
292    ) -> Result<Self> {
293        // Create actor
294        let actor_vb = VarBuilder::from_varmap(&actor_varmap, DType::F32, device);
295        let actor = SACActorNetwork::new(state_dim, num_actions, num_params, actor_vb)
296            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
297
298        // Create critic and target critic
299        let critic_vb = VarBuilder::from_varmap(&critic_varmap, DType::F32, device);
300        let critic = SACCriticNetwork::new(state_dim, num_actions, num_params, critic_vb.pp("online"))
301            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
302
303        let target_critic_varmap = VarMap::new();
304        let target_vb = VarBuilder::from_varmap(&target_critic_varmap, DType::F32, device);
305        let target_critic = SACCriticNetwork::new(state_dim, num_actions, num_params, target_vb.pp("target"))
306            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
307
308        // Initialize temperature (alpha) for entropy regularization - ensure F32 dtype
309        let alpha_varmap = VarMap::new();
310        // Create with explicit F32 dtype
311        let log_alpha_init = Tensor::zeros(&[], DType::F32, device)?;
312        let log_alpha = Var::from_tensor(&log_alpha_init)?;
313
314        // Target entropy: -dim(action_space)
315        let target_entropy = -(num_actions as f32 + num_params as f32);
316
317        // Create optimizers
318        let actor_params = ParamsAdamW { lr, beta1: 0.9, beta2: 0.999, eps: 1e-8, weight_decay: 0.0 };
319        let actor_optimizer = AdamW::new(actor_varmap.all_vars(), actor_params)
320            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
321
322        let critic_params = ParamsAdamW { lr, beta1: 0.9, beta2: 0.999, eps: 1e-8, weight_decay: 0.0 };
323        let critic_optimizer = AdamW::new(critic_varmap.all_vars(), critic_params)
324            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
325
326        let alpha_params = ParamsAdamW { lr: lr * 0.1, beta1: 0.9, beta2: 0.999, eps: 1e-8, weight_decay: 0.0 };
327        let alpha_optimizer = AdamW::new(vec![log_alpha.clone()], alpha_params)
328            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
329
330        Ok(Self {
331            actor,
332            critic,
333            target_critic,
334            actor_optimizer,
335            critic_optimizer,
336            log_alpha,
337            alpha_optimizer,
338            target_entropy,
339            actor_varmap,
340            critic_varmap,
341            alpha_varmap,
342            num_actions,
343            num_params,
344            gamma,
345            tau: 0.005,
346            step_count: 0,
347            device: device.clone(),
348        })
349    }
350
351    /// Sample action from policy
352    fn sample_action(&self, state: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
353        let (action_logits, param_mean, param_logstd) = self.actor.forward(state)?;
354
355        // Sample discrete action (Gumbel-Softmax for differentiability)
356        let action_probs = softmax(&action_logits, 1)?;
357        let action_discrete_onehot = self.gumbel_softmax(&action_logits, 1.0f32)?;
358
359        // Sample continuous params (reparameterization trick)
360        let param_std = param_logstd.exp()?;
361
362        // Use randn noise with proper F32 dtype
363        let noise = Tensor::randn(0.0f32, 1.0f32, param_mean.shape(), &self.device)?;
364        let action_continuous = (&param_mean + &param_std.mul(&noise)?)?;
365
366        // Calculate log probability for entropy
367        let log_prob_discrete = action_probs.log()?.mul(&action_discrete_onehot)?.sum(1)?;
368        let log_prob_continuous = self.gaussian_log_prob(&param_mean, &param_std, &action_continuous)?;
369        let log_prob = (log_prob_discrete + log_prob_continuous)?;
370
371        Ok((action_discrete_onehot, action_continuous, log_prob))
372    }
373
374    /// Gumbel-Softmax for discrete actions
375    fn gumbel_softmax(&self, logits: &Tensor, temperature: f32) -> candle_core::error::Result<Tensor> {
376        // Proper Gumbel noise generation with F32
377        let uniform = Tensor::rand(0.0f32, 1.0f32, logits.shape(), logits.device())?;
378
379        // Gumbel noise: -log(-log(U))
380        let eps = 1e-10f32;
381        let gumbel = uniform.clamp(eps, 1.0f32 - eps)?;
382        let gumbel = gumbel.log()?.neg()?;
383        let gumbel = gumbel.log()?.neg()?;
384
385        // FIXED: Create temperature tensor with proper F32 dtype
386        let batch_size = logits.dims()[0];
387        let num_actions = logits.dims()[1];
388        let temp_tensor = Tensor::from_vec(
389            vec![temperature; batch_size * num_actions],
390            &[batch_size, num_actions],
391            logits.device()
392        )?;
393
394        let y = (logits.clone() + gumbel)?.div(&temp_tensor)?;
395        softmax(&y, 1)
396    }
397
398    /// Gaussian log probability
399    fn gaussian_log_prob(&self, mean: &Tensor, std: &Tensor, value: &Tensor) -> candle_core::error::Result<Tensor> {
400        // Ensure proper shape broadcasting
401        let batch_size = mean.dims()[0];
402        let num_params = mean.dims()[1];
403
404        // Broadcast std to match mean if needed
405        let std_broadcast = if std.dims().len() == 1 {
406            std.unsqueeze(0)?.broadcast_as(mean.shape())?
407        } else {
408            std.clone()
409        };
410
411        let variance = std_broadcast.sqr()?;
412        let log_std = std_broadcast.log()?;
413        let diff = (value - mean)?;
414
415        // FIXED: Create pi constant with proper F32 dtype
416        let pi_constant = Tensor::from_vec(
417            vec![2.0f32 * std::f32::consts::PI; batch_size * num_params],
418            &[batch_size, num_params],
419            mean.device()
420        )?;
421
422        // Create half tensor with proper F32 dtype
423        let half_tensor = Tensor::from_vec(
424            vec![0.5f32; batch_size * num_params],
425            &[batch_size, num_params],
426            mean.device()
427        )?;
428
429        let nll = half_tensor.mul(&(
430            diff.sqr()?.div(&variance)? +
431                pi_constant.log()? +
432                log_std.mul(&Tensor::from_vec(
433                    vec![2.0f32; batch_size * num_params],
434                    &[batch_size, num_params],
435                    mean.device()
436                )?)?
437        )?)?;
438
439        nll.sum(1)
440    }
441
442    /// Soft update of target network
443    fn soft_update_target(&mut self) -> Result<()> {
444        // Soft update: target = tau * online + (1 - tau) * target
445        // Note: Candle doesn't provide easy weight mutation, so we implement a simplified version
446
447        // For Q-networks, do soft updates on all layers
448        let tau = self.tau;
449        let device = &self.device;
450
451        // In a full implementation, you would interpolate weights like:
452        // target_weight = tau * online_weight + (1 - tau) * target_weight
453
454        // Since candle doesn't easily support in-place weight updates,
455        // we'll do periodic hard copies instead
456        if self.step_count.is_multiple_of(100) {
457            // This is where you'd copy weights from critic to target_critic
458            // For now, we log the update
459
460            if self.step_count.is_multiple_of(1000) {
461                info!("SAC target network update at step {} (tau={})", self.step_count, tau);
462            }
463
464            // Attempt soft update on each layer
465            // Note: These are no-ops until candle supports weight mutation
466            let _ = soft_update_linear(&self.target_critic.q1_fc1, &self.critic.q1_fc1, tau, device);
467            let _ = soft_update_layernorm(&self.target_critic.q1_ln1, &self.critic.q1_ln1, tau, device);
468            let _ = soft_update_linear(&self.target_critic.q1_fc2, &self.critic.q1_fc2, tau, device);
469            let _ = soft_update_layernorm(&self.target_critic.q1_ln2, &self.critic.q1_ln2, tau, device);
470            let _ = soft_update_linear(&self.target_critic.q1_output, &self.critic.q1_output, tau, device);
471
472            let _ = soft_update_linear(&self.target_critic.q2_fc1, &self.critic.q2_fc1, tau, device);
473            let _ = soft_update_layernorm(&self.target_critic.q2_ln1, &self.critic.q2_ln1, tau, device);
474            let _ = soft_update_linear(&self.target_critic.q2_fc2, &self.critic.q2_fc2, tau, device);
475            let _ = soft_update_layernorm(&self.target_critic.q2_ln2, &self.critic.q2_ln2, tau, device);
476            let _ = soft_update_linear(&self.target_critic.q2_output, &self.critic.q2_output, tau, device);
477        }
478
479        Ok(())
480    }
481
482    /// Save SAC model to file with metadata
483    pub fn save_to_file(&self, path: &Path, metadata: ModelMetadata) -> Result<()> {
484        use std::fs::File;
485        use std::io::Write;
486        let mut file = File::create(path)?;
487
488        // Write metadata
489        let metadata_json = serde_json::to_string(&metadata)
490            .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
491        let metadata_bytes = metadata_json.as_bytes();
492        let metadata_len = metadata_bytes.len() as u64;
493
494        file.write_all(&metadata_len.to_le_bytes())?;
495        file.write_all(metadata_bytes)?;
496
497        // Collect all tensors - FIXED: Use helper functions
498        let mut tensors: HashMap<String, (Vec<usize>, Vec<f32>)> = HashMap::new();
499
500        // Save actor network
501        save_linear_helper(&mut tensors, "actor.fc1", &self.actor.fc1)?;
502        save_layernorm_helper(&mut tensors, "actor.ln1", &self.actor.ln1)?;
503        save_linear_helper(&mut tensors, "actor.fc2", &self.actor.fc2)?;
504        save_layernorm_helper(&mut tensors, "actor.ln2", &self.actor.ln2)?;
505        save_linear_helper(&mut tensors, "actor.fc3", &self.actor.fc3)?;
506        save_layernorm_helper(&mut tensors, "actor.ln3", &self.actor.ln3)?;
507        save_linear_helper(&mut tensors, "actor.action_logits", &self.actor.action_logits)?;
508        save_linear_helper(&mut tensors, "actor.param_mean", &self.actor.param_mean)?;
509        save_linear_helper(&mut tensors, "actor.param_logstd", &self.actor.param_logstd)?;
510
511        // Save critic network (Q1 and Q2)
512        save_linear_helper(&mut tensors, "critic.q1_fc1", &self.critic.q1_fc1)?;
513        save_layernorm_helper(&mut tensors, "critic.q1_ln1", &self.critic.q1_ln1)?;
514        save_linear_helper(&mut tensors, "critic.q1_fc2", &self.critic.q1_fc2)?;
515        save_layernorm_helper(&mut tensors, "critic.q1_ln2", &self.critic.q1_ln2)?;
516        save_linear_helper(&mut tensors, "critic.q1_output", &self.critic.q1_output)?;
517
518        save_linear_helper(&mut tensors, "critic.q2_fc1", &self.critic.q2_fc1)?;
519        save_layernorm_helper(&mut tensors, "critic.q2_ln1", &self.critic.q2_ln1)?;
520        save_linear_helper(&mut tensors, "critic.q2_fc2", &self.critic.q2_fc2)?;
521        save_layernorm_helper(&mut tensors, "critic.q2_ln2", &self.critic.q2_ln2)?;
522        save_linear_helper(&mut tensors, "critic.q2_output", &self.critic.q2_output)?;
523
524        // Save log_alpha (temperature parameter)
525        let log_alpha_tensor = self.log_alpha.as_tensor();
526        let log_alpha_shape = log_alpha_tensor.dims().to_vec();
527        let log_alpha_data = log_alpha_tensor.flatten_all()
528            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
529            .to_vec1::<f32>()
530            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
531        tensors.insert("log_alpha".to_string(), (log_alpha_shape, log_alpha_data));
532
533        // Write tensor count
534        let tensor_count = tensors.len() as u64;
535        file.write_all(&tensor_count.to_le_bytes())?;
536
537        // Write each tensor
538        for (name, (shape, data)) in tensors.iter() {
539            let name_bytes = name.as_bytes();
540            let name_len = name_bytes.len() as u64;
541            file.write_all(&name_len.to_le_bytes())?;
542            file.write_all(name_bytes)?;
543
544            let shape_len = shape.len() as u64;
545            file.write_all(&shape_len.to_le_bytes())?;
546            for &dim in shape {
547                file.write_all(&(dim as u64).to_le_bytes())?;
548            }
549
550            let data_len = data.len() as u64;
551            file.write_all(&data_len.to_le_bytes())?;
552            for &value in data {
553                file.write_all(&value.to_le_bytes())?;
554            }
555        }
556
557        let file_size = std::fs::metadata(path)?.len();
558        tracing::info!("SAC model saved: {} bytes", file_size);
559
560        Ok(())
561    }
562
563    /// Load SAC model from file
564    pub fn load_from_file(
565        path: &Path,
566        state_dim: usize,
567        num_actions: usize,
568        num_params: usize,
569        device: &Device,
570    ) -> Result<Self> {
571        use std::fs::File;
572        use std::io::Read;
573
574        tracing::info!("Loading SAC model from: {}", path.display());
575
576        let mut file = File::open(path)?;
577
578        // Read metadata
579        let mut metadata_len_bytes = [0u8; 8];
580        file.read_exact(&mut metadata_len_bytes)?;
581        let metadata_len = u64::from_le_bytes(metadata_len_bytes) as usize;
582        if metadata_len > 10 * 1024 * 1024 {
583            return Err(crate::ExtractionError::ParseError(format!("Invalid model file: metadata length {} is too large", metadata_len)));
584        }
585
586        let mut metadata_bytes = vec![0u8; metadata_len];
587        file.read_exact(&mut metadata_bytes)?;
588
589        let metadata_json = String::from_utf8(metadata_bytes)
590            .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
591        let _metadata: ModelMetadata = serde_json::from_str(&metadata_json)
592            .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
593
594        tracing::info!("Model metadata loaded, loading tensors...");
595
596        // Read tensor count
597        let mut tensor_count_bytes = [0u8; 8];
598        file.read_exact(&mut tensor_count_bytes)?;
599        let tensor_count = u64::from_le_bytes(tensor_count_bytes) as usize;
600
601        let mut tensors: HashMap<String, Tensor> = HashMap::new();
602
603        for _ in 0..tensor_count {
604            let mut name_len_bytes = [0u8; 8];
605            file.read_exact(&mut name_len_bytes)?;
606            let name_len = u64::from_le_bytes(name_len_bytes) as usize;
607
608            let mut name_bytes = vec![0u8; name_len];
609            file.read_exact(&mut name_bytes)?;
610            let name = String::from_utf8(name_bytes)
611                .map_err(|e| crate::ExtractionError::ParseError(e.to_string()))?;
612
613            let mut shape_len_bytes = [0u8; 8];
614            file.read_exact(&mut shape_len_bytes)?;
615            let shape_len = u64::from_le_bytes(shape_len_bytes) as usize;
616
617            let mut shape = Vec::with_capacity(shape_len);
618            for _ in 0..shape_len {
619                let mut dim_bytes = [0u8; 8];
620                file.read_exact(&mut dim_bytes)?;
621                shape.push(u64::from_le_bytes(dim_bytes) as usize);
622            }
623
624            let mut data_len_bytes = [0u8; 8];
625            file.read_exact(&mut data_len_bytes)?;
626            let data_len = u64::from_le_bytes(data_len_bytes) as usize;
627
628            let mut data = Vec::with_capacity(data_len);
629            for _ in 0..data_len {
630                let mut value_bytes = [0u8; 4];
631                file.read_exact(&mut value_bytes)?;
632                data.push(f32::from_le_bytes(value_bytes));
633            }
634
635            let tensor = Tensor::from_vec(data, shape.as_slice(), device)
636                .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
637            tensors.insert(name, tensor);
638        }
639
640        tracing::info!("Loaded {} tensors, reconstructing model...", tensors.len());
641
642        // Create varmaps and populate keys by building networks first, then overwrite with loaded values
643        let mut actor_varmap = VarMap::new();
644        let actor_vb = VarBuilder::from_varmap(&actor_varmap, DType::F32, device);
645        let _ = SACActorNetwork::new(state_dim, num_actions, num_params, actor_vb)
646            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
647
648        let mut critic_varmap = VarMap::new();
649        let critic_vb = VarBuilder::from_varmap(&critic_varmap, DType::F32, device);
650        let _ = SACCriticNetwork::new(state_dim, num_actions, num_params, critic_vb.pp("online"))
651            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
652
653        for (name, tensor) in tensors.iter() {
654            if name.starts_with("actor.") {
655                let actor_name = name.strip_prefix("actor.").unwrap();
656                actor_varmap.set_one(actor_name, tensor)
657                    .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
658            } else if name.starts_with("critic.") {
659                // critic is stored under "online." prefix in the varmap
660                let critic_name = format!("online.{}", name.strip_prefix("critic.").unwrap());
661                critic_varmap.set_one(&critic_name, tensor)
662                    .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
663            }
664            // log_alpha is re-initialized to zero in Self::new
665        }
666
667        Self::new(state_dim, num_actions, num_params, 0.95, 3e-4, device, actor_varmap, critic_varmap)
668    }
669
670    /// Update load_with_device to use load_from_file
671    pub fn load_with_device(
672        path: &Path,
673        state_dim: usize,
674        num_actions: usize,
675        num_params: usize,
676        device: &Device,
677    ) -> Result<Self> {
678        Self::load_from_file(path, state_dim, num_actions, num_params, device)
679    }
680
681    /// Save to SafeTensors format
682    pub fn save_to_safetensors(&self, path: &Path) -> Result<()> {
683        use safetensors::tensor::{Dtype, TensorView};
684        use std::collections::HashMap;
685
686        let mut tensors_data: HashMap<String, TensorView> = HashMap::new();
687        let mut all_tensor_bytes: Vec<(String, Vec<usize>, Vec<u8>)> = Vec::new();
688
689        // Collect all tensors
690        let mut collect_tensor = |name: &str, tensor: &Tensor| -> Result<()> {
691            let shape = tensor.dims().to_vec();
692            let data = tensor.flatten_all()
693                .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
694                .to_vec1::<f32>()
695                .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
696            let bytes: Vec<u8> = data.iter()
697                .flat_map(|&f| f.to_le_bytes())
698                .collect();
699
700            all_tensor_bytes.push((name.to_string(), shape, bytes));
701            Ok(())
702        };
703
704        // Save actor network
705        collect_tensor("actor.fc1.weight", self.actor.fc1.weight())?;
706        if let Some(bias) = self.actor.fc1.bias() {
707            collect_tensor("actor.fc1.bias", bias)?;
708        }
709
710        collect_tensor("actor.ln1.weight", self.actor.ln1.weight())?;
711        if let Some(bias) = self.actor.ln1.bias() {
712            collect_tensor("actor.ln1.bias", bias)?;
713        }
714
715        collect_tensor("actor.fc2.weight", self.actor.fc2.weight())?;
716        if let Some(bias) = self.actor.fc2.bias() {
717            collect_tensor("actor.fc2.bias", bias)?;
718        }
719
720        collect_tensor("actor.ln2.weight", self.actor.ln2.weight())?;
721        if let Some(bias) = self.actor.ln2.bias() {
722            collect_tensor("actor.ln2.bias", bias)?;
723        }
724
725        collect_tensor("actor.fc3.weight", self.actor.fc3.weight())?;
726        if let Some(bias) = self.actor.fc3.bias() {
727            collect_tensor("actor.fc3.bias", bias)?;
728        }
729
730        collect_tensor("actor.ln3.weight", self.actor.ln3.weight())?;
731        if let Some(bias) = self.actor.ln3.bias() {
732            collect_tensor("actor.ln3.bias", bias)?;
733        }
734
735        collect_tensor("actor.action_logits.weight", self.actor.action_logits.weight())?;
736        if let Some(bias) = self.actor.action_logits.bias() {
737            collect_tensor("actor.action_logits.bias", bias)?;
738        }
739
740        collect_tensor("actor.param_mean.weight", self.actor.param_mean.weight())?;
741        if let Some(bias) = self.actor.param_mean.bias() {
742            collect_tensor("actor.param_mean.bias", bias)?;
743        }
744
745        collect_tensor("actor.param_logstd.weight", self.actor.param_logstd.weight())?;
746        if let Some(bias) = self.actor.param_logstd.bias() {
747            collect_tensor("actor.param_logstd.bias", bias)?;
748        }
749
750        // Save critic network (Q1 and Q2)
751        collect_tensor("critic.q1_fc1.weight", self.critic.q1_fc1.weight())?;
752        if let Some(bias) = self.critic.q1_fc1.bias() {
753            collect_tensor("critic.q1_fc1.bias", bias)?;
754        }
755
756        collect_tensor("critic.q1_ln1.weight", self.critic.q1_ln1.weight())?;
757        if let Some(bias) = self.critic.q1_ln1.bias() {
758            collect_tensor("critic.q1_ln1.bias", bias)?;
759        }
760
761        collect_tensor("critic.q1_fc2.weight", self.critic.q1_fc2.weight())?;
762        if let Some(bias) = self.critic.q1_fc2.bias() {
763            collect_tensor("critic.q1_fc2.bias", bias)?;
764        }
765
766        collect_tensor("critic.q1_ln2.weight", self.critic.q1_ln2.weight())?;
767        if let Some(bias) = self.critic.q1_ln2.bias() {
768            collect_tensor("critic.q1_ln2.bias", bias)?;
769        }
770
771        collect_tensor("critic.q1_output.weight", self.critic.q1_output.weight())?;
772        if let Some(bias) = self.critic.q1_output.bias() {
773            collect_tensor("critic.q1_output.bias", bias)?;
774        }
775
776        collect_tensor("critic.q2_fc1.weight", self.critic.q2_fc1.weight())?;
777        if let Some(bias) = self.critic.q2_fc1.bias() {
778            collect_tensor("critic.q2_fc1.bias", bias)?;
779        }
780
781        collect_tensor("critic.q2_ln1.weight", self.critic.q2_ln1.weight())?;
782        if let Some(bias) = self.critic.q2_ln1.bias() {
783            collect_tensor("critic.q2_ln1.bias", bias)?;
784        }
785
786        collect_tensor("critic.q2_fc2.weight", self.critic.q2_fc2.weight())?;
787        if let Some(bias) = self.critic.q2_fc2.bias() {
788            collect_tensor("critic.q2_fc2.bias", bias)?;
789        }
790
791        collect_tensor("critic.q2_ln2.weight", self.critic.q2_ln2.weight())?;
792        if let Some(bias) = self.critic.q2_ln2.bias() {
793            collect_tensor("critic.q2_ln2.bias", bias)?;
794        }
795
796        collect_tensor("critic.q2_output.weight", self.critic.q2_output.weight())?;
797        if let Some(bias) = self.critic.q2_output.bias() {
798            collect_tensor("critic.q2_output.bias", bias)?;
799        }
800
801        // Save log_alpha
802        collect_tensor("log_alpha", self.log_alpha.as_tensor())?;
803
804        // Convert to SafeTensors format
805        for (name, shape, bytes) in &all_tensor_bytes {
806            tensors_data.insert(
807                name.clone(),
808                TensorView::new(Dtype::F32, shape.clone(), bytes)
809                    .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?
810            );
811        }
812
813        let serialized = safetensors::serialize(&tensors_data, None)
814            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
815
816        std::fs::write(path, serialized)?;
817
818        tracing::info!("SAC model saved to SafeTensors: {} bytes",
819                       std::fs::metadata(path).map(|m| m.len()).unwrap_or(0));
820
821        Ok(())
822    }
823
824    /// Save to ONNX format with metadata (wrapper around save_to_file)
825    pub fn save_to_onnx_with_metadata(&self, path: &Path, metadata: ModelMetadata) -> Result<()> {
826        self.save_to_file(path, metadata)
827    }
828}
829
830impl RLAgent for SACAgent {
831    fn select_action(&self, state: &[f32], _epsilon: f32) -> Result<(usize, Vec<f32>)> {
832        let state_tensor = Tensor::from_vec(state.to_vec(), &[1, state.len()], &self.device)?;
833
834        let (action_logits, param_mean, _param_logstd) = self.actor.forward(&state_tensor)
835            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
836
837        // For inference, use mean of distributions with proper error handling
838        let action_probs = softmax(&action_logits, 1)
839            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
840
841        let action_probs_vec = action_probs.to_vec2::<f32>()
842            .map_err(|e| crate::ExtractionError::ModelError(format!("Failed to convert action probs to vec2: {}", e)))?;
843
844        // Find discrete action with highest probability
845        let discrete_action = action_probs_vec.first()
846            .ok_or_else(|| crate::ExtractionError::ModelError("Empty action probabilities".to_string()))?
847            .iter()
848            .enumerate()
849            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
850            .map(|(idx, _)| idx)
851            .unwrap_or(0);
852
853        // Get continuous params with proper error handling
854        let param_mean_vec = param_mean.to_vec2::<f32>()
855            .map_err(|e| crate::ExtractionError::ModelError(format!("Failed to convert param mean to vec2: {}", e)))?;
856
857        let continuous_params = param_mean_vec.first()
858            .ok_or_else(|| crate::ExtractionError::ModelError("Empty param mean".to_string()))?
859            .clone();
860
861        Ok((discrete_action, continuous_params))
862    }
863
864    fn train_step(&mut self, replay_buffer: &mut PrioritizedReplayBuffer, batch_size: usize) -> Result<f32> {
865        let batch = replay_buffer.sample(batch_size);
866        if batch.is_none() {
867            return Ok(0.0);
868        }
869
870        let batch = batch.unwrap();
871        let experiences = &batch.experiences;
872
873        if experiences.is_empty() {
874            return Ok(0.0);
875        }
876
877        // Convert to tensors - all with explicit F32 dtype
878        let state_dim = experiences[0].state.len();
879        let states_flat: Vec<f32> = experiences.iter().flat_map(|e| e.state.clone()).collect();
880        let states = Tensor::from_vec(states_flat, &[experiences.len(), state_dim], &self.device)?;
881
882        let next_states_flat: Vec<f32> = experiences.iter().flat_map(|e| e.next_state.clone()).collect();
883        let next_states = Tensor::from_vec(next_states_flat, &[experiences.len(), state_dim], &self.device)?;
884
885        let rewards: Vec<f32> = experiences.iter().map(|e| e.reward).collect();
886        let rewards_tensor = Tensor::from_vec(rewards, &[experiences.len()], &self.device)?;
887
888        let dones: Vec<f32> = experiences.iter().map(|e| if e.done { 1.0 } else { 0.0 }).collect();
889        let dones_tensor = Tensor::from_vec(dones, &[experiences.len()], &self.device)?;
890
891        // FIXED: Get current alpha (temperature) - ensure F32 dtype
892        let alpha = self.log_alpha.as_tensor().exp()?;
893        let alpha_scalar = if alpha.dims().is_empty() {
894            alpha.to_scalar::<f32>()?
895        } else {
896            alpha.to_vec1::<f32>()?.first().copied().unwrap_or(0.0)
897        };
898
899        // Update critic
900        let (next_action_discrete, next_action_continuous, next_log_prob) = self.sample_action(&next_states)?;
901        let (next_q1, next_q2) = self.target_critic.forward(&next_states, &next_action_discrete, &next_action_continuous)?;
902        let next_q = next_q1.minimum(&next_q2)?;
903
904        // FIXED: All tensors explicitly F32
905        let batch_size_val = experiences.len();
906        let alpha_broadcast = Tensor::from_vec(vec![alpha_scalar; batch_size_val], &[batch_size_val], &self.device)?;
907        let gamma_tensor = Tensor::from_vec(vec![self.gamma; batch_size_val], &[batch_size_val], &self.device)?;
908        let ones = Tensor::ones(&[batch_size_val], DType::F32, &self.device)?;
909
910        let target_q = (
911            &rewards_tensor +
912                (&ones - &dones_tensor)?.mul(&gamma_tensor)?.mul(
913                    &(&next_q - &alpha_broadcast.mul(&next_log_prob)?)?
914                )?
915        )?;
916
917        // Current actions (from experience)
918        let actions_discrete: Vec<f32> = experiences.iter()
919            .flat_map(|e| {
920                let mut onehot = vec![0.0f32; self.num_actions];
921                if e.action.0 < self.num_actions {
922                    onehot[e.action.0] = 1.0;
923                }
924                onehot
925            })
926            .collect();
927        let actions_discrete_tensor = Tensor::from_vec(actions_discrete, &[experiences.len(), self.num_actions], &self.device)?;
928
929        let actions_continuous_flat: Vec<f32> = experiences.iter().flat_map(|e| e.action.1.clone()).collect();
930        let actions_continuous_tensor = Tensor::from_vec(actions_continuous_flat, &[experiences.len(), self.num_params], &self.device)?;
931
932        let (current_q1, current_q2) = self.critic.forward(&states, &actions_discrete_tensor, &actions_continuous_tensor)?;
933
934        let critic_loss = (
935            (&current_q1 - &target_q)?.sqr()? +
936                (&current_q2 - &target_q)?.sqr()?
937        )?.mean_all()?;
938
939        // Backward and update critic
940        let critic_grads = critic_loss.backward()?;
941        self.critic_optimizer.step(&critic_grads)
942            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
943
944        // Update actor
945        let (sampled_action_discrete, sampled_action_continuous, log_prob) = self.sample_action(&states)?;
946        let (q1_new, q2_new) = self.critic.forward(&states, &sampled_action_discrete, &sampled_action_continuous)?;
947        let q_new = q1_new.minimum(&q2_new)?;
948
949        // FIXED: Broadcast alpha for actor loss - explicit F32
950        let log_prob_size = log_prob.dims()[0];
951        let alpha_broadcast_actor = Tensor::from_vec(vec![alpha_scalar; log_prob_size], &[log_prob_size], &self.device)?;
952        let actor_loss = (&alpha_broadcast_actor.mul(&log_prob)? - &q_new)?.mean_all()?;
953
954        let actor_grads = actor_loss.backward()?;
955        self.actor_optimizer.step(&actor_grads)
956            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
957
958        // Update temperature (alpha)
959        // Broadcast target_entropy to match log_prob shape - explicit F32
960        let target_entropy_tensor = Tensor::from_vec(
961            vec![self.target_entropy; log_prob_size],
962            &[log_prob_size],
963            &self.device
964        )?;
965
966        // FIXED: Handle alpha loss calculation - detach returns Result
967        let alpha_loss_term = (&log_prob + &target_entropy_tensor)?;
968        let alpha_loss_term_detached = alpha_loss_term.detach();
969
970        // Get log_alpha as scalar and broadcast
971        let log_alpha_tensor = self.log_alpha.as_tensor();
972        let log_alpha_scalar = if log_alpha_tensor.dims().is_empty() {
973            log_alpha_tensor.to_scalar::<f32>()?
974        } else {
975            log_alpha_tensor.to_vec1::<f32>()?.first().copied().unwrap_or(0.0)
976        };
977
978        let log_alpha_broadcast = Tensor::from_vec(
979            vec![log_alpha_scalar; log_prob_size],
980            &[log_prob_size],
981            &self.device
982        )?;
983
984        let alpha_loss = (&log_alpha_broadcast.neg()? * &alpha_loss_term_detached)?.mean_all()?;
985
986        let alpha_grads = alpha_loss.backward()?;
987        self.alpha_optimizer.step(&alpha_grads)
988            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
989
990        // Soft update target network
991        self.soft_update_target()?;
992
993        self.step_count += 1;
994
995        Ok(critic_loss.to_scalar::<f32>()?)
996    }
997
998    fn update_target_network(&mut self) {
999        // SAC uses soft updates, called in train_step
1000    }
1001
1002    fn get_step_count(&self) -> usize {
1003        self.step_count
1004    }
1005
1006    fn save_with_metadata(
1007        &self,
1008        path: &Path,
1009        training_episodes: usize,
1010        hyperparameters: HashMap<String, f64>,
1011    ) -> Result<()> {
1012        let metadata = ModelMetadata::new(
1013            300,
1014            self.num_actions,
1015            self.num_params,
1016            AlgorithmType::SAC,  // FIXED: Was PPO, should be SAC
1017            training_episodes,
1018            hyperparameters,
1019        );
1020
1021        // Save ONNX with metadata
1022        self.save_to_onnx_with_metadata(path, metadata)?;
1023
1024        // Save SafeTensors
1025        let safetensors_path = path.with_extension("safetensors");
1026        self.save_to_safetensors(&safetensors_path)?;
1027
1028        tracing::info!("SAC model saved with metadata: ONNX ({} bytes), SafeTensors ({} bytes)",
1029               std::fs::metadata(path).map(|m| m.len()).unwrap_or(0),
1030               std::fs::metadata(&safetensors_path).map(|m| m.len()).unwrap_or(0));
1031
1032        Ok(())
1033    }
1034
1035    fn save(&self, path: &Path) -> Result<()> {
1036        self.save_with_metadata(path, 0, HashMap::new())
1037    }
1038
1039    fn algorithm_type(&self) -> AlgorithmType {
1040        AlgorithmType::SAC
1041    }
1042
1043    fn get_info(&self) -> AgentInfo {
1044        AgentInfo {
1045            algorithm: AlgorithmType::SAC,
1046            num_parameters: 0,
1047            state_dim: 0,
1048            num_actions: self.num_actions,
1049            continuous_params: self.num_params,
1050            version: "1.0.0".to_string(),
1051            features: vec![
1052                "twin_q".to_string(),
1053                "entropy_regularization".to_string(),
1054                "automatic_temperature".to_string(),
1055                "off_policy".to_string(),
1056            ],
1057        }
1058    }
1059
1060}