1use super::joint_embedding_spaces_types::{
6 ActivationFunction, AlignmentPair, CrossModalAttention, DomainAdapter, DomainStatistics,
7 JointEmbeddingConfig, LinearProjector, ScheduleType, TemperatureScheduler, TrainingStatistics,
8};
9use crate::{cross_modal_embeddings::Modality, Vector};
10use anyhow::{anyhow, Result};
11use parking_lot::RwLock;
12use std::collections::HashMap;
13use std::sync::Arc;
14
15impl LinearProjector {
20 pub fn new(
21 input_dim: usize,
22 output_dim: usize,
23 dropout_rate: f32,
24 activation: ActivationFunction,
25 ) -> Self {
26 let limit = (6.0 / (input_dim + output_dim) as f32).sqrt();
28 let mut weights = Vec::with_capacity(output_dim);
29
30 for _ in 0..output_dim {
31 let mut row = Vec::with_capacity(input_dim);
32 for _ in 0..input_dim {
33 let weight = ((row.len() as f32 * 0.01) % 2.0 - 1.0) * limit;
34 row.push(weight);
35 }
36 weights.push(row);
37 }
38
39 let bias = vec![0.0; output_dim];
40
41 Self {
42 weights,
43 bias,
44 input_dim,
45 output_dim,
46 dropout_rate,
47 activation,
48 }
49 }
50
51 pub fn forward(&self, input: &Vector) -> Result<Vector> {
52 if input.dimensions != self.input_dim {
53 return Err(anyhow!(
54 "Input dimension mismatch: expected {}, got {}",
55 self.input_dim,
56 input.dimensions
57 ));
58 }
59
60 let input_values = input.as_f32();
61 let mut output = vec![0.0; self.output_dim];
62
63 for (i, output_val) in output.iter_mut().enumerate().take(self.output_dim) {
64 let mut sum = self.bias[i];
65 for (j, &input_val) in input_values.iter().enumerate().take(self.input_dim) {
66 sum += input_val * self.weights[i][j];
67 }
68 *output_val = sum;
69 }
70
71 for value in &mut output {
72 *value = self.apply_activation(*value);
73 }
74
75 if self.dropout_rate > 0.0 {
76 for (i, value) in output.iter_mut().enumerate() {
77 if (i as f32 * 0.12345) % 1.0 < self.dropout_rate {
78 *value = 0.0;
79 } else {
80 *value /= 1.0 - self.dropout_rate;
81 }
82 }
83 }
84
85 Ok(Vector::new(output))
86 }
87
88 fn apply_activation(&self, x: f32) -> f32 {
89 match self.activation {
90 ActivationFunction::ReLU => x.max(0.0),
91 ActivationFunction::GELU => {
92 let sqrt_2_pi = (2.0 / std::f32::consts::PI).sqrt();
93 let inner = sqrt_2_pi * (x + 0.044715 * x.powi(3));
94 0.5 * x * (1.0 + inner.tanh())
95 }
96 ActivationFunction::Tanh => x.tanh(),
97 ActivationFunction::Sigmoid => 1.0 / (1.0 + (-x).exp()),
98 ActivationFunction::Swish => x * (1.0 / (1.0 + (-x).exp())),
99 ActivationFunction::Mish => x * (1.0 + x.exp()).ln().tanh(),
100 ActivationFunction::LeakyReLU(alpha) => {
101 if x > 0.0 {
102 x
103 } else {
104 alpha * x
105 }
106 }
107 }
108 }
109
110 pub fn update_weights(&mut self, gradients: &[Vec<f32>], learning_rate: f32) {
111 for i in 0..self.output_dim {
112 for j in 0..self.input_dim {
113 if i < gradients.len() && j < gradients[i].len() {
114 self.weights[i][j] -= learning_rate * gradients[i][j];
115 }
116 }
117 }
118 }
119}
120
121impl CrossModalAttention {
126 pub fn new(
127 input_dim: usize,
128 num_heads: usize,
129 dropout_rate: f32,
130 enable_relative_pos: bool,
131 ) -> Self {
132 let head_dim = input_dim / num_heads;
133 let scale = 1.0 / (head_dim as f32).sqrt();
134
135 Self {
136 query_projector: LinearProjector::new(
137 input_dim,
138 input_dim,
139 dropout_rate,
140 ActivationFunction::ReLU,
141 ),
142 key_projector: LinearProjector::new(
143 input_dim,
144 input_dim,
145 dropout_rate,
146 ActivationFunction::ReLU,
147 ),
148 value_projector: LinearProjector::new(
149 input_dim,
150 input_dim,
151 dropout_rate,
152 ActivationFunction::ReLU,
153 ),
154 output_projector: LinearProjector::new(
155 input_dim,
156 input_dim,
157 dropout_rate,
158 ActivationFunction::ReLU,
159 ),
160 num_heads,
161 head_dim,
162 dropout_rate,
163 scale,
164 enable_relative_pos,
165 }
166 }
167
168 pub fn cross_attention(
169 &self,
170 query_modality: &Vector,
171 key_modality: &Vector,
172 value_modality: &Vector,
173 ) -> Result<Vector> {
174 let query = self.query_projector.forward(query_modality)?;
175 let key = self.key_projector.forward(key_modality)?;
176 let value = self.value_projector.forward(value_modality)?;
177
178 let attended = self.multi_head_attention(&query, &key, &value)?;
179 self.output_projector.forward(&attended)
180 }
181
182 fn multi_head_attention(&self, query: &Vector, key: &Vector, value: &Vector) -> Result<Vector> {
183 let query_vals = query.as_f32();
184 let key_vals = key.as_f32();
185 let value_vals = value.as_f32();
186
187 if query_vals.len() != key_vals.len() || key_vals.len() != value_vals.len() {
188 return Err(anyhow!("Dimension mismatch in attention"));
189 }
190
191 let _seq_len = query_vals.len() / self.head_dim;
192 let mut output = vec![0.0; query_vals.len()];
193
194 for head in 0..self.num_heads {
195 let head_start = head * self.head_dim;
196 let head_end = head_start + self.head_dim;
197
198 let head_query = &query_vals[head_start..head_end];
199 let head_key = &key_vals[head_start..head_end];
200 let head_value = &value_vals[head_start..head_end];
201
202 let attention_score = self.compute_attention_score(head_query, head_key);
203
204 for i in 0..self.head_dim {
205 output[head_start + i] = head_value[i] * attention_score;
206 }
207 }
208
209 if self.enable_relative_pos {
210 self.apply_relative_position_encoding(&mut output)?;
211 }
212
213 Ok(Vector::new(output))
214 }
215
216 fn compute_attention_score(&self, query: &[f32], key: &[f32]) -> f32 {
217 let dot_product: f32 = query.iter().zip(key.iter()).map(|(q, k)| q * k).sum();
218 let scaled_score = dot_product * self.scale;
219 scaled_score.tanh()
220 }
221
222 fn apply_relative_position_encoding(&self, output: &mut [f32]) -> Result<()> {
223 let output_len = output.len();
224 for (i, value) in output.iter_mut().enumerate() {
225 let pos_encoding = (i as f32 / output_len as f32).sin();
226 *value += 0.1 * pos_encoding;
227 }
228 Ok(())
229 }
230}
231
232impl TemperatureScheduler {
237 pub fn new(
238 initial_temperature: f32,
239 final_temperature: f32,
240 decay_steps: usize,
241 schedule_type: ScheduleType,
242 ) -> Self {
243 Self {
244 initial_temperature,
245 final_temperature,
246 decay_steps,
247 current_step: 0,
248 schedule_type,
249 }
250 }
251
252 pub fn get_current_temperature(&self) -> f32 {
253 if self.current_step >= self.decay_steps {
254 return self.final_temperature;
255 }
256
257 let progress = self.current_step as f32 / self.decay_steps as f32;
258
259 match self.schedule_type {
260 ScheduleType::Linear => {
261 self.initial_temperature
262 + (self.final_temperature - self.initial_temperature) * progress
263 }
264 ScheduleType::Exponential => {
265 self.initial_temperature
266 * (self.final_temperature / self.initial_temperature).powf(progress)
267 }
268 ScheduleType::Cosine => {
269 let cosine_progress = 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
270 self.final_temperature
271 + (self.initial_temperature - self.final_temperature) * cosine_progress
272 }
273 ScheduleType::Warmup => {
274 if progress < 0.1 {
275 self.initial_temperature * (progress / 0.1)
276 } else {
277 let decay_progress = (progress - 0.1) / 0.9;
278 self.initial_temperature
279 + (self.final_temperature - self.initial_temperature) * decay_progress
280 }
281 }
282 }
283 }
284
285 pub fn step(&mut self) {
286 self.current_step += 1;
287 }
288}
289
290impl DomainAdapter {
295 pub fn new(adaptation_strength: f32) -> Self {
296 Self {
297 source_stats: DomainStatistics::default(),
298 target_stats: DomainStatistics::default(),
299 adaptation_weights: Vec::new(),
300 domain_classifier: None,
301 adaptation_strength,
302 }
303 }
304
305 pub fn adapt_embedding(&self, embedding: &Vector, is_source_domain: bool) -> Result<Vector> {
306 let input_values = embedding.as_f32();
307 let mut adapted_values = input_values.clone();
308
309 if self.adaptation_weights.len() != input_values.len() {
310 return Ok(embedding.clone());
311 }
312
313 let stats = if is_source_domain {
314 &self.source_stats
315 } else {
316 &self.target_stats
317 };
318
319 for (i, adapted_value) in adapted_values.iter_mut().enumerate() {
320 if i < stats.mean.len() && i < stats.variance.len() {
321 let normalized =
322 (*adapted_value - stats.mean[i]) / (stats.variance[i].sqrt() + 1e-8);
323 *adapted_value = normalized * self.adaptation_weights[i] * self.adaptation_strength
324 + *adapted_value * (1.0 - self.adaptation_strength);
325 }
326 }
327
328 Ok(Vector::new(adapted_values))
329 }
330
331 pub fn update_domain_statistics(&mut self, embeddings: &[Vector], is_source_domain: bool) {
332 let stats = if is_source_domain {
333 &mut self.source_stats
334 } else {
335 &mut self.target_stats
336 };
337
338 if embeddings.is_empty() {
339 return;
340 }
341
342 let dim = embeddings[0].dimensions;
343 if stats.mean.len() != dim {
344 stats.mean = vec![0.0; dim];
345 stats.variance = vec![0.0; dim];
346 stats.sample_count = 0;
347 }
348
349 for embedding in embeddings {
350 let values = embedding.as_f32();
351 for (i, &value) in values.iter().enumerate().take(dim) {
352 let delta = value - stats.mean[i];
353 stats.sample_count += 1;
354 stats.mean[i] += delta / stats.sample_count as f32;
355 let delta2 = value - stats.mean[i];
356 stats.variance[i] += delta * delta2;
357 }
358 }
359
360 if stats.sample_count > 1 {
361 for variance in &mut stats.variance {
362 *variance /= (stats.sample_count - 1) as f32;
363 }
364 }
365
366 self.update_adaptation_weights();
367 }
368
369 fn update_adaptation_weights(&mut self) {
370 let dim = self.source_stats.mean.len();
371 if dim == 0 || dim != self.target_stats.mean.len() {
372 return;
373 }
374
375 self.adaptation_weights = vec![1.0; dim];
376
377 for i in 0..dim {
378 let mean_diff = (self.source_stats.mean[i] - self.target_stats.mean[i]).abs();
379 let var_ratio = (self.source_stats.variance[i]
380 / (self.target_stats.variance[i] + 1e-8))
381 .ln()
382 .abs();
383
384 let discrepancy = mean_diff + 0.5 * var_ratio;
385 self.adaptation_weights[i] = 1.0 / (1.0 + discrepancy);
386 }
387 }
388}
389
390pub struct JointEmbeddingSpace {
396 pub(crate) config: JointEmbeddingConfig,
397 pub(crate) text_projector: LinearProjector,
398 pub(crate) image_projector: LinearProjector,
399 pub(crate) audio_projector: LinearProjector,
400 pub(crate) video_projector: LinearProjector,
401 pub(crate) attention_mechanism: CrossModalAttention,
402 pub(crate) alignment_cache: Arc<RwLock<HashMap<String, AlignmentPair>>>,
403 pub(crate) training_stats: Arc<RwLock<TrainingStatistics>>,
404 pub(crate) temperature_scheduler: TemperatureScheduler,
405 pub(crate) domain_adapter: DomainAdapter,
406}
407
408impl JointEmbeddingSpace {
409 pub fn new(config: JointEmbeddingConfig) -> Self {
410 let text_projector =
411 LinearProjector::new(768, config.joint_dim, 0.1, ActivationFunction::GELU);
412
413 let image_projector =
414 LinearProjector::new(2048, config.joint_dim, 0.1, ActivationFunction::GELU);
415
416 let audio_projector =
417 LinearProjector::new(1024, config.joint_dim, 0.1, ActivationFunction::GELU);
418
419 let video_projector =
420 LinearProjector::new(1536, config.joint_dim, 0.1, ActivationFunction::GELU);
421
422 let attention_mechanism = CrossModalAttention::new(config.joint_dim, 8, 0.1, true);
423
424 let temperature_scheduler = TemperatureScheduler::new(
425 config.temperature * 2.0,
426 config.temperature,
427 1000,
428 ScheduleType::Cosine,
429 );
430
431 let domain_adapter = DomainAdapter::new(config.alignment_strength);
432
433 Self {
434 config,
435 text_projector,
436 image_projector,
437 audio_projector,
438 video_projector,
439 attention_mechanism,
440 alignment_cache: Arc::new(RwLock::new(HashMap::new())),
441 training_stats: Arc::new(RwLock::new(TrainingStatistics::default())),
442 temperature_scheduler,
443 domain_adapter,
444 }
445 }
446
447 pub fn project_to_joint_space(&self, modality: Modality, embedding: &Vector) -> Result<Vector> {
449 let projected = match modality {
450 Modality::Text => self.text_projector.forward(embedding)?,
451 Modality::Image => self.image_projector.forward(embedding)?,
452 Modality::Audio => self.audio_projector.forward(embedding)?,
453 Modality::Video => self.video_projector.forward(embedding)?,
454 _ => self.text_projector.forward(embedding)?,
455 };
456
457 Ok(projected.normalized())
458 }
459
460 pub fn cross_modal_similarity(
462 &self,
463 modality1: Modality,
464 embedding1: &Vector,
465 modality2: Modality,
466 embedding2: &Vector,
467 ) -> Result<f32> {
468 let joint_emb1 = self.project_to_joint_space(modality1, embedding1)?;
469 let joint_emb2 = self.project_to_joint_space(modality2, embedding2)?;
470
471 if modality1 != modality2 {
472 let attended_emb1 =
473 self.attention_mechanism
474 .cross_attention(&joint_emb1, &joint_emb2, &joint_emb2)?;
475 let attended_emb2 =
476 self.attention_mechanism
477 .cross_attention(&joint_emb2, &joint_emb1, &joint_emb1)?;
478
479 attended_emb1.cosine_similarity(&attended_emb2)
480 } else {
481 joint_emb1.cosine_similarity(&joint_emb2)
482 }
483 }
484
485 pub fn contrastive_align(
487 &mut self,
488 positive_pairs: &[(Modality, Vector, Modality, Vector)],
489 negative_pairs: &[(Modality, Vector, Modality, Vector)],
490 ) -> Result<f32> {
491 let mut total_loss = 0.0;
492 let temperature = self.temperature_scheduler.get_current_temperature();
493
494 for (mod1, emb1, mod2, emb2) in positive_pairs {
495 let similarity = self.cross_modal_similarity(*mod1, emb1, *mod2, emb2)?;
496 let positive_score = similarity / temperature;
497 let positive_loss = -positive_score.ln_1p();
498 total_loss += positive_loss;
499
500 self.cache_alignment(*mod1, emb1.clone(), *mod2, emb2.clone(), similarity);
501 }
502
503 for (mod1, emb1, mod2, emb2) in negative_pairs {
504 let similarity = self.cross_modal_similarity(*mod1, emb1, *mod2, emb2)?;
505 let negative_score = similarity / temperature;
506 let negative_loss = (negative_score + self.config.margin).max(0.0);
507 total_loss += negative_loss;
508 }
509
510 self.update_training_stats(positive_pairs.len(), negative_pairs.len(), total_loss);
511 self.temperature_scheduler.step();
512
513 Ok(total_loss / (positive_pairs.len() + negative_pairs.len()) as f32)
514 }
515
516 pub fn zero_shot_retrieval(
518 &self,
519 query_modality: Modality,
520 query_embedding: &Vector,
521 target_modality: Modality,
522 target_embeddings: &[Vector],
523 top_k: usize,
524 ) -> Result<Vec<(usize, f32)>> {
525 let _query_joint = self.project_to_joint_space(query_modality, query_embedding)?;
527
528 self.cross_modal_search(
530 query_modality,
531 query_embedding,
532 target_modality,
533 target_embeddings,
534 top_k,
535 )
536 }
537
538 pub fn cross_modal_search(
540 &self,
541 query_modality: Modality,
542 query_embedding: &Vector,
543 candidate_modality: Modality,
544 candidate_embeddings: &[Vector],
545 top_k: usize,
546 ) -> Result<Vec<(usize, f32)>> {
547 let query_joint = self.project_to_joint_space(query_modality, query_embedding)?;
548 let mut similarities = Vec::new();
549
550 for (idx, candidate) in candidate_embeddings.iter().enumerate() {
551 let candidate_joint = self.project_to_joint_space(candidate_modality, candidate)?;
552
553 let similarity = if query_modality != candidate_modality {
554 let attended_query = self.attention_mechanism.cross_attention(
555 &query_joint,
556 &candidate_joint,
557 &candidate_joint,
558 )?;
559 attended_query.cosine_similarity(&candidate_joint)?
560 } else {
561 query_joint.cosine_similarity(&candidate_joint)?
562 };
563
564 similarities.push((idx, similarity));
565 }
566
567 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
568 similarities.truncate(top_k);
569
570 Ok(similarities)
571 }
572
573 pub fn multi_modal_fusion(&self, modalities: &[(Modality, Vector)]) -> Result<Vector> {
575 if modalities.is_empty() {
576 return Err(anyhow!("No modalities provided for fusion"));
577 }
578
579 let mut joint_embeddings = Vec::new();
580 for (modality, embedding) in modalities {
581 let joint_emb = self.project_to_joint_space(*modality, embedding)?;
582 joint_embeddings.push(joint_emb);
583 }
584
585 let mut attended_embeddings = Vec::new();
586 for i in 0..joint_embeddings.len() {
587 let mut attended = joint_embeddings[i].clone();
588
589 for j in 0..joint_embeddings.len() {
590 if i != j {
591 let cross_attended = self.attention_mechanism.cross_attention(
592 &joint_embeddings[i],
593 &joint_embeddings[j],
594 &joint_embeddings[j],
595 )?;
596
597 let weight = 1.0 / joint_embeddings.len() as f32;
598 attended = attended.add(&cross_attended.scale(weight))?;
599 }
600 }
601
602 attended_embeddings.push(attended);
603 }
604
605 if attended_embeddings.len() == 1 {
606 Ok(attended_embeddings[0].clone())
607 } else {
608 let mut fused = attended_embeddings[0].clone();
609 for embedding in attended_embeddings.iter().skip(1) {
610 fused = fused.add(embedding)?;
611 }
612 Ok(fused.scale(1.0 / attended_embeddings.len() as f32))
613 }
614 }
615
616 pub(crate) fn cache_alignment(
617 &self,
618 mod1: Modality,
619 emb1: Vector,
620 mod2: Modality,
621 emb2: Vector,
622 similarity: f32,
623 ) {
624 let alignment = AlignmentPair {
625 modality1: mod1,
626 modality2: mod2,
627 embedding1: emb1,
628 embedding2: emb2,
629 similarity,
630 confidence: similarity.abs(),
631 timestamp: std::time::SystemTime::now(),
632 };
633
634 let cache_key = format!("{mod1:?}_{mod2:?}_{similarity}");
635 let mut cache = self.alignment_cache.write();
636 cache.insert(cache_key, alignment);
637
638 if cache.len() > 10000 {
639 let mut entries: Vec<_> = cache.iter().collect();
640 entries.sort_by_key(|(_, v)| v.timestamp);
641 let oldest_key = entries[0].0.clone();
642 cache.remove(&oldest_key);
643 }
644 }
645
646 pub(crate) fn update_training_stats(
647 &self,
648 positive_count: usize,
649 negative_count: usize,
650 loss: f32,
651 ) {
652 let mut stats = self.training_stats.write();
653 stats.total_samples += (positive_count + negative_count) as u64;
654 stats.positive_pairs += positive_count as u64;
655 stats.negative_pairs += negative_count as u64;
656
657 let total_samples = stats.total_samples as f32;
658 stats.average_loss = (stats.average_loss * (total_samples - 1.0) + loss) / total_samples;
659 }
660
661 pub fn get_training_stats(&self) -> TrainingStatistics {
663 self.training_stats.read().clone()
664 }
665
666 pub fn get_cache_stats(&self) -> (usize, f32) {
668 let cache = self.alignment_cache.read();
669 let cache_size = cache.len();
670 let avg_similarity = if cache.is_empty() {
671 0.0
672 } else {
673 cache.values().map(|a| a.similarity).sum::<f32>() / cache_size as f32
674 };
675 (cache_size, avg_similarity)
676 }
677
678 pub fn evaluate_retrieval(
680 &self,
681 test_pairs: &[(Modality, Vector, Modality, Vector)],
682 distractors: &[(Modality, Vector)],
683 k_values: &[usize],
684 ) -> Result<HashMap<usize, f32>> {
685 let mut recall_at_k = HashMap::new();
686
687 for &k in k_values {
688 let mut total_recall = 0.0;
689
690 for (query_mod, query_emb, target_mod, target_emb) in test_pairs {
691 let mut candidates = vec![target_emb.clone()];
692 for (distractor_mod, distractor_emb) in distractors {
693 if *distractor_mod == *target_mod {
694 candidates.push(distractor_emb.clone());
695 }
696 }
697
698 let results =
699 self.cross_modal_search(*query_mod, query_emb, *target_mod, &candidates, k)?;
700
701 let found_target = results.iter().any(|(idx, _)| *idx == 0);
702 if found_target {
703 total_recall += 1.0;
704 }
705 }
706
707 recall_at_k.insert(k, total_recall / test_pairs.len() as f32);
708 }
709
710 Ok(recall_at_k)
711 }
712}