border_candle_agent/
lib.rs1pub mod atari_cnn;
3pub mod dqn;
4pub mod awac;
6pub mod bc;
7pub mod iql;
8pub mod mlp;
9pub mod model;
10pub mod opt;
11pub mod sac;
12mod tensor_batch;
13pub mod util;
14use candle_core::{backend::BackendDevice, DeviceLocation, Module};
15use serde::{Deserialize, Serialize};
16pub use tensor_batch::{TensorBatch, ZeroTensor};
17
18#[derive(Clone, Debug, Copy, Deserialize, Serialize, PartialEq)]
19pub enum Device {
25 Cpu,
27
28 Cuda(usize),
30}
31
32impl From<candle_core::Device> for Device {
33 fn from(device: candle_core::Device) -> Self {
34 match device {
35 candle_core::Device::Cpu => Self::Cpu,
36 candle_core::Device::Cuda(cuda_device) => {
37 let loc = cuda_device.location();
38 match loc {
39 DeviceLocation::Cuda { gpu_id } => Self::Cuda(gpu_id),
40 _ => panic!(),
41 }
42 }
43 _ => unimplemented!(),
44 }
45 }
46}
47
48impl Into<candle_core::Device> for Device {
49 fn into(self) -> candle_core::Device {
50 match self {
51 Self::Cpu => candle_core::Device::Cpu,
52 Self::Cuda(n) => candle_core::Device::new_cuda(n).unwrap(),
53 }
54 }
55}
56
57#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
58pub enum Activation {
59 None,
60 ReLU,
61 Tanh,
62 Sigmoid,
63}
64
65impl Activation {
66 pub fn forward(&self, x: &candle_core::Tensor) -> candle_core::Tensor {
67 match self {
68 Self::None => x.clone(),
69 Self::ReLU => x.relu().unwrap(),
70 Self::Tanh => x.tanh().unwrap(),
71 Self::Sigmoid => candle_nn::Activation::Sigmoid.forward(&x).unwrap(),
72 }
73 }
74}