Skip to main content

oxicuda_rl/policy/
deterministic.rs

1//! # Deterministic Policy (DDPG / TD3)
2//!
3//! Models a deterministic policy `μ_θ(s)` that maps states to actions directly.
4//! Used in DDPG and TD3 where the critic provides the gradient signal.
5//!
6//! ## Target policy smoothing (TD3)
7//!
8//! TD3 adds clipped Gaussian noise to the target policy during Q-value
9//! computation to prevent overfitting to sharp peaks in the Q-function:
10//!
11//! ```text
12//! ã = clip(μ'(s') + clip(ε, -c, c), a_lo, a_hi)
13//! where ε ~ N(0, σ²)
14//! ```
15//!
16//! ## Exploration noise
17//!
18//! During training an **Ornstein-Uhlenbeck** process is commonly added to
19//! action outputs for temporally correlated exploration:
20//!
21//! ```text
22//! dx_t = θ(μ - x_t)dt + σ dW_t
23//! ```
24//! approximated as: `x_{t+1} = x_t + θ(μ - x_t) + σ N(0, 1)`.
25
26use crate::error::{RlError, RlResult};
27use crate::handle::RlHandle;
28
29// ─── DeterministicPolicy ─────────────────────────────────────────────────────
30
31/// Deterministic policy wrapper for DDPG/TD3.
32///
33/// Holds only the action bounds and exploration noise configuration; the
34/// network parameters (`μ_θ(s)`) are managed externally.
35#[derive(Debug, Clone)]
36pub struct DeterministicPolicy {
37    action_dim: usize,
38    /// Action lower bound.
39    action_low: f32,
40    /// Action upper bound.
41    action_high: f32,
42}
43
44impl DeterministicPolicy {
45    /// Create a policy with symmetric action bounds `[-1, 1]`.
46    #[must_use]
47    pub fn new(action_dim: usize) -> Self {
48        Self::with_bounds(action_dim, -1.0, 1.0)
49    }
50
51    /// Create a policy with custom action bounds.
52    #[must_use]
53    pub fn with_bounds(action_dim: usize, action_low: f32, action_high: f32) -> Self {
54        assert!(action_dim > 0, "action_dim must be > 0");
55        assert!(action_low < action_high, "action_low must be < action_high");
56        Self {
57            action_dim,
58            action_low,
59            action_high,
60        }
61    }
62
63    /// Number of action dimensions.
64    #[must_use]
65    #[inline]
66    pub fn action_dim(&self) -> usize {
67        self.action_dim
68    }
69
70    /// Clip `action` to `[action_low, action_high]`.
71    ///
72    /// # Errors
73    ///
74    /// * [`RlError::DimensionMismatch`] if `action.len() != action_dim`.
75    pub fn clip_action(&self, action: &[f32]) -> RlResult<Vec<f32>> {
76        if action.len() != self.action_dim {
77            return Err(RlError::DimensionMismatch {
78                expected: self.action_dim,
79                got: action.len(),
80            });
81        }
82        Ok(action
83            .iter()
84            .map(|&a| a.clamp(self.action_low, self.action_high))
85            .collect())
86    }
87
88    /// Add Gaussian exploration noise and clip.
89    ///
90    /// Returns `clip(action + N(0, σ²), low, high)`.
91    ///
92    /// # Errors
93    ///
94    /// * [`RlError::DimensionMismatch`] if `action.len() != action_dim`.
95    pub fn exploration_action(
96        &self,
97        action: &[f32],
98        sigma: f32,
99        handle: &mut RlHandle,
100    ) -> RlResult<Vec<f32>> {
101        if action.len() != self.action_dim {
102            return Err(RlError::DimensionMismatch {
103                expected: self.action_dim,
104                got: action.len(),
105            });
106        }
107        let rng = handle.rng_mut();
108        let noisy: Vec<f32> = action
109            .iter()
110            .map(|&a| {
111                let u1 = (rng.next_f32() + 1e-10).min(1.0 - 1e-10);
112                let u2 = rng.next_f32();
113                let noise = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
114                (a + sigma * noise).clamp(self.action_low, self.action_high)
115            })
116            .collect();
117        Ok(noisy)
118    }
119
120    /// TD3 target policy smoothing: add clipped noise to target actions.
121    ///
122    /// ```text
123    /// ã = clip(action + clip(N(0, σ²), -c, c), low, high)
124    /// ```
125    ///
126    /// # Errors
127    ///
128    /// * [`RlError::DimensionMismatch`] if `action.len() != action_dim`.
129    pub fn smooth_target_action(
130        &self,
131        action: &[f32],
132        sigma: f32,
133        clip_c: f32,
134        handle: &mut RlHandle,
135    ) -> RlResult<Vec<f32>> {
136        if action.len() != self.action_dim {
137            return Err(RlError::DimensionMismatch {
138                expected: self.action_dim,
139                got: action.len(),
140            });
141        }
142        let rng = handle.rng_mut();
143        let smoothed: Vec<f32> = action
144            .iter()
145            .map(|&a| {
146                let u1 = (rng.next_f32() + 1e-10).min(1.0 - 1e-10);
147                let u2 = rng.next_f32();
148                let noise_raw = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
149                let noise = (sigma * noise_raw).clamp(-clip_c, clip_c);
150                (a + noise).clamp(self.action_low, self.action_high)
151            })
152            .collect();
153        Ok(smoothed)
154    }
155}
156
157// ─── Ornstein-Uhlenbeck noise ────────────────────────────────────────────────
158
159/// Ornstein-Uhlenbeck noise process for temporally correlated exploration.
160///
161/// Implements the discrete-time approximation:
162/// ```text
163/// x_{t+1} = x_t + θ(μ - x_t) + σ * N(0, 1)
164/// ```
165#[derive(Debug, Clone)]
166pub struct OrnsteinUhlenbeck {
167    action_dim: usize,
168    /// Mean-reversion level (default 0).
169    mu: Vec<f32>,
170    /// Mean-reversion speed θ.
171    theta: f32,
172    /// Noise scale σ.
173    sigma: f32,
174    /// Current state.
175    state: Vec<f32>,
176}
177
178impl OrnsteinUhlenbeck {
179    /// Create an OU process with zero mean.
180    ///
181    /// * `theta` — mean-reversion speed (typical: 0.15).
182    /// * `sigma` — noise scale (typical: 0.2).
183    #[must_use]
184    pub fn new(action_dim: usize, theta: f32, sigma: f32) -> Self {
185        Self {
186            action_dim,
187            mu: vec![0.0; action_dim],
188            theta,
189            sigma,
190            state: vec![0.0; action_dim],
191        }
192    }
193
194    /// Reset the OU process to its mean.
195    pub fn reset(&mut self) {
196        self.state.iter_mut().for_each(|x| *x = 0.0);
197    }
198
199    /// Advance the OU process by one step and return the noise vector.
200    pub fn sample(&mut self, handle: &mut RlHandle) -> Vec<f32> {
201        let rng = handle.rng_mut();
202        let mut out = Vec::with_capacity(self.action_dim);
203        let mut k = 0;
204        // Box-Muller pairs
205        while k < self.action_dim {
206            let u1 = (rng.next_f32() + 1e-10).min(1.0 - 1e-10);
207            let u2 = rng.next_f32();
208            let r = (-2.0 * u1.ln()).sqrt();
209            let theta = 2.0 * std::f32::consts::PI * u2;
210            out.push(r * theta.cos());
211            if k + 1 < self.action_dim {
212                out.push(r * theta.sin());
213            }
214            k += 2;
215        }
216        out.truncate(self.action_dim);
217
218        for (x, (&mu, &w)) in self.state.iter_mut().zip(self.mu.iter().zip(out.iter())) {
219            *x += self.theta * (mu - *x) + self.sigma * w;
220        }
221        self.state.clone()
222    }
223}
224
225// ─── Tests ───────────────────────────────────────────────────────────────────
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    // ── DeterministicPolicy ──────────────────────────────────────────────────
232
233    #[test]
234    fn clip_action_within_bounds() {
235        let p = DeterministicPolicy::new(3);
236        let clipped = p.clip_action(&[-2.0, 0.0, 2.0]).unwrap();
237        assert_eq!(clipped, vec![-1.0, 0.0, 1.0]);
238    }
239
240    #[test]
241    fn exploration_action_stays_in_bounds() {
242        let p = DeterministicPolicy::new(4);
243        let mut handle = RlHandle::default_handle();
244        for _ in 0..100 {
245            let a = p.exploration_action(&[0.0; 4], 0.3, &mut handle).unwrap();
246            for v in a {
247                assert!(
248                    (-1.0..=1.0).contains(&v),
249                    "exploration action out of bounds: {v}"
250                );
251            }
252        }
253    }
254
255    #[test]
256    fn smooth_target_action_within_bounds() {
257        let p = DeterministicPolicy::new(2);
258        let mut handle = RlHandle::default_handle();
259        for _ in 0..100 {
260            let a = p
261                .smooth_target_action(&[0.5, -0.5], 0.2, 0.5, &mut handle)
262                .unwrap();
263            for v in a {
264                assert!(
265                    (-1.0..=1.0).contains(&v),
266                    "smoothed action out of bounds: {v}"
267                );
268            }
269        }
270    }
271
272    #[test]
273    fn clip_action_dimension_error() {
274        let p = DeterministicPolicy::new(3);
275        assert!(p.clip_action(&[0.0; 2]).is_err());
276    }
277
278    // ── OrnsteinUhlenbeck ────────────────────────────────────────────────────
279
280    #[test]
281    fn ou_sample_correct_dim() {
282        let mut ou = OrnsteinUhlenbeck::new(4, 0.15, 0.2);
283        let mut handle = RlHandle::default_handle();
284        let noise = ou.sample(&mut handle);
285        assert_eq!(noise.len(), 4);
286    }
287
288    #[test]
289    fn ou_reset_zeroes_state() {
290        let mut ou = OrnsteinUhlenbeck::new(2, 0.15, 0.2);
291        let mut handle = RlHandle::default_handle();
292        ou.sample(&mut handle);
293        ou.reset();
294        assert_eq!(ou.state, vec![0.0, 0.0]);
295    }
296
297    #[test]
298    fn ou_mean_reversion() {
299        // With θ=1.0 (fast reversion) and large initial noise, state should approach 0
300        let mut ou = OrnsteinUhlenbeck::new(1, 1.0, 0.01);
301        ou.state = vec![100.0];
302        let mut handle = RlHandle::default_handle();
303        for _ in 0..50 {
304            ou.sample(&mut handle);
305        }
306        assert!(
307            ou.state[0].abs() < 5.0,
308            "OU should mean-revert, state={}",
309            ou.state[0]
310        );
311    }
312}