use ghostflow_core::Tensor;
use crate::deep::layers::{Conv2d, Dense, BatchNorm2d};
use crate::deep::activations::ReLU;
pub struct SimCLR {
encoder: ResNetEncoder,
projection_head: ProjectionHead,
temperature: f32,
}
struct ResNetEncoder {
conv_layers: Vec<Conv2d>,
bn_layers: Vec<BatchNorm2d>,
}
impl ResNetEncoder {
fn new() -> Self {
ResNetEncoder {
conv_layers: vec![
Conv2d::new(3, 64, (7, 7)).stride((2, 2)).padding((3, 3)),
Conv2d::new(64, 128, (3, 3)).stride((2, 2)).padding((1, 1)),
Conv2d::new(128, 256, (3, 3)).stride((2, 2)).padding((1, 1)),
Conv2d::new(256, 512, (3, 3)).stride((2, 2)).padding((1, 1)),
],
bn_layers: vec![
BatchNorm2d::new(64),
BatchNorm2d::new(128),
BatchNorm2d::new(256),
BatchNorm2d::new(512),
],
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = x.clone();
for (conv, bn) in self.conv_layers.iter_mut().zip(self.bn_layers.iter_mut()) {
out = conv.forward(&out, training);
out = bn.forward(&out, training);
out = ReLU::new().forward(&out);
}
self.global_avg_pool(&out)
}
fn global_avg_pool(&self, x: &Tensor) -> Tensor {
let dims = x.dims();
let batch = dims[0];
let channels = dims[1];
let height = dims[2];
let width = dims[3];
let data = x.data_f32();
let mut result = vec![0.0f32; batch * channels];
for b in 0..batch {
for c in 0..channels {
let mut sum = 0.0f32;
for h in 0..height {
for w in 0..width {
let idx = ((b * channels + c) * height + h) * width + w;
sum += data[idx];
}
}
result[b * channels + c] = sum / (height * width) as f32;
}
}
Tensor::from_slice(&result, &[batch, channels]).unwrap()
}
}
struct ProjectionHead {
fc1: Dense,
fc2: Dense,
}
impl ProjectionHead {
fn new(in_dim: usize, hidden_dim: usize, out_dim: usize) -> Self {
ProjectionHead {
fc1: Dense::new(in_dim, hidden_dim),
fc2: Dense::new(hidden_dim, out_dim),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.fc1.forward(x, training);
out = ReLU::new().forward(&out);
self.fc2.forward(&out, training)
}
}
impl SimCLR {
pub fn new(projection_dim: usize) -> Self {
SimCLR {
encoder: ResNetEncoder::new(),
projection_head: ProjectionHead::new(512, 2048, projection_dim),
temperature: 0.5,
}
}
pub fn forward(&mut self, x1: &Tensor, x2: &Tensor, training: bool) -> (Tensor, Tensor) {
let h1 = self.encoder.forward(x1, training);
let h2 = self.encoder.forward(x2, training);
let z1 = self.projection_head.forward(&h1, training);
let z2 = self.projection_head.forward(&h2, training);
(z1, z2)
}
}
pub struct MoCo {
encoder_q: ResNetEncoder,
encoder_k: ResNetEncoder,
queue: Vec<Vec<f32>>,
queue_ptr: usize,
queue_size: usize,
momentum: f32,
}
impl MoCo {
pub fn new(queue_size: usize, momentum: f32) -> Self {
MoCo {
encoder_q: ResNetEncoder::new(),
encoder_k: ResNetEncoder::new(),
queue: vec![vec![0.0f32; 128]; queue_size],
queue_ptr: 0,
queue_size,
momentum,
}
}
pub fn forward(&mut self, x_q: &Tensor, x_k: &Tensor, training: bool) -> (Tensor, Tensor) {
let q = self.encoder_q.forward(x_q, training);
let k = self.encoder_k.forward(x_k, training);
(q, k)
}
pub fn update_queue(&mut self, keys: &Tensor) {
let batch_size = keys.dims()[0];
let key_data = keys.data_f32();
let key_dim = keys.dims()[1];
for b in 0..batch_size {
let key: Vec<f32> = (0..key_dim)
.map(|i| key_data[b * key_dim + i])
.collect();
self.queue[self.queue_ptr] = key;
self.queue_ptr = (self.queue_ptr + 1) % self.queue_size;
}
}
}
pub struct BYOL {
online_network: BYOLNetwork,
target_network: BYOLNetwork,
momentum: f32,
}
struct BYOLNetwork {
encoder: ResNetEncoder,
projector: ProjectionHead,
predictor: Dense,
}
impl BYOLNetwork {
fn new() -> Self {
BYOLNetwork {
encoder: ResNetEncoder::new(),
projector: ProjectionHead::new(512, 4096, 256),
predictor: Dense::new(256, 256),
}
}
fn forward(&mut self, x: &Tensor, training: bool, use_predictor: bool) -> Tensor {
let h = self.encoder.forward(x, training);
let z = self.projector.forward(&h, training);
if use_predictor {
self.predictor.forward(&z, training)
} else {
z
}
}
}
impl BYOL {
pub fn new(momentum: f32) -> Self {
BYOL {
online_network: BYOLNetwork::new(),
target_network: BYOLNetwork::new(),
momentum,
}
}
pub fn forward(&mut self, x1: &Tensor, x2: &Tensor, training: bool) -> (Tensor, Tensor, Tensor, Tensor) {
let p1 = self.online_network.forward(x1, training, true);
let p2 = self.online_network.forward(x2, training, true);
let z1 = self.target_network.forward(x1, false, false);
let z2 = self.target_network.forward(x2, false, false);
(p1, p2, z1, z2)
}
pub fn update_target_network(&mut self) {
}
}
pub struct SwAV {
encoder: ResNetEncoder,
projection_head: ProjectionHead,
prototypes: Dense,
num_prototypes: usize,
temperature: f32,
}
impl SwAV {
pub fn new(num_prototypes: usize, projection_dim: usize) -> Self {
SwAV {
encoder: ResNetEncoder::new(),
projection_head: ProjectionHead::new(512, 2048, projection_dim),
prototypes: Dense::new(projection_dim, num_prototypes),
num_prototypes,
temperature: 0.1,
}
}
pub fn forward(&mut self, x1: &Tensor, x2: &Tensor, training: bool) -> (Tensor, Tensor) {
let h1 = self.encoder.forward(x1, training);
let h2 = self.encoder.forward(x2, training);
let z1 = self.projection_head.forward(&h1, training);
let z2 = self.projection_head.forward(&h2, training);
let q1 = self.prototypes.forward(&z1, training);
let q2 = self.prototypes.forward(&z2, training);
(q1, q2)
}
pub fn sinkhorn(&self, q: &Tensor, num_iters: usize) -> Tensor {
let mut q_normalized = q.clone();
for _ in 0..num_iters {
q_normalized = self.normalize_rows(&q_normalized);
q_normalized = self.normalize_cols(&q_normalized);
}
q_normalized
}
fn normalize_rows(&self, x: &Tensor) -> Tensor {
let data = x.data_f32();
let rows = x.dims()[0];
let cols = x.dims()[1];
let mut result = Vec::new();
for i in 0..rows {
let offset = i * cols;
let mut sum = 0.0f32;
for j in 0..cols {
sum += data[offset + j];
}
for j in 0..cols {
result.push(data[offset + j] / sum.max(1e-8));
}
}
Tensor::from_slice(&result, x.dims()).unwrap()
}
fn normalize_cols(&self, x: &Tensor) -> Tensor {
let data = x.data_f32();
let rows = x.dims()[0];
let cols = x.dims()[1];
let mut col_sums = vec![0.0f32; cols];
for i in 0..rows {
for j in 0..cols {
col_sums[j] += data[i * cols + j];
}
}
let mut result = Vec::new();
for i in 0..rows {
for j in 0..cols {
result.push(data[i * cols + j] / col_sums[j].max(1e-8));
}
}
Tensor::from_slice(&result, x.dims()).unwrap()
}
}
pub struct DINO {
student: DINONetwork,
teacher: DINONetwork,
center: Vec<f32>,
momentum: f32,
}
struct DINONetwork {
backbone: ResNetEncoder,
head: ProjectionHead,
}
impl DINONetwork {
fn new(out_dim: usize) -> Self {
DINONetwork {
backbone: ResNetEncoder::new(),
head: ProjectionHead::new(512, 2048, out_dim),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let features = self.backbone.forward(x, training);
self.head.forward(&features, training)
}
}
impl DINO {
pub fn new(out_dim: usize, momentum: f32) -> Self {
DINO {
student: DINONetwork::new(out_dim),
teacher: DINONetwork::new(out_dim),
center: vec![0.0f32; out_dim],
momentum,
}
}
pub fn forward(&mut self, x_student: &Tensor, x_teacher: &Tensor, training: bool) -> (Tensor, Tensor) {
let student_out = self.student.forward(x_student, training);
let teacher_out = self.teacher.forward(x_teacher, false);
let teacher_centered = self.center_output(&teacher_out);
(student_out, teacher_centered)
}
fn center_output(&self, x: &Tensor) -> Tensor {
let data = x.data_f32();
let batch_size = x.dims()[0];
let dim = x.dims()[1];
let mut result = Vec::new();
for b in 0..batch_size {
for d in 0..dim {
result.push(data[b * dim + d] - self.center[d]);
}
}
Tensor::from_slice(&result, x.dims()).unwrap()
}
pub fn update_center(&mut self, teacher_output: &Tensor) {
let data = teacher_output.data_f32();
let batch_size = teacher_output.dims()[0];
let dim = teacher_output.dims()[1];
let mut batch_mean = vec![0.0f32; dim];
for b in 0..batch_size {
for d in 0..dim {
batch_mean[d] += data[b * dim + d];
}
}
for d in 0..dim {
batch_mean[d] /= batch_size as f32;
self.center[d] = self.momentum * self.center[d] + (1.0 - self.momentum) * batch_mean[d];
}
}
}
pub struct MAE {
encoder: MAEEncoder,
decoder: MAEDecoder,
mask_ratio: f32,
}
struct MAEEncoder {
patch_embed: Conv2d,
transformer_blocks: Vec<TransformerBlock>,
}
struct MAEDecoder {
decoder_embed: Dense,
transformer_blocks: Vec<TransformerBlock>,
pred: Dense,
}
struct TransformerBlock {
attention: Dense,
mlp: Vec<Dense>,
}
impl TransformerBlock {
fn new(dim: usize) -> Self {
TransformerBlock {
attention: Dense::new(dim, dim),
mlp: vec![
Dense::new(dim, dim * 4),
Dense::new(dim * 4, dim),
],
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let attn_out = self.attention.forward(x, training);
let mut mlp_out = attn_out;
for (i, layer) in self.mlp.iter_mut().enumerate() {
mlp_out = layer.forward(&mlp_out, training);
if i == 0 {
mlp_out = ReLU::new().forward(&mlp_out);
}
}
mlp_out
}
}
impl MAEEncoder {
fn new(patch_size: usize, embed_dim: usize, depth: usize) -> Self {
MAEEncoder {
patch_embed: Conv2d::new(3, embed_dim, (patch_size, patch_size)).stride((patch_size, patch_size)),
transformer_blocks: (0..depth).map(|_| TransformerBlock::new(embed_dim)).collect(),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.patch_embed.forward(x, training);
for block in &mut self.transformer_blocks {
out = block.forward(&out, training);
}
out
}
}
impl MAEDecoder {
fn new(encoder_dim: usize, decoder_dim: usize, patch_size: usize, depth: usize) -> Self {
MAEDecoder {
decoder_embed: Dense::new(encoder_dim, decoder_dim),
transformer_blocks: (0..depth).map(|_| TransformerBlock::new(decoder_dim)).collect(),
pred: Dense::new(decoder_dim, patch_size * patch_size * 3),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.decoder_embed.forward(x, training);
for block in &mut self.transformer_blocks {
out = block.forward(&out, training);
}
self.pred.forward(&out, training)
}
}
impl MAE {
pub fn new(patch_size: usize, embed_dim: usize, mask_ratio: f32) -> Self {
MAE {
encoder: MAEEncoder::new(patch_size, embed_dim, 12),
decoder: MAEDecoder::new(embed_dim, 512, patch_size, 8),
mask_ratio,
}
}
pub fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let encoded = self.encoder.forward(x, training);
self.decoder.forward(&encoded, training)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simclr() {
let mut simclr = SimCLR::new(128);
let x1 = Tensor::from_slice(&vec![0.5f32; 2 * 3 * 224 * 224], &[2, 3, 224, 224]).unwrap();
let x2 = Tensor::from_slice(&vec![0.5f32; 2 * 3 * 224 * 224], &[2, 3, 224, 224]).unwrap();
let (z1, z2) = simclr.forward(&x1, &x2, false);
assert_eq!(z1.dims()[1], 128);
assert_eq!(z2.dims()[1], 128);
}
#[test]
fn test_byol() {
let mut byol = BYOL::new(0.996);
let x1 = Tensor::from_slice(&vec![0.5f32; 1 * 3 * 224 * 224], &[1, 3, 224, 224]).unwrap();
let x2 = Tensor::from_slice(&vec![0.5f32; 1 * 3 * 224 * 224], &[1, 3, 224, 224]).unwrap();
let (p1, p2, z1, z2) = byol.forward(&x1, &x2, false);
assert_eq!(p1.dims()[1], 256);
}
}