oxirs_embed/vision_language_graph/
config.rs

1//! Module for vision-language-graph integration
2
3use crate::ModelConfig;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7/// Configuration for vision-language-graph integration
8#[derive(Debug, Clone, Serialize, Deserialize, Default)]
9pub struct VisionLanguageGraphConfig {
10    pub base_config: ModelConfig,
11    /// Vision encoder configuration
12    pub vision_config: VisionEncoderConfig,
13    /// Language encoder configuration  
14    pub language_config: LanguageEncoderConfig,
15    /// Graph encoder configuration
16    pub graph_config: GraphEncoderConfig,
17    /// Multi-modal transformer configuration
18    pub transformer_config: MultiModalTransformerConfig,
19    /// Meta-learning configuration
20    pub meta_learning_config: MetaLearningConfig,
21    /// Transfer learning configuration
22    pub transfer_config: TransferLearningConfig,
23    /// Joint training configuration
24    pub joint_training_config: JointTrainingConfig,
25}
26
27/// Vision encoder configuration
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct VisionEncoderConfig {
30    /// Vision model architecture
31    pub architecture: VisionArchitecture,
32    /// Input image dimensions
33    pub image_size: (usize, usize),
34    /// Number of channels
35    pub channels: usize,
36    /// Patch size for vision transformer
37    pub patch_size: (usize, usize),
38    /// Vision embedding dimension
39    pub vision_dim: usize,
40    /// CNN backbone configuration
41    pub cnn_config: CNNConfig,
42    /// Vision transformer configuration
43    pub vit_config: ViTConfig,
44}
45
46impl Default for VisionEncoderConfig {
47    fn default() -> Self {
48        Self {
49            architecture: VisionArchitecture::VisionTransformer,
50            image_size: (224, 224),
51            channels: 3,
52            patch_size: (16, 16),
53            vision_dim: 768,
54            cnn_config: CNNConfig::default(),
55            vit_config: ViTConfig::default(),
56        }
57    }
58}
59
60/// Vision architectures
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub enum VisionArchitecture {
63    /// Convolutional Neural Networks
64    ResNet,
65    EfficientNet,
66    DenseNet,
67    /// Vision Transformers
68    VisionTransformer,
69    DeiT,
70    Swin,
71    /// Hybrid architectures
72    ConViT,
73    CvT,
74    /// CLIP-style encoders
75    CLIPVision,
76}
77
78/// CNN configuration
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct CNNConfig {
81    /// Number of layers
82    pub num_layers: usize,
83    /// Filter sizes per layer
84    pub filter_sizes: Vec<usize>,
85    /// Stride sizes
86    pub stride_sizes: Vec<usize>,
87    /// Pooling configuration
88    pub pooling: PoolingType,
89    /// Normalization type
90    pub normalization: NormalizationType,
91}
92
93impl Default for CNNConfig {
94    fn default() -> Self {
95        Self {
96            num_layers: 4,
97            filter_sizes: vec![64, 128, 256, 512],
98            stride_sizes: vec![2, 2, 2, 2],
99            pooling: PoolingType::AdaptiveAvgPool,
100            normalization: NormalizationType::BatchNorm,
101        }
102    }
103}
104
105/// Vision Transformer configuration
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct ViTConfig {
108    /// Number of transformer layers
109    pub num_layers: usize,
110    /// Number of attention heads
111    pub num_heads: usize,
112    /// MLP hidden dimension
113    pub mlp_dim: usize,
114    /// Dropout rate
115    pub dropout_rate: f32,
116    /// Position encoding type
117    pub position_encoding: PositionEncodingType,
118}
119
120impl Default for ViTConfig {
121    fn default() -> Self {
122        Self {
123            num_layers: 12,
124            num_heads: 12,
125            mlp_dim: 3072,
126            dropout_rate: 0.1,
127            position_encoding: PositionEncodingType::Learnable,
128        }
129    }
130}
131
132/// Pooling types
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub enum PoolingType {
135    MaxPool,
136    AvgPool,
137    AdaptiveAvgPool,
138    AdaptiveMaxPool,
139    GlobalAvgPool,
140}
141
142/// Normalization types
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub enum NormalizationType {
145    BatchNorm,
146    LayerNorm,
147    GroupNorm,
148    InstanceNorm,
149}
150
151/// Position encoding types
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub enum PositionEncodingType {
154    Learnable,
155    Sinusoidal,
156    Relative,
157    RoPE, // Rotary Position Embedding
158}
159
160/// Language encoder configuration
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct LanguageEncoderConfig {
163    /// Language model architecture
164    pub architecture: LanguageArchitecture,
165    /// Vocabulary size
166    pub vocab_size: usize,
167    /// Language embedding dimension
168    pub language_dim: usize,
169    /// Maximum sequence length
170    pub max_seq_length: usize,
171    /// Transformer configuration
172    pub transformer_config: LanguageTransformerConfig,
173}
174
175impl Default for LanguageEncoderConfig {
176    fn default() -> Self {
177        Self {
178            architecture: LanguageArchitecture::BERT,
179            vocab_size: 30522,
180            language_dim: 768,
181            max_seq_length: 512,
182            transformer_config: LanguageTransformerConfig::default(),
183        }
184    }
185}
186
187/// Language architectures
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub enum LanguageArchitecture {
190    BERT,
191    RoBERTa,
192    DeBERTa,
193    ELECTRA,
194    GPT,
195    T5,
196    CLIP,
197    ALIGN,
198}
199
200/// Language transformer configuration
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct LanguageTransformerConfig {
203    pub num_layers: usize,
204    pub num_heads: usize,
205    pub hidden_dim: usize,
206    pub intermediate_dim: usize,
207    pub dropout_rate: f32,
208    pub activation: ActivationFunction,
209}
210
211impl Default for LanguageTransformerConfig {
212    fn default() -> Self {
213        Self {
214            num_layers: 12,
215            num_heads: 12,
216            hidden_dim: 768,
217            intermediate_dim: 3072,
218            dropout_rate: 0.1,
219            activation: ActivationFunction::GELU,
220        }
221    }
222}
223
224/// Graph encoder configuration
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct GraphEncoderConfig {
227    /// Graph neural network architecture
228    pub architecture: GraphArchitecture,
229    /// Node embedding dimension
230    pub node_dim: usize,
231    /// Edge embedding dimension
232    pub edge_dim: usize,
233    /// Graph embedding dimension
234    pub graph_dim: usize,
235    /// Number of GNN layers
236    pub num_layers: usize,
237    /// Aggregation function
238    pub aggregation: AggregationFunction,
239    /// Readout function
240    pub readout: ReadoutFunction,
241}
242
243impl Default for GraphEncoderConfig {
244    fn default() -> Self {
245        Self {
246            architecture: GraphArchitecture::GraphTransformer,
247            node_dim: 256,
248            edge_dim: 128,
249            graph_dim: 768, // Match unified_dim for proper fusion
250            num_layers: 6,
251            aggregation: AggregationFunction::Attention,
252            readout: ReadoutFunction::GlobalAttention,
253        }
254    }
255}
256
257/// Graph neural network architectures
258#[derive(Debug, Clone, Serialize, Deserialize)]
259pub enum GraphArchitecture {
260    GCN,
261    GraphSAGE,
262    GAT,
263    GraphTransformer,
264    GIN,
265    PNA,
266    GPS, // General, Powerful, Scalable Graph Transformer
267}
268
269/// Aggregation functions for GNNs
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub enum AggregationFunction {
272    Mean,
273    Max,
274    Sum,
275    Attention,
276    LSTM,
277    GRU,
278}
279
280/// Readout functions for graph-level representations
281#[derive(Debug, Clone, Serialize, Deserialize)]
282pub enum ReadoutFunction {
283    GlobalMean,
284    GlobalMax,
285    GlobalSum,
286    GlobalAttention,
287    Set2Set,
288    DiffPool,
289}
290
291/// Activation functions
292#[derive(Debug, Clone, Serialize, Deserialize)]
293pub enum ActivationFunction {
294    ReLU,
295    GELU,
296    Swish,
297    Mish,
298    ELU,
299    LeakyReLU,
300    Tanh,
301    Sigmoid,
302}
303
304/// Multi-modal transformer configuration
305#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct MultiModalTransformerConfig {
307    /// Unified embedding dimension
308    pub unified_dim: usize,
309    /// Number of fusion layers
310    pub num_fusion_layers: usize,
311    /// Cross-attention configuration
312    pub cross_attention_config: CrossAttentionConfig,
313    /// Fusion strategy
314    pub fusion_strategy: FusionStrategy,
315    /// Positional encoding for modalities
316    pub modality_encoding: ModalityEncoding,
317}
318
319impl Default for MultiModalTransformerConfig {
320    fn default() -> Self {
321        Self {
322            unified_dim: 768,
323            num_fusion_layers: 6,
324            cross_attention_config: CrossAttentionConfig::default(),
325            fusion_strategy: FusionStrategy::CrossAttention,
326            modality_encoding: ModalityEncoding::Learnable,
327        }
328    }
329}
330
331/// Cross-attention configuration
332#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct CrossAttentionConfig {
334    /// Number of attention heads
335    pub num_heads: usize,
336    /// Attention head dimension
337    pub head_dim: usize,
338    /// Dropout rate
339    pub dropout_rate: f32,
340    /// Use residual connections
341    pub use_residual: bool,
342    /// Attention mechanism
343    pub attention_mechanism: AttentionMechanism,
344}
345
346impl Default for CrossAttentionConfig {
347    fn default() -> Self {
348        Self {
349            num_heads: 12,
350            head_dim: 64,
351            dropout_rate: 0.1,
352            use_residual: true,
353            attention_mechanism: AttentionMechanism::ScaledDotProduct,
354        }
355    }
356}
357
358/// Attention mechanisms
359#[derive(Debug, Clone, Serialize, Deserialize)]
360pub enum AttentionMechanism {
361    ScaledDotProduct,
362    MultiHead,
363    SparseAttention,
364    LinearAttention,
365    PerformerAttention,
366    CoAttn, // Co-Attention
367}
368
369/// Fusion strategies
370#[derive(Debug, Clone, Serialize, Deserialize)]
371pub enum FusionStrategy {
372    /// Early fusion (concatenation)
373    EarlyFusion,
374    /// Late fusion (separate processing)
375    LateFusion,
376    /// Cross-attention between modalities
377    CrossAttention,
378    /// Progressive fusion
379    ProgressiveFusion,
380    /// Adaptive fusion
381    AdaptiveFusion,
382    /// Tensor fusion
383    TensorFusion,
384}
385
386/// Modality encoding types
387#[derive(Debug, Clone, Serialize, Deserialize)]
388pub enum ModalityEncoding {
389    /// No modality encoding
390    None,
391    /// Learnable modality embeddings
392    Learnable,
393    /// Fixed modality embeddings
394    Fixed,
395    /// Position-aware modality encoding
396    PositionAware,
397}
398
399/// Meta-learning configuration
400#[derive(Debug, Clone, Serialize, Deserialize)]
401pub struct MetaLearningConfig {
402    /// Meta-learning algorithm
403    pub algorithm: MetaLearningAlgorithm,
404    /// Support set size for few-shot learning
405    pub support_set_size: usize,
406    /// Query set size
407    pub query_set_size: usize,
408    /// Number of adaptation steps
409    pub adaptation_steps: usize,
410    /// Inner learning rate
411    pub inner_lr: f32,
412    /// Outer learning rate
413    pub outer_lr: f32,
414    /// Task-specific parameters
415    pub task_specific_params: TaskSpecificParams,
416}
417
418impl Default for MetaLearningConfig {
419    fn default() -> Self {
420        Self {
421            algorithm: MetaLearningAlgorithm::MAML,
422            support_set_size: 5,
423            query_set_size: 15,
424            adaptation_steps: 5,
425            inner_lr: 0.01,
426            outer_lr: 0.001,
427            task_specific_params: TaskSpecificParams::default(),
428        }
429    }
430}
431
432/// Meta-learning algorithms
433#[derive(Debug, Clone, Serialize, Deserialize)]
434pub enum MetaLearningAlgorithm {
435    /// Model-Agnostic Meta-Learning
436    MAML,
437    /// First-Order MAML
438    FOMAML,
439    /// Reptile
440    Reptile,
441    /// Prototypical Networks
442    ProtoNet,
443    /// Relation Networks
444    RelationNet,
445    /// Memory-Augmented Neural Networks
446    MANN,
447    /// Meta-Learning with Adaptive Parameters
448    AMAML,
449}
450
451/// Task-specific parameters
452#[derive(Debug, Clone, Serialize, Deserialize)]
453pub struct TaskSpecificParams {
454    /// Task categories
455    pub task_categories: Vec<TaskCategory>,
456    /// Domain-specific weights
457    pub domain_weights: HashMap<String, f32>,
458    /// Task difficulty adjustment
459    pub difficulty_adjustment: bool,
460}
461
462impl Default for TaskSpecificParams {
463    fn default() -> Self {
464        let mut domain_weights = HashMap::new();
465        domain_weights.insert("vision".to_string(), 1.0);
466        domain_weights.insert("language".to_string(), 1.0);
467        domain_weights.insert("graph".to_string(), 1.0);
468
469        Self {
470            task_categories: vec![
471                TaskCategory::ImageCaptioning,
472                TaskCategory::VisualQuestionAnswering,
473                TaskCategory::GraphGrounding,
474            ],
475            domain_weights,
476            difficulty_adjustment: true,
477        }
478    }
479}
480
481/// Task categories for multi-modal learning
482#[derive(Debug, Clone, Serialize, Deserialize)]
483pub enum TaskCategory {
484    /// Image captioning
485    ImageCaptioning,
486    /// Visual question answering
487    VisualQuestionAnswering,
488    /// Image-text retrieval
489    ImageTextRetrieval,
490    /// Graph-text alignment
491    GraphTextAlignment,
492    /// Graph grounding in images
493    GraphGrounding,
494    /// Multi-modal reasoning
495    MultiModalReasoning,
496    /// Cross-modal generation
497    CrossModalGeneration,
498}
499
500/// Transfer learning configuration
501#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct TransferLearningConfig {
503    /// Transfer strategy
504    pub strategy: TransferStrategy,
505    /// Source domains
506    pub source_domains: Vec<String>,
507    /// Target domains
508    pub target_domains: Vec<String>,
509    /// Domain adaptation configuration
510    pub domain_adaptation: DomainAdaptationConfig,
511    /// Zero-shot configuration
512    pub zero_shot_config: ZeroShotConfig,
513    /// Few-shot configuration
514    pub few_shot_config: FewShotConfig,
515}
516
517impl Default for TransferLearningConfig {
518    fn default() -> Self {
519        Self {
520            strategy: TransferStrategy::ProgressiveTransfer,
521            source_domains: vec!["general".to_string(), "imagenet".to_string()],
522            target_domains: vec!["medical".to_string(), "scientific".to_string()],
523            domain_adaptation: DomainAdaptationConfig::default(),
524            zero_shot_config: ZeroShotConfig::default(),
525            few_shot_config: FewShotConfig::default(),
526        }
527    }
528}
529
530/// Transfer learning strategies
531#[derive(Debug, Clone, Serialize, Deserialize)]
532pub enum TransferStrategy {
533    /// Fine-tuning all parameters
534    FineTuning,
535    /// Feature extraction (frozen backbone)
536    FeatureExtraction,
537    /// Progressive transfer
538    ProgressiveTransfer,
539    /// Multi-task learning
540    MultiTaskLearning,
541    /// Domain adaptation
542    DomainAdaptation,
543    /// Continual learning
544    ContinualLearning,
545}
546
547/// Domain adaptation configuration
548#[derive(Debug, Clone, Serialize, Deserialize)]
549pub struct DomainAdaptationConfig {
550    /// Adaptation method
551    pub method: DomainAdaptationMethod,
552    /// Adversarial training
553    pub adversarial_training: bool,
554    /// Gradient reversal layer
555    pub gradient_reversal: bool,
556    /// Domain classifier weight
557    pub domain_classifier_weight: f32,
558}
559
560impl Default for DomainAdaptationConfig {
561    fn default() -> Self {
562        Self {
563            method: DomainAdaptationMethod::DANN,
564            adversarial_training: true,
565            gradient_reversal: true,
566            domain_classifier_weight: 0.1,
567        }
568    }
569}
570
571/// Domain adaptation methods
572#[derive(Debug, Clone, Serialize, Deserialize)]
573pub enum DomainAdaptationMethod {
574    /// Domain-Adversarial Neural Networks
575    DANN,
576    /// Maximum Mean Discrepancy
577    MMD,
578    /// Correlation Alignment
579    CORAL,
580    /// Wasserstein Distance
581    WDGRL,
582    /// Conditional Domain Adaptation
583    CDAN,
584}
585
586/// Zero-shot learning configuration
587#[derive(Debug, Clone, Serialize, Deserialize)]
588pub struct ZeroShotConfig {
589    /// Zero-shot method
590    pub method: ZeroShotMethod,
591    /// Semantic space dimension
592    pub semantic_dim: usize,
593    /// Use auxiliary attributes
594    pub use_attributes: bool,
595    /// Attribute dimension
596    pub attribute_dim: usize,
597}
598
599impl Default for ZeroShotConfig {
600    fn default() -> Self {
601        Self {
602            method: ZeroShotMethod::CLIP,
603            semantic_dim: 512,
604            use_attributes: true,
605            attribute_dim: 256,
606        }
607    }
608}
609
610/// Zero-shot learning methods
611#[derive(Debug, Clone, Serialize, Deserialize)]
612pub enum ZeroShotMethod {
613    /// CLIP-style contrastive learning
614    CLIP,
615    /// ALIGN-style learning
616    ALIGN,
617    /// Attribute-based learning
618    Attribute,
619    /// Semantic embedding
620    SemanticEmbedding,
621    /// Knowledge graph guided
622    KnowledgeGuided,
623}
624
625/// Few-shot learning configuration
626#[derive(Debug, Clone, Serialize, Deserialize)]
627pub struct FewShotConfig {
628    /// Few-shot method
629    pub method: FewShotMethod,
630    /// Number of shots
631    pub num_shots: usize,
632    /// Episode configuration
633    pub episode_config: EpisodeConfig,
634}
635
636impl Default for FewShotConfig {
637    fn default() -> Self {
638        Self {
639            method: FewShotMethod::ProtoNet,
640            num_shots: 5,
641            episode_config: EpisodeConfig::default(),
642        }
643    }
644}
645
646/// Few-shot learning methods
647#[derive(Debug, Clone, Serialize, Deserialize)]
648pub enum FewShotMethod {
649    /// Prototypical Networks
650    ProtoNet,
651    /// Matching Networks
652    MatchingNet,
653    /// Relation Networks
654    RelationNet,
655    /// Meta-learning approaches
656    MetaLearning,
657}
658
659/// Episode configuration for few-shot learning
660#[derive(Debug, Clone, Serialize, Deserialize)]
661pub struct EpisodeConfig {
662    /// Number of classes per episode
663    pub num_classes: usize,
664    /// Support samples per class
665    pub support_per_class: usize,
666    /// Query samples per class
667    pub query_per_class: usize,
668}
669
670impl Default for EpisodeConfig {
671    fn default() -> Self {
672        Self {
673            num_classes: 5,
674            support_per_class: 5,
675            query_per_class: 15,
676        }
677    }
678}
679
680/// Joint training configuration
681#[derive(Debug, Clone, Serialize, Deserialize)]
682pub struct JointTrainingConfig {
683    /// Training objectives
684    pub objectives: Vec<TrainingObjective>,
685    /// Objective weights
686    pub objective_weights: HashMap<String, f32>,
687    /// Curriculum learning
688    pub curriculum_learning: bool,
689    /// Progressive training
690    pub progressive_training: bool,
691}
692
693impl Default for JointTrainingConfig {
694    fn default() -> Self {
695        let mut objective_weights = HashMap::new();
696        objective_weights.insert("vision_language_alignment".to_string(), 1.0);
697        objective_weights.insert("language_graph_alignment".to_string(), 0.8);
698        objective_weights.insert("vision_graph_alignment".to_string(), 0.6);
699        objective_weights.insert("tri_modal_alignment".to_string(), 1.2);
700
701        Self {
702            objectives: vec![
703                TrainingObjective::ContrastiveLearning,
704                TrainingObjective::MaskedLanguageModeling,
705                TrainingObjective::ImageTextMatching,
706                TrainingObjective::GraphAlignment,
707            ],
708            objective_weights,
709            curriculum_learning: true,
710            progressive_training: true,
711        }
712    }
713}
714
715/// Training objectives
716#[derive(Debug, Clone, Serialize, Deserialize)]
717pub enum TrainingObjective {
718    /// Contrastive learning between modalities
719    ContrastiveLearning,
720    /// Masked language modeling
721    MaskedLanguageModeling,
722    /// Image-text matching
723    ImageTextMatching,
724    /// Graph-text alignment
725    GraphAlignment,
726    /// Visual question answering
727    VisualQuestionAnswering,
728    /// Image captioning
729    ImageCaptioning,
730    /// Graph reasoning
731    GraphReasoning,
732    /// Multi-modal reasoning
733    MultiModalReasoning,
734}