1use super::adaptation::RealTimeFinetuning;
4use super::config::CrossModalConfig;
5use super::encoders::{AlignmentNetwork, KGEncoder, TextEncoder};
6use super::learning::FewShotLearning;
7use crate::{EmbeddingModel, ModelStats, TrainingStats, Vector};
8use anyhow::{anyhow, Result};
9use async_trait::async_trait;
10use chrono::Utc;
11use scirs2_core::ndarray_ext::Array1;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use uuid::Uuid;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MultiModalEmbedding {
19 pub config: CrossModalConfig,
20 pub model_id: Uuid,
21 pub text_embeddings: HashMap<String, Array1<f32>>,
23 pub kg_embeddings: HashMap<String, Array1<f32>>,
25 pub unified_embeddings: HashMap<String, Array1<f32>>,
27 pub text_kg_alignments: HashMap<String, String>,
29 pub entity_descriptions: HashMap<String, String>,
31 pub property_texts: HashMap<String, String>,
33 pub multilingual_mappings: HashMap<String, Vec<String>>,
35 pub cross_domain_mappings: HashMap<String, String>,
37 pub text_encoder: TextEncoder,
39 pub kg_encoder: KGEncoder,
40 pub alignment_network: AlignmentNetwork,
41 pub training_stats: TrainingStats,
43 pub model_stats: ModelStats,
44 pub is_trained: bool,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct MultiModalStats {
50 pub num_text_embeddings: usize,
51 pub num_kg_embeddings: usize,
52 pub num_unified_embeddings: usize,
53 pub num_alignments: usize,
54 pub num_entity_descriptions: usize,
55 pub num_property_texts: usize,
56 pub num_multilingual_mappings: usize,
57 pub num_cross_domain_mappings: usize,
58 pub text_dim: usize,
59 pub kg_dim: usize,
60 pub unified_dim: usize,
61}
62
63impl MultiModalEmbedding {
64 pub fn new(config: CrossModalConfig) -> Self {
66 let model_id = {
67 use scirs2_core::random::{Random, Rng};
68 let mut random = Random::default();
69 Uuid::from_u128(random.random::<u128>())
70 };
71 let now = Utc::now();
72
73 let text_encoder = TextEncoder::new("BERT".to_string(), config.text_dim, config.text_dim);
74
75 let kg_encoder = KGEncoder::new(
76 "ComplEx".to_string(),
77 config.kg_dim,
78 config.kg_dim,
79 config.kg_dim,
80 );
81
82 let alignment_network = AlignmentNetwork::new(
83 "CrossModalAttention".to_string(),
84 config.text_dim,
85 config.kg_dim,
86 config.unified_dim / 2,
87 config.unified_dim,
88 );
89
90 Self {
91 model_id,
92 text_embeddings: HashMap::new(),
93 kg_embeddings: HashMap::new(),
94 unified_embeddings: HashMap::new(),
95 text_kg_alignments: HashMap::new(),
96 entity_descriptions: HashMap::new(),
97 property_texts: HashMap::new(),
98 multilingual_mappings: HashMap::new(),
99 cross_domain_mappings: HashMap::new(),
100 text_encoder,
101 kg_encoder,
102 alignment_network,
103 training_stats: TrainingStats {
104 epochs_completed: 0,
105 final_loss: 0.0,
106 training_time_seconds: 0.0,
107 convergence_achieved: false,
108 loss_history: Vec::new(),
109 },
110 model_stats: ModelStats {
111 num_entities: 0,
112 num_relations: 0,
113 num_triples: 0,
114 dimensions: config.unified_dim,
115 is_trained: false,
116 model_type: "MultiModalEmbedding".to_string(),
117 creation_time: now,
118 last_training_time: None,
119 },
120 is_trained: false,
121 config,
122 }
123 }
124
125 pub fn add_text_kg_alignment(&mut self, text: &str, entity: &str) {
127 self.text_kg_alignments
128 .insert(text.to_string(), entity.to_string());
129 }
130
131 pub fn add_entity_description(&mut self, entity: &str, description: &str) {
133 self.entity_descriptions
134 .insert(entity.to_string(), description.to_string());
135 }
136
137 pub fn add_property_text(&mut self, property: &str, text_description: &str) {
139 self.property_texts
140 .insert(property.to_string(), text_description.to_string());
141 }
142
143 pub fn add_multilingual_mapping(&mut self, concept: &str, translations: Vec<String>) {
145 self.multilingual_mappings
146 .insert(concept.to_string(), translations);
147 }
148
149 pub fn add_cross_domain_mapping(&mut self, source_concept: &str, target_concept: &str) {
151 self.cross_domain_mappings
152 .insert(source_concept.to_string(), target_concept.to_string());
153 }
154
155 pub async fn generate_unified_embedding(
157 &mut self,
158 text: &str,
159 entity: &str,
160 ) -> Result<Array1<f32>> {
161 let text_embedding = self.text_encoder.encode(text)?;
163
164 let kg_embedding_raw = self.get_or_create_kg_embedding(entity)?;
166
167 let kg_embedding = self.kg_encoder.encode_entity(&kg_embedding_raw)?;
169
170 let (unified_embedding, alignment_score) = self
172 .alignment_network
173 .align(&text_embedding, &kg_embedding)?;
174
175 self.text_embeddings
177 .insert(text.to_string(), text_embedding);
178 self.kg_embeddings
179 .insert(entity.to_string(), kg_embedding_raw); self.unified_embeddings
181 .insert(format!("{text}|{entity}"), unified_embedding.clone());
182
183 println!("Generated unified embedding with alignment score: {alignment_score:.3}");
184
185 Ok(unified_embedding)
186 }
187
188 pub fn get_or_create_kg_embedding(&self, entity: &str) -> Result<Array1<f32>> {
190 if let Some(embedding) = self.kg_embeddings.get(entity) {
191 Ok(embedding.clone())
192 } else {
193 let mut embedding = vec![0.0; self.config.kg_dim];
195 let entity_bytes = entity.as_bytes();
196
197 for (i, &byte) in entity_bytes.iter().enumerate() {
198 if i < self.config.kg_dim {
199 embedding[i] = (byte as f32 / 255.0 - 0.5) * 2.0;
200 }
201 }
202
203 Ok(Array1::from_vec(embedding))
204 }
205 }
206
207 pub fn contrastive_loss(
209 &self,
210 positive_pairs: &[(String, String)],
211 negative_pairs: &[(String, String)],
212 ) -> Result<f32> {
213 let mut positive_scores = Vec::new();
214 let mut negative_scores = Vec::new();
215
216 for (text, entity) in positive_pairs {
218 if let (Some(text_emb), Some(kg_emb_raw)) = (
219 self.text_embeddings.get(text),
220 self.kg_embeddings.get(entity),
221 ) {
222 let kg_emb = self.kg_encoder.encode_entity(kg_emb_raw)?;
223 let score = self
224 .alignment_network
225 .compute_alignment_score(text_emb, &kg_emb);
226 positive_scores.push(score);
227 }
228 }
229
230 for (text, entity) in negative_pairs {
232 if let (Some(text_emb), Some(kg_emb_raw)) = (
233 self.text_embeddings.get(text),
234 self.kg_embeddings.get(entity),
235 ) {
236 let kg_emb = self.kg_encoder.encode_entity(kg_emb_raw)?;
237 let score = self
238 .alignment_network
239 .compute_alignment_score(text_emb, &kg_emb);
240 negative_scores.push(score);
241 }
242 }
243
244 let temperature = self.config.contrastive_config.temperature;
246 let mut loss = 0.0;
247
248 for &pos_score in &positive_scores {
249 let pos_exp = (pos_score / temperature).exp();
250 let mut neg_sum = 0.0;
251
252 for &neg_score in &negative_scores {
253 neg_sum += (neg_score / temperature).exp();
254 }
255
256 if neg_sum > 0.0 {
257 loss -= (pos_exp / (pos_exp + neg_sum)).ln();
258 }
259 }
260
261 if !positive_scores.is_empty() {
262 loss /= positive_scores.len() as f32;
263 }
264
265 Ok(loss)
266 }
267
268 pub async fn zero_shot_prediction(
270 &self,
271 text: &str,
272 candidate_entities: &[String],
273 ) -> Result<Vec<(String, f32)>> {
274 let text_embedding = self.text_encoder.encode(text)?;
275 let mut scores = Vec::new();
276
277 for entity in candidate_entities {
278 if let Some(kg_embedding_raw) = self.kg_embeddings.get(entity) {
279 let kg_encoded = self.kg_encoder.encode_entity(kg_embedding_raw)?;
280 let score = self
281 .alignment_network
282 .compute_alignment_score(&text_embedding, &kg_encoded);
283 scores.push((entity.clone(), score));
284 }
285 }
286
287 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
289
290 Ok(scores)
291 }
292
293 pub async fn cross_domain_transfer(
295 &mut self,
296 source_domain: &str,
297 target_domain: &str,
298 ) -> Result<f32> {
299 if !self.config.cross_domain_config.enable_domain_adaptation {
300 return Ok(0.0);
301 }
302
303 let mut transfer_pairs = Vec::new();
305 for (source_concept, target_concept) in &self.cross_domain_mappings {
306 if source_concept.contains(source_domain) && target_concept.contains(target_domain) {
307 transfer_pairs.push((source_concept.clone(), target_concept.clone()));
308 }
309 }
310
311 if transfer_pairs.is_empty() {
312 return Ok(0.0);
313 }
314
315 let mut adaptation_loss = 0.0;
317 for (source, target) in &transfer_pairs {
318 if let (Some(source_emb), Some(target_emb)) = (
319 self.unified_embeddings.get(source),
320 self.unified_embeddings.get(target),
321 ) {
322 let diff = source_emb - target_emb;
324 adaptation_loss += diff.dot(&diff).sqrt();
325 }
326 }
327
328 adaptation_loss /= transfer_pairs.len() as f32;
329
330 println!(
331 "Cross-domain transfer loss ({source_domain} -> {target_domain}): {adaptation_loss:.3}"
332 );
333
334 Ok(adaptation_loss)
335 }
336
337 pub async fn multilingual_alignment(&self, concept: &str) -> Result<Vec<(String, f32)>> {
339 if let Some(translations) = self.multilingual_mappings.get(concept) {
340 let mut alignment_scores = Vec::new();
341
342 if let Some(base_embedding) = self.unified_embeddings.get(concept) {
343 for translation in translations {
344 if let Some(trans_embedding) = self.unified_embeddings.get(translation) {
345 let score = self
346 .alignment_network
347 .compute_alignment_score(base_embedding, trans_embedding);
348 alignment_scores.push((translation.clone(), score));
349 }
350 }
351 }
352
353 alignment_scores
355 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
356
357 Ok(alignment_scores)
358 } else {
359 Ok(Vec::new())
360 }
361 }
362
363 pub fn get_multimodal_stats(&self) -> MultiModalStats {
365 MultiModalStats {
366 num_text_embeddings: self.text_embeddings.len(),
367 num_kg_embeddings: self.kg_embeddings.len(),
368 num_unified_embeddings: self.unified_embeddings.len(),
369 num_alignments: self.text_kg_alignments.len(),
370 num_entity_descriptions: self.entity_descriptions.len(),
371 num_property_texts: self.property_texts.len(),
372 num_multilingual_mappings: self.multilingual_mappings.len(),
373 num_cross_domain_mappings: self.cross_domain_mappings.len(),
374 text_dim: self.config.text_dim,
375 kg_dim: self.config.kg_dim,
376 unified_dim: self.config.unified_dim,
377 }
378 }
379
380 pub fn with_few_shot_learning(self, _few_shot_config: FewShotLearning) -> Self {
382 self
385 }
386
387 pub async fn few_shot_learn(
389 &self,
390 support_examples: &[(String, String, String)],
391 query_examples: &[(String, String)],
392 ) -> Result<Vec<(String, f32)>> {
393 let mut few_shot_learner = FewShotLearning::default();
394 few_shot_learner
395 .few_shot_adapt(support_examples, query_examples, self)
396 .await
397 }
398
399 pub fn with_real_time_finetuning(self, _rt_config: RealTimeFinetuning) -> Self {
401 self
404 }
405
406 pub async fn real_time_update(&mut self, text: &str, entity: &str, label: &str) -> Result<f32> {
408 let mut rt_finetuning = RealTimeFinetuning::default();
409 rt_finetuning.add_example(text.to_string(), entity.to_string(), label.to_string());
410 rt_finetuning.update_model(self).await
411 }
412}
413
414#[async_trait]
415impl EmbeddingModel for MultiModalEmbedding {
416 fn config(&self) -> &crate::ModelConfig {
417 &self.config.base_config
418 }
419
420 fn model_id(&self) -> &Uuid {
421 &self.model_id
422 }
423
424 fn model_type(&self) -> &'static str {
425 "MultiModalEmbedding"
426 }
427
428 fn add_triple(&mut self, triple: crate::Triple) -> Result<()> {
429 let subject = &triple.subject.iri;
431 let predicate = &triple.predicate.iri;
432 let object = &triple.object.iri;
433
434 if let Some(description) = self.entity_descriptions.get(subject).cloned() {
436 self.add_text_kg_alignment(&description, subject);
437 }
438
439 if let Some(description) = self.entity_descriptions.get(object).cloned() {
440 self.add_text_kg_alignment(&description, object);
441 }
442
443 if let Some(property_text) = self.property_texts.get(predicate).cloned() {
445 self.add_text_kg_alignment(&property_text, predicate);
446 }
447
448 Ok(())
449 }
450
451 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
452 let epochs = epochs.unwrap_or(100);
453 let start_time = std::time::Instant::now();
454 let mut loss_history = Vec::new();
455
456 for epoch in 0..epochs {
458 let mut epoch_loss = 0.0;
459 let mut num_batches = 0;
460
461 let alignment_pairs: Vec<_> = self
463 .text_kg_alignments
464 .iter()
465 .map(|(k, v)| (k.clone(), v.clone()))
466 .collect();
467 for (text, entity) in &alignment_pairs {
468 if let Ok(unified) = self.generate_unified_embedding(text, entity).await {
470 let loss = unified.iter().map(|&x| x * x).sum::<f32>() / unified.len() as f32;
472 epoch_loss += loss;
473 num_batches += 1;
474 }
475 }
476
477 if alignment_pairs.len() > 1 {
479 let positive_pairs: Vec<_> = alignment_pairs
480 .iter()
481 .map(|(t, e)| (t.to_string(), e.to_string()))
482 .collect();
483
484 let mut negative_pairs = Vec::new();
486 for i in 0..positive_pairs.len().min(10) {
487 let neg_entity = &positive_pairs[(i + 1) % positive_pairs.len()].1;
488 negative_pairs.push((positive_pairs[i].0.clone(), neg_entity.clone()));
489 }
490
491 if let Ok(contrastive_loss) =
492 self.contrastive_loss(&positive_pairs, &negative_pairs)
493 {
494 epoch_loss += contrastive_loss;
495 num_batches += 1;
496 }
497 }
498
499 if num_batches > 0 {
500 epoch_loss /= num_batches as f32;
501 }
502
503 loss_history.push(epoch_loss as f64);
504
505 if epoch % 10 == 0 {
506 println!("Multi-modal training epoch {epoch}: Loss = {epoch_loss:.6}");
507 }
508
509 if epoch_loss < 0.001 {
511 break;
512 }
513 }
514
515 let training_time = start_time.elapsed().as_secs_f64();
516
517 self.training_stats = TrainingStats {
518 epochs_completed: epochs,
519 final_loss: loss_history.last().copied().unwrap_or(0.0),
520 training_time_seconds: training_time,
521 convergence_achieved: loss_history.last().is_some_and(|&loss| loss < 0.001),
522 loss_history,
523 };
524
525 self.is_trained = true;
526 self.model_stats.is_trained = true;
527 self.model_stats.last_training_time = Some(Utc::now());
528
529 self.model_stats.num_entities = self.kg_embeddings.len();
531 self.model_stats.num_relations = self.property_texts.len();
532 self.model_stats.num_triples = self.text_kg_alignments.len();
533
534 Ok(self.training_stats.clone())
535 }
536
537 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
538 if let Some(embedding) = self.unified_embeddings.get(entity) {
539 Ok(Vector::from_array1(embedding))
540 } else if let Some(embedding) = self.kg_embeddings.get(entity) {
541 Ok(Vector::from_array1(embedding))
542 } else {
543 Err(anyhow!("Entity {} not found", entity))
544 }
545 }
546
547 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
548 if let Some(embedding) = self.kg_embeddings.get(relation) {
549 Ok(Vector::from_array1(embedding))
550 } else {
551 Err(anyhow!("Relation {} not found", relation))
552 }
553 }
554
555 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
556 let subject_emb = self.get_entity_embedding(subject)?;
557 let predicate_emb = self.get_relation_embedding(predicate)?;
558 let object_emb = self.get_entity_embedding(object)?;
559
560 let mut score = 0.0;
562 for i in 0..subject_emb
563 .dimensions
564 .min(predicate_emb.dimensions)
565 .min(object_emb.dimensions)
566 {
567 let diff = subject_emb.values[i] + predicate_emb.values[i] - object_emb.values[i];
568 score += diff * diff;
569 }
570
571 Ok(1.0 / (1.0 + score as f64))
573 }
574
575 fn predict_objects(
576 &self,
577 subject: &str,
578 predicate: &str,
579 k: usize,
580 ) -> Result<Vec<(String, f64)>> {
581 let mut scores = Vec::new();
582
583 for entity in self.kg_embeddings.keys() {
584 if entity != subject {
585 if let Ok(score) = self.score_triple(subject, predicate, entity) {
586 scores.push((entity.clone(), score));
587 }
588 }
589 }
590
591 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
592 scores.truncate(k);
593
594 Ok(scores)
595 }
596
597 fn predict_subjects(
598 &self,
599 predicate: &str,
600 object: &str,
601 k: usize,
602 ) -> Result<Vec<(String, f64)>> {
603 let mut scores = Vec::new();
604
605 for entity in self.kg_embeddings.keys() {
606 if entity != object {
607 if let Ok(score) = self.score_triple(entity, predicate, object) {
608 scores.push((entity.clone(), score));
609 }
610 }
611 }
612
613 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
614 scores.truncate(k);
615
616 Ok(scores)
617 }
618
619 fn predict_relations(
620 &self,
621 subject: &str,
622 object: &str,
623 k: usize,
624 ) -> Result<Vec<(String, f64)>> {
625 let mut scores = Vec::new();
626
627 for relation in self.property_texts.keys() {
628 if let Ok(score) = self.score_triple(subject, relation, object) {
629 scores.push((relation.clone(), score));
630 }
631 }
632
633 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
634 scores.truncate(k);
635
636 Ok(scores)
637 }
638
639 fn get_entities(&self) -> Vec<String> {
640 self.kg_embeddings.keys().cloned().collect()
641 }
642
643 fn get_relations(&self) -> Vec<String> {
644 self.property_texts.keys().cloned().collect()
645 }
646
647 fn get_stats(&self) -> ModelStats {
648 self.model_stats.clone()
649 }
650
651 fn save(&self, _path: &str) -> Result<()> {
652 Ok(())
654 }
655
656 fn load(&mut self, _path: &str) -> Result<()> {
657 Ok(())
659 }
660
661 fn clear(&mut self) {
662 self.text_embeddings.clear();
663 self.kg_embeddings.clear();
664 self.unified_embeddings.clear();
665 self.text_kg_alignments.clear();
666 self.entity_descriptions.clear();
667 self.property_texts.clear();
668 self.multilingual_mappings.clear();
669 self.cross_domain_mappings.clear();
670 self.is_trained = false;
671 }
672
673 fn is_trained(&self) -> bool {
674 self.is_trained
675 }
676
677 async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
678 let mut embeddings = Vec::new();
679
680 for text in texts {
681 if let Some(embedding) = self.text_embeddings.get(text) {
682 embeddings.push(embedding.to_vec());
683 } else {
684 let embedding = self.text_encoder.encode(text)?;
686 embeddings.push(embedding.to_vec());
687 }
688 }
689
690 Ok(embeddings)
691 }
692}