pub mod box_trainer;
pub mod cone_trainer;
pub mod evaluation;
pub mod negative_sampling;
#[cfg(feature = "candle-backend")]
#[cfg_attr(docsrs, doc(cfg(feature = "candle-backend")))]
pub mod candle_trainer;
#[cfg(feature = "candle-backend")]
#[cfg_attr(docsrs, doc(cfg(feature = "candle-backend")))]
pub mod candle_el_trainer;
#[cfg(feature = "candle-backend")]
#[cfg_attr(docsrs, doc(cfg(feature = "candle-backend")))]
pub mod candle_cone_trainer;
use crate::BoxError;
#[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum RelationTransform {
#[default]
Identity,
Translation(Vec<f32>),
}
impl RelationTransform {
#[inline]
pub fn is_identity(&self) -> bool {
matches!(self, RelationTransform::Identity)
}
pub fn apply_to_bounds(&self, min: &[f32], max: &[f32]) -> (Vec<f32>, Vec<f32>) {
debug_assert_eq!(min.len(), max.len(), "min/max length mismatch");
match self {
RelationTransform::Identity => (min.to_vec(), max.to_vec()),
RelationTransform::Translation(offset) => {
debug_assert_eq!(
offset.len(),
min.len(),
"Translation offset length ({}) != bounds length ({})",
offset.len(),
min.len()
);
let new_min: Vec<f32> = min.iter().zip(offset).map(|(m, d)| m + d).collect();
let new_max: Vec<f32> = max.iter().zip(offset).map(|(m, d)| m + d).collect();
(new_min, new_max)
}
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum NegativeSamplingStrategy {
Uniform,
CorruptHead,
CorruptTail,
CorruptBoth,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TrainingConfig {
pub learning_rate: f32,
pub epochs: usize,
pub batch_size: usize,
pub negative_samples: usize,
pub negative_strategy: NegativeSamplingStrategy,
pub margin: f32,
pub early_stopping_patience: Option<usize>,
pub early_stopping_min_delta: f32,
pub regularization: f32,
pub warmup_epochs: usize,
pub negative_weight: f32,
#[serde(alias = "gumbel_beta")]
pub softplus_beta: f32,
#[serde(alias = "gumbel_beta_final")]
pub softplus_beta_final: f32,
pub max_grad_norm: f32,
pub adversarial_temperature: f32,
pub use_infonce: bool,
#[serde(default)]
pub symmetric_loss: bool,
#[serde(default)]
pub self_adversarial: bool,
#[serde(default)]
pub bernoulli_sampling: bool,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
learning_rate: 1e-3, epochs: 100,
batch_size: 512, negative_samples: 1,
negative_strategy: NegativeSamplingStrategy::CorruptTail,
margin: 1.0, early_stopping_patience: Some(10), early_stopping_min_delta: 0.001,
regularization: 0.0001,
warmup_epochs: 10,
negative_weight: 1.0,
softplus_beta: 10.0,
softplus_beta_final: 50.0,
max_grad_norm: 10.0,
adversarial_temperature: 1.0,
use_infonce: false,
symmetric_loss: false,
self_adversarial: false,
bernoulli_sampling: false,
}
}
}
impl TrainingConfig {
pub fn validate(&self) -> Result<(), BoxError> {
if !self.learning_rate.is_finite() || self.learning_rate <= 0.0 {
return Err(BoxError::Internal(format!(
"learning_rate must be positive and finite, got {}",
self.learning_rate
)));
}
if self.batch_size == 0 {
return Err(BoxError::Internal("batch_size must be > 0".to_string()));
}
if self.negative_samples == 0 {
return Err(BoxError::Internal(
"negative_samples must be > 0".to_string(),
));
}
if !self.margin.is_finite() || self.margin < 0.0 {
return Err(BoxError::Internal(format!(
"margin must be non-negative and finite, got {}",
self.margin
)));
}
if !self.softplus_beta.is_finite() || self.softplus_beta <= 0.0 {
return Err(BoxError::Internal(format!(
"softplus_beta must be positive and finite, got {}",
self.softplus_beta
)));
}
if !self.max_grad_norm.is_finite() || self.max_grad_norm <= 0.0 {
return Err(BoxError::Internal(format!(
"max_grad_norm must be positive and finite, got {}",
self.max_grad_norm
)));
}
Ok(())
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct EvaluationResults {
pub mrr: f32,
pub head_mrr: f32,
pub tail_mrr: f32,
pub hits_at_1: f32,
pub hits_at_3: f32,
pub hits_at_10: f32,
pub mean_rank: f32,
pub per_relation: Vec<PerRelationResults>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PerRelationResults {
pub relation: String,
pub mrr: f32,
pub hits_at_10: f32,
pub count: usize,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TrainingResult {
pub final_results: EvaluationResults,
pub loss_history: Vec<f32>,
pub validation_mrr_history: Vec<f32>,
pub best_epoch: usize,
pub training_time_seconds: Option<f64>,
}
pub use negative_sampling::{compute_relation_cardinalities, RelationCardinality};
pub use evaluation::{
evaluate_link_prediction, evaluate_link_prediction_filtered, evaluate_link_prediction_interned,
evaluate_link_prediction_interned_filtered, FilteredTripleIndex, FilteredTripleIndexIds,
};
#[cfg(feature = "ndarray-backend")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray-backend")))]
pub use evaluation::evaluate_link_prediction_interned_with_transforms;
pub use box_trainer::{compute_analytical_gradients, compute_pair_loss, BoxEmbeddingTrainer};
pub use cone_trainer::{
compute_cone_analytical_gradients, compute_cone_pair_loss, ConeEmbeddingTrainer,
};