anno/backends/box_embeddings/mod.rs
1//! Box embeddings for coreference resolution.
2//!
3//! This module implements geometric representations (box embeddings) that encode
4//! logical invariants of coreference resolution, addressing limitations of
5//! vector-based approaches.
6//!
7//! **Note**: Training code is in `box_embeddings_training.rs`. The [matryoshka-box](https://github.com/arclabs561/matryoshka-box)
8//! research project extends training with matryoshka-specific features (variable dimensions, etc.).
9//!
10//! # Key Concepts
11//!
12//! - **Box Embeddings**: Entities represented as axis-aligned hyperrectangles
13//! - **Conditional Probability**: Coreference = high mutual overlap
14//! - **Temporal Boxes**: Entities that evolve over time
15//! - **Uncertainty-Aware**: Box volume = confidence
16//!
17//! # Research Background
18//!
19//! This implementation is related to the **matryoshka-box** research project (not yet published),
20//! which combines matryoshka embeddings (variable dimensions) with box embeddings (hierarchical reasoning).
21//! Standard training is in `box_embeddings_training.rs`; matryoshka-box extends it with research features.
22//!
23//! Based on research from:
24//! - Vilnis et al. (2018): "Probabilistic Embedding of Knowledge Graphs with Box Lattice Measures"
25//! - Lee et al. (2022): "Box Embeddings for Event-Event Relation Extraction" (BERE)
26//! - Messner et al. (2022): "Temporal Knowledge Graph Completion with Box Embeddings" (BoxTE)
27//! - Chen et al. (2021): "Uncertainty-Aware Knowledge Graph Embeddings" (UKGE)
28//!
29//! # Complementary Geometric Representations
30//!
31//! Box embeddings are one of several geometric approaches available in Anno.
32//! See `archive/geometric-2024-12/` for alternatives:
33//!
34//! | Representation | Best For | Module |
35//! |---------------|----------|--------|
36//! | **Box embeddings** | Temporal, uncertainty | This module |
37//! | Hyperbolic (Poincaré) | Deep type hierarchies | `archive/geometric-2024-12/hyperbolic.rs` |
38//! | Sheaf NN | Gradient-level transitivity | `archive/geometric-2024-12/sheaf.rs` |
39//! | TDA | Structural diagnostics | `archive/geometric-2024-12/tda.rs` |
40//!
41//! These approaches are **complementary**, not competing. Use boxes when you need:
42//! - Explicit uncertainty (volume = confidence)
43//! - Temporal evolution (min/max with velocity)
44//! - Easy visualization and debugging
45
46use serde::{Deserialize, Serialize};
47use std::f32;
48
49/// A box embedding representing an entity in d-dimensional space.
50///
51/// Boxes are axis-aligned hyperrectangles defined by min/max bounds in each dimension.
52/// Coreference is modeled as high mutual conditional probability (overlap).
53#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54pub struct BoxEmbedding {
55 /// Lower bound in each dimension (d-dimensional vector).
56 pub min: Vec<f32>,
57 /// Upper bound in each dimension (d-dimensional vector).
58 pub max: Vec<f32>,
59}
60
61impl BoxEmbedding {
62 /// Create a new box embedding.
63 ///
64 /// # Panics
65 ///
66 /// Panics if `min.len() != max.len()` or if any `min[i] > max[i]`.
67 pub fn new(min: Vec<f32>, max: Vec<f32>) -> Self {
68 assert_eq!(min.len(), max.len(), "min and max must have same dimension");
69 for (i, (&m, &max_val)) in min.iter().zip(max.iter()).enumerate() {
70 assert!(
71 m <= max_val,
72 "min[{}] = {} must be <= max[{}] = {}",
73 i,
74 m,
75 i,
76 max_val
77 );
78 }
79 Self { min, max }
80 }
81
82 /// Get the dimension of the box.
83 #[must_use]
84 pub fn dim(&self) -> usize {
85 self.min.len()
86 }
87
88 /// Compute the volume of the box.
89 ///
90 /// Volume = product of (max - min) for each dimension.
91 #[must_use]
92 pub fn volume(&self) -> f32 {
93 self.min
94 .iter()
95 .zip(self.max.iter())
96 .map(|(&m, &max_val)| (max_val - m).max(0.0))
97 .product()
98 }
99
100 /// Compute the intersection volume with another box.
101 ///
102 /// Returns 0.0 if boxes are disjoint.
103 #[must_use]
104 pub fn intersection_volume(&self, other: &Self) -> f32 {
105 assert_eq!(
106 self.dim(),
107 other.dim(),
108 "Boxes must have same dimension for intersection"
109 );
110
111 self.min
112 .iter()
113 .zip(self.max.iter())
114 .zip(other.min.iter().zip(other.max.iter()))
115 .map(|((&m1, &max1), (&m2, &max2))| {
116 let intersection_min = m1.max(m2);
117 let intersection_max = max1.min(max2);
118 (intersection_max - intersection_min).max(0.0)
119 })
120 .product()
121 }
122
123 /// Compute conditional probability P(self | other).
124 ///
125 /// This is the BERE model's coreference metric:
126 /// P(A|B) = Vol(A ∩ B) / Vol(B)
127 ///
128 /// Returns a value in [0.0, 1.0] where:
129 /// - 1.0 = self is completely contained in other
130 /// - 0.0 = boxes are disjoint
131 #[must_use]
132 pub fn conditional_probability(&self, other: &Self) -> f32 {
133 let vol_other = other.volume();
134 if vol_other == 0.0 {
135 return 0.0;
136 }
137 self.intersection_volume(other) / vol_other
138 }
139
140 /// Compute mutual coreference score.
141 ///
142 /// Coreference requires high mutual conditional probability:
143 /// score = (P(A|B) + P(B|A)) / 2
144 ///
145 /// This ensures both boxes largely contain each other (high overlap).
146 #[must_use]
147 pub fn coreference_score(&self, other: &Self) -> f32 {
148 let p_a_given_b = self.conditional_probability(other);
149 let p_b_given_a = other.conditional_probability(self);
150 (p_a_given_b + p_b_given_a) / 2.0
151 }
152
153 /// Check if this box is contained in another box.
154 ///
155 /// Returns true if self ⊆ other (all dimensions).
156 #[must_use]
157 pub fn is_contained_in(&self, other: &Self) -> bool {
158 assert_eq!(self.dim(), other.dim(), "Boxes must have same dimension");
159 self.min
160 .iter()
161 .zip(self.max.iter())
162 .zip(other.min.iter().zip(other.max.iter()))
163 .all(|((&m1, &max1), (&m2, &max2))| m2 <= m1 && max1 <= max2)
164 }
165
166 /// Check if boxes are disjoint (no overlap).
167 #[must_use]
168 pub fn is_disjoint(&self, other: &Self) -> bool {
169 self.intersection_volume(other) == 0.0
170 }
171
172 /// Create a box embedding from a vector embedding.
173 ///
174 /// Converts a point embedding to a box by creating a small hypercube
175 /// around the point. The box size is controlled by `radius`.
176 ///
177 /// # Arguments
178 ///
179 /// * `vector` - Vector embedding (point in space)
180 /// * `radius` - Half-width of the box in each dimension
181 ///
182 /// # Example
183 ///
184 /// ```rust,ignore
185 /// let vector = vec![0.5, 0.5, 0.5];
186 /// let box_embedding = BoxEmbedding::from_vector(&vector, 0.1);
187 /// // Creates box: min=[0.4, 0.4, 0.4], max=[0.6, 0.6, 0.6]
188 /// ```
189 #[must_use]
190 pub fn from_vector(vector: &[f32], radius: f32) -> Self {
191 let min: Vec<f32> = vector.iter().map(|&v| v - radius).collect();
192 let max: Vec<f32> = vector.iter().map(|&v| v + radius).collect();
193 Self::new(min, max)
194 }
195
196 /// Create a box embedding from a vector with adaptive radius.
197 ///
198 /// Uses a radius proportional to the vector's magnitude, creating
199 /// larger boxes for vectors further from the origin.
200 ///
201 /// # Arguments
202 ///
203 /// * `vector` - Vector embedding
204 /// * `radius_factor` - Multiplier for adaptive radius (default: 0.1)
205 #[must_use]
206 pub fn from_vector_adaptive(vector: &[f32], radius_factor: f32) -> Self {
207 let magnitude: f32 = vector.iter().map(|&v| v * v).sum::<f32>().sqrt();
208 let radius = magnitude * radius_factor + 0.01; // Add small epsilon
209 Self::from_vector(vector, radius)
210 }
211
212 /// Get the center point of the box.
213 ///
214 /// Returns the midpoint in each dimension.
215 #[must_use]
216 pub fn center(&self) -> Vec<f32> {
217 self.min
218 .iter()
219 .zip(self.max.iter())
220 .map(|(&m, &max_val)| (m + max_val) / 2.0)
221 .collect()
222 }
223
224 /// Get the size (width) in each dimension.
225 #[must_use]
226 pub fn size(&self) -> Vec<f32> {
227 self.min
228 .iter()
229 .zip(self.max.iter())
230 .map(|(&m, &max_val)| (max_val - m).max(0.0))
231 .collect()
232 }
233
234 /// Compute the intersection box with another box.
235 ///
236 /// Returns a new box representing the overlapping region.
237 /// If boxes are disjoint, returns a zero-volume box.
238 #[must_use]
239 pub fn intersection(&self, other: &Self) -> Self {
240 assert_eq!(
241 self.dim(),
242 other.dim(),
243 "Boxes must have same dimension for intersection"
244 );
245
246 let min: Vec<f32> = self
247 .min
248 .iter()
249 .zip(other.min.iter())
250 .map(|(&a, &b)| a.max(b))
251 .collect();
252
253 let max: Vec<f32> = self
254 .max
255 .iter()
256 .zip(other.max.iter())
257 .map(|(&a, &b)| a.min(b))
258 .collect();
259
260 Self { min, max }
261 }
262
263 /// Compute the union box (bounding box containing both).
264 #[must_use]
265 pub fn union(&self, other: &Self) -> Self {
266 assert_eq!(
267 self.dim(),
268 other.dim(),
269 "Boxes must have same dimension for union"
270 );
271
272 let min: Vec<f32> = self
273 .min
274 .iter()
275 .zip(other.min.iter())
276 .map(|(&a, &b)| a.min(b))
277 .collect();
278
279 let max: Vec<f32> = self
280 .max
281 .iter()
282 .zip(other.max.iter())
283 .map(|(&a, &b)| a.max(b))
284 .collect();
285
286 Self { min, max }
287 }
288
289 /// Compute overlap probability (Jaccard-style).
290 ///
291 /// P(overlap) = Vol(intersection) / Vol(union)
292 #[must_use]
293 pub fn overlap_prob(&self, other: &Self) -> f32 {
294 let intersection_vol = self.intersection_volume(other);
295 let union_vol = self.volume() + other.volume() - intersection_vol;
296 if union_vol == 0.0 {
297 return 0.0;
298 }
299 intersection_vol / union_vol
300 }
301
302 /// Compute minimum Euclidean distance between two boxes.
303 ///
304 /// Returns 0.0 if boxes overlap.
305 #[must_use]
306 pub fn distance(&self, other: &Self) -> f32 {
307 assert_eq!(
308 self.dim(),
309 other.dim(),
310 "Boxes must have same dimension for distance"
311 );
312
313 let dist_sq: f32 = self
314 .min
315 .iter()
316 .zip(self.max.iter())
317 .zip(other.min.iter().zip(other.max.iter()))
318 .map(|((&min1, &max1), (&min2, &max2))| {
319 // Gap in this dimension
320 let gap = if max1 < min2 {
321 min2 - max1 // other is to the right
322 } else if max2 < min1 {
323 min1 - max2 // other is to the left
324 } else {
325 0.0 // overlap in this dimension
326 };
327 gap * gap
328 })
329 .sum();
330
331 dist_sq.sqrt()
332 }
333}
334
335// =============================================================================
336// Subsume Trait Implementation (optional, feature-gated)
337// =============================================================================
338
339/// Implements the subsume-core Box trait when the `subsume` feature is enabled.
340///
341/// This allows anno's BoxEmbedding to be used with subsume's distance metrics,
342/// training utilities, and other advanced box operations.
343#[cfg(feature = "subsume")]
344impl subsume_core::Box for BoxEmbedding {
345 type Scalar = f32;
346 type Vector = Vec<f32>;
347
348 fn min(&self) -> &Self::Vector {
349 &self.min
350 }
351
352 fn max(&self) -> &Self::Vector {
353 &self.max
354 }
355
356 fn dim(&self) -> usize {
357 self.min.len()
358 }
359
360 fn volume(&self, _temperature: Self::Scalar) -> Result<Self::Scalar, subsume_core::BoxError> {
361 // anno's BoxEmbedding doesn't use temperature (hard boxes)
362 Ok(BoxEmbedding::volume(self))
363 }
364
365 fn intersection(&self, other: &Self) -> Result<Self, subsume_core::BoxError> {
366 if self.dim() != other.dim() {
367 return Err(subsume_core::BoxError::DimensionMismatch {
368 expected: self.dim(),
369 actual: other.dim(),
370 });
371 }
372 Ok(BoxEmbedding::intersection(self, other))
373 }
374
375 fn containment_prob(
376 &self,
377 other: &Self,
378 _temperature: Self::Scalar,
379 ) -> Result<Self::Scalar, subsume_core::BoxError> {
380 if self.dim() != other.dim() {
381 return Err(subsume_core::BoxError::DimensionMismatch {
382 expected: self.dim(),
383 actual: other.dim(),
384 });
385 }
386 // subsume: P(other ⊆ self) = Vol(intersection) / Vol(other)
387 // This is the same as anno's conditional_probability but with swapped args
388 Ok(self.conditional_probability(other))
389 }
390
391 fn overlap_prob(
392 &self,
393 other: &Self,
394 _temperature: Self::Scalar,
395 ) -> Result<Self::Scalar, subsume_core::BoxError> {
396 if self.dim() != other.dim() {
397 return Err(subsume_core::BoxError::DimensionMismatch {
398 expected: self.dim(),
399 actual: other.dim(),
400 });
401 }
402 Ok(BoxEmbedding::overlap_prob(self, other))
403 }
404
405 fn union(&self, other: &Self) -> Result<Self, subsume_core::BoxError> {
406 if self.dim() != other.dim() {
407 return Err(subsume_core::BoxError::DimensionMismatch {
408 expected: self.dim(),
409 actual: other.dim(),
410 });
411 }
412 Ok(BoxEmbedding::union(self, other))
413 }
414
415 fn center(&self) -> Result<Self::Vector, subsume_core::BoxError> {
416 Ok(BoxEmbedding::center(self))
417 }
418
419 fn distance(&self, other: &Self) -> Result<Self::Scalar, subsume_core::BoxError> {
420 if self.dim() != other.dim() {
421 return Err(subsume_core::BoxError::DimensionMismatch {
422 expected: self.dim(),
423 actual: other.dim(),
424 });
425 }
426 Ok(BoxEmbedding::distance(self, other))
427 }
428
429 fn truncate(&self, k: usize) -> Result<Self, subsume_core::BoxError> {
430 if k > self.dim() {
431 return Err(subsume_core::BoxError::MatryoshkaMismatch {
432 requested: k,
433 actual: self.dim(),
434 });
435 }
436 Ok(BoxEmbedding::new(
437 self.min[..k].to_vec(),
438 self.max[..k].to_vec(),
439 ))
440 }
441}
442
443/// Configuration for box-based coreference resolution.
444pub mod extras;
445pub use extras::*;
446#[cfg(test)]
447mod tests;