oxirs_embed/multimodal/impl/
config.rs1use crate::ModelConfig;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct CrossModalConfig {
10 pub base_config: ModelConfig,
11 pub text_dim: usize,
13 pub kg_dim: usize,
15 pub unified_dim: usize,
17 pub alignment_objective: AlignmentObjective,
19 pub contrastive_config: ContrastiveConfig,
21 pub task_weights: HashMap<String, f32>,
23 pub cross_domain_config: CrossDomainConfig,
25}
26
27impl Default for CrossModalConfig {
28 fn default() -> Self {
29 let mut task_weights = HashMap::new();
30 task_weights.insert("text_kg_alignment".to_string(), 1.0);
31 task_weights.insert("entity_description".to_string(), 0.8);
32 task_weights.insert("property_text".to_string(), 0.6);
33 task_weights.insert("multilingual".to_string(), 0.4);
34
35 Self {
36 base_config: ModelConfig::default(),
37 text_dim: 768,
38 kg_dim: 128,
39 unified_dim: 512,
40 alignment_objective: AlignmentObjective::ContrastiveLearning,
41 contrastive_config: ContrastiveConfig::default(),
42 task_weights,
43 cross_domain_config: CrossDomainConfig::default(),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub enum AlignmentObjective {
51 ContrastiveLearning,
53 MutualInformation,
55 AdversarialAlignment,
57 MultiTaskLearning,
59 SelfSupervised,
61 MetaLearning,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct ContrastiveConfig {
68 pub temperature: f32,
70 pub negative_samples: usize,
72 pub hard_negative_mining: bool,
74 pub margin: f32,
76 pub use_info_nce: bool,
78}
79
80impl Default for ContrastiveConfig {
81 fn default() -> Self {
82 Self {
83 temperature: 0.07,
84 negative_samples: 64,
85 hard_negative_mining: true,
86 margin: 0.2,
87 use_info_nce: true,
88 }
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct CrossDomainConfig {
95 pub enable_domain_adaptation: bool,
97 pub source_domains: Vec<String>,
99 pub target_domains: Vec<String>,
101 pub domain_adversarial: bool,
103 pub gradual_adaptation: bool,
105}
106
107impl Default for CrossDomainConfig {
108 fn default() -> Self {
109 Self {
110 enable_domain_adaptation: true,
111 source_domains: vec!["general".to_string(), "scientific".to_string()],
112 target_domains: vec!["biomedical".to_string(), "legal".to_string()],
113 domain_adversarial: false,
114 gradual_adaptation: true,
115 }
116 }
117}