1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4pub mod environment;
8pub mod policy;
10pub mod transition_buffer;
12
13pub use environment::*;
14pub use policy::*;
15pub use transition_buffer::*;
16
17#[cfg(test)]
18pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;
19
20#[cfg(test)]
21pub(crate) mod tests {
22 use crate::{Batchable, Policy, PolicyState, TestBackend};
23
24 use burn_core::record::Record;
25 use burn_core::{self as burn};
26
27 #[derive(Clone)]
35 pub(crate) struct MockPolicy {}
36
37 impl MockPolicy {
38 pub fn new() -> Self {
39 Self {}
40 }
41 }
42
43 impl Policy<TestBackend> for MockPolicy {
44 type Observation = MockObservation;
45 type ActionDistribution = MockActionDistribution;
46 type Action = MockAction;
47 type ActionContext = MockActionContext;
48 type PolicyState = MockPolicyState;
49
50 fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution {
51 let mut dists = vec![];
52
53 for _ in obs.0 {
54 dists.push(MockActionDistribution(vec![0.]));
55 }
56 MockActionDistribution::batch(dists)
57 }
58
59 fn action(
60 &mut self,
61 obs: Self::Observation,
62 deterministic: bool,
63 ) -> (Self::Action, Vec<Self::ActionContext>) {
64 let mut actions = vec![];
65 let mut contexts = vec![];
66
67 for _ in obs.0 {
68 if deterministic {
69 actions.push(MockAction(vec![1]));
70 } else {
71 actions.push(MockAction(vec![0]));
72 }
73 contexts.push(MockActionContext);
74 }
75
76 (MockAction::batch(actions), contexts)
77 }
78
79 fn update(&mut self, _update: Self::PolicyState) {}
80
81 fn state(&self) -> Self::PolicyState {
82 MockPolicyState
83 }
84
85 fn load_record(
86 self,
87 _record: <Self::PolicyState as PolicyState<TestBackend>>::Record,
88 ) -> Self {
89 self
90 }
91 }
92
93 #[derive(Clone)]
95 pub(crate) struct MockObservation(pub Vec<f32>);
96
97 #[derive(Clone)]
99 pub(crate) struct MockAction(pub Vec<i32>);
100
101 #[derive(Clone)]
103 pub(crate) struct MockActionDistribution(Vec<f32>);
104
105 #[derive(Clone)]
106 pub(crate) struct MockActionContext;
107
108 #[derive(Clone)]
110 pub(crate) struct MockPolicyState;
111
112 #[derive(Clone, Record)]
113 pub(crate) struct MockRecord {
114 item: usize,
115 }
116
117 impl PolicyState<TestBackend> for MockPolicyState {
118 type Record = MockRecord;
119
120 fn into_record(self) -> Self::Record {
121 MockRecord { item: 0 }
122 }
123
124 fn load_record(&self, _record: Self::Record) -> Self {
125 self.clone()
126 }
127 }
128
129 impl Batchable for MockObservation {
130 fn batch(items: Vec<Self>) -> Self {
131 MockObservation(items.iter().flat_map(|m| m.0.clone()).collect())
132 }
133
134 fn unbatch(self) -> Vec<Self> {
135 vec![MockObservation(self.0)]
136 }
137 }
138
139 impl Batchable for MockAction {
140 fn batch(items: Vec<Self>) -> Self {
141 MockAction(items.iter().flat_map(|m| m.0.clone()).collect())
142 }
143
144 fn unbatch(self) -> Vec<Self> {
145 let mut actions = vec![];
146 for a in self.0 {
147 actions.push(MockAction(vec![a]));
148 }
149 actions
150 }
151 }
152
153 impl Batchable for MockActionDistribution {
154 fn batch(items: Vec<Self>) -> Self {
155 MockActionDistribution(items.iter().flat_map(|m| m.0.clone()).collect())
156 }
157
158 fn unbatch(self) -> Vec<Self> {
159 let mut dists = vec![];
160 for _ in self.0 {
161 dists.push(MockActionDistribution(vec![0.]));
162 }
163 dists
164 }
165 }
166}