oxicuda_rl/policy/
deterministic.rs1use crate::error::{RlError, RlResult};
27use crate::handle::RlHandle;
28
29#[derive(Debug, Clone)]
36pub struct DeterministicPolicy {
37 action_dim: usize,
38 action_low: f32,
40 action_high: f32,
42}
43
44impl DeterministicPolicy {
45 #[must_use]
47 pub fn new(action_dim: usize) -> Self {
48 Self::with_bounds(action_dim, -1.0, 1.0)
49 }
50
51 #[must_use]
53 pub fn with_bounds(action_dim: usize, action_low: f32, action_high: f32) -> Self {
54 assert!(action_dim > 0, "action_dim must be > 0");
55 assert!(action_low < action_high, "action_low must be < action_high");
56 Self {
57 action_dim,
58 action_low,
59 action_high,
60 }
61 }
62
63 #[must_use]
65 #[inline]
66 pub fn action_dim(&self) -> usize {
67 self.action_dim
68 }
69
70 pub fn clip_action(&self, action: &[f32]) -> RlResult<Vec<f32>> {
76 if action.len() != self.action_dim {
77 return Err(RlError::DimensionMismatch {
78 expected: self.action_dim,
79 got: action.len(),
80 });
81 }
82 Ok(action
83 .iter()
84 .map(|&a| a.clamp(self.action_low, self.action_high))
85 .collect())
86 }
87
88 pub fn exploration_action(
96 &self,
97 action: &[f32],
98 sigma: f32,
99 handle: &mut RlHandle,
100 ) -> RlResult<Vec<f32>> {
101 if action.len() != self.action_dim {
102 return Err(RlError::DimensionMismatch {
103 expected: self.action_dim,
104 got: action.len(),
105 });
106 }
107 let rng = handle.rng_mut();
108 let noisy: Vec<f32> = action
109 .iter()
110 .map(|&a| {
111 let u1 = (rng.next_f32() + 1e-10).min(1.0 - 1e-10);
112 let u2 = rng.next_f32();
113 let noise = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
114 (a + sigma * noise).clamp(self.action_low, self.action_high)
115 })
116 .collect();
117 Ok(noisy)
118 }
119
120 pub fn smooth_target_action(
130 &self,
131 action: &[f32],
132 sigma: f32,
133 clip_c: f32,
134 handle: &mut RlHandle,
135 ) -> RlResult<Vec<f32>> {
136 if action.len() != self.action_dim {
137 return Err(RlError::DimensionMismatch {
138 expected: self.action_dim,
139 got: action.len(),
140 });
141 }
142 let rng = handle.rng_mut();
143 let smoothed: Vec<f32> = action
144 .iter()
145 .map(|&a| {
146 let u1 = (rng.next_f32() + 1e-10).min(1.0 - 1e-10);
147 let u2 = rng.next_f32();
148 let noise_raw = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
149 let noise = (sigma * noise_raw).clamp(-clip_c, clip_c);
150 (a + noise).clamp(self.action_low, self.action_high)
151 })
152 .collect();
153 Ok(smoothed)
154 }
155}
156
157#[derive(Debug, Clone)]
166pub struct OrnsteinUhlenbeck {
167 action_dim: usize,
168 mu: Vec<f32>,
170 theta: f32,
172 sigma: f32,
174 state: Vec<f32>,
176}
177
178impl OrnsteinUhlenbeck {
179 #[must_use]
184 pub fn new(action_dim: usize, theta: f32, sigma: f32) -> Self {
185 Self {
186 action_dim,
187 mu: vec![0.0; action_dim],
188 theta,
189 sigma,
190 state: vec![0.0; action_dim],
191 }
192 }
193
194 pub fn reset(&mut self) {
196 self.state.iter_mut().for_each(|x| *x = 0.0);
197 }
198
199 pub fn sample(&mut self, handle: &mut RlHandle) -> Vec<f32> {
201 let rng = handle.rng_mut();
202 let mut out = Vec::with_capacity(self.action_dim);
203 let mut k = 0;
204 while k < self.action_dim {
206 let u1 = (rng.next_f32() + 1e-10).min(1.0 - 1e-10);
207 let u2 = rng.next_f32();
208 let r = (-2.0 * u1.ln()).sqrt();
209 let theta = 2.0 * std::f32::consts::PI * u2;
210 out.push(r * theta.cos());
211 if k + 1 < self.action_dim {
212 out.push(r * theta.sin());
213 }
214 k += 2;
215 }
216 out.truncate(self.action_dim);
217
218 for (x, (&mu, &w)) in self.state.iter_mut().zip(self.mu.iter().zip(out.iter())) {
219 *x += self.theta * (mu - *x) + self.sigma * w;
220 }
221 self.state.clone()
222 }
223}
224
225#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
234 fn clip_action_within_bounds() {
235 let p = DeterministicPolicy::new(3);
236 let clipped = p.clip_action(&[-2.0, 0.0, 2.0]).unwrap();
237 assert_eq!(clipped, vec![-1.0, 0.0, 1.0]);
238 }
239
240 #[test]
241 fn exploration_action_stays_in_bounds() {
242 let p = DeterministicPolicy::new(4);
243 let mut handle = RlHandle::default_handle();
244 for _ in 0..100 {
245 let a = p.exploration_action(&[0.0; 4], 0.3, &mut handle).unwrap();
246 for v in a {
247 assert!(
248 (-1.0..=1.0).contains(&v),
249 "exploration action out of bounds: {v}"
250 );
251 }
252 }
253 }
254
255 #[test]
256 fn smooth_target_action_within_bounds() {
257 let p = DeterministicPolicy::new(2);
258 let mut handle = RlHandle::default_handle();
259 for _ in 0..100 {
260 let a = p
261 .smooth_target_action(&[0.5, -0.5], 0.2, 0.5, &mut handle)
262 .unwrap();
263 for v in a {
264 assert!(
265 (-1.0..=1.0).contains(&v),
266 "smoothed action out of bounds: {v}"
267 );
268 }
269 }
270 }
271
272 #[test]
273 fn clip_action_dimension_error() {
274 let p = DeterministicPolicy::new(3);
275 assert!(p.clip_action(&[0.0; 2]).is_err());
276 }
277
278 #[test]
281 fn ou_sample_correct_dim() {
282 let mut ou = OrnsteinUhlenbeck::new(4, 0.15, 0.2);
283 let mut handle = RlHandle::default_handle();
284 let noise = ou.sample(&mut handle);
285 assert_eq!(noise.len(), 4);
286 }
287
288 #[test]
289 fn ou_reset_zeroes_state() {
290 let mut ou = OrnsteinUhlenbeck::new(2, 0.15, 0.2);
291 let mut handle = RlHandle::default_handle();
292 ou.sample(&mut handle);
293 ou.reset();
294 assert_eq!(ou.state, vec![0.0, 0.0]);
295 }
296
297 #[test]
298 fn ou_mean_reversion() {
299 let mut ou = OrnsteinUhlenbeck::new(1, 1.0, 0.01);
301 ou.state = vec![100.0];
302 let mut handle = RlHandle::default_handle();
303 for _ in 0..50 {
304 ou.sample(&mut handle);
305 }
306 assert!(
307 ou.state[0].abs() < 5.0,
308 "OU should mean-revert, state={}",
309 ou.state[0]
310 );
311 }
312}