use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use crate::{EmbeddingModel, Triple};
fn cosine_sim(a: &[f64], b: &[f64]) -> f64 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
(dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
}
#[derive(Debug, Clone, Default)]
pub struct EmbeddingEvaluator;
impl EmbeddingEvaluator {
pub fn new() -> Self {
Self
}
pub fn link_prediction_mrr(&self, model: &dyn EmbeddingModel, test_triples: &[Triple]) -> f64 {
if test_triples.is_empty() {
return 0.0;
}
let entities = model.get_entities();
if entities.is_empty() {
return 0.0;
}
let reciprocal_ranks: Vec<f64> = test_triples
.iter()
.map(|triple| {
let head = &triple.subject.iri;
let rel = &triple.predicate.iri;
let tail = &triple.object.iri;
let mut scored: Vec<(String, f64)> = entities
.iter()
.filter_map(|cand| {
model
.score_triple(head, rel, cand)
.ok()
.map(|s| (cand.clone(), s))
})
.collect();
scored.sort_unstable_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
let rank = scored
.iter()
.position(|(cand, _)| cand == tail)
.map(|pos| pos + 1);
match rank {
Some(r) => 1.0 / r as f64,
None => 0.0,
}
})
.collect();
reciprocal_ranks.iter().sum::<f64>() / reciprocal_ranks.len() as f64
}
pub fn hits_at_k(&self, model: &dyn EmbeddingModel, test_triples: &[Triple], k: usize) -> f64 {
if test_triples.is_empty() || k == 0 {
return 0.0;
}
let entities = model.get_entities();
if entities.is_empty() {
return 0.0;
}
let hits: usize = test_triples
.iter()
.filter(|triple| {
let head = &triple.subject.iri;
let rel = &triple.predicate.iri;
let tail = &triple.object.iri;
let mut scored: Vec<(String, f64)> = entities
.iter()
.filter_map(|cand| {
model
.score_triple(head, rel, cand)
.ok()
.map(|s| (cand.clone(), s))
})
.collect();
scored.sort_unstable_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
scored.iter().take(k).any(|(cand, _)| cand == tail)
})
.count();
hits as f64 / test_triples.len() as f64
}
pub fn semantic_similarity(&self, emb1: &[f64], emb2: &[f64]) -> f64 {
cosine_sim(emb1, emb2)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AnalogyQuad {
pub a: String,
pub b: String,
pub c: String,
pub expected_d: String,
}
impl AnalogyQuad {
pub fn new(
a: impl Into<String>,
b: impl Into<String>,
c: impl Into<String>,
expected_d: impl Into<String>,
) -> Self {
Self {
a: a.into(),
b: b.into(),
c: c.into(),
expected_d: expected_d.into(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct AnalogicalReasoningBenchmark;
impl AnalogicalReasoningBenchmark {
pub fn new() -> Self {
Self
}
pub fn evaluate_analogy(&self, model: &dyn EmbeddingModel, quad: &AnalogyQuad) -> bool {
let get = |name: &str| -> Option<Vec<f64>> {
model
.get_entity_embedding(name)
.ok()
.map(|v| v.values.iter().map(|&x| x as f64).collect())
};
let emb_a = match get(&quad.a) {
Some(e) => e,
None => return false,
};
let emb_b = match get(&quad.b) {
Some(e) => e,
None => return false,
};
let emb_c = match get(&quad.c) {
Some(e) => e,
None => return false,
};
let dim = emb_a.len();
if dim == 0 || emb_b.len() != dim || emb_c.len() != dim {
return false;
}
let target: Vec<f64> = (0..dim).map(|i| emb_b[i] - emb_a[i] + emb_c[i]).collect();
let excluded = [quad.a.as_str(), quad.b.as_str(), quad.c.as_str()];
let entities = model.get_entities();
let best = entities
.iter()
.filter(|e| !excluded.contains(&e.as_str()))
.filter_map(|cand| {
get(cand).map(|emb| {
let sim = cosine_sim(&target, &emb);
(cand.clone(), sim)
})
})
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
match best {
Some((predicted, _)) => predicted == quad.expected_d,
None => false,
}
}
pub fn evaluate_analogies(&self, model: &dyn EmbeddingModel, quads: &[AnalogyQuad]) -> f64 {
if quads.is_empty() {
return 0.0;
}
let correct = quads
.iter()
.filter(|q| self.evaluate_analogy(model, q))
.count();
correct as f64 / quads.len() as f64
}
}
#[derive(Debug, Clone, Default)]
pub struct EmbeddingClusteringMetrics;
impl EmbeddingClusteringMetrics {
pub fn new() -> Self {
Self
}
fn euclidean_dist(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
fn cluster_centroid(points: &[&Vec<f64>]) -> Vec<f64> {
if points.is_empty() {
return Vec::new();
}
let dim = points[0].len();
let n = points.len() as f64;
let mut centroid = vec![0.0_f64; dim];
for p in points {
for (i, v) in p.iter().enumerate() {
centroid[i] += v;
}
}
centroid.iter_mut().for_each(|v| *v /= n);
centroid
}
pub fn silhouette_score(
&self,
embeddings: &[Vec<f64>],
cluster_assignments: &[usize],
) -> Result<f64> {
let n = embeddings.len();
if n != cluster_assignments.len() {
return Err(anyhow!(
"embeddings ({}) and cluster_assignments ({}) must have the same length",
n,
cluster_assignments.len()
));
}
if n < 2 {
return Err(anyhow!(
"need at least 2 samples to compute silhouette score"
));
}
let unique_clusters: std::collections::HashSet<usize> =
cluster_assignments.iter().copied().collect();
if unique_clusters.len() < 2 {
return Err(anyhow!(
"need at least 2 distinct clusters; found {}",
unique_clusters.len()
));
}
let mut scores = Vec::with_capacity(n);
for i in 0..n {
let c_i = cluster_assignments[i];
let dim = embeddings[i].len();
if dim == 0 {
continue;
}
let intra_dists: Vec<f64> = (0..n)
.filter(|&j| j != i && cluster_assignments[j] == c_i)
.map(|j| Self::euclidean_dist(&embeddings[i], &embeddings[j]))
.collect();
let a_i = if intra_dists.is_empty() {
0.0
} else {
intra_dists.iter().sum::<f64>() / intra_dists.len() as f64
};
let b_i = unique_clusters
.iter()
.filter(|&&c| c != c_i)
.map(|&c| {
let dists: Vec<f64> = (0..n)
.filter(|&j| cluster_assignments[j] == c)
.map(|j| Self::euclidean_dist(&embeddings[i], &embeddings[j]))
.collect();
if dists.is_empty() {
f64::INFINITY
} else {
dists.iter().sum::<f64>() / dists.len() as f64
}
})
.fold(f64::INFINITY, f64::min);
let s_i = if b_i == f64::INFINITY || (a_i == 0.0 && b_i == 0.0) {
0.0
} else {
let denom = a_i.max(b_i);
if denom == 0.0 {
0.0
} else {
(b_i - a_i) / denom
}
};
scores.push(s_i);
}
if scores.is_empty() {
return Ok(0.0);
}
Ok(scores.iter().sum::<f64>() / scores.len() as f64)
}
pub fn davies_bouldin_index(
&self,
embeddings: &[Vec<f64>],
cluster_assignments: &[usize],
) -> Result<f64> {
let n = embeddings.len();
if n != cluster_assignments.len() {
return Err(anyhow!(
"embeddings ({}) and cluster_assignments ({}) must have the same length",
n,
cluster_assignments.len()
));
}
if n < 2 {
return Err(anyhow!(
"need at least 2 samples to compute Davies-Bouldin index"
));
}
let mut cluster_ids: Vec<usize> = cluster_assignments.to_vec();
cluster_ids.sort_unstable();
cluster_ids.dedup();
if cluster_ids.len() < 2 {
return Err(anyhow!(
"need at least 2 distinct clusters; found {}",
cluster_ids.len()
));
}
let cluster_points: std::collections::HashMap<usize, Vec<&Vec<f64>>> = cluster_ids
.iter()
.fold(std::collections::HashMap::new(), |mut acc, &cid| {
let pts: Vec<&Vec<f64>> = embeddings
.iter()
.zip(cluster_assignments.iter())
.filter(|(_, &a)| a == cid)
.map(|(e, _)| e)
.collect();
acc.insert(cid, pts);
acc
});
let centroids: std::collections::HashMap<usize, Vec<f64>> = cluster_ids
.iter()
.map(|&cid| {
let pts = &cluster_points[&cid];
let centroid = Self::cluster_centroid(pts);
(cid, centroid)
})
.collect();
let scatter: std::collections::HashMap<usize, f64> = cluster_ids
.iter()
.map(|&cid| {
let pts = &cluster_points[&cid];
let c = ¢roids[&cid];
let s = if pts.is_empty() {
0.0
} else {
let total: f64 = pts.iter().map(|p| Self::euclidean_dist(p, c)).sum();
total / pts.len() as f64
};
(cid, s)
})
.collect();
let k = cluster_ids.len() as f64;
let db_sum: f64 = cluster_ids
.iter()
.map(|&ci| {
let max_r = cluster_ids
.iter()
.filter(|&&cj| cj != ci)
.map(|&cj| {
let d = Self::euclidean_dist(¢roids[&ci], ¢roids[&cj]);
if d == 0.0 {
0.0
} else {
(scatter[&ci] + scatter[&cj]) / d
}
})
.fold(f64::NEG_INFINITY, f64::max);
if max_r == f64::NEG_INFINITY {
0.0
} else {
max_r
}
})
.sum();
Ok(db_sum / k)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_sim_identical() {
let v = vec![1.0, 2.0, 3.0];
assert!((cosine_sim(&v, &v) - 1.0).abs() < 1e-9);
}
#[test]
fn test_cosine_sim_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!((cosine_sim(&a, &b)).abs() < 1e-9);
}
#[test]
fn test_cosine_sim_opposite() {
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
assert!((cosine_sim(&a, &b) + 1.0).abs() < 1e-9);
}
#[test]
fn test_cosine_sim_zero_vector() {
let a = vec![0.0, 0.0];
let b = vec![1.0, 2.0];
assert_eq!(cosine_sim(&a, &b), 0.0);
}
#[test]
fn test_semantic_similarity_via_evaluator() {
let ev = EmbeddingEvaluator::new();
let a = vec![1.0, 1.0];
let b = vec![1.0, 1.0];
assert!((ev.semantic_similarity(&a, &b) - 1.0).abs() < 1e-9);
}
#[test]
fn test_semantic_similarity_orthogonal() {
let ev = EmbeddingEvaluator::new();
assert!((ev.semantic_similarity(&[1.0, 0.0], &[0.0, 1.0])).abs() < 1e-9);
}
#[test]
fn test_semantic_similarity_empty() {
let ev = EmbeddingEvaluator::new();
assert_eq!(ev.semantic_similarity(&[], &[]), 0.0);
}
#[test]
fn test_mrr_empty_triples() {
let ev = EmbeddingEvaluator::new();
let mock = MockModel::new(vec!["e1".into(), "e2".into()]);
assert_eq!(ev.link_prediction_mrr(&mock, &[]), 0.0);
}
#[test]
fn test_hits_at_k_empty_triples() {
let ev = EmbeddingEvaluator::new();
let mock = MockModel::new(vec!["e1".into()]);
assert_eq!(ev.hits_at_k(&mock, &[], 3), 0.0);
}
#[test]
fn test_hits_at_k_zero_k() {
let ev = EmbeddingEvaluator::new();
let mock = MockModel::new(vec!["e1".into()]);
let triple = Triple::new(
crate::NamedNode::new("e1").expect("should succeed"),
crate::NamedNode::new("r").expect("should succeed"),
crate::NamedNode::new("e1").expect("should succeed"),
);
assert_eq!(ev.hits_at_k(&mock, &[triple], 0), 0.0);
}
#[test]
fn test_analogy_quad_construction() {
let q = AnalogyQuad::new("man", "king", "woman", "queen");
assert_eq!(q.a, "man");
assert_eq!(q.expected_d, "queen");
}
#[test]
fn test_analogy_quad_serialization() {
let q = AnalogyQuad::new("paris", "france", "berlin", "germany");
let json = serde_json::to_string(&q).expect("serialize");
let q2: AnalogyQuad = serde_json::from_str(&json).expect("deserialize");
assert_eq!(q, q2);
}
#[test]
fn test_evaluate_analogies_empty() {
let bench = AnalogicalReasoningBenchmark::new();
let mock = MockModel::new(vec![]);
assert_eq!(bench.evaluate_analogies(&mock, &[]), 0.0);
}
#[test]
fn test_silhouette_perfect_clusters() {
let metrics = EmbeddingClusteringMetrics::new();
let embeddings = vec![
vec![0.0, 0.0],
vec![0.1, 0.0],
vec![0.0, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.0],
vec![10.0, 10.1],
];
let assignments = vec![0, 0, 0, 1, 1, 1];
let score = metrics
.silhouette_score(&embeddings, &assignments)
.expect("ok");
assert!(score > 0.8, "expected high score, got {score}");
}
#[test]
fn test_silhouette_mismatched_lengths() {
let metrics = EmbeddingClusteringMetrics::new();
let result = metrics.silhouette_score(&[vec![1.0]], &[0, 1]);
assert!(result.is_err());
}
#[test]
fn test_silhouette_single_cluster_error() {
let metrics = EmbeddingClusteringMetrics::new();
let embeddings = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let assignments = vec![0, 0]; assert!(metrics.silhouette_score(&embeddings, &assignments).is_err());
}
#[test]
fn test_davies_bouldin_perfect_clusters() {
let metrics = EmbeddingClusteringMetrics::new();
let embeddings = vec![
vec![0.0, 0.0],
vec![0.05, 0.0],
vec![20.0, 20.0],
vec![20.05, 20.0],
];
let assignments = vec![0, 0, 1, 1];
let db = metrics
.davies_bouldin_index(&embeddings, &assignments)
.expect("ok");
assert!(db < 0.1, "expected low DB index, got {db}");
}
#[test]
fn test_davies_bouldin_mismatched_lengths() {
let metrics = EmbeddingClusteringMetrics::new();
let result = metrics.davies_bouldin_index(&[vec![1.0]], &[0, 1]);
assert!(result.is_err());
}
#[test]
fn test_davies_bouldin_single_cluster_error() {
let metrics = EmbeddingClusteringMetrics::new();
let embeddings = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let assignments = vec![0, 0];
assert!(metrics
.davies_bouldin_index(&embeddings, &assignments)
.is_err());
}
#[test]
fn test_silhouette_three_clusters() {
let metrics = EmbeddingClusteringMetrics::new();
let embeddings = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![10.0, 0.0],
vec![10.1, 0.1],
vec![5.0, 8.66],
vec![5.1, 8.76],
];
let assignments = vec![0, 0, 1, 1, 2, 2];
let score = metrics
.silhouette_score(&embeddings, &assignments)
.expect("ok");
assert!(score > 0.5, "expected positive silhouette, got {score}");
}
#[test]
fn test_davies_bouldin_three_clusters() {
let metrics = EmbeddingClusteringMetrics::new();
let embeddings = vec![
vec![0.0, 0.0],
vec![0.1, 0.0],
vec![10.0, 0.0],
vec![10.1, 0.0],
vec![5.0, 8.66],
vec![5.1, 8.66],
];
let assignments = vec![0, 0, 1, 1, 2, 2];
let db = metrics
.davies_bouldin_index(&embeddings, &assignments)
.expect("ok");
assert!(db < 0.5, "expected low DB index, got {db}");
}
use crate::{ModelConfig, ModelStats, NamedNode, TrainingStats, Vector};
use anyhow::Result as AResult;
use async_trait::async_trait;
use uuid::Uuid;
struct MockModel {
entities: Vec<String>,
id: Uuid,
config: ModelConfig,
}
impl MockModel {
fn new(entities: Vec<String>) -> Self {
Self {
entities,
id: Uuid::new_v4(),
config: ModelConfig::default(),
}
}
}
#[async_trait]
impl EmbeddingModel for MockModel {
fn config(&self) -> &ModelConfig {
&self.config
}
fn model_id(&self) -> &Uuid {
&self.id
}
fn model_type(&self) -> &'static str {
"mock"
}
fn add_triple(&mut self, _triple: Triple) -> AResult<()> {
Ok(())
}
async fn train(&mut self, _epochs: Option<usize>) -> AResult<TrainingStats> {
Ok(TrainingStats::default())
}
fn get_entity_embedding(&self, entity: &str) -> AResult<Vector> {
let v: Vec<f32> = entity
.bytes()
.take(4)
.enumerate()
.map(|(i, b)| (b as f32 + i as f32) / 256.0)
.collect();
Ok(Vector::new(v))
}
fn get_relation_embedding(&self, _rel: &str) -> AResult<Vector> {
Ok(Vector::new(vec![0.1, 0.2]))
}
fn score_triple(&self, _h: &str, _r: &str, t: &str) -> AResult<f64> {
let score = self
.entities
.iter()
.position(|e| e == t)
.map(|pos| 1.0 / (pos + 1) as f64)
.unwrap_or(0.0);
Ok(score)
}
fn predict_objects(&self, _s: &str, _p: &str, k: usize) -> AResult<Vec<(String, f64)>> {
Ok(self
.entities
.iter()
.take(k)
.map(|e| (e.clone(), 1.0))
.collect())
}
fn predict_subjects(&self, _p: &str, _o: &str, k: usize) -> AResult<Vec<(String, f64)>> {
Ok(self
.entities
.iter()
.take(k)
.map(|e| (e.clone(), 1.0))
.collect())
}
fn predict_relations(&self, _s: &str, _o: &str, k: usize) -> AResult<Vec<(String, f64)>> {
Ok(self
.entities
.iter()
.take(k)
.map(|e| (e.clone(), 1.0))
.collect())
}
fn get_entities(&self) -> Vec<String> {
self.entities.clone()
}
fn get_relations(&self) -> Vec<String> {
vec!["rel".to_string()]
}
fn get_stats(&self) -> ModelStats {
ModelStats::default()
}
fn save(&self, _path: &str) -> AResult<()> {
Ok(())
}
fn load(&mut self, _path: &str) -> AResult<()> {
Ok(())
}
fn clear(&mut self) {}
fn is_trained(&self) -> bool {
true
}
async fn encode(&self, texts: &[String]) -> AResult<Vec<Vec<f32>>> {
Ok(texts.iter().map(|_| vec![0.0f32; 4]).collect())
}
}
#[test]
fn test_mrr_correct_at_rank_1() {
let ev = EmbeddingEvaluator::new();
let mock = MockModel::new(vec!["e1".into(), "e2".into(), "e3".into()]);
let triple = Triple::new(
NamedNode::new("e1").expect("should succeed"),
NamedNode::new("r").expect("should succeed"),
NamedNode::new("e1").expect("should succeed"),
);
let mrr = ev.link_prediction_mrr(&mock, &[triple]);
assert!((mrr - 1.0).abs() < 1e-9, "expected MRR=1 got {mrr}");
}
#[test]
fn test_hits_at_k_correct_first() {
let ev = EmbeddingEvaluator::new();
let mock = MockModel::new(vec!["e1".into(), "e2".into(), "e3".into()]);
let triple = Triple::new(
NamedNode::new("e1").expect("should succeed"),
NamedNode::new("r").expect("should succeed"),
NamedNode::new("e1").expect("should succeed"),
);
let h = ev.hits_at_k(&mock, &[triple], 1);
assert!((h - 1.0).abs() < 1e-9);
}
#[test]
fn test_hits_at_k_not_in_top_1() {
let ev = EmbeddingEvaluator::new();
let mock = MockModel::new(vec!["e1".into(), "e2".into(), "e3".into()]);
let triple = Triple::new(
NamedNode::new("e1").expect("should succeed"),
NamedNode::new("r").expect("should succeed"),
NamedNode::new("e3").expect("should succeed"),
);
let h = ev.hits_at_k(&mock, &[triple], 1);
assert!((h - 0.0).abs() < 1e-9);
}
}