Skip to main content

content_extractor_rl/agents/
dqn_agent.rs

1// ============================================================================
2// FILE: crates/content-extractor-rl/src/agents/dqn_agent.rs
3// ============================================================================
4
5use candle_core::{Device, Tensor, DType};
6use candle_nn::{VarBuilder, Optimizer, AdamW, ParamsAdamW, VarMap};
7use crate::models::{DuelingDQN, NetworkConfig};
8use crate::replay_buffer::{PrioritizedReplayBuffer, SampledBatch};
9use crate::{Result, agents::{RLAgent, AlgorithmType, AgentInfo}};
10use rand::RngExt;
11use tracing::{info, warn};
12use std::path::Path;
13
14/// DQN Agent for article extraction:
15pub struct DQNAgent {
16    pub(crate) online_network: DuelingDQN,
17    target_network: DuelingDQN,
18    optimizer: AdamW,
19    varmap: VarMap,
20    num_actions: usize,
21    num_params: usize,
22    gamma: f32,
23    step_count: usize,
24    device: Device,
25}
26
27impl DQNAgent {
28    /// Create new DQN agent with custom network configuration
29    pub fn new(
30        network_config: NetworkConfig,
31        gamma: f32,
32        lr: f64,
33        device: &Device,
34        varmap: VarMap,
35    ) -> Result<Self> {
36        let vb = VarBuilder::from_varmap(&varmap, DType::F32, device);
37        let online_network = DuelingDQN::new(
38            network_config.state_dim,
39            network_config.num_actions,
40            network_config.num_params,
41            vb.pp("online")
42        )?;
43
44        let target_varmap = VarMap::new();
45        let target_vb = VarBuilder::from_varmap(&target_varmap, DType::F32, device);
46        let mut target_network = DuelingDQN::new(
47            network_config.state_dim,
48            network_config.num_actions,
49            network_config.num_params,
50            target_vb.pp("target")
51        )?;
52
53        // Get trainable variables from the varmap
54        let trainable_vars = varmap.all_vars();
55
56        let params = ParamsAdamW {
57            lr,
58            beta1: 0.9,
59            beta2: 0.999,
60            eps: 1e-8,
61            weight_decay: 1e-4,
62        };
63
64        let optimizer = AdamW::new(trainable_vars, params)
65            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
66
67        // Copy online network weights to target network initially
68        target_network.copy_weights_from(&online_network)
69            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
70
71        Ok(Self {
72            online_network,
73            target_network,
74            optimizer,
75            varmap,
76            num_actions: network_config.num_actions,
77            num_params: network_config.num_params,
78            gamma,
79            step_count: 0,
80            device: device.clone(),
81        })
82    }
83
84    /// Copy weights from source network to target network
85    fn copy_network_weights(source: &DuelingDQN, target: &mut DuelingDQN) -> Result<()> {
86        target.copy_weights_from(source)
87            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))
88    }
89
90    /// Update target network using soft update
91    pub fn update_target_network(&mut self) {
92        // Implement hard update (full copy of weights)
93        // For soft update with tau, you would blend: target = tau * online + (1-tau) * target
94        if let Err(e) = Self::copy_network_weights(&self.online_network, &mut self.target_network) {
95            warn!("Failed to update target network: {}", e);
96        } else {
97            info!("Target network updated (hard update)");
98        }
99    }
100
101    /// Get step count
102    pub fn get_step_count(&self) -> usize {
103        self.step_count
104    }
105
106    /// Select action using epsilon-greedy policy
107    pub fn select_action(&self, state: &[f32], epsilon: f32) -> Result<(usize, Vec<f32>)> {
108        let mut rng = rand::rng();
109
110        if rng.random::<f32>() < epsilon {
111            let discrete_action = rng.random_range(0..self.num_actions);
112            let params: Vec<f32> = (0..self.num_params)
113                .map(|_| rng.random_range(-1.0..1.0))
114                .collect();
115            Ok((discrete_action, params))
116        } else {
117            // Greedy action
118            let state_tensor = Tensor::from_vec(state.to_vec(), &[1, state.len()], &self.device)
119                .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
120
121            let (q_values, param_mean, _param_std) = self.online_network.forward(&state_tensor, false)
122                .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
123
124            // Get discrete action
125            let q_vals = q_values.to_vec2::<f32>()
126                .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
127            let discrete_action = q_vals[0].iter()
128                .enumerate()
129                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
130                .map(|(idx, _)| idx)
131                .unwrap_or(0);
132
133            // Get continuous params
134            let params = param_mean.to_vec2::<f32>()
135                .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
136            let continuous_params = params[0].clone();
137
138            Ok((discrete_action, continuous_params))
139        }
140    }
141
142    /// Complete training step with proper loss calculation
143    pub fn train_step(&mut self, replay_buffer: &mut PrioritizedReplayBuffer, batch_size: usize) -> Result<f32> {
144        let batch = replay_buffer.sample(batch_size);
145
146        if batch.is_none() {
147            return Ok(0.0);
148        }
149
150        let SampledBatch { experiences, indices, weights } = batch.unwrap();
151
152        // Extract components from experiences
153        let states: Vec<Vec<f32>> = experiences.iter()
154            .map(|e| e.state.clone())
155            .collect();
156        let actions_discrete: Vec<usize> = experiences.iter()
157            .map(|e| e.action.0)
158            .collect();
159        let actions_params: Vec<Vec<f32>> = experiences.iter()
160            .map(|e| e.action.1.clone())
161            .collect();
162        let rewards: Vec<f32> = experiences.iter()
163            .map(|e| e.reward)
164            .collect();
165        let next_states: Vec<Vec<f32>> = experiences.iter()
166            .map(|e| e.next_state.clone())
167            .collect();
168        let dones: Vec<f32> = experiences.iter()
169            .map(|e| if e.done { 1.0 } else { 0.0 })
170            .collect();
171
172        // Convert to tensors
173        let state_dim = states[0].len();
174        let states_flat: Vec<f32> = states.into_iter().flatten().collect();
175        let states_tensor = Tensor::from_vec(
176            states_flat,
177            &[batch_size, state_dim],
178            &self.device
179        ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
180
181        let next_states_flat: Vec<f32> = next_states.into_iter().flatten().collect();
182        let next_states_tensor = Tensor::from_vec(
183            next_states_flat,
184            &[batch_size, state_dim],
185            &self.device
186        ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
187
188        let rewards_tensor = Tensor::from_vec(
189            rewards,
190            &[batch_size],
191            &self.device
192        ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
193
194        let dones_tensor = Tensor::from_vec(
195            dones,
196            &[batch_size],
197            &self.device
198        ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
199
200        let weights_tensor = Tensor::from_vec(
201            weights,
202            &[batch_size],
203            &self.device
204        ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
205
206        // Actions tensors
207        let actions_discrete_tensor = Tensor::from_vec(
208            actions_discrete.iter().map(|&x| x as i64).collect::<Vec<_>>(),
209            &[batch_size],
210            &self.device
211        ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
212
213        let actions_params_flat: Vec<f32> = actions_params.into_iter().flatten().collect();
214        let actions_params_tensor = Tensor::from_vec(
215            actions_params_flat,
216            &[batch_size, self.num_params],
217            &self.device
218        ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
219
220        // Forward pass through online network
221        let (q_values, param_means, param_stds) = self.online_network.forward(&states_tensor, true)
222            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
223
224        // VALIDATION: Check for NaN/Inf in forward pass
225        let q_sample = q_values.get(0)?.to_vec1::<f32>()?;
226        if q_sample.iter().any(|&x| x.is_nan() || x.is_infinite()) {
227            return Err(crate::ExtractionError::ModelError(
228                "NaN/Inf detected in Q-values forward pass".to_string()
229            ));
230        }
231
232        // Gather Q-values for taken actions
233        let q_values_selected = q_values
234            .gather(&actions_discrete_tensor.unsqueeze(1)?, 1)?
235            .squeeze(1)?;
236
237        // Double DQN: Use online network to select actions, target network to evaluate
238        let (next_q_online, _, _) = self.online_network.forward(&next_states_tensor, false)
239            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
240
241        let next_actions = next_q_online.argmax(1)?;
242
243        let (next_q_target, _, _) = self.target_network.forward(&next_states_tensor, false)
244            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
245
246        let next_q_values = next_q_target
247            .gather(&next_actions.unsqueeze(1)?, 1)?
248            .squeeze(1)?;
249
250        // Calculate TD targets with proper shape broadcasting
251        let ones = Tensor::ones(&[batch_size], DType::F32, &self.device)
252            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
253
254        // Create gamma tensor with same shape as batch
255        let gamma_vec = vec![self.gamma; batch_size];
256        let gamma_tensor = Tensor::from_vec(gamma_vec, &[batch_size], &self.device)?;
257
258        // Calculate discount factors: gamma * (1 - done)
259        let discount_factors = (ones - dones_tensor)?
260            .mul(&gamma_tensor)?;
261
262        // TD target: reward + gamma * (1 - done) * next_q
263        let td_targets = rewards_tensor
264            .add(&next_q_values.mul(&discount_factors)?)?;
265
266        // Calculate TD errors for priority update
267        let td_errors_tensor = (td_targets.clone() - q_values_selected.clone())?;
268        let td_errors: Vec<f32> = td_errors_tensor
269            .to_vec1()
270            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
271
272        // Q-value loss (Smooth L1 / Huber loss)
273        let q_loss_elements = smooth_l1_loss(&q_values_selected, &td_targets)?;
274        let weighted_q_loss = (q_loss_elements * weights_tensor.clone())?;
275        let loss_q = weighted_q_loss.mean_all()?;
276
277        // Parameter loss (Negative log-likelihood of Gaussian)
278        let param_loss = self.calculate_param_loss(&param_means, &param_stds, &actions_params_tensor)?;
279
280        // Combine losses
281        let loss_q_scalar = loss_q.to_scalar::<f32>()
282            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
283
284        let param_loss_scalar = param_loss.to_scalar::<f32>()
285            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
286
287        let total_loss_scalar = loss_q_scalar + 0.1 * param_loss_scalar;
288
289        // Create tensor from combined scalar
290        let total_loss = Tensor::from_vec(
291            vec![total_loss_scalar],
292            &[1],
293            &self.device
294        )?;
295
296        // VALIDATION: Check final loss
297        if total_loss_scalar.is_nan() || total_loss_scalar.is_infinite() {
298            return Err(crate::ExtractionError::ModelError(
299                format!("Invalid loss: {}", total_loss_scalar)
300            ));
301        }
302
303        // Get GradStore from backward pass
304        // Perform backward pass to get gradients
305        let mut grad_store = total_loss.backward()?;
306
307        // IMPROVED: Clip gradients to prevent explosion
308        // Get all trainable variables from the varmap
309        let vars = self.varmap.all_vars();
310        let max_grad_norm = 1.0f32;
311        let mut total_norm_sq = 0.0f32;
312
313        // Calculate total gradient norm for all trainable variables
314        for var in &vars {
315            if let Some(grad) = grad_store.get(var) {
316                let norm_sq = grad.sqr()?.sum_all()?.to_scalar::<f32>()?;
317                total_norm_sq += norm_sq;
318            }
319        }
320
321        let total_norm = total_norm_sq.sqrt();
322
323        // Apply gradient clipping if needed
324        if total_norm > max_grad_norm {
325            let clip_coef = max_grad_norm / (total_norm + 1e-6);
326
327            // Apply clipping to each gradient
328            for var in self.varmap.all_vars() {
329                if let Some(grad) = grad_store.get(&var) {
330                    // Create a tensor for the clip coefficient
331                    let clip_coef_tensor = Tensor::from_vec(
332                        vec![clip_coef],
333                        &[1],
334                        &self.device
335                    )?;
336
337                    // Multiply gradient by clip coefficient
338                    let clipped_grad = grad.mul(&clip_coef_tensor)?;
339
340                    // Update the gradient in the grad store
341                    grad_store.insert(&var, clipped_grad);
342                }
343            }
344
345            if self.step_count.is_multiple_of(1000) {
346                info!("Gradient norm: {:.4}, clipped with coef: {:.4}", total_norm, clip_coef);
347            }
348        }
349
350        // Step the optimizer with gradients
351        self.optimizer.step(&grad_store)
352            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
353
354        // Update priorities in replay buffer
355        replay_buffer.update_priorities(&indices, &td_errors);
356
357        self.step_count += 1;
358
359        // Return loss value
360        Ok(total_loss_scalar)
361    }
362
363    /// Calculate parameter loss (negative log-likelihood)
364    fn calculate_param_loss(
365        &self,
366        means: &Tensor,
367        stds: &Tensor,
368        actions: &Tensor,
369    ) -> candle_core::error::Result<Tensor> {
370        let batch_size = actions.dims()[0];
371        let num_params = actions.dims()[1];
372
373        let diff = actions.sub(means)?;
374
375        // Broadcast stds to match batch dimension
376        let stds_broadcast = stds.unsqueeze(0)?.broadcast_as(means.shape())?;
377
378        let variance = stds_broadcast.sqr()?;
379        let squared_diff = diff.sqr()?.div(&variance)?;
380
381        let log_std = stds_broadcast.log()?;
382
383        // Create constant tensors with proper shapes
384        let pi_vec = vec![std::f32::consts::PI; batch_size * num_params];
385        let pi_constant = Tensor::from_vec(pi_vec, &[batch_size, num_params], &self.device)?;
386
387        let half_vec = vec![0.5f32; batch_size * num_params];
388        let half_tensor = Tensor::from_vec(half_vec, &[batch_size, num_params], &self.device)?;
389
390        let constant = pi_constant.log()?.mul(&half_tensor)?;
391
392        let nll = constant
393            .add(&log_std)?
394            .add(&squared_diff.mul(&half_tensor)?)?;
395
396        nll.mean_all()
397    }
398
399    /// Save model in both ONNX and SafeTensors formats
400    pub fn save(&self, path: &std::path::Path) -> Result<()> {
401        self.online_network.save_to_onnx(path)
402            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
403
404        let safetensors_path = path.with_extension("safetensors");
405        self.online_network.save_to_safetensors(&safetensors_path)
406            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
407
408        tracing::info!("Model saved: ONNX ({} bytes), SafeTensors ({} bytes)",
409                   std::fs::metadata(path).map(|m| m.len()).unwrap_or(0),
410                   std::fs::metadata(&safetensors_path).map(|m| m.len()).unwrap_or(0));
411
412        Ok(())
413    }
414
415    /// Load model from ONNX format
416    pub fn load(
417        path: &std::path::Path,
418        state_dim: usize,
419        num_actions: usize,
420        num_params: usize,
421    ) -> Result<Self> {
422        let device = crate::device::get_device();
423        Self::load_with_device(path, state_dim, num_actions, num_params, &device)
424    }
425
426    pub fn load_with_device(
427        path: &std::path::Path,
428        state_dim: usize,
429        num_actions: usize,
430        num_params: usize,
431        device: &Device,
432    ) -> Result<Self> {
433        tracing::info!("Loading model on device: {}", crate::device::get_device_info(device));
434
435        let online_network = DuelingDQN::load_from_onnx(path, state_dim, num_actions, num_params, device)
436            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
437
438        // Create target network on the SAME device
439        let target_varmap = VarMap::new();
440        let vb_target = VarBuilder::from_varmap(&target_varmap, DType::F32, device);
441        let target_network = DuelingDQN::new(state_dim, num_actions, num_params, vb_target)
442            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
443
444        let varmap = VarMap::new();
445        let vars = varmap.all_vars();
446        let params = ParamsAdamW::default();
447        let optimizer = AdamW::new(vars, params)
448            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
449
450        Ok(Self {
451            online_network,
452            target_network,
453            optimizer,
454            varmap,
455            num_actions,
456            num_params,
457            gamma: 0.95,
458            step_count: 0,
459            device: device.clone(),
460        })
461    }
462}
463
464// Implement RLAgent trait for DQNAgent
465impl RLAgent for DQNAgent {
466    fn select_action(&self, state: &[f32], epsilon: f32) -> Result<(usize, Vec<f32>)> {
467        let mut rng = rand::rng();
468
469        if rng.random::<f32>() < epsilon {
470            let discrete_action = rng.random_range(0..self.num_actions);
471            let params: Vec<f32> = (0..self.num_params)
472                .map(|_| rng.random_range(-1.0..1.0))
473                .collect();
474            Ok((discrete_action, params))
475        } else {
476            let state_tensor = Tensor::from_vec(state.to_vec(), &[1, state.len()], &self.device)
477                .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
478
479            let (q_values, param_mean, _param_std) = self.online_network.forward(&state_tensor, false)
480                .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
481
482            let q_vals = q_values.to_vec2::<f32>()
483                .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
484            let discrete_action = q_vals[0].iter()
485                .enumerate()
486                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
487                .map(|(idx, _)| idx)
488                .unwrap_or(0);
489
490            let params = param_mean.to_vec2::<f32>()
491                .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
492            let continuous_params = params[0].clone();
493
494            Ok((discrete_action, continuous_params))
495        }
496    }
497
498
499    /// Complete training step with proper loss calculation
500    fn train_step(&mut self, replay_buffer: &mut PrioritizedReplayBuffer, batch_size: usize) -> Result<f32> {
501        let batch = replay_buffer.sample(batch_size);
502
503        if batch.is_none() {
504            return Ok(0.0);
505        }
506
507        let SampledBatch { experiences, indices, weights } = batch.unwrap();
508
509        // Extract components from experiences
510        let states: Vec<Vec<f32>> = experiences.iter()
511            .map(|e| e.state.clone())
512            .collect();
513        let actions_discrete: Vec<usize> = experiences.iter()
514            .map(|e| e.action.0)
515            .collect();
516        let actions_params: Vec<Vec<f32>> = experiences.iter()
517            .map(|e| e.action.1.clone())
518            .collect();
519        let rewards: Vec<f32> = experiences.iter()
520            .map(|e| e.reward)
521            .collect();
522        let next_states: Vec<Vec<f32>> = experiences.iter()
523            .map(|e| e.next_state.clone())
524            .collect();
525        let dones: Vec<f32> = experiences.iter()
526            .map(|e| if e.done { 1.0 } else { 0.0 })
527            .collect();
528
529        // Convert to tensors
530        let state_dim = states[0].len();
531        let states_flat: Vec<f32> = states.into_iter().flatten().collect();
532        let states_tensor = Tensor::from_vec(
533            states_flat,
534            &[batch_size, state_dim],
535            &self.device
536        ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
537
538        let next_states_flat: Vec<f32> = next_states.into_iter().flatten().collect();
539        let next_states_tensor = Tensor::from_vec(
540            next_states_flat,
541            &[batch_size, state_dim],
542            &self.device
543        ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
544
545        let rewards_tensor = Tensor::from_vec(
546            rewards,
547            &[batch_size],
548            &self.device
549        ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
550
551        let dones_tensor = Tensor::from_vec(
552            dones,
553            &[batch_size],
554            &self.device
555        ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
556
557        let weights_tensor = Tensor::from_vec(
558            weights,
559            &[batch_size],
560            &self.device
561        ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
562
563        // Actions tensors
564        let actions_discrete_tensor = Tensor::from_vec(
565            actions_discrete.iter().map(|&x| x as i64).collect::<Vec<_>>(),
566            &[batch_size],
567            &self.device
568        ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
569
570        let actions_params_flat: Vec<f32> = actions_params.into_iter().flatten().collect();
571        let actions_params_tensor = Tensor::from_vec(
572            actions_params_flat,
573            &[batch_size, self.num_params],
574            &self.device
575        ).map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
576
577        // Forward pass through online network
578        let (q_values, param_means, param_stds) = self.online_network.forward(&states_tensor, true)
579            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
580
581        // VALIDATION: Check for NaN/Inf in forward pass
582        let q_sample = q_values.get(0)?.to_vec1::<f32>()?;
583        if q_sample.iter().any(|&x| x.is_nan() || x.is_infinite()) {
584            return Err(crate::ExtractionError::ModelError(
585                "NaN/Inf detected in Q-values forward pass".to_string()
586            ));
587        }
588
589        // Gather Q-values for taken actions
590        let q_values_selected = q_values
591            .gather(&actions_discrete_tensor.unsqueeze(1)?, 1)?
592            .squeeze(1)?;
593
594        // Double DQN: Use online network to select actions, target network to evaluate
595        let (next_q_online, _, _) = self.online_network.forward(&next_states_tensor, false)
596            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
597
598        let next_actions = next_q_online.argmax(1)?;
599
600        let (next_q_target, _, _) = self.target_network.forward(&next_states_tensor, false)
601            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
602
603        let next_q_values = next_q_target
604            .gather(&next_actions.unsqueeze(1)?, 1)?
605            .squeeze(1)?;
606
607        // Calculate TD targets with proper shape broadcasting
608        let ones = Tensor::ones(&[batch_size], DType::F32, &self.device)
609            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
610
611        // Create gamma tensor with same shape as batch
612        let gamma_vec = vec![self.gamma; batch_size];
613        let gamma_tensor = Tensor::from_vec(gamma_vec, &[batch_size], &self.device)?;
614
615        // Calculate discount factors: gamma * (1 - done)
616        let discount_factors = (ones - dones_tensor)?
617            .mul(&gamma_tensor)?;
618
619        // TD target: reward + gamma * (1 - done) * next_q
620        let td_targets = rewards_tensor
621            .add(&next_q_values.mul(&discount_factors)?)?;
622
623        // Calculate TD errors for priority update
624        let td_errors_tensor = (td_targets.clone() - q_values_selected.clone())?;
625        let td_errors: Vec<f32> = td_errors_tensor
626            .to_vec1()
627            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
628
629        // Q-value loss (Smooth L1 / Huber loss)
630        let q_loss_elements = smooth_l1_loss(&q_values_selected, &td_targets)?;
631        let weighted_q_loss = (q_loss_elements * weights_tensor.clone())?;
632        let loss_q = weighted_q_loss.mean_all()?;
633
634        // Parameter loss (Negative log-likelihood of Gaussian)
635        let param_loss = self.calculate_param_loss(&param_means, &param_stds, &actions_params_tensor)?;
636
637        // Combine losses
638        let loss_q_scalar = loss_q.to_scalar::<f32>()
639            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
640
641        let param_loss_scalar = param_loss.to_scalar::<f32>()
642            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
643
644        let total_loss_scalar = loss_q_scalar + 0.1 * param_loss_scalar;
645
646        // Create tensor from combined scalar
647        let total_loss = Tensor::from_vec(
648            vec![total_loss_scalar],
649            &[1],
650            &self.device
651        )?;
652
653        // VALIDATION: Check final loss
654        if total_loss_scalar.is_nan() || total_loss_scalar.is_infinite() {
655            return Err(crate::ExtractionError::ModelError(
656                format!("Invalid loss: {}", total_loss_scalar)
657            ));
658        }
659
660        // Get GradStore from backward pass
661        // Perform backward pass to get gradients
662        let mut grad_store = total_loss.backward()?;
663
664        // IMPROVED: Clip gradients to prevent explosion
665        // Get all trainable variables from the varmap
666        let vars = self.varmap.all_vars();
667        let max_grad_norm = 1.0f32;
668        let mut total_norm_sq = 0.0f32;
669
670        // Calculate total gradient norm for all trainable variables
671        for var in &vars {
672            if let Some(grad) = grad_store.get(var) {
673                let norm_sq = grad.sqr()?.sum_all()?.to_scalar::<f32>()?;
674                total_norm_sq += norm_sq;
675            }
676        }
677
678        let total_norm = total_norm_sq.sqrt();
679
680        // Apply gradient clipping if needed
681        if total_norm > max_grad_norm {
682            let clip_coef = max_grad_norm / (total_norm + 1e-6);
683
684            // Apply clipping to each gradient
685            for var in self.varmap.all_vars() {
686                if let Some(grad) = grad_store.get(&var) {
687                    // Create a tensor for the clip coefficient
688                    let clip_coef_tensor = Tensor::from_vec(
689                        vec![clip_coef],
690                        &[1],
691                        &self.device
692                    )?;
693
694                    // Multiply gradient by clip coefficient
695                    let clipped_grad = grad.mul(&clip_coef_tensor)?;
696
697                    // Update the gradient in the grad store
698                    grad_store.insert(&var, clipped_grad);
699                }
700            }
701
702            if self.step_count.is_multiple_of(1000) {
703                info!("Gradient norm: {:.4}, clipped with coef: {:.4}", total_norm, clip_coef);
704            }
705        }
706
707        // Step the optimizer with gradients
708        self.optimizer.step(&grad_store)
709            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
710
711        // Update priorities in replay buffer
712        replay_buffer.update_priorities(&indices, &td_errors);
713
714        self.step_count += 1;
715
716        // Return loss value
717        Ok(total_loss_scalar)
718    }
719
720    fn update_target_network(&mut self) {
721        if let Err(e) = Self::copy_network_weights(&self.online_network, &mut self.target_network) {
722            warn!("Failed to update target network: {}", e);
723        } else {
724            info!("Target network updated (hard update)");
725        }
726    }
727
728    fn get_step_count(&self) -> usize {
729        self.step_count
730    }
731    
732    fn save_with_metadata(
733        &self,
734        path: &Path,
735        training_episodes: usize,
736        hyperparameters: std::collections::HashMap<String, f64>,
737    ) -> Result<()> {
738        use crate::models::ModelMetadata;
739
740        let metadata = ModelMetadata::new(
741            300,  // state_dim - should get from self
742            self.num_actions,
743            self.num_params,
744            AlgorithmType::DuelingDQN,
745            training_episodes,
746            hyperparameters,
747        );
748
749        self.online_network.save_to_onnx_with_metadata(path, metadata)
750            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
751
752        let safetensors_path = path.with_extension("safetensors");
753        self.online_network.save_to_safetensors(&safetensors_path)
754            .map_err(|e| crate::ExtractionError::ModelError(e.to_string()))?;
755
756        tracing::info!("Model saved with metadata: ONNX ({} bytes), SafeTensors ({} bytes)",
757               std::fs::metadata(path).map(|m| m.len()).unwrap_or(0),
758               std::fs::metadata(&safetensors_path).map(|m| m.len()).unwrap_or(0));
759
760        Ok(())
761    }
762
763    fn save(&self, path: &Path) -> Result<()> {
764        // Default save without extra metadata
765        self.save_with_metadata(path, 0, std::collections::HashMap::new())
766    }
767
768
769    fn algorithm_type(&self) -> AlgorithmType {
770        AlgorithmType::DuelingDQN
771    }
772
773    fn get_info(&self) -> AgentInfo {
774        AgentInfo {
775            algorithm: AlgorithmType::DuelingDQN,
776            num_parameters: 338525,
777            state_dim: 300,
778            num_actions: self.num_actions,
779            continuous_params: self.num_params,
780            version: "1.0.0".to_string(),
781            features: vec![
782                "dueling".to_string(),
783                "double_dqn".to_string(),
784                "prioritized_replay".to_string(),
785            ],
786        }
787    }
788}
789
790/// Smooth L1 loss (Huber loss)
791fn smooth_l1_loss(predicted: &Tensor, target: &Tensor) -> candle_core::error::Result<Tensor>
792{
793    let diff = predicted.sub(target)?;
794    let abs_diff = diff.abs()?;
795
796    let batch_size = predicted.dims()[0];
797    let threshold_vec = vec![1.0f32; batch_size];
798    let threshold = Tensor::from_vec(threshold_vec, &[batch_size], predicted.device())?;
799
800    let half_vec = vec![0.5f32; batch_size];
801    let half_tensor = Tensor::from_vec(half_vec, &[batch_size], predicted.device())?;
802
803    let small_loss = diff.sqr()?.mul(&half_tensor)?;
804    let large_loss = abs_diff.sub(&half_tensor)?;
805
806    abs_diff.lt(&threshold)?
807        .where_cond(&small_loss, &large_loss)
808}
809
810#[cfg(test)]
811mod tests {
812    use super::*;
813    use crate::replay_buffer::{PrioritizedReplayBuffer, Experience};
814    use candle_core::Device;
815    use candle_nn::VarBuilder;
816    use candle_core::DType;
817    use crate::Config;
818    use crate::models::NetworkConfig;
819
820    fn create_network_config(config: &Config) -> NetworkConfig {
821        NetworkConfig {
822            state_dim: config.state_dim,
823            num_actions: config.num_discrete_actions,
824            num_params: config.num_continuous_params,
825            hidden_layers: vec![512, 256, 128],
826            use_layer_norm: true,
827            dropout: 0.1,
828            value_hidden: 64,
829            advantage_hidden: 64,
830        }
831    }
832
833    #[test]
834    fn test_train_step_no_shape_mismatch() {
835        let device = Device::Cpu;
836        let varmap = VarMap::new();
837        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
838        let config = Config::default();
839        let network_config = create_network_config(&config);
840
841        let mut agent = DQNAgent::new(
842            network_config,
843            0.95,
844            0.001,
845            &device,
846            varmap,
847        ).unwrap();
848
849        let mut replay_buffer = PrioritizedReplayBuffer::new(10000, 0.6, 0.4);
850
851        for _ in 0..1000 {
852            let exp = Experience {
853                state: vec![0.1; 300],
854                action: (0, vec![0.0; 6]),
855                reward: 1.0,
856                next_state: vec![0.2; 300],
857                done: false,
858            };
859            replay_buffer.add(exp);
860        }
861
862        let result = agent.train_step(&mut replay_buffer, 512);
863
864        match result {
865            Ok(loss) => {
866                println!("Training step successful, loss: {}", loss);
867                assert!(!loss.is_nan(), "Loss should not be NaN");
868                assert!(!loss.is_infinite(), "Loss should not be infinite");
869            }
870            Err(e) => {
871                panic!("Training step failed: {}", e);
872            }
873        }
874    }
875}