1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
//! RL agents implemented with [tch](https://crates.io/crates/tch).
pub mod cnn;
pub mod dqn;
pub mod iqn;
pub mod mlp;
pub mod model;
pub mod opt;
pub mod sac;
mod tensor_batch;
// pub mod replay_buffer;
pub mod util;
use serde::{Deserialize, Serialize};
pub use tensor_batch::{TensorSubBatch, ZeroTensor};

#[derive(Clone, Debug, Copy, Deserialize, Serialize, PartialEq)]
/// Device for using tch-rs.
///
/// This enum is added because `tch::Device` does not support serialization.
pub enum Device {
    /// The main CPU device.
    Cpu,

    /// The main GPU device.
    Cuda(usize),
}

impl From<tch::Device> for Device {
    fn from(device: tch::Device) -> Self {
        match device {
            tch::Device::Cpu => Self::Cpu,
            tch::Device::Cuda(n) => Self::Cuda(n),
        }
    }
}

impl Into<tch::Device> for Device {
    fn into(self) -> tch::Device {
        match self {
            Self::Cpu => tch::Device::Cpu,
            Self::Cuda(n) => tch::Device::Cuda(n),
        }
    }
}