1use candle_core::{Device, Tensor, DType};
6use candle_nn::{VarBuilder, Optimizer, AdamW, ParamsAdamW, VarMap};
7use crate::models::{DuelingDQN, NetworkConfig};
8use crate::replay_buffer::{PrioritizedReplayBuffer, SampledBatch};
9use crate::{Result, agents::{RLAgent, AlgorithmType, AgentInfo}};
10use rand::RngExt;
11use tracing::{info, warn};
12use std::path::Path;
13
14pub struct DQNAgent {
16 pub(crate) online_network: DuelingDQN,
17 target_network: DuelingDQN,
18 optimizer: AdamW,
19 varmap: VarMap,
20 num_actions: usize,
21 num_params: usize,
22 gamma: f32,
23 step_count: usize,
24 device: Device,
25}
26
27impl DQNAgent {
28 pub fn new(
30 network_config: NetworkConfig,
31 gamma: f32,
32 lr: f64,
33 device: &Device,
34 varmap: VarMap,
35 ) -> Result<Self> {
36 let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
37 let online_network = DuelingDQN::new(
38 network_config.state_dim,
39 network_config.num_actions,
40 network_config.num_params,
41 vb.pp("online")
42 )?;
43
44 let target_varmap = VarMap::new();
45 let target_vb = VarBuilder::from_varmap(&target_varmap, DType::F32, device);
46 let mut target_network = DuelingDQN::new(
47 network_config.state_dim,
48 network_config.num_actions,
49 network_config.num_params,
50 target_vb.pp("target")
51 )?;
52
53 let trainable_vars = varmap.all_vars();
55
56 let params = ParamsAdamW {
57 lr,
58 beta1: 0.9,
59 beta2: 0.999,
60 eps: 1e-8,
61 weight_decay: 1e-4,
62 };
63
64 let optimizer = AdamW::new(trainable_vars, params)
65 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
66
67 target_network.copy_weights_from(&online_network)
69 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
70
71 Ok(Self {
72 online_network,
73 target_network,
74 optimizer,
75 varmap,
76 num_actions: network_config.num_actions,
77 num_params: network_config.num_params,
78 gamma,
79 step_count: 0,
80 device: device.clone(),
81 })
82 }
83
84 fn copy_network_weights(source: &DuelingDQN, target: &mut DuelingDQN) -> Result<()> {
86 target.copy_weights_from(source)
87 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))
88 }
89
90 pub fn update_target_network(&mut self) {
92 if let Err(e) = Self::copy_network_weights(&self.online_network, &mut self.target_network) {
95 warn!("Failed to update target network: {}", e);
96 } else {
97 info!("Target network updated (hard update)");
98 }
99 }
100
101 pub fn get_step_count(&self) -> usize {
103 self.step_count
104 }
105
106 pub fn select_action(&self, state: &[f32], epsilon: f32) -> Result<(usize, Vec<f32>)> {
108 let mut rng = rand::rng();
109
110 if rng.random::<f32>() < epsilon {
111 let discrete_action = rng.random_range(0..self.num_actions);
112 let params: Vec<f32> = (0..self.num_params)
113 .map(|_| rng.random_range(-1.0..1.0))
114 .collect();
115 Ok((discrete_action, params))
116 } else {
117 let state_tensor = Tensor::from_vec(state.to_vec(), &[1, state.len()], &self.device)
119 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
120
121 let (q_values, param_mean, _param_std) = self.online_network.forward(&state_tensor, false)
122 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
123
124 let q_vals = q_values.to_vec2::<f32>()
126 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
127 let discrete_action = q_vals[0].iter()
128 .enumerate()
129 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
130 .map(|(idx, _)| idx)
131 .unwrap_or(0);
132
133 let params = param_mean.to_vec2::<f32>()
135 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
136 let continuous_params = params[0].clone();
137
138 Ok((discrete_action, continuous_params))
139 }
140 }
141
142 pub fn train_step(&mut self, replay_buffer: &mut PrioritizedReplayBuffer, batch_size: usize) -> Result<f32> {
144 let batch = replay_buffer.sample(batch_size);
145
146 if batch.is_none() {
147 return Ok(0.0);
148 }
149
150 let SampledBatch { experiences, indices, weights } = batch.unwrap();
151
152 let states: Vec<Vec<f32>> = experiences.iter()
154 .map(|e| e.state.clone())
155 .collect();
156 let actions_discrete: Vec<usize> = experiences.iter()
157 .map(|e| e.action.0)
158 .collect();
159 let actions_params: Vec<Vec<f32>> = experiences.iter()
160 .map(|e| e.action.1.clone())
161 .collect();
162 let rewards: Vec<f32> = experiences.iter()
163 .map(|e| e.reward)
164 .collect();
165 let next_states: Vec<Vec<f32>> = experiences.iter()
166 .map(|e| e.next_state.clone())
167 .collect();
168 let dones: Vec<f32> = experiences.iter()
169 .map(|e| if e.done { 1.0 } else { 0.0 })
170 .collect();
171
172 let state_dim = states[0].len();
174 let states_flat: Vec<f32> = states.into_iter().flatten().collect();
175 let states_tensor = Tensor::from_vec(
176 states_flat,
177 &[batch_size, state_dim],
178 &self.device
179 ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
180
181 let next_states_flat: Vec<f32> = next_states.into_iter().flatten().collect();
182 let next_states_tensor = Tensor::from_vec(
183 next_states_flat,
184 &[batch_size, state_dim],
185 &self.device
186 ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
187
188 let rewards_tensor = Tensor::from_vec(
189 rewards,
190 &[batch_size],
191 &self.device
192 ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
193
194 let dones_tensor = Tensor::from_vec(
195 dones,
196 &[batch_size],
197 &self.device
198 ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
199
200 let weights_tensor = Tensor::from_vec(
201 weights,
202 &[batch_size],
203 &self.device
204 ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
205
206 let actions_discrete_tensor = Tensor::from_vec(
208 actions_discrete.iter().map(|&x| x as i64).collect::<Vec<_>>(),
209 &[batch_size],
210 &self.device
211 ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
212
213 let actions_params_flat: Vec<f32> = actions_params.into_iter().flatten().collect();
214 let actions_params_tensor = Tensor::from_vec(
215 actions_params_flat,
216 &[batch_size, self.num_params],
217 &self.device
218 ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
219
220 let (q_values, param_means, param_stds) = self.online_network.forward(&states_tensor, true)
222 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
223
224 let q_sample = q_values.get(0)?.to_vec1::<f32>()?;
226 if q_sample.iter().any(|&x| x.is_nan() || x.is_infinite()) {
227 return Err(crate::ExtractionError::ModelError(
228 "NaN/Inf detected in Q-values forward pass".to_string()
229 ));
230 }
231
232 let q_values_selected = q_values
234 .gather(&actions_discrete_tensor.unsqueeze(1)?, 1)?
235 .squeeze(1)?;
236
237 let (next_q_online, _, _) = self.online_network.forward(&next_states_tensor, false)
239 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
240
241 let next_actions = next_q_online.argmax(1)?;
242
243 let (next_q_target, _, _) = self.target_network.forward(&next_states_tensor, false)
244 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
245
246 let next_q_values = next_q_target
247 .gather(&next_actions.unsqueeze(1)?, 1)?
248 .squeeze(1)?;
249
250 let ones = Tensor::ones(&[batch_size], DType::F32, &self.device)
252 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
253
254 let gamma_vec = vec![self.gamma; batch_size];
256 let gamma_tensor = Tensor::from_vec(gamma_vec, &[batch_size], &self.device)?;
257
258 let discount_factors = (ones - dones_tensor)?
260 .mul(&gamma_tensor)?;
261
262 let td_targets = rewards_tensor
264 .add(&next_q_values.mul(&discount_factors)?)?;
265
266 let td_errors_tensor = (td_targets.clone() - q_values_selected.clone())?;
268 let td_errors: Vec<f32> = td_errors_tensor
269 .to_vec1()
270 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
271
272 let q_loss_elements = smooth_l1_loss(&q_values_selected, &td_targets)?;
274 let weighted_q_loss = (q_loss_elements * weights_tensor.clone())?;
275 let loss_q = weighted_q_loss.mean_all()?;
276
277 let param_loss = self.calculate_param_loss(¶m_means, ¶m_stds, &actions_params_tensor)?;
279
280 let loss_q_scalar = loss_q.to_scalar::<f32>()
282 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
283
284 let param_loss_scalar = param_loss.to_scalar::<f32>()
285 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
286
287 let total_loss_scalar = loss_q_scalar + 0.1 * param_loss_scalar;
288
289 let total_loss = Tensor::from_vec(
291 vec![total_loss_scalar],
292 &[1],
293 &self.device
294 )?;
295
296 if total_loss_scalar.is_nan() || total_loss_scalar.is_infinite() {
298 return Err(crate::ExtractionError::ModelError(
299 format!("Invalid loss: {}", total_loss_scalar)
300 ));
301 }
302
303 let mut grad_store = total_loss.backward()?;
306
307 let vars = self.varmap.all_vars();
310 let max_grad_norm = 1.0f32;
311 let mut total_norm_sq = 0.0f32;
312
313 for var in &vars {
315 if let Some(grad) = grad_store.get(var) {
316 let norm_sq = grad.sqr()?.sum_all()?.to_scalar::<f32>()?;
317 total_norm_sq += norm_sq;
318 }
319 }
320
321 let total_norm = total_norm_sq.sqrt();
322
323 if total_norm > max_grad_norm {
325 let clip_coef = max_grad_norm / (total_norm + 1e-6);
326
327 for var in self.varmap.all_vars() {
329 if let Some(grad) = grad_store.get(&var) {
330 let clip_coef_tensor = Tensor::from_vec(
332 vec![clip_coef],
333 &[1],
334 &self.device
335 )?;
336
337 let clipped_grad = grad.mul(&clip_coef_tensor)?;
339
340 grad_store.insert(&var, clipped_grad);
342 }
343 }
344
345 if self.step_count.is_multiple_of(1000) {
346 info!("Gradient norm: {:.4}, clipped with coef: {:.4}", total_norm, clip_coef);
347 }
348 }
349
350 self.optimizer.step(&grad_store)
352 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
353
354 replay_buffer.update_priorities(&indices, &td_errors);
356
357 self.step_count += 1;
358
359 Ok(total_loss_scalar)
361 }
362
363 fn calculate_param_loss(
365 &self,
366 means: &Tensor,
367 stds: &Tensor,
368 actions: &Tensor,
369 ) -> candle_core::error::Result<Tensor> {
370 let batch_size = actions.dims()[0];
371 let num_params = actions.dims()[1];
372
373 let diff = actions.sub(means)?;
374
375 let stds_broadcast = stds.unsqueeze(0)?.broadcast_as(means.shape())?;
377
378 let variance = stds_broadcast.sqr()?;
379 let squared_diff = diff.sqr()?.div(&variance)?;
380
381 let log_std = stds_broadcast.log()?;
382
383 let pi_vec = vec![std::f32::consts::PI; batch_size * num_params];
385 let pi_constant = Tensor::from_vec(pi_vec, &[batch_size, num_params], &self.device)?;
386
387 let half_vec = vec![0.5f32; batch_size * num_params];
388 let half_tensor = Tensor::from_vec(half_vec, &[batch_size, num_params], &self.device)?;
389
390 let constant = pi_constant.log()?.mul(&half_tensor)?;
391
392 let nll = constant
393 .add(&log_std)?
394 .add(&squared_diff.mul(&half_tensor)?)?;
395
396 nll.mean_all()
397 }
398
399 pub fn save(&self, path: &std::path::Path) -> Result<()> {
401 self.online_network.save_to_onnx(path)
402 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
403
404 let safetensors_path = path.with_extension("safetensors");
405 self.online_network.save_to_safetensors(&safetensors_path)
406 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
407
408 tracing::info!("Model saved: ONNX ({} bytes), SafeTensors ({} bytes)",
409 std::fs::metadata(path).map(|m| m.len()).unwrap_or(0),
410 std::fs::metadata(&safetensors_path).map(|m| m.len()).unwrap_or(0));
411
412 Ok(())
413 }
414
415 pub fn load(
417 path: &std::path::Path,
418 state_dim: usize,
419 num_actions: usize,
420 num_params: usize,
421 ) -> Result<Self> {
422 let device = crate::device::get_device();
423 Self::load_with_device(path, state_dim, num_actions, num_params, &device)
424 }
425
426 pub fn load_with_device(
427 path: &std::path::Path,
428 state_dim: usize,
429 num_actions: usize,
430 num_params: usize,
431 device: &Device,
432 ) -> Result<Self> {
433 tracing::info!("Loading model on device: {}", crate::device::get_device_info(device));
434
435 let online_network = DuelingDQN::load_from_onnx(path, state_dim, num_actions, num_params, device)
436 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
437
438 let target_varmap = VarMap::new();
440 let vb_target = VarBuilder::from_varmap(&target_varmap, DType::F32, device);
441 let target_network = DuelingDQN::new(state_dim, num_actions, num_params, vb_target)
442 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
443
444 let varmap = VarMap::new();
445 let vars = varmap.all_vars();
446 let params = ParamsAdamW::default();
447 let optimizer = AdamW::new(vars, params)
448 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
449
450 Ok(Self {
451 online_network,
452 target_network,
453 optimizer,
454 varmap,
455 num_actions,
456 num_params,
457 gamma: 0.95,
458 step_count: 0,
459 device: device.clone(),
460 })
461 }
462}
463
464impl RLAgent for DQNAgent {
466 fn select_action(&self, state: &[f32], epsilon: f32) -> Result<(usize, Vec<f32>)> {
467 let mut rng = rand::rng();
468
469 if rng.random::<f32>() < epsilon {
470 let discrete_action = rng.random_range(0..self.num_actions);
471 let params: Vec<f32> = (0..self.num_params)
472 .map(|_| rng.random_range(-1.0..1.0))
473 .collect();
474 Ok((discrete_action, params))
475 } else {
476 let state_tensor = Tensor::from_vec(state.to_vec(), &[1, state.len()], &self.device)
477 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
478
479 let (q_values, param_mean, _param_std) = self.online_network.forward(&state_tensor, false)
480 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
481
482 let q_vals = q_values.to_vec2::<f32>()
483 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
484 let discrete_action = q_vals[0].iter()
485 .enumerate()
486 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
487 .map(|(idx, _)| idx)
488 .unwrap_or(0);
489
490 let params = param_mean.to_vec2::<f32>()
491 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
492 let continuous_params = params[0].clone();
493
494 Ok((discrete_action, continuous_params))
495 }
496 }
497
498
499 fn train_step(&mut self, replay_buffer: &mut PrioritizedReplayBuffer, batch_size: usize) -> Result<f32> {
501 let batch = replay_buffer.sample(batch_size);
502
503 if batch.is_none() {
504 return Ok(0.0);
505 }
506
507 let SampledBatch { experiences, indices, weights } = batch.unwrap();
508
509 let states: Vec<Vec<f32>> = experiences.iter()
511 .map(|e| e.state.clone())
512 .collect();
513 let actions_discrete: Vec<usize> = experiences.iter()
514 .map(|e| e.action.0)
515 .collect();
516 let actions_params: Vec<Vec<f32>> = experiences.iter()
517 .map(|e| e.action.1.clone())
518 .collect();
519 let rewards: Vec<f32> = experiences.iter()
520 .map(|e| e.reward)
521 .collect();
522 let next_states: Vec<Vec<f32>> = experiences.iter()
523 .map(|e| e.next_state.clone())
524 .collect();
525 let dones: Vec<f32> = experiences.iter()
526 .map(|e| if e.done { 1.0 } else { 0.0 })
527 .collect();
528
529 let state_dim = states[0].len();
531 let states_flat: Vec<f32> = states.into_iter().flatten().collect();
532 let states_tensor = Tensor::from_vec(
533 states_flat,
534 &[batch_size, state_dim],
535 &self.device
536 ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
537
538 let next_states_flat: Vec<f32> = next_states.into_iter().flatten().collect();
539 let next_states_tensor = Tensor::from_vec(
540 next_states_flat,
541 &[batch_size, state_dim],
542 &self.device
543 ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
544
545 let rewards_tensor = Tensor::from_vec(
546 rewards,
547 &[batch_size],
548 &self.device
549 ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
550
551 let dones_tensor = Tensor::from_vec(
552 dones,
553 &[batch_size],
554 &self.device
555 ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
556
557 let weights_tensor = Tensor::from_vec(
558 weights,
559 &[batch_size],
560 &self.device
561 ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
562
563 let actions_discrete_tensor = Tensor::from_vec(
565 actions_discrete.iter().map(|&x| x as i64).collect::<Vec<_>>(),
566 &[batch_size],
567 &self.device
568 ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
569
570 let actions_params_flat: Vec<f32> = actions_params.into_iter().flatten().collect();
571 let actions_params_tensor = Tensor::from_vec(
572 actions_params_flat,
573 &[batch_size, self.num_params],
574 &self.device
575 ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
576
577 let (q_values, param_means, param_stds) = self.online_network.forward(&states_tensor, true)
579 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
580
581 let q_sample = q_values.get(0)?.to_vec1::<f32>()?;
583 if q_sample.iter().any(|&x| x.is_nan() || x.is_infinite()) {
584 return Err(crate::ExtractionError::ModelError(
585 "NaN/Inf detected in Q-values forward pass".to_string()
586 ));
587 }
588
589 let q_values_selected = q_values
591 .gather(&actions_discrete_tensor.unsqueeze(1)?, 1)?
592 .squeeze(1)?;
593
594 let (next_q_online, _, _) = self.online_network.forward(&next_states_tensor, false)
596 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
597
598 let next_actions = next_q_online.argmax(1)?;
599
600 let (next_q_target, _, _) = self.target_network.forward(&next_states_tensor, false)
601 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
602
603 let next_q_values = next_q_target
604 .gather(&next_actions.unsqueeze(1)?, 1)?
605 .squeeze(1)?;
606
607 let ones = Tensor::ones(&[batch_size], DType::F32, &self.device)
609 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
610
611 let gamma_vec = vec![self.gamma; batch_size];
613 let gamma_tensor = Tensor::from_vec(gamma_vec, &[batch_size], &self.device)?;
614
615 let discount_factors = (ones - dones_tensor)?
617 .mul(&gamma_tensor)?;
618
619 let td_targets = rewards_tensor
621 .add(&next_q_values.mul(&discount_factors)?)?;
622
623 let td_errors_tensor = (td_targets.clone() - q_values_selected.clone())?;
625 let td_errors: Vec<f32> = td_errors_tensor
626 .to_vec1()
627 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
628
629 let q_loss_elements = smooth_l1_loss(&q_values_selected, &td_targets)?;
631 let weighted_q_loss = (q_loss_elements * weights_tensor.clone())?;
632 let loss_q = weighted_q_loss.mean_all()?;
633
634 let param_loss = self.calculate_param_loss(¶m_means, ¶m_stds, &actions_params_tensor)?;
636
637 let loss_q_scalar = loss_q.to_scalar::<f32>()
639 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
640
641 let param_loss_scalar = param_loss.to_scalar::<f32>()
642 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
643
644 let total_loss_scalar = loss_q_scalar + 0.1 * param_loss_scalar;
645
646 let total_loss = Tensor::from_vec(
648 vec![total_loss_scalar],
649 &[1],
650 &self.device
651 )?;
652
653 if total_loss_scalar.is_nan() || total_loss_scalar.is_infinite() {
655 return Err(crate::ExtractionError::ModelError(
656 format!("Invalid loss: {}", total_loss_scalar)
657 ));
658 }
659
660 let mut grad_store = total_loss.backward()?;
663
664 let vars = self.varmap.all_vars();
667 let max_grad_norm = 1.0f32;
668 let mut total_norm_sq = 0.0f32;
669
670 for var in &vars {
672 if let Some(grad) = grad_store.get(var) {
673 let norm_sq = grad.sqr()?.sum_all()?.to_scalar::<f32>()?;
674 total_norm_sq += norm_sq;
675 }
676 }
677
678 let total_norm = total_norm_sq.sqrt();
679
680 if total_norm > max_grad_norm {
682 let clip_coef = max_grad_norm / (total_norm + 1e-6);
683
684 for var in self.varmap.all_vars() {
686 if let Some(grad) = grad_store.get(&var) {
687 let clip_coef_tensor = Tensor::from_vec(
689 vec![clip_coef],
690 &[1],
691 &self.device
692 )?;
693
694 let clipped_grad = grad.mul(&clip_coef_tensor)?;
696
697 grad_store.insert(&var, clipped_grad);
699 }
700 }
701
702 if self.step_count.is_multiple_of(1000) {
703 info!("Gradient norm: {:.4}, clipped with coef: {:.4}", total_norm, clip_coef);
704 }
705 }
706
707 self.optimizer.step(&grad_store)
709 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
710
711 replay_buffer.update_priorities(&indices, &td_errors);
713
714 self.step_count += 1;
715
716 Ok(total_loss_scalar)
718 }
719
720 fn update_target_network(&mut self) {
721 if let Err(e) = Self::copy_network_weights(&self.online_network, &mut self.target_network) {
722 warn!("Failed to update target network: {}", e);
723 } else {
724 info!("Target network updated (hard update)");
725 }
726 }
727
728 fn get_step_count(&self) -> usize {
729 self.step_count
730 }
731
732 fn save_with_metadata(
733 &self,
734 path: &Path,
735 training_episodes: usize,
736 hyperparameters: std::collections::HashMap<String, f64>,
737 ) -> Result<()> {
738 use crate::models::ModelMetadata;
739
740 let metadata = ModelMetadata::new(
741 300, self.num_actions,
743 self.num_params,
744 AlgorithmType::DuelingDQN,
745 training_episodes,
746 hyperparameters,
747 );
748
749 self.online_network.save_to_onnx_with_metadata(path, metadata)
750 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
751
752 let safetensors_path = path.with_extension("safetensors");
753 self.online_network.save_to_safetensors(&safetensors_path)
754 .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
755
756 tracing::info!("Model saved with metadata: ONNX ({} bytes), SafeTensors ({} bytes)",
757 std::fs::metadata(path).map(|m| m.len()).unwrap_or(0),
758 std::fs::metadata(&safetensors_path).map(|m| m.len()).unwrap_or(0));
759
760 Ok(())
761 }
762
763 fn save(&self, path: &Path) -> Result<()> {
764 self.save_with_metadata(path, 0, std::collections::HashMap::new())
766 }
767
768
769 fn algorithm_type(&self) -> AlgorithmType {
770 AlgorithmType::DuelingDQN
771 }
772
773 fn get_info(&self) -> AgentInfo {
774 AgentInfo {
775 algorithm: AlgorithmType::DuelingDQN,
776 num_parameters: 338525,
777 state_dim: 300,
778 num_actions: self.num_actions,
779 continuous_params: self.num_params,
780 version: "1.0.0".to_string(),
781 features: vec![
782 "dueling".to_string(),
783 "double_dqn".to_string(),
784 "prioritized_replay".to_string(),
785 ],
786 }
787 }
788}
789
790fn smooth_l1_loss(predicted: &Tensor, target: &Tensor) -> candle_core::error::Result<Tensor>
792{
793 let diff = predicted.sub(target)?;
794 let abs_diff = diff.abs()?;
795
796 let batch_size = predicted.dims()[0];
797 let threshold_vec = vec![1.0f32; batch_size];
798 let threshold = Tensor::from_vec(threshold_vec, &[batch_size], predicted.device())?;
799
800 let half_vec = vec![0.5f32; batch_size];
801 let half_tensor = Tensor::from_vec(half_vec, &[batch_size], predicted.device())?;
802
803 let small_loss = diff.sqr()?.mul(&half_tensor)?;
804 let large_loss = abs_diff.sub(&half_tensor)?;
805
806 abs_diff.lt(&threshold)?
807 .where_cond(&small_loss, &large_loss)
808}
809
810#[cfg(test)]
811mod tests {
812 use super::*;
813 use crate::replay_buffer::{PrioritizedReplayBuffer, Experience};
814 use candle_core::Device;
815 use candle_nn::VarBuilder;
816 use candle_core::DType;
817 use crate::Config;
818 use crate::models::NetworkConfig;
819
820 fn create_network_config(config: &Config) -> NetworkConfig {
821 NetworkConfig {
822 state_dim: config.state_dim,
823 num_actions: config.num_discrete_actions,
824 num_params: config.num_continuous_params,
825 hidden_layers: vec![512, 256, 128],
826 use_layer_norm: true,
827 dropout: 0.1,
828 value_hidden: 64,
829 advantage_hidden: 64,
830 }
831 }
832
833 #[test]
834 fn test_train_step_no_shape_mismatch() {
835 let device = Device::Cpu;
836 let varmap = VarMap::new();
837 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
838 let config = Config::default();
839 let network_config = create_network_config(&config);
840
841 let mut agent = DQNAgent::new(
842 network_config,
843 0.95,
844 0.001,
845 &device,
846 varmap,
847 ).unwrap();
848
849 let mut replay_buffer = PrioritizedReplayBuffer::new(10000, 0.6, 0.4);
850
851 for _ in 0..1000 {
852 let exp = Experience {
853 state: vec![0.1; 300],
854 action: (0, vec![0.0; 6]),
855 reward: 1.0,
856 next_state: vec![0.2; 300],
857 done: false,
858 };
859 replay_buffer.add(exp);
860 }
861
862 let result = agent.train_step(&mut replay_buffer, 512);
863
864 match result {
865 Ok(loss) => {
866 println!("Training step successful, loss: {}", loss);
867 assert!(!loss.is_nan(), "Loss should not be NaN");
868 assert!(!loss.is_infinite(), "Loss should not be infinite");
869 }
870 Err(e) => {
871 panic!("Training step failed: {}", e);
872 }
873 }
874 }
875}