1pub mod cnn;
3pub mod dqn;
4pub mod iqn;
5pub mod mlp;
6pub mod model;
7pub mod opt;
8pub mod sac;
9mod tensor_batch;
10pub mod util;
12use serde::{Deserialize, Serialize};
13pub use tensor_batch::{TensorBatch, ZeroTensor};
14
15#[derive(Clone, Debug, Copy, Deserialize, Serialize, PartialEq)]
16pub enum Device {
20 Cpu,
22
23 Cuda(usize),
25}
26
27impl From<tch::Device> for Device {
28 fn from(device: tch::Device) -> Self {
29 match device {
30 tch::Device::Cpu => Self::Cpu,
31 tch::Device::Cuda(n) => Self::Cuda(n),
32 tch::Device::Mps => unimplemented!(),
33 tch::Device::Vulkan => unimplemented!(),
34 }
35 }
36}
37
38impl Into<tch::Device> for Device {
39 fn into(self) -> tch::Device {
40 match self {
41 Self::Cpu => tch::Device::Cpu,
42 Self::Cuda(n) => tch::Device::Cuda(n),
43 }
44 }
45}