use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, info};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetCharacteristics {
pub num_entities: usize,
pub num_relations: usize,
pub num_triples: usize,
pub avg_degree: f64,
pub is_sparse: bool,
pub has_hierarchies: bool,
pub has_complex_relations: bool,
pub domain: Option<String>,
}
impl DatasetCharacteristics {
pub fn infer(num_entities: usize, num_relations: usize, num_triples: usize) -> Self {
let avg_degree = if num_entities > 0 {
(num_triples as f64 * 2.0) / num_entities as f64
} else {
0.0
};
let is_sparse = avg_degree < (num_entities as f64).sqrt();
Self {
num_entities,
num_relations,
num_triples,
avg_degree,
is_sparse,
has_hierarchies: false, has_complex_relations: num_relations > 10,
domain: None,
}
}
pub fn density(&self) -> f64 {
if self.num_entities == 0 {
return 0.0;
}
let max_possible = (self.num_entities * (self.num_entities - 1)) as f64;
if max_possible == 0.0 {
return 0.0;
}
self.num_triples as f64 / max_possible
}
pub fn estimated_memory_mb(&self, embedding_dim: usize) -> f64 {
let entity_mem = (self.num_entities * embedding_dim * 4) as f64 / 1_048_576.0; let relation_mem = (self.num_relations * embedding_dim * 4) as f64 / 1_048_576.0;
let overhead = 50.0;
entity_mem + relation_mem + overhead
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum UseCaseType {
LinkPrediction,
EntityClassification,
RelationExtraction,
QuestionAnswering,
KGCompletion,
SimilaritySearch,
GeneralPurpose,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelType {
TransE,
DistMult,
ComplEx,
RotatE,
HolE,
ConvE,
TuckER,
QuatD,
GNN,
Transformer,
}
impl std::fmt::Display for ModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelType::TransE => write!(f, "TransE"),
ModelType::DistMult => write!(f, "DistMult"),
ModelType::ComplEx => write!(f, "ComplEx"),
ModelType::RotatE => write!(f, "RotatE"),
ModelType::HolE => write!(f, "HolE"),
ModelType::ConvE => write!(f, "ConvE"),
ModelType::TuckER => write!(f, "TuckER"),
ModelType::QuatD => write!(f, "QuatD"),
ModelType::GNN => write!(f, "GNN"),
ModelType::Transformer => write!(f, "Transformer"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelRecommendation {
pub model_type: ModelType,
pub suitability_score: f64,
pub reasoning: String,
pub pros: Vec<String>,
pub cons: Vec<String>,
pub recommended_dimensions: usize,
pub estimated_training_time: TrainingTime,
pub memory_requirement: MemoryRequirement,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TrainingTime {
Fast, Medium, Slow, VerySlow, }
impl std::fmt::Display for TrainingTime {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TrainingTime::Fast => write!(f, "Fast (< 5 min)"),
TrainingTime::Medium => write!(f, "Medium (5-30 min)"),
TrainingTime::Slow => write!(f, "Slow (30-60 min)"),
TrainingTime::VerySlow => write!(f, "Very Slow (> 1 hour)"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MemoryRequirement {
Low, Medium, High, VeryHigh, }
impl std::fmt::Display for MemoryRequirement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MemoryRequirement::Low => write!(f, "Low (< 500 MB)"),
MemoryRequirement::Medium => write!(f, "Medium (500 MB - 2 GB)"),
MemoryRequirement::High => write!(f, "High (2 GB - 8 GB)"),
MemoryRequirement::VeryHigh => write!(f, "Very High (> 8 GB)"),
}
}
}
pub struct ModelSelector {
model_profiles: HashMap<ModelType, ModelProfile>,
}
#[derive(Debug, Clone)]
struct ModelProfile {
model_type: ModelType,
strengths: Vec<String>,
weaknesses: Vec<String>,
best_for: Vec<UseCaseType>,
complexity: u8,
speed: u8,
accuracy: u8,
handles_sparse: bool,
handles_hierarchies: bool,
handles_complex_relations: bool,
}
impl Default for ModelSelector {
fn default() -> Self {
Self::new()
}
}
impl ModelSelector {
pub fn new() -> Self {
let mut model_profiles = HashMap::new();
model_profiles.insert(
ModelType::TransE,
ModelProfile {
model_type: ModelType::TransE,
strengths: vec![
"Simple and efficient".to_string(),
"Good for hierarchical relations".to_string(),
"Fast training".to_string(),
],
weaknesses: vec![
"Cannot model symmetric relations well".to_string(),
"Limited expressiveness".to_string(),
],
best_for: vec![UseCaseType::LinkPrediction, UseCaseType::GeneralPurpose],
complexity: 2,
speed: 9,
accuracy: 6,
handles_sparse: true,
handles_hierarchies: true,
handles_complex_relations: false,
},
);
model_profiles.insert(
ModelType::DistMult,
ModelProfile {
model_type: ModelType::DistMult,
strengths: vec![
"Very fast".to_string(),
"Good for symmetric relations".to_string(),
"Low memory footprint".to_string(),
],
weaknesses: vec![
"Cannot model asymmetric relations".to_string(),
"Cannot capture composition".to_string(),
],
best_for: vec![
UseCaseType::SimilaritySearch,
UseCaseType::EntityClassification,
],
complexity: 1,
speed: 10,
accuracy: 5,
handles_sparse: true,
handles_hierarchies: false,
handles_complex_relations: false,
},
);
model_profiles.insert(
ModelType::ComplEx,
ModelProfile {
model_type: ModelType::ComplEx,
strengths: vec![
"Handles symmetric and asymmetric relations".to_string(),
"Good theoretical properties".to_string(),
"State-of-the-art performance".to_string(),
],
weaknesses: vec![
"More complex than TransE".to_string(),
"Requires more memory".to_string(),
],
best_for: vec![UseCaseType::LinkPrediction, UseCaseType::KGCompletion],
complexity: 5,
speed: 7,
accuracy: 8,
handles_sparse: true,
handles_hierarchies: true,
handles_complex_relations: true,
},
);
model_profiles.insert(
ModelType::RotatE,
ModelProfile {
model_type: ModelType::RotatE,
strengths: vec![
"Excellent for complex relations".to_string(),
"Handles composition patterns".to_string(),
"Strong theoretical foundation".to_string(),
],
weaknesses: vec![
"Slower than simpler models".to_string(),
"Higher memory usage".to_string(),
],
best_for: vec![UseCaseType::LinkPrediction, UseCaseType::RelationExtraction],
complexity: 6,
speed: 6,
accuracy: 9,
handles_sparse: true,
handles_hierarchies: true,
handles_complex_relations: true,
},
);
model_profiles.insert(
ModelType::HolE,
ModelProfile {
model_type: ModelType::HolE,
strengths: vec![
"Memory efficient".to_string(),
"Good compositional properties".to_string(),
"Fast inference".to_string(),
],
weaknesses: vec![
"Training can be slower".to_string(),
"Less intuitive than TransE".to_string(),
],
best_for: vec![UseCaseType::KGCompletion, UseCaseType::LinkPrediction],
complexity: 5,
speed: 7,
accuracy: 7,
handles_sparse: true,
handles_hierarchies: false,
handles_complex_relations: true,
},
);
model_profiles.insert(
ModelType::ConvE,
ModelProfile {
model_type: ModelType::ConvE,
strengths: vec![
"State-of-the-art accuracy".to_string(),
"Captures complex patterns".to_string(),
"Scalable to large graphs".to_string(),
],
weaknesses: vec![
"Requires more computational resources".to_string(),
"More complex to tune".to_string(),
"Slower training".to_string(),
],
best_for: vec![UseCaseType::LinkPrediction, UseCaseType::KGCompletion],
complexity: 8,
speed: 4,
accuracy: 9,
handles_sparse: false,
handles_hierarchies: true,
handles_complex_relations: true,
},
);
model_profiles.insert(
ModelType::GNN,
ModelProfile {
model_type: ModelType::GNN,
strengths: vec![
"Leverages graph structure".to_string(),
"Good for node classification".to_string(),
"Captures neighborhood information".to_string(),
],
weaknesses: vec![
"Computationally expensive".to_string(),
"Not ideal for very large graphs".to_string(),
],
best_for: vec![
UseCaseType::EntityClassification,
UseCaseType::QuestionAnswering,
],
complexity: 7,
speed: 5,
accuracy: 8,
handles_sparse: false,
handles_hierarchies: true,
handles_complex_relations: true,
},
);
model_profiles.insert(
ModelType::Transformer,
ModelProfile {
model_type: ModelType::Transformer,
strengths: vec![
"Excellent for complex patterns".to_string(),
"State-of-the-art on many tasks".to_string(),
"Flexible architecture".to_string(),
],
weaknesses: vec![
"Very computationally expensive".to_string(),
"Requires large amounts of data".to_string(),
"High memory usage".to_string(),
],
best_for: vec![UseCaseType::QuestionAnswering, UseCaseType::GeneralPurpose],
complexity: 9,
speed: 3,
accuracy: 9,
handles_sparse: false,
handles_hierarchies: true,
handles_complex_relations: true,
},
);
Self { model_profiles }
}
pub fn recommend_models(
&self,
characteristics: &DatasetCharacteristics,
use_case: UseCaseType,
) -> Result<Vec<ModelRecommendation>> {
info!(
"Recommending models for dataset with {} entities, {} relations, {} triples",
characteristics.num_entities,
characteristics.num_relations,
characteristics.num_triples
);
let mut recommendations = Vec::new();
for (model_type, profile) in &self.model_profiles {
let score = self.calculate_suitability_score(profile, characteristics, use_case);
if score > 0.3 {
let recommendation = self.create_recommendation(
*model_type,
profile,
characteristics,
score,
use_case,
);
recommendations.push(recommendation);
}
}
recommendations.sort_by(|a, b| {
b.suitability_score
.partial_cmp(&a.suitability_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
debug!("Generated {} model recommendations", recommendations.len());
Ok(recommendations)
}
fn calculate_suitability_score(
&self,
profile: &ModelProfile,
characteristics: &DatasetCharacteristics,
use_case: UseCaseType,
) -> f64 {
let mut score: f64 = 0.5;
if profile.best_for.contains(&use_case) {
score += 0.3;
}
if characteristics.is_sparse && profile.handles_sparse {
score += 0.1;
}
if characteristics.has_hierarchies && profile.handles_hierarchies {
score += 0.1;
}
if characteristics.has_complex_relations && profile.handles_complex_relations {
score += 0.1;
}
if characteristics.num_triples < 10000 && profile.complexity > 6 {
score -= 0.2;
}
if characteristics.num_triples > 100000 && profile.speed < 5 {
score -= 0.1;
}
if use_case == UseCaseType::LinkPrediction && profile.accuracy >= 8 {
score += 0.1;
}
score.clamp(0.0, 1.0)
}
fn create_recommendation(
&self,
model_type: ModelType,
profile: &ModelProfile,
characteristics: &DatasetCharacteristics,
score: f64,
use_case: UseCaseType,
) -> ModelRecommendation {
let recommended_dimensions = self.recommend_dimensions(characteristics, profile);
let training_time =
self.estimate_training_time(characteristics, profile, recommended_dimensions);
let memory_requirement =
self.estimate_memory_requirement(characteristics, recommended_dimensions);
let reasoning = self.generate_reasoning(profile, characteristics, use_case);
ModelRecommendation {
model_type,
suitability_score: score,
reasoning,
pros: profile.strengths.clone(),
cons: profile.weaknesses.clone(),
recommended_dimensions,
estimated_training_time: training_time,
memory_requirement,
}
}
fn recommend_dimensions(
&self,
characteristics: &DatasetCharacteristics,
profile: &ModelProfile,
) -> usize {
let base_dim = if characteristics.num_entities < 1000 {
32
} else if characteristics.num_entities < 10000 {
64
} else if characteristics.num_entities < 100000 {
128
} else {
256
};
if profile.complexity > 7 {
base_dim / 2 } else {
base_dim
}
}
fn estimate_training_time(
&self,
characteristics: &DatasetCharacteristics,
profile: &ModelProfile,
_dimensions: usize,
) -> TrainingTime {
let data_size_factor = characteristics.num_triples as f64 / 50000.0;
let speed_factor = profile.speed as f64 / 10.0;
let estimated_minutes = data_size_factor / speed_factor * 10.0;
if estimated_minutes < 5.0 {
TrainingTime::Fast
} else if estimated_minutes < 30.0 {
TrainingTime::Medium
} else if estimated_minutes < 60.0 {
TrainingTime::Slow
} else {
TrainingTime::VerySlow
}
}
fn estimate_memory_requirement(
&self,
characteristics: &DatasetCharacteristics,
dimensions: usize,
) -> MemoryRequirement {
let memory_mb = characteristics.estimated_memory_mb(dimensions);
if memory_mb < 500.0 {
MemoryRequirement::Low
} else if memory_mb < 2000.0 {
MemoryRequirement::Medium
} else if memory_mb < 8000.0 {
MemoryRequirement::High
} else {
MemoryRequirement::VeryHigh
}
}
fn generate_reasoning(
&self,
profile: &ModelProfile,
characteristics: &DatasetCharacteristics,
use_case: UseCaseType,
) -> String {
let mut reasons = Vec::new();
if profile.best_for.contains(&use_case) {
reasons.push(format!("Well-suited for {:?}", use_case));
}
if characteristics.is_sparse && profile.handles_sparse {
reasons.push("Handles sparse graphs effectively".to_string());
}
if characteristics.has_hierarchies && profile.handles_hierarchies {
reasons.push("Good for hierarchical structures".to_string());
}
if characteristics.has_complex_relations && profile.handles_complex_relations {
reasons.push("Capable of modeling complex relations".to_string());
}
if profile.speed >= 8 {
reasons.push("Fast training and inference".to_string());
}
if profile.accuracy >= 8 {
reasons.push("High accuracy on benchmarks".to_string());
}
if reasons.is_empty() {
"General-purpose model".to_string()
} else {
reasons.join("; ")
}
}
pub fn compare_models(
&self,
models: &[ModelType],
characteristics: &DatasetCharacteristics,
) -> Result<ModelComparison> {
if models.is_empty() {
return Err(anyhow!("No models provided for comparison"));
}
let mut comparisons = HashMap::new();
for model_type in models {
if let Some(profile) = self.model_profiles.get(model_type) {
let dimensions = self.recommend_dimensions(characteristics, profile);
let training_time =
self.estimate_training_time(characteristics, profile, dimensions);
let memory_req = self.estimate_memory_requirement(characteristics, dimensions);
comparisons.insert(
*model_type,
ModelComparisonEntry {
model_type: *model_type,
complexity: profile.complexity,
speed: profile.speed,
accuracy: profile.accuracy,
recommended_dimensions: dimensions,
estimated_training_time: training_time,
memory_requirement: memory_req,
},
);
}
}
Ok(ModelComparison {
models: comparisons,
dataset_size: characteristics.num_triples,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelComparison {
pub models: HashMap<ModelType, ModelComparisonEntry>,
pub dataset_size: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelComparisonEntry {
pub model_type: ModelType,
pub complexity: u8,
pub speed: u8,
pub accuracy: u8,
pub recommended_dimensions: usize,
pub estimated_training_time: TrainingTime,
pub memory_requirement: MemoryRequirement,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dataset_characteristics_infer() {
let chars = DatasetCharacteristics::infer(1000, 10, 5000);
assert_eq!(chars.num_entities, 1000);
assert_eq!(chars.num_relations, 10);
assert_eq!(chars.num_triples, 5000);
assert!(chars.avg_degree > 0.0);
}
#[test]
fn test_dataset_density() {
let chars = DatasetCharacteristics {
num_entities: 100,
num_relations: 5,
num_triples: 500,
avg_degree: 5.0,
is_sparse: false,
has_hierarchies: false,
has_complex_relations: false,
domain: None,
};
let density = chars.density();
assert!(density > 0.0);
assert!(density < 1.0);
}
#[test]
fn test_model_selector_creation() {
let selector = ModelSelector::new();
assert!(!selector.model_profiles.is_empty());
assert!(selector.model_profiles.contains_key(&ModelType::TransE));
assert!(selector.model_profiles.contains_key(&ModelType::ComplEx));
}
#[test]
fn test_model_recommendation() -> Result<()> {
let selector = ModelSelector::new();
let characteristics = DatasetCharacteristics::infer(10000, 50, 50000);
let recommendations =
selector.recommend_models(&characteristics, UseCaseType::LinkPrediction)?;
assert!(!recommendations.is_empty());
for i in 1..recommendations.len() {
assert!(
recommendations[i - 1].suitability_score >= recommendations[i].suitability_score
);
}
Ok(())
}
#[test]
fn test_model_comparison() -> Result<()> {
let selector = ModelSelector::new();
let characteristics = DatasetCharacteristics::infer(10000, 50, 50000);
let models = vec![ModelType::TransE, ModelType::ComplEx, ModelType::RotatE];
let comparison = selector.compare_models(&models, &characteristics)?;
assert_eq!(comparison.models.len(), 3);
assert!(comparison.models.contains_key(&ModelType::TransE));
assert!(comparison.models.contains_key(&ModelType::ComplEx));
assert!(comparison.models.contains_key(&ModelType::RotatE));
Ok(())
}
#[test]
fn test_small_dataset_recommendations() -> Result<()> {
let selector = ModelSelector::new();
let characteristics = DatasetCharacteristics::infer(100, 5, 500);
let recommendations =
selector.recommend_models(&characteristics, UseCaseType::GeneralPurpose)?;
let top_model = &recommendations[0];
assert!(top_model.recommended_dimensions <= 64);
Ok(())
}
#[test]
fn test_large_dataset_recommendations() -> Result<()> {
let selector = ModelSelector::new();
let characteristics = DatasetCharacteristics::infer(100000, 100, 500000);
let recommendations =
selector.recommend_models(&characteristics, UseCaseType::LinkPrediction)?;
let top_model = &recommendations[0];
assert!(top_model.recommended_dimensions >= 64);
Ok(())
}
#[test]
fn test_memory_estimation() {
let characteristics = DatasetCharacteristics::infer(10000, 50, 50000);
let memory_mb = characteristics.estimated_memory_mb(128);
assert!(memory_mb > 0.0);
assert!(memory_mb < 10000.0); }
#[test]
fn test_use_case_specific_recommendations() -> Result<()> {
let selector = ModelSelector::new();
let characteristics = DatasetCharacteristics::infer(10000, 50, 50000);
let link_pred_recs =
selector.recommend_models(&characteristics, UseCaseType::LinkPrediction)?;
let similarity_recs =
selector.recommend_models(&characteristics, UseCaseType::SimilaritySearch)?;
assert!(!link_pred_recs.is_empty());
assert!(!similarity_recs.is_empty());
Ok(())
}
}