border_candle_agent/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
//! RL agents implemented with [candle](https://crates.io/crates/candle-core).
pub mod cnn;
pub mod dqn;
// pub mod iqn;
pub mod mlp;
pub mod model;
pub mod opt;
pub mod sac;
mod tensor_batch;
pub mod util;
use candle_core::{backend::BackendDevice, DeviceLocation};
use serde::{Deserialize, Serialize};
pub use tensor_batch::{TensorBatch, ZeroTensor};

#[derive(Clone, Debug, Copy, Deserialize, Serialize, PartialEq)]
/// Device for using candle.
///
/// This enum is added because [`candle_core::Device`] does not support serialization.
///
/// [`candle_core::Device`]: https://docs.rs/candle-core/0.4.1/candle_core/enum.Device.html
pub enum Device {
    /// The main CPU device.
    Cpu,

    /// The main GPU device.
    Cuda(usize),
}

impl From<candle_core::Device> for Device {
    fn from(device: candle_core::Device) -> Self {
        match device {
            candle_core::Device::Cpu => Self::Cpu,
            candle_core::Device::Cuda(cuda_device) => {
                let loc = cuda_device.location();
                match loc {
                    DeviceLocation::Cuda { gpu_id } => Self::Cuda(gpu_id),
                    _ => panic!(),
                }
            }
            _ => unimplemented!(),
        }
    }
}

impl Into<candle_core::Device> for Device {
    fn into(self) -> candle_core::Device {
        match self {
            Self::Cpu => candle_core::Device::Cpu,
            Self::Cuda(n) => candle_core::Device::new_cuda(n).unwrap(),
        }
    }
}