1use ghostflow_core::Tensor;
10use crate::transformer::TransformerEncoder;
11use crate::linear::Linear;
12use crate::norm::LayerNorm;
13use crate::Module;
14
15#[derive(Debug, Clone)]
17pub struct ViTConfig {
18 pub image_size: usize,
20 pub patch_size: usize,
22 pub in_channels: usize,
24 pub embed_dim: usize,
26 pub num_layers: usize,
28 pub num_heads: usize,
30 pub mlp_dim: usize,
32 pub num_classes: usize,
34 pub dropout: f32,
36}
37
38impl Default for ViTConfig {
39 fn default() -> Self {
40 ViTConfig {
41 image_size: 224,
42 patch_size: 16,
43 in_channels: 3,
44 embed_dim: 768,
45 num_layers: 12,
46 num_heads: 12,
47 mlp_dim: 3072,
48 num_classes: 1000,
49 dropout: 0.1,
50 }
51 }
52}
53
54impl ViTConfig {
55 pub fn vit_base() -> Self {
57 Self::default()
58 }
59
60 pub fn vit_large() -> Self {
62 ViTConfig {
63 embed_dim: 1024,
64 num_layers: 24,
65 num_heads: 16,
66 mlp_dim: 4096,
67 ..Default::default()
68 }
69 }
70
71 pub fn vit_huge() -> Self {
73 ViTConfig {
74 embed_dim: 1280,
75 num_layers: 32,
76 num_heads: 16,
77 mlp_dim: 5120,
78 ..Default::default()
79 }
80 }
81
82 pub fn num_patches(&self) -> usize {
84 (self.image_size / self.patch_size) * (self.image_size / self.patch_size)
85 }
86}
87
88pub struct PatchEmbedding {
90 projection: Linear,
92 patch_size: usize,
94 num_patches: usize,
96}
97
98impl PatchEmbedding {
99 pub fn new(config: &ViTConfig) -> Self {
101 let patch_dim = config.patch_size * config.patch_size * config.in_channels;
102 let projection = Linear::new(patch_dim, config.embed_dim);
103
104 PatchEmbedding {
105 projection,
106 patch_size: config.patch_size,
107 num_patches: config.num_patches(),
108 }
109 }
110
111 fn extract_patches(&self, x: &Tensor) -> Result<Tensor, String> {
113 let dims = x.dims();
117 if dims.len() != 4 {
118 return Err(format!("Expected 4D input, got {}D", dims.len()));
119 }
120
121 let batch_size = dims[0];
122 let channels = dims[1];
123 let height = dims[2];
124 let width = dims[3];
125
126 let num_patches_h = height / self.patch_size;
127 let num_patches_w = width / self.patch_size;
128 let patch_dim = self.patch_size * self.patch_size * channels;
129
130 let x_data = x.data_f32();
132 let mut patches = Vec::with_capacity(batch_size * num_patches_h * num_patches_w * patch_dim);
133
134 for b in 0..batch_size {
135 for ph in 0..num_patches_h {
136 for pw in 0..num_patches_w {
137 for c in 0..channels {
139 for h in 0..self.patch_size {
140 for w in 0..self.patch_size {
141 let y = ph * self.patch_size + h;
142 let x_pos = pw * self.patch_size + w;
143 let idx = b * (channels * height * width) +
144 c * (height * width) +
145 y * width +
146 x_pos;
147 patches.push(x_data[idx]);
148 }
149 }
150 }
151 }
152 }
153 }
154
155 Tensor::from_slice(&patches, &[batch_size, num_patches_h * num_patches_w, patch_dim])
156 .map_err(|e| format!("Failed to create patches tensor: {:?}", e))
157 }
158
159 pub fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
161 let patches = self.extract_patches(x)?;
163
164 Ok(self.projection.forward(&patches))
166 }
167}
168
169pub struct VisionTransformer {
174 config: ViTConfig,
176 patch_embed: PatchEmbedding,
178 cls_token: Tensor,
180 pos_embed: Tensor,
182 encoder: TransformerEncoder,
184 norm: LayerNorm,
186 head: Linear,
188}
189
190impl VisionTransformer {
191 pub fn new(config: ViTConfig) -> Self {
193 let patch_embed = PatchEmbedding::new(&config);
194
195 let cls_token = Tensor::randn(&[1, 1, config.embed_dim]);
197
198 let num_positions = config.num_patches() + 1; let pos_embed = Tensor::randn(&[1, num_positions, config.embed_dim]);
201
202 let encoder = TransformerEncoder::new(
204 config.embed_dim,
205 config.num_heads,
206 config.mlp_dim,
207 config.num_layers,
208 config.dropout,
209 );
210
211 let norm = LayerNorm::new(&[config.embed_dim]);
213
214 let head = Linear::new(config.embed_dim, config.num_classes);
216
217 VisionTransformer {
218 config,
219 patch_embed,
220 cls_token,
221 pos_embed,
222 encoder,
223 norm,
224 head,
225 }
226 }
227
228 pub fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
230 let batch_size = x.dims()[0];
231
232 let x = self.patch_embed.forward(x)?;
234
235 let cls_tokens = self.expand_cls_token(batch_size)?;
237
238 let x = self.concat_cls_token(&x, &cls_tokens)?;
240
241 let x = self.add_position_embedding(&x)?;
243
244 let x = self.encoder.forward(&x);
246
247 let x = self.norm.forward(&x);
249
250 let cls_output = self.extract_cls_token(&x)?;
252
253 Ok(self.head.forward(&cls_output))
255 }
256
257 fn expand_cls_token(&self, batch_size: usize) -> Result<Tensor, String> {
259 let cls_data = self.cls_token.data_f32();
260 let embed_dim = self.config.embed_dim;
261
262 let mut expanded = Vec::with_capacity(batch_size * embed_dim);
263 for _ in 0..batch_size {
264 expanded.extend_from_slice(&cls_data);
265 }
266
267 Tensor::from_slice(&expanded, &[batch_size, 1, embed_dim])
268 .map_err(|e| format!("Failed to expand class token: {:?}", e))
269 }
270
271 fn concat_cls_token(&self, patches: &Tensor, cls_tokens: &Tensor) -> Result<Tensor, String> {
273 let patches_data = patches.data_f32();
274 let cls_data = cls_tokens.data_f32();
275
276 let dims = patches.dims();
277 let batch_size = dims[0];
278 let num_patches = dims[1];
279 let embed_dim = dims[2];
280
281 let mut concatenated = Vec::with_capacity(batch_size * (num_patches + 1) * embed_dim);
282
283 for b in 0..batch_size {
284 let cls_start = b * embed_dim;
286 concatenated.extend_from_slice(&cls_data[cls_start..cls_start + embed_dim]);
287
288 let patch_start = b * num_patches * embed_dim;
290 let patch_end = patch_start + num_patches * embed_dim;
291 concatenated.extend_from_slice(&patches_data[patch_start..patch_end]);
292 }
293
294 Tensor::from_slice(&concatenated, &[batch_size, num_patches + 1, embed_dim])
295 .map_err(|e| format!("Failed to concatenate tokens: {:?}", e))
296 }
297
298 fn add_position_embedding(&self, x: &Tensor) -> Result<Tensor, String> {
300 let x_data = x.data_f32();
301 let pos_data = self.pos_embed.data_f32();
302
303 let dims = x.dims();
304 let batch_size = dims[0];
305 let seq_len = dims[1];
306 let embed_dim = dims[2];
307
308 let mut result = Vec::with_capacity(x_data.len());
309
310 for b in 0..batch_size {
311 for s in 0..seq_len {
312 for d in 0..embed_dim {
313 let x_idx = b * seq_len * embed_dim + s * embed_dim + d;
314 let pos_idx = s * embed_dim + d;
315 result.push(x_data[x_idx] + pos_data[pos_idx]);
316 }
317 }
318 }
319
320 Tensor::from_slice(&result, &[batch_size, seq_len, embed_dim])
321 .map_err(|e| format!("Failed to add position embedding: {:?}", e))
322 }
323
324 fn extract_cls_token(&self, x: &Tensor) -> Result<Tensor, String> {
326 let x_data = x.data_f32();
327 let dims = x.dims();
328 let batch_size = dims[0];
329 let seq_len = dims[1];
330 let embed_dim = dims[2];
331
332 let mut cls_output = Vec::with_capacity(batch_size * embed_dim);
333
334 for b in 0..batch_size {
335 let start = b * seq_len * embed_dim;
336 let end = start + embed_dim;
337 cls_output.extend_from_slice(&x_data[start..end]);
338 }
339
340 Tensor::from_slice(&cls_output, &[batch_size, embed_dim])
341 .map_err(|e| format!("Failed to extract class token: {:?}", e))
342 }
343}
344
345#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_vit_config() {
354 let config = ViTConfig::vit_base();
355 assert_eq!(config.num_patches(), 196); let config = ViTConfig::vit_large();
358 assert_eq!(config.embed_dim, 1024);
359
360 let config = ViTConfig::vit_huge();
361 assert_eq!(config.embed_dim, 1280);
362 }
363
364 #[test]
365 fn test_patch_embedding() {
366 let config = ViTConfig {
367 image_size: 32,
368 patch_size: 8,
369 in_channels: 3,
370 embed_dim: 64,
371 ..Default::default()
372 };
373
374 let patch_embed = PatchEmbedding::new(&config);
375 let input = Tensor::randn(&[2, 3, 32, 32]); let output = patch_embed.forward(&input).unwrap();
378 assert_eq!(output.dims(), &[2, 16, 64]); }
380
381 #[test]
382 fn test_vision_transformer() {
383 let config = ViTConfig {
384 image_size: 32,
385 patch_size: 8,
386 in_channels: 3,
387 embed_dim: 64,
388 num_layers: 2,
389 num_heads: 4,
390 mlp_dim: 128,
391 num_classes: 10,
392 dropout: 0.1,
393 };
394
395 let vit = VisionTransformer::new(config);
396 let input = Tensor::randn(&[2, 3, 32, 32]);
397
398 let output = vit.forward(&input).unwrap();
399 assert_eq!(output.dims(), &[2, 10]); }
401}