1use crate::{types::{Action, QValue}};
4use candle_core::{Tensor, Result};
5use rand::{Rng, rng};
6use candle_nn::ops;
7
8#[deprecated(
10 since = "0.0.3",
11 note = "The enum PolicyConfig is deprecated, please directly instantiated the policy."
12)]
13#[derive(Debug, Default, Clone)]
14pub enum PolicyConfig {
15 EpsilonGreedy {
16 epsilon_start: f32,
17 epsilon_min: f32,
18 epsilon_decay: f32,
19 },
20 Boltzmann {
21 temperature_start: f32,
22 temperature_min: f32,
23 temperature_decay: f32,
24 },
25 OrnsteinUhlenbeck {
26 mu: f32,
27 theta: f32,
28 sigma: f32,
29 action_dim: usize,
30 },
31 GaussianNoise {
32 mean: f32,
33 std_dev: f32,
34 decay_rate: f32,
35 },
36 #[default]
37 DeterministicPolicy,
38}
39
40impl PolicyConfig {
41 pub const fn dqn_epsilon_greedy() -> Self {
43 Self::EpsilonGreedy {
44 epsilon_start: 1.0,
45 epsilon_min: 0.01,
46 epsilon_decay: 0.995,
47 }
48 }
49
50 pub const fn default_boltzmann() -> Self {
52 Self::Boltzmann {
53 temperature_start: 1.0,
54 temperature_min: 0.1,
55 temperature_decay: 0.99,
56 }
57 }
58
59 pub const fn ddpg_ornstein_uhlenbeck(action_dim: usize) -> Self {
61 Self::OrnsteinUhlenbeck {
62 mu: 0.0,
63 theta: 0.15,
64 sigma: 0.2,
65 action_dim,
66 }
67 }
68
69 pub const fn default_gaussian_noise() -> Self {
71 Self::GaussianNoise {
72 mean: 0.0,
73 std_dev: 0.2,
74 decay_rate: 0.99,
75 }
76 }
77}
78
79impl PolicyConfig {
80 pub fn create_policy<T>(&self, action_dim: usize) -> Result<Box<dyn Policy<T>>>
82 where
83 T: Copy + From<f32> + std::ops::Add<Output = T>
84 + rand::distr::uniform::SampleUniform + Default + std::cmp::PartialOrd + std::fmt::Display,
85 {
86 match self {
87 Self::EpsilonGreedy { epsilon_start, epsilon_min, epsilon_decay } => {
88 Ok(Box::new(EpsilonGreedy::new(*epsilon_start, *epsilon_min, *epsilon_decay)))
89 }
90 Self::Boltzmann { temperature_start, temperature_min, temperature_decay } => {
91 Ok(Box::new(Boltzmann::new(*temperature_start, *temperature_min, *temperature_decay)))
92 }
93 Self::OrnsteinUhlenbeck { mu, theta, sigma, action_dim: _ } => {
94 Ok(Box::new(OrnsteinUhlenbeck::new(*mu, *theta, *sigma, action_dim)))
95 }
96 Self::GaussianNoise { mean, std_dev, decay_rate } => {
97 Ok(Box::new(GaussianNoise::new(*mean, *std_dev, *decay_rate)))
98 }
99 Self::DeterministicPolicy => {
100 Ok(Box::new(DeterministicPolicy))
101 }
102 }
103 }
104}
105
106pub trait Policy<T = u16> {
108 fn select_action(&mut self, q_value: &QValue<T>) -> Result<Action<T>>;
110
111 fn update(&mut self);
113
114 fn get_params(&self) -> String;
116}
117
118pub struct EpsilonGreedy {
120 pub epsilon: f32,
122 pub epsilon_min: f32,
124 pub epsilon_decay: f32,
126}
127
128impl EpsilonGreedy {
129 pub fn new(epsilon_start: f32, epsilon_min: f32, epsilon_decay: f32) -> Self {
136 Self {
137 epsilon: epsilon_start,
138 epsilon_min,
139 epsilon_decay,
140 }
141 }
142}
143
144impl<T> Policy<T> for EpsilonGreedy
145where
146 T: Copy + rand::distr::uniform::SampleUniform + Default + std::cmp::PartialOrd,
147{
148 fn select_action(&mut self, q_values: &QValue<T>) -> Result<Action<T>> {
153 let mut rng = rng();
154
155 match q_values {
156 QValue::Deterministic(action) => {
157 if rng.random::<f32>() < self.epsilon {
158 Ok(action.random(&mut rng))
159 } else {
160 Ok(action.clone())
161 }
162 },
163 QValue::Stochastic(actions_with_values) => {
164 let best_action = q_values.best_action().clone();
166
167 if rng.random::<f32>() < self.epsilon {
168 let random_idx = rng.random_range(0..actions_with_values.len());
170 Ok(actions_with_values[random_idx].0.clone())
171 } else {
172 Ok(best_action.clone())
173 }
174 }
175 }
176 }
177
178 fn update(&mut self) {
180 if self.epsilon > self.epsilon_min {
182 self.epsilon *= self.epsilon_decay;
183 }
184 }
185
186 fn get_params(&self) -> String {
188 format!("ε={:.4}", self.epsilon)
189 }
190}
191
192pub struct Boltzmann {
194 pub temperature: f32,
196 pub temperature_min: f32,
198 pub temperature_decay: f32,
200}
201
202impl Boltzmann {
203 pub fn new(temperature_start: f32, temperature_min: f32, temperature_decay: f32) -> Self {
210 Self {
211 temperature: temperature_start,
212 temperature_min,
213 temperature_decay,
214 }
215 }
216}
217
218impl<T> Policy<T> for Boltzmann
219where
220 T: Copy,
221{
222 fn select_action(&mut self, q_values: &QValue<T>) -> Result<Action<T>> {
227 match q_values {
228 QValue::Deterministic(action) => {
229 Ok(action.clone())
231 },
232 QValue::Stochastic(actions_with_values) => {
233 let mut rng = rng();
235
236 let values: Vec<f32> = actions_with_values.iter()
238 .map(|(_, q_val)| *q_val)
239 .collect();
240
241 let values_tensor = Tensor::new(values.as_slice(), &candle_core::Device::Cpu)?;
243 let temperature_tensor = Tensor::new(self.temperature, &candle_core::Device::Cpu)?;
244 let scaled_values = values_tensor.div(&temperature_tensor)?;
245 let probabilities = ops::softmax(&scaled_values, 0)?;
246
247 let probabilities_vec = probabilities.to_vec1::<f32>()?;
249 let sample = rng.random::<f32>();
250 let mut cumulative = 0.0;
251
252 for (i, &prob) in probabilities_vec.iter().enumerate() {
253 cumulative += prob;
254 if sample < cumulative {
255 return Ok(actions_with_values[i].0.clone());
257 }
258 }
259
260 Ok(actions_with_values.last().unwrap().0.clone())
262 }
263 }
264 }
265
266 fn update(&mut self) {
268 if self.temperature > self.temperature_min {
270 self.temperature *= self.temperature_decay;
271 }
272 }
273
274 fn get_params(&self) -> String {
276 format!("T={:.4}", self.temperature)
277 }
278}
279
280pub struct OrnsteinUhlenbeck {
282 pub mu: f32,
284 pub theta: f32,
286 pub sigma: f32,
288 pub action_dim: usize,
290 pub state: Option<Vec<f32>>,
292}
293
294impl OrnsteinUhlenbeck {
295 pub fn new(mu: f32, theta: f32, sigma: f32, action_dim: usize) -> Self {
303 Self {
304 mu,
305 theta,
306 sigma,
307 action_dim,
308 state: None,
309 }
310 }
311
312 fn sample(&mut self) -> Vec<f32> {
313 let mut rng = rng();
314
315 match &mut self.state {
316 Some(state) => {
317 for i in 0..self.action_dim {
318 let dx = self.theta * (self.mu - state[i]) + self.sigma * rng.random_range(-1.0..1.0);
319 state[i] += dx;
320 }
321 state.clone()
322 },
323 None => {
324 let state = vec![self.mu; self.action_dim];
326 self.state = Some(state.clone());
327 state
328 }
329 }
330 }
331}
332
333impl<T> Policy<T> for OrnsteinUhlenbeck
334where
335 T: Copy + From<f32> + std::ops::Add<Output = T>,
336{
337 fn select_action(&mut self, q_values: &QValue<T>) -> Result<Action<T>> {
342 match q_values {
343 QValue::Deterministic(action) => {
344 let mut action_data = action.value.clone();
346
347 let noise = self.sample();
349 for i in 0..action_data.len() {
350 action_data[i] = action_data[i] + T::from(noise[i]);
351 }
352
353 Ok(Action::new(action_data, action.uppers.clone()))
355 },
356 QValue::Stochastic(_actions_with_values) => {
357 let best_action = q_values.best_action();
359 let mut action_data = best_action.value.clone();
360
361 let noise = self.sample();
363 for i in 0..action_data.len() {
364 action_data[i] = action_data[i] + T::from(noise[i]);
365 }
366
367 Ok(Action::new(action_data, best_action.uppers.clone()))
369 }
370 }
371 }
372
373 fn update(&mut self) {
374 }
376
377 fn get_params(&self) -> String {
378 format!("μ={:.4}, θ={:.4}, σ={:.4}", self.mu, self.theta, self.sigma)
379 }
380}
381
382pub struct GaussianNoise {
384 pub mean: f32,
386 pub std_dev: f32,
388 pub decay_rate: f32,
390}
391
392impl GaussianNoise {
393 pub fn new(mean: f32, std_dev: f32, decay_rate: f32) -> Self {
394 Self {
395 mean,
396 std_dev,
397 decay_rate,
398 }
399 }
400
401 fn sample(&self, size: usize) -> Vec<f32> {
402 let mut rng = rng();
403 (0..size).map(|_| rng.random_range(-1.0..1.0) * self.std_dev + self.mean).collect()
404 }
405}
406
407impl<T> Policy<T> for GaussianNoise
408where
409 T: Copy + From<f32> + std::ops::Add<Output = T> + std::fmt::Display,
410{
411 fn select_action(&mut self, q_values: &QValue<T>) -> Result<Action<T>> {
412 match q_values {
413 QValue::Deterministic(action) => {
414 let mut action_data = action.value.clone();
416
417 let noise = self.sample(action_data.len());
419 for i in 0..action_data.len() {
420 action_data[i] = action_data[i] + T::from(noise[i]);
421 }
422
423 Ok(Action::new(action_data, action.uppers.clone()))
425 },
426 QValue::Stochastic(_actions_with_values) => {
427 let best_action = q_values.best_action();
429 let mut action_data = best_action.value.clone();
430
431 let noise = self.sample(action_data.len());
433 for i in 0..action_data.len() {
434 action_data[i] = action_data[i] + T::from(noise[i]);
435 }
436
437 Ok(Action::new(action_data, best_action.uppers.clone()))
439 }
440 }
441 }
442
443 fn update(&mut self) {
444 self.std_dev = self.std_dev * self.decay_rate;
446 }
447
448 fn get_params(&self) -> String {
449 format!("μ={:.4}, σ={:.4}", self.mean, self.std_dev)
450 }
451}
452
453pub struct DeterministicPolicy;
455
456impl DeterministicPolicy {
457 pub fn new() -> Self {
458 Self
459 }
460}
461
462impl<T> Policy<T> for DeterministicPolicy
463where
464 T: Copy,
465{
466 fn select_action(&mut self, q_values: &QValue<T>) -> Result<Action<T>> {
467 match q_values {
468 QValue::Deterministic(action) => {
469 Ok(action.clone())
471 },
472 QValue::Stochastic(_actions_with_values) => {
473 Ok(q_values.best_action().clone())
475 }
476 }
477 }
478
479 fn update(&mut self) {
480 }
482
483 fn get_params(&self) -> String {
484 "Deterministic".to_string()
485 }
486}