use ghostflow_core::Tensor;
use crate::deep::layers::{Conv2d, Dense, BatchNorm2d};
use crate::deep::activations::ReLU;
pub struct CLIP {
image_encoder: CLIPImageEncoder,
text_encoder: CLIPTextEncoder,
temperature: f32,
}
struct CLIPImageEncoder {
conv_layers: Vec<Conv2d>,
bn_layers: Vec<BatchNorm2d>,
projection: Dense,
}
impl CLIPImageEncoder {
fn new(embed_dim: usize) -> Self {
CLIPImageEncoder {
conv_layers: vec![
Conv2d::new(3, 64, (7, 7)).stride((2, 2)).padding((3, 3)),
Conv2d::new(64, 128, (3, 3)).stride((2, 2)).padding((1, 1)),
Conv2d::new(128, 256, (3, 3)).stride((2, 2)).padding((1, 1)),
],
bn_layers: vec![
BatchNorm2d::new(64),
BatchNorm2d::new(128),
BatchNorm2d::new(256),
],
projection: Dense::new(256, embed_dim),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = x.clone();
for (conv, bn) in self.conv_layers.iter_mut().zip(self.bn_layers.iter_mut()) {
out = conv.forward(&out, training);
out = bn.forward(&out, training);
out = ReLU::new().forward(&out);
}
out = self.global_avg_pool(&out);
self.projection.forward(&out, training)
}
fn global_avg_pool(&self, x: &Tensor) -> Tensor {
let dims = x.dims();
let batch = dims[0];
let channels = dims[1];
let height = dims[2];
let width = dims[3];
let data = x.data_f32();
let mut result = vec![0.0f32; batch * channels];
for b in 0..batch {
for c in 0..channels {
let mut sum = 0.0f32;
for h in 0..height {
for w in 0..width {
let idx = ((b * channels + c) * height + h) * width + w;
sum += data[idx];
}
}
result[b * channels + c] = sum / (height * width) as f32;
}
}
Tensor::from_slice(&result, &[batch, channels]).unwrap()
}
}
struct CLIPTextEncoder {
embedding: Dense,
transformer_layers: Vec<TransformerLayer>,
projection: Dense,
}
struct TransformerLayer {
attention: Dense,
ffn: Vec<Dense>,
}
impl TransformerLayer {
fn new(d_model: usize) -> Self {
TransformerLayer {
attention: Dense::new(d_model, d_model),
ffn: vec![
Dense::new(d_model, d_model * 4),
Dense::new(d_model * 4, d_model),
],
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let attn_out = self.attention.forward(x, training);
let mut ffn_out = attn_out;
for (i, layer) in self.ffn.iter_mut().enumerate() {
ffn_out = layer.forward(&ffn_out, training);
if i == 0 {
ffn_out = ReLU::new().forward(&ffn_out);
}
}
ffn_out
}
}
impl CLIPTextEncoder {
fn new(vocab_size: usize, d_model: usize, embed_dim: usize, num_layers: usize) -> Self {
CLIPTextEncoder {
embedding: Dense::new(vocab_size, d_model),
transformer_layers: (0..num_layers).map(|_| TransformerLayer::new(d_model)).collect(),
projection: Dense::new(d_model, embed_dim),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.embedding.forward(x, training);
for layer in &mut self.transformer_layers {
out = layer.forward(&out, training);
}
self.projection.forward(&out, training)
}
}
impl CLIP {
pub fn new(vocab_size: usize, embed_dim: usize) -> Self {
CLIP {
image_encoder: CLIPImageEncoder::new(embed_dim),
text_encoder: CLIPTextEncoder::new(vocab_size, 512, embed_dim, 12),
temperature: 0.07,
}
}
pub fn forward(&mut self, images: &Tensor, texts: &Tensor, training: bool) -> (Tensor, Tensor) {
let image_features = self.image_encoder.forward(images, training);
let text_features = self.text_encoder.forward(texts, training);
let image_features_norm = self.l2_normalize(&image_features);
let text_features_norm = self.l2_normalize(&text_features);
(image_features_norm, text_features_norm)
}
fn l2_normalize(&self, x: &Tensor) -> Tensor {
let data = x.data_f32();
let batch_size = x.dims()[0];
let dim = x.dims()[1];
let mut result = Vec::new();
for b in 0..batch_size {
let offset = b * dim;
let mut norm = 0.0f32;
for i in 0..dim {
norm += data[offset + i].powi(2);
}
norm = norm.sqrt().max(1e-8);
for i in 0..dim {
result.push(data[offset + i] / norm);
}
}
Tensor::from_slice(&result, x.dims()).unwrap()
}
}
pub struct ALIGN {
image_encoder: EfficientNetEncoder,
text_encoder: BERTEncoder,
temperature: f32,
}
struct EfficientNetEncoder {
blocks: Vec<MBConvBlock>,
projection: Dense,
}
struct MBConvBlock {
expand: Conv2d,
depthwise: Conv2d,
project: Conv2d,
}
impl MBConvBlock {
fn new(in_channels: usize, out_channels: usize) -> Self {
MBConvBlock {
expand: Conv2d::new(in_channels, in_channels * 6, (1, 1)),
depthwise: Conv2d::new(in_channels * 6, in_channels * 6, (3, 3)).padding((1, 1)),
project: Conv2d::new(in_channels * 6, out_channels, (1, 1)),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.expand.forward(x, training);
out = ReLU::new().forward(&out);
out = self.depthwise.forward(&out, training);
out = ReLU::new().forward(&out);
self.project.forward(&out, training)
}
}
impl EfficientNetEncoder {
fn new(embed_dim: usize) -> Self {
EfficientNetEncoder {
blocks: vec![
MBConvBlock::new(3, 32),
MBConvBlock::new(32, 64),
MBConvBlock::new(64, 128),
],
projection: Dense::new(128, embed_dim),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = x.clone();
for block in &mut self.blocks {
out = block.forward(&out, training);
}
self.projection.forward(&out, training)
}
}
struct BERTEncoder {
embedding: Dense,
layers: Vec<TransformerLayer>,
projection: Dense,
}
impl BERTEncoder {
fn new(vocab_size: usize, d_model: usize, embed_dim: usize) -> Self {
BERTEncoder {
embedding: Dense::new(vocab_size, d_model),
layers: (0..12).map(|_| TransformerLayer::new(d_model)).collect(),
projection: Dense::new(d_model, embed_dim),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.embedding.forward(x, training);
for layer in &mut self.layers {
out = layer.forward(&out, training);
}
self.projection.forward(&out, training)
}
}
impl ALIGN {
pub fn new(vocab_size: usize, embed_dim: usize) -> Self {
ALIGN {
image_encoder: EfficientNetEncoder::new(embed_dim),
text_encoder: BERTEncoder::new(vocab_size, 768, embed_dim),
temperature: 0.07,
}
}
pub fn forward(&mut self, images: &Tensor, texts: &Tensor, training: bool) -> (Tensor, Tensor) {
let image_features = self.image_encoder.forward(images, training);
let text_features = self.text_encoder.forward(texts, training);
(image_features, text_features)
}
}
pub struct ViLT {
patch_embedding: PatchEmbedding,
text_embedding: Dense,
transformer: Vec<TransformerLayer>,
mlm_head: Dense,
itm_head: Dense,
}
struct PatchEmbedding {
projection: Conv2d,
}
impl PatchEmbedding {
fn new(patch_size: usize, embed_dim: usize) -> Self {
PatchEmbedding {
projection: Conv2d::new(3, embed_dim, (patch_size, patch_size)).stride((patch_size, patch_size)),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
self.projection.forward(x, training)
}
}
impl ViLT {
pub fn new(vocab_size: usize, patch_size: usize, embed_dim: usize) -> Self {
ViLT {
patch_embedding: PatchEmbedding::new(patch_size, embed_dim),
text_embedding: Dense::new(vocab_size, embed_dim),
transformer: (0..12).map(|_| TransformerLayer::new(embed_dim)).collect(),
mlm_head: Dense::new(embed_dim, vocab_size),
itm_head: Dense::new(embed_dim, 2),
}
}
pub fn forward(&mut self, images: &Tensor, texts: &Tensor, training: bool) -> (Tensor, Tensor) {
let image_embeds = self.patch_embedding.forward(images, training);
let text_embeds = self.text_embedding.forward(texts, training);
let combined = self.concatenate(&image_embeds, &text_embeds);
let mut out = combined;
for layer in &mut self.transformer {
out = layer.forward(&out, training);
}
let mlm_logits = self.mlm_head.forward(&out, training);
let itm_logits = self.itm_head.forward(&out, training);
(mlm_logits, itm_logits)
}
fn concatenate(&self, a: &Tensor, b: &Tensor) -> Tensor {
a.clone() }
}
pub struct BLIP {
vision_encoder: VisionEncoder,
text_encoder: TextEncoder,
text_decoder: TextDecoder,
}
struct VisionEncoder {
conv_layers: Vec<Conv2d>,
projection: Dense,
}
impl VisionEncoder {
fn new(embed_dim: usize) -> Self {
VisionEncoder {
conv_layers: vec![
Conv2d::new(3, 64, (3, 3)).padding((1, 1)),
Conv2d::new(64, 128, (3, 3)).padding((1, 1)),
],
projection: Dense::new(128, embed_dim),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = x.clone();
for conv in &mut self.conv_layers {
out = conv.forward(&out, training);
out = ReLU::new().forward(&out);
}
self.projection.forward(&out, training)
}
}
struct TextEncoder {
embedding: Dense,
layers: Vec<TransformerLayer>,
}
impl TextEncoder {
fn new(vocab_size: usize, d_model: usize) -> Self {
TextEncoder {
embedding: Dense::new(vocab_size, d_model),
layers: (0..6).map(|_| TransformerLayer::new(d_model)).collect(),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.embedding.forward(x, training);
for layer in &mut self.layers {
out = layer.forward(&out, training);
}
out
}
}
struct TextDecoder {
embedding: Dense,
layers: Vec<TransformerLayer>,
lm_head: Dense,
}
impl TextDecoder {
fn new(vocab_size: usize, d_model: usize) -> Self {
TextDecoder {
embedding: Dense::new(vocab_size, d_model),
layers: (0..6).map(|_| TransformerLayer::new(d_model)).collect(),
lm_head: Dense::new(d_model, vocab_size),
}
}
fn forward(&mut self, x: &Tensor, training: bool) -> Tensor {
let mut out = self.embedding.forward(x, training);
for layer in &mut self.layers {
out = layer.forward(&out, training);
}
self.lm_head.forward(&out, training)
}
}
impl BLIP {
pub fn new(vocab_size: usize, embed_dim: usize) -> Self {
BLIP {
vision_encoder: VisionEncoder::new(embed_dim),
text_encoder: TextEncoder::new(vocab_size, embed_dim),
text_decoder: TextDecoder::new(vocab_size, embed_dim),
}
}
pub fn forward(&mut self, images: &Tensor, texts: &Tensor, training: bool) -> (Tensor, Tensor) {
let image_features = self.vision_encoder.forward(images, training);
let text_features = self.text_encoder.forward(texts, training);
let decoded_text = self.text_decoder.forward(texts, training);
(text_features, decoded_text)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_clip() {
let mut clip = CLIP::new(50000, 512);
let images = Tensor::from_slice(&vec![0.5f32; 2 * 3 * 224 * 224], &[2, 3, 224, 224]).unwrap();
let texts = Tensor::from_slice(&vec![1.0f32; 2 * 77], &[2, 77]).unwrap();
let (img_feat, txt_feat) = clip.forward(&images, &texts, false);
assert_eq!(img_feat.dims()[1], 512);
assert_eq!(txt_feat.dims()[1], 512);
}
#[test]
fn test_vilt() {
let mut vilt = ViLT::new(30000, 16, 768);
let images = Tensor::from_slice(&vec![0.5f32; 1 * 3 * 224 * 224], &[1, 3, 224, 224]).unwrap();
let texts = Tensor::from_slice(&vec![1.0f32; 1 * 40], &[1, 40]).unwrap();
let (mlm, itm) = vilt.forward(&images, &texts, false);
assert_eq!(mlm.dims()[0], 1);
}
}