1use crate::ModelConfig;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, Serialize, Deserialize, Default)]
9pub struct VisionLanguageGraphConfig {
10 pub base_config: ModelConfig,
11 pub vision_config: VisionEncoderConfig,
13 pub language_config: LanguageEncoderConfig,
15 pub graph_config: GraphEncoderConfig,
17 pub transformer_config: MultiModalTransformerConfig,
19 pub meta_learning_config: MetaLearningConfig,
21 pub transfer_config: TransferLearningConfig,
23 pub joint_training_config: JointTrainingConfig,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct VisionEncoderConfig {
30 pub architecture: VisionArchitecture,
32 pub image_size: (usize, usize),
34 pub channels: usize,
36 pub patch_size: (usize, usize),
38 pub vision_dim: usize,
40 pub cnn_config: CNNConfig,
42 pub vit_config: ViTConfig,
44}
45
46impl Default for VisionEncoderConfig {
47 fn default() -> Self {
48 Self {
49 architecture: VisionArchitecture::VisionTransformer,
50 image_size: (224, 224),
51 channels: 3,
52 patch_size: (16, 16),
53 vision_dim: 768,
54 cnn_config: CNNConfig::default(),
55 vit_config: ViTConfig::default(),
56 }
57 }
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub enum VisionArchitecture {
63 ResNet,
65 EfficientNet,
66 DenseNet,
67 VisionTransformer,
69 DeiT,
70 Swin,
71 ConViT,
73 CvT,
74 CLIPVision,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct CNNConfig {
81 pub num_layers: usize,
83 pub filter_sizes: Vec<usize>,
85 pub stride_sizes: Vec<usize>,
87 pub pooling: PoolingType,
89 pub normalization: NormalizationType,
91}
92
93impl Default for CNNConfig {
94 fn default() -> Self {
95 Self {
96 num_layers: 4,
97 filter_sizes: vec![64, 128, 256, 512],
98 stride_sizes: vec![2, 2, 2, 2],
99 pooling: PoolingType::AdaptiveAvgPool,
100 normalization: NormalizationType::BatchNorm,
101 }
102 }
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct ViTConfig {
108 pub num_layers: usize,
110 pub num_heads: usize,
112 pub mlp_dim: usize,
114 pub dropout_rate: f32,
116 pub position_encoding: PositionEncodingType,
118}
119
120impl Default for ViTConfig {
121 fn default() -> Self {
122 Self {
123 num_layers: 12,
124 num_heads: 12,
125 mlp_dim: 3072,
126 dropout_rate: 0.1,
127 position_encoding: PositionEncodingType::Learnable,
128 }
129 }
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub enum PoolingType {
135 MaxPool,
136 AvgPool,
137 AdaptiveAvgPool,
138 AdaptiveMaxPool,
139 GlobalAvgPool,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub enum NormalizationType {
145 BatchNorm,
146 LayerNorm,
147 GroupNorm,
148 InstanceNorm,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub enum PositionEncodingType {
154 Learnable,
155 Sinusoidal,
156 Relative,
157 RoPE, }
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct LanguageEncoderConfig {
163 pub architecture: LanguageArchitecture,
165 pub vocab_size: usize,
167 pub language_dim: usize,
169 pub max_seq_length: usize,
171 pub transformer_config: LanguageTransformerConfig,
173}
174
175impl Default for LanguageEncoderConfig {
176 fn default() -> Self {
177 Self {
178 architecture: LanguageArchitecture::BERT,
179 vocab_size: 30522,
180 language_dim: 768,
181 max_seq_length: 512,
182 transformer_config: LanguageTransformerConfig::default(),
183 }
184 }
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189pub enum LanguageArchitecture {
190 BERT,
191 RoBERTa,
192 DeBERTa,
193 ELECTRA,
194 GPT,
195 T5,
196 CLIP,
197 ALIGN,
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct LanguageTransformerConfig {
203 pub num_layers: usize,
204 pub num_heads: usize,
205 pub hidden_dim: usize,
206 pub intermediate_dim: usize,
207 pub dropout_rate: f32,
208 pub activation: ActivationFunction,
209}
210
211impl Default for LanguageTransformerConfig {
212 fn default() -> Self {
213 Self {
214 num_layers: 12,
215 num_heads: 12,
216 hidden_dim: 768,
217 intermediate_dim: 3072,
218 dropout_rate: 0.1,
219 activation: ActivationFunction::GELU,
220 }
221 }
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct GraphEncoderConfig {
227 pub architecture: GraphArchitecture,
229 pub node_dim: usize,
231 pub edge_dim: usize,
233 pub graph_dim: usize,
235 pub num_layers: usize,
237 pub aggregation: AggregationFunction,
239 pub readout: ReadoutFunction,
241}
242
243impl Default for GraphEncoderConfig {
244 fn default() -> Self {
245 Self {
246 architecture: GraphArchitecture::GraphTransformer,
247 node_dim: 256,
248 edge_dim: 128,
249 graph_dim: 768, num_layers: 6,
251 aggregation: AggregationFunction::Attention,
252 readout: ReadoutFunction::GlobalAttention,
253 }
254 }
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
259pub enum GraphArchitecture {
260 GCN,
261 GraphSAGE,
262 GAT,
263 GraphTransformer,
264 GIN,
265 PNA,
266 GPS, }
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub enum AggregationFunction {
272 Mean,
273 Max,
274 Sum,
275 Attention,
276 LSTM,
277 GRU,
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
282pub enum ReadoutFunction {
283 GlobalMean,
284 GlobalMax,
285 GlobalSum,
286 GlobalAttention,
287 Set2Set,
288 DiffPool,
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
293pub enum ActivationFunction {
294 ReLU,
295 GELU,
296 Swish,
297 Mish,
298 ELU,
299 LeakyReLU,
300 Tanh,
301 Sigmoid,
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
306pub struct MultiModalTransformerConfig {
307 pub unified_dim: usize,
309 pub num_fusion_layers: usize,
311 pub cross_attention_config: CrossAttentionConfig,
313 pub fusion_strategy: FusionStrategy,
315 pub modality_encoding: ModalityEncoding,
317}
318
319impl Default for MultiModalTransformerConfig {
320 fn default() -> Self {
321 Self {
322 unified_dim: 768,
323 num_fusion_layers: 6,
324 cross_attention_config: CrossAttentionConfig::default(),
325 fusion_strategy: FusionStrategy::CrossAttention,
326 modality_encoding: ModalityEncoding::Learnable,
327 }
328 }
329}
330
331#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct CrossAttentionConfig {
334 pub num_heads: usize,
336 pub head_dim: usize,
338 pub dropout_rate: f32,
340 pub use_residual: bool,
342 pub attention_mechanism: AttentionMechanism,
344}
345
346impl Default for CrossAttentionConfig {
347 fn default() -> Self {
348 Self {
349 num_heads: 12,
350 head_dim: 64,
351 dropout_rate: 0.1,
352 use_residual: true,
353 attention_mechanism: AttentionMechanism::ScaledDotProduct,
354 }
355 }
356}
357
358#[derive(Debug, Clone, Serialize, Deserialize)]
360pub enum AttentionMechanism {
361 ScaledDotProduct,
362 MultiHead,
363 SparseAttention,
364 LinearAttention,
365 PerformerAttention,
366 CoAttn, }
368
369#[derive(Debug, Clone, Serialize, Deserialize)]
371pub enum FusionStrategy {
372 EarlyFusion,
374 LateFusion,
376 CrossAttention,
378 ProgressiveFusion,
380 AdaptiveFusion,
382 TensorFusion,
384}
385
386#[derive(Debug, Clone, Serialize, Deserialize)]
388pub enum ModalityEncoding {
389 None,
391 Learnable,
393 Fixed,
395 PositionAware,
397}
398
399#[derive(Debug, Clone, Serialize, Deserialize)]
401pub struct MetaLearningConfig {
402 pub algorithm: MetaLearningAlgorithm,
404 pub support_set_size: usize,
406 pub query_set_size: usize,
408 pub adaptation_steps: usize,
410 pub inner_lr: f32,
412 pub outer_lr: f32,
414 pub task_specific_params: TaskSpecificParams,
416}
417
418impl Default for MetaLearningConfig {
419 fn default() -> Self {
420 Self {
421 algorithm: MetaLearningAlgorithm::MAML,
422 support_set_size: 5,
423 query_set_size: 15,
424 adaptation_steps: 5,
425 inner_lr: 0.01,
426 outer_lr: 0.001,
427 task_specific_params: TaskSpecificParams::default(),
428 }
429 }
430}
431
432#[derive(Debug, Clone, Serialize, Deserialize)]
434pub enum MetaLearningAlgorithm {
435 MAML,
437 FOMAML,
439 Reptile,
441 ProtoNet,
443 RelationNet,
445 MANN,
447 AMAML,
449}
450
451#[derive(Debug, Clone, Serialize, Deserialize)]
453pub struct TaskSpecificParams {
454 pub task_categories: Vec<TaskCategory>,
456 pub domain_weights: HashMap<String, f32>,
458 pub difficulty_adjustment: bool,
460}
461
462impl Default for TaskSpecificParams {
463 fn default() -> Self {
464 let mut domain_weights = HashMap::new();
465 domain_weights.insert("vision".to_string(), 1.0);
466 domain_weights.insert("language".to_string(), 1.0);
467 domain_weights.insert("graph".to_string(), 1.0);
468
469 Self {
470 task_categories: vec![
471 TaskCategory::ImageCaptioning,
472 TaskCategory::VisualQuestionAnswering,
473 TaskCategory::GraphGrounding,
474 ],
475 domain_weights,
476 difficulty_adjustment: true,
477 }
478 }
479}
480
481#[derive(Debug, Clone, Serialize, Deserialize)]
483pub enum TaskCategory {
484 ImageCaptioning,
486 VisualQuestionAnswering,
488 ImageTextRetrieval,
490 GraphTextAlignment,
492 GraphGrounding,
494 MultiModalReasoning,
496 CrossModalGeneration,
498}
499
500#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct TransferLearningConfig {
503 pub strategy: TransferStrategy,
505 pub source_domains: Vec<String>,
507 pub target_domains: Vec<String>,
509 pub domain_adaptation: DomainAdaptationConfig,
511 pub zero_shot_config: ZeroShotConfig,
513 pub few_shot_config: FewShotConfig,
515}
516
517impl Default for TransferLearningConfig {
518 fn default() -> Self {
519 Self {
520 strategy: TransferStrategy::ProgressiveTransfer,
521 source_domains: vec!["general".to_string(), "imagenet".to_string()],
522 target_domains: vec!["medical".to_string(), "scientific".to_string()],
523 domain_adaptation: DomainAdaptationConfig::default(),
524 zero_shot_config: ZeroShotConfig::default(),
525 few_shot_config: FewShotConfig::default(),
526 }
527 }
528}
529
530#[derive(Debug, Clone, Serialize, Deserialize)]
532pub enum TransferStrategy {
533 FineTuning,
535 FeatureExtraction,
537 ProgressiveTransfer,
539 MultiTaskLearning,
541 DomainAdaptation,
543 ContinualLearning,
545}
546
547#[derive(Debug, Clone, Serialize, Deserialize)]
549pub struct DomainAdaptationConfig {
550 pub method: DomainAdaptationMethod,
552 pub adversarial_training: bool,
554 pub gradient_reversal: bool,
556 pub domain_classifier_weight: f32,
558}
559
560impl Default for DomainAdaptationConfig {
561 fn default() -> Self {
562 Self {
563 method: DomainAdaptationMethod::DANN,
564 adversarial_training: true,
565 gradient_reversal: true,
566 domain_classifier_weight: 0.1,
567 }
568 }
569}
570
571#[derive(Debug, Clone, Serialize, Deserialize)]
573pub enum DomainAdaptationMethod {
574 DANN,
576 MMD,
578 CORAL,
580 WDGRL,
582 CDAN,
584}
585
586#[derive(Debug, Clone, Serialize, Deserialize)]
588pub struct ZeroShotConfig {
589 pub method: ZeroShotMethod,
591 pub semantic_dim: usize,
593 pub use_attributes: bool,
595 pub attribute_dim: usize,
597}
598
599impl Default for ZeroShotConfig {
600 fn default() -> Self {
601 Self {
602 method: ZeroShotMethod::CLIP,
603 semantic_dim: 512,
604 use_attributes: true,
605 attribute_dim: 256,
606 }
607 }
608}
609
610#[derive(Debug, Clone, Serialize, Deserialize)]
612pub enum ZeroShotMethod {
613 CLIP,
615 ALIGN,
617 Attribute,
619 SemanticEmbedding,
621 KnowledgeGuided,
623}
624
625#[derive(Debug, Clone, Serialize, Deserialize)]
627pub struct FewShotConfig {
628 pub method: FewShotMethod,
630 pub num_shots: usize,
632 pub episode_config: EpisodeConfig,
634}
635
636impl Default for FewShotConfig {
637 fn default() -> Self {
638 Self {
639 method: FewShotMethod::ProtoNet,
640 num_shots: 5,
641 episode_config: EpisodeConfig::default(),
642 }
643 }
644}
645
646#[derive(Debug, Clone, Serialize, Deserialize)]
648pub enum FewShotMethod {
649 ProtoNet,
651 MatchingNet,
653 RelationNet,
655 MetaLearning,
657}
658
659#[derive(Debug, Clone, Serialize, Deserialize)]
661pub struct EpisodeConfig {
662 pub num_classes: usize,
664 pub support_per_class: usize,
666 pub query_per_class: usize,
668}
669
670impl Default for EpisodeConfig {
671 fn default() -> Self {
672 Self {
673 num_classes: 5,
674 support_per_class: 5,
675 query_per_class: 15,
676 }
677 }
678}
679
680#[derive(Debug, Clone, Serialize, Deserialize)]
682pub struct JointTrainingConfig {
683 pub objectives: Vec<TrainingObjective>,
685 pub objective_weights: HashMap<String, f32>,
687 pub curriculum_learning: bool,
689 pub progressive_training: bool,
691}
692
693impl Default for JointTrainingConfig {
694 fn default() -> Self {
695 let mut objective_weights = HashMap::new();
696 objective_weights.insert("vision_language_alignment".to_string(), 1.0);
697 objective_weights.insert("language_graph_alignment".to_string(), 0.8);
698 objective_weights.insert("vision_graph_alignment".to_string(), 0.6);
699 objective_weights.insert("tri_modal_alignment".to_string(), 1.2);
700
701 Self {
702 objectives: vec![
703 TrainingObjective::ContrastiveLearning,
704 TrainingObjective::MaskedLanguageModeling,
705 TrainingObjective::ImageTextMatching,
706 TrainingObjective::GraphAlignment,
707 ],
708 objective_weights,
709 curriculum_learning: true,
710 progressive_training: true,
711 }
712 }
713}
714
715#[derive(Debug, Clone, Serialize, Deserialize)]
717pub enum TrainingObjective {
718 ContrastiveLearning,
720 MaskedLanguageModeling,
722 ImageTextMatching,
724 GraphAlignment,
726 VisualQuestionAnswering,
728 ImageCaptioning,
730 GraphReasoning,
732 MultiModalReasoning,
734}