ghostflow_nn/
clip.rs

1//! CLIP (Contrastive Language-Image Pre-training)
2//!
3//! Implements CLIP architecture:
4//! - Vision Transformer or ResNet for image encoding
5//! - Transformer for text encoding
6//! - Contrastive learning objective
7//! - Zero-shot classification
8//! - Text-image similarity
9
10use ghostflow_core::Tensor;
11use crate::linear::Linear;
12use crate::vision_transformer::{VisionTransformer, ViTConfig};
13use crate::Module;
14
15/// CLIP configuration
16#[derive(Debug, Clone)]
17pub struct CLIPConfig {
18    /// Embedding dimension (shared between vision and text)
19    pub embed_dim: usize,
20    /// Vision config
21    pub vision_config: CLIPVisionConfig,
22    /// Text config
23    pub text_config: CLIPTextConfig,
24    /// Logit scale initialization
25    pub logit_scale_init_value: f32,
26}
27
28/// CLIP Vision configuration
29#[derive(Debug, Clone)]
30pub struct CLIPVisionConfig {
31    /// Image size
32    pub image_size: usize,
33    /// Patch size
34    pub patch_size: usize,
35    /// Hidden size
36    pub hidden_size: usize,
37    /// Number of layers
38    pub num_layers: usize,
39    /// Number of attention heads
40    pub num_heads: usize,
41    /// MLP ratio
42    pub mlp_ratio: usize,
43}
44
45/// CLIP Text configuration
46#[derive(Debug, Clone)]
47pub struct CLIPTextConfig {
48    /// Vocabulary size
49    pub vocab_size: usize,
50    /// Hidden size
51    pub hidden_size: usize,
52    /// Number of layers
53    pub num_layers: usize,
54    /// Number of attention heads
55    pub num_heads: usize,
56    /// Maximum sequence length
57    pub max_position_embeddings: usize,
58}
59
60impl Default for CLIPConfig {
61    fn default() -> Self {
62        CLIPConfig {
63            embed_dim: 512,
64            vision_config: CLIPVisionConfig::default(),
65            text_config: CLIPTextConfig::default(),
66            logit_scale_init_value: 2.6592, // ln(1/0.07)
67        }
68    }
69}
70
71impl Default for CLIPVisionConfig {
72    fn default() -> Self {
73        CLIPVisionConfig {
74            image_size: 224,
75            patch_size: 16,
76            hidden_size: 768,
77            num_layers: 12,
78            num_heads: 12,
79            mlp_ratio: 4,
80        }
81    }
82}
83
84impl Default for CLIPTextConfig {
85    fn default() -> Self {
86        CLIPTextConfig {
87            vocab_size: 49408,
88            hidden_size: 512,
89            num_layers: 12,
90            num_heads: 8,
91            max_position_embeddings: 77,
92        }
93    }
94}
95
96impl CLIPConfig {
97    /// CLIP ViT-B/32
98    pub fn vit_b_32() -> Self {
99        CLIPConfig {
100            embed_dim: 512,
101            vision_config: CLIPVisionConfig {
102                image_size: 224,
103                patch_size: 32,
104                hidden_size: 768,
105                num_layers: 12,
106                num_heads: 12,
107                mlp_ratio: 4,
108            },
109            text_config: CLIPTextConfig {
110                vocab_size: 49408,
111                hidden_size: 512,
112                num_layers: 12,
113                num_heads: 8,
114                max_position_embeddings: 77,
115            },
116            logit_scale_init_value: 2.6592,
117        }
118    }
119    
120    /// CLIP ViT-B/16
121    pub fn vit_b_16() -> Self {
122        CLIPConfig {
123            embed_dim: 512,
124            vision_config: CLIPVisionConfig {
125                image_size: 224,
126                patch_size: 16,
127                hidden_size: 768,
128                num_layers: 12,
129                num_heads: 12,
130                mlp_ratio: 4,
131            },
132            text_config: CLIPTextConfig::default(),
133            logit_scale_init_value: 2.6592,
134        }
135    }
136    
137    /// CLIP ViT-L/14
138    pub fn vit_l_14() -> Self {
139        CLIPConfig {
140            embed_dim: 768,
141            vision_config: CLIPVisionConfig {
142                image_size: 224,
143                patch_size: 14,
144                hidden_size: 1024,
145                num_layers: 24,
146                num_heads: 16,
147                mlp_ratio: 4,
148            },
149            text_config: CLIPTextConfig {
150                vocab_size: 49408,
151                hidden_size: 768,
152                num_layers: 12,
153                num_heads: 12,
154                max_position_embeddings: 77,
155            },
156            logit_scale_init_value: 2.6592,
157        }
158    }
159}
160
161/// CLIP Vision Encoder (using Vision Transformer)
162pub struct CLIPVisionEncoder {
163    vit: VisionTransformer,
164    projection: Linear,
165}
166
167impl CLIPVisionEncoder {
168    /// Create new vision encoder
169    pub fn new(config: &CLIPVisionConfig, embed_dim: usize) -> Self {
170        // Convert to ViT config
171        let vit_config = ViTConfig {
172            image_size: config.image_size,
173            patch_size: config.patch_size,
174            in_channels: 3,
175            embed_dim: config.hidden_size,
176            num_layers: config.num_layers,
177            num_heads: config.num_heads,
178            mlp_dim: config.hidden_size * config.mlp_ratio,
179            num_classes: 0, // No classification head
180            dropout: 0.0,
181        };
182        
183        let vit = VisionTransformer::new(vit_config);
184        let projection = Linear::new(config.hidden_size, embed_dim);
185        
186        CLIPVisionEncoder { vit, projection }
187    }
188    
189    /// Encode images
190    pub fn forward(&self, images: &Tensor) -> Result<Tensor, String> {
191        // Get ViT features (CLS token)
192        let features = self.vit.forward(images)?;
193        
194        // Project to shared embedding space
195        Ok(self.projection.forward(&features))
196    }
197}
198
199/// CLIP Text Encoder
200pub struct CLIPTextEncoder {
201    token_embedding: Tensor,
202    position_embedding: Tensor,
203    layers: Vec<CLIPTextLayer>,
204    ln_final: LayerNorm,
205    projection: Linear,
206}
207
208impl CLIPTextEncoder {
209    /// Create new text encoder
210    pub fn new(config: &CLIPTextConfig, embed_dim: usize) -> Self {
211        let token_embedding = Tensor::randn(&[config.vocab_size, config.hidden_size]);
212        let position_embedding = Tensor::randn(&[config.max_position_embeddings, config.hidden_size]);
213        
214        let layers = (0..config.num_layers)
215            .map(|_| CLIPTextLayer::new(config.hidden_size, config.num_heads))
216            .collect();
217        
218        let ln_final = LayerNorm::new(config.hidden_size, 1e-5);
219        let projection = Linear::new(config.hidden_size, embed_dim);
220        
221        CLIPTextEncoder {
222            token_embedding,
223            position_embedding,
224            layers,
225            ln_final,
226            projection,
227        }
228    }
229    
230    /// Encode text
231    pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
232        let dims = input_ids.dims();
233        let seq_length = dims[1];
234        
235        // Get token embeddings
236        let mut hidden_states = self.get_token_embeddings(input_ids)?;
237        
238        // Add position embeddings
239        hidden_states = self.add_position_embeddings(&hidden_states, seq_length)?;
240        
241        // Pass through transformer layers
242        for layer in &self.layers {
243            hidden_states = layer.forward(&hidden_states)?;
244        }
245        
246        // Final layer norm
247        hidden_states = self.ln_final.forward(&hidden_states)?;
248        
249        // Extract features at EOS token position (last token)
250        let features = self.extract_eos_features(&hidden_states, seq_length)?;
251        
252        // Project to shared embedding space
253        Ok(self.projection.forward(&features))
254    }
255    
256    fn get_token_embeddings(&self, input_ids: &Tensor) -> Result<Tensor, String> {
257        let ids_data = input_ids.data_f32();
258        let embed_data = self.token_embedding.data_f32();
259        let dims = input_ids.dims();
260        let batch_size = dims[0];
261        let seq_length = dims[1];
262        let hidden_size = self.token_embedding.dims()[1];
263        
264        let mut result = Vec::with_capacity(batch_size * seq_length * hidden_size);
265        
266        for &id in ids_data.iter() {
267            let idx = id as usize;
268            let start = idx * hidden_size;
269            let end = start + hidden_size;
270            result.extend_from_slice(&embed_data[start..end]);
271        }
272        
273        Tensor::from_slice(&result, &[batch_size, seq_length, hidden_size])
274            .map_err(|e| format!("Failed to create embeddings: {:?}", e))
275    }
276    
277    fn add_position_embeddings(&self, hidden_states: &Tensor, seq_length: usize) -> Result<Tensor, String> {
278        let pos_embed_data = self.position_embedding.data_f32();
279        let hidden_data = hidden_states.data_f32();
280        let dims = hidden_states.dims();
281        let hidden_size = dims[2];
282        
283        let mut result = Vec::with_capacity(hidden_data.len());
284        
285        for i in 0..hidden_data.len() {
286            let pos = (i / hidden_size) % seq_length;
287            let pos_idx = pos * hidden_size + (i % hidden_size);
288            result.push(hidden_data[i] + pos_embed_data[pos_idx]);
289        }
290        
291        Tensor::from_slice(&result, dims)
292            .map_err(|e| format!("Failed to add position embeddings: {:?}", e))
293    }
294    
295    fn extract_eos_features(&self, hidden_states: &Tensor, seq_length: usize) -> Result<Tensor, String> {
296        let data = hidden_states.data_f32();
297        let dims = hidden_states.dims();
298        let batch_size = dims[0];
299        let hidden_size = dims[2];
300        
301        let mut result = Vec::with_capacity(batch_size * hidden_size);
302        
303        // Extract last token for each batch
304        for b in 0..batch_size {
305            let start = (b * seq_length + seq_length - 1) * hidden_size;
306            let end = start + hidden_size;
307            result.extend_from_slice(&data[start..end]);
308        }
309        
310        Tensor::from_slice(&result, &[batch_size, hidden_size])
311            .map_err(|e| format!("Failed to extract EOS features: {:?}", e))
312    }
313}
314
315/// CLIP Text Transformer Layer
316pub struct CLIPTextLayer {
317    self_attn: MultiHeadAttention,
318    mlp: MLP,
319    ln1: LayerNorm,
320    ln2: LayerNorm,
321}
322
323impl CLIPTextLayer {
324    fn new(hidden_size: usize, num_heads: usize) -> Self {
325        CLIPTextLayer {
326            self_attn: MultiHeadAttention::new(hidden_size, num_heads),
327            mlp: MLP::new(hidden_size, hidden_size * 4),
328            ln1: LayerNorm::new(hidden_size, 1e-5),
329            ln2: LayerNorm::new(hidden_size, 1e-5),
330        }
331    }
332    
333    fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
334        // Self attention with residual
335        let residual = x.clone();
336        let x = self.ln1.forward(x)?;
337        let x = self.self_attn.forward(&x)?;
338        let x = x.add(&residual).unwrap_or(x);
339        
340        // MLP with residual
341        let residual = x.clone();
342        let x = self.ln2.forward(&x)?;
343        let x = self.mlp.forward(&x)?;
344        let x = x.add(&residual).unwrap_or(x);
345        
346        Ok(x)
347    }
348}
349
350/// Multi-Head Attention (simplified)
351pub struct MultiHeadAttention {
352    q_proj: Linear,
353    _k_proj: Linear,
354    _v_proj: Linear,
355    out_proj: Linear,
356    _num_heads: usize,
357    _head_dim: usize,
358}
359
360impl MultiHeadAttention {
361    fn new(hidden_size: usize, num_heads: usize) -> Self {
362        let head_dim = hidden_size / num_heads;
363        MultiHeadAttention {
364            q_proj: Linear::new(hidden_size, hidden_size),
365            _k_proj: Linear::new(hidden_size, hidden_size),
366            _v_proj: Linear::new(hidden_size, hidden_size),
367            out_proj: Linear::new(hidden_size, hidden_size),
368            _num_heads: num_heads,
369            _head_dim: head_dim,
370        }
371    }
372    
373    fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
374        let q = self.q_proj.forward(x);
375        // Simplified attention (real implementation would do proper multi-head)
376        Ok(self.out_proj.forward(&q))
377    }
378}
379
380/// MLP (Feed-Forward Network)
381pub struct MLP {
382    fc1: Linear,
383    fc2: Linear,
384}
385
386impl MLP {
387    fn new(hidden_size: usize, intermediate_size: usize) -> Self {
388        MLP {
389            fc1: Linear::new(hidden_size, intermediate_size),
390            fc2: Linear::new(intermediate_size, hidden_size),
391        }
392    }
393    
394    fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
395        let x = self.fc1.forward(x);
396        let x = x.gelu();
397        Ok(self.fc2.forward(&x))
398    }
399}
400
401/// Layer Normalization
402pub struct LayerNorm {
403    weight: Tensor,
404    bias: Tensor,
405    eps: f32,
406}
407
408impl LayerNorm {
409    fn new(hidden_size: usize, eps: f32) -> Self {
410        LayerNorm {
411            weight: Tensor::ones(&[hidden_size]),
412            bias: Tensor::zeros(&[hidden_size]),
413            eps,
414        }
415    }
416    
417    fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
418        let x_data = x.data_f32();
419        let dims = x.dims();
420        let hidden_size = dims[dims.len() - 1];
421        let batch_seq = x_data.len() / hidden_size;
422        
423        let weight_data = self.weight.data_f32();
424        let bias_data = self.bias.data_f32();
425        let mut result = Vec::with_capacity(x_data.len());
426        
427        for i in 0..batch_seq {
428            let start = i * hidden_size;
429            let end = start + hidden_size;
430            let slice = &x_data[start..end];
431            
432            // Compute mean and variance
433            let mean: f32 = slice.iter().sum::<f32>() / hidden_size as f32;
434            let variance: f32 = slice.iter()
435                .map(|x| (x - mean).powi(2))
436                .sum::<f32>() / hidden_size as f32;
437            let std = (variance + self.eps).sqrt();
438            
439            // Normalize and scale
440            for (j, &x) in slice.iter().enumerate() {
441                result.push((x - mean) / std * weight_data[j] + bias_data[j]);
442            }
443        }
444        
445        Tensor::from_slice(&result, dims)
446            .map_err(|e| format!("Failed to normalize: {:?}", e))
447    }
448}
449
450/// CLIP Model
451pub struct CLIP {
452    vision_encoder: CLIPVisionEncoder,
453    text_encoder: CLIPTextEncoder,
454    logit_scale: f32,
455}
456
457impl CLIP {
458    /// Create new CLIP model
459    pub fn new(config: CLIPConfig) -> Self {
460        let vision_encoder = CLIPVisionEncoder::new(&config.vision_config, config.embed_dim);
461        let text_encoder = CLIPTextEncoder::new(&config.text_config, config.embed_dim);
462        let logit_scale = config.logit_scale_init_value.exp();
463        
464        CLIP {
465            vision_encoder,
466            text_encoder,
467            logit_scale,
468        }
469    }
470    
471    /// Encode images
472    pub fn encode_image(&self, images: &Tensor) -> Result<Tensor, String> {
473        let features = self.vision_encoder.forward(images)?;
474        Ok(self.normalize(&features))
475    }
476    
477    /// Encode text
478    pub fn encode_text(&self, input_ids: &Tensor) -> Result<Tensor, String> {
479        let features = self.text_encoder.forward(input_ids)?;
480        Ok(self.normalize(&features))
481    }
482    
483    /// Forward pass (compute similarity matrix)
484    pub fn forward(&self, images: &Tensor, input_ids: &Tensor) -> Result<Tensor, String> {
485        let image_features = self.encode_image(images)?;
486        let text_features = self.encode_text(input_ids)?;
487        
488        // Compute cosine similarity
489        self.compute_similarity(&image_features, &text_features)
490    }
491    
492    /// Normalize features (L2 normalization)
493    fn normalize(&self, x: &Tensor) -> Tensor {
494        let data = x.data_f32();
495        let dims = x.dims();
496        let feature_dim = dims[dims.len() - 1];
497        let batch_size = data.len() / feature_dim;
498        
499        let mut result = Vec::with_capacity(data.len());
500        
501        for i in 0..batch_size {
502            let start = i * feature_dim;
503            let end = start + feature_dim;
504            let slice = &data[start..end];
505            
506            // Compute L2 norm
507            let norm: f32 = slice.iter().map(|x| x * x).sum::<f32>().sqrt();
508            let norm = norm.max(1e-8); // Avoid division by zero
509            
510            // Normalize
511            for &x in slice.iter() {
512                result.push(x / norm);
513            }
514        }
515        
516        Tensor::from_slice(&result, dims).unwrap_or_else(|_| x.clone())
517    }
518    
519    /// Compute similarity matrix
520    fn compute_similarity(&self, image_features: &Tensor, text_features: &Tensor) -> Result<Tensor, String> {
521        let img_data = image_features.data_f32();
522        let txt_data = text_features.data_f32();
523        
524        let img_dims = image_features.dims();
525        let txt_dims = text_features.dims();
526        
527        let num_images = img_dims[0];
528        let num_texts = txt_dims[0];
529        let feature_dim = img_dims[1];
530        
531        let mut result = Vec::with_capacity(num_images * num_texts);
532        
533        // Compute dot product (cosine similarity since features are normalized)
534        for i in 0..num_images {
535            for j in 0..num_texts {
536                let mut dot_product = 0.0;
537                for k in 0..feature_dim {
538                    dot_product += img_data[i * feature_dim + k] * txt_data[j * feature_dim + k];
539                }
540                result.push(dot_product * self.logit_scale);
541            }
542        }
543        
544        Tensor::from_slice(&result, &[num_images, num_texts])
545            .map_err(|e| format!("Failed to compute similarity: {:?}", e))
546    }
547    
548    /// Zero-shot classification
549    pub fn zero_shot_classify(&self, images: &Tensor, text_prompts: &Tensor) -> Result<Vec<usize>, String> {
550        let similarity = self.forward(images, text_prompts)?;
551        let data = similarity.data_f32();
552        let dims = similarity.dims();
553        let num_images = dims[0];
554        let num_classes = dims[1];
555        
556        let mut predictions = Vec::with_capacity(num_images);
557        
558        for i in 0..num_images {
559            let start = i * num_classes;
560            let end = start + num_classes;
561            let scores = &data[start..end];
562            
563            let pred = scores.iter()
564                .enumerate()
565                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
566                .map(|(idx, _)| idx)
567                .unwrap_or(0);
568            
569            predictions.push(pred);
570        }
571        
572        Ok(predictions)
573    }
574    
575    /// Image-text retrieval (find best matching text for each image)
576    pub fn image_to_text_retrieval(&self, images: &Tensor, texts: &Tensor) -> Result<Vec<usize>, String> {
577        self.zero_shot_classify(images, texts)
578    }
579    
580    /// Text-image retrieval (find best matching image for each text)
581    pub fn text_to_image_retrieval(&self, images: &Tensor, texts: &Tensor) -> Result<Vec<usize>, String> {
582        let similarity = self.forward(images, texts)?;
583        let data = similarity.data_f32();
584        let dims = similarity.dims();
585        let num_images = dims[0];
586        let num_texts = dims[1];
587        
588        let mut predictions = Vec::with_capacity(num_texts);
589        
590        // Transpose: for each text, find best image
591        for j in 0..num_texts {
592            let mut best_idx = 0;
593            let mut best_score = data[j];
594            
595            for i in 1..num_images {
596                let score = data[i * num_texts + j];
597                if score > best_score {
598                    best_score = score;
599                    best_idx = i;
600                }
601            }
602            
603            predictions.push(best_idx);
604        }
605        
606        Ok(predictions)
607    }
608}
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613    
614    #[test]
615    fn test_clip_config() {
616        let config = CLIPConfig::vit_b_32();
617        assert_eq!(config.embed_dim, 512);
618        assert_eq!(config.vision_config.patch_size, 32);
619        
620        let config = CLIPConfig::vit_l_14();
621        assert_eq!(config.embed_dim, 768);
622        assert_eq!(config.vision_config.num_layers, 24);
623    }
624    
625    #[test]
626    fn test_clip_vision_encoder() {
627        let config = CLIPVisionConfig::default();
628        let encoder = CLIPVisionEncoder::new(&config, 512);
629        
630        let images = Tensor::randn(&[2, 3, 224, 224]);
631        let features = encoder.forward(&images).unwrap();
632        
633        assert_eq!(features.dims(), &[2, 512]);
634    }
635    
636    #[test]
637    fn test_clip_text_encoder() {
638        let config = CLIPTextConfig::default();
639        let encoder = CLIPTextEncoder::new(&config, 512);
640        
641        let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
642        let features = encoder.forward(&input_ids).unwrap();
643        
644        assert_eq!(features.dims(), &[2, 512]);
645    }
646    
647    #[test]
648    fn test_clip_model() {
649        let config = CLIPConfig::vit_b_32();
650        let model = CLIP::new(config);
651        
652        let images = Tensor::randn(&[2, 3, 224, 224]);
653        let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
654        
655        let similarity = model.forward(&images, &input_ids).unwrap();
656        assert_eq!(similarity.dims(), &[2, 2]); // 2 images x 2 texts
657    }
658    
659    #[test]
660    fn test_zero_shot_classification() {
661        let config = CLIPConfig::vit_b_32();
662        let model = CLIP::new(config);
663        
664        let images = Tensor::randn(&[3, 3, 224, 224]);
665        let text_prompts = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
666        
667        let predictions = model.zero_shot_classify(&images, &text_prompts).unwrap();
668        assert_eq!(predictions.len(), 3); // 3 images
669    }
670    
671    #[test]
672    fn test_layer_norm() {
673        let ln = LayerNorm::new(128, 1e-5);
674        let x = Tensor::randn(&[2, 4, 128]);
675        let output = ln.forward(&x).unwrap();
676        assert_eq!(output.dims(), &[2, 4, 128]);
677    }
678}