Skip to main content

burn_rl/
lib.rs

1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4//! A library for training reinforcement learning agents.
5
6/// Module for implementing an environment.
7pub mod environment;
8/// Module for implementing a policy.
9pub mod policy;
10/// Transition buffer.
11pub mod transition_buffer;
12
13pub use environment::*;
14pub use policy::*;
15pub use transition_buffer::*;
16
17#[cfg(test)]
18pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;
19
20#[cfg(test)]
21pub(crate) mod tests {
22    use crate::{Batchable, Policy, PolicyState, TestBackend};
23
24    use burn_core::record::Record;
25    use burn_core::{self as burn};
26
27    /// Mock policy for testing
28    ///
29    /// Calling `forward()` with a [MockObservation](MockObservation) (list of f32) returns a [MockActionDistribution](MockActionDistribution)
30    /// containing a list of 0s of the same length as the observation.
31    ///
32    /// Calling `action()` with a [MockObservation](MockObservation) (list of f32) returns a [MockAction](MockAction) with a list of actions of the same length as the observation.
33    /// The actions are all 1 if the call is requested as deterministic, or else 0.
34    #[derive(Clone)]
35    pub(crate) struct MockPolicy {}
36
37    impl MockPolicy {
38        pub fn new() -> Self {
39            Self {}
40        }
41    }
42
43    impl Policy<TestBackend> for MockPolicy {
44        type Observation = MockObservation;
45        type ActionDistribution = MockActionDistribution;
46        type Action = MockAction;
47        type ActionContext = MockActionContext;
48        type PolicyState = MockPolicyState;
49
50        fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution {
51            let mut dists = vec![];
52
53            for _ in obs.0 {
54                dists.push(MockActionDistribution(vec![0.]));
55            }
56            MockActionDistribution::batch(dists)
57        }
58
59        fn action(
60            &mut self,
61            obs: Self::Observation,
62            deterministic: bool,
63        ) -> (Self::Action, Vec<Self::ActionContext>) {
64            let mut actions = vec![];
65            let mut contexts = vec![];
66
67            for _ in obs.0 {
68                if deterministic {
69                    actions.push(MockAction(vec![1]));
70                } else {
71                    actions.push(MockAction(vec![0]));
72                }
73                contexts.push(MockActionContext);
74            }
75
76            (MockAction::batch(actions), contexts)
77        }
78
79        fn update(&mut self, _update: Self::PolicyState) {}
80
81        fn state(&self) -> Self::PolicyState {
82            MockPolicyState
83        }
84
85        fn load_record(
86            self,
87            _record: <Self::PolicyState as PolicyState<TestBackend>>::Record,
88        ) -> Self {
89            self
90        }
91    }
92
93    /// Mock observation for testing represented as a vector of f32. Can call `batch()` and `unbatch` on it.
94    #[derive(Clone)]
95    pub(crate) struct MockObservation(pub Vec<f32>);
96
97    /// Mock action for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it.
98    #[derive(Clone)]
99    pub(crate) struct MockAction(pub Vec<i32>);
100
101    /// Mock action distribution for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it.
102    #[derive(Clone)]
103    pub(crate) struct MockActionDistribution(Vec<f32>);
104
105    #[derive(Clone)]
106    pub(crate) struct MockActionContext;
107
108    /// Mock policy state for testing represented as an arbitrary `usize` that has no effect on the policy.
109    #[derive(Clone)]
110    pub(crate) struct MockPolicyState;
111
112    #[derive(Clone, Record)]
113    pub(crate) struct MockRecord {
114        item: usize,
115    }
116
117    impl PolicyState<TestBackend> for MockPolicyState {
118        type Record = MockRecord;
119
120        fn into_record(self) -> Self::Record {
121            MockRecord { item: 0 }
122        }
123
124        fn load_record(&self, _record: Self::Record) -> Self {
125            self.clone()
126        }
127    }
128
129    impl Batchable for MockObservation {
130        fn batch(items: Vec<Self>) -> Self {
131            MockObservation(items.iter().flat_map(|m| m.0.clone()).collect())
132        }
133
134        fn unbatch(self) -> Vec<Self> {
135            vec![MockObservation(self.0)]
136        }
137    }
138
139    impl Batchable for MockAction {
140        fn batch(items: Vec<Self>) -> Self {
141            MockAction(items.iter().flat_map(|m| m.0.clone()).collect())
142        }
143
144        fn unbatch(self) -> Vec<Self> {
145            let mut actions = vec![];
146            for a in self.0 {
147                actions.push(MockAction(vec![a]));
148            }
149            actions
150        }
151    }
152
153    impl Batchable for MockActionDistribution {
154        fn batch(items: Vec<Self>) -> Self {
155            MockActionDistribution(items.iter().flat_map(|m| m.0.clone()).collect())
156        }
157
158        fn unbatch(self) -> Vec<Self> {
159            let mut dists = vec![];
160            for _ in self.0 {
161                dists.push(MockActionDistribution(vec![0.]));
162            }
163            dists
164        }
165    }
166}