rlkit/
policies.rs

1//! Policy module, containing implementations of various action selection policies.
2
3use crate::{types::{Action, QValue}};
4use candle_core::{Tensor, Result};  
5use rand::{Rng, rng};
6use candle_nn::ops; 
7
8/// Policy configuration
9#[deprecated(
10    since = "0.0.3",
11    note = "The enum PolicyConfig is deprecated, please directly instantiated the policy."
12)]
13#[derive(Debug, Default, Clone)]
14pub enum PolicyConfig {
15    EpsilonGreedy {
16        epsilon_start: f32,
17        epsilon_min: f32,
18        epsilon_decay: f32,
19    },
20    Boltzmann {
21        temperature_start: f32,
22        temperature_min: f32,
23        temperature_decay: f32,
24    },
25    OrnsteinUhlenbeck {
26        mu: f32,
27        theta: f32,
28        sigma: f32,
29        action_dim: usize,
30    },
31    GaussianNoise {
32        mean: f32,
33        std_dev: f32,
34        decay_rate: f32,
35    },
36    #[default]
37    DeterministicPolicy,
38}
39
40impl PolicyConfig {
41    /// Return the default ε-贪婪策略参数配置,常用在 DQN 中
42    pub const fn dqn_epsilon_greedy() -> Self {
43        Self::EpsilonGreedy {
44            epsilon_start: 1.0,
45            epsilon_min: 0.01,
46            epsilon_decay: 0.995,
47        }
48    }
49
50    /// Return the default Boltzmann strategy parameter configuration, commonly used in DDPG.
51    pub const fn default_boltzmann() -> Self {
52        Self::Boltzmann {
53            temperature_start: 1.0,
54            temperature_min: 0.1,
55            temperature_decay: 0.99,
56        }
57    }
58
59    /// Return the default Ornstein-Uhlenbeck process parameter configuration, commonly used in DDPG.
60    pub const fn ddpg_ornstein_uhlenbeck(action_dim: usize) -> Self {
61        Self::OrnsteinUhlenbeck {
62            mu: 0.0,
63            theta: 0.15,
64            sigma: 0.2,
65            action_dim,
66        }
67    }
68
69    /// Return the default Gaussian noise strategy parameter configuration, commonly used in DDPG.
70    pub const fn default_gaussian_noise() -> Self {
71        Self::GaussianNoise {
72            mean: 0.0,
73            std_dev: 0.2,
74            decay_rate: 0.99,
75        }
76    }
77}
78
79impl PolicyConfig {
80    /// Create a policy instance based on the configuration.
81    pub fn create_policy<T>(&self, action_dim: usize) -> Result<Box<dyn Policy<T>>>
82    where
83        T: Copy + From<f32> + std::ops::Add<Output = T>
84            + rand::distr::uniform::SampleUniform + Default + std::cmp::PartialOrd + std::fmt::Display,
85    {
86        match self {
87            Self::EpsilonGreedy { epsilon_start, epsilon_min, epsilon_decay } => {
88                Ok(Box::new(EpsilonGreedy::new(*epsilon_start, *epsilon_min, *epsilon_decay)))
89            }
90            Self::Boltzmann { temperature_start, temperature_min, temperature_decay } => {
91                Ok(Box::new(Boltzmann::new(*temperature_start, *temperature_min, *temperature_decay)))
92            }
93            Self::OrnsteinUhlenbeck { mu, theta, sigma, action_dim: _ } => {
94                Ok(Box::new(OrnsteinUhlenbeck::new(*mu, *theta, *sigma, action_dim)))
95            }
96            Self::GaussianNoise { mean, std_dev, decay_rate } => {
97                Ok(Box::new(GaussianNoise::new(*mean, *std_dev, *decay_rate)))
98            }
99            Self::DeterministicPolicy => {
100                Ok(Box::new(DeterministicPolicy))
101            }
102        }
103    }
104}
105
106/// Policy interface, defining methods for action selection.
107pub trait Policy<T = u16> {
108    /// Select an action based on the network output.
109    fn select_action(&mut self, q_value: &QValue<T>) -> Result<Action<T>>;
110    
111    /// Update the policy parameters (e.g., ε value).
112    fn update(&mut self);
113    
114    /// Get a string representation of the current policy parameters.
115    fn get_params(&self) -> String;
116}
117
118/// ε-Greedy policy, commonly used in DQN for exploration.
119pub struct EpsilonGreedy {
120    /// Current ε value
121    pub epsilon: f32,
122    /// Minimum ε value
123    pub epsilon_min: f32,
124    /// ε decay rate
125    pub epsilon_decay: f32,
126}
127
128impl EpsilonGreedy {
129    /// Create a new ε-Greedy policy.
130    /// 
131    /// # Arguments
132    /// * `epsilon_start` - Initial ε value
133    /// * `epsilon_min` - Minimum ε value
134    /// * `epsilon_decay` - ε decay rate
135    pub fn new(epsilon_start: f32, epsilon_min: f32, epsilon_decay: f32) -> Self {
136        Self {
137            epsilon: epsilon_start,
138            epsilon_min,
139            epsilon_decay,
140        }
141    }
142}
143
144impl<T> Policy<T> for EpsilonGreedy
145where
146    T: Copy + rand::distr::uniform::SampleUniform + Default + std::cmp::PartialOrd,
147{
148    /// Select an action based on the ε-Greedy policy.
149    /// 
150    /// # Arguments
151    /// * `q_values` - Q-value distribution for the current state
152    fn select_action(&mut self, q_values: &QValue<T>) -> Result<Action<T>> {
153        let mut rng = rng();
154        
155        match q_values {
156            QValue::Deterministic(action) => {
157                if rng.random::<f32>() < self.epsilon {
158                    Ok(action.random(&mut rng))
159                } else {
160                    Ok(action.clone())
161                }
162            },
163            QValue::Stochastic(actions_with_values) => {
164                // 获取最好的动作
165                let best_action = q_values.best_action().clone();
166
167                if rng.random::<f32>() < self.epsilon {
168                    // 从所有可用动作中随机选择一个
169                    let random_idx = rng.random_range(0..actions_with_values.len());
170                    Ok(actions_with_values[random_idx].0.clone())
171                } else {
172                    Ok(best_action.clone())
173                }
174            }
175        }
176    }
177    
178    /// Update the ε value according to the decay rate.
179    fn update(&mut self) {
180        // 衰减ε值
181        if self.epsilon > self.epsilon_min {
182            self.epsilon *= self.epsilon_decay;
183        }
184    }
185    
186    /// Get a string representation of the current ε value.
187    fn get_params(&self) -> String {
188        format!("ε={:.4}", self.epsilon)
189    }
190}
191
192/// Boltzmann policy, commonly used in DQN for exploration.
193pub struct Boltzmann {
194    /// Current temperature value
195    pub temperature: f32,
196    /// Minimum temperature value
197    pub temperature_min: f32,
198    /// Temperature decay rate
199    pub temperature_decay: f32,
200}
201
202impl Boltzmann {
203    /// Create a new Boltzmann policy.
204    /// 
205    /// # Arguments
206    /// * `temperature_start` - Initial temperature value
207    /// * `temperature_min` - Minimum temperature value
208    /// * `temperature_decay` - Temperature decay rate
209    pub fn new(temperature_start: f32, temperature_min: f32, temperature_decay: f32) -> Self {
210        Self {
211            temperature: temperature_start,
212            temperature_min,
213            temperature_decay,
214        }
215    }
216}
217
218impl<T> Policy<T> for Boltzmann
219where
220    T: Copy,
221{
222    /// Select an action based on the Boltzmann policy.
223    /// 
224    /// # Arguments
225    /// * `q_values` - Q-value distribution for the current state
226    fn select_action(&mut self, q_values: &QValue<T>) -> Result<Action<T>> {
227        match q_values {
228            QValue::Deterministic(action) => {
229                // 对于确定性Q值,直接返回对应的动作
230                Ok(action.clone())
231            },
232            QValue::Stochastic(actions_with_values) => {
233                // 对于随机Q值,从动作集合中基于softmax概率采样
234                let mut rng = rng();
235                
236                // 提取所有动作的Q值
237                let values: Vec<f32> = actions_with_values.iter()
238                    .map(|(_, q_val)| *q_val)
239                    .collect();
240                
241                // 计算softmax概率
242                let values_tensor = Tensor::new(values.as_slice(), &candle_core::Device::Cpu)?;
243                let temperature_tensor = Tensor::new(self.temperature, &candle_core::Device::Cpu)?;
244                let scaled_values = values_tensor.div(&temperature_tensor)?;
245                let probabilities = ops::softmax(&scaled_values, 0)?;
246                
247                // 从概率分布中采样动作索引
248                let probabilities_vec = probabilities.to_vec1::<f32>()?;
249                let sample = rng.random::<f32>();
250                let mut cumulative = 0.0;
251                
252                for (i, &prob) in probabilities_vec.iter().enumerate() {
253                    cumulative += prob;
254                    if sample < cumulative {
255                        // 返回选中的动作
256                        return Ok(actions_with_values[i].0.clone());
257                    }
258                }
259                
260                // 以防数值精度问题,返回最后一个动作
261                Ok(actions_with_values.last().unwrap().0.clone())
262            }
263        }
264    }
265    
266    /// Update the temperature value according to the decay rate.
267    fn update(&mut self) {
268        // 衰减温度参数
269        if self.temperature > self.temperature_min {
270            self.temperature *= self.temperature_decay;
271        }
272    }
273    
274    /// Get a string representation of the current temperature value.
275    fn get_params(&self) -> String {
276        format!("T={:.4}", self.temperature)
277    }
278}
279
280/// Ornstein-Uhlenbeck process noise, commonly used in DDPG for exploration.
281pub struct OrnsteinUhlenbeck {
282    /// Mean value
283    pub mu: f32,
284    /// Theta parameter
285    pub theta: f32,
286    /// Sigma parameter
287    pub sigma: f32,
288    /// Action dimension
289    pub action_dim: usize,
290    /// Current state
291    pub state: Option<Vec<f32>>,
292}
293
294impl OrnsteinUhlenbeck {
295    /// Create a new Ornstein-Uhlenbeck process noise.
296    /// 
297    /// # Arguments
298    /// * `mu` - Mean value
299    /// * `theta` - Theta parameter
300    /// * `sigma` - Sigma parameter
301    /// * `action_dim` - Action dimension
302    pub fn new(mu: f32, theta: f32, sigma: f32, action_dim: usize) -> Self {
303        Self {
304            mu,
305            theta,
306            sigma,
307            action_dim,
308            state: None,
309        }
310    }
311    
312    fn sample(&mut self) -> Vec<f32> {
313        let mut rng = rng();
314        
315        match &mut self.state {
316            Some(state) => {
317                for i in 0..self.action_dim {
318                    let dx = self.theta * (self.mu - state[i]) + self.sigma * rng.random_range(-1.0..1.0);
319                    state[i] += dx;
320                }
321                state.clone()
322            },
323            None => {
324                // 初始状态
325                let state = vec![self.mu; self.action_dim];
326                self.state = Some(state.clone());
327                state
328            }
329        }
330    }
331}
332
333impl<T> Policy<T> for OrnsteinUhlenbeck
334where
335    T: Copy + From<f32> + std::ops::Add<Output = T>,
336{
337    /// Select an action based on the Ornstein-Uhlenbeck process noise.
338    /// 
339    /// # Arguments
340    /// * `q_values` - Q-value distribution for the current state
341    fn select_action(&mut self, q_values: &QValue<T>) -> Result<Action<T>> {
342        match q_values {
343            QValue::Deterministic(action) => {
344                // 对于确定性Q值,基于动作添加噪声
345                let mut action_data = action.value.clone();
346                
347                // 添加噪声
348                let noise = self.sample();
349                for i in 0..action_data.len() {
350                    action_data[i] = action_data[i] + T::from(noise[i]);
351                }
352                
353                // 返回带噪声的动作,使用相同的上界
354                Ok(Action::new(action_data, action.uppers.clone()))
355            },
356            QValue::Stochastic(_actions_with_values) => {
357                // 对于随机Q值,选择最佳动作并添加噪声
358                let best_action = q_values.best_action();
359                let mut action_data = best_action.value.clone();
360                
361                // 添加噪声
362                let noise = self.sample();
363                for i in 0..action_data.len() {
364                    action_data[i] = action_data[i] + T::from(noise[i]);
365                }
366                
367                // 返回带噪声的动作,使用相同的上界
368                Ok(Action::new(action_data, best_action.uppers.clone()))
369            }
370        }
371    }
372    
373    fn update(&mut self) {
374        // 对于OU过程,不需要特定的更新
375    }
376    
377    fn get_params(&self) -> String {
378        format!("μ={:.4}, θ={:.4}, σ={:.4}", self.mu, self.theta, self.sigma)
379    }
380}
381
382/// Gaussian noise policy for exploration in DDPG.
383pub struct GaussianNoise {
384    /// Mean value
385    pub mean: f32,
386    /// Standard deviation
387    pub std_dev: f32,
388    /// Decay rate for standard deviation
389    pub decay_rate: f32,
390}
391
392impl GaussianNoise {
393    pub fn new(mean: f32, std_dev: f32, decay_rate: f32) -> Self {
394        Self {
395            mean,
396            std_dev,
397            decay_rate,
398        }
399    }
400    
401    fn sample(&self, size: usize) -> Vec<f32> {
402        let mut rng = rng();
403        (0..size).map(|_| rng.random_range(-1.0..1.0) * self.std_dev + self.mean).collect()
404    }
405}
406
407impl<T> Policy<T> for GaussianNoise
408where
409    T: Copy + From<f32> + std::ops::Add<Output = T> + std::fmt::Display,
410{
411    fn select_action(&mut self, q_values: &QValue<T>) -> Result<Action<T>> {
412        match q_values {
413            QValue::Deterministic(action) => {
414                // 对于确定性Q值,基于动作添加噪声
415                let mut action_data = action.value.clone();
416                
417                // 添加高斯噪声
418                let noise = self.sample(action_data.len());
419                for i in 0..action_data.len() {
420                    action_data[i] = action_data[i] + T::from(noise[i]);
421                }
422                
423                // 返回带噪声的动作,使用相同的上界
424                Ok(Action::new(action_data, action.uppers.clone()))
425            },
426            QValue::Stochastic(_actions_with_values) => {
427                // 对于随机Q值,选择最佳动作并添加噪声
428                let best_action = q_values.best_action();
429                let mut action_data = best_action.value.clone();
430                
431                // 添加高斯噪声
432                let noise = self.sample(action_data.len());
433                for i in 0..action_data.len() {
434                    action_data[i] = action_data[i] + T::from(noise[i]);
435                }
436                
437                // 返回带噪声的动作,使用相同的上界
438                Ok(Action::new(action_data, best_action.uppers.clone()))
439            }
440        }
441    }
442    
443    fn update(&mut self) {
444        // 衰减标准差
445        self.std_dev = self.std_dev * self.decay_rate;
446    }
447    
448    fn get_params(&self) -> String {
449        format!("μ={:.4}, σ={:.4}", self.mean, self.std_dev)
450    }
451}
452
453/// Deterministic policy, directly using the network output as the action.
454pub struct DeterministicPolicy;
455
456impl DeterministicPolicy {
457    pub fn new() -> Self {
458        Self
459    }
460}
461
462impl<T> Policy<T> for DeterministicPolicy
463where
464    T: Copy,
465{
466    fn select_action(&mut self, q_values: &QValue<T>) -> Result<Action<T>> {
467        match q_values {
468            QValue::Deterministic(action) => {
469                // 对于确定性Q值,直接返回对应的动作
470                Ok(action.clone())
471            },
472            QValue::Stochastic(_actions_with_values) => {
473                // 对于随机Q值,返回最佳动作
474                Ok(q_values.best_action().clone())
475            }
476        }
477    }
478    
479    fn update(&mut self) {
480        // 确定性策略不需要更新参数
481    }
482    
483    fn get_params(&self) -> String {
484        "Deterministic".to_string()
485    }
486}