1use ghostflow_core::Tensor;
11use crate::linear::Linear;
12use crate::vision_transformer::{VisionTransformer, ViTConfig};
13use crate::Module;
14
15#[derive(Debug, Clone)]
17pub struct CLIPConfig {
18 pub embed_dim: usize,
20 pub vision_config: CLIPVisionConfig,
22 pub text_config: CLIPTextConfig,
24 pub logit_scale_init_value: f32,
26}
27
28#[derive(Debug, Clone)]
30pub struct CLIPVisionConfig {
31 pub image_size: usize,
33 pub patch_size: usize,
35 pub hidden_size: usize,
37 pub num_layers: usize,
39 pub num_heads: usize,
41 pub mlp_ratio: usize,
43}
44
45#[derive(Debug, Clone)]
47pub struct CLIPTextConfig {
48 pub vocab_size: usize,
50 pub hidden_size: usize,
52 pub num_layers: usize,
54 pub num_heads: usize,
56 pub max_position_embeddings: usize,
58}
59
60impl Default for CLIPConfig {
61 fn default() -> Self {
62 CLIPConfig {
63 embed_dim: 512,
64 vision_config: CLIPVisionConfig::default(),
65 text_config: CLIPTextConfig::default(),
66 logit_scale_init_value: 2.6592, }
68 }
69}
70
71impl Default for CLIPVisionConfig {
72 fn default() -> Self {
73 CLIPVisionConfig {
74 image_size: 224,
75 patch_size: 16,
76 hidden_size: 768,
77 num_layers: 12,
78 num_heads: 12,
79 mlp_ratio: 4,
80 }
81 }
82}
83
84impl Default for CLIPTextConfig {
85 fn default() -> Self {
86 CLIPTextConfig {
87 vocab_size: 49408,
88 hidden_size: 512,
89 num_layers: 12,
90 num_heads: 8,
91 max_position_embeddings: 77,
92 }
93 }
94}
95
96impl CLIPConfig {
97 pub fn vit_b_32() -> Self {
99 CLIPConfig {
100 embed_dim: 512,
101 vision_config: CLIPVisionConfig {
102 image_size: 224,
103 patch_size: 32,
104 hidden_size: 768,
105 num_layers: 12,
106 num_heads: 12,
107 mlp_ratio: 4,
108 },
109 text_config: CLIPTextConfig {
110 vocab_size: 49408,
111 hidden_size: 512,
112 num_layers: 12,
113 num_heads: 8,
114 max_position_embeddings: 77,
115 },
116 logit_scale_init_value: 2.6592,
117 }
118 }
119
120 pub fn vit_b_16() -> Self {
122 CLIPConfig {
123 embed_dim: 512,
124 vision_config: CLIPVisionConfig {
125 image_size: 224,
126 patch_size: 16,
127 hidden_size: 768,
128 num_layers: 12,
129 num_heads: 12,
130 mlp_ratio: 4,
131 },
132 text_config: CLIPTextConfig::default(),
133 logit_scale_init_value: 2.6592,
134 }
135 }
136
137 pub fn vit_l_14() -> Self {
139 CLIPConfig {
140 embed_dim: 768,
141 vision_config: CLIPVisionConfig {
142 image_size: 224,
143 patch_size: 14,
144 hidden_size: 1024,
145 num_layers: 24,
146 num_heads: 16,
147 mlp_ratio: 4,
148 },
149 text_config: CLIPTextConfig {
150 vocab_size: 49408,
151 hidden_size: 768,
152 num_layers: 12,
153 num_heads: 12,
154 max_position_embeddings: 77,
155 },
156 logit_scale_init_value: 2.6592,
157 }
158 }
159}
160
161pub struct CLIPVisionEncoder {
163 vit: VisionTransformer,
164 projection: Linear,
165}
166
167impl CLIPVisionEncoder {
168 pub fn new(config: &CLIPVisionConfig, embed_dim: usize) -> Self {
170 let vit_config = ViTConfig {
172 image_size: config.image_size,
173 patch_size: config.patch_size,
174 in_channels: 3,
175 embed_dim: config.hidden_size,
176 num_layers: config.num_layers,
177 num_heads: config.num_heads,
178 mlp_dim: config.hidden_size * config.mlp_ratio,
179 num_classes: 0, dropout: 0.0,
181 };
182
183 let vit = VisionTransformer::new(vit_config);
184 let projection = Linear::new(config.hidden_size, embed_dim);
185
186 CLIPVisionEncoder { vit, projection }
187 }
188
189 pub fn forward(&self, images: &Tensor) -> Result<Tensor, String> {
191 let features = self.vit.forward(images)?;
193
194 Ok(self.projection.forward(&features))
196 }
197}
198
199pub struct CLIPTextEncoder {
201 token_embedding: Tensor,
202 position_embedding: Tensor,
203 layers: Vec<CLIPTextLayer>,
204 ln_final: LayerNorm,
205 projection: Linear,
206}
207
208impl CLIPTextEncoder {
209 pub fn new(config: &CLIPTextConfig, embed_dim: usize) -> Self {
211 let token_embedding = Tensor::randn(&[config.vocab_size, config.hidden_size]);
212 let position_embedding = Tensor::randn(&[config.max_position_embeddings, config.hidden_size]);
213
214 let layers = (0..config.num_layers)
215 .map(|_| CLIPTextLayer::new(config.hidden_size, config.num_heads))
216 .collect();
217
218 let ln_final = LayerNorm::new(config.hidden_size, 1e-5);
219 let projection = Linear::new(config.hidden_size, embed_dim);
220
221 CLIPTextEncoder {
222 token_embedding,
223 position_embedding,
224 layers,
225 ln_final,
226 projection,
227 }
228 }
229
230 pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
232 let dims = input_ids.dims();
233 let seq_length = dims[1];
234
235 let mut hidden_states = self.get_token_embeddings(input_ids)?;
237
238 hidden_states = self.add_position_embeddings(&hidden_states, seq_length)?;
240
241 for layer in &self.layers {
243 hidden_states = layer.forward(&hidden_states)?;
244 }
245
246 hidden_states = self.ln_final.forward(&hidden_states)?;
248
249 let features = self.extract_eos_features(&hidden_states, seq_length)?;
251
252 Ok(self.projection.forward(&features))
254 }
255
256 fn get_token_embeddings(&self, input_ids: &Tensor) -> Result<Tensor, String> {
257 let ids_data = input_ids.data_f32();
258 let embed_data = self.token_embedding.data_f32();
259 let dims = input_ids.dims();
260 let batch_size = dims[0];
261 let seq_length = dims[1];
262 let hidden_size = self.token_embedding.dims()[1];
263
264 let mut result = Vec::with_capacity(batch_size * seq_length * hidden_size);
265
266 for &id in ids_data.iter() {
267 let idx = id as usize;
268 let start = idx * hidden_size;
269 let end = start + hidden_size;
270 result.extend_from_slice(&embed_data[start..end]);
271 }
272
273 Tensor::from_slice(&result, &[batch_size, seq_length, hidden_size])
274 .map_err(|e| format!("Failed to create embeddings: {:?}", e))
275 }
276
277 fn add_position_embeddings(&self, hidden_states: &Tensor, seq_length: usize) -> Result<Tensor, String> {
278 let pos_embed_data = self.position_embedding.data_f32();
279 let hidden_data = hidden_states.data_f32();
280 let dims = hidden_states.dims();
281 let hidden_size = dims[2];
282
283 let mut result = Vec::with_capacity(hidden_data.len());
284
285 for i in 0..hidden_data.len() {
286 let pos = (i / hidden_size) % seq_length;
287 let pos_idx = pos * hidden_size + (i % hidden_size);
288 result.push(hidden_data[i] + pos_embed_data[pos_idx]);
289 }
290
291 Tensor::from_slice(&result, dims)
292 .map_err(|e| format!("Failed to add position embeddings: {:?}", e))
293 }
294
295 fn extract_eos_features(&self, hidden_states: &Tensor, seq_length: usize) -> Result<Tensor, String> {
296 let data = hidden_states.data_f32();
297 let dims = hidden_states.dims();
298 let batch_size = dims[0];
299 let hidden_size = dims[2];
300
301 let mut result = Vec::with_capacity(batch_size * hidden_size);
302
303 for b in 0..batch_size {
305 let start = (b * seq_length + seq_length - 1) * hidden_size;
306 let end = start + hidden_size;
307 result.extend_from_slice(&data[start..end]);
308 }
309
310 Tensor::from_slice(&result, &[batch_size, hidden_size])
311 .map_err(|e| format!("Failed to extract EOS features: {:?}", e))
312 }
313}
314
315pub struct CLIPTextLayer {
317 self_attn: MultiHeadAttention,
318 mlp: MLP,
319 ln1: LayerNorm,
320 ln2: LayerNorm,
321}
322
323impl CLIPTextLayer {
324 fn new(hidden_size: usize, num_heads: usize) -> Self {
325 CLIPTextLayer {
326 self_attn: MultiHeadAttention::new(hidden_size, num_heads),
327 mlp: MLP::new(hidden_size, hidden_size * 4),
328 ln1: LayerNorm::new(hidden_size, 1e-5),
329 ln2: LayerNorm::new(hidden_size, 1e-5),
330 }
331 }
332
333 fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
334 let residual = x.clone();
336 let x = self.ln1.forward(x)?;
337 let x = self.self_attn.forward(&x)?;
338 let x = x.add(&residual).unwrap_or(x);
339
340 let residual = x.clone();
342 let x = self.ln2.forward(&x)?;
343 let x = self.mlp.forward(&x)?;
344 let x = x.add(&residual).unwrap_or(x);
345
346 Ok(x)
347 }
348}
349
350pub struct MultiHeadAttention {
352 q_proj: Linear,
353 _k_proj: Linear,
354 _v_proj: Linear,
355 out_proj: Linear,
356 _num_heads: usize,
357 _head_dim: usize,
358}
359
360impl MultiHeadAttention {
361 fn new(hidden_size: usize, num_heads: usize) -> Self {
362 let head_dim = hidden_size / num_heads;
363 MultiHeadAttention {
364 q_proj: Linear::new(hidden_size, hidden_size),
365 _k_proj: Linear::new(hidden_size, hidden_size),
366 _v_proj: Linear::new(hidden_size, hidden_size),
367 out_proj: Linear::new(hidden_size, hidden_size),
368 _num_heads: num_heads,
369 _head_dim: head_dim,
370 }
371 }
372
373 fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
374 let q = self.q_proj.forward(x);
375 Ok(self.out_proj.forward(&q))
377 }
378}
379
380pub struct MLP {
382 fc1: Linear,
383 fc2: Linear,
384}
385
386impl MLP {
387 fn new(hidden_size: usize, intermediate_size: usize) -> Self {
388 MLP {
389 fc1: Linear::new(hidden_size, intermediate_size),
390 fc2: Linear::new(intermediate_size, hidden_size),
391 }
392 }
393
394 fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
395 let x = self.fc1.forward(x);
396 let x = x.gelu();
397 Ok(self.fc2.forward(&x))
398 }
399}
400
401pub struct LayerNorm {
403 weight: Tensor,
404 bias: Tensor,
405 eps: f32,
406}
407
408impl LayerNorm {
409 fn new(hidden_size: usize, eps: f32) -> Self {
410 LayerNorm {
411 weight: Tensor::ones(&[hidden_size]),
412 bias: Tensor::zeros(&[hidden_size]),
413 eps,
414 }
415 }
416
417 fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
418 let x_data = x.data_f32();
419 let dims = x.dims();
420 let hidden_size = dims[dims.len() - 1];
421 let batch_seq = x_data.len() / hidden_size;
422
423 let weight_data = self.weight.data_f32();
424 let bias_data = self.bias.data_f32();
425 let mut result = Vec::with_capacity(x_data.len());
426
427 for i in 0..batch_seq {
428 let start = i * hidden_size;
429 let end = start + hidden_size;
430 let slice = &x_data[start..end];
431
432 let mean: f32 = slice.iter().sum::<f32>() / hidden_size as f32;
434 let variance: f32 = slice.iter()
435 .map(|x| (x - mean).powi(2))
436 .sum::<f32>() / hidden_size as f32;
437 let std = (variance + self.eps).sqrt();
438
439 for (j, &x) in slice.iter().enumerate() {
441 result.push((x - mean) / std * weight_data[j] + bias_data[j]);
442 }
443 }
444
445 Tensor::from_slice(&result, dims)
446 .map_err(|e| format!("Failed to normalize: {:?}", e))
447 }
448}
449
450pub struct CLIP {
452 vision_encoder: CLIPVisionEncoder,
453 text_encoder: CLIPTextEncoder,
454 logit_scale: f32,
455}
456
457impl CLIP {
458 pub fn new(config: CLIPConfig) -> Self {
460 let vision_encoder = CLIPVisionEncoder::new(&config.vision_config, config.embed_dim);
461 let text_encoder = CLIPTextEncoder::new(&config.text_config, config.embed_dim);
462 let logit_scale = config.logit_scale_init_value.exp();
463
464 CLIP {
465 vision_encoder,
466 text_encoder,
467 logit_scale,
468 }
469 }
470
471 pub fn encode_image(&self, images: &Tensor) -> Result<Tensor, String> {
473 let features = self.vision_encoder.forward(images)?;
474 Ok(self.normalize(&features))
475 }
476
477 pub fn encode_text(&self, input_ids: &Tensor) -> Result<Tensor, String> {
479 let features = self.text_encoder.forward(input_ids)?;
480 Ok(self.normalize(&features))
481 }
482
483 pub fn forward(&self, images: &Tensor, input_ids: &Tensor) -> Result<Tensor, String> {
485 let image_features = self.encode_image(images)?;
486 let text_features = self.encode_text(input_ids)?;
487
488 self.compute_similarity(&image_features, &text_features)
490 }
491
492 fn normalize(&self, x: &Tensor) -> Tensor {
494 let data = x.data_f32();
495 let dims = x.dims();
496 let feature_dim = dims[dims.len() - 1];
497 let batch_size = data.len() / feature_dim;
498
499 let mut result = Vec::with_capacity(data.len());
500
501 for i in 0..batch_size {
502 let start = i * feature_dim;
503 let end = start + feature_dim;
504 let slice = &data[start..end];
505
506 let norm: f32 = slice.iter().map(|x| x * x).sum::<f32>().sqrt();
508 let norm = norm.max(1e-8); for &x in slice.iter() {
512 result.push(x / norm);
513 }
514 }
515
516 Tensor::from_slice(&result, dims).unwrap_or_else(|_| x.clone())
517 }
518
519 fn compute_similarity(&self, image_features: &Tensor, text_features: &Tensor) -> Result<Tensor, String> {
521 let img_data = image_features.data_f32();
522 let txt_data = text_features.data_f32();
523
524 let img_dims = image_features.dims();
525 let txt_dims = text_features.dims();
526
527 let num_images = img_dims[0];
528 let num_texts = txt_dims[0];
529 let feature_dim = img_dims[1];
530
531 let mut result = Vec::with_capacity(num_images * num_texts);
532
533 for i in 0..num_images {
535 for j in 0..num_texts {
536 let mut dot_product = 0.0;
537 for k in 0..feature_dim {
538 dot_product += img_data[i * feature_dim + k] * txt_data[j * feature_dim + k];
539 }
540 result.push(dot_product * self.logit_scale);
541 }
542 }
543
544 Tensor::from_slice(&result, &[num_images, num_texts])
545 .map_err(|e| format!("Failed to compute similarity: {:?}", e))
546 }
547
548 pub fn zero_shot_classify(&self, images: &Tensor, text_prompts: &Tensor) -> Result<Vec<usize>, String> {
550 let similarity = self.forward(images, text_prompts)?;
551 let data = similarity.data_f32();
552 let dims = similarity.dims();
553 let num_images = dims[0];
554 let num_classes = dims[1];
555
556 let mut predictions = Vec::with_capacity(num_images);
557
558 for i in 0..num_images {
559 let start = i * num_classes;
560 let end = start + num_classes;
561 let scores = &data[start..end];
562
563 let pred = scores.iter()
564 .enumerate()
565 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
566 .map(|(idx, _)| idx)
567 .unwrap_or(0);
568
569 predictions.push(pred);
570 }
571
572 Ok(predictions)
573 }
574
575 pub fn image_to_text_retrieval(&self, images: &Tensor, texts: &Tensor) -> Result<Vec<usize>, String> {
577 self.zero_shot_classify(images, texts)
578 }
579
580 pub fn text_to_image_retrieval(&self, images: &Tensor, texts: &Tensor) -> Result<Vec<usize>, String> {
582 let similarity = self.forward(images, texts)?;
583 let data = similarity.data_f32();
584 let dims = similarity.dims();
585 let num_images = dims[0];
586 let num_texts = dims[1];
587
588 let mut predictions = Vec::with_capacity(num_texts);
589
590 for j in 0..num_texts {
592 let mut best_idx = 0;
593 let mut best_score = data[j];
594
595 for i in 1..num_images {
596 let score = data[i * num_texts + j];
597 if score > best_score {
598 best_score = score;
599 best_idx = i;
600 }
601 }
602
603 predictions.push(best_idx);
604 }
605
606 Ok(predictions)
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613
614 #[test]
615 fn test_clip_config() {
616 let config = CLIPConfig::vit_b_32();
617 assert_eq!(config.embed_dim, 512);
618 assert_eq!(config.vision_config.patch_size, 32);
619
620 let config = CLIPConfig::vit_l_14();
621 assert_eq!(config.embed_dim, 768);
622 assert_eq!(config.vision_config.num_layers, 24);
623 }
624
625 #[test]
626 fn test_clip_vision_encoder() {
627 let config = CLIPVisionConfig::default();
628 let encoder = CLIPVisionEncoder::new(&config, 512);
629
630 let images = Tensor::randn(&[2, 3, 224, 224]);
631 let features = encoder.forward(&images).unwrap();
632
633 assert_eq!(features.dims(), &[2, 512]);
634 }
635
636 #[test]
637 fn test_clip_text_encoder() {
638 let config = CLIPTextConfig::default();
639 let encoder = CLIPTextEncoder::new(&config, 512);
640
641 let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
642 let features = encoder.forward(&input_ids).unwrap();
643
644 assert_eq!(features.dims(), &[2, 512]);
645 }
646
647 #[test]
648 fn test_clip_model() {
649 let config = CLIPConfig::vit_b_32();
650 let model = CLIP::new(config);
651
652 let images = Tensor::randn(&[2, 3, 224, 224]);
653 let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
654
655 let similarity = model.forward(&images, &input_ids).unwrap();
656 assert_eq!(similarity.dims(), &[2, 2]); }
658
659 #[test]
660 fn test_zero_shot_classification() {
661 let config = CLIPConfig::vit_b_32();
662 let model = CLIP::new(config);
663
664 let images = Tensor::randn(&[3, 3, 224, 224]);
665 let text_prompts = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
666
667 let predictions = model.zero_shot_classify(&images, &text_prompts).unwrap();
668 assert_eq!(predictions.len(), 3); }
670
671 #[test]
672 fn test_layer_norm() {
673 let ln = LayerNorm::new(128, 1e-5);
674 let x = Tensor::randn(&[2, 4, 128]);
675 let output = ln.forward(&x).unwrap();
676 assert_eq!(output.dims(), &[2, 4, 128]);
677 }
678}