Skip to main content

anno/backends/box_embeddings/
mod.rs

1//! Box embeddings for coreference resolution.
2//!
3//! This module implements geometric representations (box embeddings) that encode
4//! logical invariants of coreference resolution, addressing limitations of
5//! vector-based approaches.
6//!
7//! **Note**: Training code is in `box_embeddings_training.rs`. The [matryoshka-box](https://github.com/arclabs561/matryoshka-box)
8//! research project extends training with matryoshka-specific features (variable dimensions, etc.).
9//!
10//! # Key Concepts
11//!
12//! - **Box Embeddings**: Entities represented as axis-aligned hyperrectangles
13//! - **Conditional Probability**: Coreference = high mutual overlap
14//! - **Temporal Boxes**: Entities that evolve over time
15//! - **Uncertainty-Aware**: Box volume = confidence
16//!
17//! # Research Background
18//!
19//! This implementation is related to the **matryoshka-box** research project (not yet published),
20//! which combines matryoshka embeddings (variable dimensions) with box embeddings (hierarchical reasoning).
21//! Standard training is in `box_embeddings_training.rs`; matryoshka-box extends it with research features.
22//!
23//! Based on research from:
24//! - Vilnis et al. (2018): "Probabilistic Embedding of Knowledge Graphs with Box Lattice Measures"
25//! - Lee et al. (2022): "Box Embeddings for Event-Event Relation Extraction" (BERE)
26//! - Messner et al. (2022): "Temporal Knowledge Graph Completion with Box Embeddings" (BoxTE)
27//! - Chen et al. (2021): "Uncertainty-Aware Knowledge Graph Embeddings" (UKGE)
28//!
29//! # Complementary Geometric Representations
30//!
31//! Box embeddings are one of several geometric approaches available in Anno.
32//! See `archive/geometric-2024-12/` for alternatives:
33//!
34//! | Representation | Best For | Module |
35//! |---------------|----------|--------|
36//! | **Box embeddings** | Temporal, uncertainty | This module |
37//! | Hyperbolic (Poincaré) | Deep type hierarchies | `archive/geometric-2024-12/hyperbolic.rs` |
38//! | Sheaf NN | Gradient-level transitivity | `archive/geometric-2024-12/sheaf.rs` |
39//! | TDA | Structural diagnostics | `archive/geometric-2024-12/tda.rs` |
40//!
41//! These approaches are **complementary**, not competing. Use boxes when you need:
42//! - Explicit uncertainty (volume = confidence)
43//! - Temporal evolution (min/max with velocity)
44//! - Easy visualization and debugging
45
46use serde::{Deserialize, Serialize};
47use std::f32;
48
49/// A box embedding representing an entity in d-dimensional space.
50///
51/// Boxes are axis-aligned hyperrectangles defined by min/max bounds in each dimension.
52/// Coreference is modeled as high mutual conditional probability (overlap).
53#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54pub struct BoxEmbedding {
55    /// Lower bound in each dimension (d-dimensional vector).
56    pub min: Vec<f32>,
57    /// Upper bound in each dimension (d-dimensional vector).
58    pub max: Vec<f32>,
59}
60
61impl BoxEmbedding {
62    /// Create a new box embedding.
63    ///
64    /// # Panics
65    ///
66    /// Panics if `min.len() != max.len()` or if any `min[i] > max[i]`.
67    pub fn new(min: Vec<f32>, max: Vec<f32>) -> Self {
68        assert_eq!(min.len(), max.len(), "min and max must have same dimension");
69        for (i, (&m, &max_val)) in min.iter().zip(max.iter()).enumerate() {
70            assert!(
71                m <= max_val,
72                "min[{}] = {} must be <= max[{}] = {}",
73                i,
74                m,
75                i,
76                max_val
77            );
78        }
79        Self { min, max }
80    }
81
82    /// Get the dimension of the box.
83    #[must_use]
84    pub fn dim(&self) -> usize {
85        self.min.len()
86    }
87
88    /// Compute the volume of the box.
89    ///
90    /// Volume = product of (max - min) for each dimension.
91    #[must_use]
92    pub fn volume(&self) -> f32 {
93        self.min
94            .iter()
95            .zip(self.max.iter())
96            .map(|(&m, &max_val)| (max_val - m).max(0.0))
97            .product()
98    }
99
100    /// Compute the intersection volume with another box.
101    ///
102    /// Returns 0.0 if boxes are disjoint.
103    #[must_use]
104    pub fn intersection_volume(&self, other: &Self) -> f32 {
105        assert_eq!(
106            self.dim(),
107            other.dim(),
108            "Boxes must have same dimension for intersection"
109        );
110
111        self.min
112            .iter()
113            .zip(self.max.iter())
114            .zip(other.min.iter().zip(other.max.iter()))
115            .map(|((&m1, &max1), (&m2, &max2))| {
116                let intersection_min = m1.max(m2);
117                let intersection_max = max1.min(max2);
118                (intersection_max - intersection_min).max(0.0)
119            })
120            .product()
121    }
122
123    /// Compute conditional probability P(self | other).
124    ///
125    /// This is the BERE model's coreference metric:
126    /// P(A|B) = Vol(A ∩ B) / Vol(B)
127    ///
128    /// Returns a value in [0.0, 1.0] where:
129    /// - 1.0 = self is completely contained in other
130    /// - 0.0 = boxes are disjoint
131    #[must_use]
132    pub fn conditional_probability(&self, other: &Self) -> f32 {
133        let vol_other = other.volume();
134        if vol_other == 0.0 {
135            return 0.0;
136        }
137        self.intersection_volume(other) / vol_other
138    }
139
140    /// Compute mutual coreference score.
141    ///
142    /// Coreference requires high mutual conditional probability:
143    /// score = (P(A|B) + P(B|A)) / 2
144    ///
145    /// This ensures both boxes largely contain each other (high overlap).
146    #[must_use]
147    pub fn coreference_score(&self, other: &Self) -> f32 {
148        let p_a_given_b = self.conditional_probability(other);
149        let p_b_given_a = other.conditional_probability(self);
150        (p_a_given_b + p_b_given_a) / 2.0
151    }
152
153    /// Check if this box is contained in another box.
154    ///
155    /// Returns true if self ⊆ other (all dimensions).
156    #[must_use]
157    pub fn is_contained_in(&self, other: &Self) -> bool {
158        assert_eq!(self.dim(), other.dim(), "Boxes must have same dimension");
159        self.min
160            .iter()
161            .zip(self.max.iter())
162            .zip(other.min.iter().zip(other.max.iter()))
163            .all(|((&m1, &max1), (&m2, &max2))| m2 <= m1 && max1 <= max2)
164    }
165
166    /// Check if boxes are disjoint (no overlap).
167    #[must_use]
168    pub fn is_disjoint(&self, other: &Self) -> bool {
169        self.intersection_volume(other) == 0.0
170    }
171
172    /// Create a box embedding from a vector embedding.
173    ///
174    /// Converts a point embedding to a box by creating a small hypercube
175    /// around the point. The box size is controlled by `radius`.
176    ///
177    /// # Arguments
178    ///
179    /// * `vector` - Vector embedding (point in space)
180    /// * `radius` - Half-width of the box in each dimension
181    ///
182    /// # Example
183    ///
184    /// ```rust,ignore
185    /// let vector = vec![0.5, 0.5, 0.5];
186    /// let box_embedding = BoxEmbedding::from_vector(&vector, 0.1);
187    /// // Creates box: min=[0.4, 0.4, 0.4], max=[0.6, 0.6, 0.6]
188    /// ```
189    #[must_use]
190    pub fn from_vector(vector: &[f32], radius: f32) -> Self {
191        let min: Vec<f32> = vector.iter().map(|&v| v - radius).collect();
192        let max: Vec<f32> = vector.iter().map(|&v| v + radius).collect();
193        Self::new(min, max)
194    }
195
196    /// Create a box embedding from a vector with adaptive radius.
197    ///
198    /// Uses a radius proportional to the vector's magnitude, creating
199    /// larger boxes for vectors further from the origin.
200    ///
201    /// # Arguments
202    ///
203    /// * `vector` - Vector embedding
204    /// * `radius_factor` - Multiplier for adaptive radius (default: 0.1)
205    #[must_use]
206    pub fn from_vector_adaptive(vector: &[f32], radius_factor: f32) -> Self {
207        let magnitude: f32 = vector.iter().map(|&v| v * v).sum::<f32>().sqrt();
208        let radius = magnitude * radius_factor + 0.01; // Add small epsilon
209        Self::from_vector(vector, radius)
210    }
211
212    /// Get the center point of the box.
213    ///
214    /// Returns the midpoint in each dimension.
215    #[must_use]
216    pub fn center(&self) -> Vec<f32> {
217        self.min
218            .iter()
219            .zip(self.max.iter())
220            .map(|(&m, &max_val)| (m + max_val) / 2.0)
221            .collect()
222    }
223
224    /// Get the size (width) in each dimension.
225    #[must_use]
226    pub fn size(&self) -> Vec<f32> {
227        self.min
228            .iter()
229            .zip(self.max.iter())
230            .map(|(&m, &max_val)| (max_val - m).max(0.0))
231            .collect()
232    }
233
234    /// Compute the intersection box with another box.
235    ///
236    /// Returns a new box representing the overlapping region.
237    /// If boxes are disjoint, returns a zero-volume box.
238    #[must_use]
239    pub fn intersection(&self, other: &Self) -> Self {
240        assert_eq!(
241            self.dim(),
242            other.dim(),
243            "Boxes must have same dimension for intersection"
244        );
245
246        let min: Vec<f32> = self
247            .min
248            .iter()
249            .zip(other.min.iter())
250            .map(|(&a, &b)| a.max(b))
251            .collect();
252
253        let max: Vec<f32> = self
254            .max
255            .iter()
256            .zip(other.max.iter())
257            .map(|(&a, &b)| a.min(b))
258            .collect();
259
260        Self { min, max }
261    }
262
263    /// Compute the union box (bounding box containing both).
264    #[must_use]
265    pub fn union(&self, other: &Self) -> Self {
266        assert_eq!(
267            self.dim(),
268            other.dim(),
269            "Boxes must have same dimension for union"
270        );
271
272        let min: Vec<f32> = self
273            .min
274            .iter()
275            .zip(other.min.iter())
276            .map(|(&a, &b)| a.min(b))
277            .collect();
278
279        let max: Vec<f32> = self
280            .max
281            .iter()
282            .zip(other.max.iter())
283            .map(|(&a, &b)| a.max(b))
284            .collect();
285
286        Self { min, max }
287    }
288
289    /// Compute overlap probability (Jaccard-style).
290    ///
291    /// P(overlap) = Vol(intersection) / Vol(union)
292    #[must_use]
293    pub fn overlap_prob(&self, other: &Self) -> f32 {
294        let intersection_vol = self.intersection_volume(other);
295        let union_vol = self.volume() + other.volume() - intersection_vol;
296        if union_vol == 0.0 {
297            return 0.0;
298        }
299        intersection_vol / union_vol
300    }
301
302    /// Compute minimum Euclidean distance between two boxes.
303    ///
304    /// Returns 0.0 if boxes overlap.
305    #[must_use]
306    pub fn distance(&self, other: &Self) -> f32 {
307        assert_eq!(
308            self.dim(),
309            other.dim(),
310            "Boxes must have same dimension for distance"
311        );
312
313        let dist_sq: f32 = self
314            .min
315            .iter()
316            .zip(self.max.iter())
317            .zip(other.min.iter().zip(other.max.iter()))
318            .map(|((&min1, &max1), (&min2, &max2))| {
319                // Gap in this dimension
320                let gap = if max1 < min2 {
321                    min2 - max1 // other is to the right
322                } else if max2 < min1 {
323                    min1 - max2 // other is to the left
324                } else {
325                    0.0 // overlap in this dimension
326                };
327                gap * gap
328            })
329            .sum();
330
331        dist_sq.sqrt()
332    }
333}
334
335// =============================================================================
336// Subsume Trait Implementation (optional, feature-gated)
337// =============================================================================
338
339/// Implements the subsume-core Box trait when the `subsume` feature is enabled.
340///
341/// This allows anno's BoxEmbedding to be used with subsume's distance metrics,
342/// training utilities, and other advanced box operations.
343#[cfg(feature = "subsume")]
344impl subsume_core::Box for BoxEmbedding {
345    type Scalar = f32;
346    type Vector = Vec<f32>;
347
348    fn min(&self) -> &Self::Vector {
349        &self.min
350    }
351
352    fn max(&self) -> &Self::Vector {
353        &self.max
354    }
355
356    fn dim(&self) -> usize {
357        self.min.len()
358    }
359
360    fn volume(&self, _temperature: Self::Scalar) -> Result<Self::Scalar, subsume_core::BoxError> {
361        // anno's BoxEmbedding doesn't use temperature (hard boxes)
362        Ok(BoxEmbedding::volume(self))
363    }
364
365    fn intersection(&self, other: &Self) -> Result<Self, subsume_core::BoxError> {
366        if self.dim() != other.dim() {
367            return Err(subsume_core::BoxError::DimensionMismatch {
368                expected: self.dim(),
369                actual: other.dim(),
370            });
371        }
372        Ok(BoxEmbedding::intersection(self, other))
373    }
374
375    fn containment_prob(
376        &self,
377        other: &Self,
378        _temperature: Self::Scalar,
379    ) -> Result<Self::Scalar, subsume_core::BoxError> {
380        if self.dim() != other.dim() {
381            return Err(subsume_core::BoxError::DimensionMismatch {
382                expected: self.dim(),
383                actual: other.dim(),
384            });
385        }
386        // subsume: P(other ⊆ self) = Vol(intersection) / Vol(other)
387        // This is the same as anno's conditional_probability but with swapped args
388        Ok(self.conditional_probability(other))
389    }
390
391    fn overlap_prob(
392        &self,
393        other: &Self,
394        _temperature: Self::Scalar,
395    ) -> Result<Self::Scalar, subsume_core::BoxError> {
396        if self.dim() != other.dim() {
397            return Err(subsume_core::BoxError::DimensionMismatch {
398                expected: self.dim(),
399                actual: other.dim(),
400            });
401        }
402        Ok(BoxEmbedding::overlap_prob(self, other))
403    }
404
405    fn union(&self, other: &Self) -> Result<Self, subsume_core::BoxError> {
406        if self.dim() != other.dim() {
407            return Err(subsume_core::BoxError::DimensionMismatch {
408                expected: self.dim(),
409                actual: other.dim(),
410            });
411        }
412        Ok(BoxEmbedding::union(self, other))
413    }
414
415    fn center(&self) -> Result<Self::Vector, subsume_core::BoxError> {
416        Ok(BoxEmbedding::center(self))
417    }
418
419    fn distance(&self, other: &Self) -> Result<Self::Scalar, subsume_core::BoxError> {
420        if self.dim() != other.dim() {
421            return Err(subsume_core::BoxError::DimensionMismatch {
422                expected: self.dim(),
423                actual: other.dim(),
424            });
425        }
426        Ok(BoxEmbedding::distance(self, other))
427    }
428
429    fn truncate(&self, k: usize) -> Result<Self, subsume_core::BoxError> {
430        if k > self.dim() {
431            return Err(subsume_core::BoxError::MatryoshkaMismatch {
432                requested: k,
433                actual: self.dim(),
434            });
435        }
436        Ok(BoxEmbedding::new(
437            self.min[..k].to_vec(),
438            self.max[..k].to_vec(),
439        ))
440    }
441}
442
443/// Configuration for box-based coreference resolution.
444pub mod extras;
445pub use extras::*;
446#[cfg(test)]
447mod tests;