use rand::Rng;
use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize)]
pub struct Linear {
pub weights: Vec<Vec<f64>>,
pub biases: Vec<f64>,
pub in_size: usize,
pub out_size: usize,
}
impl Linear {
pub fn new(in_size: usize, out_size: usize) -> Self {
let mut rng = rand::thread_rng();
let bound = (2.0_f64 / in_size as f64).sqrt();
let weights = (0..out_size)
.map(|_| (0..in_size).map(|_| rng.gen_range(-bound..bound)).collect())
.collect();
Self { weights, biases: vec![0.0; out_size], in_size, out_size }
}
pub fn forward(&self, input: &[f64]) -> Vec<f64> {
(0..self.out_size)
.map(|i| {
self.biases[i]
+ self.weights[i]
.iter()
.zip(input)
.map(|(w, x)| w * x)
.sum::<f64>()
})
.collect()
}
pub fn sgd_update(&mut self, w_grads: &[Vec<f64>], b_grads: &[f64], lr: f64) {
for i in 0..self.out_size {
self.biases[i] -= lr * b_grads[i];
for j in 0..self.in_size {
self.weights[i][j] -= lr * w_grads[i][j];
}
}
}
}
#[inline]
fn relu(v: &[f64]) -> Vec<f64> { v.iter().map(|&x| x.max(0.0)).collect() }
#[inline]
fn relu_grad(pre: &[f64]) -> Vec<f64> {
pre.iter().map(|&x| if x > 0.0 { 1.0 } else { 0.0 }).collect()
}
#[derive(Clone, Serialize, Deserialize)]
pub struct QNetwork {
pub layers: Vec<Linear>,
pub layer_sizes: Vec<usize>,
#[serde(default)]
pub value_head: Option<Linear>,
#[serde(default)]
pub advantage_head: Option<Linear>,
}
impl QNetwork {
pub fn new(layer_sizes: &[usize]) -> Self {
assert!(layer_sizes.len() >= 2, "need at least input + output layer");
let layers = layer_sizes.windows(2).map(|w| Linear::new(w[0], w[1])).collect();
Self { layers, layer_sizes: layer_sizes.to_vec(), value_head: None, advantage_head: None }
}
pub fn new_dueling(shared_sizes: &[usize], action_dim: usize) -> Self {
assert!(shared_sizes.len() >= 2, "need at least input + one hidden");
let layers = shared_sizes.windows(2).map(|w| Linear::new(w[0], w[1])).collect();
let last_hidden = *shared_sizes.last().unwrap();
let value_head = Linear::new(last_hidden, 1);
let advantage_head = Linear::new(last_hidden, action_dim);
let mut ls = shared_sizes.to_vec();
ls.push(action_dim);
Self { layers, layer_sizes: ls, value_head: Some(value_head), advantage_head: Some(advantage_head) }
}
pub fn forward(&self, input: &[f64]) -> Vec<f64> {
let mut x = input.to_vec();
let is_dueling = self.value_head.is_some() && self.advantage_head.is_some();
if is_dueling {
for layer in &self.layers {
let pre = layer.forward(&x);
x = relu(&pre);
}
let vh = self.value_head.as_ref().unwrap();
let ah = self.advantage_head.as_ref().unwrap();
let v = vh.forward(&x)[0];
let a = ah.forward(&x);
let mean_a = a.iter().sum::<f64>() / a.len() as f64;
a.iter().map(|&ai| v + ai - mean_a).collect()
} else {
let last = self.layers.len().saturating_sub(1);
for (i, layer) in self.layers.iter().enumerate() {
let pre = layer.forward(&x);
x = if i < last { relu(&pre) } else { pre };
}
x
}
}
fn forward_cache(&self, input: &[f64]) -> (Vec<f64>, Vec<Vec<f64>>, Vec<Vec<f64>>) {
let mut pre_acts = Vec::new();
let mut post_acts = vec![input.to_vec()];
let mut x = input.to_vec();
let is_dueling = self.value_head.is_some() && self.advantage_head.is_some();
let last = self.layers.len().saturating_sub(1);
for (i, layer) in self.layers.iter().enumerate() {
let pre = layer.forward(&x);
pre_acts.push(pre.clone());
x = if is_dueling || i < last { relu(&pre) } else { pre };
post_acts.push(x.clone());
}
let output = if is_dueling {
let vh = self.value_head.as_ref().unwrap();
let ah = self.advantage_head.as_ref().unwrap();
let v = vh.forward(&x)[0];
let a = ah.forward(&x);
let mean_a = a.iter().sum::<f64>() / a.len() as f64;
a.iter().map(|&ai| v + ai - mean_a).collect()
} else {
x
};
(output, pre_acts, post_acts)
}
pub fn backward_step(
&mut self,
input: &[f64],
targets: &[f64],
mask: &[bool],
lr: f64,
grad_clip: f64,
) -> f64 {
let is_dueling = self.value_head.is_some() && self.advantage_head.is_some();
if !is_dueling {
let (output, pre_acts, post_acts) = self.forward_cache(input);
let mut loss = 0.0;
let n_masked = mask.iter().filter(|&&m| m).count().max(1);
let mut delta: Vec<f64> = output
.iter()
.zip(targets)
.zip(mask)
.map(|((o, t), &m)| {
if m {
let err = o - t;
loss += err * err;
2.0 * err / n_masked as f64
} else {
0.0
}
})
.collect();
let n_layers = self.layers.len();
for i in (0..n_layers).rev() {
if i < n_layers - 1 {
let rg = relu_grad(&pre_acts[i]);
for (d, r) in delta.iter_mut().zip(&rg) { *d *= r; }
}
let out_sz = self.layers[i].out_size;
let in_sz = self.layers[i].in_size;
let layer_input = &post_acts[i];
let mut w_grads = vec![vec![0.0; in_sz]; out_sz];
let mut b_grads = vec![0.0; out_sz];
let mut prev_delta = vec![0.0; in_sz];
for j in 0..out_sz {
b_grads[j] = delta[j].clamp(-grad_clip, grad_clip);
for k in 0..in_sz {
let g = (delta[j] * layer_input[k]).clamp(-grad_clip, grad_clip);
w_grads[j][k] = g;
prev_delta[k] += delta[j] * self.layers[i].weights[j][k];
}
}
self.layers[i].sgd_update(&w_grads, &b_grads, lr);
delta = prev_delta;
}
return loss / n_masked as f64;
}
let (output, trunk_pre, trunk_post) = self.forward_cache(input);
let trunk_out = trunk_post.last().unwrap().clone();
let mut loss = 0.0;
let n_masked = mask.iter().filter(|&&m| m).count().max(1);
let n_actions = output.len();
let dq: Vec<f64> = output.iter().zip(targets).zip(mask).map(|((o, t), &m)| {
if m {
let err = o - t;
loss += err * err;
2.0 * err / n_masked as f64
} else {
0.0
}
}).collect();
let sum_dq: f64 = dq.iter().sum();
let inv_n = 1.0 / n_actions as f64;
let da: Vec<f64> = dq.iter().map(|&dqi| dqi - inv_n * sum_dq).collect();
let dv = sum_dq;
{
let ah = self.advantage_head.as_mut().unwrap();
let in_sz = ah.in_size;
let out_sz = ah.out_size;
let mut w_grads = vec![vec![0.0; in_sz]; out_sz];
let mut b_grads = vec![0.0; out_sz];
let mut trunk_delta = vec![0.0; in_sz];
for j in 0..out_sz {
b_grads[j] = da[j].clamp(-grad_clip, grad_clip);
for k in 0..in_sz {
let g = (da[j] * trunk_out[k]).clamp(-grad_clip, grad_clip);
w_grads[j][k] = g;
trunk_delta[k] += da[j] * ah.weights[j][k];
}
}
ah.sgd_update(&w_grads, &b_grads, lr);
let vh = self.value_head.as_mut().unwrap();
let vin_sz = vh.in_size;
let mut vw_grads = vec![vec![0.0; vin_sz]; 1];
let mut vb_grads = vec![0.0; 1];
vb_grads[0] = dv.clamp(-grad_clip, grad_clip);
for k in 0..vin_sz {
let g = (dv * trunk_out[k]).clamp(-grad_clip, grad_clip);
vw_grads[0][k] = g;
trunk_delta[k] += dv * vh.weights[0][k];
}
vh.sgd_update(&vw_grads, &vb_grads, lr);
let n_layers = self.layers.len();
let mut delta = trunk_delta;
for i in (0..n_layers).rev() {
let rg = relu_grad(&trunk_pre[i]);
for (d, r) in delta.iter_mut().zip(&rg) { *d *= r; }
let out_sz = self.layers[i].out_size;
let in_sz = self.layers[i].in_size;
let layer_input = &trunk_post[i];
let mut w_grads = vec![vec![0.0; in_sz]; out_sz];
let mut b_grads = vec![0.0; out_sz];
let mut prev_delta = vec![0.0; in_sz];
for j in 0..out_sz {
b_grads[j] = delta[j].clamp(-grad_clip, grad_clip);
for k in 0..in_sz {
let g = (delta[j] * layer_input[k]).clamp(-grad_clip, grad_clip);
w_grads[j][k] = g;
prev_delta[k] += delta[j] * self.layers[i].weights[j][k];
}
}
self.layers[i].sgd_update(&w_grads, &b_grads, lr);
delta = prev_delta;
}
}
loss / n_masked as f64
}
pub fn copy_from(&mut self, src: &QNetwork) {
self.layers = src.layers.clone();
self.value_head = src.value_head.clone();
self.advantage_head = src.advantage_head.clone();
}
pub fn save(&self, path: &str) -> std::io::Result<()> {
std::fs::write(path, serde_json::to_string(self).unwrap())
}
pub fn load(path: &str) -> std::io::Result<Self> {
let raw = std::fs::read_to_string(path)?;
Ok(serde_json::from_str(&raw).unwrap())
}
}