border_candle_agent/
lib.rs

1//! RL agents implemented with [candle](https://crates.io/crates/candle-core).
2pub mod atari_cnn;
3pub mod dqn;
4// pub mod iqn;
5pub mod awac;
6pub mod bc;
7pub mod iql;
8pub mod mlp;
9pub mod model;
10pub mod opt;
11pub mod sac;
12mod tensor_batch;
13pub mod util;
14use candle_core::{backend::BackendDevice, DeviceLocation, Module};
15use serde::{Deserialize, Serialize};
16pub use tensor_batch::{TensorBatch, ZeroTensor};
17
18#[derive(Clone, Debug, Copy, Deserialize, Serialize, PartialEq)]
19/// Device for using candle.
20///
21/// This enum is added because [`candle_core::Device`] does not support serialization.
22///
23/// [`candle_core::Device`]: https://docs.rs/candle-core/0.4.1/candle_core/enum.Device.html
24pub enum Device {
25    /// The main CPU device.
26    Cpu,
27
28    /// The main GPU device.
29    Cuda(usize),
30}
31
32impl From<candle_core::Device> for Device {
33    fn from(device: candle_core::Device) -> Self {
34        match device {
35            candle_core::Device::Cpu => Self::Cpu,
36            candle_core::Device::Cuda(cuda_device) => {
37                let loc = cuda_device.location();
38                match loc {
39                    DeviceLocation::Cuda { gpu_id } => Self::Cuda(gpu_id),
40                    _ => panic!(),
41                }
42            }
43            _ => unimplemented!(),
44        }
45    }
46}
47
48impl Into<candle_core::Device> for Device {
49    fn into(self) -> candle_core::Device {
50        match self {
51            Self::Cpu => candle_core::Device::Cpu,
52            Self::Cuda(n) => candle_core::Device::new_cuda(n).unwrap(),
53        }
54    }
55}
56
57#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
58pub enum Activation {
59    None,
60    ReLU,
61    Tanh,
62    Sigmoid,
63}
64
65impl Activation {
66    pub fn forward(&self, x: &candle_core::Tensor) -> candle_core::Tensor {
67        match self {
68            Self::None => x.clone(),
69            Self::ReLU => x.relu().unwrap(),
70            Self::Tanh => x.tanh().unwrap(),
71            Self::Sigmoid => candle_nn::Activation::Sigmoid.forward(&x).unwrap(),
72        }
73    }
74}