oxirs_embed/
model_selection.rs

1//! # Model Selection Guidance
2//!
3//! This module provides intelligent model selection and recommendation capabilities
4//! to help users choose the most appropriate embedding model for their specific
5//! knowledge graph and use case.
6//!
7//! ## Features
8//!
9//! - **Automatic Model Recommendation**: Based on dataset characteristics
10//! - **Model Comparison**: Compare multiple models on the same dataset
11//! - **Performance Profiling**: Benchmark model performance
12//! - **Resource Requirements**: Estimate memory and compute needs
13//! - **Use Case Matching**: Recommend models for specific applications
14//!
15//! ## Example Usage
16//!
17//! ```rust,no_run
18//! use oxirs_embed::model_selection::{ModelSelector, DatasetCharacteristics, UseCaseType};
19//!
20//! # async fn example() -> anyhow::Result<()> {
21//! // Define dataset characteristics
22//! let characteristics = DatasetCharacteristics {
23//!     num_entities: 10000,
24//!     num_relations: 50,
25//!     num_triples: 50000,
26//!     avg_degree: 5.0,
27//!     is_sparse: false,
28//!     has_hierarchies: true,
29//!     has_complex_relations: true,
30//!     domain: Some("biomedical".to_string()),
31//! };
32//!
33//! // Get model recommendations
34//! let selector = ModelSelector::new();
35//! let recommendations = selector.recommend_models(&characteristics, UseCaseType::LinkPrediction)?;
36//!
37//! for rec in recommendations {
38//!     println!("Model: {}, Score: {:.2}, Reason: {}",
39//!              rec.model_type, rec.suitability_score, rec.reasoning);
40//! }
41//! # Ok(())
42//! # }
43//! ```
44
45use anyhow::{anyhow, Result};
46use serde::{Deserialize, Serialize};
47use std::collections::HashMap;
48use tracing::{debug, info};
49
50/// Dataset characteristics for model selection
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct DatasetCharacteristics {
53    /// Number of unique entities
54    pub num_entities: usize,
55    /// Number of unique relations
56    pub num_relations: usize,
57    /// Total number of triples
58    pub num_triples: usize,
59    /// Average node degree
60    pub avg_degree: f64,
61    /// Whether the graph is sparse (avg_degree << num_entities)
62    pub is_sparse: bool,
63    /// Whether the graph has hierarchical structure
64    pub has_hierarchies: bool,
65    /// Whether the graph has complex multi-hop relations
66    pub has_complex_relations: bool,
67    /// Domain of the knowledge graph (e.g., "biomedical", "general", "social")
68    pub domain: Option<String>,
69}
70
71impl DatasetCharacteristics {
72    /// Automatically infer characteristics from basic statistics
73    pub fn infer(num_entities: usize, num_relations: usize, num_triples: usize) -> Self {
74        let avg_degree = if num_entities > 0 {
75            (num_triples as f64 * 2.0) / num_entities as f64
76        } else {
77            0.0
78        };
79
80        let is_sparse = avg_degree < (num_entities as f64).sqrt();
81
82        Self {
83            num_entities,
84            num_relations,
85            num_triples,
86            avg_degree,
87            is_sparse,
88            has_hierarchies: false, // Conservative default
89            has_complex_relations: num_relations > 10,
90            domain: None,
91        }
92    }
93
94    /// Calculate graph density
95    pub fn density(&self) -> f64 {
96        if self.num_entities == 0 {
97            return 0.0;
98        }
99        let max_possible = (self.num_entities * (self.num_entities - 1)) as f64;
100        if max_possible == 0.0 {
101            return 0.0;
102        }
103        self.num_triples as f64 / max_possible
104    }
105
106    /// Estimate memory requirements in MB
107    pub fn estimated_memory_mb(&self, embedding_dim: usize) -> f64 {
108        // Rough estimate: entities + relations + overhead
109        let entity_mem = (self.num_entities * embedding_dim * 4) as f64 / 1_048_576.0; // 4 bytes per f32
110        let relation_mem = (self.num_relations * embedding_dim * 4) as f64 / 1_048_576.0;
111        let overhead = 50.0; // MB for other structures
112
113        entity_mem + relation_mem + overhead
114    }
115}
116
117/// Type of use case for model selection
118#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
119pub enum UseCaseType {
120    /// Link prediction (predicting missing triples)
121    LinkPrediction,
122    /// Entity classification
123    EntityClassification,
124    /// Relation extraction
125    RelationExtraction,
126    /// Question answering
127    QuestionAnswering,
128    /// Knowledge graph completion
129    KGCompletion,
130    /// Similarity search
131    SimilaritySearch,
132    /// General purpose embeddings
133    GeneralPurpose,
134}
135
136/// Available embedding model types
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
138pub enum ModelType {
139    TransE,
140    DistMult,
141    ComplEx,
142    RotatE,
143    HolE,
144    ConvE,
145    TuckER,
146    QuatD,
147    GNN,
148    Transformer,
149}
150
151impl std::fmt::Display for ModelType {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        match self {
154            ModelType::TransE => write!(f, "TransE"),
155            ModelType::DistMult => write!(f, "DistMult"),
156            ModelType::ComplEx => write!(f, "ComplEx"),
157            ModelType::RotatE => write!(f, "RotatE"),
158            ModelType::HolE => write!(f, "HolE"),
159            ModelType::ConvE => write!(f, "ConvE"),
160            ModelType::TuckER => write!(f, "TuckER"),
161            ModelType::QuatD => write!(f, "QuatD"),
162            ModelType::GNN => write!(f, "GNN"),
163            ModelType::Transformer => write!(f, "Transformer"),
164        }
165    }
166}
167
168/// Model recommendation with reasoning
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct ModelRecommendation {
171    pub model_type: ModelType,
172    pub suitability_score: f64,
173    pub reasoning: String,
174    pub pros: Vec<String>,
175    pub cons: Vec<String>,
176    pub recommended_dimensions: usize,
177    pub estimated_training_time: TrainingTime,
178    pub memory_requirement: MemoryRequirement,
179}
180
181/// Training time estimate
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
183pub enum TrainingTime {
184    Fast,     // < 5 minutes
185    Medium,   // 5-30 minutes
186    Slow,     // 30-60 minutes
187    VerySlow, // > 1 hour
188}
189
190impl std::fmt::Display for TrainingTime {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        match self {
193            TrainingTime::Fast => write!(f, "Fast (< 5 min)"),
194            TrainingTime::Medium => write!(f, "Medium (5-30 min)"),
195            TrainingTime::Slow => write!(f, "Slow (30-60 min)"),
196            TrainingTime::VerySlow => write!(f, "Very Slow (> 1 hour)"),
197        }
198    }
199}
200
201/// Memory requirement estimate
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
203pub enum MemoryRequirement {
204    Low,      // < 500 MB
205    Medium,   // 500 MB - 2 GB
206    High,     // 2 GB - 8 GB
207    VeryHigh, // > 8 GB
208}
209
210impl std::fmt::Display for MemoryRequirement {
211    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212        match self {
213            MemoryRequirement::Low => write!(f, "Low (< 500 MB)"),
214            MemoryRequirement::Medium => write!(f, "Medium (500 MB - 2 GB)"),
215            MemoryRequirement::High => write!(f, "High (2 GB - 8 GB)"),
216            MemoryRequirement::VeryHigh => write!(f, "Very High (> 8 GB)"),
217        }
218    }
219}
220
221/// Model selector for intelligent recommendation
222pub struct ModelSelector {
223    model_profiles: HashMap<ModelType, ModelProfile>,
224}
225
226/// Profile of a model's characteristics
227#[derive(Debug, Clone)]
228struct ModelProfile {
229    model_type: ModelType,
230    /// Strengths of this model
231    strengths: Vec<String>,
232    /// Weaknesses of this model
233    weaknesses: Vec<String>,
234    /// Best use cases
235    best_for: Vec<UseCaseType>,
236    /// Complexity score (1-10, higher = more complex)
237    complexity: u8,
238    /// Speed score (1-10, higher = faster)
239    speed: u8,
240    /// Accuracy score (1-10, higher = more accurate)
241    accuracy: u8,
242    /// Works well with sparse graphs
243    handles_sparse: bool,
244    /// Works well with hierarchies
245    handles_hierarchies: bool,
246    /// Works well with complex relations
247    handles_complex_relations: bool,
248}
249
250impl Default for ModelSelector {
251    fn default() -> Self {
252        Self::new()
253    }
254}
255
256impl ModelSelector {
257    /// Create a new model selector with predefined model profiles
258    pub fn new() -> Self {
259        let mut model_profiles = HashMap::new();
260
261        // TransE profile
262        model_profiles.insert(
263            ModelType::TransE,
264            ModelProfile {
265                model_type: ModelType::TransE,
266                strengths: vec![
267                    "Simple and efficient".to_string(),
268                    "Good for hierarchical relations".to_string(),
269                    "Fast training".to_string(),
270                ],
271                weaknesses: vec![
272                    "Cannot model symmetric relations well".to_string(),
273                    "Limited expressiveness".to_string(),
274                ],
275                best_for: vec![UseCaseType::LinkPrediction, UseCaseType::GeneralPurpose],
276                complexity: 2,
277                speed: 9,
278                accuracy: 6,
279                handles_sparse: true,
280                handles_hierarchies: true,
281                handles_complex_relations: false,
282            },
283        );
284
285        // DistMult profile
286        model_profiles.insert(
287            ModelType::DistMult,
288            ModelProfile {
289                model_type: ModelType::DistMult,
290                strengths: vec![
291                    "Very fast".to_string(),
292                    "Good for symmetric relations".to_string(),
293                    "Low memory footprint".to_string(),
294                ],
295                weaknesses: vec![
296                    "Cannot model asymmetric relations".to_string(),
297                    "Cannot capture composition".to_string(),
298                ],
299                best_for: vec![
300                    UseCaseType::SimilaritySearch,
301                    UseCaseType::EntityClassification,
302                ],
303                complexity: 1,
304                speed: 10,
305                accuracy: 5,
306                handles_sparse: true,
307                handles_hierarchies: false,
308                handles_complex_relations: false,
309            },
310        );
311
312        // ComplEx profile
313        model_profiles.insert(
314            ModelType::ComplEx,
315            ModelProfile {
316                model_type: ModelType::ComplEx,
317                strengths: vec![
318                    "Handles symmetric and asymmetric relations".to_string(),
319                    "Good theoretical properties".to_string(),
320                    "State-of-the-art performance".to_string(),
321                ],
322                weaknesses: vec![
323                    "More complex than TransE".to_string(),
324                    "Requires more memory".to_string(),
325                ],
326                best_for: vec![UseCaseType::LinkPrediction, UseCaseType::KGCompletion],
327                complexity: 5,
328                speed: 7,
329                accuracy: 8,
330                handles_sparse: true,
331                handles_hierarchies: true,
332                handles_complex_relations: true,
333            },
334        );
335
336        // RotatE profile
337        model_profiles.insert(
338            ModelType::RotatE,
339            ModelProfile {
340                model_type: ModelType::RotatE,
341                strengths: vec![
342                    "Excellent for complex relations".to_string(),
343                    "Handles composition patterns".to_string(),
344                    "Strong theoretical foundation".to_string(),
345                ],
346                weaknesses: vec![
347                    "Slower than simpler models".to_string(),
348                    "Higher memory usage".to_string(),
349                ],
350                best_for: vec![UseCaseType::LinkPrediction, UseCaseType::RelationExtraction],
351                complexity: 6,
352                speed: 6,
353                accuracy: 9,
354                handles_sparse: true,
355                handles_hierarchies: true,
356                handles_complex_relations: true,
357            },
358        );
359
360        // HolE profile
361        model_profiles.insert(
362            ModelType::HolE,
363            ModelProfile {
364                model_type: ModelType::HolE,
365                strengths: vec![
366                    "Memory efficient".to_string(),
367                    "Good compositional properties".to_string(),
368                    "Fast inference".to_string(),
369                ],
370                weaknesses: vec![
371                    "Training can be slower".to_string(),
372                    "Less intuitive than TransE".to_string(),
373                ],
374                best_for: vec![UseCaseType::KGCompletion, UseCaseType::LinkPrediction],
375                complexity: 5,
376                speed: 7,
377                accuracy: 7,
378                handles_sparse: true,
379                handles_hierarchies: false,
380                handles_complex_relations: true,
381            },
382        );
383
384        // ConvE profile
385        model_profiles.insert(
386            ModelType::ConvE,
387            ModelProfile {
388                model_type: ModelType::ConvE,
389                strengths: vec![
390                    "State-of-the-art accuracy".to_string(),
391                    "Captures complex patterns".to_string(),
392                    "Scalable to large graphs".to_string(),
393                ],
394                weaknesses: vec![
395                    "Requires more computational resources".to_string(),
396                    "More complex to tune".to_string(),
397                    "Slower training".to_string(),
398                ],
399                best_for: vec![UseCaseType::LinkPrediction, UseCaseType::KGCompletion],
400                complexity: 8,
401                speed: 4,
402                accuracy: 9,
403                handles_sparse: false,
404                handles_hierarchies: true,
405                handles_complex_relations: true,
406            },
407        );
408
409        // GNN profile
410        model_profiles.insert(
411            ModelType::GNN,
412            ModelProfile {
413                model_type: ModelType::GNN,
414                strengths: vec![
415                    "Leverages graph structure".to_string(),
416                    "Good for node classification".to_string(),
417                    "Captures neighborhood information".to_string(),
418                ],
419                weaknesses: vec![
420                    "Computationally expensive".to_string(),
421                    "Not ideal for very large graphs".to_string(),
422                ],
423                best_for: vec![
424                    UseCaseType::EntityClassification,
425                    UseCaseType::QuestionAnswering,
426                ],
427                complexity: 7,
428                speed: 5,
429                accuracy: 8,
430                handles_sparse: false,
431                handles_hierarchies: true,
432                handles_complex_relations: true,
433            },
434        );
435
436        // Transformer profile
437        model_profiles.insert(
438            ModelType::Transformer,
439            ModelProfile {
440                model_type: ModelType::Transformer,
441                strengths: vec![
442                    "Excellent for complex patterns".to_string(),
443                    "State-of-the-art on many tasks".to_string(),
444                    "Flexible architecture".to_string(),
445                ],
446                weaknesses: vec![
447                    "Very computationally expensive".to_string(),
448                    "Requires large amounts of data".to_string(),
449                    "High memory usage".to_string(),
450                ],
451                best_for: vec![UseCaseType::QuestionAnswering, UseCaseType::GeneralPurpose],
452                complexity: 9,
453                speed: 3,
454                accuracy: 9,
455                handles_sparse: false,
456                handles_hierarchies: true,
457                handles_complex_relations: true,
458            },
459        );
460
461        Self { model_profiles }
462    }
463
464    /// Recommend models for given dataset and use case
465    pub fn recommend_models(
466        &self,
467        characteristics: &DatasetCharacteristics,
468        use_case: UseCaseType,
469    ) -> Result<Vec<ModelRecommendation>> {
470        info!(
471            "Recommending models for dataset with {} entities, {} relations, {} triples",
472            characteristics.num_entities,
473            characteristics.num_relations,
474            characteristics.num_triples
475        );
476
477        let mut recommendations = Vec::new();
478
479        for (model_type, profile) in &self.model_profiles {
480            let score = self.calculate_suitability_score(profile, characteristics, use_case);
481
482            if score > 0.3 {
483                // Only include models with reasonable suitability
484                let recommendation = self.create_recommendation(
485                    *model_type,
486                    profile,
487                    characteristics,
488                    score,
489                    use_case,
490                );
491                recommendations.push(recommendation);
492            }
493        }
494
495        // Sort by suitability score (descending)
496        recommendations.sort_by(|a, b| {
497            b.suitability_score
498                .partial_cmp(&a.suitability_score)
499                .unwrap_or(std::cmp::Ordering::Equal)
500        });
501
502        debug!("Generated {} model recommendations", recommendations.len());
503
504        Ok(recommendations)
505    }
506
507    /// Calculate suitability score for a model (0.0 to 1.0)
508    fn calculate_suitability_score(
509        &self,
510        profile: &ModelProfile,
511        characteristics: &DatasetCharacteristics,
512        use_case: UseCaseType,
513    ) -> f64 {
514        let mut score: f64 = 0.5; // Base score
515
516        // Use case match (strong signal)
517        if profile.best_for.contains(&use_case) {
518            score += 0.3;
519        }
520
521        // Dataset characteristics match
522        if characteristics.is_sparse && profile.handles_sparse {
523            score += 0.1;
524        }
525
526        if characteristics.has_hierarchies && profile.handles_hierarchies {
527            score += 0.1;
528        }
529
530        if characteristics.has_complex_relations && profile.handles_complex_relations {
531            score += 0.1;
532        }
533
534        // Penalize complex models for small datasets
535        if characteristics.num_triples < 10000 && profile.complexity > 6 {
536            score -= 0.2;
537        }
538
539        // Penalize slow models for large datasets (unless accuracy is critical)
540        if characteristics.num_triples > 100000 && profile.speed < 5 {
541            score -= 0.1;
542        }
543
544        // Bonus for high accuracy models in link prediction
545        if use_case == UseCaseType::LinkPrediction && profile.accuracy >= 8 {
546            score += 0.1;
547        }
548
549        // Clamp to [0, 1]
550        score.clamp(0.0, 1.0)
551    }
552
553    /// Create a detailed recommendation
554    fn create_recommendation(
555        &self,
556        model_type: ModelType,
557        profile: &ModelProfile,
558        characteristics: &DatasetCharacteristics,
559        score: f64,
560        use_case: UseCaseType,
561    ) -> ModelRecommendation {
562        let recommended_dimensions = self.recommend_dimensions(characteristics, profile);
563
564        let training_time =
565            self.estimate_training_time(characteristics, profile, recommended_dimensions);
566
567        let memory_requirement =
568            self.estimate_memory_requirement(characteristics, recommended_dimensions);
569
570        let reasoning = self.generate_reasoning(profile, characteristics, use_case);
571
572        ModelRecommendation {
573            model_type,
574            suitability_score: score,
575            reasoning,
576            pros: profile.strengths.clone(),
577            cons: profile.weaknesses.clone(),
578            recommended_dimensions,
579            estimated_training_time: training_time,
580            memory_requirement,
581        }
582    }
583
584    /// Recommend embedding dimensions based on dataset size
585    fn recommend_dimensions(
586        &self,
587        characteristics: &DatasetCharacteristics,
588        profile: &ModelProfile,
589    ) -> usize {
590        let base_dim = if characteristics.num_entities < 1000 {
591            32
592        } else if characteristics.num_entities < 10000 {
593            64
594        } else if characteristics.num_entities < 100000 {
595            128
596        } else {
597            256
598        };
599
600        // Adjust for model complexity
601        if profile.complexity > 7 {
602            base_dim / 2 // Complex models can use lower dimensions
603        } else {
604            base_dim
605        }
606    }
607
608    /// Estimate training time
609    fn estimate_training_time(
610        &self,
611        characteristics: &DatasetCharacteristics,
612        profile: &ModelProfile,
613        _dimensions: usize,
614    ) -> TrainingTime {
615        let data_size_factor = characteristics.num_triples as f64 / 50000.0;
616        let speed_factor = profile.speed as f64 / 10.0;
617
618        let estimated_minutes = data_size_factor / speed_factor * 10.0;
619
620        if estimated_minutes < 5.0 {
621            TrainingTime::Fast
622        } else if estimated_minutes < 30.0 {
623            TrainingTime::Medium
624        } else if estimated_minutes < 60.0 {
625            TrainingTime::Slow
626        } else {
627            TrainingTime::VerySlow
628        }
629    }
630
631    /// Estimate memory requirement
632    fn estimate_memory_requirement(
633        &self,
634        characteristics: &DatasetCharacteristics,
635        dimensions: usize,
636    ) -> MemoryRequirement {
637        let memory_mb = characteristics.estimated_memory_mb(dimensions);
638
639        if memory_mb < 500.0 {
640            MemoryRequirement::Low
641        } else if memory_mb < 2000.0 {
642            MemoryRequirement::Medium
643        } else if memory_mb < 8000.0 {
644            MemoryRequirement::High
645        } else {
646            MemoryRequirement::VeryHigh
647        }
648    }
649
650    /// Generate reasoning explanation
651    fn generate_reasoning(
652        &self,
653        profile: &ModelProfile,
654        characteristics: &DatasetCharacteristics,
655        use_case: UseCaseType,
656    ) -> String {
657        let mut reasons = Vec::new();
658
659        if profile.best_for.contains(&use_case) {
660            reasons.push(format!("Well-suited for {:?}", use_case));
661        }
662
663        if characteristics.is_sparse && profile.handles_sparse {
664            reasons.push("Handles sparse graphs effectively".to_string());
665        }
666
667        if characteristics.has_hierarchies && profile.handles_hierarchies {
668            reasons.push("Good for hierarchical structures".to_string());
669        }
670
671        if characteristics.has_complex_relations && profile.handles_complex_relations {
672            reasons.push("Capable of modeling complex relations".to_string());
673        }
674
675        if profile.speed >= 8 {
676            reasons.push("Fast training and inference".to_string());
677        }
678
679        if profile.accuracy >= 8 {
680            reasons.push("High accuracy on benchmarks".to_string());
681        }
682
683        if reasons.is_empty() {
684            "General-purpose model".to_string()
685        } else {
686            reasons.join("; ")
687        }
688    }
689
690    /// Compare multiple models on the same criteria
691    pub fn compare_models(
692        &self,
693        models: &[ModelType],
694        characteristics: &DatasetCharacteristics,
695    ) -> Result<ModelComparison> {
696        if models.is_empty() {
697            return Err(anyhow!("No models provided for comparison"));
698        }
699
700        let mut comparisons = HashMap::new();
701
702        for model_type in models {
703            if let Some(profile) = self.model_profiles.get(model_type) {
704                let dimensions = self.recommend_dimensions(characteristics, profile);
705                let training_time =
706                    self.estimate_training_time(characteristics, profile, dimensions);
707                let memory_req = self.estimate_memory_requirement(characteristics, dimensions);
708
709                comparisons.insert(
710                    *model_type,
711                    ModelComparisonEntry {
712                        model_type: *model_type,
713                        complexity: profile.complexity,
714                        speed: profile.speed,
715                        accuracy: profile.accuracy,
716                        recommended_dimensions: dimensions,
717                        estimated_training_time: training_time,
718                        memory_requirement: memory_req,
719                    },
720                );
721            }
722        }
723
724        Ok(ModelComparison {
725            models: comparisons,
726            dataset_size: characteristics.num_triples,
727        })
728    }
729}
730
731/// Comparison result for multiple models
732#[derive(Debug, Clone, Serialize, Deserialize)]
733pub struct ModelComparison {
734    pub models: HashMap<ModelType, ModelComparisonEntry>,
735    pub dataset_size: usize,
736}
737
738/// Entry in model comparison
739#[derive(Debug, Clone, Serialize, Deserialize)]
740pub struct ModelComparisonEntry {
741    pub model_type: ModelType,
742    pub complexity: u8,
743    pub speed: u8,
744    pub accuracy: u8,
745    pub recommended_dimensions: usize,
746    pub estimated_training_time: TrainingTime,
747    pub memory_requirement: MemoryRequirement,
748}
749
750#[cfg(test)]
751mod tests {
752    use super::*;
753
754    #[test]
755    fn test_dataset_characteristics_infer() {
756        let chars = DatasetCharacteristics::infer(1000, 10, 5000);
757        assert_eq!(chars.num_entities, 1000);
758        assert_eq!(chars.num_relations, 10);
759        assert_eq!(chars.num_triples, 5000);
760        assert!(chars.avg_degree > 0.0);
761    }
762
763    #[test]
764    fn test_dataset_density() {
765        let chars = DatasetCharacteristics {
766            num_entities: 100,
767            num_relations: 5,
768            num_triples: 500,
769            avg_degree: 5.0,
770            is_sparse: false,
771            has_hierarchies: false,
772            has_complex_relations: false,
773            domain: None,
774        };
775
776        let density = chars.density();
777        assert!(density > 0.0);
778        assert!(density < 1.0);
779    }
780
781    #[test]
782    fn test_model_selector_creation() {
783        let selector = ModelSelector::new();
784        assert!(!selector.model_profiles.is_empty());
785        assert!(selector.model_profiles.contains_key(&ModelType::TransE));
786        assert!(selector.model_profiles.contains_key(&ModelType::ComplEx));
787    }
788
789    #[test]
790    fn test_model_recommendation() -> Result<()> {
791        let selector = ModelSelector::new();
792        let characteristics = DatasetCharacteristics::infer(10000, 50, 50000);
793
794        let recommendations =
795            selector.recommend_models(&characteristics, UseCaseType::LinkPrediction)?;
796
797        assert!(!recommendations.is_empty());
798
799        // Check that recommendations are sorted by score
800        for i in 1..recommendations.len() {
801            assert!(
802                recommendations[i - 1].suitability_score >= recommendations[i].suitability_score
803            );
804        }
805
806        Ok(())
807    }
808
809    #[test]
810    fn test_model_comparison() -> Result<()> {
811        let selector = ModelSelector::new();
812        let characteristics = DatasetCharacteristics::infer(10000, 50, 50000);
813
814        let models = vec![ModelType::TransE, ModelType::ComplEx, ModelType::RotatE];
815        let comparison = selector.compare_models(&models, &characteristics)?;
816
817        assert_eq!(comparison.models.len(), 3);
818        assert!(comparison.models.contains_key(&ModelType::TransE));
819        assert!(comparison.models.contains_key(&ModelType::ComplEx));
820        assert!(comparison.models.contains_key(&ModelType::RotatE));
821
822        Ok(())
823    }
824
825    #[test]
826    fn test_small_dataset_recommendations() -> Result<()> {
827        let selector = ModelSelector::new();
828        let characteristics = DatasetCharacteristics::infer(100, 5, 500);
829
830        let recommendations =
831            selector.recommend_models(&characteristics, UseCaseType::GeneralPurpose)?;
832
833        // For small datasets, simpler models should be preferred
834        let top_model = &recommendations[0];
835        assert!(top_model.recommended_dimensions <= 64);
836
837        Ok(())
838    }
839
840    #[test]
841    fn test_large_dataset_recommendations() -> Result<()> {
842        let selector = ModelSelector::new();
843        let characteristics = DatasetCharacteristics::infer(100000, 100, 500000);
844
845        let recommendations =
846            selector.recommend_models(&characteristics, UseCaseType::LinkPrediction)?;
847
848        let top_model = &recommendations[0];
849        assert!(top_model.recommended_dimensions >= 64);
850
851        Ok(())
852    }
853
854    #[test]
855    fn test_memory_estimation() {
856        let characteristics = DatasetCharacteristics::infer(10000, 50, 50000);
857        let memory_mb = characteristics.estimated_memory_mb(128);
858
859        assert!(memory_mb > 0.0);
860        assert!(memory_mb < 10000.0); // Sanity check
861    }
862
863    #[test]
864    fn test_use_case_specific_recommendations() -> Result<()> {
865        let selector = ModelSelector::new();
866        let characteristics = DatasetCharacteristics::infer(10000, 50, 50000);
867
868        let link_pred_recs =
869            selector.recommend_models(&characteristics, UseCaseType::LinkPrediction)?;
870
871        let similarity_recs =
872            selector.recommend_models(&characteristics, UseCaseType::SimilaritySearch)?;
873
874        // Different use cases may prefer different models
875        assert!(!link_pred_recs.is_empty());
876        assert!(!similarity_recs.is_empty());
877
878        Ok(())
879    }
880}