oxirs_embed/multimodal/impl/
config.rs

1//! Configuration types for multi-modal embeddings
2
3use crate::ModelConfig;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7/// Cross-modal alignment configuration
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct CrossModalConfig {
10    pub base_config: ModelConfig,
11    /// Text embedding dimension
12    pub text_dim: usize,
13    /// Knowledge graph embedding dimension
14    pub kg_dim: usize,
15    /// Unified embedding dimension
16    pub unified_dim: usize,
17    /// Alignment objective type
18    pub alignment_objective: AlignmentObjective,
19    /// Contrastive learning parameters
20    pub contrastive_config: ContrastiveConfig,
21    /// Multi-task learning weights
22    pub task_weights: HashMap<String, f32>,
23    /// Cross-domain transfer settings
24    pub cross_domain_config: CrossDomainConfig,
25}
26
27impl Default for CrossModalConfig {
28    fn default() -> Self {
29        let mut task_weights = HashMap::new();
30        task_weights.insert("text_kg_alignment".to_string(), 1.0);
31        task_weights.insert("entity_description".to_string(), 0.8);
32        task_weights.insert("property_text".to_string(), 0.6);
33        task_weights.insert("multilingual".to_string(), 0.4);
34
35        Self {
36            base_config: ModelConfig::default(),
37            text_dim: 768,
38            kg_dim: 128,
39            unified_dim: 512,
40            alignment_objective: AlignmentObjective::ContrastiveLearning,
41            contrastive_config: ContrastiveConfig::default(),
42            task_weights,
43            cross_domain_config: CrossDomainConfig::default(),
44        }
45    }
46}
47
48/// Alignment objective types
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub enum AlignmentObjective {
51    /// Contrastive learning for positive/negative pairs
52    ContrastiveLearning,
53    /// Mutual information maximization
54    MutualInformation,
55    /// Adversarial alignment with discriminator
56    AdversarialAlignment,
57    /// Multi-task learning with shared representations
58    MultiTaskLearning,
59    /// Self-supervised objectives
60    SelfSupervised,
61    /// Meta-learning for few-shot adaptation
62    MetaLearning,
63}
64
65/// Contrastive learning configuration
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct ContrastiveConfig {
68    /// Temperature parameter for contrastive loss
69    pub temperature: f32,
70    /// Number of negative samples
71    pub negative_samples: usize,
72    /// Hard negative mining
73    pub hard_negative_mining: bool,
74    /// Margin for triplet loss
75    pub margin: f32,
76    /// Use InfoNCE loss
77    pub use_info_nce: bool,
78}
79
80impl Default for ContrastiveConfig {
81    fn default() -> Self {
82        Self {
83            temperature: 0.07,
84            negative_samples: 64,
85            hard_negative_mining: true,
86            margin: 0.2,
87            use_info_nce: true,
88        }
89    }
90}
91
92/// Cross-domain transfer configuration
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct CrossDomainConfig {
95    /// Enable domain adaptation
96    pub enable_domain_adaptation: bool,
97    /// Source domains for transfer learning
98    pub source_domains: Vec<String>,
99    /// Target domains
100    pub target_domains: Vec<String>,
101    /// Domain adversarial training
102    pub domain_adversarial: bool,
103    /// Gradual domain adaptation
104    pub gradual_adaptation: bool,
105}
106
107impl Default for CrossDomainConfig {
108    fn default() -> Self {
109        Self {
110            enable_domain_adaptation: true,
111            source_domains: vec!["general".to_string(), "scientific".to_string()],
112            target_domains: vec!["biomedical".to_string(), "legal".to_string()],
113            domain_adversarial: false,
114            gradual_adaptation: true,
115        }
116    }
117}