use crate::{
cross_modal_embeddings::{
AudioData, ImageData, Modality, ModalityData, MultiModalContent, VideoData,
},
Vector,
};
use anyhow::{anyhow, Result};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JointEmbeddingConfig {
pub joint_dim: usize,
pub temperature: f32,
pub learning_rate: f32,
pub margin: f32,
pub contrastive_learning: bool,
pub triplet_loss: bool,
pub hard_negative_mining: bool,
pub batch_size: usize,
pub negative_samples: usize,
pub curriculum_learning: bool,
pub weight_decay: f32,
pub gradient_clip: f32,
pub domain_adaptation: bool,
pub alignment_strength: f32,
pub self_supervised: bool,
}
impl Default for JointEmbeddingConfig {
fn default() -> Self {
Self {
joint_dim: 512,
temperature: 0.07,
learning_rate: 1e-4,
margin: 0.2,
contrastive_learning: true,
triplet_loss: false,
hard_negative_mining: true,
batch_size: 256,
negative_samples: 5,
curriculum_learning: false,
weight_decay: 1e-4,
gradient_clip: 1.0,
domain_adaptation: true,
alignment_strength: 1.0,
self_supervised: false,
}
}
}
type ContrastivePairs = (
Vec<(Modality, Vector, Modality, Vector)>,
Vec<(Modality, Vector, Modality, Vector)>,
);
pub struct JointEmbeddingSpace {
config: JointEmbeddingConfig,
text_projector: LinearProjector,
image_projector: LinearProjector,
audio_projector: LinearProjector,
video_projector: LinearProjector,
attention_mechanism: CrossModalAttention,
alignment_cache: Arc<RwLock<HashMap<String, AlignmentPair>>>,
training_stats: Arc<RwLock<TrainingStatistics>>,
temperature_scheduler: TemperatureScheduler,
domain_adapter: DomainAdapter,
}
#[derive(Debug, Clone)]
pub struct LinearProjector {
weights: Vec<Vec<f32>>,
bias: Vec<f32>,
input_dim: usize,
output_dim: usize,
dropout_rate: f32,
activation: ActivationFunction,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum ActivationFunction {
ReLU,
GELU,
Tanh,
Sigmoid,
Swish,
Mish,
LeakyReLU(f32),
}
#[derive(Debug, Clone)]
pub struct CrossModalAttention {
query_projector: LinearProjector,
key_projector: LinearProjector,
value_projector: LinearProjector,
output_projector: LinearProjector,
num_heads: usize,
head_dim: usize,
dropout_rate: f32,
scale: f32,
enable_relative_pos: bool,
}
#[derive(Debug, Clone)]
pub struct AlignmentPair {
modality1: Modality,
modality2: Modality,
embedding1: Vector,
embedding2: Vector,
similarity: f32,
confidence: f32,
timestamp: std::time::SystemTime,
}
#[derive(Debug, Clone, Default)]
pub struct TrainingStatistics {
total_samples: u64,
positive_pairs: u64,
negative_pairs: u64,
average_loss: f32,
average_similarity: f32,
convergence_rate: f32,
alignment_accuracy: f32,
cross_modal_retrieval_acc: HashMap<(Modality, Modality), f32>,
training_epochs: u32,
last_improvement: u32,
}
#[derive(Debug, Clone)]
pub struct TemperatureScheduler {
initial_temperature: f32,
final_temperature: f32,
decay_steps: usize,
current_step: usize,
schedule_type: ScheduleType,
}
#[derive(Debug, Clone, Copy)]
pub enum ScheduleType {
Linear,
Exponential,
Cosine,
Warmup,
}
#[derive(Debug, Clone)]
pub struct DomainAdapter {
source_stats: DomainStatistics,
target_stats: DomainStatistics,
adaptation_weights: Vec<f32>,
domain_classifier: Option<DomainClassifier>,
adaptation_strength: f32,
}
#[derive(Debug, Clone, Default)]
pub struct DomainStatistics {
mean: Vec<f32>,
variance: Vec<f32>,
sample_count: usize,
feature_statistics: HashMap<String, f32>,
}
#[derive(Debug, Clone)]
pub struct DomainClassifier {
weights: Vec<Vec<f32>>,
bias: Vec<f32>,
accuracy: f32,
}
pub struct CLIPAligner {
joint_space: JointEmbeddingSpace,
optimizer: ContrastiveOptimizer,
data_augmentation: DataAugmentation,
curriculum: CurriculumLearning,
}
#[derive(Debug, Clone)]
pub struct ContrastiveOptimizer {
learning_rate: f32,
momentum: f32,
weight_decay: f32,
gradient_history: HashMap<String, Vec<f32>>,
adaptive_lr: bool,
lr_schedule: LearningRateSchedule,
}
#[derive(Debug, Clone, Copy)]
pub enum LearningRateSchedule {
Constant,
StepDecay { step_size: usize, gamma: f32 },
ExponentialDecay { gamma: f32 },
CosineAnnealing { min_lr: f32, max_epochs: usize },
}
#[derive(Debug, Clone)]
pub struct DataAugmentation {
text_augmentations: Vec<TextAugmentation>,
image_augmentations: Vec<ImageAugmentation>,
audio_augmentations: Vec<AudioAugmentation>,
cross_modal_mixup: bool,
augmentation_probability: f32,
}
#[derive(Debug, Clone)]
pub enum TextAugmentation {
RandomWordDropout(f32),
Paraphrasing,
BackTranslation,
SynonymReplacement(f32),
ContextualAugmentation,
}
#[derive(Debug, Clone)]
pub enum ImageAugmentation {
RandomCrop {
size: (u32, u32),
},
RandomFlip {
horizontal: bool,
vertical: bool,
},
ColorJitter {
brightness: f32,
contrast: f32,
saturation: f32,
},
RandomRotation {
max_angle: f32,
},
GaussianBlur {
sigma: f32,
},
}
#[derive(Debug, Clone)]
pub enum AudioAugmentation {
TimeStretch { factor: f32 },
PitchShift { semitones: f32 },
AddNoise { snr_db: f32 },
FrequencyMasking { max_freq_mask: f32 },
TimeMasking { max_time_mask: f32 },
}
#[derive(Debug, Clone)]
pub struct CurriculumLearning {
enabled: bool,
current_difficulty: f32,
difficulty_schedule: DifficultySchedule,
pacing_function: PacingFunction,
competence_threshold: f32,
}
#[derive(Debug, Clone)]
pub enum DifficultySchedule {
Linear { start: f32, end: f32, epochs: usize },
Exponential { base: f32, scale: f32 },
Adaptive { improvement_threshold: f32 },
}
#[derive(Debug, Clone)]
pub enum PacingFunction {
Root,
Linear,
Logarithmic,
Polynomial(f32),
}
impl LinearProjector {
pub fn new(
input_dim: usize,
output_dim: usize,
dropout_rate: f32,
activation: ActivationFunction,
) -> Self {
let limit = (6.0 / (input_dim + output_dim) as f32).sqrt();
let mut weights = Vec::with_capacity(output_dim);
for _ in 0..output_dim {
let mut row = Vec::with_capacity(input_dim);
for _ in 0..input_dim {
let weight = ((row.len() as f32 * 0.01) % 2.0 - 1.0) * limit;
row.push(weight);
}
weights.push(row);
}
let bias = vec![0.0; output_dim];
Self {
weights,
bias,
input_dim,
output_dim,
dropout_rate,
activation,
}
}
pub fn forward(&self, input: &Vector) -> Result<Vector> {
if input.dimensions != self.input_dim {
return Err(anyhow!(
"Input dimension mismatch: expected {}, got {}",
self.input_dim,
input.dimensions
));
}
let input_values = input.as_f32();
let mut output = vec![0.0; self.output_dim];
for (i, output_val) in output.iter_mut().enumerate().take(self.output_dim) {
let mut sum = self.bias[i];
for (j, &input_val) in input_values.iter().enumerate().take(self.input_dim) {
sum += input_val * self.weights[i][j];
}
*output_val = sum;
}
for value in &mut output {
*value = self.apply_activation(*value);
}
if self.dropout_rate > 0.0 {
for (i, value) in output.iter_mut().enumerate() {
if (i as f32 * 0.12345) % 1.0 < self.dropout_rate {
*value = 0.0;
} else {
*value /= 1.0 - self.dropout_rate; }
}
}
Ok(Vector::new(output))
}
fn apply_activation(&self, x: f32) -> f32 {
match self.activation {
ActivationFunction::ReLU => x.max(0.0),
ActivationFunction::GELU => {
let sqrt_2_pi = (2.0 / std::f32::consts::PI).sqrt();
let inner = sqrt_2_pi * (x + 0.044715 * x.powi(3));
0.5 * x * (1.0 + inner.tanh())
}
ActivationFunction::Tanh => x.tanh(),
ActivationFunction::Sigmoid => 1.0 / (1.0 + (-x).exp()),
ActivationFunction::Swish => x * (1.0 / (1.0 + (-x).exp())), ActivationFunction::Mish => x * (1.0 + x.exp()).ln().tanh(),
ActivationFunction::LeakyReLU(alpha) => {
if x > 0.0 {
x
} else {
alpha * x
}
}
}
}
pub fn update_weights(&mut self, gradients: &[Vec<f32>], learning_rate: f32) {
for i in 0..self.output_dim {
for j in 0..self.input_dim {
if i < gradients.len() && j < gradients[i].len() {
self.weights[i][j] -= learning_rate * gradients[i][j];
}
}
}
}
}
impl CrossModalAttention {
pub fn new(
input_dim: usize,
num_heads: usize,
dropout_rate: f32,
enable_relative_pos: bool,
) -> Self {
let head_dim = input_dim / num_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
Self {
query_projector: LinearProjector::new(
input_dim,
input_dim,
dropout_rate,
ActivationFunction::ReLU,
),
key_projector: LinearProjector::new(
input_dim,
input_dim,
dropout_rate,
ActivationFunction::ReLU,
),
value_projector: LinearProjector::new(
input_dim,
input_dim,
dropout_rate,
ActivationFunction::ReLU,
),
output_projector: LinearProjector::new(
input_dim,
input_dim,
dropout_rate,
ActivationFunction::ReLU,
),
num_heads,
head_dim,
dropout_rate,
scale,
enable_relative_pos,
}
}
pub fn cross_attention(
&self,
query_modality: &Vector,
key_modality: &Vector,
value_modality: &Vector,
) -> Result<Vector> {
let query = self.query_projector.forward(query_modality)?;
let key = self.key_projector.forward(key_modality)?;
let value = self.value_projector.forward(value_modality)?;
let attended = self.multi_head_attention(&query, &key, &value)?;
self.output_projector.forward(&attended)
}
fn multi_head_attention(&self, query: &Vector, key: &Vector, value: &Vector) -> Result<Vector> {
let query_vals = query.as_f32();
let key_vals = key.as_f32();
let value_vals = value.as_f32();
if query_vals.len() != key_vals.len() || key_vals.len() != value_vals.len() {
return Err(anyhow!("Dimension mismatch in attention"));
}
let _seq_len = query_vals.len() / self.head_dim;
let mut output = vec![0.0; query_vals.len()];
for head in 0..self.num_heads {
let head_start = head * self.head_dim;
let head_end = head_start + self.head_dim;
let head_query = &query_vals[head_start..head_end];
let head_key = &key_vals[head_start..head_end];
let head_value = &value_vals[head_start..head_end];
let attention_score = self.compute_attention_score(head_query, head_key);
for i in 0..self.head_dim {
output[head_start + i] = head_value[i] * attention_score;
}
}
if self.enable_relative_pos {
self.apply_relative_position_encoding(&mut output)?;
}
Ok(Vector::new(output))
}
fn compute_attention_score(&self, query: &[f32], key: &[f32]) -> f32 {
let dot_product: f32 = query.iter().zip(key.iter()).map(|(q, k)| q * k).sum();
let scaled_score = dot_product * self.scale;
scaled_score.tanh() }
fn apply_relative_position_encoding(&self, output: &mut [f32]) -> Result<()> {
let output_len = output.len();
for (i, value) in output.iter_mut().enumerate() {
let pos_encoding = (i as f32 / output_len as f32).sin();
*value += 0.1 * pos_encoding; }
Ok(())
}
}
impl TemperatureScheduler {
pub fn new(
initial_temperature: f32,
final_temperature: f32,
decay_steps: usize,
schedule_type: ScheduleType,
) -> Self {
Self {
initial_temperature,
final_temperature,
decay_steps,
current_step: 0,
schedule_type,
}
}
pub fn get_current_temperature(&self) -> f32 {
if self.current_step >= self.decay_steps {
return self.final_temperature;
}
let progress = self.current_step as f32 / self.decay_steps as f32;
match self.schedule_type {
ScheduleType::Linear => {
self.initial_temperature
+ (self.final_temperature - self.initial_temperature) * progress
}
ScheduleType::Exponential => {
self.initial_temperature
* (self.final_temperature / self.initial_temperature).powf(progress)
}
ScheduleType::Cosine => {
let cosine_progress = 0.5 * (1.0 + (std::f32::consts::PI * progress).cos());
self.final_temperature
+ (self.initial_temperature - self.final_temperature) * cosine_progress
}
ScheduleType::Warmup => {
if progress < 0.1 {
self.initial_temperature * (progress / 0.1)
} else {
let decay_progress = (progress - 0.1) / 0.9;
self.initial_temperature
+ (self.final_temperature - self.initial_temperature) * decay_progress
}
}
}
}
pub fn step(&mut self) {
self.current_step += 1;
}
}
impl DomainAdapter {
pub fn new(adaptation_strength: f32) -> Self {
Self {
source_stats: DomainStatistics::default(),
target_stats: DomainStatistics::default(),
adaptation_weights: Vec::new(),
domain_classifier: None,
adaptation_strength,
}
}
pub fn adapt_embedding(&self, embedding: &Vector, is_source_domain: bool) -> Result<Vector> {
let input_values = embedding.as_f32();
let mut adapted_values = input_values.clone();
if self.adaptation_weights.len() != input_values.len() {
return Ok(embedding.clone()); }
let stats = if is_source_domain {
&self.source_stats
} else {
&self.target_stats
};
for (i, adapted_value) in adapted_values.iter_mut().enumerate() {
if i < stats.mean.len() && i < stats.variance.len() {
let normalized =
(*adapted_value - stats.mean[i]) / (stats.variance[i].sqrt() + 1e-8);
*adapted_value = normalized * self.adaptation_weights[i] * self.adaptation_strength
+ *adapted_value * (1.0 - self.adaptation_strength);
}
}
Ok(Vector::new(adapted_values))
}
pub fn update_domain_statistics(&mut self, embeddings: &[Vector], is_source_domain: bool) {
let stats = if is_source_domain {
&mut self.source_stats
} else {
&mut self.target_stats
};
if embeddings.is_empty() {
return;
}
let dim = embeddings[0].dimensions;
if stats.mean.len() != dim {
stats.mean = vec![0.0; dim];
stats.variance = vec![0.0; dim];
stats.sample_count = 0;
}
for embedding in embeddings {
let values = embedding.as_f32();
for (i, &value) in values.iter().enumerate().take(dim) {
let delta = value - stats.mean[i];
stats.sample_count += 1;
stats.mean[i] += delta / stats.sample_count as f32;
let delta2 = value - stats.mean[i];
stats.variance[i] += delta * delta2;
}
}
if stats.sample_count > 1 {
for variance in &mut stats.variance {
*variance /= (stats.sample_count - 1) as f32;
}
}
self.update_adaptation_weights();
}
fn update_adaptation_weights(&mut self) {
let dim = self.source_stats.mean.len();
if dim == 0 || dim != self.target_stats.mean.len() {
return;
}
self.adaptation_weights = vec![1.0; dim];
for i in 0..dim {
let mean_diff = (self.source_stats.mean[i] - self.target_stats.mean[i]).abs();
let var_ratio = (self.source_stats.variance[i]
/ (self.target_stats.variance[i] + 1e-8))
.ln()
.abs();
let discrepancy = mean_diff + 0.5 * var_ratio;
self.adaptation_weights[i] = 1.0 / (1.0 + discrepancy);
}
}
}
impl JointEmbeddingSpace {
pub fn new(config: JointEmbeddingConfig) -> Self {
let text_projector = LinearProjector::new(
768, config.joint_dim,
0.1,
ActivationFunction::GELU,
);
let image_projector = LinearProjector::new(
2048, config.joint_dim,
0.1,
ActivationFunction::GELU,
);
let audio_projector = LinearProjector::new(
1024, config.joint_dim,
0.1,
ActivationFunction::GELU,
);
let video_projector = LinearProjector::new(
1536, config.joint_dim,
0.1,
ActivationFunction::GELU,
);
let attention_mechanism = CrossModalAttention::new(config.joint_dim, 8, 0.1, true);
let temperature_scheduler = TemperatureScheduler::new(
config.temperature * 2.0,
config.temperature,
1000,
ScheduleType::Cosine,
);
let domain_adapter = DomainAdapter::new(config.alignment_strength);
Self {
config,
text_projector,
image_projector,
audio_projector,
video_projector,
attention_mechanism,
alignment_cache: Arc::new(RwLock::new(HashMap::new())),
training_stats: Arc::new(RwLock::new(TrainingStatistics::default())),
temperature_scheduler,
domain_adapter,
}
}
pub fn project_to_joint_space(&self, modality: Modality, embedding: &Vector) -> Result<Vector> {
let projected = match modality {
Modality::Text => self.text_projector.forward(embedding)?,
Modality::Image => self.image_projector.forward(embedding)?,
Modality::Audio => self.audio_projector.forward(embedding)?,
Modality::Video => self.video_projector.forward(embedding)?,
_ => {
self.text_projector.forward(embedding)?
}
};
Ok(projected.normalized())
}
pub fn cross_modal_similarity(
&self,
modality1: Modality,
embedding1: &Vector,
modality2: Modality,
embedding2: &Vector,
) -> Result<f32> {
let joint_emb1 = self.project_to_joint_space(modality1, embedding1)?;
let joint_emb2 = self.project_to_joint_space(modality2, embedding2)?;
if modality1 != modality2 {
let attended_emb1 =
self.attention_mechanism
.cross_attention(&joint_emb1, &joint_emb2, &joint_emb2)?;
let attended_emb2 =
self.attention_mechanism
.cross_attention(&joint_emb2, &joint_emb1, &joint_emb1)?;
attended_emb1.cosine_similarity(&attended_emb2)
} else {
joint_emb1.cosine_similarity(&joint_emb2)
}
}
pub fn contrastive_align(
&mut self,
positive_pairs: &[(Modality, Vector, Modality, Vector)],
negative_pairs: &[(Modality, Vector, Modality, Vector)],
) -> Result<f32> {
let mut total_loss = 0.0;
let temperature = self.temperature_scheduler.get_current_temperature();
for (mod1, emb1, mod2, emb2) in positive_pairs {
let similarity = self.cross_modal_similarity(*mod1, emb1, *mod2, emb2)?;
let positive_score = similarity / temperature;
let positive_loss = -positive_score.ln_1p(); total_loss += positive_loss;
self.cache_alignment(*mod1, emb1.clone(), *mod2, emb2.clone(), similarity);
}
for (mod1, emb1, mod2, emb2) in negative_pairs {
let similarity = self.cross_modal_similarity(*mod1, emb1, *mod2, emb2)?;
let negative_score = similarity / temperature;
let negative_loss = (negative_score + self.config.margin).max(0.0);
total_loss += negative_loss;
}
self.update_training_stats(positive_pairs.len(), negative_pairs.len(), total_loss);
self.temperature_scheduler.step();
Ok(total_loss / (positive_pairs.len() + negative_pairs.len()) as f32)
}
pub fn cross_modal_search(
&self,
query_modality: Modality,
query_embedding: &Vector,
candidate_modality: Modality,
candidate_embeddings: &[Vector],
top_k: usize,
) -> Result<Vec<(usize, f32)>> {
let query_joint = self.project_to_joint_space(query_modality, query_embedding)?;
let mut similarities = Vec::new();
for (idx, candidate) in candidate_embeddings.iter().enumerate() {
let candidate_joint = self.project_to_joint_space(candidate_modality, candidate)?;
let similarity = if query_modality != candidate_modality {
let attended_query = self.attention_mechanism.cross_attention(
&query_joint,
&candidate_joint,
&candidate_joint,
)?;
attended_query.cosine_similarity(&candidate_joint)?
} else {
query_joint.cosine_similarity(&candidate_joint)?
};
similarities.push((idx, similarity));
}
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
similarities.truncate(top_k);
Ok(similarities)
}
pub fn zero_shot_retrieval(
&self,
query_modality: Modality,
query_embedding: &Vector,
target_modality: Modality,
target_embeddings: &[Vector],
top_k: usize,
) -> Result<Vec<(usize, f32)>> {
let _query_joint = self.project_to_joint_space(query_modality, query_embedding)?;
self.cross_modal_search(
query_modality,
query_embedding,
target_modality,
target_embeddings,
top_k,
)
}
pub fn multi_modal_fusion(&self, modalities: &[(Modality, Vector)]) -> Result<Vector> {
if modalities.is_empty() {
return Err(anyhow!("No modalities provided for fusion"));
}
let mut joint_embeddings = Vec::new();
for (modality, embedding) in modalities {
let joint_emb = self.project_to_joint_space(*modality, embedding)?;
joint_embeddings.push(joint_emb);
}
let mut attended_embeddings = Vec::new();
for i in 0..joint_embeddings.len() {
let mut attended = joint_embeddings[i].clone();
for j in 0..joint_embeddings.len() {
if i != j {
let cross_attended = self.attention_mechanism.cross_attention(
&joint_embeddings[i],
&joint_embeddings[j],
&joint_embeddings[j],
)?;
let weight = 1.0 / joint_embeddings.len() as f32;
attended = attended.add(&cross_attended.scale(weight))?;
}
}
attended_embeddings.push(attended);
}
if attended_embeddings.len() == 1 {
Ok(attended_embeddings[0].clone())
} else {
let mut fused = attended_embeddings[0].clone();
for embedding in attended_embeddings.iter().skip(1) {
fused = fused.add(embedding)?;
}
Ok(fused.scale(1.0 / attended_embeddings.len() as f32))
}
}
fn cache_alignment(
&self,
mod1: Modality,
emb1: Vector,
mod2: Modality,
emb2: Vector,
similarity: f32,
) {
let alignment = AlignmentPair {
modality1: mod1,
modality2: mod2,
embedding1: emb1,
embedding2: emb2,
similarity,
confidence: similarity.abs(), timestamp: std::time::SystemTime::now(),
};
let cache_key = format!("{mod1:?}_{mod2:?}_{similarity}");
let mut cache = self.alignment_cache.write();
cache.insert(cache_key, alignment);
if cache.len() > 10000 {
let mut entries: Vec<_> = cache.iter().collect();
entries.sort_by_key(|(_, v)| v.timestamp);
let oldest_key = entries[0].0.clone();
cache.remove(&oldest_key);
}
}
fn update_training_stats(&self, positive_count: usize, negative_count: usize, loss: f32) {
let mut stats = self.training_stats.write();
stats.total_samples += (positive_count + negative_count) as u64;
stats.positive_pairs += positive_count as u64;
stats.negative_pairs += negative_count as u64;
let total_samples = stats.total_samples as f32;
stats.average_loss = (stats.average_loss * (total_samples - 1.0) + loss) / total_samples;
}
pub fn get_training_stats(&self) -> TrainingStatistics {
self.training_stats.read().clone()
}
pub fn get_cache_stats(&self) -> (usize, f32) {
let cache = self.alignment_cache.read();
let cache_size = cache.len();
let avg_similarity = if cache.is_empty() {
0.0
} else {
cache.values().map(|a| a.similarity).sum::<f32>() / cache_size as f32
};
(cache_size, avg_similarity)
}
pub fn evaluate_retrieval(
&self,
test_pairs: &[(Modality, Vector, Modality, Vector)],
distractors: &[(Modality, Vector)],
k_values: &[usize],
) -> Result<HashMap<usize, f32>> {
let mut recall_at_k = HashMap::new();
for &k in k_values {
let mut total_recall = 0.0;
for (query_mod, query_emb, target_mod, target_emb) in test_pairs {
let mut candidates = vec![target_emb.clone()];
for (distractor_mod, distractor_emb) in distractors {
if *distractor_mod == *target_mod {
candidates.push(distractor_emb.clone());
}
}
let results =
self.cross_modal_search(*query_mod, query_emb, *target_mod, &candidates, k)?;
let found_target = results.iter().any(|(idx, _)| *idx == 0);
if found_target {
total_recall += 1.0;
}
}
recall_at_k.insert(k, total_recall / test_pairs.len() as f32);
}
Ok(recall_at_k)
}
}
impl CLIPAligner {
pub fn new(config: JointEmbeddingConfig) -> Self {
let joint_space = JointEmbeddingSpace::new(config.clone());
let optimizer = ContrastiveOptimizer::new(config.learning_rate, 0.9, config.weight_decay);
let data_augmentation = DataAugmentation::default();
let curriculum = CurriculumLearning::new();
Self {
joint_space,
optimizer,
data_augmentation,
curriculum,
}
}
pub fn train_alignment(
&mut self,
training_data: &[(MultiModalContent, MultiModalContent)],
epochs: usize,
) -> Result<Vec<f32>> {
let mut epoch_losses = Vec::new();
for epoch in 0..epochs {
let mut epoch_loss = 0.0;
let mut batch_count = 0;
for batch in training_data.chunks(self.joint_space.config.batch_size) {
let (positive_pairs, negative_pairs) = self.create_contrastive_pairs(batch)?;
let augmented_positive = self.augment_pairs(&positive_pairs)?;
let augmented_negative = self.augment_pairs(&negative_pairs)?;
let batch_loss = self
.joint_space
.contrastive_align(&augmented_positive, &augmented_negative)?;
epoch_loss += batch_loss;
batch_count += 1;
if self.curriculum.enabled {
self.curriculum.update_difficulty(batch_loss);
}
}
let avg_epoch_loss = epoch_loss / batch_count as f32;
epoch_losses.push(avg_epoch_loss);
self.optimizer.step_schedule();
tracing::info!(
"Epoch {}/{}: Average Loss = {:.4}, Temperature = {:.4}",
epoch + 1,
epochs,
avg_epoch_loss,
self.joint_space
.temperature_scheduler
.get_current_temperature()
);
}
Ok(epoch_losses)
}
fn create_contrastive_pairs(
&self,
batch: &[(MultiModalContent, MultiModalContent)],
) -> Result<ContrastivePairs> {
let mut positive_pairs = Vec::new();
let mut negative_pairs = Vec::new();
for (content1, content2) in batch {
for (mod1, data1) in &content1.modalities {
for (mod2, data2) in &content2.modalities {
if let (Ok(emb1), Ok(emb2)) = (
self.extract_embedding(*mod1, data1),
self.extract_embedding(*mod2, data2),
) {
positive_pairs.push((*mod1, emb1, *mod2, emb2));
}
}
}
}
let batch_size = batch.len();
for i in 0..batch_size {
for j in 0..batch_size {
if i != j {
let (content1, _) = &batch[i];
let (_, content2) = &batch[j];
for (mod1, data1) in &content1.modalities {
for (mod2, data2) in &content2.modalities {
if let (Ok(emb1), Ok(emb2)) = (
self.extract_embedding(*mod1, data1),
self.extract_embedding(*mod2, data2),
) {
negative_pairs.push((*mod1, emb1, *mod2, emb2));
}
}
}
}
}
}
let max_negatives = positive_pairs.len() * self.joint_space.config.negative_samples;
negative_pairs.truncate(max_negatives);
Ok((positive_pairs, negative_pairs))
}
fn extract_embedding(&self, modality: Modality, data: &ModalityData) -> Result<Vector> {
match (modality, data) {
(Modality::Text, ModalityData::Text(text)) => {
let words: Vec<&str> = text.split_whitespace().collect();
let embedding = self.create_text_embedding(&words);
Ok(embedding)
}
(Modality::Image, ModalityData::Image(image)) => {
let embedding = self.create_image_embedding(image);
Ok(embedding)
}
(Modality::Audio, ModalityData::Audio(audio)) => {
let embedding = self.create_audio_embedding(audio);
Ok(embedding)
}
(Modality::Video, ModalityData::Video(video)) => {
let embedding = self.create_video_embedding(video);
Ok(embedding)
}
(Modality::Numeric, ModalityData::Numeric(values)) => Ok(Vector::new(values.clone())),
_ => Err(anyhow!("Modality-data type mismatch")),
}
}
fn create_text_embedding(&self, words: &[&str]) -> Vector {
let mut embedding = vec![0.0; 768];
for (i, word) in words.iter().enumerate().take(100) {
let hash = self.simple_hash(word) as usize;
let idx = hash % embedding.len();
embedding[idx] += 1.0 / (i + 1) as f32; }
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for value in &mut embedding {
*value /= norm;
}
}
Vector::new(embedding)
}
fn create_image_embedding(&self, image: &ImageData) -> Vector {
let mut embedding = vec![0.0; 2048];
let color_features = self.extract_color_features(image);
for (i, &feature) in color_features.iter().enumerate().take(256) {
if i < embedding.len() {
embedding[i] = feature;
}
}
let texture_features = self.extract_texture_features(image);
for (i, &feature) in texture_features.iter().enumerate().take(256) {
if i + 256 < embedding.len() {
embedding[i + 256] = feature;
}
}
Vector::new(embedding)
}
fn create_audio_embedding(&self, audio: &AudioData) -> Vector {
let mut embedding = vec![0.0; 1024];
if let Some(ref features) = audio.features {
for (i, &feature) in features.iter().enumerate().take(embedding.len()) {
embedding[i] = feature;
}
} else {
let spectral_features = self.extract_spectral_features(audio);
for (i, &feature) in spectral_features.iter().enumerate().take(embedding.len()) {
embedding[i] = feature;
}
}
Vector::new(embedding)
}
fn create_video_embedding(&self, video: &VideoData) -> Vector {
let mut embedding = vec![0.0; 1536];
if !video.frames.is_empty() {
let frame_embedding = self.create_image_embedding(&video.frames[0]);
let frame_values = frame_embedding.as_f32();
for (i, &value) in frame_values.iter().enumerate().take(1024) {
if i < embedding.len() {
embedding[i] = value;
}
}
}
if let Some(ref audio) = video.audio {
let audio_embedding = self.create_audio_embedding(audio);
let audio_values = audio_embedding.as_f32();
for (i, &value) in audio_values.iter().enumerate().take(512) {
if i + 1024 < embedding.len() {
embedding[i + 1024] = value;
}
}
}
Vector::new(embedding)
}
fn simple_hash(&self, text: &str) -> u64 {
let mut hash = 5381u64;
for byte in text.bytes() {
hash = hash.wrapping_mul(33).wrapping_add(byte as u64);
}
hash
}
fn extract_color_features(&self, image: &ImageData) -> Vec<f32> {
let mut histogram = vec![0.0; 256];
match image.format {
crate::cross_modal_embeddings::ImageFormat::RGB => {
for chunk in image.data.chunks(3) {
if chunk.len() == 3 {
let intensity = (chunk[0] as f32 + chunk[1] as f32 + chunk[2] as f32) / 3.0;
let bin = (intensity as usize).min(255);
histogram[bin] += 1.0;
}
}
}
_ => {
for &pixel in &image.data {
let bin = (pixel as usize).min(255);
histogram[bin] += 1.0;
}
}
}
let total: f32 = histogram.iter().sum();
if total > 0.0 {
for value in &mut histogram {
*value /= total;
}
}
histogram
}
fn extract_texture_features(&self, image: &ImageData) -> Vec<f32> {
let mut features = vec![0.0; 256];
let width = image.width as usize;
let height = image.height as usize;
if width > 2 && height > 2 {
for y in 1..height - 1 {
for x in 1..width - 1 {
let center_idx = y * width + x;
if center_idx < image.data.len() {
let center = image.data[center_idx];
let mut pattern = 0u8;
let neighbors = [
(-1, -1),
(0, -1),
(1, -1),
(-1, 0),
(1, 0),
(-1, 1),
(0, 1),
(1, 1),
];
for (bit, (dx, dy)) in neighbors.iter().enumerate() {
let nx = (x as i32 + dx) as usize;
let ny = (y as i32 + dy) as usize;
let neighbor_idx = ny * width + nx;
if neighbor_idx < image.data.len() && image.data[neighbor_idx] > center
{
pattern |= 1 << bit;
}
}
features[pattern as usize] += 1.0;
}
}
}
}
let total: f32 = features.iter().sum();
if total > 0.0 {
for value in &mut features {
*value /= total;
}
}
features
}
fn extract_spectral_features(&self, audio: &AudioData) -> Vec<f32> {
let mut features = vec![0.0; 128];
if !audio.samples.is_empty() {
let chunk_size = audio.samples.len() / features.len();
for (i, feature) in features.iter_mut().enumerate() {
let start = i * chunk_size;
let end = ((i + 1) * chunk_size).min(audio.samples.len());
if start < end {
let chunk = &audio.samples[start..end];
let energy: f32 = chunk.iter().map(|x| x * x).sum();
*feature = energy.sqrt() / (chunk.len() as f32).sqrt();
}
}
}
features
}
fn augment_pairs(
&self,
pairs: &[(Modality, Vector, Modality, Vector)],
) -> Result<Vec<(Modality, Vector, Modality, Vector)>> {
let mut augmented = Vec::new();
for (mod1, emb1, mod2, emb2) in pairs {
let aug_emb1 = self.add_noise(emb1, 0.01)?;
let aug_emb2 = self.add_noise(emb2, 0.01)?;
augmented.push((*mod1, aug_emb1, *mod2, aug_emb2));
}
Ok(augmented)
}
fn add_noise(&self, embedding: &Vector, noise_std: f32) -> Result<Vector> {
let values = embedding.as_f32();
let mut noisy_values = Vec::with_capacity(values.len());
for (i, &value) in values.iter().enumerate() {
let noise = ((i as f32 * 0.1234).sin() * noise_std).clamp(-0.1, 0.1);
noisy_values.push(value + noise);
}
Ok(Vector::new(noisy_values))
}
}
impl ContrastiveOptimizer {
pub fn new(learning_rate: f32, momentum: f32, weight_decay: f32) -> Self {
Self {
learning_rate,
momentum,
weight_decay,
gradient_history: HashMap::new(),
adaptive_lr: true,
lr_schedule: LearningRateSchedule::CosineAnnealing {
min_lr: learning_rate * 0.01,
max_epochs: 100,
},
}
}
pub fn step_schedule(&mut self) {
match self.lr_schedule {
LearningRateSchedule::StepDecay {
step_size: _,
gamma,
} => {
self.learning_rate *= gamma;
}
LearningRateSchedule::ExponentialDecay { gamma } => {
self.learning_rate *= gamma;
}
LearningRateSchedule::CosineAnnealing {
min_lr,
max_epochs: _,
} => {
let progress = 0.01; let lr_range = self.learning_rate - min_lr;
self.learning_rate =
min_lr + lr_range * (1.0 + (std::f32::consts::PI * progress).cos()) / 2.0;
}
LearningRateSchedule::Constant => {
}
}
}
}
impl Default for DataAugmentation {
fn default() -> Self {
Self {
text_augmentations: vec![
TextAugmentation::RandomWordDropout(0.1),
TextAugmentation::SynonymReplacement(0.1),
],
image_augmentations: vec![
ImageAugmentation::RandomFlip {
horizontal: true,
vertical: false,
},
ImageAugmentation::ColorJitter {
brightness: 0.2,
contrast: 0.2,
saturation: 0.2,
},
],
audio_augmentations: vec![
AudioAugmentation::AddNoise { snr_db: 20.0 },
AudioAugmentation::TimeStretch { factor: 1.1 },
],
cross_modal_mixup: false,
augmentation_probability: 0.5,
}
}
}
impl Default for CurriculumLearning {
fn default() -> Self {
Self::new()
}
}
impl CurriculumLearning {
pub fn new() -> Self {
Self {
enabled: false,
current_difficulty: 0.0,
difficulty_schedule: DifficultySchedule::Linear {
start: 0.0,
end: 1.0,
epochs: 50,
},
pacing_function: PacingFunction::Root,
competence_threshold: 0.8,
}
}
pub fn update_difficulty(&mut self, loss: f32) {
if self.enabled {
if loss < self.competence_threshold {
self.current_difficulty = (self.current_difficulty + 0.01).min(1.0);
} else {
self.current_difficulty = (self.current_difficulty - 0.005).max(0.0);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
#[test]
fn test_joint_embedding_space() -> Result<()> {
let config = JointEmbeddingConfig::default();
let joint_space = JointEmbeddingSpace::new(config);
let text_embedding = Vector::new(vec![0.1; 768]);
let image_embedding = Vector::new(vec![0.2; 2048]);
let joint_text = joint_space.project_to_joint_space(Modality::Text, &text_embedding)?;
let joint_image = joint_space.project_to_joint_space(Modality::Image, &image_embedding)?;
assert_eq!(joint_text.dimensions, 512);
assert_eq!(joint_image.dimensions, 512);
let similarity = joint_space.cross_modal_similarity(
Modality::Text,
&text_embedding,
Modality::Image,
&image_embedding,
)?;
assert!((-1.0..=1.0).contains(&similarity));
Ok(())
}
#[test]
fn test_cross_modal_attention() -> Result<()> {
let attention = CrossModalAttention::new(128, 4, 0.1, true);
let query = Vector::new(vec![0.1; 128]);
let key = Vector::new(vec![0.2; 128]);
let value = Vector::new(vec![0.3; 128]);
let result = attention.cross_attention(&query, &key, &value)?;
assert_eq!(result.dimensions, 128);
Ok(())
}
#[test]
fn test_contrastive_learning() -> Result<()> {
let config = JointEmbeddingConfig::default();
let mut joint_space = JointEmbeddingSpace::new(config);
let positive_pairs = vec![(
Modality::Text,
Vector::new(vec![0.1; 768]),
Modality::Image,
Vector::new(vec![0.1; 2048]),
)];
let negative_pairs = vec![(
Modality::Text,
Vector::new(vec![0.1; 768]),
Modality::Image,
Vector::new(vec![-0.1; 2048]),
)];
let loss = joint_space.contrastive_align(&positive_pairs, &negative_pairs)?;
assert!(loss >= 0.0);
Ok(())
}
#[test]
fn test_clip_aligner() {
let config = JointEmbeddingConfig::default();
let aligner = CLIPAligner::new(config);
let text_words = vec!["hello", "world"];
let text_embedding = aligner.create_text_embedding(&text_words);
assert_eq!(text_embedding.dimensions, 768);
let (cache_size, _) = aligner.joint_space.get_cache_stats();
assert_eq!(cache_size, 0); }
}