pub type Symbol = bool;
pub type SymbolList = Vec<Symbol>;
pub type Action = u64;
pub type Reward = i64;
pub type PerceptVal = u64;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ObservationKeyMode {
FullStream,
First,
Last,
StreamHash,
}
pub fn observation_key_from_stream(
mode: ObservationKeyMode,
observations: &[PerceptVal],
observation_bits: usize,
) -> PerceptVal {
match mode {
ObservationKeyMode::FullStream => {
debug_assert!(
false,
"observation_key_from_stream called with FullStream; use observation_repr_from_stream"
);
observation_key_from_stream(
ObservationKeyMode::StreamHash,
observations,
observation_bits,
)
}
ObservationKeyMode::First => observations.first().copied().unwrap_or(0),
ObservationKeyMode::Last => observations.last().copied().unwrap_or(0),
ObservationKeyMode::StreamHash => {
let mask = if observation_bits >= 64 {
u64::MAX
} else if observation_bits == 0 {
0
} else {
(1u64 << observation_bits) - 1
};
let mut h = 0u64;
for &obs in observations {
let v = obs & mask;
h = h.rotate_left(7) ^ v;
}
h
}
}
}
pub fn observation_repr_from_stream(
mode: ObservationKeyMode,
observations: &[PerceptVal],
observation_bits: usize,
) -> Vec<PerceptVal> {
match mode {
ObservationKeyMode::FullStream => observations.to_vec(),
_ => vec![observation_key_from_stream(
mode,
observations,
observation_bits,
)],
}
}
#[derive(Clone, Copy)]
pub struct RandomGenerator {
state: u64,
}
impl RandomGenerator {
#[inline]
fn initial_seed() -> u64 {
#[cfg(feature = "backend-zpaq")]
{
if let Ok(bytes) = zpaq_rs::random_bytes(8) {
let mut seed_arr = [0u8; 8];
seed_arr.copy_from_slice(&bytes);
return u64::from_le_bytes(seed_arr);
}
}
#[cfg(target_arch = "wasm32")]
{
return 0xCAFEBABEDEADBEEF ^ 0x9E3779B97F4A7C15;
}
#[cfg(not(target_arch = "wasm32"))]
#[allow(clippy::cast_possible_truncation)]
{
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0xCAFEBABEDEADBEEF);
return nanos ^ 0x9E3779B97F4A7C15;
}
#[allow(unreachable_code)]
0xCAFEBABEDEADBEEF
}
pub fn new() -> Self {
let seed = Self::initial_seed();
let state = if seed == 0 { 0xCAFEBABEDEADBEEF } else { seed };
Self { state }
}
pub fn from_seed(seed: u64) -> Self {
let state = if seed == 0 { 0xCAFEBABEDEADBEEF } else { seed };
Self { state }
}
pub fn next_u64(&mut self) -> u64 {
let mut x = self.state;
x ^= x >> 12;
x ^= x << 25;
x ^= x >> 27;
self.state = x;
x.wrapping_mul(0x2545F4914F6CDD1D)
}
pub fn gen_range(&mut self, end: usize) -> usize {
if end == 0 {
return 0;
}
(self.next_u64() % (end as u64)) as usize
}
pub fn gen_bool(&mut self, p: f64) -> bool {
self.gen_f64() < p
}
pub fn gen_f64(&mut self) -> f64 {
let v = self.next_u64() >> 11;
(v as f64) * (1.0 / 9007199254740992.0)
}
pub fn fork_with(&self, salt: u64) -> Self {
let mixed = Self::splitmix64(self.state ^ salt ^ 0x9E3779B97F4A7C15);
let state = if mixed == 0 {
0xCAFEBABEDEADBEEF
} else {
mixed
};
Self { state }
}
fn splitmix64(mut x: u64) -> u64 {
x = x.wrapping_add(0x9E3779B97F4A7C15);
let mut z = x;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D049BB133111EB);
z ^ (z >> 31)
}
}
impl Default for RandomGenerator {
fn default() -> Self {
Self::new()
}
}
pub fn encode(symlist: &mut SymbolList, value: u64, bits: usize) {
let mut v = value;
for _ in 0..bits {
symlist.push((v & 1) == 1);
v >>= 1;
}
}
pub fn encode_reward(symlist: &mut SymbolList, value: i64, bits: usize) {
let mut v = value as u64;
for _ in 0..bits {
symlist.push((v & 1) == 1);
v >>= 1;
}
}
pub fn encode_reward_offset(symlist: &mut SymbolList, value: i64, bits: usize, offset: i64) {
let shifted = (value + offset) as u64;
encode(symlist, shifted, bits);
}
pub fn decode(symlist: &[Symbol], bits: usize) -> u64 {
if bits == 0 {
return 0;
}
assert!(bits <= symlist.len());
let mut value = 0u64;
for i in 0..bits {
let sym = symlist[symlist.len() - 1 - i];
value = (value << 1) + (if sym { 1 } else { 0 });
}
value
}
pub fn decode_reward(symlist: &[Symbol], bits: usize) -> i64 {
if bits == 0 {
return 0;
}
let v = decode(symlist, bits);
if bits < 64 && (v & (1 << (bits - 1))) != 0 {
(v | (!0u64 << bits)) as i64
} else {
v as i64
}
}
pub fn decode_reward_offset(symlist: &[Symbol], bits: usize, offset: i64) -> i64 {
if bits == 0 {
return 0;
}
let v = decode(symlist, bits) as i64;
v - offset
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn observation_repr_full_stream_is_identity() {
let obs = vec![1u64, 2u64, 3u64];
let repr = observation_repr_from_stream(ObservationKeyMode::FullStream, &obs, 8);
assert_eq!(repr, obs);
}
#[test]
fn observation_key_first_last() {
let obs = vec![10u64, 20u64, 30u64];
assert_eq!(
observation_key_from_stream(ObservationKeyMode::First, &obs, 8),
10
);
assert_eq!(
observation_key_from_stream(ObservationKeyMode::Last, &obs, 8),
30
);
let empty: Vec<PerceptVal> = vec![];
assert_eq!(
observation_key_from_stream(ObservationKeyMode::First, &empty, 8),
0
);
assert_eq!(
observation_key_from_stream(ObservationKeyMode::Last, &empty, 8),
0
);
}
#[test]
fn observation_key_stream_hash_masks_and_mix() {
let obs = vec![9u64, 2u64];
let h = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 3);
assert_eq!(h, 130);
}
#[test]
fn observation_key_stream_hash_observation_bits_zero_is_zero() {
let obs = vec![123u64, 456u64, 789u64];
let h = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 0);
assert_eq!(h, 0);
}
#[test]
fn observation_key_stream_hash_observation_bits_ge_64_uses_full_u64() {
let obs = vec![u64::MAX, 0x0123_4567_89ab_cdef];
let h1 = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 64);
let h2 = observation_key_from_stream(ObservationKeyMode::StreamHash, &obs, 128);
assert_eq!(h1, h2);
}
}