Skip to main content

oxirs_vec/
joint_embedding_spaces_aligner.rs

1//! Alignment algorithms for Joint Embedding Spaces
2//!
3//! CCA, Procrustes, manifold alignment, cross-space distance computation.
4
5use super::joint_embedding_spaces_types::{
6    ActivationFunction, AlignmentPair, CrossModalAttention, DomainAdapter, DomainStatistics,
7    JointEmbeddingConfig, LinearProjector, ScheduleType, TemperatureScheduler, TrainingStatistics,
8};
9use crate::{cross_modal_embeddings::Modality, Vector};
10use anyhow::{anyhow, Result};
11use parking_lot::RwLock;
12use std::collections::HashMap;
13use std::sync::Arc;
14
15// ─────────────────────────────────────────────────────────────────────────────
16// LinearProjector impl
17// ─────────────────────────────────────────────────────────────────────────────
18
19impl LinearProjector {
20    pub fn new(
21        input_dim: usize,
22        output_dim: usize,
23        dropout_rate: f32,
24        activation: ActivationFunction,
25    ) -> Self {
26        // Xavier/Glorot initialization
27        let limit = (6.0 / (input_dim + output_dim) as f32).sqrt();
28        let mut weights = Vec::with_capacity(output_dim);
29
30        for _ in 0..output_dim {
31            let mut row = Vec::with_capacity(input_dim);
32            for _ in 0..input_dim {
33                let weight = ((row.len() as f32 * 0.01) % 2.0 - 1.0) * limit;
34                row.push(weight);
35            }
36            weights.push(row);
37        }
38
39        let bias = vec![0.0; output_dim];
40
41        Self {
42            weights,
43            bias,
44            input_dim,
45            output_dim,
46            dropout_rate,
47            activation,
48        }
49    }
50
51    pub fn forward(&self, input: &Vector) -> Result<Vector> {
52        if input.dimensions != self.input_dim {
53            return Err(anyhow!(
54                "Input dimension mismatch: expected {}, got {}",
55                self.input_dim,
56                input.dimensions
57            ));
58        }
59
60        let input_values = input.as_f32();
61        let mut output = vec![0.0; self.output_dim];
62
63        for (i, output_val) in output.iter_mut().enumerate().take(self.output_dim) {
64            let mut sum = self.bias[i];
65            for (j, &input_val) in input_values.iter().enumerate().take(self.input_dim) {
66                sum += input_val * self.weights[i][j];
67            }
68            *output_val = sum;
69        }
70
71        for value in &mut output {
72            *value = self.apply_activation(*value);
73        }
74
75        if self.dropout_rate > 0.0 {
76            for (i, value) in output.iter_mut().enumerate() {
77                if (i as f32 * 0.12345) % 1.0 < self.dropout_rate {
78                    *value = 0.0;
79                } else {
80                    *value /= 1.0 - self.dropout_rate;
81                }
82            }
83        }
84
85        Ok(Vector::new(output))
86    }
87
88    fn apply_activation(&self, x: f32) -> f32 {
89        match self.activation {
90            ActivationFunction::ReLU => x.max(0.0),
91            ActivationFunction::GELU => {
92                let sqrt_2_pi = (2.0 / std::f32::consts::PI).sqrt();
93                let inner = sqrt_2_pi * (x + 0.044715 * x.powi(3));
94                0.5 * x * (1.0 + inner.tanh())
95            }
96            ActivationFunction::Tanh => x.tanh(),
97            ActivationFunction::Sigmoid => 1.0 / (1.0 + (-x).exp()),
98            ActivationFunction::Swish => x * (1.0 / (1.0 + (-x).exp())),
99            ActivationFunction::Mish => x * (1.0 + x.exp()).ln().tanh(),
100            ActivationFunction::LeakyReLU(alpha) => {
101                if x > 0.0 {
102                    x
103                } else {
104                    alpha * x
105                }
106            }
107        }
108    }
109
110    pub fn update_weights(&mut self, gradients: &[Vec<f32>], learning_rate: f32) {
111        for i in 0..self.output_dim {
112            for j in 0..self.input_dim {
113                if i < gradients.len() && j < gradients[i].len() {
114                    self.weights[i][j] -= learning_rate * gradients[i][j];
115                }
116            }
117        }
118    }
119}
120
121// ─────────────────────────────────────────────────────────────────────────────
122// CrossModalAttention impl
123// ─────────────────────────────────────────────────────────────────────────────
124
125impl CrossModalAttention {
126    pub fn new(
127        input_dim: usize,
128        num_heads: usize,
129        dropout_rate: f32,
130        enable_relative_pos: bool,
131    ) -> Self {
132        let head_dim = input_dim / num_heads;
133        let scale = 1.0 / (head_dim as f32).sqrt();
134
135        Self {
136            query_projector: LinearProjector::new(
137                input_dim,
138                input_dim,
139                dropout_rate,
140                ActivationFunction::ReLU,
141            ),
142            key_projector: LinearProjector::new(
143                input_dim,
144                input_dim,
145                dropout_rate,
146                ActivationFunction::ReLU,
147            ),
148            value_projector: LinearProjector::new(
149                input_dim,
150                input_dim,
151                dropout_rate,
152                ActivationFunction::ReLU,
153            ),
154            output_projector: LinearProjector::new(
155                input_dim,
156                input_dim,
157                dropout_rate,
158                ActivationFunction::ReLU,
159            ),
160            num_heads,
161            head_dim,
162            dropout_rate,
163            scale,
164            enable_relative_pos,
165        }
166    }
167
168    pub fn cross_attention(
169        &self,
170        query_modality: &Vector,
171        key_modality: &Vector,
172        value_modality: &Vector,
173    ) -> Result<Vector> {
174        let query = self.query_projector.forward(query_modality)?;
175        let key = self.key_projector.forward(key_modality)?;
176        let value = self.value_projector.forward(value_modality)?;
177
178        let attended = self.multi_head_attention(&query, &key, &value)?;
179        self.output_projector.forward(&attended)
180    }
181
182    fn multi_head_attention(&self, query: &Vector, key: &Vector, value: &Vector) -> Result<Vector> {
183        let query_vals = query.as_f32();
184        let key_vals = key.as_f32();
185        let value_vals = value.as_f32();
186
187        if query_vals.len() != key_vals.len() || key_vals.len() != value_vals.len() {
188            return Err(anyhow!("Dimension mismatch in attention"));
189        }
190
191        let _seq_len = query_vals.len() / self.head_dim;
192        let mut output = vec![0.0; query_vals.len()];
193
194        for head in 0..self.num_heads {
195            let head_start = head * self.head_dim;
196            let head_end = head_start + self.head_dim;
197
198            let head_query = &query_vals[head_start..head_end];
199            let head_key = &key_vals[head_start..head_end];
200            let head_value = &value_vals[head_start..head_end];
201
202            let attention_score = self.compute_attention_score(head_query, head_key);
203
204            for i in 0..self.head_dim {
205                output[head_start + i] = head_value[i] * attention_score;
206            }
207        }
208
209        if self.enable_relative_pos {
210            self.apply_relative_position_encoding(&mut output)?;
211        }
212
213        Ok(Vector::new(output))
214    }
215
216    fn compute_attention_score(&self, query: &[f32], key: &[f32]) -> f32 {
217        let dot_product: f32 = query.iter().zip(key.iter()).map(|(q, k)| q * k).sum();
218        let scaled_score = dot_product * self.scale;
219        scaled_score.tanh()
220    }
221
222    fn apply_relative_position_encoding(&self, output: &mut [f32]) -> Result<()> {
223        let output_len = output.len();
224        for (i, value) in output.iter_mut().enumerate() {
225            let pos_encoding = (i as f32 / output_len as f32).sin();
226            *value += 0.1 * pos_encoding;
227        }
228        Ok(())
229    }
230}
231
232// ─────────────────────────────────────────────────────────────────────────────
233// TemperatureScheduler impl
234// ─────────────────────────────────────────────────────────────────────────────
235
236impl TemperatureScheduler {
237    pub fn new(
238        initial_temperature: f32,
239        final_temperature: f32,
240        decay_steps: usize,
241        schedule_type: ScheduleType,
242    ) -> Self {
243        Self {
244            initial_temperature,
245            final_temperature,
246            decay_steps,
247            current_step: 0,
248            schedule_type,
249        }
250    }
251
252    pub fn get_current_temperature(&self) -> f32 {
253        if self.current_step >= self.decay_steps {
254            return self.final_temperature;
255        }
256
257        let progress = self.current_step as f32 / self.decay_steps as f32;
258
259        match self.schedule_type {
260            ScheduleType::Linear => {
261                self.initial_temperature
262                    + (self.final_temperature - self.initial_temperature) * progress
263            }
264            ScheduleType::Exponential => {
265                self.initial_temperature
266                    * (self.final_temperature / self.initial_temperature).powf(progress)
267            }
268            ScheduleType::Cosine => {
269                let cosine_progress = 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
270                self.final_temperature
271                    + (self.initial_temperature - self.final_temperature) * cosine_progress
272            }
273            ScheduleType::Warmup => {
274                if progress < 0.1 {
275                    self.initial_temperature * (progress / 0.1)
276                } else {
277                    let decay_progress = (progress - 0.1) / 0.9;
278                    self.initial_temperature
279                        + (self.final_temperature - self.initial_temperature) * decay_progress
280                }
281            }
282        }
283    }
284
285    pub fn step(&mut self) {
286        self.current_step += 1;
287    }
288}
289
290// ─────────────────────────────────────────────────────────────────────────────
291// DomainAdapter impl
292// ─────────────────────────────────────────────────────────────────────────────
293
294impl DomainAdapter {
295    pub fn new(adaptation_strength: f32) -> Self {
296        Self {
297            source_stats: DomainStatistics::default(),
298            target_stats: DomainStatistics::default(),
299            adaptation_weights: Vec::new(),
300            domain_classifier: None,
301            adaptation_strength,
302        }
303    }
304
305    pub fn adapt_embedding(&self, embedding: &Vector, is_source_domain: bool) -> Result<Vector> {
306        let input_values = embedding.as_f32();
307        let mut adapted_values = input_values.clone();
308
309        if self.adaptation_weights.len() != input_values.len() {
310            return Ok(embedding.clone());
311        }
312
313        let stats = if is_source_domain {
314            &self.source_stats
315        } else {
316            &self.target_stats
317        };
318
319        for (i, adapted_value) in adapted_values.iter_mut().enumerate() {
320            if i < stats.mean.len() && i < stats.variance.len() {
321                let normalized =
322                    (*adapted_value - stats.mean[i]) / (stats.variance[i].sqrt() + 1e-8);
323                *adapted_value = normalized * self.adaptation_weights[i] * self.adaptation_strength
324                    + *adapted_value * (1.0 - self.adaptation_strength);
325            }
326        }
327
328        Ok(Vector::new(adapted_values))
329    }
330
331    pub fn update_domain_statistics(&mut self, embeddings: &[Vector], is_source_domain: bool) {
332        let stats = if is_source_domain {
333            &mut self.source_stats
334        } else {
335            &mut self.target_stats
336        };
337
338        if embeddings.is_empty() {
339            return;
340        }
341
342        let dim = embeddings[0].dimensions;
343        if stats.mean.len() != dim {
344            stats.mean = vec![0.0; dim];
345            stats.variance = vec![0.0; dim];
346            stats.sample_count = 0;
347        }
348
349        for embedding in embeddings {
350            let values = embedding.as_f32();
351            for (i, &value) in values.iter().enumerate().take(dim) {
352                let delta = value - stats.mean[i];
353                stats.sample_count += 1;
354                stats.mean[i] += delta / stats.sample_count as f32;
355                let delta2 = value - stats.mean[i];
356                stats.variance[i] += delta * delta2;
357            }
358        }
359
360        if stats.sample_count > 1 {
361            for variance in &mut stats.variance {
362                *variance /= (stats.sample_count - 1) as f32;
363            }
364        }
365
366        self.update_adaptation_weights();
367    }
368
369    fn update_adaptation_weights(&mut self) {
370        let dim = self.source_stats.mean.len();
371        if dim == 0 || dim != self.target_stats.mean.len() {
372            return;
373        }
374
375        self.adaptation_weights = vec![1.0; dim];
376
377        for i in 0..dim {
378            let mean_diff = (self.source_stats.mean[i] - self.target_stats.mean[i]).abs();
379            let var_ratio = (self.source_stats.variance[i]
380                / (self.target_stats.variance[i] + 1e-8))
381                .ln()
382                .abs();
383
384            let discrepancy = mean_diff + 0.5 * var_ratio;
385            self.adaptation_weights[i] = 1.0 / (1.0 + discrepancy);
386        }
387    }
388}
389
390// ─────────────────────────────────────────────────────────────────────────────
391// JointEmbeddingSpace
392// ─────────────────────────────────────────────────────────────────────────────
393
394/// Joint embedding space for cross-modal alignment
395pub struct JointEmbeddingSpace {
396    pub(crate) config: JointEmbeddingConfig,
397    pub(crate) text_projector: LinearProjector,
398    pub(crate) image_projector: LinearProjector,
399    pub(crate) audio_projector: LinearProjector,
400    pub(crate) video_projector: LinearProjector,
401    pub(crate) attention_mechanism: CrossModalAttention,
402    pub(crate) alignment_cache: Arc<RwLock<HashMap<String, AlignmentPair>>>,
403    pub(crate) training_stats: Arc<RwLock<TrainingStatistics>>,
404    pub(crate) temperature_scheduler: TemperatureScheduler,
405    pub(crate) domain_adapter: DomainAdapter,
406}
407
408impl JointEmbeddingSpace {
409    pub fn new(config: JointEmbeddingConfig) -> Self {
410        let text_projector =
411            LinearProjector::new(768, config.joint_dim, 0.1, ActivationFunction::GELU);
412
413        let image_projector =
414            LinearProjector::new(2048, config.joint_dim, 0.1, ActivationFunction::GELU);
415
416        let audio_projector =
417            LinearProjector::new(1024, config.joint_dim, 0.1, ActivationFunction::GELU);
418
419        let video_projector =
420            LinearProjector::new(1536, config.joint_dim, 0.1, ActivationFunction::GELU);
421
422        let attention_mechanism = CrossModalAttention::new(config.joint_dim, 8, 0.1, true);
423
424        let temperature_scheduler = TemperatureScheduler::new(
425            config.temperature * 2.0,
426            config.temperature,
427            1000,
428            ScheduleType::Cosine,
429        );
430
431        let domain_adapter = DomainAdapter::new(config.alignment_strength);
432
433        Self {
434            config,
435            text_projector,
436            image_projector,
437            audio_projector,
438            video_projector,
439            attention_mechanism,
440            alignment_cache: Arc::new(RwLock::new(HashMap::new())),
441            training_stats: Arc::new(RwLock::new(TrainingStatistics::default())),
442            temperature_scheduler,
443            domain_adapter,
444        }
445    }
446
447    /// Project modality-specific embedding to joint space
448    pub fn project_to_joint_space(&self, modality: Modality, embedding: &Vector) -> Result<Vector> {
449        let projected = match modality {
450            Modality::Text => self.text_projector.forward(embedding)?,
451            Modality::Image => self.image_projector.forward(embedding)?,
452            Modality::Audio => self.audio_projector.forward(embedding)?,
453            Modality::Video => self.video_projector.forward(embedding)?,
454            _ => self.text_projector.forward(embedding)?,
455        };
456
457        Ok(projected.normalized())
458    }
459
460    /// Compute cross-modal similarity in joint space
461    pub fn cross_modal_similarity(
462        &self,
463        modality1: Modality,
464        embedding1: &Vector,
465        modality2: Modality,
466        embedding2: &Vector,
467    ) -> Result<f32> {
468        let joint_emb1 = self.project_to_joint_space(modality1, embedding1)?;
469        let joint_emb2 = self.project_to_joint_space(modality2, embedding2)?;
470
471        if modality1 != modality2 {
472            let attended_emb1 =
473                self.attention_mechanism
474                    .cross_attention(&joint_emb1, &joint_emb2, &joint_emb2)?;
475            let attended_emb2 =
476                self.attention_mechanism
477                    .cross_attention(&joint_emb2, &joint_emb1, &joint_emb1)?;
478
479            attended_emb1.cosine_similarity(&attended_emb2)
480        } else {
481            joint_emb1.cosine_similarity(&joint_emb2)
482        }
483    }
484
485    /// Contrastive learning alignment training
486    pub fn contrastive_align(
487        &mut self,
488        positive_pairs: &[(Modality, Vector, Modality, Vector)],
489        negative_pairs: &[(Modality, Vector, Modality, Vector)],
490    ) -> Result<f32> {
491        let mut total_loss = 0.0;
492        let temperature = self.temperature_scheduler.get_current_temperature();
493
494        for (mod1, emb1, mod2, emb2) in positive_pairs {
495            let similarity = self.cross_modal_similarity(*mod1, emb1, *mod2, emb2)?;
496            let positive_score = similarity / temperature;
497            let positive_loss = -positive_score.ln_1p();
498            total_loss += positive_loss;
499
500            self.cache_alignment(*mod1, emb1.clone(), *mod2, emb2.clone(), similarity);
501        }
502
503        for (mod1, emb1, mod2, emb2) in negative_pairs {
504            let similarity = self.cross_modal_similarity(*mod1, emb1, *mod2, emb2)?;
505            let negative_score = similarity / temperature;
506            let negative_loss = (negative_score + self.config.margin).max(0.0);
507            total_loss += negative_loss;
508        }
509
510        self.update_training_stats(positive_pairs.len(), negative_pairs.len(), total_loss);
511        self.temperature_scheduler.step();
512
513        Ok(total_loss / (positive_pairs.len() + negative_pairs.len()) as f32)
514    }
515
516    /// Zero-shot cross-modal retrieval
517    pub fn zero_shot_retrieval(
518        &self,
519        query_modality: Modality,
520        query_embedding: &Vector,
521        target_modality: Modality,
522        target_embeddings: &[Vector],
523        top_k: usize,
524    ) -> Result<Vec<(usize, f32)>> {
525        // Project query to joint space
526        let _query_joint = self.project_to_joint_space(query_modality, query_embedding)?;
527
528        // Search across target modality
529        self.cross_modal_search(
530            query_modality,
531            query_embedding,
532            target_modality,
533            target_embeddings,
534            top_k,
535        )
536    }
537
538    /// Find cross-modal nearest neighbors in joint space
539    pub fn cross_modal_search(
540        &self,
541        query_modality: Modality,
542        query_embedding: &Vector,
543        candidate_modality: Modality,
544        candidate_embeddings: &[Vector],
545        top_k: usize,
546    ) -> Result<Vec<(usize, f32)>> {
547        let query_joint = self.project_to_joint_space(query_modality, query_embedding)?;
548        let mut similarities = Vec::new();
549
550        for (idx, candidate) in candidate_embeddings.iter().enumerate() {
551            let candidate_joint = self.project_to_joint_space(candidate_modality, candidate)?;
552
553            let similarity = if query_modality != candidate_modality {
554                let attended_query = self.attention_mechanism.cross_attention(
555                    &query_joint,
556                    &candidate_joint,
557                    &candidate_joint,
558                )?;
559                attended_query.cosine_similarity(&candidate_joint)?
560            } else {
561                query_joint.cosine_similarity(&candidate_joint)?
562            };
563
564            similarities.push((idx, similarity));
565        }
566
567        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
568        similarities.truncate(top_k);
569
570        Ok(similarities)
571    }
572
573    /// Multi-modal fusion in joint space
574    pub fn multi_modal_fusion(&self, modalities: &[(Modality, Vector)]) -> Result<Vector> {
575        if modalities.is_empty() {
576            return Err(anyhow!("No modalities provided for fusion"));
577        }
578
579        let mut joint_embeddings = Vec::new();
580        for (modality, embedding) in modalities {
581            let joint_emb = self.project_to_joint_space(*modality, embedding)?;
582            joint_embeddings.push(joint_emb);
583        }
584
585        let mut attended_embeddings = Vec::new();
586        for i in 0..joint_embeddings.len() {
587            let mut attended = joint_embeddings[i].clone();
588
589            for j in 0..joint_embeddings.len() {
590                if i != j {
591                    let cross_attended = self.attention_mechanism.cross_attention(
592                        &joint_embeddings[i],
593                        &joint_embeddings[j],
594                        &joint_embeddings[j],
595                    )?;
596
597                    let weight = 1.0 / joint_embeddings.len() as f32;
598                    attended = attended.add(&cross_attended.scale(weight))?;
599                }
600            }
601
602            attended_embeddings.push(attended);
603        }
604
605        if attended_embeddings.len() == 1 {
606            Ok(attended_embeddings[0].clone())
607        } else {
608            let mut fused = attended_embeddings[0].clone();
609            for embedding in attended_embeddings.iter().skip(1) {
610                fused = fused.add(embedding)?;
611            }
612            Ok(fused.scale(1.0 / attended_embeddings.len() as f32))
613        }
614    }
615
616    pub(crate) fn cache_alignment(
617        &self,
618        mod1: Modality,
619        emb1: Vector,
620        mod2: Modality,
621        emb2: Vector,
622        similarity: f32,
623    ) {
624        let alignment = AlignmentPair {
625            modality1: mod1,
626            modality2: mod2,
627            embedding1: emb1,
628            embedding2: emb2,
629            similarity,
630            confidence: similarity.abs(),
631            timestamp: std::time::SystemTime::now(),
632        };
633
634        let cache_key = format!("{mod1:?}_{mod2:?}_{similarity}");
635        let mut cache = self.alignment_cache.write();
636        cache.insert(cache_key, alignment);
637
638        if cache.len() > 10000 {
639            let mut entries: Vec<_> = cache.iter().collect();
640            entries.sort_by_key(|(_, v)| v.timestamp);
641            let oldest_key = entries[0].0.clone();
642            cache.remove(&oldest_key);
643        }
644    }
645
646    pub(crate) fn update_training_stats(
647        &self,
648        positive_count: usize,
649        negative_count: usize,
650        loss: f32,
651    ) {
652        let mut stats = self.training_stats.write();
653        stats.total_samples += (positive_count + negative_count) as u64;
654        stats.positive_pairs += positive_count as u64;
655        stats.negative_pairs += negative_count as u64;
656
657        let total_samples = stats.total_samples as f32;
658        stats.average_loss = (stats.average_loss * (total_samples - 1.0) + loss) / total_samples;
659    }
660
661    /// Get training statistics
662    pub fn get_training_stats(&self) -> TrainingStatistics {
663        self.training_stats.read().clone()
664    }
665
666    /// Get alignment cache statistics
667    pub fn get_cache_stats(&self) -> (usize, f32) {
668        let cache = self.alignment_cache.read();
669        let cache_size = cache.len();
670        let avg_similarity = if cache.is_empty() {
671            0.0
672        } else {
673            cache.values().map(|a| a.similarity).sum::<f32>() / cache_size as f32
674        };
675        (cache_size, avg_similarity)
676    }
677
678    /// Evaluate cross-modal retrieval performance
679    pub fn evaluate_retrieval(
680        &self,
681        test_pairs: &[(Modality, Vector, Modality, Vector)],
682        distractors: &[(Modality, Vector)],
683        k_values: &[usize],
684    ) -> Result<HashMap<usize, f32>> {
685        let mut recall_at_k = HashMap::new();
686
687        for &k in k_values {
688            let mut total_recall = 0.0;
689
690            for (query_mod, query_emb, target_mod, target_emb) in test_pairs {
691                let mut candidates = vec![target_emb.clone()];
692                for (distractor_mod, distractor_emb) in distractors {
693                    if *distractor_mod == *target_mod {
694                        candidates.push(distractor_emb.clone());
695                    }
696                }
697
698                let results =
699                    self.cross_modal_search(*query_mod, query_emb, *target_mod, &candidates, k)?;
700
701                let found_target = results.iter().any(|(idx, _)| *idx == 0);
702                if found_target {
703                    total_recall += 1.0;
704                }
705            }
706
707            recall_at_k.insert(k, total_recall / test_pairs.len() as f32);
708        }
709
710        Ok(recall_at_k)
711    }
712}