reinforcex 0.0.4

Deep Reinforcement Learning Framework
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
use std::collections::HashSet;

use tch::Tensor;

use ulid::Ulid;

pub trait BaseAgent {
    fn act_and_train(&mut self, obs: &Tensor, reward: f64) -> Tensor;
    fn act(&self, obs: &Tensor) -> Tensor;
    fn stop_episode_and_train(&mut self, obs: &Tensor, reward: f64);
    fn get_statistics(&self) -> Vec<(String, f64)>;
    fn get_agent_id(&self) -> &Ulid;
    fn save(&self, dirname: &str, ancestors: HashSet<String>);
    fn load(&self, dirname: &str, ancestors: HashSet<String>);
}