Skip to main content

anno/backends/box_embeddings/
extras.rs

1//! BoxCorefConfig, TemporalBox, BoxVelocity, UncertainBox, and related types.
2
3use super::*;
4
5/// Configuration for box-embedding-based coreference.
6pub struct BoxCorefConfig {
7    /// Minimum coreference score to link entities
8    pub coreference_threshold: f32,
9    /// Whether to enforce syntactic constraints (Principle B/C)
10    pub enforce_syntactic_constraints: bool,
11    /// Maximum token distance for local domain (Principle B)
12    pub max_local_distance: usize,
13    /// Radius for converting vector embeddings to boxes (if using vectors)
14    pub vector_to_box_radius: Option<f32>,
15}
16
17impl Default for BoxCorefConfig {
18    fn default() -> Self {
19        Self {
20            coreference_threshold: 0.7,
21            enforce_syntactic_constraints: true,
22            max_local_distance: 5,
23            vector_to_box_radius: Some(0.1),
24        }
25    }
26}
27
28// =============================================================================
29// Temporal Boxes (BoxTE-style)
30// =============================================================================
31
32/// A temporal box embedding that evolves over time.
33///
34/// Based on BoxTE (Messner et al., 2022), this models entities that change
35/// over time. For example, "The President" refers to Obama in 2012 but
36/// Trump in 2017 - they should not corefer despite the same surface form.
37///
38/// # Example
39///
40/// ```rust,ignore
41/// use anno::backends::box_embeddings::{BoxEmbedding, TemporalBox, BoxVelocity};
42///
43/// // "The President" in 2012 (Obama)
44/// let base = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
45/// let velocity = BoxVelocity::new(vec![0.0, 0.0], vec![0.0, 0.0]); // Static
46/// let obama_presidency = TemporalBox::new(base, velocity, (2012.0, 2016.0));
47///
48/// // "The President" in 2017 (Trump)
49/// let trump_base = BoxEmbedding::new(vec![5.0, 5.0], vec![6.0, 6.0]);
50/// let trump_presidency = TemporalBox::new(trump_base, velocity, (2017.0, 2021.0));
51///
52/// // Should not corefer (different time ranges)
53/// assert_eq!(obama_presidency.coreference_at_time(&trump_presidency, 2015.0), 0.0);
54/// ```
55#[derive(Debug, Clone, PartialEq)]
56pub struct TemporalBox {
57    /// Base box at time t=0 (or reference time)
58    pub base: BoxEmbedding,
59    /// Velocity: how box moves/resizes per time unit
60    pub velocity: BoxVelocity,
61    /// Time range where this box is valid [start, end)
62    pub time_range: (f64, f64),
63}
64
65/// Velocity of a temporal box (change per time unit).
66#[derive(Debug, Clone, PartialEq)]
67pub struct BoxVelocity {
68    /// Change in min bounds per time unit (d-dimensional vector).
69    pub min_delta: Vec<f32>,
70    /// Change in max bounds per time unit (d-dimensional vector).
71    pub max_delta: Vec<f32>,
72}
73
74impl BoxVelocity {
75    /// Create a new box velocity (static by default).
76    #[must_use]
77    pub fn new(min_delta: Vec<f32>, max_delta: Vec<f32>) -> Self {
78        Self {
79            min_delta,
80            max_delta,
81        }
82    }
83
84    /// Create a static velocity (no change over time).
85    #[must_use]
86    pub fn static_velocity(dim: usize) -> Self {
87        Self {
88            min_delta: vec![0.0; dim],
89            max_delta: vec![0.0; dim],
90        }
91    }
92}
93
94impl TemporalBox {
95    /// Create a new temporal box.
96    ///
97    /// # Arguments
98    ///
99    /// * `base` - Base box at reference time
100    /// * `velocity` - How box evolves per time unit
101    /// * `time_range` - (start, end) time range where box is valid
102    #[must_use]
103    pub fn new(base: BoxEmbedding, velocity: BoxVelocity, time_range: (f64, f64)) -> Self {
104        assert_eq!(
105            base.dim(),
106            velocity.min_delta.len(),
107            "base and velocity must have same dimension"
108        );
109        assert_eq!(
110            velocity.min_delta.len(),
111            velocity.max_delta.len(),
112            "velocity min and max deltas must have same dimension"
113        );
114        Self {
115            base,
116            velocity,
117            time_range,
118        }
119    }
120
121    /// Get the box at a specific time.
122    ///
123    /// Returns None if time is outside the valid range.
124    #[must_use]
125    pub fn at_time(&self, time: f64) -> Option<BoxEmbedding> {
126        if time < self.time_range.0 || time >= self.time_range.1 {
127            return None;
128        }
129
130        // Compute time offset from reference (using start of range as reference)
131        let time_offset = time - self.time_range.0;
132
133        // Apply velocity to base box
134        let new_min: Vec<f32> = self
135            .base
136            .min
137            .iter()
138            .zip(self.velocity.min_delta.iter())
139            .map(|(&m, &delta)| m + delta * time_offset as f32)
140            .collect();
141
142        let new_max: Vec<f32> = self
143            .base
144            .max
145            .iter()
146            .zip(self.velocity.max_delta.iter())
147            .map(|(&max_val, &delta)| max_val + delta * time_offset as f32)
148            .collect();
149
150        Some(BoxEmbedding::new(new_min, new_max))
151    }
152
153    /// Compute coreference score at a specific time.
154    ///
155    /// Returns 0.0 if either box is invalid at the given time.
156    #[must_use]
157    pub fn coreference_at_time(&self, other: &Self, time: f64) -> f32 {
158        let box_a = match self.at_time(time) {
159            Some(b) => b,
160            None => return 0.0,
161        };
162        let box_b = match other.at_time(time) {
163            Some(b) => b,
164            None => return 0.0,
165        };
166        box_a.coreference_score(&box_b)
167    }
168
169    /// Check if this temporal box is valid at the given time.
170    #[must_use]
171    pub fn is_valid_at(&self, time: f64) -> bool {
172        time >= self.time_range.0 && time < self.time_range.1
173    }
174}
175
176// =============================================================================
177// Uncertainty-Aware Boxes (UKGE-style)
178// =============================================================================
179
180/// An uncertainty-aware box embedding (UKGE-style).
181///
182/// Based on UKGE (Chen et al., 2021), box volume represents confidence:
183/// - Small box = high confidence (precise, trusted fact)
184/// - Large box = low confidence (vague, uncertain, or dubious claim)
185///
186/// This enables conflict detection: if two high-confidence boxes are disjoint,
187/// they represent contradictory claims.
188///
189/// # Example
190///
191/// ```rust,ignore
192/// use anno::backends::box_embeddings::{BoxEmbedding, UncertainBox};
193///
194/// // High-confidence claim: "Trump is in NY" (small, precise box)
195/// let claim_a = UncertainBox::new(
196///     BoxEmbedding::new(vec![0.0, 0.0], vec![0.1, 0.1]), // Small = high confidence
197///     0.95, // Source trust
198/// );
199///
200/// // Contradictory claim: "Trump is in FL" (also high confidence, but disjoint)
201/// let claim_b = UncertainBox::new(
202///     BoxEmbedding::new(vec![5.0, 5.0], vec![5.1, 5.1]), // Disjoint from claim_a
203///     0.90,
204/// );
205///
206/// // Should detect conflict
207/// assert!(claim_a.detect_conflict(&claim_b).is_some());
208/// ```
209#[derive(Debug, Clone, PartialEq)]
210pub struct UncertainBox {
211    /// The underlying box embedding
212    pub box_embedding: BoxEmbedding,
213    /// Source trustworthiness (0.0-1.0)
214    pub source_trust: f32,
215}
216
217impl UncertainBox {
218    /// Create a new uncertainty-aware box.
219    ///
220    /// Confidence is derived from box volume (smaller = higher confidence).
221    #[must_use]
222    pub fn new(box_embedding: BoxEmbedding, source_trust: f32) -> Self {
223        assert!(
224            (0.0..=1.0).contains(&source_trust),
225            "source_trust must be in [0.0, 1.0]"
226        );
227        Self {
228            box_embedding,
229            source_trust,
230        }
231    }
232
233    /// Get confidence derived from box volume.
234    ///
235    /// Smaller boxes = higher confidence. This is a heuristic:
236    /// confidence ≈ 1.0 / (1.0 + volume)
237    #[must_use]
238    pub fn confidence(&self) -> f32 {
239        let vol = self.box_embedding.volume();
240        // Normalize: confidence decreases as volume increases
241        // Using sigmoid-like function: 1 / (1 + volume)
242        1.0 / (1.0 + vol)
243    }
244
245    /// Detect conflict with another uncertain box.
246    ///
247    /// Returns Some(Conflict) if both boxes are high-confidence but disjoint,
248    /// indicating contradictory claims.
249    #[must_use]
250    pub fn detect_conflict(&self, other: &Self) -> Option<Conflict> {
251        let overlap = self.box_embedding.intersection_volume(&other.box_embedding);
252        let min_vol = self
253            .box_embedding
254            .volume()
255            .min(other.box_embedding.volume());
256
257        // If both are high-confidence (small volume) but disjoint, conflict
258        let conf_a = self.confidence();
259        let conf_b = other.confidence();
260        let threshold = 0.8;
261
262        if overlap < min_vol * 0.1 && conf_a > threshold && conf_b > threshold {
263            Some(Conflict {
264                claim_a_trust: self.source_trust,
265                claim_b_trust: other.source_trust,
266                severity: (1.0 - overlap / min_vol.max(1e-6)) * (conf_a + conf_b) / 2.0,
267            })
268        } else {
269            None
270        }
271    }
272}
273
274/// Represents a conflict between two uncertain claims.
275#[derive(Debug, Clone, PartialEq)]
276pub struct Conflict {
277    /// Trust in first claim's source
278    pub claim_a_trust: f32,
279    /// Trust in second claim's source
280    pub claim_b_trust: f32,
281    /// Severity of conflict (0.0-1.0, higher = more severe)
282    pub severity: f32,
283}
284
285// =============================================================================
286// Interaction Modeling (Triple Intersection)
287// =============================================================================
288
289/// Compute interaction strength between actor, action, and target.
290///
291/// Models asymmetric relations (e.g., "Company A acquired Company B")
292/// via triple intersection volume. The interaction is the volume where
293/// all three boxes overlap.
294///
295/// # Arguments
296///
297/// * `actor_box` - Box for the actor (e.g., buyer)
298/// * `action_box` - Box for the action/relation (e.g., "acquired")
299/// * `target_box` - Box for the target (e.g., company being acquired)
300///
301/// # Returns
302///
303/// Conditional probability P(action, target | actor), representing
304/// how much of the actor's space contains the interaction.
305#[must_use]
306pub fn interaction_strength(
307    actor_box: &BoxEmbedding,
308    action_box: &BoxEmbedding,
309    target_box: &BoxEmbedding,
310) -> f32 {
311    // Triple intersection: where all three boxes overlap
312    // For simplicity, we compute pairwise intersections and take minimum
313    // In full implementation, would compute true 3-way intersection
314    let actor_action = actor_box.intersection_volume(action_box);
315    let action_target = action_box.intersection_volume(target_box);
316    let actor_target = actor_box.intersection_volume(target_box);
317
318    // Interaction volume ≈ minimum of pairwise intersections
319    let interaction_vol = actor_action.min(action_target).min(actor_target);
320
321    // P(interaction | actor) = interaction_vol / vol(actor)
322    let vol_actor = actor_box.volume();
323    if vol_actor == 0.0 {
324        return 0.0;
325    }
326    interaction_vol / vol_actor
327}
328
329/// Compute asymmetric roles in a relation.
330///
331/// For a relation like "acquired", determines which entity is the
332/// buyer vs. seller based on conditional probabilities.
333///
334/// # Returns
335///
336/// (buyer_role, seller_role) where each is the interaction strength
337/// for that role.
338#[must_use]
339pub fn acquisition_roles(
340    entity_a: &BoxEmbedding,
341    entity_b: &BoxEmbedding,
342    acquisition_box: &BoxEmbedding,
343) -> (f32, f32) {
344    let buyer_role = interaction_strength(entity_a, acquisition_box, entity_b);
345    let seller_role = interaction_strength(entity_b, acquisition_box, entity_a);
346    (buyer_role, seller_role)
347}
348
349// =============================================================================
350// Gumbel Boxes (Noise Robustness)
351// =============================================================================
352
353/// A Gumbel box with soft, probabilistic boundaries.
354///
355/// Instead of hard walls, boundaries are modeled as Gumbel distributions,
356/// creating "fuzzy" boxes that tolerate slight misalignments. This prevents
357/// brittle logic failures when data is noisy.
358///
359/// # Example
360///
361/// ```rust,ignore
362/// use anno::backends::box_embeddings::{BoxEmbedding, GumbelBox};
363///
364/// let mean_box = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
365/// let gumbel_box = GumbelBox::new(mean_box, 0.1); // Low temperature = sharp
366///
367/// // Membership is probabilistic, not binary
368/// let point = vec![0.5, 0.5];
369/// let prob = gumbel_box.membership_probability(&point);
370/// assert!(prob > 0.5); // High probability inside box
371/// ```
372#[derive(Debug, Clone, PartialEq)]
373pub struct GumbelBox {
374    /// Mean box boundaries (lower bounds)
375    pub mean_min: Vec<f32>,
376    /// Mean box boundaries (upper bounds)
377    pub mean_max: Vec<f32>,
378    /// Temperature: controls fuzziness (higher = more fuzzy)
379    /// Typical values: 0.01-0.1 for sharp, 0.5-1.0 for fuzzy
380    pub temperature: f32,
381}
382
383impl GumbelBox {
384    /// Create a new Gumbel box.
385    #[must_use]
386    pub fn new(mean_box: BoxEmbedding, temperature: f32) -> Self {
387        assert!(
388            temperature > 0.0,
389            "temperature must be positive, got {}",
390            temperature
391        );
392        Self {
393            mean_min: mean_box.min,
394            mean_max: mean_box.max,
395            temperature,
396        }
397    }
398
399    /// Compute membership probability for a point.
400    ///
401    /// Returns probability that point belongs to this box (0.0-1.0).
402    /// Uses Gumbel CDF approximation for soft boundaries.
403    #[must_use]
404    pub fn membership_probability(&self, point: &[f32]) -> f32 {
405        assert_eq!(
406            point.len(),
407            self.mean_min.len(),
408            "point dimension must match box dimension"
409        );
410
411        let mut prob = 1.0;
412        for (i, &coord) in point.iter().enumerate() {
413            // Gumbel CDF approximation: P(x < max) ≈ 1 / (1 + exp(-(max - x) / temp))
414            // For min boundary: P(x > min) ≈ 1 / (1 + exp(-(x - min) / temp))
415            let min_prob = 1.0 / (1.0 + (-(coord - self.mean_min[i]) / self.temperature).exp());
416            let max_prob = 1.0 / (1.0 + (-(self.mean_max[i] - coord) / self.temperature).exp());
417            prob *= min_prob * max_prob;
418        }
419        prob
420    }
421
422    /// Compute robust coreference score with another Gumbel box.
423    ///
424    /// Samples points from self and checks membership in other, averaging
425    /// probabilities. This tolerates slight misalignments.
426    ///
427    /// # Arguments
428    ///
429    /// * `other` - The other Gumbel box to compare against
430    /// * `samples` - Number of sample points to use (more = more accurate but slower)
431    /// * `rng` - Optional RNG for sampling. If None, uses deterministic grid sampling.
432    #[must_use]
433    pub fn robust_coreference(&self, other: &Self, samples: usize) -> f32 {
434        assert_eq!(
435            self.mean_min.len(),
436            other.mean_min.len(),
437            "boxes must have same dimension"
438        );
439
440        // Deterministic grid sampling (no RNG dependency)
441        // For each dimension, sample at regular intervals
442        let samples_per_dim = (samples as f32)
443            .powf(1.0 / self.mean_min.len() as f32)
444            .ceil() as usize;
445        let mut total_prob = 0.0;
446        let mut count = 0;
447
448        // Generate grid points
449        let mut indices = vec![0; self.mean_min.len()];
450        loop {
451            // Compute point from grid indices
452            let point: Vec<f32> = self
453                .mean_min
454                .iter()
455                .zip(self.mean_max.iter())
456                .zip(indices.iter())
457                .map(|((&min_val, &max_val), &idx)| {
458                    let t = idx as f32 / (samples_per_dim - 1).max(1) as f32;
459                    min_val + t * (max_val - min_val)
460                })
461                .collect();
462
463            total_prob += other.membership_probability(&point);
464            count += 1;
465
466            // Increment grid indices
467            let mut carry = true;
468            for idx in &mut indices {
469                if carry {
470                    *idx += 1;
471                    if *idx >= samples_per_dim {
472                        *idx = 0;
473                        carry = true;
474                    } else {
475                        carry = false;
476                    }
477                }
478            }
479
480            if carry || count >= samples {
481                break;
482            }
483        }
484
485        total_prob / count as f32
486    }
487}
488
489// =============================================================================
490// Subsume Trait Implementations for GumbelBox
491// =============================================================================
492
493#[cfg(feature = "subsume")]
494impl subsume::Box for GumbelBox {
495    type Scalar = f32;
496    type Vector = Vec<f32>;
497
498    fn min(&self) -> &Self::Vector {
499        &self.mean_min
500    }
501
502    fn max(&self) -> &Self::Vector {
503        &self.mean_max
504    }
505
506    fn dim(&self) -> usize {
507        self.mean_min.len()
508    }
509
510    fn volume(&self, temperature: Self::Scalar) -> Result<Self::Scalar, subsume::BoxError> {
511        // Use log-space volume approximation for Gumbel boxes
512        let mut log_vol = 0.0;
513        for i in 0..self.dim() {
514            let diff = self.mean_max[i] - self.mean_min[i];
515            // Softplus approximation: temp * log(1 + exp(x/temp))
516            log_vol += (diff / temperature).exp().ln_1p() * temperature;
517        }
518        Ok(log_vol.exp())
519    }
520
521    fn intersection(&self, other: &Self) -> Result<Self, subsume::BoxError> {
522        if self.dim() != other.dim() {
523            return Err(subsume::BoxError::DimensionMismatch {
524                expected: self.dim(),
525                actual: other.dim(),
526            });
527        }
528
529        // Gumbel intersection uses LSE for max-stability
530        let mut new_min = Vec::with_capacity(self.dim());
531        let mut new_max = Vec::with_capacity(self.dim());
532
533        for i in 0..self.dim() {
534            let m1 = self.mean_min[i];
535            let m2 = other.mean_min[i];
536            let lse_min =
537                m1.max(m2) + self.temperature * (-(m1 - m2).abs() / self.temperature).exp().ln_1p();
538            new_min.push(lse_min);
539
540            let x1 = self.mean_max[i];
541            let x2 = other.mean_max[i];
542            let lse_max =
543                x1.min(x2) - self.temperature * (-(x1 - x2).abs() / self.temperature).exp().ln_1p();
544            new_max.push(lse_max);
545        }
546
547        Ok(GumbelBox {
548            mean_min: new_min,
549            mean_max: new_max,
550            temperature: self.temperature,
551        })
552    }
553
554    fn containment_prob(
555        &self,
556        other: &Self,
557        temperature: Self::Scalar,
558    ) -> Result<Self::Scalar, subsume::BoxError> {
559        let intersection = self.intersection(other)?;
560        let vol_int = intersection.volume(temperature)?;
561        let vol_other = other.volume(temperature)?;
562        if vol_other == 0.0 {
563            return Ok(0.0);
564        }
565        Ok(vol_int / vol_other)
566    }
567
568    fn overlap_prob(
569        &self,
570        other: &Self,
571        temperature: Self::Scalar,
572    ) -> Result<Self::Scalar, subsume::BoxError> {
573        let intersection = self.intersection(other)?;
574        let vol_int = intersection.volume(temperature)?;
575        let vol_self = self.volume(temperature)?;
576        let vol_other = other.volume(temperature)?;
577        let vol_union = vol_self + vol_other - vol_int;
578        if vol_union <= 0.0 {
579            return Ok(0.0);
580        }
581        Ok(vol_int / vol_union)
582    }
583
584    fn union(&self, other: &Self) -> Result<Self, subsume::BoxError> {
585        let mut new_min = Vec::with_capacity(self.dim());
586        let mut new_max = Vec::with_capacity(self.dim());
587        for i in 0..self.dim() {
588            new_min.push(self.mean_min[i].min(other.mean_min[i]));
589            new_max.push(self.mean_max[i].max(other.mean_max[i]));
590        }
591        Ok(GumbelBox {
592            mean_min: new_min,
593            mean_max: new_max,
594            temperature: self.temperature,
595        })
596    }
597
598    fn center(&self) -> Result<Self::Vector, subsume::BoxError> {
599        let mut center = Vec::with_capacity(self.dim());
600        for i in 0..self.dim() {
601            center.push((self.mean_min[i] + self.mean_max[i]) / 2.0);
602        }
603        Ok(center)
604    }
605
606    fn distance(&self, other: &Self) -> Result<Self::Scalar, subsume::BoxError> {
607        let mut dist_sq = 0.0;
608        for i in 0..self.dim() {
609            let gap = if self.mean_max[i] < other.mean_min[i] {
610                other.mean_min[i] - self.mean_max[i]
611            } else if other.mean_max[i] < self.mean_min[i] {
612                self.mean_min[i] - other.mean_max[i]
613            } else {
614                0.0
615            };
616            dist_sq += gap * gap;
617        }
618        Ok(dist_sq.sqrt())
619    }
620
621    fn truncate(&self, k: usize) -> Result<Self, subsume::BoxError> {
622        if k > self.dim() {
623            return Err(subsume::BoxError::MatryoshkaMismatch {
624                requested: k,
625                actual: self.dim(),
626            });
627        }
628        Ok(GumbelBox {
629            mean_min: self.mean_min[..k].to_vec(),
630            mean_max: self.mean_max[..k].to_vec(),
631            temperature: self.temperature,
632        })
633    }
634}
635
636#[cfg(feature = "subsume")]
637impl subsume::GumbelBox for GumbelBox {
638    fn temperature(&self) -> Self::Scalar {
639        self.temperature
640    }
641
642    fn membership_probability(
643        &self,
644        point: &Self::Vector,
645    ) -> Result<Self::Scalar, subsume::BoxError> {
646        Ok(self.membership_probability(point))
647    }
648
649    fn sample(&self) -> Self::Vector {
650        self.mean_min
651            .iter()
652            .zip(self.mean_max.iter())
653            .map(|(mn, mx)| (mn + mx) / 2.0)
654            .collect()
655    }
656}
657
658// Note: BoxCorefResolver is implemented in src/eval/coref_resolver.rs
659// to be alongside other coreference resolvers.