border_tch_agent/
lib.rs

1//! RL agents implemented with [tch](https://crates.io/crates/tch).
2pub mod cnn;
3pub mod dqn;
4pub mod iqn;
5pub mod mlp;
6pub mod model;
7pub mod opt;
8pub mod sac;
9mod tensor_batch;
10// pub mod replay_buffer;
11pub mod util;
12use serde::{Deserialize, Serialize};
13pub use tensor_batch::{TensorBatch, ZeroTensor};
14
15#[derive(Clone, Debug, Copy, Deserialize, Serialize, PartialEq)]
16/// Device for using tch-rs.
17///
18/// This enum is added because `tch::Device` does not support serialization.
19pub enum Device {
20    /// The main CPU device.
21    Cpu,
22
23    /// The main GPU device.
24    Cuda(usize),
25}
26
27impl From<tch::Device> for Device {
28    fn from(device: tch::Device) -> Self {
29        match device {
30            tch::Device::Cpu => Self::Cpu,
31            tch::Device::Cuda(n) => Self::Cuda(n),
32            tch::Device::Mps => unimplemented!(),
33            tch::Device::Vulkan => unimplemented!(),
34        }
35    }
36}
37
38impl Into<tch::Device> for Device {
39    fn into(self) -> tch::Device {
40        match self {
41            Self::Cpu => tch::Device::Cpu,
42            Self::Cuda(n) => tch::Device::Cuda(n),
43        }
44    }
45}