ghostflow_nn/
vision_transformer.rs

1//! Vision Transformer (ViT) Implementation
2//!
3//! Implements Vision Transformers as described in "An Image is Worth 16x16 Words"
4//! - Patch embedding
5//! - Position embedding
6//! - Transformer encoder blocks
7//! - Classification head
8
9use ghostflow_core::Tensor;
10use crate::transformer::TransformerEncoder;
11use crate::linear::Linear;
12use crate::norm::LayerNorm;
13use crate::Module;
14
15/// Vision Transformer configuration
16#[derive(Debug, Clone)]
17pub struct ViTConfig {
18    /// Image size (assumed square)
19    pub image_size: usize,
20    /// Patch size (assumed square)
21    pub patch_size: usize,
22    /// Number of input channels
23    pub in_channels: usize,
24    /// Embedding dimension
25    pub embed_dim: usize,
26    /// Number of transformer layers
27    pub num_layers: usize,
28    /// Number of attention heads
29    pub num_heads: usize,
30    /// MLP hidden dimension
31    pub mlp_dim: usize,
32    /// Number of output classes
33    pub num_classes: usize,
34    /// Dropout rate
35    pub dropout: f32,
36}
37
38impl Default for ViTConfig {
39    fn default() -> Self {
40        ViTConfig {
41            image_size: 224,
42            patch_size: 16,
43            in_channels: 3,
44            embed_dim: 768,
45            num_layers: 12,
46            num_heads: 12,
47            mlp_dim: 3072,
48            num_classes: 1000,
49            dropout: 0.1,
50        }
51    }
52}
53
54impl ViTConfig {
55    /// ViT-Base configuration
56    pub fn vit_base() -> Self {
57        Self::default()
58    }
59    
60    /// ViT-Large configuration
61    pub fn vit_large() -> Self {
62        ViTConfig {
63            embed_dim: 1024,
64            num_layers: 24,
65            num_heads: 16,
66            mlp_dim: 4096,
67            ..Default::default()
68        }
69    }
70    
71    /// ViT-Huge configuration
72    pub fn vit_huge() -> Self {
73        ViTConfig {
74            embed_dim: 1280,
75            num_layers: 32,
76            num_heads: 16,
77            mlp_dim: 5120,
78            ..Default::default()
79        }
80    }
81    
82    /// Get number of patches
83    pub fn num_patches(&self) -> usize {
84        (self.image_size / self.patch_size) * (self.image_size / self.patch_size)
85    }
86}
87
88/// Patch embedding layer
89pub struct PatchEmbedding {
90    /// Projection layer
91    projection: Linear,
92    /// Patch size
93    patch_size: usize,
94    /// Number of patches
95    num_patches: usize,
96}
97
98impl PatchEmbedding {
99    /// Create new patch embedding
100    pub fn new(config: &ViTConfig) -> Self {
101        let patch_dim = config.patch_size * config.patch_size * config.in_channels;
102        let projection = Linear::new(patch_dim, config.embed_dim);
103        
104        PatchEmbedding {
105            projection,
106            patch_size: config.patch_size,
107            num_patches: config.num_patches(),
108        }
109    }
110    
111    /// Extract patches from image
112    fn extract_patches(&self, x: &Tensor) -> Result<Tensor, String> {
113        // Input: [batch, channels, height, width]
114        // Output: [batch, num_patches, patch_dim]
115        
116        let dims = x.dims();
117        if dims.len() != 4 {
118            return Err(format!("Expected 4D input, got {}D", dims.len()));
119        }
120        
121        let batch_size = dims[0];
122        let channels = dims[1];
123        let height = dims[2];
124        let width = dims[3];
125        
126        let num_patches_h = height / self.patch_size;
127        let num_patches_w = width / self.patch_size;
128        let patch_dim = self.patch_size * self.patch_size * channels;
129        
130        // Extract patches
131        let x_data = x.data_f32();
132        let mut patches = Vec::with_capacity(batch_size * num_patches_h * num_patches_w * patch_dim);
133        
134        for b in 0..batch_size {
135            for ph in 0..num_patches_h {
136                for pw in 0..num_patches_w {
137                    // Extract single patch
138                    for c in 0..channels {
139                        for h in 0..self.patch_size {
140                            for w in 0..self.patch_size {
141                                let y = ph * self.patch_size + h;
142                                let x_pos = pw * self.patch_size + w;
143                                let idx = b * (channels * height * width) +
144                                         c * (height * width) +
145                                         y * width +
146                                         x_pos;
147                                patches.push(x_data[idx]);
148                            }
149                        }
150                    }
151                }
152            }
153        }
154        
155        Tensor::from_slice(&patches, &[batch_size, num_patches_h * num_patches_w, patch_dim])
156            .map_err(|e| format!("Failed to create patches tensor: {:?}", e))
157    }
158    
159    /// Forward pass
160    pub fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
161        // Extract patches
162        let patches = self.extract_patches(x)?;
163        
164        // Project patches to embedding dimension
165        Ok(self.projection.forward(&patches))
166    }
167}
168
169// Note: PatchEmbedding doesn't implement Module trait due to Result return type
170// Use patch_embed.forward() directly instead
171
172/// Vision Transformer model
173pub struct VisionTransformer {
174    /// Configuration
175    config: ViTConfig,
176    /// Patch embedding
177    patch_embed: PatchEmbedding,
178    /// Class token
179    cls_token: Tensor,
180    /// Position embedding
181    pos_embed: Tensor,
182    /// Transformer encoder
183    encoder: TransformerEncoder,
184    /// Layer normalization
185    norm: LayerNorm,
186    /// Classification head
187    head: Linear,
188}
189
190impl VisionTransformer {
191    /// Create new Vision Transformer
192    pub fn new(config: ViTConfig) -> Self {
193        let patch_embed = PatchEmbedding::new(&config);
194        
195        // Create class token [1, 1, embed_dim]
196        let cls_token = Tensor::randn(&[1, 1, config.embed_dim]);
197        
198        // Create position embedding [1, num_patches + 1, embed_dim]
199        let num_positions = config.num_patches() + 1; // +1 for class token
200        let pos_embed = Tensor::randn(&[1, num_positions, config.embed_dim]);
201        
202        // Create transformer encoder
203        let encoder = TransformerEncoder::new(
204            config.embed_dim,
205            config.num_heads,
206            config.mlp_dim,
207            config.num_layers,
208            config.dropout,
209        );
210        
211        // Layer norm
212        let norm = LayerNorm::new(&[config.embed_dim]);
213        
214        // Classification head
215        let head = Linear::new(config.embed_dim, config.num_classes);
216        
217        VisionTransformer {
218            config,
219            patch_embed,
220            cls_token,
221            pos_embed,
222            encoder,
223            norm,
224            head,
225        }
226    }
227    
228    /// Forward pass
229    pub fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
230        let batch_size = x.dims()[0];
231        
232        // Patch embedding: [batch, num_patches, embed_dim]
233        let x = self.patch_embed.forward(x)?;
234        
235        // Expand class token for batch: [batch, 1, embed_dim]
236        let cls_tokens = self.expand_cls_token(batch_size)?;
237        
238        // Concatenate class token: [batch, num_patches + 1, embed_dim]
239        let x = self.concat_cls_token(&x, &cls_tokens)?;
240        
241        // Add position embedding
242        let x = self.add_position_embedding(&x)?;
243        
244        // Transformer encoder
245        let x = self.encoder.forward(&x);
246        
247        // Layer norm
248        let x = self.norm.forward(&x);
249        
250        // Extract class token: [batch, embed_dim]
251        let cls_output = self.extract_cls_token(&x)?;
252        
253        // Classification head: [batch, num_classes]
254        Ok(self.head.forward(&cls_output))
255    }
256    
257    /// Expand class token for batch
258    fn expand_cls_token(&self, batch_size: usize) -> Result<Tensor, String> {
259        let cls_data = self.cls_token.data_f32();
260        let embed_dim = self.config.embed_dim;
261        
262        let mut expanded = Vec::with_capacity(batch_size * embed_dim);
263        for _ in 0..batch_size {
264            expanded.extend_from_slice(&cls_data);
265        }
266        
267        Tensor::from_slice(&expanded, &[batch_size, 1, embed_dim])
268            .map_err(|e| format!("Failed to expand class token: {:?}", e))
269    }
270    
271    /// Concatenate class token with patches
272    fn concat_cls_token(&self, patches: &Tensor, cls_tokens: &Tensor) -> Result<Tensor, String> {
273        let patches_data = patches.data_f32();
274        let cls_data = cls_tokens.data_f32();
275        
276        let dims = patches.dims();
277        let batch_size = dims[0];
278        let num_patches = dims[1];
279        let embed_dim = dims[2];
280        
281        let mut concatenated = Vec::with_capacity(batch_size * (num_patches + 1) * embed_dim);
282        
283        for b in 0..batch_size {
284            // Add class token
285            let cls_start = b * embed_dim;
286            concatenated.extend_from_slice(&cls_data[cls_start..cls_start + embed_dim]);
287            
288            // Add patches
289            let patch_start = b * num_patches * embed_dim;
290            let patch_end = patch_start + num_patches * embed_dim;
291            concatenated.extend_from_slice(&patches_data[patch_start..patch_end]);
292        }
293        
294        Tensor::from_slice(&concatenated, &[batch_size, num_patches + 1, embed_dim])
295            .map_err(|e| format!("Failed to concatenate tokens: {:?}", e))
296    }
297    
298    /// Add position embedding
299    fn add_position_embedding(&self, x: &Tensor) -> Result<Tensor, String> {
300        let x_data = x.data_f32();
301        let pos_data = self.pos_embed.data_f32();
302        
303        let dims = x.dims();
304        let batch_size = dims[0];
305        let seq_len = dims[1];
306        let embed_dim = dims[2];
307        
308        let mut result = Vec::with_capacity(x_data.len());
309        
310        for b in 0..batch_size {
311            for s in 0..seq_len {
312                for d in 0..embed_dim {
313                    let x_idx = b * seq_len * embed_dim + s * embed_dim + d;
314                    let pos_idx = s * embed_dim + d;
315                    result.push(x_data[x_idx] + pos_data[pos_idx]);
316                }
317            }
318        }
319        
320        Tensor::from_slice(&result, &[batch_size, seq_len, embed_dim])
321            .map_err(|e| format!("Failed to add position embedding: {:?}", e))
322    }
323    
324    /// Extract class token from sequence
325    fn extract_cls_token(&self, x: &Tensor) -> Result<Tensor, String> {
326        let x_data = x.data_f32();
327        let dims = x.dims();
328        let batch_size = dims[0];
329        let seq_len = dims[1];
330        let embed_dim = dims[2];
331        
332        let mut cls_output = Vec::with_capacity(batch_size * embed_dim);
333        
334        for b in 0..batch_size {
335            let start = b * seq_len * embed_dim;
336            let end = start + embed_dim;
337            cls_output.extend_from_slice(&x_data[start..end]);
338        }
339        
340        Tensor::from_slice(&cls_output, &[batch_size, embed_dim])
341            .map_err(|e| format!("Failed to extract class token: {:?}", e))
342    }
343}
344
345// Note: VisionTransformer doesn't implement Module trait due to Result return type
346// Use vit.forward() directly instead
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    
352    #[test]
353    fn test_vit_config() {
354        let config = ViTConfig::vit_base();
355        assert_eq!(config.num_patches(), 196); // (224/16)^2
356        
357        let config = ViTConfig::vit_large();
358        assert_eq!(config.embed_dim, 1024);
359        
360        let config = ViTConfig::vit_huge();
361        assert_eq!(config.embed_dim, 1280);
362    }
363    
364    #[test]
365    fn test_patch_embedding() {
366        let config = ViTConfig {
367            image_size: 32,
368            patch_size: 8,
369            in_channels: 3,
370            embed_dim: 64,
371            ..Default::default()
372        };
373        
374        let patch_embed = PatchEmbedding::new(&config);
375        let input = Tensor::randn(&[2, 3, 32, 32]); // batch=2, channels=3, 32x32
376        
377        let output = patch_embed.forward(&input).unwrap();
378        assert_eq!(output.dims(), &[2, 16, 64]); // 16 patches, 64 embed_dim
379    }
380    
381    #[test]
382    fn test_vision_transformer() {
383        let config = ViTConfig {
384            image_size: 32,
385            patch_size: 8,
386            in_channels: 3,
387            embed_dim: 64,
388            num_layers: 2,
389            num_heads: 4,
390            mlp_dim: 128,
391            num_classes: 10,
392            dropout: 0.1,
393        };
394        
395        let vit = VisionTransformer::new(config);
396        let input = Tensor::randn(&[2, 3, 32, 32]);
397        
398        let output = vit.forward(&input).unwrap();
399        assert_eq!(output.dims(), &[2, 10]); // batch=2, num_classes=10
400    }
401}