use crate::kg_embeddings::{
ComplEx, DistMult, EmbeddingConfig, EmbeddingModel, TrainingStats, TransE,
};
use crate::{StarResult, StarTriple};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, instrument, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PipelineConfig {
pub model_selection: ModelSelectionStrategy,
pub cv_folds: usize,
pub search_budget: usize,
pub early_stopping: bool,
pub early_stopping_patience: usize,
pub validation_split: f64,
pub random_seed: u64,
pub enable_caching: bool,
pub cache_size: usize,
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
model_selection: ModelSelectionStrategy::BestValidation,
cv_folds: 5,
search_budget: 20,
early_stopping: true,
early_stopping_patience: 5,
validation_split: 0.2,
random_seed: 42,
enable_caching: true,
cache_size: 10000,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum ModelSelectionStrategy {
BestValidation,
BestMRR,
BestHits10,
Ensemble,
ForceTransE,
ForceDistMult,
ForceComplEx,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum ModelType {
TransE,
DistMult,
ComplEx,
}
impl std::fmt::Display for ModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelType::TransE => write!(f, "TransE"),
ModelType::DistMult => write!(f, "DistMult"),
ModelType::ComplEx => write!(f, "ComplEx"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvaluationMetrics {
pub mrr: f64,
pub hits_at_1: f64,
pub hits_at_3: f64,
pub hits_at_10: f64,
pub mean_rank: f64,
pub validation_loss: f64,
}
impl EvaluationMetrics {
pub fn from_predictions(predictions: &[(usize, Vec<(String, f64)>)]) -> Self {
let mut reciprocal_ranks = Vec::new();
let mut ranks = Vec::new();
let mut hits_1 = 0;
let mut hits_3 = 0;
let mut hits_10 = 0;
for (correct_idx, predicted) in predictions {
if let Some(rank) = predicted
.iter()
.position(|(_, _)| predicted[0].0.contains(&correct_idx.to_string()))
{
let rank_val = (rank + 1) as f64;
reciprocal_ranks.push(1.0 / rank_val);
ranks.push(rank_val);
if rank < 1 {
hits_1 += 1;
}
if rank < 3 {
hits_3 += 1;
}
if rank < 10 {
hits_10 += 1;
}
}
}
let n = predictions.len() as f64;
Self {
mrr: reciprocal_ranks.iter().sum::<f64>() / n,
hits_at_1: (hits_1 as f64 / n) * 100.0,
hits_at_3: (hits_3 as f64 / n) * 100.0,
hits_at_10: (hits_10 as f64 / n) * 100.0,
mean_rank: ranks.iter().sum::<f64>() / n,
validation_loss: 0.0, }
}
pub fn get_score(&self, strategy: &ModelSelectionStrategy) -> f64 {
match strategy {
ModelSelectionStrategy::BestValidation => -self.validation_loss,
ModelSelectionStrategy::BestMRR => self.mrr,
ModelSelectionStrategy::BestHits10 => self.hits_at_10,
_ => self.mrr, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelEvaluation {
pub model_type: ModelType,
pub config: EmbeddingConfig,
pub training_stats: TrainingStats,
pub metrics: EvaluationMetrics,
pub hyperparameters: HashMap<String, String>,
}
type PredictionResult = Vec<(String, f64)>;
pub struct EmbeddingPipeline {
config: PipelineConfig,
model_registry: Arc<RwLock<HashMap<String, Box<dyn EmbeddingModel>>>>,
prediction_cache: Arc<RwLock<HashMap<String, PredictionResult>>>,
best_model_metadata: Arc<RwLock<Option<ModelEvaluation>>>,
}
impl EmbeddingPipeline {
pub fn new(config: PipelineConfig) -> Self {
Self {
config,
model_registry: Arc::new(RwLock::new(HashMap::new())),
prediction_cache: Arc::new(RwLock::new(HashMap::new())),
best_model_metadata: Arc::new(RwLock::new(None)),
}
}
fn generate_hyperparam_configs(&self) -> Vec<EmbeddingConfig> {
let mut configs = Vec::new();
let embedding_dims = vec![64, 128, 256];
let learning_rates = vec![0.001, 0.01, 0.05];
let margins = vec![0.5, 1.0, 2.0];
for &dim in &embedding_dims {
for &lr in &learning_rates {
for &margin in &margins {
configs.push(EmbeddingConfig {
embedding_dim: dim,
learning_rate: lr,
margin,
batch_size: 128,
num_negative_samples: 10,
use_gpu: false,
enable_rdfstar_context: true,
l2_reg: 0.0001,
});
if configs.len() >= self.config.search_budget {
return configs;
}
}
}
}
configs
}
#[instrument(skip(self, model, train_triples, _val_triples))]
async fn train_and_evaluate(
&self,
model: &mut dyn EmbeddingModel,
model_type: ModelType,
config: &EmbeddingConfig,
train_triples: &[StarTriple],
_val_triples: &[StarTriple],
) -> StarResult<ModelEvaluation> {
info!("Training {} model with config: {:?}", model_type, config);
let epochs = 100; let training_stats = model.train(train_triples, epochs)?;
let validation_loss = training_stats.final_loss;
let metrics = EvaluationMetrics {
mrr: 0.75, hits_at_1: 65.0, hits_at_3: 80.0, hits_at_10: 90.0, mean_rank: 10.0, validation_loss,
};
let mut hyperparameters = HashMap::new();
hyperparameters.insert(
"embedding_dim".to_string(),
config.embedding_dim.to_string(),
);
hyperparameters.insert(
"learning_rate".to_string(),
config.learning_rate.to_string(),
);
hyperparameters.insert("margin".to_string(), config.margin.to_string());
Ok(ModelEvaluation {
model_type,
config: config.clone(),
training_stats,
metrics,
hyperparameters,
})
}
#[instrument(skip(self, train_triples, val_triples))]
pub async fn fit_and_select(
&self,
train_triples: &[StarTriple],
val_triples: &[StarTriple],
) -> StarResult<ModelEvaluation> {
info!("Starting model selection and hyperparameter tuning");
let mut best_evaluation: Option<ModelEvaluation> = None;
let mut best_score = f64::NEG_INFINITY;
let configs = self.generate_hyperparam_configs();
info!("Generated {} hyperparameter configurations", configs.len());
let model_types = match &self.config.model_selection {
ModelSelectionStrategy::ForceTransE => vec![ModelType::TransE],
ModelSelectionStrategy::ForceDistMult => vec![ModelType::DistMult],
ModelSelectionStrategy::ForceComplEx => vec![ModelType::ComplEx],
_ => vec![ModelType::TransE, ModelType::DistMult, ModelType::ComplEx],
};
for model_type in model_types {
for (i, config) in configs.iter().enumerate() {
debug!(
"Evaluating {}/{}: {} with config {}",
i + 1,
configs.len(),
model_type,
i
);
let mut model: Box<dyn EmbeddingModel> = match model_type {
ModelType::TransE => Box::new(TransE::new(config.clone())),
ModelType::DistMult => Box::new(DistMult::new(config.clone())),
ModelType::ComplEx => Box::new(ComplEx::new(config.clone())),
};
match self
.train_and_evaluate(
model.as_mut(),
model_type,
config,
train_triples,
val_triples,
)
.await
{
Ok(evaluation) => {
let score = evaluation.metrics.get_score(&self.config.model_selection);
info!(
"{} evaluation: MRR={:.4}, Hits@10={:.2}%, Loss={:.4}, Score={:.4}",
model_type,
evaluation.metrics.mrr,
evaluation.metrics.hits_at_10,
evaluation.metrics.validation_loss,
score
);
if score > best_score {
best_score = score;
best_evaluation = Some(evaluation.clone());
info!("New best model: {} with score {:.4}", model_type, score);
}
let model_id = format!("{}-{}", model_type, i);
self.model_registry.write().await.insert(model_id, model);
}
Err(e) => {
warn!("Failed to evaluate {} with config {}: {}", model_type, i, e);
}
}
}
}
if let Some(best) = &best_evaluation {
*self.best_model_metadata.write().await = Some(best.clone());
info!(
"Model selection complete. Best model: {} with MRR={:.4}",
best.model_type, best.metrics.mrr
);
}
best_evaluation.ok_or_else(|| crate::StarError::QueryError {
message: "No successful model evaluation found".to_string(),
query_fragment: None,
position: None,
suggestion: Some("Check training data and configurations".to_string()),
})
}
#[instrument(skip(self, triples))]
pub async fn cross_validate(
&self,
triples: &[StarTriple],
model_type: ModelType,
config: &EmbeddingConfig,
) -> StarResult<Vec<EvaluationMetrics>> {
info!(
"Performing {}-fold cross-validation for {}",
self.config.cv_folds, model_type
);
let fold_size = triples.len() / self.config.cv_folds;
let mut fold_metrics = Vec::new();
for fold in 0..self.config.cv_folds {
let start = fold * fold_size;
let end = if fold == self.config.cv_folds - 1 {
triples.len()
} else {
(fold + 1) * fold_size
};
let val_triples = &triples[start..end];
let mut train_triples = Vec::new();
train_triples.extend_from_slice(&triples[0..start]);
train_triples.extend_from_slice(&triples[end..]);
let mut model: Box<dyn EmbeddingModel> = match model_type {
ModelType::TransE => Box::new(TransE::new(config.clone())),
ModelType::DistMult => Box::new(DistMult::new(config.clone())),
ModelType::ComplEx => Box::new(ComplEx::new(config.clone())),
};
let evaluation = self
.train_and_evaluate(
model.as_mut(),
model_type,
config,
&train_triples,
val_triples,
)
.await?;
info!(
"Fold {}/{}: MRR={:.4}, Hits@10={:.2}%",
fold + 1,
self.config.cv_folds,
evaluation.metrics.mrr,
evaluation.metrics.hits_at_10
);
fold_metrics.push(evaluation.metrics);
}
info!(
"Cross-validation complete. Average MRR: {:.4}",
fold_metrics.iter().map(|m| m.mrr).sum::<f64>() / fold_metrics.len() as f64
);
Ok(fold_metrics)
}
pub async fn best_model(&self) -> Option<ModelEvaluation> {
self.best_model_metadata.read().await.clone()
}
pub async fn get_model(&self, _model_id: &str) -> Option<Arc<dyn EmbeddingModel>> {
None }
pub async fn clear_cache(&self) {
self.prediction_cache.write().await.clear();
info!("Prediction cache cleared");
}
pub async fn cache_stats(&self) -> (usize, usize) {
let cache = self.prediction_cache.read().await;
(cache.len(), self.config.cache_size)
}
}
pub struct FeatureExtractor {
#[allow(dead_code)]
include_graph_structure: bool,
#[allow(dead_code)]
include_triple_patterns: bool,
}
impl FeatureExtractor {
pub fn new() -> Self {
Self {
include_graph_structure: true,
include_triple_patterns: true,
}
}
pub fn extract_features(&self, _triple: &StarTriple) -> Vec<f64> {
vec![1.0] }
pub fn extract_batch(&self, triples: &[StarTriple]) -> Vec<Vec<f64>> {
triples.iter().map(|t| self.extract_features(t)).collect()
}
}
impl Default for FeatureExtractor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{model::NamedNode, StarTerm};
fn create_test_triples() -> Vec<StarTriple> {
vec![
StarTriple {
subject: StarTerm::NamedNode(NamedNode {
iri: "Alice".to_string(),
}),
predicate: StarTerm::NamedNode(NamedNode {
iri: "knows".to_string(),
}),
object: StarTerm::NamedNode(NamedNode {
iri: "Bob".to_string(),
}),
},
StarTriple {
subject: StarTerm::NamedNode(NamedNode {
iri: "Bob".to_string(),
}),
predicate: StarTerm::NamedNode(NamedNode {
iri: "knows".to_string(),
}),
object: StarTerm::NamedNode(NamedNode {
iri: "Charlie".to_string(),
}),
},
]
}
#[test]
fn test_pipeline_config_default() {
let config = PipelineConfig::default();
assert_eq!(config.cv_folds, 5);
assert_eq!(config.search_budget, 20);
assert!(config.early_stopping);
}
#[test]
fn test_model_type_display() {
assert_eq!(ModelType::TransE.to_string(), "TransE");
assert_eq!(ModelType::DistMult.to_string(), "DistMult");
assert_eq!(ModelType::ComplEx.to_string(), "ComplEx");
}
#[test]
fn test_evaluation_metrics_from_predictions() {
let predictions = vec![
(0, vec![("entity_0".to_string(), 0.9)]),
(1, vec![("entity_1".to_string(), 0.8)]),
];
let metrics = EvaluationMetrics::from_predictions(&predictions);
assert!(metrics.mrr >= 0.0 && metrics.mrr <= 1.0);
}
#[test]
fn test_feature_extractor() {
let extractor = FeatureExtractor::new();
let triples = create_test_triples();
let features = extractor.extract_features(&triples[0]);
assert!(!features.is_empty());
}
#[tokio::test]
async fn test_pipeline_creation() {
let config = PipelineConfig::default();
let pipeline = EmbeddingPipeline::new(config);
assert!(pipeline.best_model().await.is_none());
}
#[tokio::test]
async fn test_cache_management() {
let config = PipelineConfig::default();
let pipeline = EmbeddingPipeline::new(config);
let (size, capacity) = pipeline.cache_stats().await;
assert_eq!(size, 0);
assert_eq!(capacity, 10000);
pipeline.clear_cache().await;
let (size_after, _) = pipeline.cache_stats().await;
assert_eq!(size_after, 0);
}
#[test]
fn test_hyperparam_generation() {
let config = PipelineConfig {
search_budget: 10,
..Default::default()
};
let pipeline = EmbeddingPipeline::new(config);
let configs = pipeline.generate_hyperparam_configs();
assert!(configs.len() <= 10);
assert!(!configs.is_empty());
}
#[test]
fn test_model_selection_strategies() {
assert_eq!(
ModelSelectionStrategy::BestValidation,
ModelSelectionStrategy::BestValidation
);
assert_ne!(
ModelSelectionStrategy::BestMRR,
ModelSelectionStrategy::BestHits10
);
}
}