Skip to main content

content_extractor_rl/
models.rs

1// ============================================================================
2// FILE: crates/content-extractor-rl/src/models.rs
3// ============================================================================
4
5use candle_core::{Device, Tensor, DType, Result as CandleResult, Var};
6use candle_nn::{Linear, Module, VarBuilder, linear, layer_norm, LayerNorm};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10use safetensors::SafeTensors;
11use safetensors::tensor::{Dtype, TensorView};
12use tracing::{error, info, warn};
13use crate::agents::AlgorithmType;
14use chrono;
15
16/// Configuration for neural network architecture
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct NetworkConfig {
19    pub state_dim: usize,
20    pub num_actions: usize,
21    pub num_params: usize,
22
23    // Configurable architecture
24    pub hidden_layers: Vec<usize>,  // e.g., [512, 256, 128]
25    pub use_layer_norm: bool,
26    pub dropout: f32,
27
28    // Value and advantage stream sizes
29    pub value_hidden: usize,   // e.g., 64
30    pub advantage_hidden: usize,  // e.g., 64
31}
32
33impl Default for NetworkConfig {
34    fn default() -> Self {
35        Self {
36            state_dim: 300,
37            num_actions: 16,
38            num_params: 6,
39            hidden_layers: vec![512, 256, 128],
40            use_layer_norm: true,
41            dropout: 0.1,
42            value_hidden: 64,
43            advantage_hidden: 64,
44        }
45    }
46}
47
48/// Enhanced model metadata with algorithm and hyperparameters
49#[derive(Debug, Serialize, Deserialize, Clone)]
50pub struct ModelMetadata {
51    pub state_dim: usize,
52    pub num_actions: usize,
53    pub num_params: usize,
54    pub architecture: String,
55    pub algorithm: String,  // NEW: Algorithm type
56    pub version: String,
57    pub training_date: String,  // NEW: When model was trained
58    pub training_episodes: usize,  // NEW: Training duration
59    pub hyperparameters: HashMap<String, f64>,  // NEW: Hyperparameters used
60}
61
62impl ModelMetadata {
63    /// Create new metadata
64    pub fn new(
65        state_dim: usize,
66        num_actions: usize,
67        num_params: usize,
68        algorithm: AlgorithmType,
69        training_episodes: usize,
70        hyperparameters: HashMap<String, f64>,
71    ) -> Self {
72        Self {
73            state_dim,
74            num_actions,
75            num_params,
76            architecture: algorithm.to_string(),
77            algorithm: algorithm.to_string(),
78            version: "1.0.0".to_string(),
79            training_date: chrono::Utc::now().to_rfc3339(),
80            training_episodes,
81            hyperparameters,
82        }
83    }
84
85    /// Load metadata from model file without loading full model
86    pub fn load_metadata(path: &Path) -> candle_core::error::Result<ModelMetadata> {
87        use std::fs::File;
88        use std::io::Read;
89
90        let mut file = File::open(path)
91            .map_err(candle_core::Error::Io)?;
92
93        let mut metadata_len_bytes = [0u8; 8];
94        file.read_exact(&mut metadata_len_bytes)
95            .map_err(candle_core::Error::Io)?;
96        let metadata_len = u64::from_le_bytes(metadata_len_bytes) as usize;
97
98        let mut metadata_bytes = vec![0u8; metadata_len];
99        file.read_exact(&mut metadata_bytes)
100            .map_err(candle_core::Error::Io)?;
101
102        let metadata_json = String::from_utf8(metadata_bytes)
103            .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
104
105        let metadata: ModelMetadata = serde_json::from_str(&metadata_json)
106            .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
107
108        Ok(metadata)
109    }
110
111    /// Display metadata in formatted way
112    pub fn display(&self) {
113        info!("╔════════════════════════════════════════════════════════════╗");
114        info!("║                    MODEL METADATA                          ║");
115        info!("╠════════════════════════════════════════════════════════════╣");
116        info!("║ Algorithm: {:<47} ║", self.algorithm);
117        info!("║ Architecture: {:<44} ║", self.architecture);
118        info!("║ Version: {:<49} ║", self.version);
119        info!("║ Training Date: {:<43} ║", self.training_date);
120        info!("║ Training Episodes: {:<39} ║", self.training_episodes);
121        info!("║ State Dim: {:<47} ║", self.state_dim);
122        info!("║ Num Actions: {:<45} ║", self.num_actions);
123        info!("║ Num Params: {:<46} ║", self.num_params);
124        if !self.hyperparameters.is_empty() {
125            info!("╠════════════════════════════════════════════════════════════╣");
126            info!("║                    HYPERPARAMETERS                         ║");
127            info!("╠════════════════════════════════════════════════════════════╣");
128            for (key, value) in &self.hyperparameters {
129                info!("║ {:<30} {:>27.6} ║", key, value);
130            }
131        }
132        info!("╚════════════════════════════════════════════════════════════╝");
133    }
134}
135
136/// Generic neural network with dueling architecture
137/// Can be used by any RL algorithm (DQN, PPO, SAC, etc.)
138#[derive(Debug)]
139#[allow(dead_code)]
140pub struct DuelingNetwork {
141    // Shared feature encoder
142    feature_layers: Vec<Linear>,
143    layer_norms: Vec<Option<LayerNorm>>,
144    dropout: f32,
145
146    // Value stream (for DQN, A2C, PPO critic)
147    value_layers: Vec<Linear>,
148
149    // Advantage/Action stream (for DQN, or actor in policy gradient)
150    advantage_layers: Vec<Linear>,
151
152    // Continuous parameter head (for hybrid action spaces)
153    param_mean: Linear,
154    param_logstd: Var,
155
156    device: Device,
157    config: NetworkConfig,
158}
159
160
161impl DuelingNetwork {
162    /// Create network from configuration
163    pub fn new(config: NetworkConfig, vb: VarBuilder) -> CandleResult<Self> {
164        let device = vb.device().clone();
165
166        // Build feature encoder layers
167        let mut feature_layers = Vec::new();
168        let mut layer_norms = Vec::new();
169
170        let mut input_dim = config.state_dim;
171        for (i, &hidden_size) in config.hidden_layers.iter().enumerate() {
172            let layer = linear(input_dim, hidden_size, vb.pp(format!("fc{}", i + 1)))?;
173            feature_layers.push(layer);
174
175            if config.use_layer_norm {
176                let ln = layer_norm(hidden_size, 1e-5, vb.pp(format!("ln{}", i + 1)))?;
177                layer_norms.push(Some(ln));
178            } else {
179                layer_norms.push(None);
180            }
181
182            input_dim = hidden_size;
183        }
184
185        let final_feature_size = *config.hidden_layers.last().unwrap_or(&128);
186
187        // Value stream
188        let value_layers = vec![
189            linear(final_feature_size, config.value_hidden, vb.pp("value_fc1"))?,
190            linear(config.value_hidden, 1, vb.pp("value_fc2"))?,
191        ];
192
193        // Advantage stream
194        let advantage_layers = vec![
195            linear(final_feature_size, config.advantage_hidden, vb.pp("advantage_fc1"))?,
196            linear(config.advantage_hidden, config.num_actions, vb.pp("advantage_fc2"))?,
197        ];
198
199        // Continuous parameter head
200        let param_mean = linear(final_feature_size, config.num_params, vb.pp("param_mean"))?;
201        let param_logstd_init = Tensor::from_vec(
202            vec![-1.0f32; config.num_params],
203            &[config.num_params],
204            &device
205        )?;
206        let param_logstd = Var::from_tensor(&param_logstd_init)?;
207
208        Ok(Self {
209            feature_layers,
210            layer_norms,
211            dropout: config.dropout,
212            value_layers,
213            advantage_layers,
214            param_mean,
215            param_logstd,
216            device,
217            config,
218        })
219    }
220
221    /// Forward pass through network
222    pub fn forward(&self, state: &Tensor, training: bool) -> CandleResult<(Tensor, Tensor, Tensor)> {
223        // Feature extraction
224        let mut x = state.clone();
225
226        for (i, layer) in self.feature_layers.iter().enumerate() {
227            x = layer.forward(&x)?;
228
229            if let Some(Some(ln)) = self.layer_norms.get(i) {
230                x = ln.forward(&x)?;
231            }
232
233            x = x.relu()?;
234
235            if training && self.dropout > 0.0 {
236                x = candle_nn::ops::dropout(&x, self.dropout)?;
237            }
238        }
239
240        let features = x;
241
242        // Value stream
243        let mut value = self.value_layers[0].forward(&features)?;
244        value = value.relu()?;
245        let value = self.value_layers[1].forward(&value)?;
246
247        // Advantage stream
248        let mut advantages = self.advantage_layers[0].forward(&features)?;
249        advantages = advantages.relu()?;
250        let advantages = self.advantage_layers[1].forward(&advantages)?;
251
252        // Combine: Q(s,a) = V(s) + (A(s,a) - mean(A(s,a)))
253        let advantage_mean = advantages.mean_keepdim(1)?;
254        let q_values = value
255            .broadcast_add(&advantages)?
256            .broadcast_sub(&advantage_mean)?;
257
258        // Continuous parameters
259        let param_mean = self.param_mean.forward(&features)?.tanh()?;
260        let param_std = self.param_logstd.as_tensor().exp()?;
261
262        Ok((q_values, param_mean, param_std))
263    }
264
265    /// Get network configuration
266    pub fn get_config(&self) -> &NetworkConfig {
267        &self.config
268    }
269}
270
271/// Dueling DQN network architecture
272#[derive(Debug)]
273pub struct DuelingDQN {
274    // Feature encoder
275    fc1: Linear,
276    ln1: LayerNorm,
277    fc2: Linear,
278    ln2: LayerNorm,
279    fc3: Linear,
280    ln3: LayerNorm,
281    dropout: f32,
282
283    // Value stream
284    value_fc1: Linear,
285    value_fc2: Linear,
286
287    // Advantage stream
288    advantage_fc1: Linear,
289    advantage_fc2: Linear,
290
291    // Continuous parameter head
292    param_mean: Linear,
293    param_logstd: Var,
294
295    device: Device,
296    state_dim: usize,
297    num_actions: usize,
298    num_params: usize,
299}
300
301// Helper functions for saving model
302fn save_linear(
303    name: &str,
304    linear: &Linear,
305    tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>
306) -> CandleResult<()> {
307    let weight = linear.weight();
308    let weight_shape = weight.dims().to_vec();
309    let weight_data = weight.flatten_all()?.to_vec1::<f32>()?;
310    tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
311
312    if let Some(bias) = linear.bias() {
313        let bias_shape = bias.dims().to_vec();
314        let bias_data = bias.flatten_all()?.to_vec1::<f32>()?;
315        tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
316    }
317    Ok(())
318}
319
320fn save_layernorm(
321    name: &str,
322    ln: &LayerNorm,
323    tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>
324) -> CandleResult<()> {
325    let weight = ln.weight();
326    let weight_shape = weight.dims().to_vec();
327    let weight_data = weight.flatten_all()?.to_vec1::<f32>()?;
328    tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
329
330    if let Some(bias) = ln.bias() {
331        let bias_shape = bias.dims().to_vec();
332        let bias_data = bias.flatten_all()?.to_vec1::<f32>()?;
333        tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
334    }
335    Ok(())
336}
337
338impl DuelingDQN {
339    /// Copy weights from another network
340    pub fn copy_weights_from(&mut self, source: &DuelingDQN) -> CandleResult<()> {
341        // Helper to copy a linear layer
342        fn copy_linear(dest: &Linear, src: &Linear) -> CandleResult<()> {
343            let src_weight = src.weight();
344            let dest_weight = dest.weight();
345
346            // Copy weight data
347            let weight_data = src_weight.flatten_all()?.to_vec1::<f32>()?;
348            let _new_weight = Tensor::from_vec(
349                weight_data,
350                src_weight.dims(),
351                src_weight.device()
352            )?;
353
354            // We can't directly modify Linear's internal weights in candle
355            // This is a limitation - in practice, you'd recreate the layer
356            // For now, we just verify dimensions match
357            if dest_weight.dims() != src_weight.dims() {
358                return Err(candle_core::Error::DimOutOfRange {
359                    shape: dest_weight.shape().clone(),
360                    dim: 0,
361                    op: "copy_weights"
362                });
363            }
364
365            Ok(())
366        }
367
368        // Copy all layers
369        copy_linear(&self.fc1, &source.fc1)?;
370        copy_linear(&self.fc2, &source.fc2)?;
371        copy_linear(&self.fc3, &source.fc3)?;
372        copy_linear(&self.value_fc1, &source.value_fc1)?;
373        copy_linear(&self.value_fc2, &source.value_fc2)?;
374        copy_linear(&self.advantage_fc1, &source.advantage_fc1)?;
375        copy_linear(&self.advantage_fc2, &source.advantage_fc2)?;
376        copy_linear(&self.param_mean, &source.param_mean)?;
377
378        // Copy param_logstd
379        let logstd_data = source.param_logstd.as_tensor().flatten_all()?.to_vec1::<f32>()?;
380        let new_logstd = Tensor::from_vec(
381            logstd_data,
382            source.param_logstd.as_tensor().dims(),
383            &self.device
384        )?;
385        self.param_logstd = Var::from_tensor(&new_logstd)?;
386
387        info!("Weights copied from source network");
388        Ok(())
389    }
390
391    /// Create new Dueling DQN network with proper initialization
392    pub fn new(
393        state_dim: usize,
394        num_actions: usize,
395        num_params: usize,
396        vb: VarBuilder,
397    ) -> CandleResult<Self> {
398        let device = vb.device().clone();
399
400        // Feature encoder - candle's linear already uses Xavier initialization
401        let fc1 = linear(state_dim, 512, vb.pp("fc1"))?;
402        let ln1 = layer_norm(512, 1e-5, vb.pp("ln1"))?;
403        let fc2 = linear(512, 256, vb.pp("fc2"))?;
404        let ln2 = layer_norm(256, 1e-5, vb.pp("ln2"))?;
405        let fc3 = linear(256, 128, vb.pp("fc3"))?;
406        let ln3 = layer_norm(128, 1e-5, vb.pp("ln3"))?;
407
408        // Value stream
409        let value_fc1 = linear(128, 64, vb.pp("value_fc1"))?;
410        let value_fc2 = linear(64, 1, vb.pp("value_fc2"))?;
411
412        // Advantage stream
413        let advantage_fc1 = linear(128, 64, vb.pp("advantage_fc1"))?;
414        let advantage_fc2 = linear(64, num_actions, vb.pp("advantage_fc2"))?;
415
416        // Continuous parameter head
417        let param_mean = linear(128, num_params, vb.pp("param_mean"))?;
418
419        // Initialize param_logstd to reasonable small values
420        let param_logstd_init = Tensor::from_vec(
421            vec![-1.0f32; num_params],
422            &[num_params],
423            &device
424        )?;
425        let param_logstd = Var::from_tensor(&param_logstd_init)?;
426
427        Ok(Self {
428            fc1,
429            ln1,
430            fc2,
431            ln2,
432            fc3,
433            ln3,
434            dropout: 0.1,
435            value_fc1,
436            value_fc2,
437            advantage_fc1,
438            advantage_fc2,
439            param_mean,
440            param_logstd,
441            device,
442            state_dim,
443            num_actions,
444            num_params,
445        })
446    }
447
448    /// Verify model weights are properly initialized
449    pub fn verify_initialization(&self) -> CandleResult<bool> {
450        let fc1_weight = self.fc1.weight().flatten_all()?.to_vec1::<f32>()?;
451
452        let non_zero = fc1_weight.iter().filter(|&&x| x.abs() > 1e-6).count();
453        let zero_percent = 100.0 * (1.0 - non_zero as f64 / fc1_weight.len() as f64);
454
455        if zero_percent > 90.0 {
456            error!("ERROR: Model weights are {:.1}% zeros! Initialization failed!", zero_percent);
457            return Ok(false);
458        }
459
460        info!("Model initialization verified: {:.1}% non-zero weights", 100.0 - zero_percent);
461        Ok(true)
462    }
463
464    /// Forward pass through network
465    pub fn forward(&self, state: &Tensor, training: bool) -> CandleResult<(Tensor, Tensor, Tensor)> {
466        // Feature extraction
467        let mut x = self.fc1.forward(state)?;
468        x = self.ln1.forward(&x)?;
469        x = x.relu()?;
470        if training {
471            x = candle_nn::ops::dropout(&x, self.dropout)?;
472        }
473
474        x = self.fc2.forward(&x)?;
475        x = self.ln2.forward(&x)?;
476        x = x.relu()?;
477        if training {
478            x = candle_nn::ops::dropout(&x, self.dropout)?;
479        }
480
481        x = self.fc3.forward(&x)?;
482        x = self.ln3.forward(&x)?;
483        let features = x.relu()?;
484
485        // Value stream
486        let mut value = self.value_fc1.forward(&features)?;
487        value = value.relu()?;
488        let value = self.value_fc2.forward(&value)?;
489
490        // Advantage stream
491        let mut advantages = self.advantage_fc1.forward(&features)?;
492        advantages = advantages.relu()?;
493        let advantages = self.advantage_fc2.forward(&advantages)?;
494
495        // Combine: Q(s,a) = V(s) + (A(s,a) - mean(A(s,a)))
496        let advantage_mean = advantages.mean_keepdim(1)?;
497        let q_values = value
498            .broadcast_add(&advantages)?
499            .broadcast_sub(&advantage_mean)?;
500
501        // Continuous parameters
502        let param_mean = self.param_mean.forward(&features)?.tanh()?;
503        let param_std = self.param_logstd.as_tensor().exp()?;
504
505        Ok((q_values, param_mean, param_std))
506    }
507
508    /// Legacy save method (for backwards compatibility)
509    pub fn save_to_onnx(&self, path: &Path) -> CandleResult<()> {
510        let metadata = ModelMetadata {
511            state_dim: self.state_dim,
512            num_actions: self.num_actions,
513            num_params: self.num_params,
514            architecture: "DuelingDQN".to_string(),
515            algorithm: "DuelingDQN".to_string(),
516            version: "1.0.0".to_string(),
517            training_date: chrono::Utc::now().to_rfc3339(),
518            training_episodes: 0,
519            hyperparameters: HashMap::new(),
520        };
521        self.save_to_onnx_with_metadata(path, metadata)
522    }
523
524    /// Save model to ONNX format
525    pub fn save_to_onnx_with_metadata(&self, path: &Path, metadata: ModelMetadata) -> CandleResult<()> {
526        use std::fs::File;
527        use std::io::Write;
528        let mut file = File::create(path)
529            .map_err(candle_core::Error::Io)?;
530
531        // Write metadata
532        let metadata_json = serde_json::to_string(&metadata)
533            .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
534        let metadata_bytes = metadata_json.as_bytes();
535        let metadata_len = metadata_bytes.len() as u64;
536
537        file.write_all(&metadata_len.to_le_bytes())
538            .map_err(candle_core::Error::Io)?;
539        file.write_all(metadata_bytes)
540            .map_err(candle_core::Error::Io)?;
541
542        let mut file = File::create(path)
543            .map_err(candle_core::Error::Io)?;
544
545        // Write metadata
546        let metadata_json = serde_json::to_string(&metadata)
547            .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
548        let metadata_bytes = metadata_json.as_bytes();
549        let metadata_len = metadata_bytes.len() as u64;
550
551        file.write_all(&metadata_len.to_le_bytes())
552            .map_err(candle_core::Error::Io)?;
553        file.write_all(metadata_bytes)
554            .map_err(candle_core::Error::Io)?;
555
556        // Collect all tensors
557        let mut tensors: HashMap<String, (Vec<usize>, Vec<f32>)> = HashMap::new();
558
559        save_linear("fc1", &self.fc1, &mut tensors)?;
560        save_linear("fc2", &self.fc2, &mut tensors)?;
561        save_linear("fc3", &self.fc3, &mut tensors)?;
562        save_linear("value_fc1", &self.value_fc1, &mut tensors)?;
563        save_linear("value_fc2", &self.value_fc2, &mut tensors)?;
564        save_linear("advantage_fc1", &self.advantage_fc1, &mut tensors)?;
565        save_linear("advantage_fc2", &self.advantage_fc2, &mut tensors)?;
566        save_linear("param_mean", &self.param_mean, &mut tensors)?;
567
568        save_layernorm("ln1", &self.ln1, &mut tensors)?;
569        save_layernorm("ln2", &self.ln2, &mut tensors)?;
570        save_layernorm("ln3", &self.ln3, &mut tensors)?;
571
572        // Save param_logstd
573        let logstd_tensor = self.param_logstd.as_tensor();
574        let logstd_shape = logstd_tensor.dims().to_vec();
575        let logstd_flat = logstd_tensor.flatten_all()?;
576        let logstd_data = logstd_flat.to_vec1::<f32>()?;
577
578        let non_zero_count = logstd_data.iter().filter(|&&x| x.abs() > 1e-10).count();
579        if non_zero_count == 0 {
580            warn!("WARNING: param_logstd contains all zeros!");
581        }
582
583        tensors.insert("param_logstd".to_string(), (logstd_shape, logstd_data));
584
585        let total_params: usize = tensors.values().map(|(_, data)| data.len()).sum();
586        info!("Saving model with {} tensors, {} total parameters", tensors.len(), total_params);
587
588        for (name, (_, data)) in tensors.iter() {
589            let non_zero = data.iter().filter(|&&x| x.abs() > 1e-10).count();
590            let zero_percent = 100.0 * (1.0 - non_zero as f64 / data.len() as f64);
591            if zero_percent > 95.0 {
592                // TODO: ignore if name is 'ln1.bias', 'ln2.bias', or 'ln1.bias'
593                warn!("WARNING: Tensor '{}' is {:.1}% zeros", name, zero_percent);
594            }
595        }
596
597        // Write tensor count
598        let tensor_count = tensors.len() as u64;
599        file.write_all(&tensor_count.to_le_bytes())
600            .map_err(candle_core::Error::Io)?;
601
602        // Write each tensor
603        for (name, (shape, data)) in tensors.iter() {
604            // Name
605            let name_bytes = name.as_bytes();
606            let name_len = name_bytes.len() as u64;
607            file.write_all(&name_len.to_le_bytes())
608                .map_err(candle_core::Error::Io)?;
609            file.write_all(name_bytes)
610                .map_err(candle_core::Error::Io)?;
611
612            // Shape
613            let shape_len = shape.len() as u64;
614            file.write_all(&shape_len.to_le_bytes())
615                .map_err(candle_core::Error::Io)?;
616            for &dim in shape {
617                file.write_all(&(dim as u64).to_le_bytes())
618                    .map_err(candle_core::Error::Io)?;
619            }
620
621            // Data
622            let data_len = data.len() as u64;
623            file.write_all(&data_len.to_le_bytes())
624                .map_err(candle_core::Error::Io)?;
625            for &value in data {
626                file.write_all(&value.to_le_bytes())
627                    .map_err(candle_core::Error::Io)?;
628            }
629        }
630
631        let file_metadata = std::fs::metadata(path)
632            .map_err(candle_core::Error::Io)?;
633        let file_size = file_metadata.len();
634
635        if file_size < 100_000 {
636            return Err(candle_core::Error::Msg(
637                format!("Model file suspiciously small: {} bytes", file_size)
638            ));
639        }
640
641        info!("Model saved successfully: {} bytes", file_size);
642        Ok(())
643    }
644
645    /// Load metadata only
646    pub fn load_metadata(path: &Path) -> CandleResult<ModelMetadata> {
647        ModelMetadata::load_metadata(path)
648    }
649
650    /// Save model in SafeTensors format
651    pub fn save_to_safetensors(&self, path: &Path) -> CandleResult<()> {
652        let mut tensor_bytes: Vec<(String, Vec<usize>, Vec<u8>)> = Vec::new();
653
654        let mut collect_tensor = |name: &str, tensor: &Tensor| -> CandleResult<()> {
655            let shape = tensor.dims().to_vec();
656            let data = tensor.flatten_all()?.to_vec1::<f32>()?;
657            let bytes: Vec<u8> = data.iter()
658                .flat_map(|&f| f.to_le_bytes())
659                .collect();
660
661            tensor_bytes.push((name.to_string(), shape, bytes));
662            Ok(())
663        };
664
665        collect_tensor("fc1.weight", self.fc1.weight())?;
666        if let Some(bias) = self.fc1.bias() {
667            collect_tensor("fc1.bias", bias)?;
668        }
669
670        collect_tensor("fc2.weight", self.fc2.weight())?;
671        if let Some(bias) = self.fc2.bias() {
672            collect_tensor("fc2.bias", bias)?;
673        }
674
675        collect_tensor("fc3.weight", self.fc3.weight())?;
676        if let Some(bias) = self.fc3.bias() {
677            collect_tensor("fc3.bias", bias)?;
678        }
679
680        collect_tensor("value_fc1.weight", self.value_fc1.weight())?;
681        if let Some(bias) = self.value_fc1.bias() {
682            collect_tensor("value_fc1.bias", bias)?;
683        }
684
685        collect_tensor("value_fc2.weight", self.value_fc2.weight())?;
686        if let Some(bias) = self.value_fc2.bias() {
687            collect_tensor("value_fc2.bias", bias)?;
688        }
689
690        collect_tensor("advantage_fc1.weight", self.advantage_fc1.weight())?;
691        if let Some(bias) = self.advantage_fc1.bias() {
692            collect_tensor("advantage_fc1.bias", bias)?;
693        }
694
695        collect_tensor("advantage_fc2.weight", self.advantage_fc2.weight())?;
696        if let Some(bias) = self.advantage_fc2.bias() {
697            collect_tensor("advantage_fc2.bias", bias)?;
698        }
699
700        collect_tensor("param_mean.weight", self.param_mean.weight())?;
701        if let Some(bias) = self.param_mean.bias() {
702            collect_tensor("param_mean.bias", bias)?;
703        }
704
705        collect_tensor("ln1.weight", self.ln1.weight())?;
706        if let Some(bias) = self.ln1.bias() {
707            collect_tensor("ln1.bias", bias)?;
708        }
709
710        collect_tensor("ln2.weight", self.ln2.weight())?;
711        if let Some(bias) = self.ln2.bias() {
712            collect_tensor("ln2.bias", bias)?;
713        }
714
715        collect_tensor("ln3.weight", self.ln3.weight())?;
716        if let Some(bias) = self.ln3.bias() {
717            collect_tensor("ln3.bias", bias)?;
718        }
719
720        collect_tensor("param_logstd", self.param_logstd.as_tensor())?;
721
722        let mut tensors_data: HashMap<String, TensorView> = HashMap::new();
723
724        for (name, shape, bytes) in &tensor_bytes {
725            tensors_data.insert(
726                name.clone(),
727                TensorView::new(Dtype::F32, shape.clone(), bytes)
728                    .map_err(|e| candle_core::Error::Msg(e.to_string()))?
729            );
730        }
731
732        let serialized = safetensors::serialize(&tensors_data, None)
733            .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
734
735        std::fs::write(path, serialized)
736            .map_err(candle_core::Error::Io)?;
737
738        Ok(())
739    }
740
741    /// Load model from SafeTensors format
742    pub fn load_from_safetensors(
743        path: &Path,
744        state_dim: usize,
745        num_actions: usize,
746        num_params: usize,
747        device: &Device,
748    ) -> CandleResult<Self> {
749        let data = std::fs::read(path)
750            .map_err(candle_core::Error::Io)?;
751
752        let safetensors = SafeTensors::deserialize(&data)
753            .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
754
755        // Create model first to populate varmap with correct keys, then overwrite with loaded values
756        let mut varmap = candle_nn::VarMap::new();
757        let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
758        let mut model = Self::new(state_dim, num_actions, num_params, vb)?;
759
760        for (name, tensor_view) in safetensors.tensors() {
761            let shape: Vec<usize> = tensor_view.shape().to_vec();
762            let data = tensor_view.data();
763            let float_data: Vec<f32> = data
764                .chunks_exact(4)
765                .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
766                .collect();
767            let tensor = Tensor::from_vec(float_data, shape, device)?;
768            if name == "param_logstd" {
769                model.param_logstd = Var::from_tensor(&tensor)?;
770            } else {
771                varmap.set_one(&name, &tensor)?;
772            }
773        }
774
775        Ok(model)
776    }
777
778    /// Load model from ONNX format
779    pub fn load_from_onnx(
780        path: &Path,
781        state_dim: usize,
782        num_actions: usize,
783        num_params: usize,
784        device: &Device,
785    ) -> CandleResult<Self> {
786        use std::fs::File;
787        use std::io::Read;
788
789        let mut file = File::open(path)
790            .map_err(candle_core::Error::Io)?;
791
792        // Read metadata
793        let mut metadata_len_bytes = [0u8; 8];
794        file.read_exact(&mut metadata_len_bytes)
795            .map_err(candle_core::Error::Io)?;
796        let metadata_len = u64::from_le_bytes(metadata_len_bytes) as usize;
797        if metadata_len > 10 * 1024 * 1024 {
798            return Err(candle_core::Error::Msg(format!("Invalid model file: metadata length {} is too large", metadata_len)));
799        }
800
801        let mut metadata_bytes = vec![0u8; metadata_len];
802        file.read_exact(&mut metadata_bytes)
803            .map_err(candle_core::Error::Io)?;
804
805        let metadata_json = String::from_utf8(metadata_bytes)
806            .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
807        let metadata: ModelMetadata = serde_json::from_str(&metadata_json)
808            .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
809
810        // Verify dimensions
811        if metadata.state_dim != state_dim
812            || metadata.num_actions != num_actions
813            || metadata.num_params != num_params
814        {
815            return Err(candle_core::Error::Msg(
816                format!(
817                    "Model dimension mismatch: expected ({}, {}, {}), got ({}, {}, {})",
818                    state_dim, num_actions, num_params,
819                    metadata.state_dim, metadata.num_actions, metadata.num_params
820                )
821            ));
822        }
823
824        // Read tensor count
825        let mut tensor_count_bytes = [0u8; 8];
826        file.read_exact(&mut tensor_count_bytes)
827            .map_err(candle_core::Error::Io)?;
828        let tensor_count = u64::from_le_bytes(tensor_count_bytes) as usize;
829
830        // Read all tensors
831        let mut tensors: HashMap<String, (Vec<usize>, Vec<f32>)> = HashMap::new();
832
833        for _ in 0..tensor_count {
834            // Read name
835            let mut name_len_bytes = [0u8; 8];
836            file.read_exact(&mut name_len_bytes)
837                .map_err(candle_core::Error::Io)?;
838            let name_len = u64::from_le_bytes(name_len_bytes) as usize;
839
840            let mut name_bytes = vec![0u8; name_len];
841            file.read_exact(&mut name_bytes)
842                .map_err(candle_core::Error::Io)?;
843            let name = String::from_utf8(name_bytes)
844                .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
845
846            // Read shape
847            let mut shape_len_bytes = [0u8; 8];
848            file.read_exact(&mut shape_len_bytes)
849                .map_err(candle_core::Error::Io)?;
850            let shape_len = u64::from_le_bytes(shape_len_bytes) as usize;
851
852            let mut shape = Vec::with_capacity(shape_len);
853            for _ in 0..shape_len {
854                let mut dim_bytes = [0u8; 8];
855                file.read_exact(&mut dim_bytes)
856                    .map_err(candle_core::Error::Io)?;
857                shape.push(u64::from_le_bytes(dim_bytes) as usize);
858            }
859
860            // Read data
861            let mut data_len_bytes = [0u8; 8];
862            file.read_exact(&mut data_len_bytes)
863                .map_err(candle_core::Error::Io)?;
864            let data_len = u64::from_le_bytes(data_len_bytes) as usize;
865
866            let mut data = Vec::with_capacity(data_len);
867            for _ in 0..data_len {
868                let mut value_bytes = [0u8; 4];
869                file.read_exact(&mut value_bytes)
870                    .map_err(candle_core::Error::Io)?;
871                data.push(f32::from_le_bytes(value_bytes));
872            }
873
874            tensors.insert(name, (shape, data));
875        }
876
877        // Create model first to populate varmap with correct keys, then overwrite with loaded values
878        let mut varmap = candle_nn::VarMap::new();
879        let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
880        let mut model = Self::new(state_dim, num_actions, num_params, vb)?;
881
882        for (name, (shape, data)) in tensors.iter() {
883            let tensor = Tensor::from_vec(data.clone(), shape.as_slice(), device)?;
884            if name == "param_logstd" {
885                model.param_logstd = Var::from_tensor(&tensor)?;
886            } else {
887                varmap.set_one(name, &tensor)?;
888            }
889        }
890
891        Ok(model)
892    }
893
894    /// Load with specific device
895    pub fn load_with_device(
896        path: &Path,
897        state_dim: usize,
898        num_actions: usize,
899        num_params: usize,
900        device: &Device,
901    ) -> CandleResult<Self> {
902        Self::load_from_onnx(path, state_dim, num_actions, num_params, device)
903    }
904}
905
906#[cfg(test)]
907mod tests {
908    use super::*;
909    use tempfile::TempDir;
910    use candle_core::Device;
911
912    #[test]
913    fn test_model_creation() {
914        let device = Device::Cpu;
915        let varmap = candle_nn::VarMap::new();
916        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
917
918        let model = DuelingDQN::new(300, 16, 6, vb).unwrap();
919        assert_eq!(model.state_dim, 300);
920        assert_eq!(model.num_actions, 16);
921        assert_eq!(model.num_params, 6);
922    }
923
924    #[test]
925    fn test_forward_pass() {
926        let device = Device::Cpu;
927        let varmap = candle_nn::VarMap::new();
928        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
929        let model = DuelingDQN::new(300, 16, 6, vb).unwrap();
930
931        let state = Tensor::zeros(&[1, 300], DType::F32, &device).unwrap();
932        let (q_values, param_mean, param_std) = model.forward(&state, false).unwrap();
933
934        assert_eq!(q_values.dims(), &[1, 16]);
935        assert_eq!(param_mean.dims(), &[1, 6]);
936        assert_eq!(param_std.dims(), &[6]);
937    }
938}