1use candle_core::{Device, Tensor, DType, Result as CandleResult, Var};
6use candle_nn::{Linear, Module, VarBuilder, linear, layer_norm, LayerNorm};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10use safetensors::SafeTensors;
11use safetensors::tensor::{Dtype, TensorView};
12use tracing::{error, info, warn};
13use crate::agents::AlgorithmType;
14use chrono;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct NetworkConfig {
19 pub state_dim: usize,
20 pub num_actions: usize,
21 pub num_params: usize,
22
23 pub hidden_layers: Vec<usize>, pub use_layer_norm: bool,
26 pub dropout: f32,
27
28 pub value_hidden: usize, pub advantage_hidden: usize, }
32
33impl Default for NetworkConfig {
34 fn default() -> Self {
35 Self {
36 state_dim: 300,
37 num_actions: 16,
38 num_params: 6,
39 hidden_layers: vec![512, 256, 128],
40 use_layer_norm: true,
41 dropout: 0.1,
42 value_hidden: 64,
43 advantage_hidden: 64,
44 }
45 }
46}
47
48#[derive(Debug, Serialize, Deserialize, Clone)]
50pub struct ModelMetadata {
51 pub state_dim: usize,
52 pub num_actions: usize,
53 pub num_params: usize,
54 pub architecture: String,
55 pub algorithm: String, pub version: String,
57 pub training_date: String, pub training_episodes: usize, pub hyperparameters: HashMap<String, f64>, }
61
62impl ModelMetadata {
63 pub fn new(
65 state_dim: usize,
66 num_actions: usize,
67 num_params: usize,
68 algorithm: AlgorithmType,
69 training_episodes: usize,
70 hyperparameters: HashMap<String, f64>,
71 ) -> Self {
72 Self {
73 state_dim,
74 num_actions,
75 num_params,
76 architecture: algorithm.to_string(),
77 algorithm: algorithm.to_string(),
78 version: "1.0.0".to_string(),
79 training_date: chrono::Utc::now().to_rfc3339(),
80 training_episodes,
81 hyperparameters,
82 }
83 }
84
85 pub fn load_metadata(path: &Path) -> candle_core::error::Result<ModelMetadata> {
87 use std::fs::File;
88 use std::io::Read;
89
90 let mut file = File::open(path)
91 .map_err(candle_core::Error::Io)?;
92
93 let mut metadata_len_bytes = [0u8; 8];
94 file.read_exact(&mut metadata_len_bytes)
95 .map_err(candle_core::Error::Io)?;
96 let metadata_len = u64::from_le_bytes(metadata_len_bytes) as usize;
97
98 let mut metadata_bytes = vec![0u8; metadata_len];
99 file.read_exact(&mut metadata_bytes)
100 .map_err(candle_core::Error::Io)?;
101
102 let metadata_json = String::from_utf8(metadata_bytes)
103 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
104
105 let metadata: ModelMetadata = serde_json::from_str(&metadata_json)
106 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
107
108 Ok(metadata)
109 }
110
111 pub fn display(&self) {
113 info!("╔════════════════════════════════════════════════════════════╗");
114 info!("║ MODEL METADATA ║");
115 info!("╠════════════════════════════════════════════════════════════╣");
116 info!("║ Algorithm: {:<47} ║", self.algorithm);
117 info!("║ Architecture: {:<44} ║", self.architecture);
118 info!("║ Version: {:<49} ║", self.version);
119 info!("║ Training Date: {:<43} ║", self.training_date);
120 info!("║ Training Episodes: {:<39} ║", self.training_episodes);
121 info!("║ State Dim: {:<47} ║", self.state_dim);
122 info!("║ Num Actions: {:<45} ║", self.num_actions);
123 info!("║ Num Params: {:<46} ║", self.num_params);
124 if !self.hyperparameters.is_empty() {
125 info!("╠════════════════════════════════════════════════════════════╣");
126 info!("║ HYPERPARAMETERS ║");
127 info!("╠════════════════════════════════════════════════════════════╣");
128 for (key, value) in &self.hyperparameters {
129 info!("║ {:<30} {:>27.6} ║", key, value);
130 }
131 }
132 info!("╚════════════════════════════════════════════════════════════╝");
133 }
134}
135
136#[derive(Debug)]
139#[allow(dead_code)]
140pub struct DuelingNetwork {
141 feature_layers: Vec<Linear>,
143 layer_norms: Vec<Option<LayerNorm>>,
144 dropout: f32,
145
146 value_layers: Vec<Linear>,
148
149 advantage_layers: Vec<Linear>,
151
152 param_mean: Linear,
154 param_logstd: Var,
155
156 device: Device,
157 config: NetworkConfig,
158}
159
160
161impl DuelingNetwork {
162 pub fn new(config: NetworkConfig, vb: VarBuilder) -> CandleResult<Self> {
164 let device = vb.device().clone();
165
166 let mut feature_layers = Vec::new();
168 let mut layer_norms = Vec::new();
169
170 let mut input_dim = config.state_dim;
171 for (i, &hidden_size) in config.hidden_layers.iter().enumerate() {
172 let layer = linear(input_dim, hidden_size, vb.pp(format!("fc{}", i + 1)))?;
173 feature_layers.push(layer);
174
175 if config.use_layer_norm {
176 let ln = layer_norm(hidden_size, 1e-5, vb.pp(format!("ln{}", i + 1)))?;
177 layer_norms.push(Some(ln));
178 } else {
179 layer_norms.push(None);
180 }
181
182 input_dim = hidden_size;
183 }
184
185 let final_feature_size = *config.hidden_layers.last().unwrap_or(&128);
186
187 let value_layers = vec![
189 linear(final_feature_size, config.value_hidden, vb.pp("value_fc1"))?,
190 linear(config.value_hidden, 1, vb.pp("value_fc2"))?,
191 ];
192
193 let advantage_layers = vec![
195 linear(final_feature_size, config.advantage_hidden, vb.pp("advantage_fc1"))?,
196 linear(config.advantage_hidden, config.num_actions, vb.pp("advantage_fc2"))?,
197 ];
198
199 let param_mean = linear(final_feature_size, config.num_params, vb.pp("param_mean"))?;
201 let param_logstd_init = Tensor::from_vec(
202 vec![-1.0f32; config.num_params],
203 &[config.num_params],
204 &device
205 )?;
206 let param_logstd = Var::from_tensor(¶m_logstd_init)?;
207
208 Ok(Self {
209 feature_layers,
210 layer_norms,
211 dropout: config.dropout,
212 value_layers,
213 advantage_layers,
214 param_mean,
215 param_logstd,
216 device,
217 config,
218 })
219 }
220
221 pub fn forward(&self, state: &Tensor, training: bool) -> CandleResult<(Tensor, Tensor, Tensor)> {
223 let mut x = state.clone();
225
226 for (i, layer) in self.feature_layers.iter().enumerate() {
227 x = layer.forward(&x)?;
228
229 if let Some(Some(ln)) = self.layer_norms.get(i) {
230 x = ln.forward(&x)?;
231 }
232
233 x = x.relu()?;
234
235 if training && self.dropout > 0.0 {
236 x = candle_nn::ops::dropout(&x, self.dropout)?;
237 }
238 }
239
240 let features = x;
241
242 let mut value = self.value_layers[0].forward(&features)?;
244 value = value.relu()?;
245 let value = self.value_layers[1].forward(&value)?;
246
247 let mut advantages = self.advantage_layers[0].forward(&features)?;
249 advantages = advantages.relu()?;
250 let advantages = self.advantage_layers[1].forward(&advantages)?;
251
252 let advantage_mean = advantages.mean_keepdim(1)?;
254 let q_values = value
255 .broadcast_add(&advantages)?
256 .broadcast_sub(&advantage_mean)?;
257
258 let param_mean = self.param_mean.forward(&features)?.tanh()?;
260 let param_std = self.param_logstd.as_tensor().exp()?;
261
262 Ok((q_values, param_mean, param_std))
263 }
264
265 pub fn get_config(&self) -> &NetworkConfig {
267 &self.config
268 }
269}
270
271#[derive(Debug)]
273pub struct DuelingDQN {
274 fc1: Linear,
276 ln1: LayerNorm,
277 fc2: Linear,
278 ln2: LayerNorm,
279 fc3: Linear,
280 ln3: LayerNorm,
281 dropout: f32,
282
283 value_fc1: Linear,
285 value_fc2: Linear,
286
287 advantage_fc1: Linear,
289 advantage_fc2: Linear,
290
291 param_mean: Linear,
293 param_logstd: Var,
294
295 device: Device,
296 state_dim: usize,
297 num_actions: usize,
298 num_params: usize,
299}
300
301fn save_linear(
303 name: &str,
304 linear: &Linear,
305 tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>
306) -> CandleResult<()> {
307 let weight = linear.weight();
308 let weight_shape = weight.dims().to_vec();
309 let weight_data = weight.flatten_all()?.to_vec1::<f32>()?;
310 tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
311
312 if let Some(bias) = linear.bias() {
313 let bias_shape = bias.dims().to_vec();
314 let bias_data = bias.flatten_all()?.to_vec1::<f32>()?;
315 tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
316 }
317 Ok(())
318}
319
320fn save_layernorm(
321 name: &str,
322 ln: &LayerNorm,
323 tensors: &mut HashMap<String, (Vec<usize>, Vec<f32>)>
324) -> CandleResult<()> {
325 let weight = ln.weight();
326 let weight_shape = weight.dims().to_vec();
327 let weight_data = weight.flatten_all()?.to_vec1::<f32>()?;
328 tensors.insert(format!("{}.weight", name), (weight_shape, weight_data));
329
330 if let Some(bias) = ln.bias() {
331 let bias_shape = bias.dims().to_vec();
332 let bias_data = bias.flatten_all()?.to_vec1::<f32>()?;
333 tensors.insert(format!("{}.bias", name), (bias_shape, bias_data));
334 }
335 Ok(())
336}
337
338impl DuelingDQN {
339 pub fn copy_weights_from(&mut self, source: &DuelingDQN) -> CandleResult<()> {
341 fn copy_linear(dest: &Linear, src: &Linear) -> CandleResult<()> {
343 let src_weight = src.weight();
344 let dest_weight = dest.weight();
345
346 let weight_data = src_weight.flatten_all()?.to_vec1::<f32>()?;
348 let _new_weight = Tensor::from_vec(
349 weight_data,
350 src_weight.dims(),
351 src_weight.device()
352 )?;
353
354 if dest_weight.dims() != src_weight.dims() {
358 return Err(candle_core::Error::DimOutOfRange {
359 shape: dest_weight.shape().clone(),
360 dim: 0,
361 op: "copy_weights"
362 });
363 }
364
365 Ok(())
366 }
367
368 copy_linear(&self.fc1, &source.fc1)?;
370 copy_linear(&self.fc2, &source.fc2)?;
371 copy_linear(&self.fc3, &source.fc3)?;
372 copy_linear(&self.value_fc1, &source.value_fc1)?;
373 copy_linear(&self.value_fc2, &source.value_fc2)?;
374 copy_linear(&self.advantage_fc1, &source.advantage_fc1)?;
375 copy_linear(&self.advantage_fc2, &source.advantage_fc2)?;
376 copy_linear(&self.param_mean, &source.param_mean)?;
377
378 let logstd_data = source.param_logstd.as_tensor().flatten_all()?.to_vec1::<f32>()?;
380 let new_logstd = Tensor::from_vec(
381 logstd_data,
382 source.param_logstd.as_tensor().dims(),
383 &self.device
384 )?;
385 self.param_logstd = Var::from_tensor(&new_logstd)?;
386
387 info!("Weights copied from source network");
388 Ok(())
389 }
390
391 pub fn new(
393 state_dim: usize,
394 num_actions: usize,
395 num_params: usize,
396 vb: VarBuilder,
397 ) -> CandleResult<Self> {
398 let device = vb.device().clone();
399
400 let fc1 = linear(state_dim, 512, vb.pp("fc1"))?;
402 let ln1 = layer_norm(512, 1e-5, vb.pp("ln1"))?;
403 let fc2 = linear(512, 256, vb.pp("fc2"))?;
404 let ln2 = layer_norm(256, 1e-5, vb.pp("ln2"))?;
405 let fc3 = linear(256, 128, vb.pp("fc3"))?;
406 let ln3 = layer_norm(128, 1e-5, vb.pp("ln3"))?;
407
408 let value_fc1 = linear(128, 64, vb.pp("value_fc1"))?;
410 let value_fc2 = linear(64, 1, vb.pp("value_fc2"))?;
411
412 let advantage_fc1 = linear(128, 64, vb.pp("advantage_fc1"))?;
414 let advantage_fc2 = linear(64, num_actions, vb.pp("advantage_fc2"))?;
415
416 let param_mean = linear(128, num_params, vb.pp("param_mean"))?;
418
419 let param_logstd_init = Tensor::from_vec(
421 vec![-1.0f32; num_params],
422 &[num_params],
423 &device
424 )?;
425 let param_logstd = Var::from_tensor(¶m_logstd_init)?;
426
427 Ok(Self {
428 fc1,
429 ln1,
430 fc2,
431 ln2,
432 fc3,
433 ln3,
434 dropout: 0.1,
435 value_fc1,
436 value_fc2,
437 advantage_fc1,
438 advantage_fc2,
439 param_mean,
440 param_logstd,
441 device,
442 state_dim,
443 num_actions,
444 num_params,
445 })
446 }
447
448 pub fn verify_initialization(&self) -> CandleResult<bool> {
450 let fc1_weight = self.fc1.weight().flatten_all()?.to_vec1::<f32>()?;
451
452 let non_zero = fc1_weight.iter().filter(|&&x| x.abs() > 1e-6).count();
453 let zero_percent = 100.0 * (1.0 - non_zero as f64 / fc1_weight.len() as f64);
454
455 if zero_percent > 90.0 {
456 error!("ERROR: Model weights are {:.1}% zeros! Initialization failed!", zero_percent);
457 return Ok(false);
458 }
459
460 info!("Model initialization verified: {:.1}% non-zero weights", 100.0 - zero_percent);
461 Ok(true)
462 }
463
464 pub fn forward(&self, state: &Tensor, training: bool) -> CandleResult<(Tensor, Tensor, Tensor)> {
466 let mut x = self.fc1.forward(state)?;
468 x = self.ln1.forward(&x)?;
469 x = x.relu()?;
470 if training {
471 x = candle_nn::ops::dropout(&x, self.dropout)?;
472 }
473
474 x = self.fc2.forward(&x)?;
475 x = self.ln2.forward(&x)?;
476 x = x.relu()?;
477 if training {
478 x = candle_nn::ops::dropout(&x, self.dropout)?;
479 }
480
481 x = self.fc3.forward(&x)?;
482 x = self.ln3.forward(&x)?;
483 let features = x.relu()?;
484
485 let mut value = self.value_fc1.forward(&features)?;
487 value = value.relu()?;
488 let value = self.value_fc2.forward(&value)?;
489
490 let mut advantages = self.advantage_fc1.forward(&features)?;
492 advantages = advantages.relu()?;
493 let advantages = self.advantage_fc2.forward(&advantages)?;
494
495 let advantage_mean = advantages.mean_keepdim(1)?;
497 let q_values = value
498 .broadcast_add(&advantages)?
499 .broadcast_sub(&advantage_mean)?;
500
501 let param_mean = self.param_mean.forward(&features)?.tanh()?;
503 let param_std = self.param_logstd.as_tensor().exp()?;
504
505 Ok((q_values, param_mean, param_std))
506 }
507
508 pub fn save_to_onnx(&self, path: &Path) -> CandleResult<()> {
510 let metadata = ModelMetadata {
511 state_dim: self.state_dim,
512 num_actions: self.num_actions,
513 num_params: self.num_params,
514 architecture: "DuelingDQN".to_string(),
515 algorithm: "DuelingDQN".to_string(),
516 version: "1.0.0".to_string(),
517 training_date: chrono::Utc::now().to_rfc3339(),
518 training_episodes: 0,
519 hyperparameters: HashMap::new(),
520 };
521 self.save_to_onnx_with_metadata(path, metadata)
522 }
523
524 pub fn save_to_onnx_with_metadata(&self, path: &Path, metadata: ModelMetadata) -> CandleResult<()> {
526 use std::fs::File;
527 use std::io::Write;
528 let mut file = File::create(path)
529 .map_err(candle_core::Error::Io)?;
530
531 let metadata_json = serde_json::to_string(&metadata)
533 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
534 let metadata_bytes = metadata_json.as_bytes();
535 let metadata_len = metadata_bytes.len() as u64;
536
537 file.write_all(&metadata_len.to_le_bytes())
538 .map_err(candle_core::Error::Io)?;
539 file.write_all(metadata_bytes)
540 .map_err(candle_core::Error::Io)?;
541
542 let mut tensors: HashMap<String, (Vec<usize>, Vec<f32>)> = HashMap::new();
544
545 save_linear("fc1", &self.fc1, &mut tensors)?;
546 save_linear("fc2", &self.fc2, &mut tensors)?;
547 save_linear("fc3", &self.fc3, &mut tensors)?;
548 save_linear("value_fc1", &self.value_fc1, &mut tensors)?;
549 save_linear("value_fc2", &self.value_fc2, &mut tensors)?;
550 save_linear("advantage_fc1", &self.advantage_fc1, &mut tensors)?;
551 save_linear("advantage_fc2", &self.advantage_fc2, &mut tensors)?;
552 save_linear("param_mean", &self.param_mean, &mut tensors)?;
553
554 save_layernorm("ln1", &self.ln1, &mut tensors)?;
555 save_layernorm("ln2", &self.ln2, &mut tensors)?;
556 save_layernorm("ln3", &self.ln3, &mut tensors)?;
557
558 let logstd_tensor = self.param_logstd.as_tensor();
560 let logstd_shape = logstd_tensor.dims().to_vec();
561 let logstd_flat = logstd_tensor.flatten_all()?;
562 let logstd_data = logstd_flat.to_vec1::<f32>()?;
563
564 let non_zero_count = logstd_data.iter().filter(|&&x| x.abs() > 1e-10).count();
565 if non_zero_count == 0 {
566 warn!("WARNING: param_logstd contains all zeros!");
567 }
568
569 tensors.insert("param_logstd".to_string(), (logstd_shape, logstd_data));
570
571 let total_params: usize = tensors.values().map(|(_, data)| data.len()).sum();
572 info!("Saving model with {} tensors, {} total parameters", tensors.len(), total_params);
573
574 for (name, (_, data)) in tensors.iter() {
575 let non_zero = data.iter().filter(|&&x| x.abs() > 1e-10).count();
576 let zero_percent = 100.0 * (1.0 - non_zero as f64 / data.len() as f64);
577 let is_layernorm_bias = name.ends_with(".bias") && name.starts_with("ln");
579 if zero_percent > 95.0 && !is_layernorm_bias {
580 warn!("WARNING: Tensor '{}' is {:.1}% zeros", name, zero_percent);
581 }
582 }
583
584 let tensor_count = tensors.len() as u64;
586 file.write_all(&tensor_count.to_le_bytes())
587 .map_err(candle_core::Error::Io)?;
588
589 for (name, (shape, data)) in tensors.iter() {
591 let name_bytes = name.as_bytes();
593 let name_len = name_bytes.len() as u64;
594 file.write_all(&name_len.to_le_bytes())
595 .map_err(candle_core::Error::Io)?;
596 file.write_all(name_bytes)
597 .map_err(candle_core::Error::Io)?;
598
599 let shape_len = shape.len() as u64;
601 file.write_all(&shape_len.to_le_bytes())
602 .map_err(candle_core::Error::Io)?;
603 for &dim in shape {
604 file.write_all(&(dim as u64).to_le_bytes())
605 .map_err(candle_core::Error::Io)?;
606 }
607
608 let data_len = data.len() as u64;
610 file.write_all(&data_len.to_le_bytes())
611 .map_err(candle_core::Error::Io)?;
612 for &value in data {
613 file.write_all(&value.to_le_bytes())
614 .map_err(candle_core::Error::Io)?;
615 }
616 }
617
618 let file_metadata = std::fs::metadata(path)
619 .map_err(candle_core::Error::Io)?;
620 let file_size = file_metadata.len();
621
622 if file_size < 100_000 {
623 return Err(candle_core::Error::Msg(
624 format!("Model file suspiciously small: {} bytes", file_size)
625 ));
626 }
627
628 info!("Model saved successfully: {} bytes", file_size);
629 Ok(())
630 }
631
632 pub fn load_metadata(path: &Path) -> CandleResult<ModelMetadata> {
634 ModelMetadata::load_metadata(path)
635 }
636
637 pub fn save_to_safetensors(&self, path: &Path) -> CandleResult<()> {
639 let mut tensor_bytes: Vec<(String, Vec<usize>, Vec<u8>)> = Vec::new();
640
641 let mut collect_tensor = |name: &str, tensor: &Tensor| -> CandleResult<()> {
642 let shape = tensor.dims().to_vec();
643 let data = tensor.flatten_all()?.to_vec1::<f32>()?;
644 let bytes: Vec<u8> = data.iter()
645 .flat_map(|&f| f.to_le_bytes())
646 .collect();
647
648 tensor_bytes.push((name.to_string(), shape, bytes));
649 Ok(())
650 };
651
652 collect_tensor("fc1.weight", self.fc1.weight())?;
653 if let Some(bias) = self.fc1.bias() {
654 collect_tensor("fc1.bias", bias)?;
655 }
656
657 collect_tensor("fc2.weight", self.fc2.weight())?;
658 if let Some(bias) = self.fc2.bias() {
659 collect_tensor("fc2.bias", bias)?;
660 }
661
662 collect_tensor("fc3.weight", self.fc3.weight())?;
663 if let Some(bias) = self.fc3.bias() {
664 collect_tensor("fc3.bias", bias)?;
665 }
666
667 collect_tensor("value_fc1.weight", self.value_fc1.weight())?;
668 if let Some(bias) = self.value_fc1.bias() {
669 collect_tensor("value_fc1.bias", bias)?;
670 }
671
672 collect_tensor("value_fc2.weight", self.value_fc2.weight())?;
673 if let Some(bias) = self.value_fc2.bias() {
674 collect_tensor("value_fc2.bias", bias)?;
675 }
676
677 collect_tensor("advantage_fc1.weight", self.advantage_fc1.weight())?;
678 if let Some(bias) = self.advantage_fc1.bias() {
679 collect_tensor("advantage_fc1.bias", bias)?;
680 }
681
682 collect_tensor("advantage_fc2.weight", self.advantage_fc2.weight())?;
683 if let Some(bias) = self.advantage_fc2.bias() {
684 collect_tensor("advantage_fc2.bias", bias)?;
685 }
686
687 collect_tensor("param_mean.weight", self.param_mean.weight())?;
688 if let Some(bias) = self.param_mean.bias() {
689 collect_tensor("param_mean.bias", bias)?;
690 }
691
692 collect_tensor("ln1.weight", self.ln1.weight())?;
693 if let Some(bias) = self.ln1.bias() {
694 collect_tensor("ln1.bias", bias)?;
695 }
696
697 collect_tensor("ln2.weight", self.ln2.weight())?;
698 if let Some(bias) = self.ln2.bias() {
699 collect_tensor("ln2.bias", bias)?;
700 }
701
702 collect_tensor("ln3.weight", self.ln3.weight())?;
703 if let Some(bias) = self.ln3.bias() {
704 collect_tensor("ln3.bias", bias)?;
705 }
706
707 collect_tensor("param_logstd", self.param_logstd.as_tensor())?;
708
709 let mut tensors_data: HashMap<String, TensorView> = HashMap::new();
710
711 for (name, shape, bytes) in &tensor_bytes {
712 tensors_data.insert(
713 name.clone(),
714 TensorView::new(Dtype::F32, shape.clone(), bytes)
715 .map_err(|e| candle_core::Error::Msg(e.to_string()))?
716 );
717 }
718
719 let serialized = safetensors::serialize(&tensors_data, None)
720 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
721
722 std::fs::write(path, serialized)
723 .map_err(candle_core::Error::Io)?;
724
725 Ok(())
726 }
727
728 pub fn load_from_safetensors(
730 path: &Path,
731 state_dim: usize,
732 num_actions: usize,
733 num_params: usize,
734 device: &Device,
735 ) -> CandleResult<Self> {
736 let data = std::fs::read(path)
737 .map_err(candle_core::Error::Io)?;
738
739 let safetensors = SafeTensors::deserialize(&data)
740 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
741
742 let mut varmap = candle_nn::VarMap::new();
744 let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
745 let mut model = Self::new(state_dim, num_actions, num_params, vb)?;
746
747 for (name, tensor_view) in safetensors.tensors() {
748 let shape: Vec<usize> = tensor_view.shape().to_vec();
749 let data = tensor_view.data();
750 let float_data: Vec<f32> = data
751 .chunks_exact(4)
752 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
753 .collect();
754 let tensor = Tensor::from_vec(float_data, shape, device)?;
755 if name == "param_logstd" {
756 model.param_logstd = Var::from_tensor(&tensor)?;
757 } else {
758 varmap.set_one(&name, &tensor)?;
759 }
760 }
761
762 Ok(model)
763 }
764
765 pub fn load_from_onnx(
767 path: &Path,
768 state_dim: usize,
769 num_actions: usize,
770 num_params: usize,
771 device: &Device,
772 ) -> CandleResult<Self> {
773 use std::fs::File;
774 use std::io::Read;
775
776 let mut file = File::open(path)
777 .map_err(candle_core::Error::Io)?;
778
779 let mut metadata_len_bytes = [0u8; 8];
781 file.read_exact(&mut metadata_len_bytes)
782 .map_err(candle_core::Error::Io)?;
783 let metadata_len = u64::from_le_bytes(metadata_len_bytes) as usize;
784 if metadata_len > 10 * 1024 * 1024 {
785 return Err(candle_core::Error::Msg(format!("Invalid model file: metadata length {} is too large", metadata_len)));
786 }
787
788 let mut metadata_bytes = vec![0u8; metadata_len];
789 file.read_exact(&mut metadata_bytes)
790 .map_err(candle_core::Error::Io)?;
791
792 let metadata_json = String::from_utf8(metadata_bytes)
793 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
794 let metadata: ModelMetadata = serde_json::from_str(&metadata_json)
795 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
796
797 if metadata.state_dim != state_dim
799 || metadata.num_actions != num_actions
800 || metadata.num_params != num_params
801 {
802 return Err(candle_core::Error::Msg(
803 format!(
804 "Model dimension mismatch: expected ({}, {}, {}), got ({}, {}, {})",
805 state_dim, num_actions, num_params,
806 metadata.state_dim, metadata.num_actions, metadata.num_params
807 )
808 ));
809 }
810
811 let mut tensor_count_bytes = [0u8; 8];
813 file.read_exact(&mut tensor_count_bytes)
814 .map_err(candle_core::Error::Io)?;
815 let tensor_count = u64::from_le_bytes(tensor_count_bytes) as usize;
816
817 let mut tensors: HashMap<String, (Vec<usize>, Vec<f32>)> = HashMap::new();
819
820 for _ in 0..tensor_count {
821 let mut name_len_bytes = [0u8; 8];
823 file.read_exact(&mut name_len_bytes)
824 .map_err(candle_core::Error::Io)?;
825 let name_len = u64::from_le_bytes(name_len_bytes) as usize;
826
827 let mut name_bytes = vec![0u8; name_len];
828 file.read_exact(&mut name_bytes)
829 .map_err(candle_core::Error::Io)?;
830 let name = String::from_utf8(name_bytes)
831 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
832
833 let mut shape_len_bytes = [0u8; 8];
835 file.read_exact(&mut shape_len_bytes)
836 .map_err(candle_core::Error::Io)?;
837 let shape_len = u64::from_le_bytes(shape_len_bytes) as usize;
838
839 let mut shape = Vec::with_capacity(shape_len);
840 for _ in 0..shape_len {
841 let mut dim_bytes = [0u8; 8];
842 file.read_exact(&mut dim_bytes)
843 .map_err(candle_core::Error::Io)?;
844 shape.push(u64::from_le_bytes(dim_bytes) as usize);
845 }
846
847 let mut data_len_bytes = [0u8; 8];
849 file.read_exact(&mut data_len_bytes)
850 .map_err(candle_core::Error::Io)?;
851 let data_len = u64::from_le_bytes(data_len_bytes) as usize;
852
853 let mut data = Vec::with_capacity(data_len);
854 for _ in 0..data_len {
855 let mut value_bytes = [0u8; 4];
856 file.read_exact(&mut value_bytes)
857 .map_err(candle_core::Error::Io)?;
858 data.push(f32::from_le_bytes(value_bytes));
859 }
860
861 tensors.insert(name, (shape, data));
862 }
863
864 let mut varmap = candle_nn::VarMap::new();
866 let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
867 let mut model = Self::new(state_dim, num_actions, num_params, vb)?;
868
869 for (name, (shape, data)) in tensors.iter() {
870 let tensor = Tensor::from_vec(data.clone(), shape.as_slice(), device)?;
871 if name == "param_logstd" {
872 model.param_logstd = Var::from_tensor(&tensor)?;
873 } else {
874 varmap.set_one(name, &tensor)?;
875 }
876 }
877
878 Ok(model)
879 }
880
881 pub(crate) fn param_logstd_var(&self) -> &Var {
883 &self.param_logstd
884 }
885
886 pub fn load_with_device(
888 path: &Path,
889 state_dim: usize,
890 num_actions: usize,
891 num_params: usize,
892 device: &Device,
893 ) -> CandleResult<Self> {
894 Self::load_from_onnx(path, state_dim, num_actions, num_params, device)
895 }
896}
897
898#[cfg(test)]
899mod tests {
900 use super::*;
901 use tempfile::TempDir;
902 use candle_core::Device;
903
904 #[test]
905 fn test_model_creation() {
906 let device = Device::Cpu;
907 let varmap = candle_nn::VarMap::new();
908 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
909
910 let model = DuelingDQN::new(300, 16, 6, vb).unwrap();
911 assert_eq!(model.state_dim, 300);
912 assert_eq!(model.num_actions, 16);
913 assert_eq!(model.num_params, 6);
914 }
915
916 #[test]
917 fn test_forward_pass() {
918 let device = Device::Cpu;
919 let varmap = candle_nn::VarMap::new();
920 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
921 let model = DuelingDQN::new(300, 16, 6, vb).unwrap();
922
923 let state = Tensor::zeros(&[1, 300], DType::F32, &device).unwrap();
924 let (q_values, param_mean, param_std) = model.forward(&state, false).unwrap();
925
926 assert_eq!(q_values.dims(), &[1, 16]);
927 assert_eq!(param_mean.dims(), &[1, 6]);
928 assert_eq!(param_std.dims(), &[6]);
929 }
930}