Skip to main content

oxirs_embed/vision_language_graph/
mod.rs

1//! Vision-Language-Graph Multi-Modal Integration
2//!
3//! This module implements advanced multi-modal integration for vision, language, and knowledge graphs
4//! with features including:
5//! - Multi-modal transformers with cross-attention
6//! - Joint representation learning
7//! - Zero-shot and few-shot transfer learning
8//! - Meta-learning for adaptation
9//! - Vision-text-graph unified embedding spaces
10
11use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
12use anyhow::{anyhow, Result};
13use async_trait::async_trait;
14use chrono::Utc;
15use scirs2_core::ndarray_ext::{Array1, Array2, Array3};
16use scirs2_core::random::{Random, Rng};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use uuid::Uuid;
20
21pub mod config;
22pub mod encoders;
23pub mod meta_learner;
24pub mod transformer;
25
26pub use config::*;
27pub use encoders::*;
28pub use meta_learner::*;
29pub use transformer::*;
30
31/// Vision-Language-Graph embedding model
32#[derive(Debug)]
33pub struct VisionLanguageGraphModel {
34    pub config: VisionLanguageGraphConfig,
35    pub model_id: Uuid,
36    /// Vision encoder
37    pub vision_encoder: VisionEncoder,
38    /// Language encoder
39    pub language_encoder: LanguageEncoder,
40    /// Graph encoder
41    pub graph_encoder: GraphEncoder,
42    /// Multi-modal transformer
43    pub multimodal_transformer: MultiModalTransformer,
44    /// Meta-learner for adaptation
45    pub meta_learner: MetaLearner,
46    /// Cached embeddings
47    pub vision_embeddings: HashMap<String, Array1<f32>>,
48    pub language_embeddings: HashMap<String, Array1<f32>>,
49    pub graph_embeddings: HashMap<String, Array1<f32>>,
50    pub unified_embeddings: HashMap<String, Array1<f32>>,
51    /// Training state
52    pub training_stats: Option<TrainingStats>,
53    pub is_trained: bool,
54}
55
56/// Vision encoder
57impl VisionLanguageGraphModel {
58    /// Create new vision-language-graph model
59    pub fn new(config: VisionLanguageGraphConfig) -> Self {
60        let model_id = Uuid::new_v4();
61
62        let vision_encoder = VisionEncoder::new(config.vision_config.clone());
63        let language_encoder = LanguageEncoder::new(config.language_config.clone());
64        let graph_encoder = GraphEncoder::new(config.graph_config.clone());
65        let multimodal_transformer = MultiModalTransformer::new(config.transformer_config.clone());
66        let meta_learner = MetaLearner::new(config.meta_learning_config.clone());
67
68        Self {
69            config,
70            model_id,
71            vision_encoder,
72            language_encoder,
73            graph_encoder,
74            multimodal_transformer,
75            meta_learner,
76            vision_embeddings: HashMap::new(),
77            language_embeddings: HashMap::new(),
78            graph_embeddings: HashMap::new(),
79            unified_embeddings: HashMap::new(),
80            training_stats: None,
81            is_trained: false,
82        }
83    }
84
85    /// Generate unified multi-modal embedding
86    pub async fn generate_unified_embedding(
87        &mut self,
88        image: Option<&Array3<f32>>,
89        text: Option<&str>,
90        graph_data: Option<(&Array2<f32>, &Array2<f32>, &Array2<f32>)>,
91    ) -> Result<Array1<f32>> {
92        let mut embeddings = Vec::new();
93
94        // Vision embedding
95        let vision_emb = if let Some(img) = image {
96            let emb = self.vision_encoder.encode_image(img)?;
97            self.vision_embeddings
98                .insert("current_image".to_string(), emb.clone());
99            emb
100        } else {
101            Array1::zeros(self.config.vision_config.vision_dim)
102        };
103        embeddings.push(vision_emb.clone());
104
105        // Language embedding
106        let language_emb = if let Some(txt) = text {
107            let emb = self.language_encoder.encode_text(txt)?;
108            self.language_embeddings
109                .insert("current_text".to_string(), emb.clone());
110            emb
111        } else {
112            Array1::zeros(self.config.language_config.language_dim)
113        };
114        embeddings.push(language_emb.clone());
115
116        // Graph embedding
117        let graph_emb = if let Some((nodes, edges, adj)) = graph_data {
118            let emb = self.graph_encoder.encode_graph(nodes, edges, adj)?;
119            self.graph_embeddings
120                .insert("current_graph".to_string(), emb.clone());
121            emb
122        } else {
123            Array1::zeros(self.config.graph_config.graph_dim)
124        };
125        embeddings.push(graph_emb.clone());
126
127        // Fuse embeddings
128        let unified_emb =
129            self.multimodal_transformer
130                .fuse_embeddings(&vision_emb, &language_emb, &graph_emb)?;
131
132        self.unified_embeddings
133            .insert("current_unified".to_string(), unified_emb.clone());
134
135        Ok(unified_emb)
136    }
137
138    /// Zero-shot prediction
139    pub fn zero_shot_predict(
140        &self,
141        query_embedding: &Array1<f32>,
142        class_prototypes: &HashMap<String, Array1<f32>>,
143    ) -> Result<String> {
144        let mut best_class = String::new();
145        let mut best_score = f32::NEG_INFINITY;
146
147        for (class_name, prototype) in class_prototypes {
148            let score = self.cosine_similarity(query_embedding, prototype);
149            if score > best_score {
150                best_score = score;
151                best_class = class_name.clone();
152            }
153        }
154
155        Ok(best_class)
156    }
157
158    /// Few-shot adaptation
159    pub fn few_shot_adapt(
160        &mut self,
161        support_examples: &[(Array1<f32>, String)],
162        query_examples: &[Array1<f32>],
163    ) -> Result<Vec<String>> {
164        // Convert support examples to meta-learning format
165        let support_set: Vec<(Array1<f32>, Array1<f32>)> = support_examples
166            .iter()
167            .map(|(emb, label)| {
168                let label_emb = Array1::from_vec(vec![label.len() as f32]); // Simplified label encoding
169                (emb.clone(), label_emb)
170            })
171            .collect();
172
173        let query_set: Vec<(Array1<f32>, Array1<f32>)> = query_examples
174            .iter()
175            .map(|emb| (emb.clone(), Array1::zeros(1)))
176            .collect();
177
178        // Adapt meta-learner
179        let _adapted_params = self.meta_learner.adapt_to_task(&support_set, &query_set)?;
180
181        // Make predictions on query set
182        let mut predictions = Vec::new();
183
184        for query_emb in query_examples {
185            // Find nearest support example
186            let mut best_label = String::new();
187            let mut best_distance = f32::INFINITY;
188
189            for (support_emb, label) in support_examples {
190                let distance = self.euclidean_distance(query_emb, support_emb);
191                if distance < best_distance {
192                    best_distance = distance;
193                    best_label = label.clone();
194                }
195            }
196
197            predictions.push(best_label);
198        }
199
200        Ok(predictions)
201    }
202
203    /// Cosine similarity
204    fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
205        let dot_product = a.dot(b);
206        let norm_a = a.dot(a).sqrt();
207        let norm_b = b.dot(b).sqrt();
208
209        if norm_a > 0.0 && norm_b > 0.0 {
210            dot_product / (norm_a * norm_b)
211        } else {
212            0.0
213        }
214    }
215
216    /// Euclidean distance
217    fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
218        let diff = a - b;
219        diff.dot(&diff).sqrt()
220    }
221}
222
223/// Multi-modal statistics
224#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct VisionLanguageGraphStats {
226    pub num_vision_samples: usize,
227    pub num_language_samples: usize,
228    pub num_graph_samples: usize,
229    pub num_unified_embeddings: usize,
230    pub vision_dim: usize,
231    pub language_dim: usize,
232    pub graph_dim: usize,
233    pub unified_dim: usize,
234    pub zero_shot_accuracy: f32,
235    pub few_shot_accuracy: f32,
236    pub cross_modal_alignment_score: f32,
237}
238
239impl Default for VisionLanguageGraphStats {
240    fn default() -> Self {
241        Self {
242            num_vision_samples: 0,
243            num_language_samples: 0,
244            num_graph_samples: 0,
245            num_unified_embeddings: 0,
246            vision_dim: 768,
247            language_dim: 768,
248            graph_dim: 512,
249            unified_dim: 768,
250            zero_shot_accuracy: 0.0,
251            few_shot_accuracy: 0.0,
252            cross_modal_alignment_score: 0.0,
253        }
254    }
255}
256
257#[async_trait]
258impl EmbeddingModel for VisionLanguageGraphModel {
259    fn config(&self) -> &ModelConfig {
260        &self.config.base_config
261    }
262
263    fn model_id(&self) -> &Uuid {
264        &self.model_id
265    }
266
267    fn model_type(&self) -> &'static str {
268        "VisionLanguageGraphModel"
269    }
270
271    fn add_triple(&mut self, _triple: Triple) -> Result<()> {
272        // Implementation would process triples for graph structure
273        Ok(())
274    }
275
276    async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
277        let epochs = epochs.unwrap_or(self.config.base_config.max_epochs);
278        let start_time = std::time::Instant::now();
279
280        let mut loss_history = Vec::new();
281
282        for epoch in 0..epochs {
283            // Simulate multi-modal training
284            let epoch_loss = self.train_epoch().await?;
285            loss_history.push(epoch_loss);
286
287            if epoch > 10 && epoch_loss < 1e-4 {
288                break;
289            }
290        }
291
292        let training_time = start_time.elapsed().as_secs_f64();
293        let final_loss = loss_history.last().copied().unwrap_or(0.0);
294
295        let stats = TrainingStats {
296            epochs_completed: loss_history.len(),
297            final_loss,
298            training_time_seconds: training_time,
299            convergence_achieved: final_loss < 1e-4,
300            loss_history,
301        };
302
303        self.training_stats = Some(stats.clone());
304        self.is_trained = true;
305
306        Ok(stats)
307    }
308
309    fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
310        if let Some(embedding) = self.unified_embeddings.get(entity) {
311            Ok(Vector::new(embedding.to_vec()))
312        } else {
313            Err(anyhow!("Entity not found: {}", entity))
314        }
315    }
316
317    fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
318        if let Some(embedding) = self.unified_embeddings.get(relation) {
319            Ok(Vector::new(embedding.to_vec()))
320        } else {
321            Err(anyhow!("Relation not found: {}", relation))
322        }
323    }
324
325    fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
326        let subject_emb = self.get_entity_embedding(subject)?;
327        let predicate_emb = self.get_relation_embedding(predicate)?;
328        let object_emb = self.get_entity_embedding(object)?;
329
330        // Simple TransE-style scoring
331        let subject_arr = Array1::from_vec(subject_emb.values);
332        let predicate_arr = Array1::from_vec(predicate_emb.values);
333        let object_arr = Array1::from_vec(object_emb.values);
334
335        let predicted = &subject_arr + &predicate_arr;
336        let diff = &predicted - &object_arr;
337        let distance = diff.dot(&diff).sqrt();
338
339        Ok(-distance as f64)
340    }
341
342    fn predict_objects(
343        &self,
344        subject: &str,
345        predicate: &str,
346        k: usize,
347    ) -> Result<Vec<(String, f64)>> {
348        let mut scores = Vec::new();
349
350        for entity in self.unified_embeddings.keys() {
351            if entity != subject {
352                let score = self.score_triple(subject, predicate, entity)?;
353                scores.push((entity.clone(), score));
354            }
355        }
356
357        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
358        scores.truncate(k);
359
360        Ok(scores)
361    }
362
363    fn predict_subjects(
364        &self,
365        predicate: &str,
366        object: &str,
367        k: usize,
368    ) -> Result<Vec<(String, f64)>> {
369        let mut scores = Vec::new();
370
371        for entity in self.unified_embeddings.keys() {
372            if entity != object {
373                let score = self.score_triple(entity, predicate, object)?;
374                scores.push((entity.clone(), score));
375            }
376        }
377
378        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
379        scores.truncate(k);
380
381        Ok(scores)
382    }
383
384    fn predict_relations(
385        &self,
386        subject: &str,
387        object: &str,
388        k: usize,
389    ) -> Result<Vec<(String, f64)>> {
390        let mut scores = Vec::new();
391
392        for relation in self.unified_embeddings.keys() {
393            let score = self.score_triple(subject, relation, object)?;
394            scores.push((relation.clone(), score));
395        }
396
397        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
398        scores.truncate(k);
399
400        Ok(scores)
401    }
402
403    fn get_entities(&self) -> Vec<String> {
404        self.unified_embeddings.keys().cloned().collect()
405    }
406
407    fn get_relations(&self) -> Vec<String> {
408        self.unified_embeddings.keys().cloned().collect()
409    }
410
411    fn get_stats(&self) -> ModelStats {
412        ModelStats {
413            num_entities: self.unified_embeddings.len(),
414            num_relations: self.unified_embeddings.len(),
415            num_triples: 0,
416            dimensions: self.config.transformer_config.unified_dim,
417            is_trained: self.is_trained,
418            model_type: self.model_type().to_string(),
419            creation_time: Utc::now(),
420            last_training_time: if self.is_trained {
421                Some(Utc::now())
422            } else {
423                None
424            },
425        }
426    }
427
428    fn save(&self, _path: &str) -> Result<()> {
429        Ok(())
430    }
431
432    fn load(&mut self, _path: &str) -> Result<()> {
433        Ok(())
434    }
435
436    fn clear(&mut self) {
437        self.vision_embeddings.clear();
438        self.language_embeddings.clear();
439        self.graph_embeddings.clear();
440        self.unified_embeddings.clear();
441        self.is_trained = false;
442        self.training_stats = None;
443    }
444
445    fn is_trained(&self) -> bool {
446        self.is_trained
447    }
448
449    async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
450        let mut results = Vec::new();
451
452        for text in texts {
453            let embedding = self.language_encoder.encode_text(text)?;
454            results.push(embedding.to_vec());
455        }
456
457        Ok(results)
458    }
459}
460
461impl VisionLanguageGraphModel {
462    /// Training epoch for multi-modal model
463    async fn train_epoch(&mut self) -> Result<f64> {
464        // Simulate multi-modal training loss
465        let mut random = Random::default();
466        let vision_loss = 0.1 * random.random::<f64>();
467        let language_loss = 0.1 * random.random::<f64>();
468        let graph_loss = 0.1 * random.random::<f64>();
469        let fusion_loss = 0.1 * random.random::<f64>();
470
471        let total_loss = vision_loss + language_loss + graph_loss + fusion_loss;
472
473        Ok(total_loss)
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480
481    /// Create a small VisionLanguageGraphConfig suitable for fast tests in debug builds.
482    /// All dimensions are reduced to avoid multi-second matrix allocations.
483    fn small_test_config() -> VisionLanguageGraphConfig {
484        VisionLanguageGraphConfig {
485            vision_config: VisionEncoderConfig {
486                image_size: (32, 32),
487                channels: 3,
488                patch_size: (8, 8),
489                vision_dim: 32,
490                cnn_config: CNNConfig {
491                    num_layers: 2,
492                    filter_sizes: vec![8, 16],
493                    stride_sizes: vec![2, 2],
494                    ..CNNConfig::default()
495                },
496                vit_config: ViTConfig {
497                    num_layers: 2,
498                    num_heads: 2,
499                    mlp_dim: 64,
500                    ..ViTConfig::default()
501                },
502                ..VisionEncoderConfig::default()
503            },
504            language_config: LanguageEncoderConfig {
505                vocab_size: 256,
506                language_dim: 32,
507                max_seq_length: 16,
508                transformer_config: LanguageTransformerConfig {
509                    num_layers: 2,
510                    num_heads: 2,
511                    hidden_dim: 32,
512                    intermediate_dim: 64,
513                    ..LanguageTransformerConfig::default()
514                },
515                ..LanguageEncoderConfig::default()
516            },
517            graph_config: GraphEncoderConfig {
518                node_dim: 16,
519                edge_dim: 8,
520                graph_dim: 32,
521                num_layers: 2,
522                ..GraphEncoderConfig::default()
523            },
524            transformer_config: MultiModalTransformerConfig {
525                unified_dim: 32,
526                num_fusion_layers: 2,
527                cross_attention_config: CrossAttentionConfig {
528                    num_heads: 2,
529                    head_dim: 16,
530                    ..CrossAttentionConfig::default()
531                },
532                ..MultiModalTransformerConfig::default()
533            },
534            ..VisionLanguageGraphConfig::default()
535        }
536    }
537
538    #[test]
539    fn test_vision_language_graph_config_default() {
540        let config = VisionLanguageGraphConfig::default();
541        assert_eq!(config.vision_config.vision_dim, 768);
542        assert_eq!(config.language_config.language_dim, 768);
543        assert_eq!(config.graph_config.graph_dim, 768); // Updated to match unified_dim
544    }
545
546    #[test]
547    fn test_vision_encoder_creation() {
548        let config = VisionEncoderConfig::default();
549        let encoder = VisionEncoder::new(config);
550        assert!(!encoder.cnn_parameters.is_empty());
551        assert!(!encoder.vit_parameters.is_empty());
552    }
553
554    #[test]
555    fn test_language_encoder_creation() {
556        // Use small dimensions to avoid timeout in debug builds
557        let config = LanguageEncoderConfig {
558            vocab_size: 256,
559            language_dim: 32,
560            max_seq_length: 16,
561            transformer_config: LanguageTransformerConfig {
562                num_layers: 2,
563                num_heads: 2,
564                hidden_dim: 32,
565                intermediate_dim: 64,
566                ..LanguageTransformerConfig::default()
567            },
568            ..LanguageEncoderConfig::default()
569        };
570        let encoder = LanguageEncoder::new(config);
571        assert_eq!(encoder.token_embeddings.nrows(), 256);
572        assert_eq!(encoder.position_embeddings.nrows(), 16);
573    }
574
575    #[test]
576    fn test_graph_encoder_creation() {
577        let config = GraphEncoderConfig::default();
578        let encoder = GraphEncoder::new(config);
579        assert!(!encoder.node_parameters.is_empty());
580        assert!(!encoder.edge_parameters.is_empty());
581    }
582
583    #[test]
584    fn test_multimodal_transformer_creation() {
585        let config = MultiModalTransformerConfig::default();
586        let transformer = MultiModalTransformer::new(config);
587        assert!(!transformer.cross_attention_params.is_empty());
588        assert!(!transformer.fusion_params.is_empty());
589    }
590
591    #[test]
592    #[cfg_attr(debug_assertions, ignore = "Model initialization slow in debug builds")]
593    fn test_vision_language_graph_model_creation() {
594        let config = VisionLanguageGraphConfig::default();
595        let model = VisionLanguageGraphModel::new(config);
596        assert!(!model.is_trained);
597        assert_eq!(model.unified_embeddings.len(), 0);
598    }
599
600    #[test]
601    fn test_vision_encoder_image_encoding() {
602        let config = VisionEncoderConfig::default();
603        let encoder = VisionEncoder::new(config);
604
605        let mut random = Random::default();
606        let image = Array3::from_shape_fn((224, 224, 3), |_| random.random::<f32>());
607        let embedding = encoder.encode_image(&image).unwrap();
608
609        assert_eq!(embedding.len(), encoder.config.vision_dim);
610    }
611
612    #[test]
613    fn test_language_encoder_text_encoding() {
614        // Use small dimensions to avoid timeout in debug builds
615        let config = LanguageEncoderConfig {
616            vocab_size: 256,
617            language_dim: 32,
618            max_seq_length: 16,
619            transformer_config: LanguageTransformerConfig {
620                num_layers: 2,
621                num_heads: 2,
622                hidden_dim: 32,
623                intermediate_dim: 64,
624                ..LanguageTransformerConfig::default()
625            },
626            ..LanguageEncoderConfig::default()
627        };
628        let encoder = LanguageEncoder::new(config);
629
630        let text = "Hello world, this is a test";
631        let embedding = encoder
632            .encode_text(text)
633            .expect("encode_text should succeed");
634
635        assert_eq!(embedding.len(), encoder.config.language_dim);
636    }
637
638    #[test]
639    fn test_graph_encoder_graph_encoding() {
640        let config = GraphEncoderConfig::default();
641        let node_dim = config.node_dim;
642        let edge_dim = config.edge_dim;
643        let encoder = GraphEncoder::new(config);
644
645        let mut random = Random::default();
646        let node_features = Array2::from_shape_fn((5, node_dim), |_| random.random::<f32>());
647        let edge_features = Array2::from_shape_fn((10, edge_dim), |_| random.random::<f32>());
648        let adjacency = Array2::eye(5);
649
650        let embedding = encoder
651            .encode_graph(&node_features, &edge_features, &adjacency)
652            .unwrap();
653
654        assert_eq!(embedding.len(), encoder.config.graph_dim);
655    }
656
657    #[tokio::test]
658    #[cfg_attr(debug_assertions, ignore = "Embedding tests require release builds")]
659    async fn test_unified_embedding_generation() {
660        let config = VisionLanguageGraphConfig::default();
661        let mut model = VisionLanguageGraphModel::new(config);
662
663        let mut random = Random::default();
664        let image = Array3::from_shape_fn((224, 224, 3), |_| random.random::<f32>());
665        let text = "A beautiful landscape with mountains";
666        let node_features = Array2::from_shape_fn((3, 256), |_| random.random::<f32>());
667        let edge_features = Array2::from_shape_fn((6, 128), |_| random.random::<f32>());
668        let adjacency = Array2::eye(3);
669
670        let unified_embedding = model
671            .generate_unified_embedding(
672                Some(&image),
673                Some(text),
674                Some((&node_features, &edge_features, &adjacency)),
675            )
676            .await
677            .unwrap();
678
679        assert!(!unified_embedding.is_empty());
680        assert_eq!(model.vision_embeddings.len(), 1);
681        assert_eq!(model.language_embeddings.len(), 1);
682        assert_eq!(model.graph_embeddings.len(), 1);
683        assert_eq!(model.unified_embeddings.len(), 1);
684    }
685
686    #[test]
687    fn test_zero_shot_prediction() {
688        let config = small_test_config();
689        let model = VisionLanguageGraphModel::new(config);
690
691        let mut random = Random::default();
692        let query = Array1::from_shape_fn(32, |_| random.random::<f32>());
693
694        let mut prototypes = HashMap::new();
695        let mut random = Random::default();
696        prototypes.insert(
697            "class1".to_string(),
698            Array1::from_shape_fn(32, |_| random.random::<f32>()),
699        );
700        let mut random = Random::default();
701        prototypes.insert(
702            "class2".to_string(),
703            Array1::from_shape_fn(32, |_| random.random::<f32>()),
704        );
705
706        let prediction = model
707            .zero_shot_predict(&query, &prototypes)
708            .expect("zero_shot_predict should succeed");
709        assert!(prototypes.contains_key(&prediction));
710    }
711
712    #[test]
713    #[cfg_attr(debug_assertions, ignore = "Embedding tests require release builds")]
714    fn test_few_shot_adaptation() {
715        let config = VisionLanguageGraphConfig::default();
716        let mut model = VisionLanguageGraphModel::new(config);
717
718        let mut random = Random::default();
719        let support_examples = vec![
720            (
721                Array1::from_shape_fn(512, |_| random.random::<f32>()),
722                "cat".to_string(),
723            ),
724            (
725                Array1::from_shape_fn(512, |_| random.random::<f32>()),
726                "dog".to_string(),
727            ),
728        ];
729
730        let mut random = Random::default();
731        let query_examples = vec![
732            Array1::from_shape_fn(512, |_| random.random::<f32>()),
733            Array1::from_shape_fn(512, |_| random.random::<f32>()),
734        ];
735
736        let predictions = model
737            .few_shot_adapt(&support_examples, &query_examples)
738            .unwrap();
739        assert_eq!(predictions.len(), 2);
740    }
741
742    #[test]
743    fn test_meta_learner_adaptation() {
744        let config = MetaLearningConfig::default();
745        let mut meta_learner = MetaLearner::new(config);
746
747        let mut random = Random::default();
748        let support_set = vec![
749            (
750                Array1::from_shape_fn(512, |_| random.random::<f32>()),
751                Array1::from_vec(vec![1.0]),
752            ),
753            (
754                Array1::from_shape_fn(512, |_| random.random::<f32>()),
755                Array1::from_vec(vec![0.0]),
756            ),
757        ];
758
759        let query_set = vec![];
760
761        let adapted_params = meta_learner
762            .adapt_to_task(&support_set, &query_set)
763            .unwrap();
764        assert!(!adapted_params.is_empty());
765    }
766
767    #[tokio::test]
768    async fn test_vision_language_graph_training() {
769        let config = small_test_config();
770        let mut model = VisionLanguageGraphModel::new(config);
771
772        let stats = model.train(Some(3)).await.expect("training should succeed");
773        assert_eq!(stats.epochs_completed, 3);
774        assert!(model.is_trained());
775    }
776
777    #[tokio::test]
778    #[cfg_attr(debug_assertions, ignore = "Embedding tests require release builds")]
779    async fn test_vision_language_graph_encoding() {
780        let config = VisionLanguageGraphConfig::default();
781        let expected_dim = config.language_config.language_dim;
782        let model = VisionLanguageGraphModel::new(config);
783
784        let texts = vec!["hello world".to_string(), "test encoding".to_string()];
785        let embeddings = model.encode(&texts).await.unwrap();
786
787        assert_eq!(embeddings.len(), 2);
788        assert_eq!(embeddings[0].len(), expected_dim);
789    }
790}