1use crate::{
11 cross_modal_embeddings::{
12 AudioData, ImageData, Modality, ModalityData, MultiModalContent, VideoData,
13 },
14 Vector,
15};
16use anyhow::{anyhow, Result};
17use parking_lot::RwLock;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::sync::Arc;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct JointEmbeddingConfig {
25 pub joint_dim: usize,
27 pub temperature: f32,
29 pub learning_rate: f32,
31 pub margin: f32,
33 pub contrastive_learning: bool,
35 pub triplet_loss: bool,
37 pub hard_negative_mining: bool,
39 pub batch_size: usize,
41 pub negative_samples: usize,
43 pub curriculum_learning: bool,
45 pub weight_decay: f32,
47 pub gradient_clip: f32,
49 pub domain_adaptation: bool,
51 pub alignment_strength: f32,
53 pub self_supervised: bool,
55}
56
57impl Default for JointEmbeddingConfig {
58 fn default() -> Self {
59 Self {
60 joint_dim: 512,
61 temperature: 0.07,
62 learning_rate: 1e-4,
63 margin: 0.2,
64 contrastive_learning: true,
65 triplet_loss: false,
66 hard_negative_mining: true,
67 batch_size: 256,
68 negative_samples: 5,
69 curriculum_learning: false,
70 weight_decay: 1e-4,
71 gradient_clip: 1.0,
72 domain_adaptation: true,
73 alignment_strength: 1.0,
74 self_supervised: false,
75 }
76 }
77}
78
79type ContrastivePairs = (
81 Vec<(Modality, Vector, Modality, Vector)>,
82 Vec<(Modality, Vector, Modality, Vector)>,
83);
84
85pub struct JointEmbeddingSpace {
87 config: JointEmbeddingConfig,
88 text_projector: LinearProjector,
89 image_projector: LinearProjector,
90 audio_projector: LinearProjector,
91 video_projector: LinearProjector,
92 attention_mechanism: CrossModalAttention,
93 alignment_cache: Arc<RwLock<HashMap<String, AlignmentPair>>>,
94 training_stats: Arc<RwLock<TrainingStatistics>>,
95 temperature_scheduler: TemperatureScheduler,
96 domain_adapter: DomainAdapter,
97}
98
99#[derive(Debug, Clone)]
101pub struct LinearProjector {
102 weights: Vec<Vec<f32>>,
103 bias: Vec<f32>,
104 input_dim: usize,
105 output_dim: usize,
106 dropout_rate: f32,
107 activation: ActivationFunction,
108}
109
110#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
112pub enum ActivationFunction {
113 ReLU,
114 GELU,
115 Tanh,
116 Sigmoid,
117 Swish,
118 Mish,
119 LeakyReLU(f32),
120}
121
122#[derive(Debug, Clone)]
124pub struct CrossModalAttention {
125 query_projector: LinearProjector,
126 key_projector: LinearProjector,
127 value_projector: LinearProjector,
128 output_projector: LinearProjector,
129 num_heads: usize,
130 head_dim: usize,
131 dropout_rate: f32,
132 scale: f32,
133 enable_relative_pos: bool,
134}
135
136#[derive(Debug, Clone)]
138pub struct AlignmentPair {
139 modality1: Modality,
140 modality2: Modality,
141 embedding1: Vector,
142 embedding2: Vector,
143 similarity: f32,
144 confidence: f32,
145 timestamp: std::time::SystemTime,
146}
147
148#[derive(Debug, Clone, Default)]
150pub struct TrainingStatistics {
151 total_samples: u64,
152 positive_pairs: u64,
153 negative_pairs: u64,
154 average_loss: f32,
155 average_similarity: f32,
156 convergence_rate: f32,
157 alignment_accuracy: f32,
158 cross_modal_retrieval_acc: HashMap<(Modality, Modality), f32>,
159 training_epochs: u32,
160 last_improvement: u32,
161}
162
163#[derive(Debug, Clone)]
165pub struct TemperatureScheduler {
166 initial_temperature: f32,
167 final_temperature: f32,
168 decay_steps: usize,
169 current_step: usize,
170 schedule_type: ScheduleType,
171}
172
173#[derive(Debug, Clone, Copy)]
174pub enum ScheduleType {
175 Linear,
176 Exponential,
177 Cosine,
178 Warmup,
179}
180
181#[derive(Debug, Clone)]
183pub struct DomainAdapter {
184 source_stats: DomainStatistics,
185 target_stats: DomainStatistics,
186 adaptation_weights: Vec<f32>,
187 domain_classifier: Option<DomainClassifier>,
188 adaptation_strength: f32,
189}
190
191#[derive(Debug, Clone, Default)]
192pub struct DomainStatistics {
193 mean: Vec<f32>,
194 variance: Vec<f32>,
195 sample_count: usize,
196 feature_statistics: HashMap<String, f32>,
197}
198
199#[derive(Debug, Clone)]
200pub struct DomainClassifier {
201 weights: Vec<Vec<f32>>,
202 bias: Vec<f32>,
203 accuracy: f32,
204}
205
206pub struct CLIPAligner {
208 joint_space: JointEmbeddingSpace,
209 optimizer: ContrastiveOptimizer,
210 data_augmentation: DataAugmentation,
211 curriculum: CurriculumLearning,
212}
213
214#[derive(Debug, Clone)]
216pub struct ContrastiveOptimizer {
217 learning_rate: f32,
218 momentum: f32,
219 weight_decay: f32,
220 gradient_history: HashMap<String, Vec<f32>>,
221 adaptive_lr: bool,
222 lr_schedule: LearningRateSchedule,
223}
224
225#[derive(Debug, Clone, Copy)]
226pub enum LearningRateSchedule {
227 Constant,
228 StepDecay { step_size: usize, gamma: f32 },
229 ExponentialDecay { gamma: f32 },
230 CosineAnnealing { min_lr: f32, max_epochs: usize },
231}
232
233#[derive(Debug, Clone)]
235pub struct DataAugmentation {
236 text_augmentations: Vec<TextAugmentation>,
237 image_augmentations: Vec<ImageAugmentation>,
238 audio_augmentations: Vec<AudioAugmentation>,
239 cross_modal_mixup: bool,
240 augmentation_probability: f32,
241}
242
243#[derive(Debug, Clone)]
244pub enum TextAugmentation {
245 RandomWordDropout(f32),
246 Paraphrasing,
247 BackTranslation,
248 SynonymReplacement(f32),
249 ContextualAugmentation,
250}
251
252#[derive(Debug, Clone)]
253pub enum ImageAugmentation {
254 RandomCrop {
255 size: (u32, u32),
256 },
257 RandomFlip {
258 horizontal: bool,
259 vertical: bool,
260 },
261 ColorJitter {
262 brightness: f32,
263 contrast: f32,
264 saturation: f32,
265 },
266 RandomRotation {
267 max_angle: f32,
268 },
269 GaussianBlur {
270 sigma: f32,
271 },
272}
273
274#[derive(Debug, Clone)]
275pub enum AudioAugmentation {
276 TimeStretch { factor: f32 },
277 PitchShift { semitones: f32 },
278 AddNoise { snr_db: f32 },
279 FrequencyMasking { max_freq_mask: f32 },
280 TimeMasking { max_time_mask: f32 },
281}
282
283#[derive(Debug, Clone)]
285pub struct CurriculumLearning {
286 enabled: bool,
287 current_difficulty: f32,
288 difficulty_schedule: DifficultySchedule,
289 pacing_function: PacingFunction,
290 competence_threshold: f32,
291}
292
293#[derive(Debug, Clone)]
294pub enum DifficultySchedule {
295 Linear { start: f32, end: f32, epochs: usize },
296 Exponential { base: f32, scale: f32 },
297 Adaptive { improvement_threshold: f32 },
298}
299
300#[derive(Debug, Clone)]
301pub enum PacingFunction {
302 Root,
303 Linear,
304 Logarithmic,
305 Polynomial(f32),
306}
307
308impl LinearProjector {
309 pub fn new(
310 input_dim: usize,
311 output_dim: usize,
312 dropout_rate: f32,
313 activation: ActivationFunction,
314 ) -> Self {
315 let limit = (6.0 / (input_dim + output_dim) as f32).sqrt();
317 let mut weights = Vec::with_capacity(output_dim);
318
319 for _ in 0..output_dim {
320 let mut row = Vec::with_capacity(input_dim);
321 for _ in 0..input_dim {
322 let weight = ((row.len() as f32 * 0.01) % 2.0 - 1.0) * limit;
324 row.push(weight);
325 }
326 weights.push(row);
327 }
328
329 let bias = vec![0.0; output_dim];
330
331 Self {
332 weights,
333 bias,
334 input_dim,
335 output_dim,
336 dropout_rate,
337 activation,
338 }
339 }
340
341 pub fn forward(&self, input: &Vector) -> Result<Vector> {
342 if input.dimensions != self.input_dim {
343 return Err(anyhow!(
344 "Input dimension mismatch: expected {}, got {}",
345 self.input_dim,
346 input.dimensions
347 ));
348 }
349
350 let input_values = input.as_f32();
351 let mut output = vec![0.0; self.output_dim];
352
353 for (i, output_val) in output.iter_mut().enumerate().take(self.output_dim) {
355 let mut sum = self.bias[i];
356 for (j, &input_val) in input_values.iter().enumerate().take(self.input_dim) {
357 sum += input_val * self.weights[i][j];
358 }
359 *output_val = sum;
360 }
361
362 for value in &mut output {
364 *value = self.apply_activation(*value);
365 }
366
367 if self.dropout_rate > 0.0 {
369 for (i, value) in output.iter_mut().enumerate() {
370 if (i as f32 * 0.12345) % 1.0 < self.dropout_rate {
372 *value = 0.0;
373 } else {
374 *value /= 1.0 - self.dropout_rate; }
376 }
377 }
378
379 Ok(Vector::new(output))
380 }
381
382 fn apply_activation(&self, x: f32) -> f32 {
383 match self.activation {
384 ActivationFunction::ReLU => x.max(0.0),
385 ActivationFunction::GELU => {
386 let sqrt_2_pi = (2.0 / std::f32::consts::PI).sqrt();
388 let inner = sqrt_2_pi * (x + 0.044715 * x.powi(3));
389 0.5 * x * (1.0 + inner.tanh())
390 }
391 ActivationFunction::Tanh => x.tanh(),
392 ActivationFunction::Sigmoid => 1.0 / (1.0 + (-x).exp()),
393 ActivationFunction::Swish => x * (1.0 / (1.0 + (-x).exp())), ActivationFunction::Mish => x * (1.0 + x.exp()).ln().tanh(),
395 ActivationFunction::LeakyReLU(alpha) => {
396 if x > 0.0 {
397 x
398 } else {
399 alpha * x
400 }
401 }
402 }
403 }
404
405 pub fn update_weights(&mut self, gradients: &[Vec<f32>], learning_rate: f32) {
406 for i in 0..self.output_dim {
407 for j in 0..self.input_dim {
408 if i < gradients.len() && j < gradients[i].len() {
409 self.weights[i][j] -= learning_rate * gradients[i][j];
410 }
411 }
412 }
413 }
414}
415
416impl CrossModalAttention {
417 pub fn new(
418 input_dim: usize,
419 num_heads: usize,
420 dropout_rate: f32,
421 enable_relative_pos: bool,
422 ) -> Self {
423 let head_dim = input_dim / num_heads;
424 let scale = 1.0 / (head_dim as f32).sqrt();
425
426 Self {
427 query_projector: LinearProjector::new(
428 input_dim,
429 input_dim,
430 dropout_rate,
431 ActivationFunction::ReLU,
432 ),
433 key_projector: LinearProjector::new(
434 input_dim,
435 input_dim,
436 dropout_rate,
437 ActivationFunction::ReLU,
438 ),
439 value_projector: LinearProjector::new(
440 input_dim,
441 input_dim,
442 dropout_rate,
443 ActivationFunction::ReLU,
444 ),
445 output_projector: LinearProjector::new(
446 input_dim,
447 input_dim,
448 dropout_rate,
449 ActivationFunction::ReLU,
450 ),
451 num_heads,
452 head_dim,
453 dropout_rate,
454 scale,
455 enable_relative_pos,
456 }
457 }
458
459 pub fn cross_attention(
460 &self,
461 query_modality: &Vector,
462 key_modality: &Vector,
463 value_modality: &Vector,
464 ) -> Result<Vector> {
465 let query = self.query_projector.forward(query_modality)?;
467 let key = self.key_projector.forward(key_modality)?;
468 let value = self.value_projector.forward(value_modality)?;
469
470 let attended = self.multi_head_attention(&query, &key, &value)?;
472
473 self.output_projector.forward(&attended)
475 }
476
477 fn multi_head_attention(&self, query: &Vector, key: &Vector, value: &Vector) -> Result<Vector> {
478 let query_vals = query.as_f32();
479 let key_vals = key.as_f32();
480 let value_vals = value.as_f32();
481
482 if query_vals.len() != key_vals.len() || key_vals.len() != value_vals.len() {
483 return Err(anyhow!("Dimension mismatch in attention"));
484 }
485
486 let _seq_len = query_vals.len() / self.head_dim;
487 let mut output = vec![0.0; query_vals.len()];
488
489 for head in 0..self.num_heads {
491 let head_start = head * self.head_dim;
492 let head_end = head_start + self.head_dim;
493
494 let head_query = &query_vals[head_start..head_end];
496 let head_key = &key_vals[head_start..head_end];
497 let head_value = &value_vals[head_start..head_end];
498
499 let attention_score = self.compute_attention_score(head_query, head_key);
501
502 for i in 0..self.head_dim {
504 output[head_start + i] = head_value[i] * attention_score;
505 }
506 }
507
508 if self.enable_relative_pos {
510 self.apply_relative_position_encoding(&mut output)?;
511 }
512
513 Ok(Vector::new(output))
514 }
515
516 fn compute_attention_score(&self, query: &[f32], key: &[f32]) -> f32 {
517 let dot_product: f32 = query.iter().zip(key.iter()).map(|(q, k)| q * k).sum();
518 let scaled_score = dot_product * self.scale;
519
520 scaled_score.tanh() }
523
524 fn apply_relative_position_encoding(&self, output: &mut [f32]) -> Result<()> {
525 let output_len = output.len();
527 for (i, value) in output.iter_mut().enumerate() {
528 let pos_encoding = (i as f32 / output_len as f32).sin();
529 *value += 0.1 * pos_encoding; }
531 Ok(())
532 }
533}
534
535impl TemperatureScheduler {
536 pub fn new(
537 initial_temperature: f32,
538 final_temperature: f32,
539 decay_steps: usize,
540 schedule_type: ScheduleType,
541 ) -> Self {
542 Self {
543 initial_temperature,
544 final_temperature,
545 decay_steps,
546 current_step: 0,
547 schedule_type,
548 }
549 }
550
551 pub fn get_current_temperature(&self) -> f32 {
552 if self.current_step >= self.decay_steps {
553 return self.final_temperature;
554 }
555
556 let progress = self.current_step as f32 / self.decay_steps as f32;
557
558 match self.schedule_type {
559 ScheduleType::Linear => {
560 self.initial_temperature
561 + (self.final_temperature - self.initial_temperature) * progress
562 }
563 ScheduleType::Exponential => {
564 self.initial_temperature
565 * (self.final_temperature / self.initial_temperature).powf(progress)
566 }
567 ScheduleType::Cosine => {
568 let cosine_progress = 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
569 self.final_temperature
570 + (self.initial_temperature - self.final_temperature) * cosine_progress
571 }
572 ScheduleType::Warmup => {
573 if progress < 0.1 {
574 self.initial_temperature * (progress / 0.1)
576 } else {
577 let decay_progress = (progress - 0.1) / 0.9;
579 self.initial_temperature
580 + (self.final_temperature - self.initial_temperature) * decay_progress
581 }
582 }
583 }
584 }
585
586 pub fn step(&mut self) {
587 self.current_step += 1;
588 }
589}
590
591impl DomainAdapter {
592 pub fn new(adaptation_strength: f32) -> Self {
593 Self {
594 source_stats: DomainStatistics::default(),
595 target_stats: DomainStatistics::default(),
596 adaptation_weights: Vec::new(),
597 domain_classifier: None,
598 adaptation_strength,
599 }
600 }
601
602 pub fn adapt_embedding(&self, embedding: &Vector, is_source_domain: bool) -> Result<Vector> {
603 let input_values = embedding.as_f32();
604 let mut adapted_values = input_values.clone();
605
606 if self.adaptation_weights.len() != input_values.len() {
607 return Ok(embedding.clone()); }
609
610 let stats = if is_source_domain {
612 &self.source_stats
613 } else {
614 &self.target_stats
615 };
616
617 for (i, adapted_value) in adapted_values.iter_mut().enumerate() {
618 if i < stats.mean.len() && i < stats.variance.len() {
619 let normalized =
621 (*adapted_value - stats.mean[i]) / (stats.variance[i].sqrt() + 1e-8);
622
623 *adapted_value = normalized * self.adaptation_weights[i] * self.adaptation_strength
625 + *adapted_value * (1.0 - self.adaptation_strength);
626 }
627 }
628
629 Ok(Vector::new(adapted_values))
630 }
631
632 pub fn update_domain_statistics(&mut self, embeddings: &[Vector], is_source_domain: bool) {
633 let stats = if is_source_domain {
634 &mut self.source_stats
635 } else {
636 &mut self.target_stats
637 };
638
639 if embeddings.is_empty() {
640 return;
641 }
642
643 let dim = embeddings[0].dimensions;
644 if stats.mean.len() != dim {
645 stats.mean = vec![0.0; dim];
646 stats.variance = vec![0.0; dim];
647 stats.sample_count = 0;
648 }
649
650 for embedding in embeddings {
652 let values = embedding.as_f32();
653 for (i, &value) in values.iter().enumerate().take(dim) {
654 let delta = value - stats.mean[i];
655 stats.sample_count += 1;
656 stats.mean[i] += delta / stats.sample_count as f32;
657 let delta2 = value - stats.mean[i];
658 stats.variance[i] += delta * delta2;
659 }
660 }
661
662 if stats.sample_count > 1 {
664 for variance in &mut stats.variance {
665 *variance /= (stats.sample_count - 1) as f32;
666 }
667 }
668
669 self.update_adaptation_weights();
671 }
672
673 fn update_adaptation_weights(&mut self) {
674 let dim = self.source_stats.mean.len();
675 if dim == 0 || dim != self.target_stats.mean.len() {
676 return;
677 }
678
679 self.adaptation_weights = vec![1.0; dim];
680
681 for i in 0..dim {
682 let mean_diff = (self.source_stats.mean[i] - self.target_stats.mean[i]).abs();
684 let var_ratio = (self.source_stats.variance[i]
685 / (self.target_stats.variance[i] + 1e-8))
686 .ln()
687 .abs();
688
689 let discrepancy = mean_diff + 0.5 * var_ratio;
691 self.adaptation_weights[i] = 1.0 / (1.0 + discrepancy);
692 }
693 }
694}
695
696impl JointEmbeddingSpace {
697 pub fn new(config: JointEmbeddingConfig) -> Self {
698 let text_projector = LinearProjector::new(
699 768, config.joint_dim,
701 0.1,
702 ActivationFunction::GELU,
703 );
704
705 let image_projector = LinearProjector::new(
706 2048, config.joint_dim,
708 0.1,
709 ActivationFunction::GELU,
710 );
711
712 let audio_projector = LinearProjector::new(
713 1024, config.joint_dim,
715 0.1,
716 ActivationFunction::GELU,
717 );
718
719 let video_projector = LinearProjector::new(
720 1536, config.joint_dim,
722 0.1,
723 ActivationFunction::GELU,
724 );
725
726 let attention_mechanism = CrossModalAttention::new(config.joint_dim, 8, 0.1, true);
727
728 let temperature_scheduler = TemperatureScheduler::new(
729 config.temperature * 2.0,
730 config.temperature,
731 1000,
732 ScheduleType::Cosine,
733 );
734
735 let domain_adapter = DomainAdapter::new(config.alignment_strength);
736
737 Self {
738 config,
739 text_projector,
740 image_projector,
741 audio_projector,
742 video_projector,
743 attention_mechanism,
744 alignment_cache: Arc::new(RwLock::new(HashMap::new())),
745 training_stats: Arc::new(RwLock::new(TrainingStatistics::default())),
746 temperature_scheduler,
747 domain_adapter,
748 }
749 }
750
751 pub fn project_to_joint_space(&self, modality: Modality, embedding: &Vector) -> Result<Vector> {
753 let projected = match modality {
754 Modality::Text => self.text_projector.forward(embedding)?,
755 Modality::Image => self.image_projector.forward(embedding)?,
756 Modality::Audio => self.audio_projector.forward(embedding)?,
757 Modality::Video => self.video_projector.forward(embedding)?,
758 _ => {
759 self.text_projector.forward(embedding)?
761 }
762 };
763
764 Ok(projected.normalized())
766 }
767
768 pub fn cross_modal_similarity(
770 &self,
771 modality1: Modality,
772 embedding1: &Vector,
773 modality2: Modality,
774 embedding2: &Vector,
775 ) -> Result<f32> {
776 let joint_emb1 = self.project_to_joint_space(modality1, embedding1)?;
777 let joint_emb2 = self.project_to_joint_space(modality2, embedding2)?;
778
779 if modality1 != modality2 {
781 let attended_emb1 =
782 self.attention_mechanism
783 .cross_attention(&joint_emb1, &joint_emb2, &joint_emb2)?;
784 let attended_emb2 =
785 self.attention_mechanism
786 .cross_attention(&joint_emb2, &joint_emb1, &joint_emb1)?;
787
788 attended_emb1.cosine_similarity(&attended_emb2)
789 } else {
790 joint_emb1.cosine_similarity(&joint_emb2)
791 }
792 }
793
794 pub fn contrastive_align(
796 &mut self,
797 positive_pairs: &[(Modality, Vector, Modality, Vector)],
798 negative_pairs: &[(Modality, Vector, Modality, Vector)],
799 ) -> Result<f32> {
800 let mut total_loss = 0.0;
801 let temperature = self.temperature_scheduler.get_current_temperature();
802
803 for (mod1, emb1, mod2, emb2) in positive_pairs {
805 let similarity = self.cross_modal_similarity(*mod1, emb1, *mod2, emb2)?;
806 let positive_score = similarity / temperature;
807
808 let positive_loss = -positive_score.ln_1p(); total_loss += positive_loss;
811
812 self.cache_alignment(*mod1, emb1.clone(), *mod2, emb2.clone(), similarity);
814 }
815
816 for (mod1, emb1, mod2, emb2) in negative_pairs {
818 let similarity = self.cross_modal_similarity(*mod1, emb1, *mod2, emb2)?;
819 let negative_score = similarity / temperature;
820
821 let negative_loss = (negative_score + self.config.margin).max(0.0);
823 total_loss += negative_loss;
824 }
825
826 self.update_training_stats(positive_pairs.len(), negative_pairs.len(), total_loss);
828
829 self.temperature_scheduler.step();
831
832 Ok(total_loss / (positive_pairs.len() + negative_pairs.len()) as f32)
833 }
834
835 pub fn cross_modal_search(
837 &self,
838 query_modality: Modality,
839 query_embedding: &Vector,
840 candidate_modality: Modality,
841 candidate_embeddings: &[Vector],
842 top_k: usize,
843 ) -> Result<Vec<(usize, f32)>> {
844 let query_joint = self.project_to_joint_space(query_modality, query_embedding)?;
845 let mut similarities = Vec::new();
846
847 for (idx, candidate) in candidate_embeddings.iter().enumerate() {
848 let candidate_joint = self.project_to_joint_space(candidate_modality, candidate)?;
849
850 let similarity = if query_modality != candidate_modality {
852 let attended_query = self.attention_mechanism.cross_attention(
853 &query_joint,
854 &candidate_joint,
855 &candidate_joint,
856 )?;
857 attended_query.cosine_similarity(&candidate_joint)?
858 } else {
859 query_joint.cosine_similarity(&candidate_joint)?
860 };
861
862 similarities.push((idx, similarity));
863 }
864
865 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
867 similarities.truncate(top_k);
868
869 Ok(similarities)
870 }
871
872 pub fn zero_shot_retrieval(
874 &self,
875 query_modality: Modality,
876 query_embedding: &Vector,
877 target_modality: Modality,
878 target_embeddings: &[Vector],
879 top_k: usize,
880 ) -> Result<Vec<(usize, f32)>> {
881 let _query_joint = self.project_to_joint_space(query_modality, query_embedding)?;
883
884 self.cross_modal_search(
886 query_modality,
887 query_embedding,
888 target_modality,
889 target_embeddings,
890 top_k,
891 )
892 }
893
894 pub fn multi_modal_fusion(&self, modalities: &[(Modality, Vector)]) -> Result<Vector> {
896 if modalities.is_empty() {
897 return Err(anyhow!("No modalities provided for fusion"));
898 }
899
900 let mut joint_embeddings = Vec::new();
901 for (modality, embedding) in modalities {
902 let joint_emb = self.project_to_joint_space(*modality, embedding)?;
903 joint_embeddings.push(joint_emb);
904 }
905
906 let mut attended_embeddings = Vec::new();
908 for i in 0..joint_embeddings.len() {
909 let mut attended = joint_embeddings[i].clone();
910
911 for j in 0..joint_embeddings.len() {
912 if i != j {
913 let cross_attended = self.attention_mechanism.cross_attention(
914 &joint_embeddings[i],
915 &joint_embeddings[j],
916 &joint_embeddings[j],
917 )?;
918
919 let weight = 1.0 / joint_embeddings.len() as f32;
921 attended = attended.add(&cross_attended.scale(weight))?;
922 }
923 }
924
925 attended_embeddings.push(attended);
926 }
927
928 if attended_embeddings.len() == 1 {
930 Ok(attended_embeddings[0].clone())
931 } else {
932 let mut fused = attended_embeddings[0].clone();
933 for embedding in attended_embeddings.iter().skip(1) {
934 fused = fused.add(embedding)?;
935 }
936 Ok(fused.scale(1.0 / attended_embeddings.len() as f32))
937 }
938 }
939
940 fn cache_alignment(
941 &self,
942 mod1: Modality,
943 emb1: Vector,
944 mod2: Modality,
945 emb2: Vector,
946 similarity: f32,
947 ) {
948 let alignment = AlignmentPair {
949 modality1: mod1,
950 modality2: mod2,
951 embedding1: emb1,
952 embedding2: emb2,
953 similarity,
954 confidence: similarity.abs(), timestamp: std::time::SystemTime::now(),
956 };
957
958 let cache_key = format!("{mod1:?}_{mod2:?}_{similarity}");
959 let mut cache = self.alignment_cache.write();
960 cache.insert(cache_key, alignment);
961
962 if cache.len() > 10000 {
964 let mut entries: Vec<_> = cache.iter().collect();
966 entries.sort_by_key(|(_, v)| v.timestamp);
967 let oldest_key = entries[0].0.clone();
968 cache.remove(&oldest_key);
969 }
970 }
971
972 fn update_training_stats(&self, positive_count: usize, negative_count: usize, loss: f32) {
973 let mut stats = self.training_stats.write();
974 stats.total_samples += (positive_count + negative_count) as u64;
975 stats.positive_pairs += positive_count as u64;
976 stats.negative_pairs += negative_count as u64;
977
978 let total_samples = stats.total_samples as f32;
980 stats.average_loss = (stats.average_loss * (total_samples - 1.0) + loss) / total_samples;
981 }
982
983 pub fn get_training_stats(&self) -> TrainingStatistics {
985 self.training_stats.read().clone()
986 }
987
988 pub fn get_cache_stats(&self) -> (usize, f32) {
990 let cache = self.alignment_cache.read();
991 let cache_size = cache.len();
992 let avg_similarity = if cache.is_empty() {
993 0.0
994 } else {
995 cache.values().map(|a| a.similarity).sum::<f32>() / cache_size as f32
996 };
997 (cache_size, avg_similarity)
998 }
999
1000 pub fn evaluate_retrieval(
1002 &self,
1003 test_pairs: &[(Modality, Vector, Modality, Vector)],
1004 distractors: &[(Modality, Vector)],
1005 k_values: &[usize],
1006 ) -> Result<HashMap<usize, f32>> {
1007 let mut recall_at_k = HashMap::new();
1008
1009 for &k in k_values {
1010 let mut total_recall = 0.0;
1011
1012 for (query_mod, query_emb, target_mod, target_emb) in test_pairs {
1013 let mut candidates = vec![target_emb.clone()];
1015 for (distractor_mod, distractor_emb) in distractors {
1016 if *distractor_mod == *target_mod {
1017 candidates.push(distractor_emb.clone());
1018 }
1019 }
1020
1021 let results =
1023 self.cross_modal_search(*query_mod, query_emb, *target_mod, &candidates, k)?;
1024
1025 let found_target = results.iter().any(|(idx, _)| *idx == 0);
1027 if found_target {
1028 total_recall += 1.0;
1029 }
1030 }
1031
1032 recall_at_k.insert(k, total_recall / test_pairs.len() as f32);
1033 }
1034
1035 Ok(recall_at_k)
1036 }
1037}
1038
1039impl CLIPAligner {
1040 pub fn new(config: JointEmbeddingConfig) -> Self {
1041 let joint_space = JointEmbeddingSpace::new(config.clone());
1042 let optimizer = ContrastiveOptimizer::new(config.learning_rate, 0.9, config.weight_decay);
1043 let data_augmentation = DataAugmentation::default();
1044 let curriculum = CurriculumLearning::new();
1045
1046 Self {
1047 joint_space,
1048 optimizer,
1049 data_augmentation,
1050 curriculum,
1051 }
1052 }
1053
1054 pub fn train_alignment(
1056 &mut self,
1057 training_data: &[(MultiModalContent, MultiModalContent)],
1058 epochs: usize,
1059 ) -> Result<Vec<f32>> {
1060 let mut epoch_losses = Vec::new();
1061
1062 for epoch in 0..epochs {
1063 let mut epoch_loss = 0.0;
1064 let mut batch_count = 0;
1065
1066 for batch in training_data.chunks(self.joint_space.config.batch_size) {
1068 let (positive_pairs, negative_pairs) = self.create_contrastive_pairs(batch)?;
1069
1070 let augmented_positive = self.augment_pairs(&positive_pairs)?;
1072 let augmented_negative = self.augment_pairs(&negative_pairs)?;
1073
1074 let batch_loss = self
1076 .joint_space
1077 .contrastive_align(&augmented_positive, &augmented_negative)?;
1078
1079 epoch_loss += batch_loss;
1080 batch_count += 1;
1081
1082 if self.curriculum.enabled {
1084 self.curriculum.update_difficulty(batch_loss);
1085 }
1086 }
1087
1088 let avg_epoch_loss = epoch_loss / batch_count as f32;
1089 epoch_losses.push(avg_epoch_loss);
1090
1091 self.optimizer.step_schedule();
1093
1094 tracing::info!(
1095 "Epoch {}/{}: Average Loss = {:.4}, Temperature = {:.4}",
1096 epoch + 1,
1097 epochs,
1098 avg_epoch_loss,
1099 self.joint_space
1100 .temperature_scheduler
1101 .get_current_temperature()
1102 );
1103 }
1104
1105 Ok(epoch_losses)
1106 }
1107
1108 fn create_contrastive_pairs(
1109 &self,
1110 batch: &[(MultiModalContent, MultiModalContent)],
1111 ) -> Result<ContrastivePairs> {
1112 let mut positive_pairs = Vec::new();
1113 let mut negative_pairs = Vec::new();
1114
1115 for (content1, content2) in batch {
1117 for (mod1, data1) in &content1.modalities {
1118 for (mod2, data2) in &content2.modalities {
1119 if let (Ok(emb1), Ok(emb2)) = (
1120 self.extract_embedding(*mod1, data1),
1121 self.extract_embedding(*mod2, data2),
1122 ) {
1123 positive_pairs.push((*mod1, emb1, *mod2, emb2));
1124 }
1125 }
1126 }
1127 }
1128
1129 let batch_size = batch.len();
1131 for i in 0..batch_size {
1132 for j in 0..batch_size {
1133 if i != j {
1134 let (content1, _) = &batch[i];
1135 let (_, content2) = &batch[j];
1136
1137 for (mod1, data1) in &content1.modalities {
1138 for (mod2, data2) in &content2.modalities {
1139 if let (Ok(emb1), Ok(emb2)) = (
1140 self.extract_embedding(*mod1, data1),
1141 self.extract_embedding(*mod2, data2),
1142 ) {
1143 negative_pairs.push((*mod1, emb1, *mod2, emb2));
1144 }
1145 }
1146 }
1147 }
1148 }
1149 }
1150
1151 let max_negatives = positive_pairs.len() * self.joint_space.config.negative_samples;
1153 negative_pairs.truncate(max_negatives);
1154
1155 Ok((positive_pairs, negative_pairs))
1156 }
1157
1158 fn extract_embedding(&self, modality: Modality, data: &ModalityData) -> Result<Vector> {
1159 match (modality, data) {
1161 (Modality::Text, ModalityData::Text(text)) => {
1162 let words: Vec<&str> = text.split_whitespace().collect();
1164 let embedding = self.create_text_embedding(&words);
1165 Ok(embedding)
1166 }
1167 (Modality::Image, ModalityData::Image(image)) => {
1168 let embedding = self.create_image_embedding(image);
1170 Ok(embedding)
1171 }
1172 (Modality::Audio, ModalityData::Audio(audio)) => {
1173 let embedding = self.create_audio_embedding(audio);
1175 Ok(embedding)
1176 }
1177 (Modality::Video, ModalityData::Video(video)) => {
1178 let embedding = self.create_video_embedding(video);
1180 Ok(embedding)
1181 }
1182 (Modality::Numeric, ModalityData::Numeric(values)) => Ok(Vector::new(values.clone())),
1183 _ => Err(anyhow!("Modality-data type mismatch")),
1184 }
1185 }
1186
1187 fn create_text_embedding(&self, words: &[&str]) -> Vector {
1188 let mut embedding = vec![0.0; 768]; for (i, word) in words.iter().enumerate().take(100) {
1192 let hash = self.simple_hash(word) as usize;
1193 let idx = hash % embedding.len();
1194 embedding[idx] += 1.0 / (i + 1) as f32; }
1196
1197 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
1199 if norm > 0.0 {
1200 for value in &mut embedding {
1201 *value /= norm;
1202 }
1203 }
1204
1205 Vector::new(embedding)
1206 }
1207
1208 fn create_image_embedding(&self, image: &ImageData) -> Vector {
1209 let mut embedding = vec![0.0; 2048]; let color_features = self.extract_color_features(image);
1214 for (i, &feature) in color_features.iter().enumerate().take(256) {
1215 if i < embedding.len() {
1216 embedding[i] = feature;
1217 }
1218 }
1219
1220 let texture_features = self.extract_texture_features(image);
1222 for (i, &feature) in texture_features.iter().enumerate().take(256) {
1223 if i + 256 < embedding.len() {
1224 embedding[i + 256] = feature;
1225 }
1226 }
1227
1228 Vector::new(embedding)
1229 }
1230
1231 fn create_audio_embedding(&self, audio: &AudioData) -> Vector {
1232 let mut embedding = vec![0.0; 1024]; if let Some(ref features) = audio.features {
1237 for (i, &feature) in features.iter().enumerate().take(embedding.len()) {
1238 embedding[i] = feature;
1239 }
1240 } else {
1241 let spectral_features = self.extract_spectral_features(audio);
1243 for (i, &feature) in spectral_features.iter().enumerate().take(embedding.len()) {
1244 embedding[i] = feature;
1245 }
1246 }
1247
1248 Vector::new(embedding)
1249 }
1250
1251 fn create_video_embedding(&self, video: &VideoData) -> Vector {
1252 let mut embedding = vec![0.0; 1536]; if !video.frames.is_empty() {
1257 let frame_embedding = self.create_image_embedding(&video.frames[0]);
1258 let frame_values = frame_embedding.as_f32();
1259 for (i, &value) in frame_values.iter().enumerate().take(1024) {
1260 if i < embedding.len() {
1261 embedding[i] = value;
1262 }
1263 }
1264 }
1265
1266 if let Some(ref audio) = video.audio {
1268 let audio_embedding = self.create_audio_embedding(audio);
1269 let audio_values = audio_embedding.as_f32();
1270 for (i, &value) in audio_values.iter().enumerate().take(512) {
1271 if i + 1024 < embedding.len() {
1272 embedding[i + 1024] = value;
1273 }
1274 }
1275 }
1276
1277 Vector::new(embedding)
1278 }
1279
1280 fn simple_hash(&self, text: &str) -> u64 {
1281 let mut hash = 5381u64;
1282 for byte in text.bytes() {
1283 hash = hash.wrapping_mul(33).wrapping_add(byte as u64);
1284 }
1285 hash
1286 }
1287
1288 fn extract_color_features(&self, image: &ImageData) -> Vec<f32> {
1289 let mut histogram = vec![0.0; 256];
1291
1292 match image.format {
1293 crate::cross_modal_embeddings::ImageFormat::RGB => {
1294 for chunk in image.data.chunks(3) {
1295 if chunk.len() == 3 {
1296 let intensity = (chunk[0] as f32 + chunk[1] as f32 + chunk[2] as f32) / 3.0;
1297 let bin = (intensity as usize).min(255);
1298 histogram[bin] += 1.0;
1299 }
1300 }
1301 }
1302 _ => {
1303 for &pixel in &image.data {
1305 let bin = (pixel as usize).min(255);
1306 histogram[bin] += 1.0;
1307 }
1308 }
1309 }
1310
1311 let total: f32 = histogram.iter().sum();
1313 if total > 0.0 {
1314 for value in &mut histogram {
1315 *value /= total;
1316 }
1317 }
1318
1319 histogram
1320 }
1321
1322 fn extract_texture_features(&self, image: &ImageData) -> Vec<f32> {
1323 let mut features = vec![0.0; 256];
1325
1326 let width = image.width as usize;
1327 let height = image.height as usize;
1328
1329 if width > 2 && height > 2 {
1330 for y in 1..height - 1 {
1331 for x in 1..width - 1 {
1332 let center_idx = y * width + x;
1333 if center_idx < image.data.len() {
1334 let center = image.data[center_idx];
1335 let mut pattern = 0u8;
1336
1337 let neighbors = [
1339 (-1, -1),
1340 (0, -1),
1341 (1, -1),
1342 (-1, 0),
1343 (1, 0),
1344 (-1, 1),
1345 (0, 1),
1346 (1, 1),
1347 ];
1348
1349 for (bit, (dx, dy)) in neighbors.iter().enumerate() {
1350 let nx = (x as i32 + dx) as usize;
1351 let ny = (y as i32 + dy) as usize;
1352 let neighbor_idx = ny * width + nx;
1353
1354 if neighbor_idx < image.data.len() && image.data[neighbor_idx] > center
1355 {
1356 pattern |= 1 << bit;
1357 }
1358 }
1359
1360 features[pattern as usize] += 1.0;
1361 }
1362 }
1363 }
1364 }
1365
1366 let total: f32 = features.iter().sum();
1368 if total > 0.0 {
1369 for value in &mut features {
1370 *value /= total;
1371 }
1372 }
1373
1374 features
1375 }
1376
1377 fn extract_spectral_features(&self, audio: &AudioData) -> Vec<f32> {
1378 let mut features = vec![0.0; 128];
1380
1381 if !audio.samples.is_empty() {
1382 let chunk_size = audio.samples.len() / features.len();
1384
1385 for (i, feature) in features.iter_mut().enumerate() {
1386 let start = i * chunk_size;
1387 let end = ((i + 1) * chunk_size).min(audio.samples.len());
1388
1389 if start < end {
1390 let chunk = &audio.samples[start..end];
1391
1392 let energy: f32 = chunk.iter().map(|x| x * x).sum();
1394 *feature = energy.sqrt() / (chunk.len() as f32).sqrt();
1395 }
1396 }
1397 }
1398
1399 features
1400 }
1401
1402 fn augment_pairs(
1403 &self,
1404 pairs: &[(Modality, Vector, Modality, Vector)],
1405 ) -> Result<Vec<(Modality, Vector, Modality, Vector)>> {
1406 let mut augmented = Vec::new();
1408
1409 for (mod1, emb1, mod2, emb2) in pairs {
1410 let aug_emb1 = self.add_noise(emb1, 0.01)?;
1411 let aug_emb2 = self.add_noise(emb2, 0.01)?;
1412 augmented.push((*mod1, aug_emb1, *mod2, aug_emb2));
1413 }
1414
1415 Ok(augmented)
1416 }
1417
1418 fn add_noise(&self, embedding: &Vector, noise_std: f32) -> Result<Vector> {
1419 let values = embedding.as_f32();
1420 let mut noisy_values = Vec::with_capacity(values.len());
1421
1422 for (i, &value) in values.iter().enumerate() {
1423 let noise = ((i as f32 * 0.1234).sin() * noise_std).clamp(-0.1, 0.1);
1425 noisy_values.push(value + noise);
1426 }
1427
1428 Ok(Vector::new(noisy_values))
1429 }
1430}
1431
1432impl ContrastiveOptimizer {
1433 pub fn new(learning_rate: f32, momentum: f32, weight_decay: f32) -> Self {
1434 Self {
1435 learning_rate,
1436 momentum,
1437 weight_decay,
1438 gradient_history: HashMap::new(),
1439 adaptive_lr: true,
1440 lr_schedule: LearningRateSchedule::CosineAnnealing {
1441 min_lr: learning_rate * 0.01,
1442 max_epochs: 100,
1443 },
1444 }
1445 }
1446
1447 pub fn step_schedule(&mut self) {
1448 match self.lr_schedule {
1450 LearningRateSchedule::StepDecay {
1451 step_size: _,
1452 gamma,
1453 } => {
1454 self.learning_rate *= gamma;
1456 }
1457 LearningRateSchedule::ExponentialDecay { gamma } => {
1458 self.learning_rate *= gamma;
1459 }
1460 LearningRateSchedule::CosineAnnealing {
1461 min_lr,
1462 max_epochs: _,
1463 } => {
1464 let progress = 0.01; let lr_range = self.learning_rate - min_lr;
1467 self.learning_rate =
1468 min_lr + lr_range * (1.0 + (std::f32::consts::PI * progress).cos()) / 2.0;
1469 }
1470 LearningRateSchedule::Constant => {
1471 }
1473 }
1474 }
1475}
1476
1477impl Default for DataAugmentation {
1478 fn default() -> Self {
1479 Self {
1480 text_augmentations: vec![
1481 TextAugmentation::RandomWordDropout(0.1),
1482 TextAugmentation::SynonymReplacement(0.1),
1483 ],
1484 image_augmentations: vec![
1485 ImageAugmentation::RandomFlip {
1486 horizontal: true,
1487 vertical: false,
1488 },
1489 ImageAugmentation::ColorJitter {
1490 brightness: 0.2,
1491 contrast: 0.2,
1492 saturation: 0.2,
1493 },
1494 ],
1495 audio_augmentations: vec![
1496 AudioAugmentation::AddNoise { snr_db: 20.0 },
1497 AudioAugmentation::TimeStretch { factor: 1.1 },
1498 ],
1499 cross_modal_mixup: false,
1500 augmentation_probability: 0.5,
1501 }
1502 }
1503}
1504
1505impl Default for CurriculumLearning {
1506 fn default() -> Self {
1507 Self::new()
1508 }
1509}
1510
1511impl CurriculumLearning {
1512 pub fn new() -> Self {
1513 Self {
1514 enabled: false,
1515 current_difficulty: 0.0,
1516 difficulty_schedule: DifficultySchedule::Linear {
1517 start: 0.0,
1518 end: 1.0,
1519 epochs: 50,
1520 },
1521 pacing_function: PacingFunction::Root,
1522 competence_threshold: 0.8,
1523 }
1524 }
1525
1526 pub fn update_difficulty(&mut self, loss: f32) {
1527 if self.enabled {
1528 if loss < self.competence_threshold {
1530 self.current_difficulty = (self.current_difficulty + 0.01).min(1.0);
1531 } else {
1532 self.current_difficulty = (self.current_difficulty - 0.005).max(0.0);
1533 }
1534 }
1535 }
1536}
1537
1538#[cfg(test)]
1539mod tests {
1540 use super::*;
1541
1542 #[test]
1543 fn test_joint_embedding_space() {
1544 let config = JointEmbeddingConfig::default();
1545 let joint_space = JointEmbeddingSpace::new(config);
1546
1547 let text_embedding = Vector::new(vec![0.1; 768]);
1548 let image_embedding = Vector::new(vec![0.2; 2048]);
1549
1550 let joint_text = joint_space
1551 .project_to_joint_space(Modality::Text, &text_embedding)
1552 .unwrap();
1553 let joint_image = joint_space
1554 .project_to_joint_space(Modality::Image, &image_embedding)
1555 .unwrap();
1556
1557 assert_eq!(joint_text.dimensions, 512);
1558 assert_eq!(joint_image.dimensions, 512);
1559
1560 let similarity = joint_space
1561 .cross_modal_similarity(
1562 Modality::Text,
1563 &text_embedding,
1564 Modality::Image,
1565 &image_embedding,
1566 )
1567 .unwrap();
1568
1569 assert!((-1.0..=1.0).contains(&similarity));
1570 }
1571
1572 #[test]
1573 fn test_cross_modal_attention() {
1574 let attention = CrossModalAttention::new(128, 4, 0.1, true);
1575
1576 let query = Vector::new(vec![0.1; 128]);
1577 let key = Vector::new(vec![0.2; 128]);
1578 let value = Vector::new(vec![0.3; 128]);
1579
1580 let result = attention.cross_attention(&query, &key, &value).unwrap();
1581 assert_eq!(result.dimensions, 128);
1582 }
1583
1584 #[test]
1585 fn test_contrastive_learning() {
1586 let config = JointEmbeddingConfig::default();
1587 let mut joint_space = JointEmbeddingSpace::new(config);
1588
1589 let positive_pairs = vec![(
1590 Modality::Text,
1591 Vector::new(vec![0.1; 768]),
1592 Modality::Image,
1593 Vector::new(vec![0.1; 2048]),
1594 )];
1595
1596 let negative_pairs = vec![(
1597 Modality::Text,
1598 Vector::new(vec![0.1; 768]),
1599 Modality::Image,
1600 Vector::new(vec![-0.1; 2048]),
1601 )];
1602
1603 let loss = joint_space
1604 .contrastive_align(&positive_pairs, &negative_pairs)
1605 .unwrap();
1606
1607 assert!(loss >= 0.0);
1608 }
1609
1610 #[test]
1611 fn test_clip_aligner() {
1612 let config = JointEmbeddingConfig::default();
1613 let aligner = CLIPAligner::new(config);
1614
1615 let text_words = vec!["hello", "world"];
1616 let text_embedding = aligner.create_text_embedding(&text_words);
1617 assert_eq!(text_embedding.dimensions, 768);
1618
1619 let (cache_size, _) = aligner.joint_space.get_cache_stats();
1620 assert_eq!(cache_size, 0); }
1622}