#![allow(missing_docs)]
use axonml_autograd::Variable;
use axonml_nn::{BatchNorm2d, Conv2d, Module, Parameter};
use axonml_tensor::Tensor;
pub struct PredictiveCodingModule {
predict_conv: Conv2d,
predict_bn: BatchNorm2d,
_channels: usize,
prediction: Option<Variable>,
pub temperature: f32,
}
impl PredictiveCodingModule {
pub fn new(channels: usize) -> Self {
Self {
predict_conv: Conv2d::with_options(channels, channels, (3, 3), (1, 1), (1, 1), true),
predict_bn: BatchNorm2d::new(channels),
_channels: channels,
prediction: None,
temperature: 1.0,
}
}
pub fn forward(&mut self, actual: &Variable) -> (Variable, Variable) {
let shape = actual.shape();
let (b, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
let actual_data = actual.data().to_vec();
let spatial = h * w;
if let Some(ref pred) = self.prediction {
let pred_shape = pred.shape();
if pred_shape[2] == h && pred_shape[3] == w {
let pred_data = pred.data().to_vec();
let mut surprise = vec![0.0f32; b * spatial];
for bi in 0..b {
for y in 0..h {
for x in 0..w {
let mut mse = 0.0f32;
for ci in 0..c {
let idx = bi * c * spatial + ci * spatial + y * w + x;
let diff = actual_data[idx] - pred_data[idx];
mse += diff * diff;
}
surprise[bi * spatial + y * w + x] = mse / c as f32;
}
}
}
let mut gate = vec![0.0f32; b * spatial];
for i in 0..b * spatial {
gate[i] = 1.0 / (1.0 + (-surprise[i] * self.temperature).exp());
}
let surprise_var = Variable::new(
Tensor::from_vec(gate.clone(), &[b, 1, h, w]).unwrap(),
false,
);
let gate_var = Variable::new(
Tensor::from_vec(gate, &[b, 1, h, w]).unwrap(),
false, );
let gate_expanded = gate_var.expand(&[b, c, h, w]);
let ones = Variable::new(
Tensor::from_vec(vec![1.0f32; b * c * h * w], &[b, c, h, w]).unwrap(),
false,
);
let inv_gate = &ones - &gate_expanded;
let gated_var = &(&gate_expanded * actual) + &(&inv_gate * pred);
self.prediction = Some(
self.predict_bn
.forward(&self.predict_conv.forward(&gated_var)),
);
return (gated_var, surprise_var);
}
}
let surprise_data = vec![1.0f32; b * spatial]; let surprise_var = Variable::new(
Tensor::from_vec(surprise_data, &[b, 1, h, w]).unwrap(),
false,
);
self.prediction = Some(self.predict_bn.forward(&self.predict_conv.forward(actual)));
(actual.clone(), surprise_var)
}
pub fn reset(&mut self) {
self.prediction = None;
}
pub fn has_prediction(&self) -> bool {
self.prediction.is_some()
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.predict_conv.parameters());
p.extend(self.predict_bn.parameters());
p
}
pub fn eval(&mut self) {
self.predict_bn.eval();
}
pub fn train(&mut self) {
self.predict_bn.train();
}
}
pub struct MultiScalePredictiveCoding {
pub scale1: PredictiveCodingModule,
pub scale2: PredictiveCodingModule,
pub scale3: PredictiveCodingModule,
}
impl MultiScalePredictiveCoding {
pub fn new(channels: usize) -> Self {
Self {
scale1: PredictiveCodingModule::new(channels),
scale2: PredictiveCodingModule::new(channels),
scale3: PredictiveCodingModule::new(channels),
}
}
pub fn forward(
&mut self,
f1: &Variable,
f2: &Variable,
f3: &Variable,
) -> (
(Variable, Variable),
(Variable, Variable),
(Variable, Variable),
) {
let r1 = self.scale1.forward(f1);
let r2 = self.scale2.forward(f2);
let r3 = self.scale3.forward(f3);
(r1, r2, r3)
}
pub fn reset(&mut self) {
self.scale1.reset();
self.scale2.reset();
self.scale3.reset();
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut p = Vec::new();
p.extend(self.scale1.parameters());
p.extend(self.scale2.parameters());
p.extend(self.scale3.parameters());
p
}
pub fn eval(&mut self) {
self.scale1.eval();
self.scale2.eval();
self.scale3.eval();
}
pub fn train(&mut self) {
self.scale1.train();
self.scale2.train();
self.scale3.train();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_predictive_coding_first_frame() {
let mut pc = PredictiveCodingModule::new(32);
assert!(!pc.has_prediction());
let x = Variable::new(
Tensor::from_vec(vec![0.5; 32 * 8 * 8], &[1, 32, 8, 8]).unwrap(),
false,
);
let (gated, surprise) = pc.forward(&x);
assert_eq!(gated.shape(), vec![1, 32, 8, 8]);
assert_eq!(surprise.shape(), vec![1, 1, 8, 8]);
assert!(pc.has_prediction());
let s_data = surprise.data().to_vec();
assert!(s_data.iter().all(|&v| (v - 1.0).abs() < 1e-5));
}
#[test]
fn test_predictive_coding_identical_frames() {
let mut pc = PredictiveCodingModule::new(16);
let x = Variable::new(
Tensor::from_vec(vec![0.3; 16 * 4 * 4], &[1, 16, 4, 4]).unwrap(),
false,
);
pc.forward(&x);
let (_gated, surprise) = pc.forward(&x);
let s_data = surprise.data().to_vec();
let avg_surprise: f32 = s_data.iter().sum::<f32>() / s_data.len() as f32;
assert!(avg_surprise <= 1.0);
}
#[test]
fn test_predictive_coding_changed_features() {
let mut pc = PredictiveCodingModule::new(8);
let x1 = Variable::new(
Tensor::from_vec(vec![0.0; 8 * 4 * 4], &[1, 8, 4, 4]).unwrap(),
false,
);
let x2 = Variable::new(
Tensor::from_vec(vec![5.0; 8 * 4 * 4], &[1, 8, 4, 4]).unwrap(),
false,
);
pc.forward(&x1);
let (_gated, surprise) = pc.forward(&x2);
let s_data = surprise.data().to_vec();
let avg_surprise: f32 = s_data.iter().sum::<f32>() / s_data.len() as f32;
assert!(
avg_surprise > 0.3,
"Expected high surprise, got {avg_surprise}"
);
}
#[test]
fn test_predictive_coding_output_finite() {
let mut pc = PredictiveCodingModule::new(16);
let x = Variable::new(
Tensor::from_vec(vec![0.5; 16 * 8 * 8], &[1, 16, 8, 8]).unwrap(),
false,
);
pc.forward(&x);
let (gated, surprise) = pc.forward(&x);
assert!(gated.data().to_vec().iter().all(|v| v.is_finite()));
assert!(surprise.data().to_vec().iter().all(|v| v.is_finite()));
}
#[test]
fn test_predictive_coding_reset() {
let mut pc = PredictiveCodingModule::new(8);
let x = Variable::new(
Tensor::from_vec(vec![0.1; 8 * 4 * 4], &[1, 8, 4, 4]).unwrap(),
false,
);
pc.forward(&x);
assert!(pc.has_prediction());
pc.reset();
assert!(!pc.has_prediction());
}
#[test]
fn test_multi_scale_predictive_coding() {
let mut mspc = MultiScalePredictiveCoding::new(96);
let f1 = Variable::new(
Tensor::from_vec(vec![0.1; 96 * 40 * 40], &[1, 96, 40, 40]).unwrap(),
false,
);
let f2 = Variable::new(
Tensor::from_vec(vec![0.1; 96 * 20 * 20], &[1, 96, 20, 20]).unwrap(),
false,
);
let f3 = Variable::new(
Tensor::from_vec(vec![0.1; 96 * 10 * 10], &[1, 96, 10, 10]).unwrap(),
false,
);
let ((g1, s1), (g2, s2), (g3, s3)) = mspc.forward(&f1, &f2, &f3);
assert_eq!(g1.shape(), vec![1, 96, 40, 40]);
assert_eq!(s1.shape(), vec![1, 1, 40, 40]);
assert_eq!(g2.shape(), vec![1, 96, 20, 20]);
assert_eq!(s2.shape(), vec![1, 1, 20, 20]);
assert_eq!(g3.shape(), vec![1, 96, 10, 10]);
assert_eq!(s3.shape(), vec![1, 1, 10, 10]);
}
}