1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#![allow(clippy::float_cmp)]
use crate::core::{
record::{Record, RecordValue, Recorder},
Env, Obs, Policy, Step,
};
use log::info;
use std::cell::RefCell;
pub fn sample<E: Env, P: Policy<E>>(
env: &mut E,
policy: &mut P,
obs_prev: &RefCell<Option<E::Obs>>,
) -> (Step<E>, Record) {
let obs = obs_prev
.replace(None)
.expect("The buffer of the previous observations is not initialized.");
let a = policy.sample(&obs);
let (step, record) = env.step(&a);
let obs_reset = env.reset(Some(&step.is_done)).unwrap();
let obs_reset = step.obs.clone().merge(obs_reset, &step.is_done);
obs_prev.replace(Some(obs_reset));
(step, record)
}
#[cfg_attr(doc, aquamarine::aquamarine)]
pub fn eval<E: Env, P: Policy<E>>(env: &mut E, policy: &mut P, n_episodes: usize) -> Vec<f32> {
let mut rs = Vec::new();
let obs = env.reset(None).unwrap();
let obs_prev = RefCell::new(Some(obs));
for i in 0..n_episodes {
let mut r_sum = 0.0;
let mut steps = 0;
loop {
let (step, _) = sample(env, policy, &obs_prev);
r_sum += &step.reward[0];
if step.is_done[0] == 1 {
break;
} else {
steps += 1;
}
}
rs.push(r_sum);
info!("Episode {:?}, {:?} steps, reward = {:?}", i, steps, r_sum);
}
rs
}
pub fn eval_with_recorder<E, P, R>(
env: &mut E,
policy: &mut P,
n_episodes: usize,
recorder: &mut R,
) -> Vec<f32>
where
E: Env,
P: Policy<E>,
R: Recorder,
{
let mut rs = Vec::new();
let obs = env.reset(None).unwrap();
let obs_prev = RefCell::new(Some(obs));
for episode in 0..n_episodes {
let mut count_step = 0;
let mut r_sum = 0.0;
loop {
let (step, mut record) = sample(env, policy, &obs_prev);
r_sum += &step.reward[0];
record.insert("reward", RecordValue::Scalar(step.reward[0] as _));
record.insert("episode", RecordValue::Scalar(episode as _));
record.insert("step", RecordValue::Scalar(count_step as _));
recorder.write(record);
if step.is_done[0] == 1 {
break;
}
count_step += 1;
}
rs.push(r_sum);
}
rs
}