use ndarray::Array1;
use subsume::ndarray_backend::distance::query2box_distance;
use subsume::ndarray_backend::NdarrayBox;
use subsume::Box as BoxTrait;
fn rank_candidates<'a>(
query: &NdarrayBox,
candidates: &[(&'a str, &NdarrayBox)],
) -> Vec<(&'a str, f32)> {
let mut scored: Vec<(&str, f32)> = candidates
.iter()
.map(|(name, b)| {
let p = query.containment_prob(b).unwrap_or(0.0);
(*name, p)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
scored
}
fn rank_by_distance<'a>(
query: &NdarrayBox,
candidates: &[(&'a str, &NdarrayBox)],
alpha: f32,
) -> Vec<(&'a str, f32)> {
let mut scored: Vec<(&str, f32)> = candidates
.iter()
.map(|(name, b)| {
let center: Array1<f32> = (b.min() + b.max()) * 0.5;
let d = query2box_distance(query, ¢er, alpha).unwrap_or(f32::INFINITY);
(*name, d)
})
.collect();
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); scored
}
fn print_ranking(label: &str, ranking: &[(&str, f32)]) {
println!(" {label}");
for (i, (name, score)) in ranking.iter().enumerate() {
let marker = if *score > 0.5 { "<-- answer" } else { "" };
println!(" {}: {:>8} score={:.4} {}", i + 1, name, score, marker);
}
println!();
}
fn main() -> Result<(), subsume::BoxError> {
println!("=== Query2Box: Compositional Query Answering ===\n");
let dim = 8;
let vary = |base: f32, step: f32| -> Array1<f32> {
(0..dim).map(|d| base + step * d as f32).collect()
};
let france = NdarrayBox::new(vary(0.0, 0.01), vary(1.0, 0.01), 1.0)?;
let uk_min: Array1<f32> = (0..dim)
.map(|d| {
if !(2..4).contains(&d) {
2.0 + 0.01 * d as f32
} else {
0.0 + 0.01 * d as f32
}
})
.collect();
let uk_max: Array1<f32> = (0..dim)
.map(|d| {
if !(2..4).contains(&d) {
3.0 + 0.01 * d as f32
} else {
1.0 + 0.01 * d as f32
}
})
.collect();
let _uk = NdarrayBox::new(uk_min, uk_max, 1.0)?;
let paris = NdarrayBox::new(vary(0.1, 0.02), vary(0.3, 0.02), 1.0)?;
let lyon = NdarrayBox::new(vary(0.5, 0.015), vary(0.7, 0.015), 1.0)?;
let london_min: Array1<f32> = (0..dim)
.map(|d| {
if !(2..4).contains(&d) {
2.2 + 0.02 * d as f32
} else {
0.2 + 0.02 * d as f32
}
})
.collect();
let london_max: Array1<f32> = (0..dim)
.map(|d| {
if !(2..4).contains(&d) {
2.4 + 0.02 * d as f32
} else {
0.4 + 0.02 * d as f32
}
})
.collect();
let london = NdarrayBox::new(london_min, london_max, 1.0)?;
let french = NdarrayBox::new(vary(0.2, 0.015), vary(0.5, 0.015), 1.0)?;
let english_min: Array1<f32> = (0..dim)
.map(|d| {
if !(2..4).contains(&d) {
2.1 + 0.02 * d as f32
} else {
0.1 + 0.02 * d as f32
}
})
.collect();
let english_max: Array1<f32> = (0..dim)
.map(|d| {
if !(2..4).contains(&d) {
2.8 + 0.02 * d as f32
} else {
0.8 + 0.02 * d as f32
}
})
.collect();
let english = NdarrayBox::new(english_min, english_max, 1.0)?;
let _temp = 1.0;
println!("Q1: What cities are in France?\n");
let city_candidates: Vec<(&str, &NdarrayBox)> =
vec![("Paris", &paris), ("Lyon", &lyon), ("London", &london)];
let q1 = rank_candidates(&france, &city_candidates);
print_ranking("Rank by P(France contains city):", &q1);
println!("Q2: What languages are spoken in France?\n");
let lang_candidates: Vec<(&str, &NdarrayBox)> =
vec![("French", &french), ("English", &english)];
let q2 = rank_candidates(&france, &lang_candidates);
print_ranking("Rank by P(France contains language):", &q2);
println!("Q3: Languages spoken in countries with French cities (2-hop)\n");
println!(" Hop 1: intersect France with city-containing region");
let city_region = paris.union(&lyon)?.union(&london)?;
let hop1 = france.intersection(&city_region)?;
println!(
" intersection volume: {:.4} (> 0 confirms France has cities)",
hop1.volume()?
);
println!(" Hop 2: rank languages by containment in hop-1 result\n");
let q3 = rank_candidates(&hop1, &lang_candidates);
print_ranking("Rank by P(hop1_box contains language):", &q3);
println!("Q4: Alpha-weighted distance scoring (Ren et al., 2020)\n");
let alpha = 0.02;
println!(" alpha = {alpha} (inside penalty << outside penalty)\n");
println!(" Q1 re-scored: cities in France (by distance, ascending)\n");
let q4a = rank_by_distance(&france, &city_candidates, alpha);
for (i, (name, dist)) in q4a.iter().enumerate() {
let marker = if *dist < 1.0 { "<-- answer" } else { "" };
println!(" {}: {:>8} dist={:.4} {}", i + 1, name, dist, marker);
}
println!();
println!(" Q2 re-scored: languages in France (by distance, ascending)\n");
let q4b = rank_by_distance(&france, &lang_candidates, alpha);
for (i, (name, dist)) in q4b.iter().enumerate() {
println!(" {}: {:>8} dist={:.4}", i + 1, name, dist);
}
println!();
println!(" Alpha sensitivity: distance for Paris across alpha values\n");
let paris_center: Array1<f32> = (paris.min() + paris.max()) * 0.5;
for &a in &[0.0, 0.02, 0.1, 0.5, 1.0] {
let d = query2box_distance(&france, &paris_center, a)?;
println!(" alpha={a:.2}: dist={d:.4}");
}
println!();
println!("--- Summary ---\n");
println!(" Q1 correctly ranks Paris and Lyon above London.");
println!(" Q2 ranks French highest (fully inside France).");
println!(" Q3 chains two hops: city containment, then language containment.");
println!(" Intersection volume decreases at each hop, narrowing the answer set.");
println!(" Q4 shows Query2Box distance scoring: lower distance = better answer.");
println!(" Alpha controls inside-vs-outside penalty balance.");
Ok(())
}