use crate::ModelConfig;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct VisionLanguageGraphConfig {
pub base_config: ModelConfig,
pub vision_config: VisionEncoderConfig,
pub language_config: LanguageEncoderConfig,
pub graph_config: GraphEncoderConfig,
pub transformer_config: MultiModalTransformerConfig,
pub meta_learning_config: MetaLearningConfig,
pub transfer_config: TransferLearningConfig,
pub joint_training_config: JointTrainingConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisionEncoderConfig {
pub architecture: VisionArchitecture,
pub image_size: (usize, usize),
pub channels: usize,
pub patch_size: (usize, usize),
pub vision_dim: usize,
pub cnn_config: CNNConfig,
pub vit_config: ViTConfig,
}
impl Default for VisionEncoderConfig {
fn default() -> Self {
Self {
architecture: VisionArchitecture::VisionTransformer,
image_size: (224, 224),
channels: 3,
patch_size: (16, 16),
vision_dim: 768,
cnn_config: CNNConfig::default(),
vit_config: ViTConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum VisionArchitecture {
ResNet,
EfficientNet,
DenseNet,
VisionTransformer,
DeiT,
Swin,
ConViT,
CvT,
CLIPVision,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CNNConfig {
pub num_layers: usize,
pub filter_sizes: Vec<usize>,
pub stride_sizes: Vec<usize>,
pub pooling: PoolingType,
pub normalization: NormalizationType,
}
impl Default for CNNConfig {
fn default() -> Self {
Self {
num_layers: 4,
filter_sizes: vec![64, 128, 256, 512],
stride_sizes: vec![2, 2, 2, 2],
pooling: PoolingType::AdaptiveAvgPool,
normalization: NormalizationType::BatchNorm,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ViTConfig {
pub num_layers: usize,
pub num_heads: usize,
pub mlp_dim: usize,
pub dropout_rate: f32,
pub position_encoding: PositionEncodingType,
}
impl Default for ViTConfig {
fn default() -> Self {
Self {
num_layers: 12,
num_heads: 12,
mlp_dim: 3072,
dropout_rate: 0.1,
position_encoding: PositionEncodingType::Learnable,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PoolingType {
MaxPool,
AvgPool,
AdaptiveAvgPool,
AdaptiveMaxPool,
GlobalAvgPool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum NormalizationType {
BatchNorm,
LayerNorm,
GroupNorm,
InstanceNorm,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PositionEncodingType {
Learnable,
Sinusoidal,
Relative,
RoPE, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LanguageEncoderConfig {
pub architecture: LanguageArchitecture,
pub vocab_size: usize,
pub language_dim: usize,
pub max_seq_length: usize,
pub transformer_config: LanguageTransformerConfig,
}
impl Default for LanguageEncoderConfig {
fn default() -> Self {
Self {
architecture: LanguageArchitecture::BERT,
vocab_size: 30522,
language_dim: 768,
max_seq_length: 512,
transformer_config: LanguageTransformerConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LanguageArchitecture {
BERT,
RoBERTa,
DeBERTa,
ELECTRA,
GPT,
T5,
CLIP,
ALIGN,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LanguageTransformerConfig {
pub num_layers: usize,
pub num_heads: usize,
pub hidden_dim: usize,
pub intermediate_dim: usize,
pub dropout_rate: f32,
pub activation: ActivationFunction,
}
impl Default for LanguageTransformerConfig {
fn default() -> Self {
Self {
num_layers: 12,
num_heads: 12,
hidden_dim: 768,
intermediate_dim: 3072,
dropout_rate: 0.1,
activation: ActivationFunction::GELU,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphEncoderConfig {
pub architecture: GraphArchitecture,
pub node_dim: usize,
pub edge_dim: usize,
pub graph_dim: usize,
pub num_layers: usize,
pub aggregation: AggregationFunction,
pub readout: ReadoutFunction,
}
impl Default for GraphEncoderConfig {
fn default() -> Self {
Self {
architecture: GraphArchitecture::GraphTransformer,
node_dim: 256,
edge_dim: 128,
graph_dim: 768, num_layers: 6,
aggregation: AggregationFunction::Attention,
readout: ReadoutFunction::GlobalAttention,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GraphArchitecture {
GCN,
GraphSAGE,
GAT,
GraphTransformer,
GIN,
PNA,
GPS, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AggregationFunction {
Mean,
Max,
Sum,
Attention,
LSTM,
GRU,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ReadoutFunction {
GlobalMean,
GlobalMax,
GlobalSum,
GlobalAttention,
Set2Set,
DiffPool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ActivationFunction {
ReLU,
GELU,
Swish,
Mish,
ELU,
LeakyReLU,
Tanh,
Sigmoid,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiModalTransformerConfig {
pub unified_dim: usize,
pub num_fusion_layers: usize,
pub cross_attention_config: CrossAttentionConfig,
pub fusion_strategy: FusionStrategy,
pub modality_encoding: ModalityEncoding,
}
impl Default for MultiModalTransformerConfig {
fn default() -> Self {
Self {
unified_dim: 768,
num_fusion_layers: 6,
cross_attention_config: CrossAttentionConfig::default(),
fusion_strategy: FusionStrategy::CrossAttention,
modality_encoding: ModalityEncoding::Learnable,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossAttentionConfig {
pub num_heads: usize,
pub head_dim: usize,
pub dropout_rate: f32,
pub use_residual: bool,
pub attention_mechanism: AttentionMechanism,
}
impl Default for CrossAttentionConfig {
fn default() -> Self {
Self {
num_heads: 12,
head_dim: 64,
dropout_rate: 0.1,
use_residual: true,
attention_mechanism: AttentionMechanism::ScaledDotProduct,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AttentionMechanism {
ScaledDotProduct,
MultiHead,
SparseAttention,
LinearAttention,
PerformerAttention,
CoAttn, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FusionStrategy {
EarlyFusion,
LateFusion,
CrossAttention,
ProgressiveFusion,
AdaptiveFusion,
TensorFusion,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModalityEncoding {
None,
Learnable,
Fixed,
PositionAware,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetaLearningConfig {
pub algorithm: MetaLearningAlgorithm,
pub support_set_size: usize,
pub query_set_size: usize,
pub adaptation_steps: usize,
pub inner_lr: f32,
pub outer_lr: f32,
pub task_specific_params: TaskSpecificParams,
}
impl Default for MetaLearningConfig {
fn default() -> Self {
Self {
algorithm: MetaLearningAlgorithm::MAML,
support_set_size: 5,
query_set_size: 15,
adaptation_steps: 5,
inner_lr: 0.01,
outer_lr: 0.001,
task_specific_params: TaskSpecificParams::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MetaLearningAlgorithm {
MAML,
FOMAML,
Reptile,
ProtoNet,
RelationNet,
MANN,
AMAML,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskSpecificParams {
pub task_categories: Vec<TaskCategory>,
pub domain_weights: HashMap<String, f32>,
pub difficulty_adjustment: bool,
}
impl Default for TaskSpecificParams {
fn default() -> Self {
let mut domain_weights = HashMap::new();
domain_weights.insert("vision".to_string(), 1.0);
domain_weights.insert("language".to_string(), 1.0);
domain_weights.insert("graph".to_string(), 1.0);
Self {
task_categories: vec![
TaskCategory::ImageCaptioning,
TaskCategory::VisualQuestionAnswering,
TaskCategory::GraphGrounding,
],
domain_weights,
difficulty_adjustment: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TaskCategory {
ImageCaptioning,
VisualQuestionAnswering,
ImageTextRetrieval,
GraphTextAlignment,
GraphGrounding,
MultiModalReasoning,
CrossModalGeneration,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransferLearningConfig {
pub strategy: TransferStrategy,
pub source_domains: Vec<String>,
pub target_domains: Vec<String>,
pub domain_adaptation: DomainAdaptationConfig,
pub zero_shot_config: ZeroShotConfig,
pub few_shot_config: FewShotConfig,
}
impl Default for TransferLearningConfig {
fn default() -> Self {
Self {
strategy: TransferStrategy::ProgressiveTransfer,
source_domains: vec!["general".to_string(), "imagenet".to_string()],
target_domains: vec!["medical".to_string(), "scientific".to_string()],
domain_adaptation: DomainAdaptationConfig::default(),
zero_shot_config: ZeroShotConfig::default(),
few_shot_config: FewShotConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TransferStrategy {
FineTuning,
FeatureExtraction,
ProgressiveTransfer,
MultiTaskLearning,
DomainAdaptation,
ContinualLearning,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DomainAdaptationConfig {
pub method: DomainAdaptationMethod,
pub adversarial_training: bool,
pub gradient_reversal: bool,
pub domain_classifier_weight: f32,
}
impl Default for DomainAdaptationConfig {
fn default() -> Self {
Self {
method: DomainAdaptationMethod::DANN,
adversarial_training: true,
gradient_reversal: true,
domain_classifier_weight: 0.1,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DomainAdaptationMethod {
DANN,
MMD,
CORAL,
WDGRL,
CDAN,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroShotConfig {
pub method: ZeroShotMethod,
pub semantic_dim: usize,
pub use_attributes: bool,
pub attribute_dim: usize,
}
impl Default for ZeroShotConfig {
fn default() -> Self {
Self {
method: ZeroShotMethod::CLIP,
semantic_dim: 512,
use_attributes: true,
attribute_dim: 256,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ZeroShotMethod {
CLIP,
ALIGN,
Attribute,
SemanticEmbedding,
KnowledgeGuided,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FewShotConfig {
pub method: FewShotMethod,
pub num_shots: usize,
pub episode_config: EpisodeConfig,
}
impl Default for FewShotConfig {
fn default() -> Self {
Self {
method: FewShotMethod::ProtoNet,
num_shots: 5,
episode_config: EpisodeConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FewShotMethod {
ProtoNet,
MatchingNet,
RelationNet,
MetaLearning,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EpisodeConfig {
pub num_classes: usize,
pub support_per_class: usize,
pub query_per_class: usize,
}
impl Default for EpisodeConfig {
fn default() -> Self {
Self {
num_classes: 5,
support_per_class: 5,
query_per_class: 15,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JointTrainingConfig {
pub objectives: Vec<TrainingObjective>,
pub objective_weights: HashMap<String, f32>,
pub curriculum_learning: bool,
pub progressive_training: bool,
}
impl Default for JointTrainingConfig {
fn default() -> Self {
let mut objective_weights = HashMap::new();
objective_weights.insert("vision_language_alignment".to_string(), 1.0);
objective_weights.insert("language_graph_alignment".to_string(), 0.8);
objective_weights.insert("vision_graph_alignment".to_string(), 0.6);
objective_weights.insert("tri_modal_alignment".to_string(), 1.2);
Self {
objectives: vec![
TrainingObjective::ContrastiveLearning,
TrainingObjective::MaskedLanguageModeling,
TrainingObjective::ImageTextMatching,
TrainingObjective::GraphAlignment,
],
objective_weights,
curriculum_learning: true,
progressive_training: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TrainingObjective {
ContrastiveLearning,
MaskedLanguageModeling,
ImageTextMatching,
GraphAlignment,
VisualQuestionAnswering,
ImageCaptioning,
GraphReasoning,
MultiModalReasoning,
}