Skip to main content

anno/backends/
box_embeddings.rs

1//! Box embeddings for coreference resolution.
2//!
3//! This module implements geometric representations (box embeddings) that encode
4//! logical invariants of coreference resolution, addressing limitations of
5//! vector-based approaches.
6//!
7//! **Note**: Training code is in `box_embeddings_training.rs`. The [matryoshka-box](https://github.com/arclabs561/matryoshka-box)
8//! research project extends training with matryoshka-specific features (variable dimensions, etc.).
9//!
10//! # Key Concepts
11//!
12//! - **Box Embeddings**: Entities represented as axis-aligned hyperrectangles
13//! - **Conditional Probability**: Coreference = high mutual overlap
14//! - **Temporal Boxes**: Entities that evolve over time
15//! - **Uncertainty-Aware**: Box volume = confidence
16//!
17//! # Research Background
18//!
19//! This implementation is related to the **matryoshka-box** research project (not yet published),
20//! which combines matryoshka embeddings (variable dimensions) with box embeddings (hierarchical reasoning).
21//! Standard training is in `box_embeddings_training.rs`; matryoshka-box extends it with research features.
22//!
23//! Based on research from:
24//! - Vilnis et al. (2018): "Probabilistic Embedding of Knowledge Graphs with Box Lattice Measures"
25//! - Lee et al. (2022): "Box Embeddings for Event-Event Relation Extraction" (BERE)
26//! - Messner et al. (2022): "Temporal Knowledge Graph Completion with Box Embeddings" (BoxTE)
27//! - Chen et al. (2021): "Uncertainty-Aware Knowledge Graph Embeddings" (UKGE)
28//!
29//! # Complementary Geometric Representations
30//!
31//! Box embeddings are one of several geometric approaches available in Anno.
32//! See `archive/geometric-2024-12/` for alternatives:
33//!
34//! | Representation | Best For | Module |
35//! |---------------|----------|--------|
36//! | **Box embeddings** | Temporal, uncertainty | This module |
37//! | Hyperbolic (Poincaré) | Deep type hierarchies | `archive/geometric-2024-12/hyperbolic.rs` |
38//! | Sheaf NN | Gradient-level transitivity | `archive/geometric-2024-12/sheaf.rs` |
39//! | TDA | Structural diagnostics | `archive/geometric-2024-12/tda.rs` |
40//!
41//! These approaches are **complementary**, not competing. Use boxes when you need:
42//! - Explicit uncertainty (volume = confidence)
43//! - Temporal evolution (min/max with velocity)
44//! - Easy visualization and debugging
45
46use serde::{Deserialize, Serialize};
47use std::f32;
48
49/// A box embedding representing an entity in d-dimensional space.
50///
51/// Boxes are axis-aligned hyperrectangles defined by min/max bounds in each dimension.
52/// Coreference is modeled as high mutual conditional probability (overlap).
53#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
54pub struct BoxEmbedding {
55    /// Lower bound in each dimension (d-dimensional vector).
56    pub min: Vec<f32>,
57    /// Upper bound in each dimension (d-dimensional vector).
58    pub max: Vec<f32>,
59}
60
61impl BoxEmbedding {
62    /// Create a new box embedding.
63    ///
64    /// # Panics
65    ///
66    /// Panics if `min.len() != max.len()` or if any `min[i] > max[i]`.
67    pub fn new(min: Vec<f32>, max: Vec<f32>) -> Self {
68        assert_eq!(min.len(), max.len(), "min and max must have same dimension");
69        for (i, (&m, &max_val)) in min.iter().zip(max.iter()).enumerate() {
70            assert!(
71                m <= max_val,
72                "min[{}] = {} must be <= max[{}] = {}",
73                i,
74                m,
75                i,
76                max_val
77            );
78        }
79        Self { min, max }
80    }
81
82    /// Get the dimension of the box.
83    #[must_use]
84    pub fn dim(&self) -> usize {
85        self.min.len()
86    }
87
88    /// Compute the volume of the box.
89    ///
90    /// Volume = product of (max - min) for each dimension.
91    #[must_use]
92    pub fn volume(&self) -> f32 {
93        self.min
94            .iter()
95            .zip(self.max.iter())
96            .map(|(&m, &max_val)| (max_val - m).max(0.0))
97            .product()
98    }
99
100    /// Compute the intersection volume with another box.
101    ///
102    /// Returns 0.0 if boxes are disjoint.
103    #[must_use]
104    pub fn intersection_volume(&self, other: &Self) -> f32 {
105        assert_eq!(
106            self.dim(),
107            other.dim(),
108            "Boxes must have same dimension for intersection"
109        );
110
111        self.min
112            .iter()
113            .zip(self.max.iter())
114            .zip(other.min.iter().zip(other.max.iter()))
115            .map(|((&m1, &max1), (&m2, &max2))| {
116                let intersection_min = m1.max(m2);
117                let intersection_max = max1.min(max2);
118                (intersection_max - intersection_min).max(0.0)
119            })
120            .product()
121    }
122
123    /// Compute conditional probability P(self | other).
124    ///
125    /// This is the BERE model's coreference metric:
126    /// P(A|B) = Vol(A ∩ B) / Vol(B)
127    ///
128    /// Returns a value in [0.0, 1.0] where:
129    /// - 1.0 = self is completely contained in other
130    /// - 0.0 = boxes are disjoint
131    #[must_use]
132    pub fn conditional_probability(&self, other: &Self) -> f32 {
133        let vol_other = other.volume();
134        if vol_other == 0.0 {
135            return 0.0;
136        }
137        self.intersection_volume(other) / vol_other
138    }
139
140    /// Compute mutual coreference score.
141    ///
142    /// Coreference requires high mutual conditional probability:
143    /// score = (P(A|B) + P(B|A)) / 2
144    ///
145    /// This ensures both boxes largely contain each other (high overlap).
146    #[must_use]
147    pub fn coreference_score(&self, other: &Self) -> f32 {
148        let p_a_given_b = self.conditional_probability(other);
149        let p_b_given_a = other.conditional_probability(self);
150        (p_a_given_b + p_b_given_a) / 2.0
151    }
152
153    /// Check if this box is contained in another box.
154    ///
155    /// Returns true if self ⊆ other (all dimensions).
156    #[must_use]
157    pub fn is_contained_in(&self, other: &Self) -> bool {
158        assert_eq!(self.dim(), other.dim(), "Boxes must have same dimension");
159        self.min
160            .iter()
161            .zip(self.max.iter())
162            .zip(other.min.iter().zip(other.max.iter()))
163            .all(|((&m1, &max1), (&m2, &max2))| m2 <= m1 && max1 <= max2)
164    }
165
166    /// Check if boxes are disjoint (no overlap).
167    #[must_use]
168    pub fn is_disjoint(&self, other: &Self) -> bool {
169        self.intersection_volume(other) == 0.0
170    }
171
172    /// Create a box embedding from a vector embedding.
173    ///
174    /// Converts a point embedding to a box by creating a small hypercube
175    /// around the point. The box size is controlled by `radius`.
176    ///
177    /// # Arguments
178    ///
179    /// * `vector` - Vector embedding (point in space)
180    /// * `radius` - Half-width of the box in each dimension
181    ///
182    /// # Example
183    ///
184    /// ```rust,ignore
185    /// let vector = vec![0.5, 0.5, 0.5];
186    /// let box_embedding = BoxEmbedding::from_vector(&vector, 0.1);
187    /// // Creates box: min=[0.4, 0.4, 0.4], max=[0.6, 0.6, 0.6]
188    /// ```
189    #[must_use]
190    pub fn from_vector(vector: &[f32], radius: f32) -> Self {
191        let min: Vec<f32> = vector.iter().map(|&v| v - radius).collect();
192        let max: Vec<f32> = vector.iter().map(|&v| v + radius).collect();
193        Self::new(min, max)
194    }
195
196    /// Create a box embedding from a vector with adaptive radius.
197    ///
198    /// Uses a radius proportional to the vector's magnitude, creating
199    /// larger boxes for vectors further from the origin.
200    ///
201    /// # Arguments
202    ///
203    /// * `vector` - Vector embedding
204    /// * `radius_factor` - Multiplier for adaptive radius (default: 0.1)
205    #[must_use]
206    pub fn from_vector_adaptive(vector: &[f32], radius_factor: f32) -> Self {
207        let magnitude: f32 = vector.iter().map(|&v| v * v).sum::<f32>().sqrt();
208        let radius = magnitude * radius_factor + 0.01; // Add small epsilon
209        Self::from_vector(vector, radius)
210    }
211
212    /// Get the center point of the box.
213    ///
214    /// Returns the midpoint in each dimension.
215    #[must_use]
216    pub fn center(&self) -> Vec<f32> {
217        self.min
218            .iter()
219            .zip(self.max.iter())
220            .map(|(&m, &max_val)| (m + max_val) / 2.0)
221            .collect()
222    }
223
224    /// Get the size (width) in each dimension.
225    #[must_use]
226    pub fn size(&self) -> Vec<f32> {
227        self.min
228            .iter()
229            .zip(self.max.iter())
230            .map(|(&m, &max_val)| (max_val - m).max(0.0))
231            .collect()
232    }
233
234    /// Compute the intersection box with another box.
235    ///
236    /// Returns a new box representing the overlapping region.
237    /// If boxes are disjoint, returns a zero-volume box.
238    #[must_use]
239    pub fn intersection(&self, other: &Self) -> Self {
240        assert_eq!(
241            self.dim(),
242            other.dim(),
243            "Boxes must have same dimension for intersection"
244        );
245
246        let min: Vec<f32> = self
247            .min
248            .iter()
249            .zip(other.min.iter())
250            .map(|(&a, &b)| a.max(b))
251            .collect();
252
253        let max: Vec<f32> = self
254            .max
255            .iter()
256            .zip(other.max.iter())
257            .map(|(&a, &b)| a.min(b))
258            .collect();
259
260        Self { min, max }
261    }
262
263    /// Compute the union box (bounding box containing both).
264    #[must_use]
265    pub fn union(&self, other: &Self) -> Self {
266        assert_eq!(
267            self.dim(),
268            other.dim(),
269            "Boxes must have same dimension for union"
270        );
271
272        let min: Vec<f32> = self
273            .min
274            .iter()
275            .zip(other.min.iter())
276            .map(|(&a, &b)| a.min(b))
277            .collect();
278
279        let max: Vec<f32> = self
280            .max
281            .iter()
282            .zip(other.max.iter())
283            .map(|(&a, &b)| a.max(b))
284            .collect();
285
286        Self { min, max }
287    }
288
289    /// Compute overlap probability (Jaccard-style).
290    ///
291    /// P(overlap) = Vol(intersection) / Vol(union)
292    #[must_use]
293    pub fn overlap_prob(&self, other: &Self) -> f32 {
294        let intersection_vol = self.intersection_volume(other);
295        let union_vol = self.volume() + other.volume() - intersection_vol;
296        if union_vol == 0.0 {
297            return 0.0;
298        }
299        intersection_vol / union_vol
300    }
301
302    /// Compute minimum Euclidean distance between two boxes.
303    ///
304    /// Returns 0.0 if boxes overlap.
305    #[must_use]
306    pub fn distance(&self, other: &Self) -> f32 {
307        assert_eq!(
308            self.dim(),
309            other.dim(),
310            "Boxes must have same dimension for distance"
311        );
312
313        let dist_sq: f32 = self
314            .min
315            .iter()
316            .zip(self.max.iter())
317            .zip(other.min.iter().zip(other.max.iter()))
318            .map(|((&min1, &max1), (&min2, &max2))| {
319                // Gap in this dimension
320                let gap = if max1 < min2 {
321                    min2 - max1 // other is to the right
322                } else if max2 < min1 {
323                    min1 - max2 // other is to the left
324                } else {
325                    0.0 // overlap in this dimension
326                };
327                gap * gap
328            })
329            .sum();
330
331        dist_sq.sqrt()
332    }
333}
334
335// =============================================================================
336// Subsume Trait Implementation (optional, feature-gated)
337// =============================================================================
338
339/// Implements the subsume-core Box trait when the `subsume` feature is enabled.
340///
341/// This allows anno's BoxEmbedding to be used with subsume's distance metrics,
342/// training utilities, and other advanced box operations.
343#[cfg(feature = "subsume")]
344impl subsume_core::Box for BoxEmbedding {
345    type Scalar = f32;
346    type Vector = Vec<f32>;
347
348    fn min(&self) -> &Self::Vector {
349        &self.min
350    }
351
352    fn max(&self) -> &Self::Vector {
353        &self.max
354    }
355
356    fn dim(&self) -> usize {
357        self.min.len()
358    }
359
360    fn volume(&self, _temperature: Self::Scalar) -> Result<Self::Scalar, subsume_core::BoxError> {
361        // anno's BoxEmbedding doesn't use temperature (hard boxes)
362        Ok(BoxEmbedding::volume(self))
363    }
364
365    fn intersection(&self, other: &Self) -> Result<Self, subsume_core::BoxError> {
366        if self.dim() != other.dim() {
367            return Err(subsume_core::BoxError::DimensionMismatch {
368                expected: self.dim(),
369                actual: other.dim(),
370            });
371        }
372        Ok(BoxEmbedding::intersection(self, other))
373    }
374
375    fn containment_prob(
376        &self,
377        other: &Self,
378        _temperature: Self::Scalar,
379    ) -> Result<Self::Scalar, subsume_core::BoxError> {
380        if self.dim() != other.dim() {
381            return Err(subsume_core::BoxError::DimensionMismatch {
382                expected: self.dim(),
383                actual: other.dim(),
384            });
385        }
386        // subsume: P(other ⊆ self) = Vol(intersection) / Vol(other)
387        // This is the same as anno's conditional_probability but with swapped args
388        Ok(self.conditional_probability(other))
389    }
390
391    fn overlap_prob(
392        &self,
393        other: &Self,
394        _temperature: Self::Scalar,
395    ) -> Result<Self::Scalar, subsume_core::BoxError> {
396        if self.dim() != other.dim() {
397            return Err(subsume_core::BoxError::DimensionMismatch {
398                expected: self.dim(),
399                actual: other.dim(),
400            });
401        }
402        Ok(BoxEmbedding::overlap_prob(self, other))
403    }
404
405    fn union(&self, other: &Self) -> Result<Self, subsume_core::BoxError> {
406        if self.dim() != other.dim() {
407            return Err(subsume_core::BoxError::DimensionMismatch {
408                expected: self.dim(),
409                actual: other.dim(),
410            });
411        }
412        Ok(BoxEmbedding::union(self, other))
413    }
414
415    fn center(&self) -> Result<Self::Vector, subsume_core::BoxError> {
416        Ok(BoxEmbedding::center(self))
417    }
418
419    fn distance(&self, other: &Self) -> Result<Self::Scalar, subsume_core::BoxError> {
420        if self.dim() != other.dim() {
421            return Err(subsume_core::BoxError::DimensionMismatch {
422                expected: self.dim(),
423                actual: other.dim(),
424            });
425        }
426        Ok(BoxEmbedding::distance(self, other))
427    }
428}
429
430/// Configuration for box-based coreference resolution.
431#[derive(Debug, Clone)]
432pub struct BoxCorefConfig {
433    /// Minimum coreference score to link entities
434    pub coreference_threshold: f32,
435    /// Whether to enforce syntactic constraints (Principle B/C)
436    pub enforce_syntactic_constraints: bool,
437    /// Maximum token distance for local domain (Principle B)
438    pub max_local_distance: usize,
439    /// Radius for converting vector embeddings to boxes (if using vectors)
440    pub vector_to_box_radius: Option<f32>,
441}
442
443impl Default for BoxCorefConfig {
444    fn default() -> Self {
445        Self {
446            coreference_threshold: 0.7,
447            enforce_syntactic_constraints: true,
448            max_local_distance: 5,
449            vector_to_box_radius: Some(0.1),
450        }
451    }
452}
453
454// =============================================================================
455// Temporal Boxes (BoxTE-style)
456// =============================================================================
457
458/// A temporal box embedding that evolves over time.
459///
460/// Based on BoxTE (Messner et al., 2022), this models entities that change
461/// over time. For example, "The President" refers to Obama in 2012 but
462/// Trump in 2017 - they should not corefer despite the same surface form.
463///
464/// # Example
465///
466/// ```rust,ignore
467/// use anno::backends::box_embeddings::{BoxEmbedding, TemporalBox, BoxVelocity};
468///
469/// // "The President" in 2012 (Obama)
470/// let base = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
471/// let velocity = BoxVelocity::new(vec![0.0, 0.0], vec![0.0, 0.0]); // Static
472/// let obama_presidency = TemporalBox::new(base, velocity, (2012.0, 2016.0));
473///
474/// // "The President" in 2017 (Trump)
475/// let trump_base = BoxEmbedding::new(vec![5.0, 5.0], vec![6.0, 6.0]);
476/// let trump_presidency = TemporalBox::new(trump_base, velocity, (2017.0, 2021.0));
477///
478/// // Should not corefer (different time ranges)
479/// assert_eq!(obama_presidency.coreference_at_time(&trump_presidency, 2015.0), 0.0);
480/// ```
481#[derive(Debug, Clone, PartialEq)]
482pub struct TemporalBox {
483    /// Base box at time t=0 (or reference time)
484    pub base: BoxEmbedding,
485    /// Velocity: how box moves/resizes per time unit
486    pub velocity: BoxVelocity,
487    /// Time range where this box is valid [start, end)
488    pub time_range: (f64, f64),
489}
490
491/// Velocity of a temporal box (change per time unit).
492#[derive(Debug, Clone, PartialEq)]
493pub struct BoxVelocity {
494    /// Change in min bounds per time unit (d-dimensional vector).
495    pub min_delta: Vec<f32>,
496    /// Change in max bounds per time unit (d-dimensional vector).
497    pub max_delta: Vec<f32>,
498}
499
500impl BoxVelocity {
501    /// Create a new box velocity (static by default).
502    #[must_use]
503    pub fn new(min_delta: Vec<f32>, max_delta: Vec<f32>) -> Self {
504        Self {
505            min_delta,
506            max_delta,
507        }
508    }
509
510    /// Create a static velocity (no change over time).
511    #[must_use]
512    pub fn static_velocity(dim: usize) -> Self {
513        Self {
514            min_delta: vec![0.0; dim],
515            max_delta: vec![0.0; dim],
516        }
517    }
518}
519
520impl TemporalBox {
521    /// Create a new temporal box.
522    ///
523    /// # Arguments
524    ///
525    /// * `base` - Base box at reference time
526    /// * `velocity` - How box evolves per time unit
527    /// * `time_range` - (start, end) time range where box is valid
528    #[must_use]
529    pub fn new(base: BoxEmbedding, velocity: BoxVelocity, time_range: (f64, f64)) -> Self {
530        assert_eq!(
531            base.dim(),
532            velocity.min_delta.len(),
533            "base and velocity must have same dimension"
534        );
535        assert_eq!(
536            velocity.min_delta.len(),
537            velocity.max_delta.len(),
538            "velocity min and max deltas must have same dimension"
539        );
540        Self {
541            base,
542            velocity,
543            time_range,
544        }
545    }
546
547    /// Get the box at a specific time.
548    ///
549    /// Returns None if time is outside the valid range.
550    #[must_use]
551    pub fn at_time(&self, time: f64) -> Option<BoxEmbedding> {
552        if time < self.time_range.0 || time >= self.time_range.1 {
553            return None;
554        }
555
556        // Compute time offset from reference (using start of range as reference)
557        let time_offset = time - self.time_range.0;
558
559        // Apply velocity to base box
560        let new_min: Vec<f32> = self
561            .base
562            .min
563            .iter()
564            .zip(self.velocity.min_delta.iter())
565            .map(|(&m, &delta)| m + delta * time_offset as f32)
566            .collect();
567
568        let new_max: Vec<f32> = self
569            .base
570            .max
571            .iter()
572            .zip(self.velocity.max_delta.iter())
573            .map(|(&max_val, &delta)| max_val + delta * time_offset as f32)
574            .collect();
575
576        Some(BoxEmbedding::new(new_min, new_max))
577    }
578
579    /// Compute coreference score at a specific time.
580    ///
581    /// Returns 0.0 if either box is invalid at the given time.
582    #[must_use]
583    pub fn coreference_at_time(&self, other: &Self, time: f64) -> f32 {
584        let box_a = match self.at_time(time) {
585            Some(b) => b,
586            None => return 0.0,
587        };
588        let box_b = match other.at_time(time) {
589            Some(b) => b,
590            None => return 0.0,
591        };
592        box_a.coreference_score(&box_b)
593    }
594
595    /// Check if this temporal box is valid at the given time.
596    #[must_use]
597    pub fn is_valid_at(&self, time: f64) -> bool {
598        time >= self.time_range.0 && time < self.time_range.1
599    }
600}
601
602// =============================================================================
603// Uncertainty-Aware Boxes (UKGE-style)
604// =============================================================================
605
606/// An uncertainty-aware box embedding (UKGE-style).
607///
608/// Based on UKGE (Chen et al., 2021), box volume represents confidence:
609/// - Small box = high confidence (precise, trusted fact)
610/// - Large box = low confidence (vague, uncertain, or dubious claim)
611///
612/// This enables conflict detection: if two high-confidence boxes are disjoint,
613/// they represent contradictory claims.
614///
615/// # Example
616///
617/// ```rust,ignore
618/// use anno::backends::box_embeddings::{BoxEmbedding, UncertainBox};
619///
620/// // High-confidence claim: "Trump is in NY" (small, precise box)
621/// let claim_a = UncertainBox::new(
622///     BoxEmbedding::new(vec![0.0, 0.0], vec![0.1, 0.1]), // Small = high confidence
623///     0.95, // Source trust
624/// );
625///
626/// // Contradictory claim: "Trump is in FL" (also high confidence, but disjoint)
627/// let claim_b = UncertainBox::new(
628///     BoxEmbedding::new(vec![5.0, 5.0], vec![5.1, 5.1]), // Disjoint from claim_a
629///     0.90,
630/// );
631///
632/// // Should detect conflict
633/// assert!(claim_a.detect_conflict(&claim_b).is_some());
634/// ```
635#[derive(Debug, Clone, PartialEq)]
636pub struct UncertainBox {
637    /// The underlying box embedding
638    pub box_embedding: BoxEmbedding,
639    /// Source trustworthiness (0.0-1.0)
640    pub source_trust: f32,
641}
642
643impl UncertainBox {
644    /// Create a new uncertainty-aware box.
645    ///
646    /// Confidence is derived from box volume (smaller = higher confidence).
647    #[must_use]
648    pub fn new(box_embedding: BoxEmbedding, source_trust: f32) -> Self {
649        assert!(
650            (0.0..=1.0).contains(&source_trust),
651            "source_trust must be in [0.0, 1.0]"
652        );
653        Self {
654            box_embedding,
655            source_trust,
656        }
657    }
658
659    /// Get confidence derived from box volume.
660    ///
661    /// Smaller boxes = higher confidence. This is a heuristic:
662    /// confidence ≈ 1.0 / (1.0 + volume)
663    #[must_use]
664    pub fn confidence(&self) -> f32 {
665        let vol = self.box_embedding.volume();
666        // Normalize: confidence decreases as volume increases
667        // Using sigmoid-like function: 1 / (1 + volume)
668        1.0 / (1.0 + vol)
669    }
670
671    /// Detect conflict with another uncertain box.
672    ///
673    /// Returns Some(Conflict) if both boxes are high-confidence but disjoint,
674    /// indicating contradictory claims.
675    #[must_use]
676    pub fn detect_conflict(&self, other: &Self) -> Option<Conflict> {
677        let overlap = self.box_embedding.intersection_volume(&other.box_embedding);
678        let min_vol = self
679            .box_embedding
680            .volume()
681            .min(other.box_embedding.volume());
682
683        // If both are high-confidence (small volume) but disjoint, conflict
684        let conf_a = self.confidence();
685        let conf_b = other.confidence();
686        let threshold = 0.8;
687
688        if overlap < min_vol * 0.1 && conf_a > threshold && conf_b > threshold {
689            Some(Conflict {
690                claim_a_trust: self.source_trust,
691                claim_b_trust: other.source_trust,
692                severity: (1.0 - overlap / min_vol.max(1e-6)) * (conf_a + conf_b) / 2.0,
693            })
694        } else {
695            None
696        }
697    }
698}
699
700/// Represents a conflict between two uncertain claims.
701#[derive(Debug, Clone, PartialEq)]
702pub struct Conflict {
703    /// Trust in first claim's source
704    pub claim_a_trust: f32,
705    /// Trust in second claim's source
706    pub claim_b_trust: f32,
707    /// Severity of conflict (0.0-1.0, higher = more severe)
708    pub severity: f32,
709}
710
711// =============================================================================
712// Interaction Modeling (Triple Intersection)
713// =============================================================================
714
715/// Compute interaction strength between actor, action, and target.
716///
717/// Models asymmetric relations (e.g., "Company A acquired Company B")
718/// via triple intersection volume. The interaction is the volume where
719/// all three boxes overlap.
720///
721/// # Arguments
722///
723/// * `actor_box` - Box for the actor (e.g., buyer)
724/// * `action_box` - Box for the action/relation (e.g., "acquired")
725/// * `target_box` - Box for the target (e.g., company being acquired)
726///
727/// # Returns
728///
729/// Conditional probability P(action, target | actor), representing
730/// how much of the actor's space contains the interaction.
731#[must_use]
732pub fn interaction_strength(
733    actor_box: &BoxEmbedding,
734    action_box: &BoxEmbedding,
735    target_box: &BoxEmbedding,
736) -> f32 {
737    // Triple intersection: where all three boxes overlap
738    // For simplicity, we compute pairwise intersections and take minimum
739    // In full implementation, would compute true 3-way intersection
740    let actor_action = actor_box.intersection_volume(action_box);
741    let action_target = action_box.intersection_volume(target_box);
742    let actor_target = actor_box.intersection_volume(target_box);
743
744    // Interaction volume ≈ minimum of pairwise intersections
745    let interaction_vol = actor_action.min(action_target).min(actor_target);
746
747    // P(interaction | actor) = interaction_vol / vol(actor)
748    let vol_actor = actor_box.volume();
749    if vol_actor == 0.0 {
750        return 0.0;
751    }
752    interaction_vol / vol_actor
753}
754
755/// Compute asymmetric roles in a relation.
756///
757/// For a relation like "acquired", determines which entity is the
758/// buyer vs. seller based on conditional probabilities.
759///
760/// # Returns
761///
762/// (buyer_role, seller_role) where each is the interaction strength
763/// for that role.
764#[must_use]
765pub fn acquisition_roles(
766    entity_a: &BoxEmbedding,
767    entity_b: &BoxEmbedding,
768    acquisition_box: &BoxEmbedding,
769) -> (f32, f32) {
770    let buyer_role = interaction_strength(entity_a, acquisition_box, entity_b);
771    let seller_role = interaction_strength(entity_b, acquisition_box, entity_a);
772    (buyer_role, seller_role)
773}
774
775// =============================================================================
776// Gumbel Boxes (Noise Robustness)
777// =============================================================================
778
779/// A Gumbel box with soft, probabilistic boundaries.
780///
781/// Instead of hard walls, boundaries are modeled as Gumbel distributions,
782/// creating "fuzzy" boxes that tolerate slight misalignments. This prevents
783/// brittle logic failures when data is noisy.
784///
785/// # Example
786///
787/// ```rust,ignore
788/// use anno::backends::box_embeddings::{BoxEmbedding, GumbelBox};
789///
790/// let mean_box = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
791/// let gumbel_box = GumbelBox::new(mean_box, 0.1); // Low temperature = sharp
792///
793/// // Membership is probabilistic, not binary
794/// let point = vec![0.5, 0.5];
795/// let prob = gumbel_box.membership_probability(&point);
796/// assert!(prob > 0.5); // High probability inside box
797/// ```
798#[derive(Debug, Clone, PartialEq)]
799pub struct GumbelBox {
800    /// Mean box boundaries (lower bounds)
801    pub mean_min: Vec<f32>,
802    /// Mean box boundaries (upper bounds)
803    pub mean_max: Vec<f32>,
804    /// Temperature: controls fuzziness (higher = more fuzzy)
805    /// Typical values: 0.01-0.1 for sharp, 0.5-1.0 for fuzzy
806    pub temperature: f32,
807}
808
809impl GumbelBox {
810    /// Create a new Gumbel box.
811    #[must_use]
812    pub fn new(mean_box: BoxEmbedding, temperature: f32) -> Self {
813        assert!(
814            temperature > 0.0,
815            "temperature must be positive, got {}",
816            temperature
817        );
818        Self {
819            mean_min: mean_box.min,
820            mean_max: mean_box.max,
821            temperature,
822        }
823    }
824
825    /// Compute membership probability for a point.
826    ///
827    /// Returns probability that point belongs to this box (0.0-1.0).
828    /// Uses Gumbel CDF approximation for soft boundaries.
829    #[must_use]
830    pub fn membership_probability(&self, point: &[f32]) -> f32 {
831        assert_eq!(
832            point.len(),
833            self.mean_min.len(),
834            "point dimension must match box dimension"
835        );
836
837        let mut prob = 1.0;
838        for (i, &coord) in point.iter().enumerate() {
839            // Gumbel CDF approximation: P(x < max) ≈ 1 / (1 + exp(-(max - x) / temp))
840            // For min boundary: P(x > min) ≈ 1 / (1 + exp(-(x - min) / temp))
841            let min_prob = 1.0 / (1.0 + (-(coord - self.mean_min[i]) / self.temperature).exp());
842            let max_prob = 1.0 / (1.0 + (-(self.mean_max[i] - coord) / self.temperature).exp());
843            prob *= min_prob * max_prob;
844        }
845        prob
846    }
847
848    /// Compute robust coreference score with another Gumbel box.
849    ///
850    /// Samples points from self and checks membership in other, averaging
851    /// probabilities. This tolerates slight misalignments.
852    ///
853    /// # Arguments
854    ///
855    /// * `other` - The other Gumbel box to compare against
856    /// * `samples` - Number of sample points to use (more = more accurate but slower)
857    /// * `rng` - Optional RNG for sampling. If None, uses deterministic grid sampling.
858    #[must_use]
859    pub fn robust_coreference(&self, other: &Self, samples: usize) -> f32 {
860        assert_eq!(
861            self.mean_min.len(),
862            other.mean_min.len(),
863            "boxes must have same dimension"
864        );
865
866        // Deterministic grid sampling (no RNG dependency)
867        // For each dimension, sample at regular intervals
868        let samples_per_dim = (samples as f32)
869            .powf(1.0 / self.mean_min.len() as f32)
870            .ceil() as usize;
871        let mut total_prob = 0.0;
872        let mut count = 0;
873
874        // Generate grid points
875        let mut indices = vec![0; self.mean_min.len()];
876        loop {
877            // Compute point from grid indices
878            let point: Vec<f32> = self
879                .mean_min
880                .iter()
881                .zip(self.mean_max.iter())
882                .zip(indices.iter())
883                .map(|((&min_val, &max_val), &idx)| {
884                    let t = idx as f32 / (samples_per_dim - 1).max(1) as f32;
885                    min_val + t * (max_val - min_val)
886                })
887                .collect();
888
889            total_prob += other.membership_probability(&point);
890            count += 1;
891
892            // Increment grid indices
893            let mut carry = true;
894            for idx in &mut indices {
895                if carry {
896                    *idx += 1;
897                    if *idx >= samples_per_dim {
898                        *idx = 0;
899                        carry = true;
900                    } else {
901                        carry = false;
902                    }
903                }
904            }
905
906            if carry || count >= samples {
907                break;
908            }
909        }
910
911        total_prob / count as f32
912    }
913}
914
915// =============================================================================
916// Subsume Trait Implementations for GumbelBox
917// =============================================================================
918
919#[cfg(feature = "subsume")]
920impl subsume_core::Box for GumbelBox {
921    type Scalar = f32;
922    type Vector = Vec<f32>;
923
924    fn min(&self) -> &Self::Vector {
925        &self.mean_min
926    }
927
928    fn max(&self) -> &Self::Vector {
929        &self.mean_max
930    }
931
932    fn dim(&self) -> usize {
933        self.mean_min.len()
934    }
935
936    fn volume(&self, temperature: Self::Scalar) -> Result<Self::Scalar, subsume_core::BoxError> {
937        // Use log-space volume approximation for Gumbel boxes
938        let mut log_vol = 0.0;
939        for i in 0..self.dim() {
940            let diff = self.mean_max[i] - self.mean_min[i];
941            // Softplus approximation: temp * log(1 + exp(x/temp))
942            log_vol += (diff / temperature).exp().ln_1p() * temperature;
943        }
944        Ok(log_vol.exp())
945    }
946
947    fn intersection(&self, other: &Self) -> Result<Self, subsume_core::BoxError> {
948        if self.dim() != other.dim() {
949            return Err(subsume_core::BoxError::DimensionMismatch {
950                expected: self.dim(),
951                actual: other.dim(),
952            });
953        }
954
955        // Gumbel intersection uses LSE for max-stability
956        let mut new_min = Vec::with_capacity(self.dim());
957        let mut new_max = Vec::with_capacity(self.dim());
958
959        for i in 0..self.dim() {
960            let m1 = self.mean_min[i];
961            let m2 = other.mean_min[i];
962            let lse_min =
963                m1.max(m2) + self.temperature * (-(m1 - m2).abs() / self.temperature).exp().ln_1p();
964            new_min.push(lse_min);
965
966            let x1 = self.mean_max[i];
967            let x2 = other.mean_max[i];
968            let lse_max =
969                x1.min(x2) - self.temperature * (-(x1 - x2).abs() / self.temperature).exp().ln_1p();
970            new_max.push(lse_max);
971        }
972
973        Ok(GumbelBox {
974            mean_min: new_min,
975            mean_max: new_max,
976            temperature: self.temperature,
977        })
978    }
979
980    fn containment_prob(
981        &self,
982        other: &Self,
983        temperature: Self::Scalar,
984    ) -> Result<Self::Scalar, subsume_core::BoxError> {
985        let intersection = self.intersection(other)?;
986        let vol_int = intersection.volume(temperature)?;
987        let vol_other = other.volume(temperature)?;
988        if vol_other == 0.0 {
989            return Ok(0.0);
990        }
991        Ok(vol_int / vol_other)
992    }
993
994    fn overlap_prob(
995        &self,
996        other: &Self,
997        temperature: Self::Scalar,
998    ) -> Result<Self::Scalar, subsume_core::BoxError> {
999        let intersection = self.intersection(other)?;
1000        let vol_int = intersection.volume(temperature)?;
1001        let vol_self = self.volume(temperature)?;
1002        let vol_other = other.volume(temperature)?;
1003        let vol_union = vol_self + vol_other - vol_int;
1004        if vol_union <= 0.0 {
1005            return Ok(0.0);
1006        }
1007        Ok(vol_int / vol_union)
1008    }
1009
1010    fn union(&self, other: &Self) -> Result<Self, subsume_core::BoxError> {
1011        let mut new_min = Vec::with_capacity(self.dim());
1012        let mut new_max = Vec::with_capacity(self.dim());
1013        for i in 0..self.dim() {
1014            new_min.push(self.mean_min[i].min(other.mean_min[i]));
1015            new_max.push(self.mean_max[i].max(other.mean_max[i]));
1016        }
1017        Ok(GumbelBox {
1018            mean_min: new_min,
1019            mean_max: new_max,
1020            temperature: self.temperature,
1021        })
1022    }
1023
1024    fn center(&self) -> Result<Self::Vector, subsume_core::BoxError> {
1025        let mut center = Vec::with_capacity(self.dim());
1026        for i in 0..self.dim() {
1027            center.push((self.mean_min[i] + self.mean_max[i]) / 2.0);
1028        }
1029        Ok(center)
1030    }
1031
1032    fn distance(&self, other: &Self) -> Result<Self::Scalar, subsume_core::BoxError> {
1033        let mut dist_sq = 0.0;
1034        for i in 0..self.dim() {
1035            let gap = if self.mean_max[i] < other.mean_min[i] {
1036                other.mean_min[i] - self.mean_max[i]
1037            } else if other.mean_max[i] < self.mean_min[i] {
1038                self.mean_min[i] - other.mean_max[i]
1039            } else {
1040                0.0
1041            };
1042            dist_sq += gap * gap;
1043        }
1044        Ok(dist_sq.sqrt())
1045    }
1046}
1047
1048#[cfg(feature = "subsume")]
1049impl subsume_core::GumbelBox for GumbelBox {
1050    fn temperature(&self) -> Self::Scalar {
1051        self.temperature
1052    }
1053
1054    fn membership_probability(
1055        &self,
1056        point: &Self::Vector,
1057    ) -> Result<Self::Scalar, subsume_core::BoxError> {
1058        Ok(self.membership_probability(point))
1059    }
1060
1061    fn sample(&self) -> Self::Vector {
1062        self.center().unwrap_or_default()
1063    }
1064}
1065
1066// Note: BoxCorefResolver is implemented in src/eval/coref_resolver.rs
1067// to be alongside other coreference resolvers.
1068
1069#[cfg(test)]
1070mod tests {
1071    use super::*;
1072
1073    #[test]
1074    fn test_box_volume() {
1075        let box1 = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1076        assert_eq!(box1.volume(), 1.0);
1077
1078        let box2 = BoxEmbedding::new(vec![0.0, 0.0, 0.0], vec![2.0, 3.0, 4.0]);
1079        assert_eq!(box2.volume(), 24.0);
1080    }
1081
1082    #[test]
1083    fn test_intersection_volume() {
1084        let box1 = BoxEmbedding::new(vec![0.0, 0.0], vec![2.0, 2.0]);
1085        let box2 = BoxEmbedding::new(vec![1.0, 1.0], vec![3.0, 3.0]);
1086        assert_eq!(box1.intersection_volume(&box2), 1.0);
1087
1088        let box3 = BoxEmbedding::new(vec![5.0, 5.0], vec![6.0, 6.0]);
1089        assert_eq!(box1.intersection_volume(&box3), 0.0); // Disjoint
1090    }
1091
1092    #[test]
1093    fn test_conditional_probability() {
1094        let box_a = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]); // Volume = 1
1095        let box_b = BoxEmbedding::new(vec![0.0, 0.0], vec![2.0, 2.0]); // Volume = 4
1096
1097        // box_a is contained in box_b
1098        assert_eq!(box_a.conditional_probability(&box_b), 0.25); // 1/4
1099        assert_eq!(box_b.conditional_probability(&box_a), 1.0); // 4/4 (intersection = box_a)
1100    }
1101
1102    #[test]
1103    fn test_coreference_score() {
1104        // Identical boxes should have score = 1.0
1105        let box1 = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1106        let box2 = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1107        assert!((box1.coreference_score(&box2) - 1.0).abs() < 1e-6);
1108
1109        // Disjoint boxes should have score = 0.0
1110        let box3 = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1111        let box4 = BoxEmbedding::new(vec![2.0, 2.0], vec![3.0, 3.0]);
1112        assert_eq!(box3.coreference_score(&box4), 0.0);
1113    }
1114
1115    #[test]
1116    fn test_containment() {
1117        let box_a = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1118        let box_b = BoxEmbedding::new(vec![0.0, 0.0], vec![2.0, 2.0]);
1119        assert!(box_a.is_contained_in(&box_b));
1120        assert!(!box_b.is_contained_in(&box_a));
1121    }
1122
1123    #[test]
1124    fn test_box_operations() {
1125        // Test that box operations work correctly
1126        let box1 = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1127        let box2 = BoxEmbedding::new(vec![0.5, 0.5], vec![1.5, 1.5]);
1128
1129        // Should have intersection
1130        assert!(box1.intersection_volume(&box2) > 0.0);
1131
1132        // Coreference score should be > 0
1133        assert!(box1.coreference_score(&box2) > 0.0);
1134    }
1135
1136    #[test]
1137    fn test_from_vector() {
1138        let vector = vec![0.5, 0.5, 0.5];
1139        let box_embedding = BoxEmbedding::from_vector(&vector, 0.1);
1140
1141        assert_eq!(box_embedding.min, vec![0.4, 0.4, 0.4]);
1142        assert_eq!(box_embedding.max, vec![0.6, 0.6, 0.6]);
1143        assert!((box_embedding.volume() - 0.008).abs() < 1e-6); // 0.2^3 with float tolerance
1144    }
1145
1146    #[test]
1147    fn test_center_and_size() {
1148        let box_embedding = BoxEmbedding::new(vec![0.0, 1.0], vec![2.0, 3.0]);
1149        let center = box_embedding.center();
1150        let size = box_embedding.size();
1151
1152        assert_eq!(center, vec![1.0, 2.0]);
1153        assert_eq!(size, vec![2.0, 2.0]);
1154    }
1155
1156    // =========================================================================
1157    // Temporal Box Tests
1158    // =========================================================================
1159
1160    #[test]
1161    fn test_temporal_box_at_time() {
1162        let base = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1163        let velocity = BoxVelocity::static_velocity(2);
1164        let temporal = TemporalBox::new(base, velocity, (2012.0, 2016.0));
1165
1166        // Should be valid in range
1167        assert!(temporal.is_valid_at(2014.0));
1168        assert!(!temporal.is_valid_at(2017.0));
1169
1170        // Static velocity: box should be same at any time in range
1171        let box_at_time = temporal.at_time(2014.0).unwrap();
1172        assert_eq!(box_at_time.min, vec![0.0, 0.0]);
1173        assert_eq!(box_at_time.max, vec![1.0, 1.0]);
1174    }
1175
1176    #[test]
1177    fn test_temporal_box_with_velocity() {
1178        // Box that moves over time
1179        let base = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1180        let velocity = BoxVelocity::new(vec![0.1, 0.1], vec![0.1, 0.1]);
1181        let temporal = TemporalBox::new(base, velocity, (0.0, 10.0));
1182
1183        // At time 0, should be at base
1184        let box_t0 = temporal.at_time(0.0).unwrap();
1185        assert_eq!(box_t0.min, vec![0.0, 0.0]);
1186        assert_eq!(box_t0.max, vec![1.0, 1.0]);
1187
1188        // At time 5, should have moved
1189        let box_t5 = temporal.at_time(5.0).unwrap();
1190        assert_eq!(box_t5.min, vec![0.5, 0.5]); // 0.0 + 0.1 * 5
1191        assert_eq!(box_t5.max, vec![1.5, 1.5]); // 1.0 + 0.1 * 5
1192    }
1193
1194    #[test]
1195    fn test_temporal_box_coreference() {
1196        // Two presidencies that don't overlap in time
1197        let obama_base = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1198        let trump_base = BoxEmbedding::new(vec![5.0, 5.0], vec![6.0, 6.0]);
1199        let velocity = BoxVelocity::static_velocity(2);
1200
1201        let obama = TemporalBox::new(obama_base, velocity.clone(), (2012.0, 2016.0));
1202        let trump = TemporalBox::new(trump_base, velocity, (2017.0, 2021.0));
1203
1204        // Should not corefer (different time ranges)
1205        assert_eq!(obama.coreference_at_time(&trump, 2015.0), 0.0);
1206        assert_eq!(obama.coreference_at_time(&trump, 2018.0), 0.0);
1207    }
1208
1209    // =========================================================================
1210    // Uncertainty-Aware Box Tests
1211    // =========================================================================
1212
1213    #[test]
1214    fn test_uncertain_box_confidence() {
1215        // Small box = high confidence
1216        let small_box = BoxEmbedding::new(vec![0.0, 0.0], vec![0.1, 0.1]);
1217        let uncertain_small = UncertainBox::new(small_box, 0.9);
1218        assert!(uncertain_small.confidence() > 0.5);
1219
1220        // Large box = low confidence
1221        let large_box = BoxEmbedding::new(vec![0.0, 0.0], vec![10.0, 10.0]);
1222        let uncertain_large = UncertainBox::new(large_box, 0.9);
1223        assert!(uncertain_large.confidence() < uncertain_small.confidence());
1224    }
1225
1226    #[test]
1227    fn test_conflict_detection() {
1228        // Two high-confidence, disjoint claims = conflict
1229        let claim_a = UncertainBox::new(BoxEmbedding::new(vec![0.0, 0.0], vec![0.1, 0.1]), 0.95);
1230        let claim_b = UncertainBox::new(BoxEmbedding::new(vec![5.0, 5.0], vec![5.1, 5.1]), 0.90);
1231
1232        let conflict = claim_a.detect_conflict(&claim_b);
1233        assert!(conflict.is_some());
1234        if let Some(c) = conflict {
1235            assert!(c.severity > 0.0);
1236            assert_eq!(c.claim_a_trust, 0.95);
1237            assert_eq!(c.claim_b_trust, 0.90);
1238        }
1239    }
1240
1241    #[test]
1242    fn test_no_conflict_for_overlapping_boxes() {
1243        // Overlapping boxes should not conflict
1244        let claim_a = UncertainBox::new(BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]), 0.95);
1245        let claim_b = UncertainBox::new(BoxEmbedding::new(vec![0.5, 0.5], vec![1.5, 1.5]), 0.90);
1246
1247        let conflict = claim_a.detect_conflict(&claim_b);
1248        assert!(conflict.is_none()); // Overlapping = no conflict
1249    }
1250
1251    // =========================================================================
1252    // Gumbel Box Tests
1253    // =========================================================================
1254
1255    #[test]
1256    fn test_gumbel_box_membership() {
1257        let mean_box = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1258        let gumbel = GumbelBox::new(mean_box, 0.1);
1259
1260        // Point inside box should have high membership
1261        let inside = vec![0.5, 0.5];
1262        let prob_inside = gumbel.membership_probability(&inside);
1263        assert!(prob_inside > 0.5);
1264
1265        // Point outside box should have low membership
1266        let outside = vec![2.0, 2.0];
1267        let prob_outside = gumbel.membership_probability(&outside);
1268        assert!(prob_outside < prob_inside);
1269    }
1270
1271    #[test]
1272    fn test_gumbel_temperature_effect() {
1273        let mean_box = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1274        let sharp = GumbelBox::new(mean_box.clone(), 0.01); // Low temp = sharp
1275        let fuzzy = GumbelBox::new(mean_box, 1.0); // High temp = fuzzy
1276
1277        let point = vec![1.1, 1.1]; // Just outside box
1278        let prob_sharp = sharp.membership_probability(&point);
1279        let prob_fuzzy = fuzzy.membership_probability(&point);
1280
1281        // Fuzzy box should have higher probability for near-boundary points
1282        assert!(prob_fuzzy > prob_sharp);
1283    }
1284
1285    #[test]
1286    fn test_gumbel_robust_coreference() {
1287        let box1 = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1288        let box2 = BoxEmbedding::new(vec![0.1, 0.1], vec![0.9, 0.9]);
1289        let gumbel1 = GumbelBox::new(box1, 0.1);
1290        let gumbel2 = GumbelBox::new(box2, 0.1);
1291
1292        // Overlapping boxes should have high robust coreference
1293        let score = gumbel1.robust_coreference(&gumbel2, 100);
1294        assert!(score > 0.3);
1295    }
1296
1297    // =========================================================================
1298    // Interaction Modeling Tests
1299    // =========================================================================
1300
1301    #[test]
1302    fn test_interaction_strength() {
1303        let actor = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1304        let action = BoxEmbedding::new(vec![0.2, 0.2], vec![0.8, 0.8]);
1305        let target = BoxEmbedding::new(vec![0.3, 0.3], vec![0.7, 0.7]);
1306
1307        let strength = interaction_strength(&actor, &action, &target);
1308        assert!(strength > 0.0);
1309        assert!(strength <= 1.0);
1310    }
1311
1312    #[test]
1313    fn test_acquisition_roles() {
1314        let buyer = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1315        let seller = BoxEmbedding::new(vec![0.5, 0.5], vec![1.5, 1.5]);
1316        let acquisition = BoxEmbedding::new(vec![0.2, 0.2], vec![0.8, 0.8]);
1317
1318        let (buyer_role, seller_role) = acquisition_roles(&buyer, &seller, &acquisition);
1319
1320        // Both should have non-zero roles
1321        assert!(buyer_role >= 0.0);
1322        assert!(seller_role >= 0.0);
1323
1324        // Roles should be asymmetric (buyer ≠ seller in general)
1325        // Note: In this simple test, they might be equal, but in practice
1326        // with learned embeddings, they would differ
1327    }
1328
1329    // =========================================================================
1330    // New Methods Tests (intersection, union, overlap_prob, distance)
1331    // =========================================================================
1332
1333    #[test]
1334    fn test_intersection_box() {
1335        let a = BoxEmbedding::new(vec![0.0, 0.0], vec![2.0, 2.0]);
1336        let b = BoxEmbedding::new(vec![1.0, 1.0], vec![3.0, 3.0]);
1337
1338        let intersection = a.intersection(&b);
1339        assert_eq!(intersection.min, vec![1.0, 1.0]);
1340        assert_eq!(intersection.max, vec![2.0, 2.0]);
1341        assert_eq!(intersection.volume(), 1.0);
1342    }
1343
1344    #[test]
1345    fn test_union_box() {
1346        let a = BoxEmbedding::new(vec![0.0, 0.0], vec![2.0, 2.0]);
1347        let b = BoxEmbedding::new(vec![1.0, 1.0], vec![3.0, 3.0]);
1348
1349        let union = a.union(&b);
1350        assert_eq!(union.min, vec![0.0, 0.0]);
1351        assert_eq!(union.max, vec![3.0, 3.0]);
1352        assert_eq!(union.volume(), 9.0);
1353    }
1354
1355    #[test]
1356    fn test_overlap_prob() {
1357        // Identical boxes: overlap = 1.0
1358        let a = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1359        let b = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
1360        assert!((a.overlap_prob(&b) - 1.0).abs() < 0.001);
1361
1362        // Disjoint boxes: overlap = 0.0
1363        let c = BoxEmbedding::new(vec![5.0, 5.0], vec![6.0, 6.0]);
1364        assert!((a.overlap_prob(&c) - 0.0).abs() < 0.001);
1365
1366        // Partial overlap
1367        let d = BoxEmbedding::new(vec![0.5, 0.5], vec![1.5, 1.5]);
1368        let overlap = a.overlap_prob(&d);
1369        assert!(overlap > 0.0 && overlap < 1.0);
1370    }
1371
1372    #[test]
1373    fn test_distance() {
1374        // Overlapping boxes: distance = 0
1375        let a = BoxEmbedding::new(vec![0.0, 0.0], vec![2.0, 2.0]);
1376        let b = BoxEmbedding::new(vec![1.0, 1.0], vec![3.0, 3.0]);
1377        assert_eq!(a.distance(&b), 0.0);
1378
1379        // Disjoint boxes: distance > 0
1380        let c = BoxEmbedding::new(vec![5.0, 5.0], vec![6.0, 6.0]);
1381        let dist = a.distance(&c);
1382        assert!(dist > 0.0);
1383        // Distance should be sqrt((5-2)^2 + (5-2)^2) = sqrt(18) ≈ 4.24
1384        assert!((dist - 18.0_f32.sqrt()).abs() < 0.01);
1385    }
1386
1387    // =========================================================================
1388    // Subsume Trait Tests (feature-gated)
1389    // =========================================================================
1390
1391    #[test]
1392    #[cfg(feature = "subsume")]
1393    fn test_subsume_trait_implementation() {
1394        use subsume_core::Box as SubsumeBox;
1395
1396        let a = BoxEmbedding::new(vec![0.0, 0.0], vec![2.0, 2.0]);
1397        let b = BoxEmbedding::new(vec![0.5, 0.5], vec![1.5, 1.5]);
1398
1399        // Test trait methods
1400        assert_eq!(SubsumeBox::dim(&a), 2);
1401        assert_eq!(SubsumeBox::min(&a), &vec![0.0, 0.0]);
1402        assert_eq!(SubsumeBox::max(&a), &vec![2.0, 2.0]);
1403
1404        // Volume (temperature is ignored for hard boxes)
1405        let vol = SubsumeBox::volume(&a, 1.0).unwrap();
1406        assert_eq!(vol, 4.0);
1407
1408        // Containment prob: b is contained in a
1409        let containment = SubsumeBox::containment_prob(&a, &b, 1.0).unwrap();
1410        assert!(containment > 0.0);
1411
1412        // Distance
1413        let dist = SubsumeBox::distance(&a, &b).unwrap();
1414        assert_eq!(dist, 0.0); // Overlapping
1415
1416        // This verifies anno's BoxEmbedding is compatible with subsume's trait
1417        // and can use subsume's distance metrics, diagnostics, etc.
1418    }
1419}