burn_rl/transition_buffer/
base.rs1use burn_core::{Tensor, prelude::Backend, tensor::Distribution};
2use derive_new::new;
3
4use super::SliceAccess;
5
6#[derive(Clone, new)]
8pub struct Transition<B: Backend, S, A> {
9 pub state: S,
11 pub next_state: S,
13 pub action: A,
15 pub reward: Tensor<B, 1>,
17 pub done: Tensor<B, 1>,
19}
20
21pub struct TransitionBatch<B: Backend, SB, AB> {
23 pub states: SB,
25 pub next_states: SB,
27 pub actions: AB,
29 pub rewards: Tensor<B, 2>,
31 pub dones: Tensor<B, 2>,
33}
34
35pub struct TransitionBuffer<B: Backend, SB: SliceAccess<B>, AB: SliceAccess<B>> {
41 states: Option<SB>,
42 next_states: Option<SB>,
43 actions: Option<AB>,
44 rewards: Option<Tensor<B, 2>>,
45 dones: Option<Tensor<B, 2>>,
46 capacity: usize,
47 write_head: usize,
48 len: usize,
49 device: B::Device,
50}
51
52impl<B: Backend, SB: SliceAccess<B>, AB: SliceAccess<B>> TransitionBuffer<B, SB, AB> {
53 pub fn new(capacity: usize, device: &B::Device) -> Self {
55 Self {
56 states: None,
57 next_states: None,
58 actions: None,
59 rewards: None,
60 dones: None,
61 capacity,
62 write_head: 0,
63 len: 0,
64 device: device.clone(),
65 }
66 }
67
68 fn ensure_init(&mut self, state: &SB, next_state: &SB, action: &AB) {
69 if self.states.is_none() {
70 self.states = Some(SB::zeros_like(state, self.capacity, &self.device));
71 self.next_states = Some(SB::zeros_like(next_state, self.capacity, &self.device));
72 self.actions = Some(AB::zeros_like(action, self.capacity, &self.device));
73 self.rewards = Some(Tensor::zeros([self.capacity, 1], &self.device));
74 self.dones = Some(Tensor::zeros([self.capacity, 1], &self.device));
75 }
76 }
77
78 pub fn push(&mut self, state: SB, next_state: SB, action: AB, reward: f32, done: bool) {
80 self.ensure_init(&state, &next_state, &action);
81
82 let idx = self.write_head % self.capacity;
83
84 self.states
85 .as_mut()
86 .unwrap()
87 .slice_assign_inplace(idx, state);
88 self.next_states
89 .as_mut()
90 .unwrap()
91 .slice_assign_inplace(idx, next_state);
92 self.actions
93 .as_mut()
94 .unwrap()
95 .slice_assign_inplace(idx, action);
96
97 let reward = Tensor::from_data([[reward]], &self.device);
98 self.rewards
99 .as_mut()
100 .unwrap()
101 .inplace(|r| r.slice_assign(idx..idx + 1, reward));
102
103 let done_val = if done { 1.0f32 } else { 0.0 };
104 let done = Tensor::from_data([[done_val]], &self.device);
105 self.dones
106 .as_mut()
107 .unwrap()
108 .inplace(|d| d.slice_assign(idx..idx + 1, done));
109
110 self.write_head += 1;
111 if self.len < self.capacity {
112 self.len += 1;
113 }
114 }
115
116 pub fn sample(&self, batch_size: usize) -> TransitionBatch<B, SB, AB> {
118 assert!(batch_size <= self.len, "batch_size exceeds buffer length");
119
120 let indices = Tensor::<B, 1>::random(
121 [batch_size],
122 Distribution::Uniform(0.0, self.len as f64),
123 &self.device,
124 )
125 .int();
126
127 TransitionBatch {
128 states: self
129 .states
130 .as_ref()
131 .unwrap()
132 .clone()
133 .select(0, indices.clone()),
134 next_states: self
135 .next_states
136 .as_ref()
137 .unwrap()
138 .clone()
139 .select(0, indices.clone()),
140 actions: self
141 .actions
142 .as_ref()
143 .unwrap()
144 .clone()
145 .select(0, indices.clone()),
146 rewards: self
147 .rewards
148 .as_ref()
149 .unwrap()
150 .clone()
151 .select(0, indices.clone()),
152 dones: self.dones.as_ref().unwrap().clone().select(0, indices),
153 }
154 }
155
156 pub fn len(&self) -> usize {
158 self.len
159 }
160
161 pub fn is_empty(&self) -> bool {
163 self.len == 0
164 }
165
166 pub fn capacity(&self) -> usize {
168 self.capacity
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::TestBackend;
176
177 type TB = Tensor<TestBackend, 2>;
178
179 fn push_transition(
180 buffer: &mut TransitionBuffer<TestBackend, TB, TB>,
181 device: &<TestBackend as Backend>::Device,
182 val: f32,
183 ) {
184 let state = Tensor::<TestBackend, 2>::from_data([[val, val]], device);
185 let next_state = Tensor::<TestBackend, 2>::from_data([[val + 1.0, val + 1.0]], device);
186 let action = Tensor::<TestBackend, 2>::from_data([[val]], device);
187 buffer.push(state, next_state, action, val, false);
188 }
189
190 #[test]
191 fn push_increment_len() {
192 let device = Default::default();
193 let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(5, &device);
194
195 assert_eq!(buffer.len(), 0);
196 assert!(buffer.is_empty());
197
198 push_transition(&mut buffer, &device, 1.0);
199 assert_eq!(buffer.len(), 1);
200
201 push_transition(&mut buffer, &device, 2.0);
202 assert_eq!(buffer.len(), 2);
203 }
204
205 #[test]
206 fn push_overwrites_when_full() {
207 let device = Default::default();
208 let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(3, &device);
209
210 for i in 0..5 {
211 push_transition(&mut buffer, &device, i as f32);
212 }
213
214 assert_eq!(buffer.len(), 3);
215 assert_eq!(buffer.capacity(), 3);
216 }
217
218 #[test]
219 fn sample_returns_correct_shapes() {
220 let device = Default::default();
221 let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(10, &device);
222
223 for i in 0..5 {
224 push_transition(&mut buffer, &device, i as f32);
225 }
226
227 let batch = buffer.sample(3);
228 assert_eq!(batch.states.dims(), [3, 2]);
229 assert_eq!(batch.next_states.dims(), [3, 2]);
230 assert_eq!(batch.actions.dims(), [3, 1]);
231 assert_eq!(batch.rewards.dims(), [3, 1]);
232 assert_eq!(batch.dones.dims(), [3, 1]);
233 }
234
235 #[test]
236 #[should_panic(expected = "batch_size exceeds buffer length")]
237 fn sample_panics_when_batch_too_large() {
238 let device = Default::default();
239 let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(5, &device);
240
241 push_transition(&mut buffer, &device, 1.0);
242 buffer.sample(5);
243 }
244}