1use super::joint_embedding_spaces_aligner::JointEmbeddingSpace;
7use super::joint_embedding_spaces_types::{
8 AudioAugmentation, ContrastiveOptimizer, ContrastivePairs, CurriculumLearning,
9 DataAugmentation, DifficultySchedule, ImageAugmentation, JointEmbeddingConfig,
10 LearningRateSchedule, PacingFunction, TextAugmentation,
11};
12use crate::cross_modal_embeddings::{
13 AudioData, ImageData, Modality, ModalityData, MultiModalContent, VideoData,
14};
15use crate::Vector;
16use anyhow::{anyhow, Result};
17use std::collections::HashMap;
18
19pub struct CLIPAligner {
25 pub(crate) joint_space: JointEmbeddingSpace,
26 pub(crate) optimizer: ContrastiveOptimizer,
27 pub(crate) data_augmentation: DataAugmentation,
28 pub(crate) curriculum: CurriculumLearning,
29}
30
31impl CLIPAligner {
32 pub fn new(config: JointEmbeddingConfig) -> Self {
33 let joint_space = JointEmbeddingSpace::new(config.clone());
34 let optimizer = ContrastiveOptimizer::new(config.learning_rate, 0.9, config.weight_decay);
35 let data_augmentation = DataAugmentation::default();
36 let curriculum = CurriculumLearning::new();
37
38 Self {
39 joint_space,
40 optimizer,
41 data_augmentation,
42 curriculum,
43 }
44 }
45
46 pub fn train_alignment(
48 &mut self,
49 training_data: &[(MultiModalContent, MultiModalContent)],
50 epochs: usize,
51 ) -> Result<Vec<f32>> {
52 let mut epoch_losses = Vec::new();
53
54 for epoch in 0..epochs {
55 let mut epoch_loss = 0.0;
56 let mut batch_count = 0;
57
58 for batch in training_data.chunks(self.joint_space.config.batch_size) {
59 let (positive_pairs, negative_pairs) = self.create_contrastive_pairs(batch)?;
60
61 let augmented_positive = self.augment_pairs(&positive_pairs)?;
62 let augmented_negative = self.augment_pairs(&negative_pairs)?;
63
64 let batch_loss = self
65 .joint_space
66 .contrastive_align(&augmented_positive, &augmented_negative)?;
67
68 epoch_loss += batch_loss;
69 batch_count += 1;
70
71 if self.curriculum.enabled {
72 self.curriculum.update_difficulty(batch_loss);
73 }
74 }
75
76 let avg_epoch_loss = epoch_loss / batch_count as f32;
77 epoch_losses.push(avg_epoch_loss);
78
79 self.optimizer.step_schedule();
80
81 tracing::info!(
82 "Epoch {}/{}: Average Loss = {:.4}, Temperature = {:.4}",
83 epoch + 1,
84 epochs,
85 avg_epoch_loss,
86 self.joint_space
87 .temperature_scheduler
88 .get_current_temperature()
89 );
90 }
91
92 Ok(epoch_losses)
93 }
94
95 fn create_contrastive_pairs(
96 &self,
97 batch: &[(MultiModalContent, MultiModalContent)],
98 ) -> Result<ContrastivePairs> {
99 let mut positive_pairs = Vec::new();
100 let mut negative_pairs = Vec::new();
101
102 for (content1, content2) in batch {
103 for (mod1, data1) in &content1.modalities {
104 for (mod2, data2) in &content2.modalities {
105 if let (Ok(emb1), Ok(emb2)) = (
106 self.extract_embedding(*mod1, data1),
107 self.extract_embedding(*mod2, data2),
108 ) {
109 positive_pairs.push((*mod1, emb1, *mod2, emb2));
110 }
111 }
112 }
113 }
114
115 let batch_size = batch.len();
116 for i in 0..batch_size {
117 for j in 0..batch_size {
118 if i != j {
119 let (content1, _) = &batch[i];
120 let (_, content2) = &batch[j];
121
122 for (mod1, data1) in &content1.modalities {
123 for (mod2, data2) in &content2.modalities {
124 if let (Ok(emb1), Ok(emb2)) = (
125 self.extract_embedding(*mod1, data1),
126 self.extract_embedding(*mod2, data2),
127 ) {
128 negative_pairs.push((*mod1, emb1, *mod2, emb2));
129 }
130 }
131 }
132 }
133 }
134 }
135
136 let max_negatives = positive_pairs.len() * self.joint_space.config.negative_samples;
137 negative_pairs.truncate(max_negatives);
138
139 Ok((positive_pairs, negative_pairs))
140 }
141
142 fn extract_embedding(&self, modality: Modality, data: &ModalityData) -> Result<Vector> {
143 match (modality, data) {
144 (Modality::Text, ModalityData::Text(text)) => {
145 let words: Vec<&str> = text.split_whitespace().collect();
146 let embedding = self.create_text_embedding(&words);
147 Ok(embedding)
148 }
149 (Modality::Image, ModalityData::Image(image)) => {
150 let embedding = self.create_image_embedding(image);
151 Ok(embedding)
152 }
153 (Modality::Audio, ModalityData::Audio(audio)) => {
154 let embedding = self.create_audio_embedding(audio);
155 Ok(embedding)
156 }
157 (Modality::Video, ModalityData::Video(video)) => {
158 let embedding = self.create_video_embedding(video);
159 Ok(embedding)
160 }
161 (Modality::Numeric, ModalityData::Numeric(values)) => Ok(Vector::new(values.clone())),
162 _ => Err(anyhow!("Modality-data type mismatch")),
163 }
164 }
165
166 pub(crate) fn create_text_embedding(&self, words: &[&str]) -> Vector {
167 let mut embedding = vec![0.0; 768];
168
169 for (i, word) in words.iter().enumerate().take(100) {
170 let hash = self.simple_hash(word) as usize;
171 let idx = hash % embedding.len();
172 embedding[idx] += 1.0 / (i + 1) as f32;
173 }
174
175 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
176 if norm > 0.0 {
177 for value in &mut embedding {
178 *value /= norm;
179 }
180 }
181
182 Vector::new(embedding)
183 }
184
185 fn create_image_embedding(&self, image: &ImageData) -> Vector {
186 let mut embedding = vec![0.0; 2048];
187
188 let color_features = self.extract_color_features(image);
189 for (i, &feature) in color_features.iter().enumerate().take(256) {
190 if i < embedding.len() {
191 embedding[i] = feature;
192 }
193 }
194
195 let texture_features = self.extract_texture_features(image);
196 for (i, &feature) in texture_features.iter().enumerate().take(256) {
197 if i + 256 < embedding.len() {
198 embedding[i + 256] = feature;
199 }
200 }
201
202 Vector::new(embedding)
203 }
204
205 fn create_audio_embedding(&self, audio: &AudioData) -> Vector {
206 let mut embedding = vec![0.0; 1024];
207
208 if let Some(ref features) = audio.features {
209 for (i, &feature) in features.iter().enumerate().take(embedding.len()) {
210 embedding[i] = feature;
211 }
212 } else {
213 let spectral_features = self.extract_spectral_features(audio);
214 for (i, &feature) in spectral_features.iter().enumerate().take(embedding.len()) {
215 embedding[i] = feature;
216 }
217 }
218
219 Vector::new(embedding)
220 }
221
222 fn create_video_embedding(&self, video: &VideoData) -> Vector {
223 let mut embedding = vec![0.0; 1536];
224
225 if !video.frames.is_empty() {
226 let frame_embedding = self.create_image_embedding(&video.frames[0]);
227 let frame_values = frame_embedding.as_f32();
228 for (i, &value) in frame_values.iter().enumerate().take(1024) {
229 if i < embedding.len() {
230 embedding[i] = value;
231 }
232 }
233 }
234
235 if let Some(ref audio) = video.audio {
236 let audio_embedding = self.create_audio_embedding(audio);
237 let audio_values = audio_embedding.as_f32();
238 for (i, &value) in audio_values.iter().enumerate().take(512) {
239 if i + 1024 < embedding.len() {
240 embedding[i + 1024] = value;
241 }
242 }
243 }
244
245 Vector::new(embedding)
246 }
247
248 fn simple_hash(&self, text: &str) -> u64 {
249 let mut hash = 5381u64;
250 for byte in text.bytes() {
251 hash = hash.wrapping_mul(33).wrapping_add(byte as u64);
252 }
253 hash
254 }
255
256 fn extract_color_features(&self, image: &ImageData) -> Vec<f32> {
257 let mut histogram = vec![0.0; 256];
258
259 match image.format {
260 crate::cross_modal_embeddings::ImageFormat::RGB => {
261 for chunk in image.data.chunks(3) {
262 if chunk.len() == 3 {
263 let intensity = (chunk[0] as f32 + chunk[1] as f32 + chunk[2] as f32) / 3.0;
264 let bin = (intensity as usize).min(255);
265 histogram[bin] += 1.0;
266 }
267 }
268 }
269 _ => {
270 for &pixel in &image.data {
271 let bin = (pixel as usize).min(255);
272 histogram[bin] += 1.0;
273 }
274 }
275 }
276
277 let total: f32 = histogram.iter().sum();
278 if total > 0.0 {
279 for value in &mut histogram {
280 *value /= total;
281 }
282 }
283
284 histogram
285 }
286
287 fn extract_texture_features(&self, image: &ImageData) -> Vec<f32> {
288 let mut features = vec![0.0; 256];
289
290 let width = image.width as usize;
291 let height = image.height as usize;
292
293 if width > 2 && height > 2 {
294 for y in 1..height - 1 {
295 for x in 1..width - 1 {
296 let center_idx = y * width + x;
297 if center_idx < image.data.len() {
298 let center = image.data[center_idx];
299 let mut pattern = 0u8;
300
301 let neighbors = [
302 (-1i32, -1i32),
303 (0, -1),
304 (1, -1),
305 (-1, 0),
306 (1, 0),
307 (-1, 1),
308 (0, 1),
309 (1, 1),
310 ];
311
312 for (bit, (dx, dy)) in neighbors.iter().enumerate() {
313 let nx = (x as i32 + dx) as usize;
314 let ny = (y as i32 + dy) as usize;
315 let neighbor_idx = ny * width + nx;
316
317 if neighbor_idx < image.data.len() && image.data[neighbor_idx] > center
318 {
319 pattern |= 1 << bit;
320 }
321 }
322
323 features[pattern as usize] += 1.0;
324 }
325 }
326 }
327 }
328
329 let total: f32 = features.iter().sum();
330 if total > 0.0 {
331 for value in &mut features {
332 *value /= total;
333 }
334 }
335
336 features
337 }
338
339 fn extract_spectral_features(&self, audio: &AudioData) -> Vec<f32> {
340 let mut features = vec![0.0; 128];
341
342 if !audio.samples.is_empty() {
343 let chunk_size = audio.samples.len() / features.len();
344
345 for (i, feature) in features.iter_mut().enumerate() {
346 let start = i * chunk_size;
347 let end = ((i + 1) * chunk_size).min(audio.samples.len());
348
349 if start < end {
350 let chunk = &audio.samples[start..end];
351 let energy: f32 = chunk.iter().map(|x| x * x).sum();
352 *feature = energy.sqrt() / (chunk.len() as f32).sqrt();
353 }
354 }
355 }
356
357 features
358 }
359
360 fn augment_pairs(
361 &self,
362 pairs: &[(Modality, Vector, Modality, Vector)],
363 ) -> Result<Vec<(Modality, Vector, Modality, Vector)>> {
364 let mut augmented = Vec::new();
365
366 for (mod1, emb1, mod2, emb2) in pairs {
367 let aug_emb1 = self.add_noise(emb1, 0.01)?;
368 let aug_emb2 = self.add_noise(emb2, 0.01)?;
369 augmented.push((*mod1, aug_emb1, *mod2, aug_emb2));
370 }
371
372 Ok(augmented)
373 }
374
375 fn add_noise(&self, embedding: &Vector, noise_std: f32) -> Result<Vector> {
376 let values = embedding.as_f32();
377 let mut noisy_values = Vec::with_capacity(values.len());
378
379 for (i, &value) in values.iter().enumerate() {
380 let noise = ((i as f32 * 0.1234).sin() * noise_std).clamp(-0.1, 0.1);
381 noisy_values.push(value + noise);
382 }
383
384 Ok(Vector::new(noisy_values))
385 }
386}
387
388impl ContrastiveOptimizer {
393 pub fn new(learning_rate: f32, momentum: f32, weight_decay: f32) -> Self {
394 Self {
395 learning_rate,
396 momentum,
397 weight_decay,
398 gradient_history: HashMap::new(),
399 adaptive_lr: true,
400 lr_schedule: LearningRateSchedule::CosineAnnealing {
401 min_lr: learning_rate * 0.01,
402 max_epochs: 100,
403 },
404 }
405 }
406
407 pub fn step_schedule(&mut self) {
408 match self.lr_schedule {
409 LearningRateSchedule::StepDecay {
410 step_size: _,
411 gamma,
412 } => {
413 self.learning_rate *= gamma;
414 }
415 LearningRateSchedule::ExponentialDecay { gamma } => {
416 self.learning_rate *= gamma;
417 }
418 LearningRateSchedule::CosineAnnealing {
419 min_lr,
420 max_epochs: _,
421 } => {
422 let progress = 0.01;
423 let lr_range = self.learning_rate - min_lr;
424 self.learning_rate =
425 min_lr + lr_range * (1.0 + (std::f32::consts::PI * progress).cos()) / 2.0;
426 }
427 LearningRateSchedule::Constant => {}
428 }
429 }
430}
431
432impl Default for DataAugmentation {
437 fn default() -> Self {
438 Self {
439 text_augmentations: vec![
440 TextAugmentation::RandomWordDropout(0.1),
441 TextAugmentation::SynonymReplacement(0.1),
442 ],
443 image_augmentations: vec![
444 ImageAugmentation::RandomFlip {
445 horizontal: true,
446 vertical: false,
447 },
448 ImageAugmentation::ColorJitter {
449 brightness: 0.2,
450 contrast: 0.2,
451 saturation: 0.2,
452 },
453 ],
454 audio_augmentations: vec![
455 AudioAugmentation::AddNoise { snr_db: 20.0 },
456 AudioAugmentation::TimeStretch { factor: 1.1 },
457 ],
458 cross_modal_mixup: false,
459 augmentation_probability: 0.5,
460 }
461 }
462}
463
464impl Default for CurriculumLearning {
469 fn default() -> Self {
470 Self::new()
471 }
472}
473
474impl CurriculumLearning {
475 pub fn new() -> Self {
476 Self {
477 enabled: false,
478 current_difficulty: 0.0,
479 difficulty_schedule: DifficultySchedule::Linear {
480 start: 0.0,
481 end: 1.0,
482 epochs: 50,
483 },
484 pacing_function: PacingFunction::Root,
485 competence_threshold: 0.8,
486 }
487 }
488
489 pub fn update_difficulty(&mut self, loss: f32) {
490 if self.enabled {
491 if loss < self.competence_threshold {
492 self.current_difficulty = (self.current_difficulty + 0.01).min(1.0);
493 } else {
494 self.current_difficulty = (self.current_difficulty - 0.005).max(0.0);
495 }
496 }
497 }
498}
499
500pub fn zero_shot_retrieval(
502 space: &JointEmbeddingSpace,
503 query_modality: Modality,
504 query_embedding: &Vector,
505 target_modality: Modality,
506 target_embeddings: &[Vector],
507 top_k: usize,
508) -> Result<Vec<(usize, f32)>> {
509 space.cross_modal_search(
510 query_modality,
511 query_embedding,
512 target_modality,
513 target_embeddings,
514 top_k,
515 )
516}