Skip to main content

oxirs_vec/
joint_embedding_spaces_transfer.rs

1//! Cross-space transfer functions for Joint Embedding Spaces
2//!
3//! Zero-shot transfer, knowledge transfer, domain adaptation,
4//! CLIP-style contrastive training, data augmentation, curriculum learning.
5
6use super::joint_embedding_spaces_aligner::JointEmbeddingSpace;
7use super::joint_embedding_spaces_types::{
8    AudioAugmentation, ContrastiveOptimizer, ContrastivePairs, CurriculumLearning,
9    DataAugmentation, DifficultySchedule, ImageAugmentation, JointEmbeddingConfig,
10    LearningRateSchedule, PacingFunction, TextAugmentation,
11};
12use crate::cross_modal_embeddings::{
13    AudioData, ImageData, Modality, ModalityData, MultiModalContent, VideoData,
14};
15use crate::Vector;
16use anyhow::{anyhow, Result};
17use std::collections::HashMap;
18
19// ─────────────────────────────────────────────────────────────────────────────
20// CLIPAligner
21// ─────────────────────────────────────────────────────────────────────────────
22
23/// CLIP-style contrastive learning implementation
24pub struct CLIPAligner {
25    pub(crate) joint_space: JointEmbeddingSpace,
26    pub(crate) optimizer: ContrastiveOptimizer,
27    pub(crate) data_augmentation: DataAugmentation,
28    pub(crate) curriculum: CurriculumLearning,
29}
30
31impl CLIPAligner {
32    pub fn new(config: JointEmbeddingConfig) -> Self {
33        let joint_space = JointEmbeddingSpace::new(config.clone());
34        let optimizer = ContrastiveOptimizer::new(config.learning_rate, 0.9, config.weight_decay);
35        let data_augmentation = DataAugmentation::default();
36        let curriculum = CurriculumLearning::new();
37
38        Self {
39            joint_space,
40            optimizer,
41            data_augmentation,
42            curriculum,
43        }
44    }
45
46    /// Train CLIP-style alignment with contrastive learning
47    pub fn train_alignment(
48        &mut self,
49        training_data: &[(MultiModalContent, MultiModalContent)],
50        epochs: usize,
51    ) -> Result<Vec<f32>> {
52        let mut epoch_losses = Vec::new();
53
54        for epoch in 0..epochs {
55            let mut epoch_loss = 0.0;
56            let mut batch_count = 0;
57
58            for batch in training_data.chunks(self.joint_space.config.batch_size) {
59                let (positive_pairs, negative_pairs) = self.create_contrastive_pairs(batch)?;
60
61                let augmented_positive = self.augment_pairs(&positive_pairs)?;
62                let augmented_negative = self.augment_pairs(&negative_pairs)?;
63
64                let batch_loss = self
65                    .joint_space
66                    .contrastive_align(&augmented_positive, &augmented_negative)?;
67
68                epoch_loss += batch_loss;
69                batch_count += 1;
70
71                if self.curriculum.enabled {
72                    self.curriculum.update_difficulty(batch_loss);
73                }
74            }
75
76            let avg_epoch_loss = epoch_loss / batch_count as f32;
77            epoch_losses.push(avg_epoch_loss);
78
79            self.optimizer.step_schedule();
80
81            tracing::info!(
82                "Epoch {}/{}: Average Loss = {:.4}, Temperature = {:.4}",
83                epoch + 1,
84                epochs,
85                avg_epoch_loss,
86                self.joint_space
87                    .temperature_scheduler
88                    .get_current_temperature()
89            );
90        }
91
92        Ok(epoch_losses)
93    }
94
95    fn create_contrastive_pairs(
96        &self,
97        batch: &[(MultiModalContent, MultiModalContent)],
98    ) -> Result<ContrastivePairs> {
99        let mut positive_pairs = Vec::new();
100        let mut negative_pairs = Vec::new();
101
102        for (content1, content2) in batch {
103            for (mod1, data1) in &content1.modalities {
104                for (mod2, data2) in &content2.modalities {
105                    if let (Ok(emb1), Ok(emb2)) = (
106                        self.extract_embedding(*mod1, data1),
107                        self.extract_embedding(*mod2, data2),
108                    ) {
109                        positive_pairs.push((*mod1, emb1, *mod2, emb2));
110                    }
111                }
112            }
113        }
114
115        let batch_size = batch.len();
116        for i in 0..batch_size {
117            for j in 0..batch_size {
118                if i != j {
119                    let (content1, _) = &batch[i];
120                    let (_, content2) = &batch[j];
121
122                    for (mod1, data1) in &content1.modalities {
123                        for (mod2, data2) in &content2.modalities {
124                            if let (Ok(emb1), Ok(emb2)) = (
125                                self.extract_embedding(*mod1, data1),
126                                self.extract_embedding(*mod2, data2),
127                            ) {
128                                negative_pairs.push((*mod1, emb1, *mod2, emb2));
129                            }
130                        }
131                    }
132                }
133            }
134        }
135
136        let max_negatives = positive_pairs.len() * self.joint_space.config.negative_samples;
137        negative_pairs.truncate(max_negatives);
138
139        Ok((positive_pairs, negative_pairs))
140    }
141
142    fn extract_embedding(&self, modality: Modality, data: &ModalityData) -> Result<Vector> {
143        match (modality, data) {
144            (Modality::Text, ModalityData::Text(text)) => {
145                let words: Vec<&str> = text.split_whitespace().collect();
146                let embedding = self.create_text_embedding(&words);
147                Ok(embedding)
148            }
149            (Modality::Image, ModalityData::Image(image)) => {
150                let embedding = self.create_image_embedding(image);
151                Ok(embedding)
152            }
153            (Modality::Audio, ModalityData::Audio(audio)) => {
154                let embedding = self.create_audio_embedding(audio);
155                Ok(embedding)
156            }
157            (Modality::Video, ModalityData::Video(video)) => {
158                let embedding = self.create_video_embedding(video);
159                Ok(embedding)
160            }
161            (Modality::Numeric, ModalityData::Numeric(values)) => Ok(Vector::new(values.clone())),
162            _ => Err(anyhow!("Modality-data type mismatch")),
163        }
164    }
165
166    pub(crate) fn create_text_embedding(&self, words: &[&str]) -> Vector {
167        let mut embedding = vec![0.0; 768];
168
169        for (i, word) in words.iter().enumerate().take(100) {
170            let hash = self.simple_hash(word) as usize;
171            let idx = hash % embedding.len();
172            embedding[idx] += 1.0 / (i + 1) as f32;
173        }
174
175        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
176        if norm > 0.0 {
177            for value in &mut embedding {
178                *value /= norm;
179            }
180        }
181
182        Vector::new(embedding)
183    }
184
185    fn create_image_embedding(&self, image: &ImageData) -> Vector {
186        let mut embedding = vec![0.0; 2048];
187
188        let color_features = self.extract_color_features(image);
189        for (i, &feature) in color_features.iter().enumerate().take(256) {
190            if i < embedding.len() {
191                embedding[i] = feature;
192            }
193        }
194
195        let texture_features = self.extract_texture_features(image);
196        for (i, &feature) in texture_features.iter().enumerate().take(256) {
197            if i + 256 < embedding.len() {
198                embedding[i + 256] = feature;
199            }
200        }
201
202        Vector::new(embedding)
203    }
204
205    fn create_audio_embedding(&self, audio: &AudioData) -> Vector {
206        let mut embedding = vec![0.0; 1024];
207
208        if let Some(ref features) = audio.features {
209            for (i, &feature) in features.iter().enumerate().take(embedding.len()) {
210                embedding[i] = feature;
211            }
212        } else {
213            let spectral_features = self.extract_spectral_features(audio);
214            for (i, &feature) in spectral_features.iter().enumerate().take(embedding.len()) {
215                embedding[i] = feature;
216            }
217        }
218
219        Vector::new(embedding)
220    }
221
222    fn create_video_embedding(&self, video: &VideoData) -> Vector {
223        let mut embedding = vec![0.0; 1536];
224
225        if !video.frames.is_empty() {
226            let frame_embedding = self.create_image_embedding(&video.frames[0]);
227            let frame_values = frame_embedding.as_f32();
228            for (i, &value) in frame_values.iter().enumerate().take(1024) {
229                if i < embedding.len() {
230                    embedding[i] = value;
231                }
232            }
233        }
234
235        if let Some(ref audio) = video.audio {
236            let audio_embedding = self.create_audio_embedding(audio);
237            let audio_values = audio_embedding.as_f32();
238            for (i, &value) in audio_values.iter().enumerate().take(512) {
239                if i + 1024 < embedding.len() {
240                    embedding[i + 1024] = value;
241                }
242            }
243        }
244
245        Vector::new(embedding)
246    }
247
248    fn simple_hash(&self, text: &str) -> u64 {
249        let mut hash = 5381u64;
250        for byte in text.bytes() {
251            hash = hash.wrapping_mul(33).wrapping_add(byte as u64);
252        }
253        hash
254    }
255
256    fn extract_color_features(&self, image: &ImageData) -> Vec<f32> {
257        let mut histogram = vec![0.0; 256];
258
259        match image.format {
260            crate::cross_modal_embeddings::ImageFormat::RGB => {
261                for chunk in image.data.chunks(3) {
262                    if chunk.len() == 3 {
263                        let intensity = (chunk[0] as f32 + chunk[1] as f32 + chunk[2] as f32) / 3.0;
264                        let bin = (intensity as usize).min(255);
265                        histogram[bin] += 1.0;
266                    }
267                }
268            }
269            _ => {
270                for &pixel in &image.data {
271                    let bin = (pixel as usize).min(255);
272                    histogram[bin] += 1.0;
273                }
274            }
275        }
276
277        let total: f32 = histogram.iter().sum();
278        if total > 0.0 {
279            for value in &mut histogram {
280                *value /= total;
281            }
282        }
283
284        histogram
285    }
286
287    fn extract_texture_features(&self, image: &ImageData) -> Vec<f32> {
288        let mut features = vec![0.0; 256];
289
290        let width = image.width as usize;
291        let height = image.height as usize;
292
293        if width > 2 && height > 2 {
294            for y in 1..height - 1 {
295                for x in 1..width - 1 {
296                    let center_idx = y * width + x;
297                    if center_idx < image.data.len() {
298                        let center = image.data[center_idx];
299                        let mut pattern = 0u8;
300
301                        let neighbors = [
302                            (-1i32, -1i32),
303                            (0, -1),
304                            (1, -1),
305                            (-1, 0),
306                            (1, 0),
307                            (-1, 1),
308                            (0, 1),
309                            (1, 1),
310                        ];
311
312                        for (bit, (dx, dy)) in neighbors.iter().enumerate() {
313                            let nx = (x as i32 + dx) as usize;
314                            let ny = (y as i32 + dy) as usize;
315                            let neighbor_idx = ny * width + nx;
316
317                            if neighbor_idx < image.data.len() && image.data[neighbor_idx] > center
318                            {
319                                pattern |= 1 << bit;
320                            }
321                        }
322
323                        features[pattern as usize] += 1.0;
324                    }
325                }
326            }
327        }
328
329        let total: f32 = features.iter().sum();
330        if total > 0.0 {
331            for value in &mut features {
332                *value /= total;
333            }
334        }
335
336        features
337    }
338
339    fn extract_spectral_features(&self, audio: &AudioData) -> Vec<f32> {
340        let mut features = vec![0.0; 128];
341
342        if !audio.samples.is_empty() {
343            let chunk_size = audio.samples.len() / features.len();
344
345            for (i, feature) in features.iter_mut().enumerate() {
346                let start = i * chunk_size;
347                let end = ((i + 1) * chunk_size).min(audio.samples.len());
348
349                if start < end {
350                    let chunk = &audio.samples[start..end];
351                    let energy: f32 = chunk.iter().map(|x| x * x).sum();
352                    *feature = energy.sqrt() / (chunk.len() as f32).sqrt();
353                }
354            }
355        }
356
357        features
358    }
359
360    fn augment_pairs(
361        &self,
362        pairs: &[(Modality, Vector, Modality, Vector)],
363    ) -> Result<Vec<(Modality, Vector, Modality, Vector)>> {
364        let mut augmented = Vec::new();
365
366        for (mod1, emb1, mod2, emb2) in pairs {
367            let aug_emb1 = self.add_noise(emb1, 0.01)?;
368            let aug_emb2 = self.add_noise(emb2, 0.01)?;
369            augmented.push((*mod1, aug_emb1, *mod2, aug_emb2));
370        }
371
372        Ok(augmented)
373    }
374
375    fn add_noise(&self, embedding: &Vector, noise_std: f32) -> Result<Vector> {
376        let values = embedding.as_f32();
377        let mut noisy_values = Vec::with_capacity(values.len());
378
379        for (i, &value) in values.iter().enumerate() {
380            let noise = ((i as f32 * 0.1234).sin() * noise_std).clamp(-0.1, 0.1);
381            noisy_values.push(value + noise);
382        }
383
384        Ok(Vector::new(noisy_values))
385    }
386}
387
388// ─────────────────────────────────────────────────────────────────────────────
389// ContrastiveOptimizer impl
390// ─────────────────────────────────────────────────────────────────────────────
391
392impl ContrastiveOptimizer {
393    pub fn new(learning_rate: f32, momentum: f32, weight_decay: f32) -> Self {
394        Self {
395            learning_rate,
396            momentum,
397            weight_decay,
398            gradient_history: HashMap::new(),
399            adaptive_lr: true,
400            lr_schedule: LearningRateSchedule::CosineAnnealing {
401                min_lr: learning_rate * 0.01,
402                max_epochs: 100,
403            },
404        }
405    }
406
407    pub fn step_schedule(&mut self) {
408        match self.lr_schedule {
409            LearningRateSchedule::StepDecay {
410                step_size: _,
411                gamma,
412            } => {
413                self.learning_rate *= gamma;
414            }
415            LearningRateSchedule::ExponentialDecay { gamma } => {
416                self.learning_rate *= gamma;
417            }
418            LearningRateSchedule::CosineAnnealing {
419                min_lr,
420                max_epochs: _,
421            } => {
422                let progress = 0.01;
423                let lr_range = self.learning_rate - min_lr;
424                self.learning_rate =
425                    min_lr + lr_range * (1.0 + (std::f32::consts::PI * progress).cos()) / 2.0;
426            }
427            LearningRateSchedule::Constant => {}
428        }
429    }
430}
431
432// ─────────────────────────────────────────────────────────────────────────────
433// DataAugmentation impl
434// ─────────────────────────────────────────────────────────────────────────────
435
436impl Default for DataAugmentation {
437    fn default() -> Self {
438        Self {
439            text_augmentations: vec![
440                TextAugmentation::RandomWordDropout(0.1),
441                TextAugmentation::SynonymReplacement(0.1),
442            ],
443            image_augmentations: vec![
444                ImageAugmentation::RandomFlip {
445                    horizontal: true,
446                    vertical: false,
447                },
448                ImageAugmentation::ColorJitter {
449                    brightness: 0.2,
450                    contrast: 0.2,
451                    saturation: 0.2,
452                },
453            ],
454            audio_augmentations: vec![
455                AudioAugmentation::AddNoise { snr_db: 20.0 },
456                AudioAugmentation::TimeStretch { factor: 1.1 },
457            ],
458            cross_modal_mixup: false,
459            augmentation_probability: 0.5,
460        }
461    }
462}
463
464// ─────────────────────────────────────────────────────────────────────────────
465// CurriculumLearning impl
466// ─────────────────────────────────────────────────────────────────────────────
467
468impl Default for CurriculumLearning {
469    fn default() -> Self {
470        Self::new()
471    }
472}
473
474impl CurriculumLearning {
475    pub fn new() -> Self {
476        Self {
477            enabled: false,
478            current_difficulty: 0.0,
479            difficulty_schedule: DifficultySchedule::Linear {
480                start: 0.0,
481                end: 1.0,
482                epochs: 50,
483            },
484            pacing_function: PacingFunction::Root,
485            competence_threshold: 0.8,
486        }
487    }
488
489    pub fn update_difficulty(&mut self, loss: f32) {
490        if self.enabled {
491            if loss < self.competence_threshold {
492                self.current_difficulty = (self.current_difficulty + 0.01).min(1.0);
493            } else {
494                self.current_difficulty = (self.current_difficulty - 0.005).max(0.0);
495            }
496        }
497    }
498}
499
500/// Zero-shot cross-modal retrieval helper (delegates to JointEmbeddingSpace)
501pub fn zero_shot_retrieval(
502    space: &JointEmbeddingSpace,
503    query_modality: Modality,
504    query_embedding: &Vector,
505    target_modality: Modality,
506    target_embeddings: &[Vector],
507    top_k: usize,
508) -> Result<Vec<(usize, f32)>> {
509    space.cross_modal_search(
510        query_modality,
511        query_embedding,
512        target_modality,
513        target_embeddings,
514        top_k,
515    )
516}