use border_tch_agent::{model::SubModel, util::OutDim};
use serde::{Deserialize, Serialize};
use tch::{nn, nn::Module, Device, Tensor};
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
pub struct CNNConfig {
n_stack: i64,
out_dim: i64,
}
impl OutDim for CNNConfig {
fn get_out_dim(&self) -> i64 {
self.out_dim
}
fn set_out_dim(&mut self, v: i64) {
self.out_dim = v;
}
}
impl CNNConfig {
#[allow(dead_code)]
pub fn new(n_stack: i64, out_dim: i64) -> Self {
Self {
n_stack,
out_dim,
}
}
}
#[allow(clippy::upper_case_acronyms)]
pub struct CNN {
n_stack: i64,
out_dim: i64,
device: Device,
seq: nn::Sequential,
}
impl CNN {
fn stride(s: i64) -> nn::ConvConfig {
nn::ConvConfig {
stride: s,
..Default::default()
}
}
fn create_net(var_store: &nn::VarStore, n_stack: i64, out_dim: i64) -> nn::Sequential {
let p = &var_store.root();
nn::seq()
.add_fn(|xs| xs.squeeze_dim(2).internal_cast_float(true) / 255)
.add(nn::conv2d(p / "c1", n_stack, 32, 8, Self::stride(4)))
.add_fn(|xs| xs.relu())
.add(nn::conv2d(p / "c2", 32, 64, 4, Self::stride(2)))
.add_fn(|xs| xs.relu())
.add(nn::conv2d(p / "c3", 64, 64, 3, Self::stride(1)))
.add_fn(|xs| xs.relu().flat_view())
.add(nn::linear(p / "l1", 3136, 512, Default::default()))
.add_fn(|xs| xs.relu())
.add(nn::linear(p / "l2", 512, out_dim as _, Default::default()))
}
}
impl SubModel for CNN {
type Config = CNNConfig;
type Input = Tensor;
type Output = Tensor;
fn forward(&self, x: &Self::Input) -> Tensor {
self.seq.forward(&x.to(self.device))
}
fn build(var_store: &nn::VarStore, config: Self::Config) -> Self {
let n_stack = config.n_stack;
let out_dim = config.out_dim;
let device = var_store.device();
let seq = Self::create_net(var_store, n_stack, out_dim);
Self {
n_stack,
out_dim,
device,
seq,
}
}
fn clone_with_var_store(&self, var_store: &nn::VarStore) -> Self {
let n_stack = self.n_stack;
let out_dim = self.out_dim;
let device = var_store.device();
let seq = Self::create_net(&var_store, n_stack, out_dim);
Self {
n_stack,
out_dim,
device,
seq,
}
}
}