1use crate::{EmbeddingModel, Vector};
7use anyhow::{anyhow, Result};
8use scirs2_core::random::{Random, Rng};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11
12pub struct CrossDomainTransferManager {
14 source_domains: HashMap<String, DomainModel>,
16 target_domains: HashMap<String, DomainSpecification>,
18 transfer_strategies: Vec<TransferStrategy>,
20 transfer_metrics: Vec<TransferMetric>,
22 config: TransferConfig,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct TransferConfig {
29 pub enable_domain_adaptation: bool,
31 pub use_adversarial_alignment: bool,
33 pub max_alignment_iterations: usize,
35 pub adaptation_learning_rate: f64,
37 pub min_domain_similarity: f64,
39 pub enable_entity_linking: bool,
41 pub evaluation_sample_size: usize,
43}
44
45impl Default for TransferConfig {
46 fn default() -> Self {
47 Self {
48 enable_domain_adaptation: true,
49 use_adversarial_alignment: true,
50 max_alignment_iterations: 100,
51 adaptation_learning_rate: 0.001,
52 min_domain_similarity: 0.3,
53 enable_entity_linking: true,
54 evaluation_sample_size: 1000,
55 }
56 }
57}
58
59pub struct DomainModel {
61 pub domain_id: String,
63 pub model: Box<dyn EmbeddingModel + Send + Sync>,
65 pub characteristics: DomainCharacteristics,
67 pub entity_mappings: HashMap<String, String>,
69 pub vocabulary: HashSet<String>,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct DomainCharacteristics {
76 pub domain_type: String,
78 pub language: String,
80 pub entity_types: Vec<String>,
82 pub relation_types: Vec<String>,
84 pub size_metrics: DomainSizeMetrics,
86 pub complexity_metrics: DomainComplexityMetrics,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct DomainSizeMetrics {
93 pub num_entities: usize,
95 pub num_relations: usize,
97 pub num_triples: usize,
99 pub avg_entity_degree: f64,
101 pub graph_density: f64,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct DomainComplexityMetrics {
108 pub entity_type_diversity: usize,
110 pub relation_type_diversity: usize,
112 pub hierarchical_depth: usize,
114 pub semantic_diversity: f64,
116 pub structural_complexity: f64,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct DomainSpecification {
123 pub domain_id: String,
125 pub characteristics: DomainCharacteristics,
127 pub training_data: Vec<(String, String, String)>,
129 pub validation_data: Vec<(String, String, String)>,
131 pub test_data: Vec<(String, String, String)>,
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub enum TransferStrategy {
138 DirectTransfer,
140 FineTuning {
142 learning_rate: f64,
143 epochs: usize,
144 freeze_layers: Vec<String>,
145 },
146 DomainAdaptation {
148 alignment_method: AlignmentMethod,
149 regularization_strength: f64,
150 },
151 MultiTaskLearning { task_weights: HashMap<String, f64> },
153 MetaLearning {
155 inner_steps: usize,
156 meta_learning_rate: f64,
157 },
158 ProgressiveTransfer {
160 intermediate_domains: Vec<String>,
161 progression_strategy: ProgressionStrategy,
162 },
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub enum AlignmentMethod {
168 LinearAlignment,
170 NeuralAlignment,
172 AdversarialAlignment,
174 CCA,
176 ProcrustesAlignment,
178 WassersteinAlignment,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
184pub enum ProgressionStrategy {
185 Sequential,
187 CurriculumBased,
189 SimilarityGuided,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub enum TransferMetric {
196 TransferAccuracy,
198 AdaptationQuality,
200 EntityAlignmentQuality,
202 SemanticPreservation,
204 StructuralPreservation,
206 TransferEfficiency,
208 CatastrophicForgetting,
210 CrossDomainCoherence,
212 KnowledgeRetention,
214 AdaptationSpeed,
216 TransferRobustness,
218 SemanticDriftDetection,
220 GeneralizationAbility,
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct TransferEvaluationResults {
227 pub source_domain: String,
229 pub target_domain: String,
231 pub strategy: TransferStrategy,
233 pub metric_scores: HashMap<String, f64>,
235 pub overall_quality: f64,
237 pub domain_similarity: f64,
239 pub improvement_over_baseline: f64,
241 pub transfer_time: f64,
243 pub detailed_analysis: TransferAnalysis,
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct TransferAnalysis {
250 pub entity_alignments: Vec<EntityAlignment>,
252 pub relation_alignments: Vec<RelationAlignment>,
254 pub semantic_shifts: Vec<SemanticShift>,
256 pub structural_changes: StructuralChanges,
258 pub recommendations: Vec<String>,
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct EntityAlignment {
265 pub source_entity: String,
267 pub target_entity: String,
269 pub confidence: f64,
271 pub similarity: f64,
273 pub method: String,
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct RelationAlignment {
280 pub source_relation: String,
282 pub target_relation: String,
284 pub confidence: f64,
286 pub semantic_similarity: f64,
288 pub structural_similarity: f64,
290}
291
292#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct SemanticShift {
295 pub concept: String,
297 pub source_meaning: String,
299 pub target_meaning: String,
301 pub shift_magnitude: f64,
303 pub impact: f64,
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
309pub struct StructuralChanges {
310 pub degree_distribution_shift: f64,
312 pub clustering_changes: f64,
314 pub path_length_changes: f64,
316 pub community_structure_changes: f64,
318}
319
320impl CrossDomainTransferManager {
321 pub fn new(config: TransferConfig) -> Self {
323 Self {
324 source_domains: HashMap::new(),
325 target_domains: HashMap::new(),
326 transfer_strategies: vec![
327 TransferStrategy::DirectTransfer,
328 TransferStrategy::FineTuning {
329 learning_rate: 0.001,
330 epochs: 50,
331 freeze_layers: vec![],
332 },
333 TransferStrategy::DomainAdaptation {
334 alignment_method: AlignmentMethod::AdversarialAlignment,
335 regularization_strength: 0.1,
336 },
337 ],
338 transfer_metrics: vec![
339 TransferMetric::TransferAccuracy,
340 TransferMetric::AdaptationQuality,
341 TransferMetric::SemanticPreservation,
342 TransferMetric::StructuralPreservation,
343 ],
344 config,
345 }
346 }
347
348 pub fn register_source_domain(
350 &mut self,
351 domain_id: String,
352 model: Box<dyn EmbeddingModel + Send + Sync>,
353 characteristics: DomainCharacteristics,
354 ) -> Result<()> {
355 let domain_model = DomainModel {
356 domain_id: domain_id.clone(),
357 model,
358 characteristics,
359 entity_mappings: HashMap::new(),
360 vocabulary: HashSet::new(),
361 };
362
363 self.source_domains.insert(domain_id, domain_model);
364 Ok(())
365 }
366
367 pub fn register_target_domain(&mut self, domain_spec: DomainSpecification) -> Result<()> {
369 self.target_domains
370 .insert(domain_spec.domain_id.clone(), domain_spec);
371 Ok(())
372 }
373
374 pub async fn evaluate_transfer(
376 &self,
377 source_domain_id: &str,
378 target_domain_id: &str,
379 strategy: TransferStrategy,
380 ) -> Result<TransferEvaluationResults> {
381 let source_domain = self
382 .source_domains
383 .get(source_domain_id)
384 .ok_or_else(|| anyhow!("Source domain not found: {}", source_domain_id))?;
385
386 let target_domain = self
387 .target_domains
388 .get(target_domain_id)
389 .ok_or_else(|| anyhow!("Target domain not found: {}", target_domain_id))?;
390
391 let start_time = std::time::Instant::now();
392
393 let domain_similarity = self.calculate_domain_similarity(
395 &source_domain.characteristics,
396 &target_domain.characteristics,
397 )?;
398
399 let entity_alignments = self.align_entities(source_domain, target_domain).await?;
401
402 let relation_alignments = self.align_relations(source_domain, target_domain).await?;
404
405 let semantic_shifts = self
407 .analyze_semantic_shifts(source_domain, target_domain)
408 .await?;
409
410 let structural_changes = self.analyze_structural_changes(
412 &source_domain.characteristics,
413 &target_domain.characteristics,
414 )?;
415
416 let mut metric_scores = HashMap::new();
418 for metric in &self.transfer_metrics {
419 let score = self
420 .evaluate_transfer_metric(
421 metric,
422 source_domain,
423 target_domain,
424 &entity_alignments,
425 &relation_alignments,
426 )
427 .await?;
428 metric_scores.insert(format!("{metric:?}"), score);
429 }
430
431 let overall_quality = if metric_scores.is_empty() {
433 0.5 } else {
435 let avg_quality = metric_scores.values().sum::<f64>() / metric_scores.len() as f64;
436 avg_quality.max(0.0) };
438
439 let baseline_performance = 0.1; let improvement_over_baseline = overall_quality - baseline_performance;
442
443 let transfer_time = start_time.elapsed().as_secs_f64();
444
445 let recommendations = self.generate_transfer_recommendations(
447 domain_similarity,
448 &entity_alignments,
449 &semantic_shifts,
450 );
451
452 let detailed_analysis = TransferAnalysis {
453 entity_alignments,
454 relation_alignments,
455 semantic_shifts,
456 structural_changes,
457 recommendations,
458 };
459
460 Ok(TransferEvaluationResults {
461 source_domain: source_domain_id.to_string(),
462 target_domain: target_domain_id.to_string(),
463 strategy,
464 metric_scores,
465 overall_quality,
466 domain_similarity,
467 improvement_over_baseline,
468 transfer_time,
469 detailed_analysis,
470 })
471 }
472
473 pub fn calculate_domain_similarity(
475 &self,
476 source: &DomainCharacteristics,
477 target: &DomainCharacteristics,
478 ) -> Result<f64> {
479 let mut similarity_scores = Vec::new();
480
481 let language_similarity = if source.language == target.language {
483 1.0
484 } else {
485 0.5 };
487 similarity_scores.push(language_similarity);
488
489 let source_entity_types: HashSet<_> = source.entity_types.iter().collect();
491 let target_entity_types: HashSet<_> = target.entity_types.iter().collect();
492 let entity_overlap = source_entity_types
493 .intersection(&target_entity_types)
494 .count() as f64;
495 let entity_similarity =
496 entity_overlap / (source_entity_types.len() + target_entity_types.len()) as f64 * 2.0;
497 similarity_scores.push(entity_similarity);
498
499 let source_relation_types: HashSet<_> = source.relation_types.iter().collect();
501 let target_relation_types: HashSet<_> = target.relation_types.iter().collect();
502 let relation_overlap = source_relation_types
503 .intersection(&target_relation_types)
504 .count() as f64;
505 let relation_similarity = relation_overlap
506 / (source_relation_types.len() + target_relation_types.len()) as f64
507 * 2.0;
508 similarity_scores.push(relation_similarity);
509
510 let size_ratio = (target.size_metrics.num_entities as f64
512 / source.size_metrics.num_entities as f64)
513 .min(source.size_metrics.num_entities as f64 / target.size_metrics.num_entities as f64);
514 similarity_scores.push(size_ratio);
515
516 let complexity_diff = (source.complexity_metrics.semantic_diversity
518 - target.complexity_metrics.semantic_diversity)
519 .abs();
520 let complexity_similarity = (1.0 - complexity_diff).max(0.0);
521 similarity_scores.push(complexity_similarity);
522
523 let overall_similarity =
525 similarity_scores.iter().sum::<f64>() / similarity_scores.len() as f64;
526
527 Ok(overall_similarity)
528 }
529
530 async fn align_entities(
532 &self,
533 source: &DomainModel,
534 target: &DomainSpecification,
535 ) -> Result<Vec<EntityAlignment>> {
536 let mut alignments = Vec::new();
537
538 let source_entities = source.model.get_entities();
539 let target_entities = self.extract_entities_from_triples(&target.training_data);
540
541 for source_entity in &source_entities {
543 for target_entity in &target_entities {
544 let similarity = self.calculate_string_similarity(source_entity, target_entity);
545
546 if similarity > 0.7 {
547 alignments.push(EntityAlignment {
549 source_entity: source_entity.clone(),
550 target_entity: target_entity.clone(),
551 confidence: similarity,
552 similarity,
553 method: "string_similarity".to_string(),
554 });
555 }
556 }
557 }
558
559 for source_entity in source_entities.iter().take(50) {
561 if let Ok(source_embedding) = source.model.get_entity_embedding(source_entity) {
563 let mut best_match = None;
564 let mut best_similarity = 0.0;
565
566 for target_entity in target_entities.iter().take(50) {
567 let target_embedding = self.create_simple_embedding(target_entity);
569 let similarity = self.cosine_similarity(&source_embedding, &target_embedding);
570
571 if similarity > best_similarity && similarity > 0.5 {
572 best_similarity = similarity;
573 best_match = Some(target_entity.clone());
574 }
575 }
576
577 if let Some(target_entity) = best_match {
578 alignments.push(EntityAlignment {
579 source_entity: source_entity.clone(),
580 target_entity,
581 confidence: best_similarity,
582 similarity: best_similarity,
583 method: "semantic_embedding".to_string(),
584 });
585 }
586 }
587 }
588
589 Ok(alignments)
590 }
591
592 async fn align_relations(
594 &self,
595 source: &DomainModel,
596 target: &DomainSpecification,
597 ) -> Result<Vec<RelationAlignment>> {
598 let mut alignments = Vec::new();
599
600 let source_relations = source.model.get_relations();
601 let target_relations = self.extract_relations_from_triples(&target.training_data);
602
603 for source_relation in &source_relations {
604 for target_relation in &target_relations {
605 let semantic_similarity =
606 self.calculate_string_similarity(source_relation, target_relation);
607
608 let structural_similarity = 0.5; if semantic_similarity > 0.6 {
612 alignments.push(RelationAlignment {
613 source_relation: source_relation.clone(),
614 target_relation: target_relation.clone(),
615 confidence: (semantic_similarity + structural_similarity) / 2.0,
616 semantic_similarity,
617 structural_similarity,
618 });
619 }
620 }
621 }
622
623 Ok(alignments)
624 }
625
626 async fn analyze_semantic_shifts(
628 &self,
629 source: &DomainModel,
630 target: &DomainSpecification,
631 ) -> Result<Vec<SemanticShift>> {
632 let mut shifts = Vec::new();
633
634 let source_entities = source.model.get_entities();
636 let target_entities = self.extract_entities_from_triples(&target.training_data);
637
638 for source_entity in source_entities.iter().take(20) {
639 for target_entity in target_entities.iter().take(20) {
640 if self.calculate_string_similarity(source_entity, target_entity) > 0.8 {
641 let shift_magnitude = self.calculate_semantic_shift_magnitude(
643 source_entity,
644 target_entity,
645 source,
646 target,
647 )?;
648
649 if shift_magnitude > 0.3 {
650 shifts.push(SemanticShift {
651 concept: source_entity.clone(),
652 source_meaning: format!("Source domain context: {source_entity}"),
653 target_meaning: format!("Target domain context: {target_entity}"),
654 shift_magnitude,
655 impact: shift_magnitude * 0.5, });
657 }
658 }
659 }
660 }
661
662 Ok(shifts)
663 }
664
665 fn analyze_structural_changes(
667 &self,
668 source: &DomainCharacteristics,
669 target: &DomainCharacteristics,
670 ) -> Result<StructuralChanges> {
671 let degree_distribution_shift =
673 (source.size_metrics.avg_entity_degree - target.size_metrics.avg_entity_degree).abs()
674 / source.size_metrics.avg_entity_degree;
675
676 let clustering_changes = 0.1; let path_length_changes = 0.15; let community_structure_changes = 0.2; Ok(StructuralChanges {
681 degree_distribution_shift,
682 clustering_changes,
683 path_length_changes,
684 community_structure_changes,
685 })
686 }
687
688 async fn evaluate_transfer_metric(
690 &self,
691 metric: &TransferMetric,
692 source: &DomainModel,
693 target: &DomainSpecification,
694 entity_alignments: &[EntityAlignment],
695 relation_alignments: &[RelationAlignment],
696 ) -> Result<f64> {
697 match metric {
698 TransferMetric::TransferAccuracy => {
699 self.calculate_transfer_accuracy(source, target).await
701 }
702 TransferMetric::AdaptationQuality => {
703 if entity_alignments.is_empty() {
705 Ok(0.5) } else {
707 Ok(entity_alignments.iter().map(|a| a.confidence).sum::<f64>()
708 / entity_alignments.len() as f64)
709 }
710 }
711 TransferMetric::EntityAlignmentQuality => {
712 if entity_alignments.is_empty() {
714 Ok(0.5) } else {
716 Ok(entity_alignments
717 .iter()
718 .filter(|a| a.confidence > 0.7)
719 .count() as f64
720 / entity_alignments.len() as f64)
721 }
722 }
723 TransferMetric::SemanticPreservation => {
724 self.calculate_semantic_preservation(source, target, entity_alignments)
726 .await
727 }
728 TransferMetric::StructuralPreservation => {
729 self.calculate_structural_preservation(source, target, relation_alignments)
731 .await
732 }
733 TransferMetric::TransferEfficiency => {
734 self.calculate_transfer_efficiency(source, target).await
736 }
737 TransferMetric::CatastrophicForgetting => {
738 self.calculate_catastrophic_forgetting(source, target).await
740 }
741 TransferMetric::CrossDomainCoherence => {
742 self.calculate_cross_domain_coherence(source, target, entity_alignments)
744 .await
745 }
746 TransferMetric::KnowledgeRetention => {
747 self.calculate_knowledge_retention(source, target).await
749 }
750 TransferMetric::AdaptationSpeed => {
751 self.calculate_adaptation_speed(source, target).await
753 }
754 TransferMetric::TransferRobustness => {
755 self.calculate_transfer_robustness(source, target).await
757 }
758 TransferMetric::SemanticDriftDetection => {
759 self.calculate_semantic_drift_detection(source, target)
761 .await
762 }
763 TransferMetric::GeneralizationAbility => {
764 self.calculate_generalization_ability(source, target).await
766 }
767 }
768 }
769
770 async fn calculate_transfer_accuracy(
772 &self,
773 source: &DomainModel,
774 target: &DomainSpecification,
775 ) -> Result<f64> {
776 let mut correct_predictions = 0;
777 let total_predictions = target
778 .test_data
779 .len()
780 .min(self.config.evaluation_sample_size);
781
782 if total_predictions == 0 {
783 return Ok(0.5); }
785
786 for (subject, predicate, object) in target.test_data.iter().take(total_predictions) {
787 if let Ok(score) = source.model.score_triple(subject, predicate, object) {
789 if score > 0.0 {
791 correct_predictions += 1;
792 }
793 }
794 }
795
796 Ok(correct_predictions as f64 / total_predictions as f64)
797 }
798
799 async fn calculate_semantic_preservation(
801 &self,
802 source: &DomainModel,
803 _target: &DomainSpecification,
804 entity_alignments: &[EntityAlignment],
805 ) -> Result<f64> {
806 if entity_alignments.is_empty() {
807 return Ok(0.0);
808 }
809
810 let mut preservation_scores = Vec::new();
811
812 for alignment in entity_alignments.iter().take(20) {
813 if let Ok(source_embedding) =
815 source.model.get_entity_embedding(&alignment.source_entity)
816 {
817 let target_embedding = self.create_simple_embedding(&alignment.target_entity);
819
820 let preservation = self.cosine_similarity(&source_embedding, &target_embedding);
822 preservation_scores.push(preservation);
823 }
824 }
825
826 if preservation_scores.is_empty() {
827 Ok(0.0)
828 } else {
829 Ok(preservation_scores.iter().sum::<f64>() / preservation_scores.len() as f64)
830 }
831 }
832
833 async fn calculate_structural_preservation(
835 &self,
836 _source: &DomainModel,
837 _target: &DomainSpecification,
838 relation_alignments: &[RelationAlignment],
839 ) -> Result<f64> {
840 if relation_alignments.is_empty() {
841 return Ok(0.5); }
843
844 let avg_structural_similarity = relation_alignments
846 .iter()
847 .map(|a| a.structural_similarity)
848 .sum::<f64>()
849 / relation_alignments.len() as f64;
850
851 Ok(avg_structural_similarity)
852 }
853
854 fn generate_transfer_recommendations(
856 &self,
857 domain_similarity: f64,
858 entity_alignments: &[EntityAlignment],
859 semantic_shifts: &[SemanticShift],
860 ) -> Vec<String> {
861 let mut recommendations = Vec::new();
862
863 if domain_similarity < 0.3 {
864 recommendations.push(
865 "Low domain similarity detected. Consider using domain adaptation techniques."
866 .to_string(),
867 );
868 }
869
870 if entity_alignments.len() < 10 {
871 recommendations.push(
872 "Few entity alignments found. Consider improving entity linking methods."
873 .to_string(),
874 );
875 }
876
877 let high_shift_count = semantic_shifts
878 .iter()
879 .filter(|s| s.shift_magnitude > 0.5)
880 .count();
881 if high_shift_count > 5 {
882 recommendations.push(
883 "Significant semantic shifts detected. Consider gradual domain adaptation."
884 .to_string(),
885 );
886 }
887
888 if domain_similarity > 0.7 {
889 recommendations
890 .push("High domain similarity. Direct transfer should work well.".to_string());
891 }
892
893 recommendations
894 }
895
896 fn extract_entities_from_triples(
898 &self,
899 triples: &[(String, String, String)],
900 ) -> HashSet<String> {
901 let mut entities = HashSet::new();
902 for (subject, _, object) in triples {
903 entities.insert(subject.clone());
904 entities.insert(object.clone());
905 }
906 entities.into_iter().collect::<HashSet<_>>()
907 }
908
909 fn extract_relations_from_triples(
911 &self,
912 triples: &[(String, String, String)],
913 ) -> HashSet<String> {
914 triples
915 .iter()
916 .map(|(_, predicate, _)| predicate.clone())
917 .collect()
918 }
919
920 fn calculate_string_similarity(&self, s1: &str, s2: &str) -> f64 {
922 if s1 == s2 {
923 return 1.0;
924 }
925
926 let n = 3;
928 let ngrams1: HashSet<String> = s1
929 .chars()
930 .collect::<Vec<_>>()
931 .windows(n)
932 .map(|w| w.iter().collect())
933 .collect();
934 let ngrams2: HashSet<String> = s2
935 .chars()
936 .collect::<Vec<_>>()
937 .windows(n)
938 .map(|w| w.iter().collect())
939 .collect();
940
941 if ngrams1.is_empty() && ngrams2.is_empty() {
942 return 1.0;
943 }
944
945 let intersection = ngrams1.intersection(&ngrams2).count();
946 let union = ngrams1.union(&ngrams2).count();
947
948 intersection as f64 / union as f64
949 }
950
951 fn create_simple_embedding(&self, entity: &str) -> Vector {
953 let mut embedding = vec![0.0f32; 100]; for (i, byte) in entity.bytes().enumerate() {
956 if i >= embedding.len() {
957 break;
958 }
959 embedding[i] = (byte as f32) / 255.0;
960 }
961 Vector::new(embedding)
962 }
963
964 fn cosine_similarity(&self, v1: &Vector, v2: &Vector) -> f64 {
966 let dot_product: f32 = v1
967 .values
968 .iter()
969 .zip(v2.values.iter())
970 .map(|(a, b)| a * b)
971 .sum();
972 let norm_a: f32 = v1.values.iter().map(|x| x * x).sum::<f32>().sqrt();
973 let norm_b: f32 = v2.values.iter().map(|x| x * x).sum::<f32>().sqrt();
974
975 if norm_a > 0.0 && norm_b > 0.0 {
976 (dot_product / (norm_a * norm_b)) as f64
977 } else {
978 0.0
979 }
980 }
981
982 fn calculate_semantic_shift_magnitude(
984 &self,
985 _source_entity: &str,
986 _target_entity: &str,
987 _source: &DomainModel,
988 _target: &DomainSpecification,
989 ) -> Result<f64> {
990 Ok({
992 let mut random = Random::default();
993 random.random::<f64>() * 0.8
994 }) }
996
997 pub fn get_source_domains(&self) -> Vec<String> {
999 self.source_domains.keys().cloned().collect()
1000 }
1001
1002 pub fn get_target_domains(&self) -> Vec<String> {
1004 self.target_domains.keys().cloned().collect()
1005 }
1006
1007 pub fn get_domain_characteristics(&self, domain_id: &str) -> Option<&DomainCharacteristics> {
1009 self.source_domains
1010 .get(domain_id)
1011 .map(|d| &d.characteristics)
1012 .or_else(|| {
1013 self.target_domains
1014 .get(domain_id)
1015 .map(|d| &d.characteristics)
1016 })
1017 }
1018
1019 async fn calculate_transfer_efficiency(
1021 &self,
1022 source: &DomainModel,
1023 target: &DomainSpecification,
1024 ) -> Result<f64> {
1025 let start_time = std::time::Instant::now();
1026
1027 let domain_similarity =
1029 self.calculate_domain_similarity(&source.characteristics, &target.characteristics)?;
1030
1031 let transfer_accuracy = self.calculate_transfer_accuracy(source, target).await?;
1033 let transfer_time = start_time.elapsed().as_secs_f64();
1034
1035 let normalized_time = (transfer_time / 60.0).clamp(0.01, 1.0); let efficiency = (transfer_accuracy * domain_similarity) / normalized_time;
1038
1039 Ok(efficiency.clamp(0.0, 1.0))
1040 }
1041
1042 async fn calculate_catastrophic_forgetting(
1044 &self,
1045 source: &DomainModel,
1046 target: &DomainSpecification,
1047 ) -> Result<f64> {
1048 let source_entities = source.model.get_entities();
1050 let sample_size = source_entities.len().min(20);
1051
1052 if sample_size == 0 {
1053 return Ok(0.0);
1054 }
1055
1056 let mut forgetting_scores = Vec::new();
1057
1058 for entity in source_entities.iter().take(sample_size) {
1060 if let Ok(_source_embedding) = source.model.get_entity_embedding(entity) {
1061 let target_entities = self.extract_entities_from_triples(&target.training_data);
1063 let domain_overlap = target_entities.contains(entity);
1064
1065 let degradation = if domain_overlap {
1066 let mut random = Random::default();
1068 0.1 + random.random::<f64>() * 0.2
1069 } else {
1070 let mut random = Random::default();
1072 0.3 + random.random::<f64>() * 0.4
1073 };
1074
1075 forgetting_scores.push(degradation);
1076 }
1077 }
1078
1079 if forgetting_scores.is_empty() {
1080 Ok(0.1) } else {
1082 let avg_forgetting =
1083 forgetting_scores.iter().sum::<f64>() / forgetting_scores.len() as f64;
1084 Ok(avg_forgetting.clamp(0.0, 1.0))
1085 }
1086 }
1087
1088 async fn calculate_cross_domain_coherence(
1090 &self,
1091 source: &DomainModel,
1092 target: &DomainSpecification,
1093 entity_alignments: &[EntityAlignment],
1094 ) -> Result<f64> {
1095 if entity_alignments.is_empty() {
1096 return Ok(0.5);
1097 }
1098
1099 let mut coherence_scores = Vec::new();
1100
1101 for alignment in entity_alignments.iter().take(15) {
1103 if alignment.confidence > 0.6 {
1104 if let Ok(source_embedding) =
1105 source.model.get_entity_embedding(&alignment.source_entity)
1106 {
1107 let target_embedding = self.create_simple_embedding(&alignment.target_entity);
1108
1109 let embedding_coherence =
1111 self.cosine_similarity(&source_embedding, &target_embedding);
1112
1113 let source_neighbors =
1115 self.get_source_neighbors(&alignment.source_entity, source);
1116 let target_neighbors =
1117 self.get_target_neighbors(&alignment.target_entity, target);
1118 let neighborhood_coherence = self
1119 .calculate_neighborhood_similarity(&source_neighbors, &target_neighbors);
1120
1121 let combined_coherence = (embedding_coherence + neighborhood_coherence) / 2.0;
1123 coherence_scores.push(combined_coherence);
1124 }
1125 }
1126 }
1127
1128 if coherence_scores.is_empty() {
1129 Ok(0.5)
1130 } else {
1131 let avg_coherence =
1132 coherence_scores.iter().sum::<f64>() / coherence_scores.len() as f64;
1133 Ok(avg_coherence.clamp(0.0, 1.0))
1134 }
1135 }
1136
1137 async fn calculate_knowledge_retention(
1139 &self,
1140 _source: &DomainModel,
1141 _target: &DomainSpecification,
1142 ) -> Result<f64> {
1143 Ok(0.85)
1145 }
1146
1147 async fn calculate_adaptation_speed(
1149 &self,
1150 _source: &DomainModel,
1151 _target: &DomainSpecification,
1152 ) -> Result<f64> {
1153 Ok(0.75)
1155 }
1156
1157 async fn calculate_transfer_robustness(
1159 &self,
1160 _source: &DomainModel,
1161 _target: &DomainSpecification,
1162 ) -> Result<f64> {
1163 Ok(0.8)
1165 }
1166
1167 async fn calculate_semantic_drift_detection(
1169 &self,
1170 _source: &DomainModel,
1171 _target: &DomainSpecification,
1172 ) -> Result<f64> {
1173 Ok(0.7)
1175 }
1176
1177 async fn calculate_generalization_ability(
1179 &self,
1180 _source: &DomainModel,
1181 _target: &DomainSpecification,
1182 ) -> Result<f64> {
1183 Ok(0.8)
1185 }
1186
1187 fn get_source_neighbors(&self, _entity: &str, source: &DomainModel) -> Vec<String> {
1189 let relations = source.model.get_relations();
1191 relations.into_iter().take(5).collect()
1192 }
1193
1194 fn get_target_neighbors(&self, entity: &str, target: &DomainSpecification) -> Vec<String> {
1196 let mut neighbors = Vec::new();
1197 for (subject, predicate, object) in &target.training_data {
1198 if subject == entity {
1199 neighbors.push(object.clone());
1200 neighbors.push(predicate.clone());
1201 } else if object == entity {
1202 neighbors.push(subject.clone());
1203 neighbors.push(predicate.clone());
1204 }
1205 }
1206 neighbors.into_iter().take(5).collect()
1207 }
1208
1209 fn calculate_neighborhood_similarity(
1211 &self,
1212 source_neighbors: &[String],
1213 target_neighbors: &[String],
1214 ) -> f64 {
1215 if source_neighbors.is_empty() && target_neighbors.is_empty() {
1216 return 1.0;
1217 }
1218
1219 if source_neighbors.is_empty() || target_neighbors.is_empty() {
1220 return 0.0;
1221 }
1222
1223 let source_set: HashSet<&String> = source_neighbors.iter().collect();
1224 let target_set: HashSet<&String> = target_neighbors.iter().collect();
1225
1226 let intersection = source_set.intersection(&target_set).count();
1227 let union = source_set.union(&target_set).count();
1228
1229 if union == 0 {
1230 0.0
1231 } else {
1232 intersection as f64 / union as f64
1233 }
1234 }
1235}
1236
1237pub struct TransferUtils;
1239
1240impl TransferUtils {
1241 pub fn analyze_domain_from_triples(
1243 _domain_id: String,
1244 triples: &[(String, String, String)],
1245 ) -> DomainCharacteristics {
1246 let mut entities = HashSet::new();
1247 let mut relations = HashSet::new();
1248
1249 for (subject, predicate, object) in triples {
1250 entities.insert(subject.clone());
1251 entities.insert(object.clone());
1252 relations.insert(predicate.clone());
1253 }
1254
1255 let num_entities = entities.len();
1256 let num_relations = relations.len();
1257 let num_triples = triples.len();
1258
1259 let mut entity_degrees = HashMap::new();
1261 for (subject, _, object) in triples {
1262 *entity_degrees.entry(subject.clone()).or_insert(0) += 1;
1263 *entity_degrees.entry(object.clone()).or_insert(0) += 1;
1264 }
1265 let avg_entity_degree = if num_entities > 0 {
1266 entity_degrees.values().sum::<usize>() as f64 / num_entities as f64
1267 } else {
1268 0.0
1269 };
1270
1271 let max_possible_edges = num_entities * (num_entities - 1);
1273 let graph_density = if max_possible_edges > 0 {
1274 num_triples as f64 / max_possible_edges as f64
1275 } else {
1276 0.0
1277 };
1278
1279 DomainCharacteristics {
1280 domain_type: "unknown".to_string(),
1281 language: "unknown".to_string(),
1282 entity_types: vec!["Entity".to_string()], relation_types: relations.into_iter().collect(),
1284 size_metrics: DomainSizeMetrics {
1285 num_entities,
1286 num_relations,
1287 num_triples,
1288 avg_entity_degree,
1289 graph_density,
1290 },
1291 complexity_metrics: DomainComplexityMetrics {
1292 entity_type_diversity: 1, relation_type_diversity: num_relations,
1294 hierarchical_depth: 3, semantic_diversity: 0.5, structural_complexity: avg_entity_degree / 10.0, },
1298 }
1299 }
1300
1301 pub fn create_test_domain_specification(
1303 domain_id: String,
1304 training_data: Vec<(String, String, String)>,
1305 ) -> DomainSpecification {
1306 let total = training_data.len();
1308 let train_size = (total as f64 * 0.7) as usize;
1309 let val_size = (total as f64 * 0.15) as usize;
1310
1311 let training = training_data[..train_size].to_vec();
1312 let validation = training_data[train_size..train_size + val_size].to_vec();
1313 let test = training_data[train_size + val_size..].to_vec();
1314
1315 let characteristics = Self::analyze_domain_from_triples(domain_id.clone(), &training);
1316
1317 DomainSpecification {
1318 domain_id,
1319 characteristics,
1320 training_data: training,
1321 validation_data: validation,
1322 test_data: test,
1323 }
1324 }
1325}
1326
1327#[cfg(test)]
1328mod tests {
1329 use super::*;
1330 use crate::models::transe::TransE;
1331
1332 #[test]
1333 fn test_transfer_config_default() {
1334 let config = TransferConfig::default();
1335 assert!(config.enable_domain_adaptation);
1336 assert!(config.use_adversarial_alignment);
1337 assert_eq!(config.max_alignment_iterations, 100);
1338 }
1339
1340 #[test]
1341 fn test_domain_characteristics_creation() {
1342 let triples = vec![
1343 ("alice".to_string(), "knows".to_string(), "bob".to_string()),
1344 ("bob".to_string(), "likes".to_string(), "pizza".to_string()),
1345 (
1346 "alice".to_string(),
1347 "likes".to_string(),
1348 "coffee".to_string(),
1349 ),
1350 ];
1351
1352 let characteristics =
1353 TransferUtils::analyze_domain_from_triples("test_domain".to_string(), &triples);
1354
1355 assert_eq!(characteristics.size_metrics.num_triples, 3);
1356 assert_eq!(characteristics.size_metrics.num_entities, 4); assert_eq!(characteristics.size_metrics.num_relations, 2); }
1359
1360 #[test]
1361 fn test_string_similarity() {
1362 let manager = CrossDomainTransferManager::new(TransferConfig::default());
1363
1364 let sim1 = manager.calculate_string_similarity("hello", "hello");
1365 assert_eq!(sim1, 1.0);
1366
1367 let sim2 = manager.calculate_string_similarity("hello", "world");
1368 assert!(sim2 < 0.5);
1369
1370 let sim3 = manager.calculate_string_similarity("testing", "test");
1371 assert!(sim3 > 0.3);
1372 }
1373
1374 #[tokio::test]
1375 async fn test_transfer_evaluation() {
1376 let mut manager = CrossDomainTransferManager::new(TransferConfig::default());
1377
1378 let source_model = Box::new(TransE::new(Default::default()));
1380 let source_characteristics = DomainCharacteristics {
1381 domain_type: "test".to_string(),
1382 language: "en".to_string(),
1383 entity_types: vec!["Person".to_string()],
1384 relation_types: vec!["knows".to_string()],
1385 size_metrics: DomainSizeMetrics {
1386 num_entities: 100,
1387 num_relations: 10,
1388 num_triples: 500,
1389 avg_entity_degree: 5.0,
1390 graph_density: 0.01,
1391 },
1392 complexity_metrics: DomainComplexityMetrics {
1393 entity_type_diversity: 2,
1394 relation_type_diversity: 10,
1395 hierarchical_depth: 3,
1396 semantic_diversity: 0.6,
1397 structural_complexity: 0.5,
1398 },
1399 };
1400
1401 manager
1402 .register_source_domain("source".to_string(), source_model, source_characteristics)
1403 .unwrap();
1404
1405 let target_spec = TransferUtils::create_test_domain_specification(
1407 "target".to_string(),
1408 vec![
1409 ("alice".to_string(), "knows".to_string(), "bob".to_string()),
1410 (
1411 "bob".to_string(),
1412 "knows".to_string(),
1413 "charlie".to_string(),
1414 ),
1415 ],
1416 );
1417
1418 manager.register_target_domain(target_spec).unwrap();
1419
1420 let results = manager
1422 .evaluate_transfer("source", "target", TransferStrategy::DirectTransfer)
1423 .await;
1424
1425 assert!(results.is_ok());
1426 let results = results.unwrap();
1427 assert_eq!(results.source_domain, "source");
1428 assert_eq!(results.target_domain, "target");
1429 assert!(results.overall_quality >= 0.0);
1430 assert!(results.overall_quality <= 1.0);
1431 }
1432
1433 #[test]
1434 fn test_domain_similarity_calculation() {
1435 let manager = CrossDomainTransferManager::new(TransferConfig::default());
1436
1437 let source = DomainCharacteristics {
1438 domain_type: "biomedical".to_string(),
1439 language: "en".to_string(),
1440 entity_types: vec!["Gene".to_string(), "Disease".to_string()],
1441 relation_types: vec!["causes".to_string(), "treats".to_string()],
1442 size_metrics: DomainSizeMetrics {
1443 num_entities: 1000,
1444 num_relations: 50,
1445 num_triples: 5000,
1446 avg_entity_degree: 5.0,
1447 graph_density: 0.005,
1448 },
1449 complexity_metrics: DomainComplexityMetrics {
1450 entity_type_diversity: 2,
1451 relation_type_diversity: 50,
1452 hierarchical_depth: 4,
1453 semantic_diversity: 0.7,
1454 structural_complexity: 0.6,
1455 },
1456 };
1457
1458 let target = DomainCharacteristics {
1459 domain_type: "medical".to_string(),
1460 language: "en".to_string(),
1461 entity_types: vec!["Gene".to_string(), "Drug".to_string()],
1462 relation_types: vec!["treats".to_string(), "interacts".to_string()],
1463 size_metrics: DomainSizeMetrics {
1464 num_entities: 800,
1465 num_relations: 40,
1466 num_triples: 4000,
1467 avg_entity_degree: 5.0,
1468 graph_density: 0.006,
1469 },
1470 complexity_metrics: DomainComplexityMetrics {
1471 entity_type_diversity: 2,
1472 relation_type_diversity: 40,
1473 hierarchical_depth: 3,
1474 semantic_diversity: 0.6,
1475 structural_complexity: 0.5,
1476 },
1477 };
1478
1479 let similarity = manager
1480 .calculate_domain_similarity(&source, &target)
1481 .unwrap();
1482 assert!(similarity > 0.0);
1483 assert!(similarity <= 1.0);
1484
1485 assert!(similarity > 0.2);
1487 }
1488}