1use crate::{EmbeddingModel, ModelConfig, ModelStats, TrainingStats, Triple, Vector};
12use anyhow::{anyhow, Result};
13use async_trait::async_trait;
14use chrono::Utc;
15use scirs2_core::ndarray_ext::{Array1, Array2, Array3};
16use scirs2_core::random::{Random, Rng};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use uuid::Uuid;
20
21pub mod config;
22pub mod encoders;
23pub mod meta_learner;
24pub mod transformer;
25
26pub use config::*;
27pub use encoders::*;
28pub use meta_learner::*;
29pub use transformer::*;
30
31#[derive(Debug)]
33pub struct VisionLanguageGraphModel {
34 pub config: VisionLanguageGraphConfig,
35 pub model_id: Uuid,
36 pub vision_encoder: VisionEncoder,
38 pub language_encoder: LanguageEncoder,
40 pub graph_encoder: GraphEncoder,
42 pub multimodal_transformer: MultiModalTransformer,
44 pub meta_learner: MetaLearner,
46 pub vision_embeddings: HashMap<String, Array1<f32>>,
48 pub language_embeddings: HashMap<String, Array1<f32>>,
49 pub graph_embeddings: HashMap<String, Array1<f32>>,
50 pub unified_embeddings: HashMap<String, Array1<f32>>,
51 pub training_stats: Option<TrainingStats>,
53 pub is_trained: bool,
54}
55
56impl VisionLanguageGraphModel {
58 pub fn new(config: VisionLanguageGraphConfig) -> Self {
60 let model_id = Uuid::new_v4();
61
62 let vision_encoder = VisionEncoder::new(config.vision_config.clone());
63 let language_encoder = LanguageEncoder::new(config.language_config.clone());
64 let graph_encoder = GraphEncoder::new(config.graph_config.clone());
65 let multimodal_transformer = MultiModalTransformer::new(config.transformer_config.clone());
66 let meta_learner = MetaLearner::new(config.meta_learning_config.clone());
67
68 Self {
69 config,
70 model_id,
71 vision_encoder,
72 language_encoder,
73 graph_encoder,
74 multimodal_transformer,
75 meta_learner,
76 vision_embeddings: HashMap::new(),
77 language_embeddings: HashMap::new(),
78 graph_embeddings: HashMap::new(),
79 unified_embeddings: HashMap::new(),
80 training_stats: None,
81 is_trained: false,
82 }
83 }
84
85 pub async fn generate_unified_embedding(
87 &mut self,
88 image: Option<&Array3<f32>>,
89 text: Option<&str>,
90 graph_data: Option<(&Array2<f32>, &Array2<f32>, &Array2<f32>)>,
91 ) -> Result<Array1<f32>> {
92 let mut embeddings = Vec::new();
93
94 let vision_emb = if let Some(img) = image {
96 let emb = self.vision_encoder.encode_image(img)?;
97 self.vision_embeddings
98 .insert("current_image".to_string(), emb.clone());
99 emb
100 } else {
101 Array1::zeros(self.config.vision_config.vision_dim)
102 };
103 embeddings.push(vision_emb.clone());
104
105 let language_emb = if let Some(txt) = text {
107 let emb = self.language_encoder.encode_text(txt)?;
108 self.language_embeddings
109 .insert("current_text".to_string(), emb.clone());
110 emb
111 } else {
112 Array1::zeros(self.config.language_config.language_dim)
113 };
114 embeddings.push(language_emb.clone());
115
116 let graph_emb = if let Some((nodes, edges, adj)) = graph_data {
118 let emb = self.graph_encoder.encode_graph(nodes, edges, adj)?;
119 self.graph_embeddings
120 .insert("current_graph".to_string(), emb.clone());
121 emb
122 } else {
123 Array1::zeros(self.config.graph_config.graph_dim)
124 };
125 embeddings.push(graph_emb.clone());
126
127 let unified_emb =
129 self.multimodal_transformer
130 .fuse_embeddings(&vision_emb, &language_emb, &graph_emb)?;
131
132 self.unified_embeddings
133 .insert("current_unified".to_string(), unified_emb.clone());
134
135 Ok(unified_emb)
136 }
137
138 pub fn zero_shot_predict(
140 &self,
141 query_embedding: &Array1<f32>,
142 class_prototypes: &HashMap<String, Array1<f32>>,
143 ) -> Result<String> {
144 let mut best_class = String::new();
145 let mut best_score = f32::NEG_INFINITY;
146
147 for (class_name, prototype) in class_prototypes {
148 let score = self.cosine_similarity(query_embedding, prototype);
149 if score > best_score {
150 best_score = score;
151 best_class = class_name.clone();
152 }
153 }
154
155 Ok(best_class)
156 }
157
158 pub fn few_shot_adapt(
160 &mut self,
161 support_examples: &[(Array1<f32>, String)],
162 query_examples: &[Array1<f32>],
163 ) -> Result<Vec<String>> {
164 let support_set: Vec<(Array1<f32>, Array1<f32>)> = support_examples
166 .iter()
167 .map(|(emb, label)| {
168 let label_emb = Array1::from_vec(vec![label.len() as f32]); (emb.clone(), label_emb)
170 })
171 .collect();
172
173 let query_set: Vec<(Array1<f32>, Array1<f32>)> = query_examples
174 .iter()
175 .map(|emb| (emb.clone(), Array1::zeros(1)))
176 .collect();
177
178 let _adapted_params = self.meta_learner.adapt_to_task(&support_set, &query_set)?;
180
181 let mut predictions = Vec::new();
183
184 for query_emb in query_examples {
185 let mut best_label = String::new();
187 let mut best_distance = f32::INFINITY;
188
189 for (support_emb, label) in support_examples {
190 let distance = self.euclidean_distance(query_emb, support_emb);
191 if distance < best_distance {
192 best_distance = distance;
193 best_label = label.clone();
194 }
195 }
196
197 predictions.push(best_label);
198 }
199
200 Ok(predictions)
201 }
202
203 fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
205 let dot_product = a.dot(b);
206 let norm_a = a.dot(a).sqrt();
207 let norm_b = b.dot(b).sqrt();
208
209 if norm_a > 0.0 && norm_b > 0.0 {
210 dot_product / (norm_a * norm_b)
211 } else {
212 0.0
213 }
214 }
215
216 fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
218 let diff = a - b;
219 diff.dot(&diff).sqrt()
220 }
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct VisionLanguageGraphStats {
226 pub num_vision_samples: usize,
227 pub num_language_samples: usize,
228 pub num_graph_samples: usize,
229 pub num_unified_embeddings: usize,
230 pub vision_dim: usize,
231 pub language_dim: usize,
232 pub graph_dim: usize,
233 pub unified_dim: usize,
234 pub zero_shot_accuracy: f32,
235 pub few_shot_accuracy: f32,
236 pub cross_modal_alignment_score: f32,
237}
238
239impl Default for VisionLanguageGraphStats {
240 fn default() -> Self {
241 Self {
242 num_vision_samples: 0,
243 num_language_samples: 0,
244 num_graph_samples: 0,
245 num_unified_embeddings: 0,
246 vision_dim: 768,
247 language_dim: 768,
248 graph_dim: 512,
249 unified_dim: 768,
250 zero_shot_accuracy: 0.0,
251 few_shot_accuracy: 0.0,
252 cross_modal_alignment_score: 0.0,
253 }
254 }
255}
256
257#[async_trait]
258impl EmbeddingModel for VisionLanguageGraphModel {
259 fn config(&self) -> &ModelConfig {
260 &self.config.base_config
261 }
262
263 fn model_id(&self) -> &Uuid {
264 &self.model_id
265 }
266
267 fn model_type(&self) -> &'static str {
268 "VisionLanguageGraphModel"
269 }
270
271 fn add_triple(&mut self, _triple: Triple) -> Result<()> {
272 Ok(())
274 }
275
276 async fn train(&mut self, epochs: Option<usize>) -> Result<TrainingStats> {
277 let epochs = epochs.unwrap_or(self.config.base_config.max_epochs);
278 let start_time = std::time::Instant::now();
279
280 let mut loss_history = Vec::new();
281
282 for epoch in 0..epochs {
283 let epoch_loss = self.train_epoch().await?;
285 loss_history.push(epoch_loss);
286
287 if epoch > 10 && epoch_loss < 1e-4 {
288 break;
289 }
290 }
291
292 let training_time = start_time.elapsed().as_secs_f64();
293 let final_loss = loss_history.last().copied().unwrap_or(0.0);
294
295 let stats = TrainingStats {
296 epochs_completed: loss_history.len(),
297 final_loss,
298 training_time_seconds: training_time,
299 convergence_achieved: final_loss < 1e-4,
300 loss_history,
301 };
302
303 self.training_stats = Some(stats.clone());
304 self.is_trained = true;
305
306 Ok(stats)
307 }
308
309 fn get_entity_embedding(&self, entity: &str) -> Result<Vector> {
310 if let Some(embedding) = self.unified_embeddings.get(entity) {
311 Ok(Vector::new(embedding.to_vec()))
312 } else {
313 Err(anyhow!("Entity not found: {}", entity))
314 }
315 }
316
317 fn get_relation_embedding(&self, relation: &str) -> Result<Vector> {
318 if let Some(embedding) = self.unified_embeddings.get(relation) {
319 Ok(Vector::new(embedding.to_vec()))
320 } else {
321 Err(anyhow!("Relation not found: {}", relation))
322 }
323 }
324
325 fn score_triple(&self, subject: &str, predicate: &str, object: &str) -> Result<f64> {
326 let subject_emb = self.get_entity_embedding(subject)?;
327 let predicate_emb = self.get_relation_embedding(predicate)?;
328 let object_emb = self.get_entity_embedding(object)?;
329
330 let subject_arr = Array1::from_vec(subject_emb.values);
332 let predicate_arr = Array1::from_vec(predicate_emb.values);
333 let object_arr = Array1::from_vec(object_emb.values);
334
335 let predicted = &subject_arr + &predicate_arr;
336 let diff = &predicted - &object_arr;
337 let distance = diff.dot(&diff).sqrt();
338
339 Ok(-distance as f64)
340 }
341
342 fn predict_objects(
343 &self,
344 subject: &str,
345 predicate: &str,
346 k: usize,
347 ) -> Result<Vec<(String, f64)>> {
348 let mut scores = Vec::new();
349
350 for entity in self.unified_embeddings.keys() {
351 if entity != subject {
352 let score = self.score_triple(subject, predicate, entity)?;
353 scores.push((entity.clone(), score));
354 }
355 }
356
357 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
358 scores.truncate(k);
359
360 Ok(scores)
361 }
362
363 fn predict_subjects(
364 &self,
365 predicate: &str,
366 object: &str,
367 k: usize,
368 ) -> Result<Vec<(String, f64)>> {
369 let mut scores = Vec::new();
370
371 for entity in self.unified_embeddings.keys() {
372 if entity != object {
373 let score = self.score_triple(entity, predicate, object)?;
374 scores.push((entity.clone(), score));
375 }
376 }
377
378 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
379 scores.truncate(k);
380
381 Ok(scores)
382 }
383
384 fn predict_relations(
385 &self,
386 subject: &str,
387 object: &str,
388 k: usize,
389 ) -> Result<Vec<(String, f64)>> {
390 let mut scores = Vec::new();
391
392 for relation in self.unified_embeddings.keys() {
393 let score = self.score_triple(subject, relation, object)?;
394 scores.push((relation.clone(), score));
395 }
396
397 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
398 scores.truncate(k);
399
400 Ok(scores)
401 }
402
403 fn get_entities(&self) -> Vec<String> {
404 self.unified_embeddings.keys().cloned().collect()
405 }
406
407 fn get_relations(&self) -> Vec<String> {
408 self.unified_embeddings.keys().cloned().collect()
409 }
410
411 fn get_stats(&self) -> ModelStats {
412 ModelStats {
413 num_entities: self.unified_embeddings.len(),
414 num_relations: self.unified_embeddings.len(),
415 num_triples: 0,
416 dimensions: self.config.transformer_config.unified_dim,
417 is_trained: self.is_trained,
418 model_type: self.model_type().to_string(),
419 creation_time: Utc::now(),
420 last_training_time: if self.is_trained {
421 Some(Utc::now())
422 } else {
423 None
424 },
425 }
426 }
427
428 fn save(&self, _path: &str) -> Result<()> {
429 Ok(())
430 }
431
432 fn load(&mut self, _path: &str) -> Result<()> {
433 Ok(())
434 }
435
436 fn clear(&mut self) {
437 self.vision_embeddings.clear();
438 self.language_embeddings.clear();
439 self.graph_embeddings.clear();
440 self.unified_embeddings.clear();
441 self.is_trained = false;
442 self.training_stats = None;
443 }
444
445 fn is_trained(&self) -> bool {
446 self.is_trained
447 }
448
449 async fn encode(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
450 let mut results = Vec::new();
451
452 for text in texts {
453 let embedding = self.language_encoder.encode_text(text)?;
454 results.push(embedding.to_vec());
455 }
456
457 Ok(results)
458 }
459}
460
461impl VisionLanguageGraphModel {
462 async fn train_epoch(&mut self) -> Result<f64> {
464 let mut random = Random::default();
466 let vision_loss = 0.1 * random.random::<f64>();
467 let language_loss = 0.1 * random.random::<f64>();
468 let graph_loss = 0.1 * random.random::<f64>();
469 let fusion_loss = 0.1 * random.random::<f64>();
470
471 let total_loss = vision_loss + language_loss + graph_loss + fusion_loss;
472
473 Ok(total_loss)
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 fn small_test_config() -> VisionLanguageGraphConfig {
484 VisionLanguageGraphConfig {
485 vision_config: VisionEncoderConfig {
486 image_size: (32, 32),
487 channels: 3,
488 patch_size: (8, 8),
489 vision_dim: 32,
490 cnn_config: CNNConfig {
491 num_layers: 2,
492 filter_sizes: vec![8, 16],
493 stride_sizes: vec![2, 2],
494 ..CNNConfig::default()
495 },
496 vit_config: ViTConfig {
497 num_layers: 2,
498 num_heads: 2,
499 mlp_dim: 64,
500 ..ViTConfig::default()
501 },
502 ..VisionEncoderConfig::default()
503 },
504 language_config: LanguageEncoderConfig {
505 vocab_size: 256,
506 language_dim: 32,
507 max_seq_length: 16,
508 transformer_config: LanguageTransformerConfig {
509 num_layers: 2,
510 num_heads: 2,
511 hidden_dim: 32,
512 intermediate_dim: 64,
513 ..LanguageTransformerConfig::default()
514 },
515 ..LanguageEncoderConfig::default()
516 },
517 graph_config: GraphEncoderConfig {
518 node_dim: 16,
519 edge_dim: 8,
520 graph_dim: 32,
521 num_layers: 2,
522 ..GraphEncoderConfig::default()
523 },
524 transformer_config: MultiModalTransformerConfig {
525 unified_dim: 32,
526 num_fusion_layers: 2,
527 cross_attention_config: CrossAttentionConfig {
528 num_heads: 2,
529 head_dim: 16,
530 ..CrossAttentionConfig::default()
531 },
532 ..MultiModalTransformerConfig::default()
533 },
534 ..VisionLanguageGraphConfig::default()
535 }
536 }
537
538 #[test]
539 fn test_vision_language_graph_config_default() {
540 let config = VisionLanguageGraphConfig::default();
541 assert_eq!(config.vision_config.vision_dim, 768);
542 assert_eq!(config.language_config.language_dim, 768);
543 assert_eq!(config.graph_config.graph_dim, 768); }
545
546 #[test]
547 fn test_vision_encoder_creation() {
548 let config = VisionEncoderConfig::default();
549 let encoder = VisionEncoder::new(config);
550 assert!(!encoder.cnn_parameters.is_empty());
551 assert!(!encoder.vit_parameters.is_empty());
552 }
553
554 #[test]
555 fn test_language_encoder_creation() {
556 let config = LanguageEncoderConfig {
558 vocab_size: 256,
559 language_dim: 32,
560 max_seq_length: 16,
561 transformer_config: LanguageTransformerConfig {
562 num_layers: 2,
563 num_heads: 2,
564 hidden_dim: 32,
565 intermediate_dim: 64,
566 ..LanguageTransformerConfig::default()
567 },
568 ..LanguageEncoderConfig::default()
569 };
570 let encoder = LanguageEncoder::new(config);
571 assert_eq!(encoder.token_embeddings.nrows(), 256);
572 assert_eq!(encoder.position_embeddings.nrows(), 16);
573 }
574
575 #[test]
576 fn test_graph_encoder_creation() {
577 let config = GraphEncoderConfig::default();
578 let encoder = GraphEncoder::new(config);
579 assert!(!encoder.node_parameters.is_empty());
580 assert!(!encoder.edge_parameters.is_empty());
581 }
582
583 #[test]
584 fn test_multimodal_transformer_creation() {
585 let config = MultiModalTransformerConfig::default();
586 let transformer = MultiModalTransformer::new(config);
587 assert!(!transformer.cross_attention_params.is_empty());
588 assert!(!transformer.fusion_params.is_empty());
589 }
590
591 #[test]
592 #[cfg_attr(debug_assertions, ignore = "Model initialization slow in debug builds")]
593 fn test_vision_language_graph_model_creation() {
594 let config = VisionLanguageGraphConfig::default();
595 let model = VisionLanguageGraphModel::new(config);
596 assert!(!model.is_trained);
597 assert_eq!(model.unified_embeddings.len(), 0);
598 }
599
600 #[test]
601 fn test_vision_encoder_image_encoding() {
602 let config = VisionEncoderConfig::default();
603 let encoder = VisionEncoder::new(config);
604
605 let mut random = Random::default();
606 let image = Array3::from_shape_fn((224, 224, 3), |_| random.random::<f32>());
607 let embedding = encoder.encode_image(&image).unwrap();
608
609 assert_eq!(embedding.len(), encoder.config.vision_dim);
610 }
611
612 #[test]
613 fn test_language_encoder_text_encoding() {
614 let config = LanguageEncoderConfig {
616 vocab_size: 256,
617 language_dim: 32,
618 max_seq_length: 16,
619 transformer_config: LanguageTransformerConfig {
620 num_layers: 2,
621 num_heads: 2,
622 hidden_dim: 32,
623 intermediate_dim: 64,
624 ..LanguageTransformerConfig::default()
625 },
626 ..LanguageEncoderConfig::default()
627 };
628 let encoder = LanguageEncoder::new(config);
629
630 let text = "Hello world, this is a test";
631 let embedding = encoder
632 .encode_text(text)
633 .expect("encode_text should succeed");
634
635 assert_eq!(embedding.len(), encoder.config.language_dim);
636 }
637
638 #[test]
639 fn test_graph_encoder_graph_encoding() {
640 let config = GraphEncoderConfig::default();
641 let node_dim = config.node_dim;
642 let edge_dim = config.edge_dim;
643 let encoder = GraphEncoder::new(config);
644
645 let mut random = Random::default();
646 let node_features = Array2::from_shape_fn((5, node_dim), |_| random.random::<f32>());
647 let edge_features = Array2::from_shape_fn((10, edge_dim), |_| random.random::<f32>());
648 let adjacency = Array2::eye(5);
649
650 let embedding = encoder
651 .encode_graph(&node_features, &edge_features, &adjacency)
652 .unwrap();
653
654 assert_eq!(embedding.len(), encoder.config.graph_dim);
655 }
656
657 #[tokio::test]
658 #[cfg_attr(debug_assertions, ignore = "Embedding tests require release builds")]
659 async fn test_unified_embedding_generation() {
660 let config = VisionLanguageGraphConfig::default();
661 let mut model = VisionLanguageGraphModel::new(config);
662
663 let mut random = Random::default();
664 let image = Array3::from_shape_fn((224, 224, 3), |_| random.random::<f32>());
665 let text = "A beautiful landscape with mountains";
666 let node_features = Array2::from_shape_fn((3, 256), |_| random.random::<f32>());
667 let edge_features = Array2::from_shape_fn((6, 128), |_| random.random::<f32>());
668 let adjacency = Array2::eye(3);
669
670 let unified_embedding = model
671 .generate_unified_embedding(
672 Some(&image),
673 Some(text),
674 Some((&node_features, &edge_features, &adjacency)),
675 )
676 .await
677 .unwrap();
678
679 assert!(!unified_embedding.is_empty());
680 assert_eq!(model.vision_embeddings.len(), 1);
681 assert_eq!(model.language_embeddings.len(), 1);
682 assert_eq!(model.graph_embeddings.len(), 1);
683 assert_eq!(model.unified_embeddings.len(), 1);
684 }
685
686 #[test]
687 fn test_zero_shot_prediction() {
688 let config = small_test_config();
689 let model = VisionLanguageGraphModel::new(config);
690
691 let mut random = Random::default();
692 let query = Array1::from_shape_fn(32, |_| random.random::<f32>());
693
694 let mut prototypes = HashMap::new();
695 let mut random = Random::default();
696 prototypes.insert(
697 "class1".to_string(),
698 Array1::from_shape_fn(32, |_| random.random::<f32>()),
699 );
700 let mut random = Random::default();
701 prototypes.insert(
702 "class2".to_string(),
703 Array1::from_shape_fn(32, |_| random.random::<f32>()),
704 );
705
706 let prediction = model
707 .zero_shot_predict(&query, &prototypes)
708 .expect("zero_shot_predict should succeed");
709 assert!(prototypes.contains_key(&prediction));
710 }
711
712 #[test]
713 #[cfg_attr(debug_assertions, ignore = "Embedding tests require release builds")]
714 fn test_few_shot_adaptation() {
715 let config = VisionLanguageGraphConfig::default();
716 let mut model = VisionLanguageGraphModel::new(config);
717
718 let mut random = Random::default();
719 let support_examples = vec![
720 (
721 Array1::from_shape_fn(512, |_| random.random::<f32>()),
722 "cat".to_string(),
723 ),
724 (
725 Array1::from_shape_fn(512, |_| random.random::<f32>()),
726 "dog".to_string(),
727 ),
728 ];
729
730 let mut random = Random::default();
731 let query_examples = vec![
732 Array1::from_shape_fn(512, |_| random.random::<f32>()),
733 Array1::from_shape_fn(512, |_| random.random::<f32>()),
734 ];
735
736 let predictions = model
737 .few_shot_adapt(&support_examples, &query_examples)
738 .unwrap();
739 assert_eq!(predictions.len(), 2);
740 }
741
742 #[test]
743 fn test_meta_learner_adaptation() {
744 let config = MetaLearningConfig::default();
745 let mut meta_learner = MetaLearner::new(config);
746
747 let mut random = Random::default();
748 let support_set = vec![
749 (
750 Array1::from_shape_fn(512, |_| random.random::<f32>()),
751 Array1::from_vec(vec![1.0]),
752 ),
753 (
754 Array1::from_shape_fn(512, |_| random.random::<f32>()),
755 Array1::from_vec(vec![0.0]),
756 ),
757 ];
758
759 let query_set = vec![];
760
761 let adapted_params = meta_learner
762 .adapt_to_task(&support_set, &query_set)
763 .unwrap();
764 assert!(!adapted_params.is_empty());
765 }
766
767 #[tokio::test]
768 async fn test_vision_language_graph_training() {
769 let config = small_test_config();
770 let mut model = VisionLanguageGraphModel::new(config);
771
772 let stats = model.train(Some(3)).await.expect("training should succeed");
773 assert_eq!(stats.epochs_completed, 3);
774 assert!(model.is_trained());
775 }
776
777 #[tokio::test]
778 #[cfg_attr(debug_assertions, ignore = "Embedding tests require release builds")]
779 async fn test_vision_language_graph_encoding() {
780 let config = VisionLanguageGraphConfig::default();
781 let expected_dim = config.language_config.language_dim;
782 let model = VisionLanguageGraphModel::new(config);
783
784 let texts = vec!["hello world".to_string(), "test encoding".to_string()];
785 let embeddings = model.encode(&texts).await.unwrap();
786
787 assert_eq!(embeddings.len(), 2);
788 assert_eq!(embeddings[0].len(), expected_dim);
789 }
790}