1use anyhow::{anyhow, Result};
46use serde::{Deserialize, Serialize};
47use std::collections::HashMap;
48use tracing::{debug, info};
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct DatasetCharacteristics {
53 pub num_entities: usize,
55 pub num_relations: usize,
57 pub num_triples: usize,
59 pub avg_degree: f64,
61 pub is_sparse: bool,
63 pub has_hierarchies: bool,
65 pub has_complex_relations: bool,
67 pub domain: Option<String>,
69}
70
71impl DatasetCharacteristics {
72 pub fn infer(num_entities: usize, num_relations: usize, num_triples: usize) -> Self {
74 let avg_degree = if num_entities > 0 {
75 (num_triples as f64 * 2.0) / num_entities as f64
76 } else {
77 0.0
78 };
79
80 let is_sparse = avg_degree < (num_entities as f64).sqrt();
81
82 Self {
83 num_entities,
84 num_relations,
85 num_triples,
86 avg_degree,
87 is_sparse,
88 has_hierarchies: false, has_complex_relations: num_relations > 10,
90 domain: None,
91 }
92 }
93
94 pub fn density(&self) -> f64 {
96 if self.num_entities == 0 {
97 return 0.0;
98 }
99 let max_possible = (self.num_entities * (self.num_entities - 1)) as f64;
100 if max_possible == 0.0 {
101 return 0.0;
102 }
103 self.num_triples as f64 / max_possible
104 }
105
106 pub fn estimated_memory_mb(&self, embedding_dim: usize) -> f64 {
108 let entity_mem = (self.num_entities * embedding_dim * 4) as f64 / 1_048_576.0; let relation_mem = (self.num_relations * embedding_dim * 4) as f64 / 1_048_576.0;
111 let overhead = 50.0; entity_mem + relation_mem + overhead
114 }
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
119pub enum UseCaseType {
120 LinkPrediction,
122 EntityClassification,
124 RelationExtraction,
126 QuestionAnswering,
128 KGCompletion,
130 SimilaritySearch,
132 GeneralPurpose,
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
138pub enum ModelType {
139 TransE,
140 DistMult,
141 ComplEx,
142 RotatE,
143 HolE,
144 ConvE,
145 TuckER,
146 QuatD,
147 GNN,
148 Transformer,
149}
150
151impl std::fmt::Display for ModelType {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 match self {
154 ModelType::TransE => write!(f, "TransE"),
155 ModelType::DistMult => write!(f, "DistMult"),
156 ModelType::ComplEx => write!(f, "ComplEx"),
157 ModelType::RotatE => write!(f, "RotatE"),
158 ModelType::HolE => write!(f, "HolE"),
159 ModelType::ConvE => write!(f, "ConvE"),
160 ModelType::TuckER => write!(f, "TuckER"),
161 ModelType::QuatD => write!(f, "QuatD"),
162 ModelType::GNN => write!(f, "GNN"),
163 ModelType::Transformer => write!(f, "Transformer"),
164 }
165 }
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct ModelRecommendation {
171 pub model_type: ModelType,
172 pub suitability_score: f64,
173 pub reasoning: String,
174 pub pros: Vec<String>,
175 pub cons: Vec<String>,
176 pub recommended_dimensions: usize,
177 pub estimated_training_time: TrainingTime,
178 pub memory_requirement: MemoryRequirement,
179}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
183pub enum TrainingTime {
184 Fast, Medium, Slow, VerySlow, }
189
190impl std::fmt::Display for TrainingTime {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 match self {
193 TrainingTime::Fast => write!(f, "Fast (< 5 min)"),
194 TrainingTime::Medium => write!(f, "Medium (5-30 min)"),
195 TrainingTime::Slow => write!(f, "Slow (30-60 min)"),
196 TrainingTime::VerySlow => write!(f, "Very Slow (> 1 hour)"),
197 }
198 }
199}
200
201#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
203pub enum MemoryRequirement {
204 Low, Medium, High, VeryHigh, }
209
210impl std::fmt::Display for MemoryRequirement {
211 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212 match self {
213 MemoryRequirement::Low => write!(f, "Low (< 500 MB)"),
214 MemoryRequirement::Medium => write!(f, "Medium (500 MB - 2 GB)"),
215 MemoryRequirement::High => write!(f, "High (2 GB - 8 GB)"),
216 MemoryRequirement::VeryHigh => write!(f, "Very High (> 8 GB)"),
217 }
218 }
219}
220
221pub struct ModelSelector {
223 model_profiles: HashMap<ModelType, ModelProfile>,
224}
225
226#[derive(Debug, Clone)]
228struct ModelProfile {
229 model_type: ModelType,
230 strengths: Vec<String>,
232 weaknesses: Vec<String>,
234 best_for: Vec<UseCaseType>,
236 complexity: u8,
238 speed: u8,
240 accuracy: u8,
242 handles_sparse: bool,
244 handles_hierarchies: bool,
246 handles_complex_relations: bool,
248}
249
250impl Default for ModelSelector {
251 fn default() -> Self {
252 Self::new()
253 }
254}
255
256impl ModelSelector {
257 pub fn new() -> Self {
259 let mut model_profiles = HashMap::new();
260
261 model_profiles.insert(
263 ModelType::TransE,
264 ModelProfile {
265 model_type: ModelType::TransE,
266 strengths: vec![
267 "Simple and efficient".to_string(),
268 "Good for hierarchical relations".to_string(),
269 "Fast training".to_string(),
270 ],
271 weaknesses: vec![
272 "Cannot model symmetric relations well".to_string(),
273 "Limited expressiveness".to_string(),
274 ],
275 best_for: vec![UseCaseType::LinkPrediction, UseCaseType::GeneralPurpose],
276 complexity: 2,
277 speed: 9,
278 accuracy: 6,
279 handles_sparse: true,
280 handles_hierarchies: true,
281 handles_complex_relations: false,
282 },
283 );
284
285 model_profiles.insert(
287 ModelType::DistMult,
288 ModelProfile {
289 model_type: ModelType::DistMult,
290 strengths: vec![
291 "Very fast".to_string(),
292 "Good for symmetric relations".to_string(),
293 "Low memory footprint".to_string(),
294 ],
295 weaknesses: vec![
296 "Cannot model asymmetric relations".to_string(),
297 "Cannot capture composition".to_string(),
298 ],
299 best_for: vec![
300 UseCaseType::SimilaritySearch,
301 UseCaseType::EntityClassification,
302 ],
303 complexity: 1,
304 speed: 10,
305 accuracy: 5,
306 handles_sparse: true,
307 handles_hierarchies: false,
308 handles_complex_relations: false,
309 },
310 );
311
312 model_profiles.insert(
314 ModelType::ComplEx,
315 ModelProfile {
316 model_type: ModelType::ComplEx,
317 strengths: vec![
318 "Handles symmetric and asymmetric relations".to_string(),
319 "Good theoretical properties".to_string(),
320 "State-of-the-art performance".to_string(),
321 ],
322 weaknesses: vec![
323 "More complex than TransE".to_string(),
324 "Requires more memory".to_string(),
325 ],
326 best_for: vec![UseCaseType::LinkPrediction, UseCaseType::KGCompletion],
327 complexity: 5,
328 speed: 7,
329 accuracy: 8,
330 handles_sparse: true,
331 handles_hierarchies: true,
332 handles_complex_relations: true,
333 },
334 );
335
336 model_profiles.insert(
338 ModelType::RotatE,
339 ModelProfile {
340 model_type: ModelType::RotatE,
341 strengths: vec![
342 "Excellent for complex relations".to_string(),
343 "Handles composition patterns".to_string(),
344 "Strong theoretical foundation".to_string(),
345 ],
346 weaknesses: vec![
347 "Slower than simpler models".to_string(),
348 "Higher memory usage".to_string(),
349 ],
350 best_for: vec![UseCaseType::LinkPrediction, UseCaseType::RelationExtraction],
351 complexity: 6,
352 speed: 6,
353 accuracy: 9,
354 handles_sparse: true,
355 handles_hierarchies: true,
356 handles_complex_relations: true,
357 },
358 );
359
360 model_profiles.insert(
362 ModelType::HolE,
363 ModelProfile {
364 model_type: ModelType::HolE,
365 strengths: vec![
366 "Memory efficient".to_string(),
367 "Good compositional properties".to_string(),
368 "Fast inference".to_string(),
369 ],
370 weaknesses: vec![
371 "Training can be slower".to_string(),
372 "Less intuitive than TransE".to_string(),
373 ],
374 best_for: vec![UseCaseType::KGCompletion, UseCaseType::LinkPrediction],
375 complexity: 5,
376 speed: 7,
377 accuracy: 7,
378 handles_sparse: true,
379 handles_hierarchies: false,
380 handles_complex_relations: true,
381 },
382 );
383
384 model_profiles.insert(
386 ModelType::ConvE,
387 ModelProfile {
388 model_type: ModelType::ConvE,
389 strengths: vec![
390 "State-of-the-art accuracy".to_string(),
391 "Captures complex patterns".to_string(),
392 "Scalable to large graphs".to_string(),
393 ],
394 weaknesses: vec![
395 "Requires more computational resources".to_string(),
396 "More complex to tune".to_string(),
397 "Slower training".to_string(),
398 ],
399 best_for: vec![UseCaseType::LinkPrediction, UseCaseType::KGCompletion],
400 complexity: 8,
401 speed: 4,
402 accuracy: 9,
403 handles_sparse: false,
404 handles_hierarchies: true,
405 handles_complex_relations: true,
406 },
407 );
408
409 model_profiles.insert(
411 ModelType::GNN,
412 ModelProfile {
413 model_type: ModelType::GNN,
414 strengths: vec![
415 "Leverages graph structure".to_string(),
416 "Good for node classification".to_string(),
417 "Captures neighborhood information".to_string(),
418 ],
419 weaknesses: vec![
420 "Computationally expensive".to_string(),
421 "Not ideal for very large graphs".to_string(),
422 ],
423 best_for: vec![
424 UseCaseType::EntityClassification,
425 UseCaseType::QuestionAnswering,
426 ],
427 complexity: 7,
428 speed: 5,
429 accuracy: 8,
430 handles_sparse: false,
431 handles_hierarchies: true,
432 handles_complex_relations: true,
433 },
434 );
435
436 model_profiles.insert(
438 ModelType::Transformer,
439 ModelProfile {
440 model_type: ModelType::Transformer,
441 strengths: vec![
442 "Excellent for complex patterns".to_string(),
443 "State-of-the-art on many tasks".to_string(),
444 "Flexible architecture".to_string(),
445 ],
446 weaknesses: vec![
447 "Very computationally expensive".to_string(),
448 "Requires large amounts of data".to_string(),
449 "High memory usage".to_string(),
450 ],
451 best_for: vec![UseCaseType::QuestionAnswering, UseCaseType::GeneralPurpose],
452 complexity: 9,
453 speed: 3,
454 accuracy: 9,
455 handles_sparse: false,
456 handles_hierarchies: true,
457 handles_complex_relations: true,
458 },
459 );
460
461 Self { model_profiles }
462 }
463
464 pub fn recommend_models(
466 &self,
467 characteristics: &DatasetCharacteristics,
468 use_case: UseCaseType,
469 ) -> Result<Vec<ModelRecommendation>> {
470 info!(
471 "Recommending models for dataset with {} entities, {} relations, {} triples",
472 characteristics.num_entities,
473 characteristics.num_relations,
474 characteristics.num_triples
475 );
476
477 let mut recommendations = Vec::new();
478
479 for (model_type, profile) in &self.model_profiles {
480 let score = self.calculate_suitability_score(profile, characteristics, use_case);
481
482 if score > 0.3 {
483 let recommendation = self.create_recommendation(
485 *model_type,
486 profile,
487 characteristics,
488 score,
489 use_case,
490 );
491 recommendations.push(recommendation);
492 }
493 }
494
495 recommendations.sort_by(|a, b| {
497 b.suitability_score
498 .partial_cmp(&a.suitability_score)
499 .unwrap_or(std::cmp::Ordering::Equal)
500 });
501
502 debug!("Generated {} model recommendations", recommendations.len());
503
504 Ok(recommendations)
505 }
506
507 fn calculate_suitability_score(
509 &self,
510 profile: &ModelProfile,
511 characteristics: &DatasetCharacteristics,
512 use_case: UseCaseType,
513 ) -> f64 {
514 let mut score: f64 = 0.5; if profile.best_for.contains(&use_case) {
518 score += 0.3;
519 }
520
521 if characteristics.is_sparse && profile.handles_sparse {
523 score += 0.1;
524 }
525
526 if characteristics.has_hierarchies && profile.handles_hierarchies {
527 score += 0.1;
528 }
529
530 if characteristics.has_complex_relations && profile.handles_complex_relations {
531 score += 0.1;
532 }
533
534 if characteristics.num_triples < 10000 && profile.complexity > 6 {
536 score -= 0.2;
537 }
538
539 if characteristics.num_triples > 100000 && profile.speed < 5 {
541 score -= 0.1;
542 }
543
544 if use_case == UseCaseType::LinkPrediction && profile.accuracy >= 8 {
546 score += 0.1;
547 }
548
549 score.clamp(0.0, 1.0)
551 }
552
553 fn create_recommendation(
555 &self,
556 model_type: ModelType,
557 profile: &ModelProfile,
558 characteristics: &DatasetCharacteristics,
559 score: f64,
560 use_case: UseCaseType,
561 ) -> ModelRecommendation {
562 let recommended_dimensions = self.recommend_dimensions(characteristics, profile);
563
564 let training_time =
565 self.estimate_training_time(characteristics, profile, recommended_dimensions);
566
567 let memory_requirement =
568 self.estimate_memory_requirement(characteristics, recommended_dimensions);
569
570 let reasoning = self.generate_reasoning(profile, characteristics, use_case);
571
572 ModelRecommendation {
573 model_type,
574 suitability_score: score,
575 reasoning,
576 pros: profile.strengths.clone(),
577 cons: profile.weaknesses.clone(),
578 recommended_dimensions,
579 estimated_training_time: training_time,
580 memory_requirement,
581 }
582 }
583
584 fn recommend_dimensions(
586 &self,
587 characteristics: &DatasetCharacteristics,
588 profile: &ModelProfile,
589 ) -> usize {
590 let base_dim = if characteristics.num_entities < 1000 {
591 32
592 } else if characteristics.num_entities < 10000 {
593 64
594 } else if characteristics.num_entities < 100000 {
595 128
596 } else {
597 256
598 };
599
600 if profile.complexity > 7 {
602 base_dim / 2 } else {
604 base_dim
605 }
606 }
607
608 fn estimate_training_time(
610 &self,
611 characteristics: &DatasetCharacteristics,
612 profile: &ModelProfile,
613 _dimensions: usize,
614 ) -> TrainingTime {
615 let data_size_factor = characteristics.num_triples as f64 / 50000.0;
616 let speed_factor = profile.speed as f64 / 10.0;
617
618 let estimated_minutes = data_size_factor / speed_factor * 10.0;
619
620 if estimated_minutes < 5.0 {
621 TrainingTime::Fast
622 } else if estimated_minutes < 30.0 {
623 TrainingTime::Medium
624 } else if estimated_minutes < 60.0 {
625 TrainingTime::Slow
626 } else {
627 TrainingTime::VerySlow
628 }
629 }
630
631 fn estimate_memory_requirement(
633 &self,
634 characteristics: &DatasetCharacteristics,
635 dimensions: usize,
636 ) -> MemoryRequirement {
637 let memory_mb = characteristics.estimated_memory_mb(dimensions);
638
639 if memory_mb < 500.0 {
640 MemoryRequirement::Low
641 } else if memory_mb < 2000.0 {
642 MemoryRequirement::Medium
643 } else if memory_mb < 8000.0 {
644 MemoryRequirement::High
645 } else {
646 MemoryRequirement::VeryHigh
647 }
648 }
649
650 fn generate_reasoning(
652 &self,
653 profile: &ModelProfile,
654 characteristics: &DatasetCharacteristics,
655 use_case: UseCaseType,
656 ) -> String {
657 let mut reasons = Vec::new();
658
659 if profile.best_for.contains(&use_case) {
660 reasons.push(format!("Well-suited for {:?}", use_case));
661 }
662
663 if characteristics.is_sparse && profile.handles_sparse {
664 reasons.push("Handles sparse graphs effectively".to_string());
665 }
666
667 if characteristics.has_hierarchies && profile.handles_hierarchies {
668 reasons.push("Good for hierarchical structures".to_string());
669 }
670
671 if characteristics.has_complex_relations && profile.handles_complex_relations {
672 reasons.push("Capable of modeling complex relations".to_string());
673 }
674
675 if profile.speed >= 8 {
676 reasons.push("Fast training and inference".to_string());
677 }
678
679 if profile.accuracy >= 8 {
680 reasons.push("High accuracy on benchmarks".to_string());
681 }
682
683 if reasons.is_empty() {
684 "General-purpose model".to_string()
685 } else {
686 reasons.join("; ")
687 }
688 }
689
690 pub fn compare_models(
692 &self,
693 models: &[ModelType],
694 characteristics: &DatasetCharacteristics,
695 ) -> Result<ModelComparison> {
696 if models.is_empty() {
697 return Err(anyhow!("No models provided for comparison"));
698 }
699
700 let mut comparisons = HashMap::new();
701
702 for model_type in models {
703 if let Some(profile) = self.model_profiles.get(model_type) {
704 let dimensions = self.recommend_dimensions(characteristics, profile);
705 let training_time =
706 self.estimate_training_time(characteristics, profile, dimensions);
707 let memory_req = self.estimate_memory_requirement(characteristics, dimensions);
708
709 comparisons.insert(
710 *model_type,
711 ModelComparisonEntry {
712 model_type: *model_type,
713 complexity: profile.complexity,
714 speed: profile.speed,
715 accuracy: profile.accuracy,
716 recommended_dimensions: dimensions,
717 estimated_training_time: training_time,
718 memory_requirement: memory_req,
719 },
720 );
721 }
722 }
723
724 Ok(ModelComparison {
725 models: comparisons,
726 dataset_size: characteristics.num_triples,
727 })
728 }
729}
730
731#[derive(Debug, Clone, Serialize, Deserialize)]
733pub struct ModelComparison {
734 pub models: HashMap<ModelType, ModelComparisonEntry>,
735 pub dataset_size: usize,
736}
737
738#[derive(Debug, Clone, Serialize, Deserialize)]
740pub struct ModelComparisonEntry {
741 pub model_type: ModelType,
742 pub complexity: u8,
743 pub speed: u8,
744 pub accuracy: u8,
745 pub recommended_dimensions: usize,
746 pub estimated_training_time: TrainingTime,
747 pub memory_requirement: MemoryRequirement,
748}
749
750#[cfg(test)]
751mod tests {
752 use super::*;
753
754 #[test]
755 fn test_dataset_characteristics_infer() {
756 let chars = DatasetCharacteristics::infer(1000, 10, 5000);
757 assert_eq!(chars.num_entities, 1000);
758 assert_eq!(chars.num_relations, 10);
759 assert_eq!(chars.num_triples, 5000);
760 assert!(chars.avg_degree > 0.0);
761 }
762
763 #[test]
764 fn test_dataset_density() {
765 let chars = DatasetCharacteristics {
766 num_entities: 100,
767 num_relations: 5,
768 num_triples: 500,
769 avg_degree: 5.0,
770 is_sparse: false,
771 has_hierarchies: false,
772 has_complex_relations: false,
773 domain: None,
774 };
775
776 let density = chars.density();
777 assert!(density > 0.0);
778 assert!(density < 1.0);
779 }
780
781 #[test]
782 fn test_model_selector_creation() {
783 let selector = ModelSelector::new();
784 assert!(!selector.model_profiles.is_empty());
785 assert!(selector.model_profiles.contains_key(&ModelType::TransE));
786 assert!(selector.model_profiles.contains_key(&ModelType::ComplEx));
787 }
788
789 #[test]
790 fn test_model_recommendation() -> Result<()> {
791 let selector = ModelSelector::new();
792 let characteristics = DatasetCharacteristics::infer(10000, 50, 50000);
793
794 let recommendations =
795 selector.recommend_models(&characteristics, UseCaseType::LinkPrediction)?;
796
797 assert!(!recommendations.is_empty());
798
799 for i in 1..recommendations.len() {
801 assert!(
802 recommendations[i - 1].suitability_score >= recommendations[i].suitability_score
803 );
804 }
805
806 Ok(())
807 }
808
809 #[test]
810 fn test_model_comparison() -> Result<()> {
811 let selector = ModelSelector::new();
812 let characteristics = DatasetCharacteristics::infer(10000, 50, 50000);
813
814 let models = vec![ModelType::TransE, ModelType::ComplEx, ModelType::RotatE];
815 let comparison = selector.compare_models(&models, &characteristics)?;
816
817 assert_eq!(comparison.models.len(), 3);
818 assert!(comparison.models.contains_key(&ModelType::TransE));
819 assert!(comparison.models.contains_key(&ModelType::ComplEx));
820 assert!(comparison.models.contains_key(&ModelType::RotatE));
821
822 Ok(())
823 }
824
825 #[test]
826 fn test_small_dataset_recommendations() -> Result<()> {
827 let selector = ModelSelector::new();
828 let characteristics = DatasetCharacteristics::infer(100, 5, 500);
829
830 let recommendations =
831 selector.recommend_models(&characteristics, UseCaseType::GeneralPurpose)?;
832
833 let top_model = &recommendations[0];
835 assert!(top_model.recommended_dimensions <= 64);
836
837 Ok(())
838 }
839
840 #[test]
841 fn test_large_dataset_recommendations() -> Result<()> {
842 let selector = ModelSelector::new();
843 let characteristics = DatasetCharacteristics::infer(100000, 100, 500000);
844
845 let recommendations =
846 selector.recommend_models(&characteristics, UseCaseType::LinkPrediction)?;
847
848 let top_model = &recommendations[0];
849 assert!(top_model.recommended_dimensions >= 64);
850
851 Ok(())
852 }
853
854 #[test]
855 fn test_memory_estimation() {
856 let characteristics = DatasetCharacteristics::infer(10000, 50, 50000);
857 let memory_mb = characteristics.estimated_memory_mb(128);
858
859 assert!(memory_mb > 0.0);
860 assert!(memory_mb < 10000.0); }
862
863 #[test]
864 fn test_use_case_specific_recommendations() -> Result<()> {
865 let selector = ModelSelector::new();
866 let characteristics = DatasetCharacteristics::infer(10000, 50, 50000);
867
868 let link_pred_recs =
869 selector.recommend_models(&characteristics, UseCaseType::LinkPrediction)?;
870
871 let similarity_recs =
872 selector.recommend_models(&characteristics, UseCaseType::SimilaritySearch)?;
873
874 assert!(!link_pred_recs.is_empty());
876 assert!(!similarity_recs.is_empty());
877
878 Ok(())
879 }
880}