use crate::Scorer;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TNorm {
Min,
Product,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScoreNorm {
Sigmoid,
Softmax,
}
#[derive(Debug, Clone)]
pub struct QueryConfig {
pub t_norm_projection: TNorm,
pub t_norm_intersection: TNorm,
pub beam_k: usize,
pub score_norm: ScoreNorm,
}
impl Default for QueryConfig {
fn default() -> Self {
Self {
t_norm_projection: TNorm::Product,
t_norm_intersection: TNorm::Min,
beam_k: 128,
score_norm: ScoreNorm::Sigmoid,
}
}
}
#[derive(Debug, Clone)]
pub enum Query {
Anchor {
entity: usize,
relation: usize,
},
Project {
inner: Box<Query>,
relation: usize,
},
Intersection {
branches: Vec<Query>,
},
Union {
branches: Vec<Query>,
},
Negation {
inner: Box<Query>,
},
}
impl Query {
pub fn anchor(entity: usize, relation: usize) -> Self {
Query::Anchor { entity, relation }
}
pub fn then(self, relation: usize) -> Self {
Query::Project {
inner: Box::new(self),
relation,
}
}
pub fn intersection(branches: Vec<Query>) -> Self {
assert!(
!branches.is_empty(),
"intersection requires at least one branch"
);
Query::Intersection { branches }
}
pub fn union(branches: Vec<Query>) -> Self {
assert!(!branches.is_empty(), "union requires at least one branch");
Query::Union { branches }
}
pub fn negate(self) -> Self {
Query::Negation {
inner: Box::new(self),
}
}
}
pub fn answer_query(model: &dyn Scorer, query: &Query, config: &QueryConfig) -> Vec<f32> {
let n = model.num_entities();
eval_query(model, query, config, n)
}
pub fn answer_query_topk(
model: &dyn Scorer,
query: &Query,
config: &QueryConfig,
k: usize,
) -> Vec<(usize, f32)> {
let scores = answer_query(model, query, config);
top_k_descending(&scores, k)
}
fn eval_query(model: &dyn Scorer, query: &Query, config: &QueryConfig, n: usize) -> Vec<f32> {
match query {
Query::Anchor { entity, relation } => {
let raw = model.score_all_tails(*entity, *relation);
normalize_scores(&raw, config.score_norm)
}
Query::Project { inner, relation } => {
let inner_scores = eval_query(model, inner, config, n);
beam_project(model, &inner_scores, *relation, config, n)
}
Query::Intersection { branches } => {
let branch_scores: Vec<Vec<f32>> = branches
.iter()
.map(|b| eval_query(model, b, config, n))
.collect();
combine_conjunction(&branch_scores, config.t_norm_intersection, n)
}
Query::Union { branches } => {
let branch_scores: Vec<Vec<f32>> = branches
.iter()
.map(|b| eval_query(model, b, config, n))
.collect();
combine_disjunction(&branch_scores, config.t_norm_intersection, n)
}
Query::Negation { inner } => {
let scores = eval_query(model, inner, config, n);
scores.iter().map(|&s| 1.0 - s).collect()
}
}
}
fn beam_project(
model: &dyn Scorer,
inner_scores: &[f32],
relation: usize,
config: &QueryConfig,
n: usize,
) -> Vec<f32> {
let candidates = top_k_descending(inner_scores, config.beam_k);
let norm = config.t_norm_projection;
match norm {
TNorm::Min if config.score_norm == ScoreNorm::Sigmoid => {
let mut best_raw = vec![f32::NEG_INFINITY; n];
for &(entity, inner_score) in &candidates {
let inner_raw = logit(inner_score);
let raw_tail_scores = model.score_all_tails(entity, relation);
for (t, &raw) in raw_tail_scores.iter().enumerate() {
let tail_raw = -raw;
let combined_raw = inner_raw.min(tail_raw);
if combined_raw > best_raw[t] {
best_raw[t] = combined_raw;
}
}
}
best_raw.iter().map(|&r| sigmoid(r)).collect()
}
_ => {
let mut result = vec![0.0_f32; n];
for &(entity, inner_score) in &candidates {
let raw_tail_scores = model.score_all_tails(entity, relation);
let tail_probs = normalize_scores(&raw_tail_scores, config.score_norm);
for (t, &tail_prob) in tail_probs.iter().enumerate() {
let combined = apply_t_norm(inner_score, tail_prob, norm);
if combined > result[t] {
result[t] = combined;
}
}
}
result
}
}
}
fn logit(p: f32) -> f32 {
(p / (1.0 - p)).ln()
}
fn normalize_scores(raw: &[f32], norm: ScoreNorm) -> Vec<f32> {
match norm {
ScoreNorm::Sigmoid => raw.iter().map(|&s| sigmoid(-s)).collect(),
ScoreNorm::Softmax => {
let max = raw.iter().copied().fold(f32::INFINITY, |a, b| a.min(b)); let exps: Vec<f32> = raw.iter().map(|&s| (-(s - max)).exp()).collect();
let sum: f32 = exps.iter().sum();
if sum > 0.0 {
exps.iter().map(|&e| e / sum).collect()
} else {
vec![1.0 / raw.len() as f32; raw.len()]
}
}
}
}
fn sigmoid(x: f32) -> f32 {
if x >= 0.0 {
let e = (-x).exp();
1.0 / (1.0 + e)
} else {
let e = x.exp();
e / (1.0 + e)
}
}
fn apply_t_norm(a: f32, b: f32, norm: TNorm) -> f32 {
match norm {
TNorm::Min => a.min(b),
TNorm::Product => a * b,
}
}
fn apply_t_conorm(a: f32, b: f32, norm: TNorm) -> f32 {
match norm {
TNorm::Min => a.max(b),
TNorm::Product => a + b - a * b,
}
}
fn combine_conjunction(branch_scores: &[Vec<f32>], norm: TNorm, n: usize) -> Vec<f32> {
let mut result = vec![1.0_f32; n];
for branch in branch_scores {
for (i, &s) in branch.iter().enumerate() {
result[i] = apply_t_norm(result[i], s, norm);
}
}
result
}
fn combine_disjunction(branch_scores: &[Vec<f32>], norm: TNorm, n: usize) -> Vec<f32> {
let mut result = vec![0.0_f32; n];
for branch in branch_scores {
for (i, &s) in branch.iter().enumerate() {
result[i] = apply_t_conorm(result[i], s, norm);
}
}
result
}
fn top_k_descending(scores: &[f32], k: usize) -> Vec<(usize, f32)> {
let mut indexed: Vec<(usize, f32)> = scores.iter().copied().enumerate().collect();
indexed.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(k);
indexed
}
#[cfg(test)]
mod tests {
use super::*;
struct ChainModel {
n: usize,
}
impl Scorer for ChainModel {
fn score(&self, h: usize, r: usize, t: usize) -> f32 {
let expected = (h + r + 1) % self.n;
t.abs_diff(expected) as f32
}
fn num_entities(&self) -> usize {
self.n
}
}
#[test]
fn anchor_query_matches_score_all_tails() {
let model = ChainModel { n: 10 };
let config = QueryConfig::default();
let query = Query::anchor(2, 3);
let scores = answer_query(&model, &query, &config);
assert_eq!(scores.len(), 10);
let best = scores
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap();
assert_eq!(best.0, 6, "Entity 6 should score best for (2, 3, ?)");
}
#[test]
fn chain_2p_finds_two_hop_answer() {
let model = ChainModel { n: 20 };
let config = QueryConfig {
t_norm_projection: TNorm::Product,
t_norm_intersection: TNorm::Min,
beam_k: 20,
..QueryConfig::default()
};
let query = Query::anchor(0, 0).then(2);
let scores = answer_query(&model, &query, &config);
let top = top_k_descending(&scores, 3);
assert_eq!(top[0].0, 4, "Two-hop answer should be entity 4");
}
#[test]
fn intersection_narrows_results() {
let model = ChainModel { n: 20 };
let config = QueryConfig::default();
let query = Query::intersection(vec![Query::anchor(0, 4), Query::anchor(2, 2)]);
let scores = answer_query(&model, &query, &config);
let top = top_k_descending(&scores, 1);
assert_eq!(top[0].0, 5, "Intersection should agree on entity 5");
}
#[test]
fn union_at_least_as_good_as_branches() {
let model = ChainModel { n: 10 };
let config = QueryConfig {
t_norm_intersection: TNorm::Product,
..QueryConfig::default()
};
let branch1 = Query::anchor(0, 0);
let branch2 = Query::anchor(3, 3);
let scores1 = answer_query(&model, &branch1, &config);
let scores2 = answer_query(&model, &branch2, &config);
let union_scores = answer_query(&model, &Query::union(vec![branch1, branch2]), &config);
for i in 0..10 {
assert!(
union_scores[i] >= scores1[i] - 1e-6,
"Union score should be >= branch 1 for entity {i}"
);
assert!(
union_scores[i] >= scores2[i] - 1e-6,
"Union score should be >= branch 2 for entity {i}"
);
}
}
#[test]
fn negation_inverts_scores() {
let model = ChainModel { n: 10 };
let config = QueryConfig::default();
let query = Query::anchor(0, 0);
let scores = answer_query(&model, &query, &config);
let neg_scores = answer_query(&model, &query.clone().negate(), &config);
for i in 0..10 {
assert!(
(scores[i] + neg_scores[i] - 1.0).abs() < 1e-6,
"score + negated should equal 1.0 for entity {i}: {} + {} = {}",
scores[i],
neg_scores[i],
scores[i] + neg_scores[i],
);
}
}
#[test]
fn topk_returns_sorted_descending() {
let model = ChainModel { n: 20 };
let config = QueryConfig::default();
let query = Query::anchor(0, 0);
let top = answer_query_topk(&model, &query, &config, 5);
assert_eq!(top.len(), 5);
for w in top.windows(2) {
assert!(w[0].1 >= w[1].1, "Top-k should be sorted descending");
}
}
#[test]
fn pi_query_intersect_then_project() {
let model = ChainModel { n: 20 };
let config = QueryConfig {
t_norm_projection: TNorm::Min,
t_norm_intersection: TNorm::Min,
beam_k: 20,
..QueryConfig::default()
};
let query = Query::intersection(vec![Query::anchor(0, 4), Query::anchor(2, 2)]).then(0);
let top = answer_query_topk(&model, &query, &config, 1);
assert_eq!(top[0].0, 6, "pi query should find entity 6");
}
#[test]
fn sigmoid_is_numerically_stable() {
assert!((sigmoid(100.0) - 1.0).abs() < 1e-6);
assert!(sigmoid(-100.0).abs() < 1e-6);
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
assert!(sigmoid(f32::MAX).is_finite());
assert!(sigmoid(f32::MIN).is_finite());
}
#[test]
fn t_norm_properties() {
for &x in &[0.0, 0.3, 0.7, 1.0] {
assert!((apply_t_norm(x, 1.0, TNorm::Min) - x).abs() < 1e-6);
assert!((apply_t_norm(x, 1.0, TNorm::Product) - x).abs() < 1e-6);
}
let (a, b) = (0.3, 0.7);
assert_eq!(
apply_t_norm(a, b, TNorm::Min),
apply_t_norm(b, a, TNorm::Min)
);
assert!(
(apply_t_norm(a, b, TNorm::Product) - apply_t_norm(b, a, TNorm::Product)).abs() < 1e-6
);
}
#[test]
fn t_conorm_de_morgan_duality() {
let (a, b) = (0.3, 0.7);
for norm in [TNorm::Min, TNorm::Product] {
let conorm = apply_t_conorm(a, b, norm);
let dual = 1.0 - apply_t_norm(1.0 - a, 1.0 - b, norm);
assert!(
(conorm - dual).abs() < 1e-6,
"De Morgan failed for {norm:?}: {conorm} vs {dual}"
);
}
}
}