Skip to main content

burn_rl/transition_buffer/
base.rs

1use burn_core::{Tensor, prelude::Backend, tensor::Distribution};
2use derive_new::new;
3
4use super::SliceAccess;
5
6/// A state transition in an environment.
7#[derive(Clone, new)]
8pub struct Transition<B: Backend, S, A> {
9    /// The initial state.
10    pub state: S,
11    /// The state after the step was taken.
12    pub next_state: S,
13    /// The action taken in the step.
14    pub action: A,
15    /// The reward.
16    pub reward: Tensor<B, 1>,
17    /// If the environment has reached a terminal state.
18    pub done: Tensor<B, 1>,
19}
20
21/// A batch of transitions.
22pub struct TransitionBatch<B: Backend, SB, AB> {
23    /// Batched initial states.
24    pub states: SB,
25    /// Batched resulting states.
26    pub next_states: SB,
27    /// Batched actions.
28    pub actions: AB,
29    /// Batched rewards.
30    pub rewards: Tensor<B, 2>,
31    /// Batched flags for terminal states.
32    pub dones: Tensor<B, 2>,
33}
34
35/// A tensor-backed circular buffer for transitions.
36///
37/// Uses [`SliceAccess`] to store state and action batches in contiguous
38/// tensor storage, enabling efficient random sampling via `select`.
39/// The buffer lazily initializes its storage on the first `push` call.
40pub struct TransitionBuffer<B: Backend, SB: SliceAccess<B>, AB: SliceAccess<B>> {
41    states: Option<SB>,
42    next_states: Option<SB>,
43    actions: Option<AB>,
44    rewards: Option<Tensor<B, 2>>,
45    dones: Option<Tensor<B, 2>>,
46    capacity: usize,
47    write_head: usize,
48    len: usize,
49    device: B::Device,
50}
51
52impl<B: Backend, SB: SliceAccess<B>, AB: SliceAccess<B>> TransitionBuffer<B, SB, AB> {
53    /// Creates a new buffer. Storage is lazily allocated on the first `push`.
54    pub fn new(capacity: usize, device: &B::Device) -> Self {
55        Self {
56            states: None,
57            next_states: None,
58            actions: None,
59            rewards: None,
60            dones: None,
61            capacity,
62            write_head: 0,
63            len: 0,
64            device: device.clone(),
65        }
66    }
67
68    fn ensure_init(&mut self, state: &SB, next_state: &SB, action: &AB) {
69        if self.states.is_none() {
70            self.states = Some(SB::zeros_like(state, self.capacity, &self.device));
71            self.next_states = Some(SB::zeros_like(next_state, self.capacity, &self.device));
72            self.actions = Some(AB::zeros_like(action, self.capacity, &self.device));
73            self.rewards = Some(Tensor::zeros([self.capacity, 1], &self.device));
74            self.dones = Some(Tensor::zeros([self.capacity, 1], &self.device));
75        }
76    }
77
78    /// Add a transition, overwriting the oldest if full.
79    pub fn push(&mut self, state: SB, next_state: SB, action: AB, reward: f32, done: bool) {
80        self.ensure_init(&state, &next_state, &action);
81
82        let idx = self.write_head % self.capacity;
83
84        self.states
85            .as_mut()
86            .unwrap()
87            .slice_assign_inplace(idx, state);
88        self.next_states
89            .as_mut()
90            .unwrap()
91            .slice_assign_inplace(idx, next_state);
92        self.actions
93            .as_mut()
94            .unwrap()
95            .slice_assign_inplace(idx, action);
96
97        let reward = Tensor::from_data([[reward]], &self.device);
98        self.rewards
99            .as_mut()
100            .unwrap()
101            .inplace(|r| r.slice_assign(idx..idx + 1, reward));
102
103        let done_val = if done { 1.0f32 } else { 0.0 };
104        let done = Tensor::from_data([[done_val]], &self.device);
105        self.dones
106            .as_mut()
107            .unwrap()
108            .inplace(|d| d.slice_assign(idx..idx + 1, done));
109
110        self.write_head += 1;
111        if self.len < self.capacity {
112            self.len += 1;
113        }
114    }
115
116    /// Sample a random batch of transitions.
117    pub fn sample(&self, batch_size: usize) -> TransitionBatch<B, SB, AB> {
118        assert!(batch_size <= self.len, "batch_size exceeds buffer length");
119
120        let indices = Tensor::<B, 1>::random(
121            [batch_size],
122            Distribution::Uniform(0.0, self.len as f64),
123            &self.device,
124        )
125        .int();
126
127        TransitionBatch {
128            states: self
129                .states
130                .as_ref()
131                .unwrap()
132                .clone()
133                .select(0, indices.clone()),
134            next_states: self
135                .next_states
136                .as_ref()
137                .unwrap()
138                .clone()
139                .select(0, indices.clone()),
140            actions: self
141                .actions
142                .as_ref()
143                .unwrap()
144                .clone()
145                .select(0, indices.clone()),
146            rewards: self
147                .rewards
148                .as_ref()
149                .unwrap()
150                .clone()
151                .select(0, indices.clone()),
152            dones: self.dones.as_ref().unwrap().clone().select(0, indices),
153        }
154    }
155
156    /// Current number of stored transitions.
157    pub fn len(&self) -> usize {
158        self.len
159    }
160
161    /// Whether the buffer is empty.
162    pub fn is_empty(&self) -> bool {
163        self.len == 0
164    }
165
166    /// Buffer capacity.
167    pub fn capacity(&self) -> usize {
168        self.capacity
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use crate::TestBackend;
176
177    type TB = Tensor<TestBackend, 2>;
178
179    fn push_transition(
180        buffer: &mut TransitionBuffer<TestBackend, TB, TB>,
181        device: &<TestBackend as Backend>::Device,
182        val: f32,
183    ) {
184        let state = Tensor::<TestBackend, 2>::from_data([[val, val]], device);
185        let next_state = Tensor::<TestBackend, 2>::from_data([[val + 1.0, val + 1.0]], device);
186        let action = Tensor::<TestBackend, 2>::from_data([[val]], device);
187        buffer.push(state, next_state, action, val, false);
188    }
189
190    #[test]
191    fn push_increment_len() {
192        let device = Default::default();
193        let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(5, &device);
194
195        assert_eq!(buffer.len(), 0);
196        assert!(buffer.is_empty());
197
198        push_transition(&mut buffer, &device, 1.0);
199        assert_eq!(buffer.len(), 1);
200
201        push_transition(&mut buffer, &device, 2.0);
202        assert_eq!(buffer.len(), 2);
203    }
204
205    #[test]
206    fn push_overwrites_when_full() {
207        let device = Default::default();
208        let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(3, &device);
209
210        for i in 0..5 {
211            push_transition(&mut buffer, &device, i as f32);
212        }
213
214        assert_eq!(buffer.len(), 3);
215        assert_eq!(buffer.capacity(), 3);
216    }
217
218    #[test]
219    fn sample_returns_correct_shapes() {
220        let device = Default::default();
221        let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(10, &device);
222
223        for i in 0..5 {
224            push_transition(&mut buffer, &device, i as f32);
225        }
226
227        let batch = buffer.sample(3);
228        assert_eq!(batch.states.dims(), [3, 2]);
229        assert_eq!(batch.next_states.dims(), [3, 2]);
230        assert_eq!(batch.actions.dims(), [3, 1]);
231        assert_eq!(batch.rewards.dims(), [3, 1]);
232        assert_eq!(batch.dones.dims(), [3, 1]);
233    }
234
235    #[test]
236    #[should_panic(expected = "batch_size exceeds buffer length")]
237    fn sample_panics_when_batch_too_large() {
238        let device = Default::default();
239        let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(5, &device);
240
241        push_transition(&mut buffer, &device, 1.0);
242        buffer.sample(5);
243    }
244}