use crate::ModelConfig;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossModalConfig {
pub base_config: ModelConfig,
pub text_dim: usize,
pub kg_dim: usize,
pub unified_dim: usize,
pub alignment_objective: AlignmentObjective,
pub contrastive_config: ContrastiveConfig,
pub task_weights: HashMap<String, f32>,
pub cross_domain_config: CrossDomainConfig,
}
impl Default for CrossModalConfig {
fn default() -> Self {
let mut task_weights = HashMap::new();
task_weights.insert("text_kg_alignment".to_string(), 1.0);
task_weights.insert("entity_description".to_string(), 0.8);
task_weights.insert("property_text".to_string(), 0.6);
task_weights.insert("multilingual".to_string(), 0.4);
Self {
base_config: ModelConfig::default(),
text_dim: 768,
kg_dim: 128,
unified_dim: 512,
alignment_objective: AlignmentObjective::ContrastiveLearning,
contrastive_config: ContrastiveConfig::default(),
task_weights,
cross_domain_config: CrossDomainConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AlignmentObjective {
ContrastiveLearning,
MutualInformation,
AdversarialAlignment,
MultiTaskLearning,
SelfSupervised,
MetaLearning,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContrastiveConfig {
pub temperature: f32,
pub negative_samples: usize,
pub hard_negative_mining: bool,
pub margin: f32,
pub use_info_nce: bool,
}
impl Default for ContrastiveConfig {
fn default() -> Self {
Self {
temperature: 0.07,
negative_samples: 64,
hard_negative_mining: true,
margin: 0.2,
use_info_nce: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossDomainConfig {
pub enable_domain_adaptation: bool,
pub source_domains: Vec<String>,
pub target_domains: Vec<String>,
pub domain_adversarial: bool,
pub gradual_adaptation: bool,
}
impl Default for CrossDomainConfig {
fn default() -> Self {
Self {
enable_domain_adaptation: true,
source_domains: vec!["general".to_string(), "scientific".to_string()],
target_domains: vec!["biomedical".to_string(), "legal".to_string()],
domain_adversarial: false,
gradual_adaptation: true,
}
}
}