1use candle_core::{Device, Tensor, DType, Result as CandleResult, Var};
6use candle_nn::{Linear, Module, VarBuilder, linear, layer_norm, LayerNorm};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10use safetensors::SafeTensors;
11use safetensors::tensor::{Dtype, TensorView};
12use tracing::{error, info, warn};
13use crate::agents::AlgorithmType;
14use chrono;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct NetworkConfig {
19 pub state_dim: usize,
20 pub num_actions: usize,
21 pub num_params: usize,
22
23 pub hidden_layers: Vec<usize>, pub use_layer_norm: bool,
26 pub dropout: f32,
27
28 pub value_hidden: usize, pub advantage_hidden: usize, }
32
33impl Default for NetworkConfig {
34 fn default() -> Self {
35 Self {
36 state_dim: 300,
37 num_actions: 16,
38 num_params: 6,
39 hidden_layers: vec![512, 256, 128],
40 use_layer_norm: true,
41 dropout: 0.1,
42 value_hidden: 64,
43 advantage_hidden: 64,
44 }
45 }
46}
47
48#[derive(Debug, Serialize, Deserialize, Clone)]
50pub struct ModelMetadata {
51 pub state_dim: usize,
52 pub num_actions: usize,
53 pub num_params: usize,
54 pub architecture: String,
55 pub algorithm: String, pub version: String,
57 pub training_date: String, pub training_episodes: usize, pub hyperparameters: HashMap<String, f64>, }
61
62impl ModelMetadata {
63 pub fn new(
65 state_dim: usize,
66 num_actions: usize,
67 num_params: usize,
68 algorithm: AlgorithmType,
69 training_episodes: usize,
70 hyperparameters: HashMap<String, f64>,
71 ) -> Self {
72 Self {
73 state_dim,
74 num_actions,
75 num_params,
76 architecture: algorithm.to_string(),
77 algorithm: algorithm.to_string(),
78 version: "1.0.0".to_string(),
79 training_date: chrono::Utc::now().to_rfc3339(),
80 training_episodes,
81 hyperparameters,
82 }
83 }
84
85 pub fn load_metadata(path: &Path) -> candle_core::error::Result<ModelMetadata> {
87 use std::fs::File;
88 use std::io::Read;
89
90 let mut file = File::open(path)
91 .map_err(candle_core::Error::Io)?;
92
93 let mut metadata_len_bytes = [0u8; 8];
94 file.read_exact(&mut metadata_len_bytes)
95 .map_err(candle_core::Error::Io)?;
96 let metadata_len = u64::from_le_bytes(metadata_len_bytes) as usize;
97
98 let mut metadata_bytes = vec![0u8; metadata_len];
99 file.read_exact(&mut metadata_bytes)
100 .map_err(candle_core::Error::Io)?;
101
102 let metadata_json = String::from_utf8(metadata_bytes)
103 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
104
105 let metadata: ModelMetadata = serde_json::from_str(&metadata_json)
106 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
107
108 Ok(metadata)
109 }
110
111 pub fn display(&self) {
113 info!("╔════════════════════════════════════════════════════════════╗");
114 info!("║ MODEL METADATA ║");
115 info!("╠════════════════════════════════════════════════════════════╣");
116 info!("║ Algorithm: {:<47} ║", self.algorithm);
117 info!("║ Architecture: {:<44} ║", self.architecture);
118 info!("║ Version: {:<49} ║", self.version);
119 info!("║ Training Date: {:<43} ║", self.training_date);
120 info!("║ Training Episodes: {:<39} ║", self.training_episodes);
121 info!("║ State Dim: {:<47} ║", self.state_dim);
122 info!("║ Num Actions: {:<45} ║", self.num_actions);
123 info!("║ Num Params: {:<46} ║", self.num_params);
124 if !self.hyperparameters.is_empty() {
125 info!("╠════════════════════════════════════════════════════════════╣");
126 info!("║ HYPERPARAMETERS ║");
127 info!("╠════════════════════════════════════════════════════════════╣");
128 for (key, value) in &self.hyperparameters {
129 info!("║ {:<30} {:>27.6} ║", key, value);
130 }
131 }
132 info!("╚════════════════════════════════════════════════════════════╝");
133 }
134}
135
136#[derive(Debug)]
139#[allow(dead_code)]
140pub struct DuelingNetwork {
141 feature_layers: Vec<Linear>,
143 layer_norms: Vec<Option<LayerNorm>>,
144 dropout: f32,
145
146 value_layers: Vec<Linear>,
148
149 advantage_layers: Vec<Linear>,
151
152 param_mean: Linear,
154 param_logstd: Var,
155
156 device: Device,
157 config: NetworkConfig,
158}
159
160
161impl DuelingNetwork {
162 pub fn new(config: NetworkConfig, vb: VarBuilder) -> CandleResult<Self> {
164 let device = vb.device().clone();
165
166 let mut feature_layers = Vec::new();
168 let mut layer_norms = Vec::new();
169
170 let mut input_dim = config.state_dim;
171 for (i, &hidden_size) in config.hidden_layers.iter().enumerate() {
172 let layer = linear(input_dim, hidden_size, vb.pp(format!("fc{}", i + 1)))?;
173 feature_layers.push(layer);
174
175 if config.use_layer_norm {
176 let ln = layer_norm(hidden_size, 1e-5, vb.pp(format!("ln{}", i + 1)))?;
177 layer_norms.push(Some(ln));
178 } else {
179 layer_norms.push(None);
180 }
181
182 input_dim = hidden_size;
183 }
184
185 let final_feature_size = *config.hidden_layers.last().unwrap_or(&128);
186
187 let value_layers = vec![
189 linear(final_feature_size, config.value_hidden, vb.pp("value_fc1"))?,
190 linear(config.value_hidden, 1, vb.pp("value_fc2"))?,
191 ];
192
193 let advantage_layers = vec![
195 linear(final_feature_size, config.advantage_hidden, vb.pp("advantage_fc1"))?,
196 linear(config.advantage_hidden, config.num_actions, vb.pp("advantage_fc2"))?,
197 ];
198
199 let param_mean = linear(final_feature_size, config.num_params, vb.pp("param_mean"))?;
201 let param_logstd_init = Tensor::from_vec(
202 vec![-1.0f32; config.num_params],
203 &[config.num_params],
204 &device
205 )?;
206 let param_logstd = Var::from_tensor(¶m_logstd_init)?;
207
208 Ok(Self {
209 feature_layers,
210 layer_norms,
211 dropout: config.dropout,
212 value_layers,
213 advantage_layers,
214 param_mean,
215 param_logstd,
216 device,
217 config,
218 })
219 }
220
221 pub fn forward(&self, state: &Tensor, training: bool) -> CandleResult<(Tensor, Tensor, Tensor)> {
223 let mut x = state.clone();
225
226 for (i, layer) in self.feature_layers.iter().enumerate() {
227 x = layer.forward(&x)?;
228
229 if let Some(Some(ln)) = self.layer_norms.get(i) {
230 x = ln.forward(&x)?;
231 }
232
233 x = x.relu()?;
234
235 if training && self.dropout > 0.0 {
236 x = candle_nn::ops::dropout(&x, self.dropout)?;
237 }
238 }
239
240 let features = x;
241
242 let mut value = self.value_layers[0].forward(&features)?;
244 value = value.relu()?;
245 let value = self.value_layers[1].forward(&value)?;
246
247 let mut advantages = self.advantage_layers[0].forward(&features)?;
249 advantages = advantages.relu()?;
250 let advantages = self.advantage_layers[1].forward(&advantages)?;
251
252 let advantage_mean = advantages.mean_keepdim(1)?;
254 let q_values = value
255 .broadcast_add(&advantages)?
256 .broadcast_sub(&advantage_mean)?;
257
258 let param_mean = self.param_mean.forward(&features)?.tanh()?;
260 let param_std = self.param_logstd.as_tensor().exp()?;
261
262 Ok((q_values, param_mean, param_std))
263 }
264
265 pub fn get_config(&self) -> &NetworkConfig {
267 &self.config
268 }
269}
270
271#[derive(Debug)]
273pub struct DuelingDQN {
274 fc1: Linear,
276 ln1: LayerNorm,
277 fc2: Linear,
278 ln2: LayerNorm,
279 fc3: Linear,
280 ln3: LayerNorm,
281 dropout: f32,
282
283 value_fc1: Linear,
285 value_fc2: Linear,
286
287 advantage_fc1: Linear,
289 advantage_fc2: Linear,
290
291 param_mean: Linear,
293 param_logstd: Var,
294
295 device: Device,
296 state_dim: usize,
297 num_actions: usize,
298 num_params: usize,
299}
300
301fn save_linear(
303 name: &str,
304 linear: &Linear,
305 tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>
306) -> CandleResult<()> {
307 let weight = linear.weight();
308 let weight_shape = weight.dims().to_vec();
309 let weight_data = weight.flatten_all()?.to_vec1::<f32>()?;
310 tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
311
312 if let Some(bias) = linear.bias() {
313 let bias_shape = bias.dims().to_vec();
314 let bias_data = bias.flatten_all()?.to_vec1::<f32>()?;
315 tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
316 }
317 Ok(())
318}
319
320fn save_layernorm(
321 name: &str,
322 ln: &LayerNorm,
323 tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>
324) -> CandleResult<()> {
325 let weight = ln.weight();
326 let weight_shape = weight.dims().to_vec();
327 let weight_data = weight.flatten_all()?.to_vec1::<f32>()?;
328 tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
329
330 if let Some(bias) = ln.bias() {
331 let bias_shape = bias.dims().to_vec();
332 let bias_data = bias.flatten_all()?.to_vec1::<f32>()?;
333 tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
334 }
335 Ok(())
336}
337
338impl DuelingDQN {
339 pub fn copy_weights_from(&mut self, source: &DuelingDQN) -> CandleResult<()> {
341 fn copy_linear(dest: &Linear, src: &Linear) -> CandleResult<()> {
343 let src_weight = src.weight();
344 let dest_weight = dest.weight();
345
346 let weight_data = src_weight.flatten_all()?.to_vec1::<f32>()?;
348 let _new_weight = Tensor::from_vec(
349 weight_data,
350 src_weight.dims(),
351 src_weight.device()
352 )?;
353
354 if dest_weight.dims() != src_weight.dims() {
358 return Err(candle_core::Error::DimOutOfRange {
359 shape: dest_weight.shape().clone(),
360 dim: 0,
361 op: "copy_weights"
362 });
363 }
364
365 Ok(())
366 }
367
368 copy_linear(&self.fc1, &source.fc1)?;
370 copy_linear(&self.fc2, &source.fc2)?;
371 copy_linear(&self.fc3, &source.fc3)?;
372 copy_linear(&self.value_fc1, &source.value_fc1)?;
373 copy_linear(&self.value_fc2, &source.value_fc2)?;
374 copy_linear(&self.advantage_fc1, &source.advantage_fc1)?;
375 copy_linear(&self.advantage_fc2, &source.advantage_fc2)?;
376 copy_linear(&self.param_mean, &source.param_mean)?;
377
378 let logstd_data = source.param_logstd.as_tensor().flatten_all()?.to_vec1::<f32>()?;
380 let new_logstd = Tensor::from_vec(
381 logstd_data,
382 source.param_logstd.as_tensor().dims(),
383 &self.device
384 )?;
385 self.param_logstd = Var::from_tensor(&new_logstd)?;
386
387 info!("Weights copied from source network");
388 Ok(())
389 }
390
391 pub fn new(
393 state_dim: usize,
394 num_actions: usize,
395 num_params: usize,
396 vb: VarBuilder,
397 ) -> CandleResult<Self> {
398 let device = vb.device().clone();
399
400 let fc1 = linear(state_dim, 512, vb.pp("fc1"))?;
402 let ln1 = layer_norm(512, 1e-5, vb.pp("ln1"))?;
403 let fc2 = linear(512, 256, vb.pp("fc2"))?;
404 let ln2 = layer_norm(256, 1e-5, vb.pp("ln2"))?;
405 let fc3 = linear(256, 128, vb.pp("fc3"))?;
406 let ln3 = layer_norm(128, 1e-5, vb.pp("ln3"))?;
407
408 let value_fc1 = linear(128, 64, vb.pp("value_fc1"))?;
410 let value_fc2 = linear(64, 1, vb.pp("value_fc2"))?;
411
412 let advantage_fc1 = linear(128, 64, vb.pp("advantage_fc1"))?;
414 let advantage_fc2 = linear(64, num_actions, vb.pp("advantage_fc2"))?;
415
416 let param_mean = linear(128, num_params, vb.pp("param_mean"))?;
418
419 let param_logstd_init = Tensor::from_vec(
421 vec![-1.0f32; num_params],
422 &[num_params],
423 &device
424 )?;
425 let param_logstd = Var::from_tensor(¶m_logstd_init)?;
426
427 Ok(Self {
428 fc1,
429 ln1,
430 fc2,
431 ln2,
432 fc3,
433 ln3,
434 dropout: 0.1,
435 value_fc1,
436 value_fc2,
437 advantage_fc1,
438 advantage_fc2,
439 param_mean,
440 param_logstd,
441 device,
442 state_dim,
443 num_actions,
444 num_params,
445 })
446 }
447
448 pub fn verify_initialization(&self) -> CandleResult<bool> {
450 let fc1_weight = self.fc1.weight().flatten_all()?.to_vec1::<f32>()?;
451
452 let non_zero = fc1_weight.iter().filter(|&&x| x.abs() > 1e-6).count();
453 let zero_percent = 100.0 * (1.0 - non_zero as f64 / fc1_weight.len() as f64);
454
455 if zero_percent > 90.0 {
456 error!("ERROR: Model weights are {:.1}% zeros! Initialization failed!", zero_percent);
457 return Ok(false);
458 }
459
460 info!("Model initialization verified: {:.1}% non-zero weights", 100.0 - zero_percent);
461 Ok(true)
462 }
463
464 pub fn forward(&self, state: &Tensor, training: bool) -> CandleResult<(Tensor, Tensor, Tensor)> {
466 let mut x = self.fc1.forward(state)?;
468 x = self.ln1.forward(&x)?;
469 x = x.relu()?;
470 if training {
471 x = candle_nn::ops::dropout(&x, self.dropout)?;
472 }
473
474 x = self.fc2.forward(&x)?;
475 x = self.ln2.forward(&x)?;
476 x = x.relu()?;
477 if training {
478 x = candle_nn::ops::dropout(&x, self.dropout)?;
479 }
480
481 x = self.fc3.forward(&x)?;
482 x = self.ln3.forward(&x)?;
483 let features = x.relu()?;
484
485 let mut value = self.value_fc1.forward(&features)?;
487 value = value.relu()?;
488 let value = self.value_fc2.forward(&value)?;
489
490 let mut advantages = self.advantage_fc1.forward(&features)?;
492 advantages = advantages.relu()?;
493 let advantages = self.advantage_fc2.forward(&advantages)?;
494
495 let advantage_mean = advantages.mean_keepdim(1)?;
497 let q_values = value
498 .broadcast_add(&advantages)?
499 .broadcast_sub(&advantage_mean)?;
500
501 let param_mean = self.param_mean.forward(&features)?.tanh()?;
503 let param_std = self.param_logstd.as_tensor().exp()?;
504
505 Ok((q_values, param_mean, param_std))
506 }
507
508 pub fn save_to_onnx(&self, path: &Path) -> CandleResult<()> {
510 let metadata = ModelMetadata {
511 state_dim: self.state_dim,
512 num_actions: self.num_actions,
513 num_params: self.num_params,
514 architecture: "DuelingDQN".to_string(),
515 algorithm: "DuelingDQN".to_string(),
516 version: "1.0.0".to_string(),
517 training_date: chrono::Utc::now().to_rfc3339(),
518 training_episodes: 0,
519 hyperparameters: HashMap::new(),
520 };
521 self.save_to_onnx_with_metadata(path, metadata)
522 }
523
524 pub fn save_to_onnx_with_metadata(&self, path: &Path, metadata: ModelMetadata) -> CandleResult<()> {
526 use std::fs::File;
527 use std::io::Write;
528 let mut file = File::create(path)
529 .map_err(candle_core::Error::Io)?;
530
531 let metadata_json = serde_json::to_string(&metadata)
533 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
534 let metadata_bytes = metadata_json.as_bytes();
535 let metadata_len = metadata_bytes.len() as u64;
536
537 file.write_all(&metadata_len.to_le_bytes())
538 .map_err(candle_core::Error::Io)?;
539 file.write_all(metadata_bytes)
540 .map_err(candle_core::Error::Io)?;
541
542 let mut file = File::create(path)
543 .map_err(candle_core::Error::Io)?;
544
545 let metadata_json = serde_json::to_string(&metadata)
547 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
548 let metadata_bytes = metadata_json.as_bytes();
549 let metadata_len = metadata_bytes.len() as u64;
550
551 file.write_all(&metadata_len.to_le_bytes())
552 .map_err(candle_core::Error::Io)?;
553 file.write_all(metadata_bytes)
554 .map_err(candle_core::Error::Io)?;
555
556 let mut tensors: HashMap<String, (Vec<usize>, Vec<f32>)> = HashMap::new();
558
559 save_linear("fc1", &self.fc1, &mut tensors)?;
560 save_linear("fc2", &self.fc2, &mut tensors)?;
561 save_linear("fc3", &self.fc3, &mut tensors)?;
562 save_linear("value_fc1", &self.value_fc1, &mut tensors)?;
563 save_linear("value_fc2", &self.value_fc2, &mut tensors)?;
564 save_linear("advantage_fc1", &self.advantage_fc1, &mut tensors)?;
565 save_linear("advantage_fc2", &self.advantage_fc2, &mut tensors)?;
566 save_linear("param_mean", &self.param_mean, &mut tensors)?;
567
568 save_layernorm("ln1", &self.ln1, &mut tensors)?;
569 save_layernorm("ln2", &self.ln2, &mut tensors)?;
570 save_layernorm("ln3", &self.ln3, &mut tensors)?;
571
572 let logstd_tensor = self.param_logstd.as_tensor();
574 let logstd_shape = logstd_tensor.dims().to_vec();
575 let logstd_flat = logstd_tensor.flatten_all()?;
576 let logstd_data = logstd_flat.to_vec1::<f32>()?;
577
578 let non_zero_count = logstd_data.iter().filter(|&&x| x.abs() > 1e-10).count();
579 if non_zero_count == 0 {
580 warn!("WARNING: param_logstd contains all zeros!");
581 }
582
583 tensors.insert("param_logstd".to_string(), (logstd_shape, logstd_data));
584
585 let total_params: usize = tensors.values().map(|(_, data)| data.len()).sum();
586 info!("Saving model with {} tensors, {} total parameters", tensors.len(), total_params);
587
588 for (name, (_, data)) in tensors.iter() {
589 let non_zero = data.iter().filter(|&&x| x.abs() > 1e-10).count();
590 let zero_percent = 100.0 * (1.0 - non_zero as f64 / data.len() as f64);
591 if zero_percent > 95.0 {
592 warn!("WARNING: Tensor '{}' is {:.1}% zeros", name, zero_percent);
594 }
595 }
596
597 let tensor_count = tensors.len() as u64;
599 file.write_all(&tensor_count.to_le_bytes())
600 .map_err(candle_core::Error::Io)?;
601
602 for (name, (shape, data)) in tensors.iter() {
604 let name_bytes = name.as_bytes();
606 let name_len = name_bytes.len() as u64;
607 file.write_all(&name_len.to_le_bytes())
608 .map_err(candle_core::Error::Io)?;
609 file.write_all(name_bytes)
610 .map_err(candle_core::Error::Io)?;
611
612 let shape_len = shape.len() as u64;
614 file.write_all(&shape_len.to_le_bytes())
615 .map_err(candle_core::Error::Io)?;
616 for &dim in shape {
617 file.write_all(&(dim as u64).to_le_bytes())
618 .map_err(candle_core::Error::Io)?;
619 }
620
621 let data_len = data.len() as u64;
623 file.write_all(&data_len.to_le_bytes())
624 .map_err(candle_core::Error::Io)?;
625 for &value in data {
626 file.write_all(&value.to_le_bytes())
627 .map_err(candle_core::Error::Io)?;
628 }
629 }
630
631 let file_metadata = std::fs::metadata(path)
632 .map_err(candle_core::Error::Io)?;
633 let file_size = file_metadata.len();
634
635 if file_size < 100_000 {
636 return Err(candle_core::Error::Msg(
637 format!("Model file suspiciously small: {} bytes", file_size)
638 ));
639 }
640
641 info!("Model saved successfully: {} bytes", file_size);
642 Ok(())
643 }
644
645 pub fn load_metadata(path: &Path) -> CandleResult<ModelMetadata> {
647 ModelMetadata::load_metadata(path)
648 }
649
650 pub fn save_to_safetensors(&self, path: &Path) -> CandleResult<()> {
652 let mut tensor_bytes: Vec<(String, Vec<usize>, Vec<u8>)> = Vec::new();
653
654 let mut collect_tensor = |name: &str, tensor: &Tensor| -> CandleResult<()> {
655 let shape = tensor.dims().to_vec();
656 let data = tensor.flatten_all()?.to_vec1::<f32>()?;
657 let bytes: Vec<u8> = data.iter()
658 .flat_map(|&f| f.to_le_bytes())
659 .collect();
660
661 tensor_bytes.push((name.to_string(), shape, bytes));
662 Ok(())
663 };
664
665 collect_tensor("fc1.weight", self.fc1.weight())?;
666 if let Some(bias) = self.fc1.bias() {
667 collect_tensor("fc1.bias", bias)?;
668 }
669
670 collect_tensor("fc2.weight", self.fc2.weight())?;
671 if let Some(bias) = self.fc2.bias() {
672 collect_tensor("fc2.bias", bias)?;
673 }
674
675 collect_tensor("fc3.weight", self.fc3.weight())?;
676 if let Some(bias) = self.fc3.bias() {
677 collect_tensor("fc3.bias", bias)?;
678 }
679
680 collect_tensor("value_fc1.weight", self.value_fc1.weight())?;
681 if let Some(bias) = self.value_fc1.bias() {
682 collect_tensor("value_fc1.bias", bias)?;
683 }
684
685 collect_tensor("value_fc2.weight", self.value_fc2.weight())?;
686 if let Some(bias) = self.value_fc2.bias() {
687 collect_tensor("value_fc2.bias", bias)?;
688 }
689
690 collect_tensor("advantage_fc1.weight", self.advantage_fc1.weight())?;
691 if let Some(bias) = self.advantage_fc1.bias() {
692 collect_tensor("advantage_fc1.bias", bias)?;
693 }
694
695 collect_tensor("advantage_fc2.weight", self.advantage_fc2.weight())?;
696 if let Some(bias) = self.advantage_fc2.bias() {
697 collect_tensor("advantage_fc2.bias", bias)?;
698 }
699
700 collect_tensor("param_mean.weight", self.param_mean.weight())?;
701 if let Some(bias) = self.param_mean.bias() {
702 collect_tensor("param_mean.bias", bias)?;
703 }
704
705 collect_tensor("ln1.weight", self.ln1.weight())?;
706 if let Some(bias) = self.ln1.bias() {
707 collect_tensor("ln1.bias", bias)?;
708 }
709
710 collect_tensor("ln2.weight", self.ln2.weight())?;
711 if let Some(bias) = self.ln2.bias() {
712 collect_tensor("ln2.bias", bias)?;
713 }
714
715 collect_tensor("ln3.weight", self.ln3.weight())?;
716 if let Some(bias) = self.ln3.bias() {
717 collect_tensor("ln3.bias", bias)?;
718 }
719
720 collect_tensor("param_logstd", self.param_logstd.as_tensor())?;
721
722 let mut tensors_data: HashMap<String, TensorView> = HashMap::new();
723
724 for (name, shape, bytes) in &tensor_bytes {
725 tensors_data.insert(
726 name.clone(),
727 TensorView::new(Dtype::F32, shape.clone(), bytes)
728 .map_err(|e| candle_core::Error::Msg(e.to_string()))?
729 );
730 }
731
732 let serialized = safetensors::serialize(&tensors_data, None)
733 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
734
735 std::fs::write(path, serialized)
736 .map_err(candle_core::Error::Io)?;
737
738 Ok(())
739 }
740
741 pub fn load_from_safetensors(
743 path: &Path,
744 state_dim: usize,
745 num_actions: usize,
746 num_params: usize,
747 device: &Device,
748 ) -> CandleResult<Self> {
749 let data = std::fs::read(path)
750 .map_err(candle_core::Error::Io)?;
751
752 let safetensors = SafeTensors::deserialize(&data)
753 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
754
755 let mut varmap = candle_nn::VarMap::new();
757 let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
758 let mut model = Self::new(state_dim, num_actions, num_params, vb)?;
759
760 for (name, tensor_view) in safetensors.tensors() {
761 let shape: Vec<usize> = tensor_view.shape().to_vec();
762 let data = tensor_view.data();
763 let float_data: Vec<f32> = data
764 .chunks_exact(4)
765 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
766 .collect();
767 let tensor = Tensor::from_vec(float_data, shape, device)?;
768 if name == "param_logstd" {
769 model.param_logstd = Var::from_tensor(&tensor)?;
770 } else {
771 varmap.set_one(&name, &tensor)?;
772 }
773 }
774
775 Ok(model)
776 }
777
778 pub fn load_from_onnx(
780 path: &Path,
781 state_dim: usize,
782 num_actions: usize,
783 num_params: usize,
784 device: &Device,
785 ) -> CandleResult<Self> {
786 use std::fs::File;
787 use std::io::Read;
788
789 let mut file = File::open(path)
790 .map_err(candle_core::Error::Io)?;
791
792 let mut metadata_len_bytes = [0u8; 8];
794 file.read_exact(&mut metadata_len_bytes)
795 .map_err(candle_core::Error::Io)?;
796 let metadata_len = u64::from_le_bytes(metadata_len_bytes) as usize;
797 if metadata_len > 10 * 1024 * 1024 {
798 return Err(candle_core::Error::Msg(format!("Invalid model file: metadata length {} is too large", metadata_len)));
799 }
800
801 let mut metadata_bytes = vec![0u8; metadata_len];
802 file.read_exact(&mut metadata_bytes)
803 .map_err(candle_core::Error::Io)?;
804
805 let metadata_json = String::from_utf8(metadata_bytes)
806 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
807 let metadata: ModelMetadata = serde_json::from_str(&metadata_json)
808 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
809
810 if metadata.state_dim != state_dim
812 || metadata.num_actions != num_actions
813 || metadata.num_params != num_params
814 {
815 return Err(candle_core::Error::Msg(
816 format!(
817 "Model dimension mismatch: expected ({}, {}, {}), got ({}, {}, {})",
818 state_dim, num_actions, num_params,
819 metadata.state_dim, metadata.num_actions, metadata.num_params
820 )
821 ));
822 }
823
824 let mut tensor_count_bytes = [0u8; 8];
826 file.read_exact(&mut tensor_count_bytes)
827 .map_err(candle_core::Error::Io)?;
828 let tensor_count = u64::from_le_bytes(tensor_count_bytes) as usize;
829
830 let mut tensors: HashMap<String, (Vec<usize>, Vec<f32>)> = HashMap::new();
832
833 for _ in 0..tensor_count {
834 let mut name_len_bytes = [0u8; 8];
836 file.read_exact(&mut name_len_bytes)
837 .map_err(candle_core::Error::Io)?;
838 let name_len = u64::from_le_bytes(name_len_bytes) as usize;
839
840 let mut name_bytes = vec![0u8; name_len];
841 file.read_exact(&mut name_bytes)
842 .map_err(candle_core::Error::Io)?;
843 let name = String::from_utf8(name_bytes)
844 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
845
846 let mut shape_len_bytes = [0u8; 8];
848 file.read_exact(&mut shape_len_bytes)
849 .map_err(candle_core::Error::Io)?;
850 let shape_len = u64::from_le_bytes(shape_len_bytes) as usize;
851
852 let mut shape = Vec::with_capacity(shape_len);
853 for _ in 0..shape_len {
854 let mut dim_bytes = [0u8; 8];
855 file.read_exact(&mut dim_bytes)
856 .map_err(candle_core::Error::Io)?;
857 shape.push(u64::from_le_bytes(dim_bytes) as usize);
858 }
859
860 let mut data_len_bytes = [0u8; 8];
862 file.read_exact(&mut data_len_bytes)
863 .map_err(candle_core::Error::Io)?;
864 let data_len = u64::from_le_bytes(data_len_bytes) as usize;
865
866 let mut data = Vec::with_capacity(data_len);
867 for _ in 0..data_len {
868 let mut value_bytes = [0u8; 4];
869 file.read_exact(&mut value_bytes)
870 .map_err(candle_core::Error::Io)?;
871 data.push(f32::from_le_bytes(value_bytes));
872 }
873
874 tensors.insert(name, (shape, data));
875 }
876
877 let mut varmap = candle_nn::VarMap::new();
879 let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
880 let mut model = Self::new(state_dim, num_actions, num_params, vb)?;
881
882 for (name, (shape, data)) in tensors.iter() {
883 let tensor = Tensor::from_vec(data.clone(), shape.as_slice(), device)?;
884 if name == "param_logstd" {
885 model.param_logstd = Var::from_tensor(&tensor)?;
886 } else {
887 varmap.set_one(name, &tensor)?;
888 }
889 }
890
891 Ok(model)
892 }
893
894 pub fn load_with_device(
896 path: &Path,
897 state_dim: usize,
898 num_actions: usize,
899 num_params: usize,
900 device: &Device,
901 ) -> CandleResult<Self> {
902 Self::load_from_onnx(path, state_dim, num_actions, num_params, device)
903 }
904}
905
906#[cfg(test)]
907mod tests {
908 use super::*;
909 use tempfile::TempDir;
910 use candle_core::Device;
911
912 #[test]
913 fn test_model_creation() {
914 let device = Device::Cpu;
915 let varmap = candle_nn::VarMap::new();
916 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
917
918 let model = DuelingDQN::new(300, 16, 6, vb).unwrap();
919 assert_eq!(model.state_dim, 300);
920 assert_eq!(model.num_actions, 16);
921 assert_eq!(model.num_params, 6);
922 }
923
924 #[test]
925 fn test_forward_pass() {
926 let device = Device::Cpu;
927 let varmap = candle_nn::VarMap::new();
928 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
929 let model = DuelingDQN::new(300, 16, 6, vb).unwrap();
930
931 let state = Tensor::zeros(&[1, 300], DType::F32, &device).unwrap();
932 let (q_values, param_mean, param_std) = model.forward(&state, false).unwrap();
933
934 assert_eq!(q_values.dims(), &[1, 16]);
935 assert_eq!(param_mean.dims(), &[1, 6]);
936 assert_eq!(param_std.dims(), &[6]);
937 }
938}