use ndarray::{Array, Array1, Array2};
use std::collections::HashMap;
use crate::{EmbeddingModel, TrainingData};
#[derive(Debug, Clone)]
pub struct MultimodalFusion {
pub text_dim: usize,
pub aux_dim: usize,
pub fused_dim: usize,
}
impl MultimodalFusion {
pub fn new(text_dim: usize, aux_dim: usize, fused_dim: usize) -> Self {
Self {
text_dim,
aux_dim,
fused_dim,
}
}
pub fn concatenate(&self, text: &Array1<f32>, aux: &Array1<f32>) -> Array1<f32> {
let mut result = Array::zeros(self.text_dim + self.aux_dim);
result.slice_mut(ndarray::s![..self.text_dim]).assign(text);
result.slice_mut(ndarray::s![self.text_dim..]).assign(aux);
result
}
pub fn weighted_average(&self, text: &Array1<f32>, aux: &Array1<f32>, text_weight: f32) -> Option<Array1<f32>> {
if text.len() != aux.len() {
return None;
}
let aux_weight = 1.0 - text_weight;
Some(text * text_weight + aux * aux_weight)
}
pub fn attention_fusion(&self, text: &Array1<f32>, aux: &Array1<f32>) -> Option<Array1<f32>> {
if text.len() != aux.len() {
return None;
}
let dot: f32 = text.iter().zip(aux.iter()).map(|(&a, &b)| a * b).sum();
let norm_t = text.iter().map(|&x| x * x).sum::<f32>().sqrt();
let norm_a = aux.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm_t == 0.0 || norm_a == 0.0 {
return Some(text.clone());
}
let raw_attn = (dot / (norm_t * norm_a)).tanh();
let attn = 0.5 + 0.5 * raw_attn; Some(text * attn + aux * (1.0 - attn))
}
pub fn project_and_fuse(
&self,
text: &Array1<f32>,
aux: &Array1<f32>,
text_proj: &Array2<f32>,
aux_proj: &Array2<f32>,
) -> Option<Array1<f32>> {
if text_proj.shape()[1] != self.fused_dim || aux_proj.shape()[1] != self.fused_dim {
return None;
}
let text_p = text.dot(text_proj);
let aux_p = aux.dot(aux_proj);
Some(&(text_p + aux_p) / 2.0)
}
pub fn cross_modal_similarity(text: &Array1<f32>, aux: &Array1<f32>) -> f32 {
if text.len() != aux.len() {
return 0.0;
}
let dot: f32 = text.iter().zip(aux.iter()).map(|(&a, &b)| a * b).sum();
let norm_t = text.iter().map(|&x| x * x).sum::<f32>().sqrt();
let norm_a = aux.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm_t == 0.0 || norm_a == 0.0 {
0.0
} else {
dot / (norm_t * norm_a)
}
}
}
#[derive(Debug, Clone)]
pub struct CrossLingualAligner {
pub projection: Array2<f32>,
}
impl CrossLingualAligner {
pub fn new(dim: usize) -> Self {
let mut proj = Array::zeros((dim, dim));
for i in 0..dim {
proj[[i, i]] = 1.0;
}
Self { projection: proj }
}
pub fn align(&self, embedding: &Array1<f32>) -> Array1<f32> {
self.projection.dot(embedding)
}
pub fn train_from_dictionary(
&mut self,
pairs: &[(Array1<f32>, Array1<f32>)],
epochs: usize,
learning_rate: f32,
) {
let dim = self.projection.nrows();
for _ in 0..epochs {
for (src, tgt) in pairs {
let pred = self.projection.dot(src);
let error = &pred - tgt;
for i in 0..dim {
let grad = error[i] * src;
let mut row = self.projection.row_mut(i);
row -= &(grad * learning_rate);
}
}
}
}
}
pub struct DomainAdapter;
impl DomainAdapter {
pub fn adapt(
model: &mut EmbeddingModel,
data: &mut TrainingData,
domain_sentences: &[Vec<String>],
epochs: usize,
) -> Result<(), String> {
let mut domain_data = TrainingData {
sentences: domain_sentences.to_vec(),
vocab: data.vocab.clone(),
reverse_vocab: data.reverse_vocab.clone(),
};
let domain_words: Vec<String> = domain_sentences
.iter()
.flat_map(|s| s.iter().cloned())
.collect::<std::collections::HashSet<String>>()
.into_iter()
.collect();
model.incremental_vocab_update(&domain_words, &mut domain_data)?;
let original_epochs = model.config.epochs;
model.config.epochs = epochs;
model.train(&domain_data)?;
model.config.epochs = original_epochs;
*data = domain_data;
Ok(())
}
}
pub struct DocumentEmbedder;
impl DocumentEmbedder {
pub fn embed_document(
model: &EmbeddingModel,
data: &TrainingData,
sentences: &[Vec<String>],
) -> Option<Array1<f32>> {
if sentences.is_empty() {
return None;
}
let mut sum = Array::zeros(model.config.embedding_dim);
let mut count = 0usize;
for sentence in sentences {
if let Some(emb) = model.sentence_embedding(sentence, data) {
sum += &emb;
count += 1;
}
}
if count == 0 {
return None;
}
Some(&sum / (count as f32))
}
}
pub struct SubwordEmbedder {
pub min_n: usize,
pub max_n: usize,
}
impl SubwordEmbedder {
pub fn new(min_n: usize, max_n: usize) -> Self {
Self { min_n, max_n }
}
pub fn ngrams(&self, word: &str) -> Vec<String> {
let bounded = format!("<{}>", word);
let chars: Vec<char> = bounded.chars().collect();
let mut result = Vec::new();
for n in self.min_n..=self.max_n {
if n > chars.len() {
continue;
}
for window in chars.windows(n) {
result.push(window.iter().collect());
}
}
result
}
pub fn embed(&self, word: &str, ngram_vectors: &HashMap<String, Array1<f32>>) -> Option<Array1<f32>> {
let grams = self.ngrams(word);
if grams.is_empty() {
return None;
}
let mut sum: Option<Array1<f32>> = None;
let mut count = 0usize;
for gram in grams {
if let Some(vec) = ngram_vectors.get(&gram) {
sum = Some(match sum {
Some(s) => s + vec,
None => vec.clone(),
});
count += 1;
}
}
sum.map(|s| s / (count.max(1) as f32))
}
}
pub struct ZeroShotTransfer;
impl ZeroShotTransfer {
pub fn classify(
query: &Array1<f32>,
class_prototypes: &HashMap<String, Array1<f32>>,
) -> Option<(String, f32)> {
let mut best_label = None;
let mut best_sim = f32::NEG_INFINITY;
for (label, proto) in class_prototypes {
let dot: f32 = query.iter().zip(proto.iter()).map(|(&a, &b)| a * b).sum();
let norm_q = query.iter().map(|&x| x * x).sum::<f32>().sqrt();
let norm_p = proto.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm_q > 0.0 && norm_p > 0.0 {
let sim = dot / (norm_q * norm_p);
if sim > best_sim {
best_sim = sim;
best_label = Some(label.clone());
}
}
}
best_label.map(|l| (l, best_sim))
}
}