anno/backends/box_embeddings/extras.rs
1//! BoxCorefConfig, TemporalBox, BoxVelocity, UncertainBox, and related types.
2
3use super::*;
4
5/// Configuration for box-embedding-based coreference.
6pub struct BoxCorefConfig {
7 /// Minimum coreference score to link entities
8 pub coreference_threshold: f32,
9 /// Whether to enforce syntactic constraints (Principle B/C)
10 pub enforce_syntactic_constraints: bool,
11 /// Maximum token distance for local domain (Principle B)
12 pub max_local_distance: usize,
13 /// Radius for converting vector embeddings to boxes (if using vectors)
14 pub vector_to_box_radius: Option<f32>,
15}
16
17impl Default for BoxCorefConfig {
18 fn default() -> Self {
19 Self {
20 coreference_threshold: 0.7,
21 enforce_syntactic_constraints: true,
22 max_local_distance: 5,
23 vector_to_box_radius: Some(0.1),
24 }
25 }
26}
27
28// =============================================================================
29// Temporal Boxes (BoxTE-style)
30// =============================================================================
31
32/// A temporal box embedding that evolves over time.
33///
34/// Based on BoxTE (Messner et al., 2022), this models entities that change
35/// over time. For example, "The President" refers to Obama in 2012 but
36/// Trump in 2017 - they should not corefer despite the same surface form.
37///
38/// # Example
39///
40/// ```rust,ignore
41/// use anno::backends::box_embeddings::{BoxEmbedding, TemporalBox, BoxVelocity};
42///
43/// // "The President" in 2012 (Obama)
44/// let base = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
45/// let velocity = BoxVelocity::new(vec![0.0, 0.0], vec![0.0, 0.0]); // Static
46/// let obama_presidency = TemporalBox::new(base, velocity, (2012.0, 2016.0));
47///
48/// // "The President" in 2017 (Trump)
49/// let trump_base = BoxEmbedding::new(vec![5.0, 5.0], vec![6.0, 6.0]);
50/// let trump_presidency = TemporalBox::new(trump_base, velocity, (2017.0, 2021.0));
51///
52/// // Should not corefer (different time ranges)
53/// assert_eq!(obama_presidency.coreference_at_time(&trump_presidency, 2015.0), 0.0);
54/// ```
55#[derive(Debug, Clone, PartialEq)]
56pub struct TemporalBox {
57 /// Base box at time t=0 (or reference time)
58 pub base: BoxEmbedding,
59 /// Velocity: how box moves/resizes per time unit
60 pub velocity: BoxVelocity,
61 /// Time range where this box is valid [start, end)
62 pub time_range: (f64, f64),
63}
64
65/// Velocity of a temporal box (change per time unit).
66#[derive(Debug, Clone, PartialEq)]
67pub struct BoxVelocity {
68 /// Change in min bounds per time unit (d-dimensional vector).
69 pub min_delta: Vec<f32>,
70 /// Change in max bounds per time unit (d-dimensional vector).
71 pub max_delta: Vec<f32>,
72}
73
74impl BoxVelocity {
75 /// Create a new box velocity (static by default).
76 #[must_use]
77 pub fn new(min_delta: Vec<f32>, max_delta: Vec<f32>) -> Self {
78 Self {
79 min_delta,
80 max_delta,
81 }
82 }
83
84 /// Create a static velocity (no change over time).
85 #[must_use]
86 pub fn static_velocity(dim: usize) -> Self {
87 Self {
88 min_delta: vec![0.0; dim],
89 max_delta: vec![0.0; dim],
90 }
91 }
92}
93
94impl TemporalBox {
95 /// Create a new temporal box.
96 ///
97 /// # Arguments
98 ///
99 /// * `base` - Base box at reference time
100 /// * `velocity` - How box evolves per time unit
101 /// * `time_range` - (start, end) time range where box is valid
102 #[must_use]
103 pub fn new(base: BoxEmbedding, velocity: BoxVelocity, time_range: (f64, f64)) -> Self {
104 assert_eq!(
105 base.dim(),
106 velocity.min_delta.len(),
107 "base and velocity must have same dimension"
108 );
109 assert_eq!(
110 velocity.min_delta.len(),
111 velocity.max_delta.len(),
112 "velocity min and max deltas must have same dimension"
113 );
114 Self {
115 base,
116 velocity,
117 time_range,
118 }
119 }
120
121 /// Get the box at a specific time.
122 ///
123 /// Returns None if time is outside the valid range.
124 #[must_use]
125 pub fn at_time(&self, time: f64) -> Option<BoxEmbedding> {
126 if time < self.time_range.0 || time >= self.time_range.1 {
127 return None;
128 }
129
130 // Compute time offset from reference (using start of range as reference)
131 let time_offset = time - self.time_range.0;
132
133 // Apply velocity to base box
134 let new_min: Vec<f32> = self
135 .base
136 .min
137 .iter()
138 .zip(self.velocity.min_delta.iter())
139 .map(|(&m, &delta)| m + delta * time_offset as f32)
140 .collect();
141
142 let new_max: Vec<f32> = self
143 .base
144 .max
145 .iter()
146 .zip(self.velocity.max_delta.iter())
147 .map(|(&max_val, &delta)| max_val + delta * time_offset as f32)
148 .collect();
149
150 Some(BoxEmbedding::new(new_min, new_max))
151 }
152
153 /// Compute coreference score at a specific time.
154 ///
155 /// Returns 0.0 if either box is invalid at the given time.
156 #[must_use]
157 pub fn coreference_at_time(&self, other: &Self, time: f64) -> f32 {
158 let box_a = match self.at_time(time) {
159 Some(b) => b,
160 None => return 0.0,
161 };
162 let box_b = match other.at_time(time) {
163 Some(b) => b,
164 None => return 0.0,
165 };
166 box_a.coreference_score(&box_b)
167 }
168
169 /// Check if this temporal box is valid at the given time.
170 #[must_use]
171 pub fn is_valid_at(&self, time: f64) -> bool {
172 time >= self.time_range.0 && time < self.time_range.1
173 }
174}
175
176// =============================================================================
177// Uncertainty-Aware Boxes (UKGE-style)
178// =============================================================================
179
180/// An uncertainty-aware box embedding (UKGE-style).
181///
182/// Based on UKGE (Chen et al., 2021), box volume represents confidence:
183/// - Small box = high confidence (precise, trusted fact)
184/// - Large box = low confidence (vague, uncertain, or dubious claim)
185///
186/// This enables conflict detection: if two high-confidence boxes are disjoint,
187/// they represent contradictory claims.
188///
189/// # Example
190///
191/// ```rust,ignore
192/// use anno::backends::box_embeddings::{BoxEmbedding, UncertainBox};
193///
194/// // High-confidence claim: "Trump is in NY" (small, precise box)
195/// let claim_a = UncertainBox::new(
196/// BoxEmbedding::new(vec![0.0, 0.0], vec![0.1, 0.1]), // Small = high confidence
197/// 0.95, // Source trust
198/// );
199///
200/// // Contradictory claim: "Trump is in FL" (also high confidence, but disjoint)
201/// let claim_b = UncertainBox::new(
202/// BoxEmbedding::new(vec![5.0, 5.0], vec![5.1, 5.1]), // Disjoint from claim_a
203/// 0.90,
204/// );
205///
206/// // Should detect conflict
207/// assert!(claim_a.detect_conflict(&claim_b).is_some());
208/// ```
209#[derive(Debug, Clone, PartialEq)]
210pub struct UncertainBox {
211 /// The underlying box embedding
212 pub box_embedding: BoxEmbedding,
213 /// Source trustworthiness (0.0-1.0)
214 pub source_trust: f32,
215}
216
217impl UncertainBox {
218 /// Create a new uncertainty-aware box.
219 ///
220 /// Confidence is derived from box volume (smaller = higher confidence).
221 #[must_use]
222 pub fn new(box_embedding: BoxEmbedding, source_trust: f32) -> Self {
223 assert!(
224 (0.0..=1.0).contains(&source_trust),
225 "source_trust must be in [0.0, 1.0]"
226 );
227 Self {
228 box_embedding,
229 source_trust,
230 }
231 }
232
233 /// Get confidence derived from box volume.
234 ///
235 /// Smaller boxes = higher confidence. This is a heuristic:
236 /// confidence ≈ 1.0 / (1.0 + volume)
237 #[must_use]
238 pub fn confidence(&self) -> f32 {
239 let vol = self.box_embedding.volume();
240 // Normalize: confidence decreases as volume increases
241 // Using sigmoid-like function: 1 / (1 + volume)
242 1.0 / (1.0 + vol)
243 }
244
245 /// Detect conflict with another uncertain box.
246 ///
247 /// Returns Some(Conflict) if both boxes are high-confidence but disjoint,
248 /// indicating contradictory claims.
249 #[must_use]
250 pub fn detect_conflict(&self, other: &Self) -> Option<Conflict> {
251 let overlap = self.box_embedding.intersection_volume(&other.box_embedding);
252 let min_vol = self
253 .box_embedding
254 .volume()
255 .min(other.box_embedding.volume());
256
257 // If both are high-confidence (small volume) but disjoint, conflict
258 let conf_a = self.confidence();
259 let conf_b = other.confidence();
260 let threshold = 0.8;
261
262 if overlap < min_vol * 0.1 && conf_a > threshold && conf_b > threshold {
263 Some(Conflict {
264 claim_a_trust: self.source_trust,
265 claim_b_trust: other.source_trust,
266 severity: (1.0 - overlap / min_vol.max(1e-6)) * (conf_a + conf_b) / 2.0,
267 })
268 } else {
269 None
270 }
271 }
272}
273
274/// Represents a conflict between two uncertain claims.
275#[derive(Debug, Clone, PartialEq)]
276pub struct Conflict {
277 /// Trust in first claim's source
278 pub claim_a_trust: f32,
279 /// Trust in second claim's source
280 pub claim_b_trust: f32,
281 /// Severity of conflict (0.0-1.0, higher = more severe)
282 pub severity: f32,
283}
284
285// =============================================================================
286// Interaction Modeling (Triple Intersection)
287// =============================================================================
288
289/// Compute interaction strength between actor, action, and target.
290///
291/// Models asymmetric relations (e.g., "Company A acquired Company B")
292/// via triple intersection volume. The interaction is the volume where
293/// all three boxes overlap.
294///
295/// # Arguments
296///
297/// * `actor_box` - Box for the actor (e.g., buyer)
298/// * `action_box` - Box for the action/relation (e.g., "acquired")
299/// * `target_box` - Box for the target (e.g., company being acquired)
300///
301/// # Returns
302///
303/// Conditional probability P(action, target | actor), representing
304/// how much of the actor's space contains the interaction.
305#[must_use]
306pub fn interaction_strength(
307 actor_box: &BoxEmbedding,
308 action_box: &BoxEmbedding,
309 target_box: &BoxEmbedding,
310) -> f32 {
311 // Triple intersection: where all three boxes overlap
312 // For simplicity, we compute pairwise intersections and take minimum
313 // In full implementation, would compute true 3-way intersection
314 let actor_action = actor_box.intersection_volume(action_box);
315 let action_target = action_box.intersection_volume(target_box);
316 let actor_target = actor_box.intersection_volume(target_box);
317
318 // Interaction volume ≈ minimum of pairwise intersections
319 let interaction_vol = actor_action.min(action_target).min(actor_target);
320
321 // P(interaction | actor) = interaction_vol / vol(actor)
322 let vol_actor = actor_box.volume();
323 if vol_actor == 0.0 {
324 return 0.0;
325 }
326 interaction_vol / vol_actor
327}
328
329/// Compute asymmetric roles in a relation.
330///
331/// For a relation like "acquired", determines which entity is the
332/// buyer vs. seller based on conditional probabilities.
333///
334/// # Returns
335///
336/// (buyer_role, seller_role) where each is the interaction strength
337/// for that role.
338#[must_use]
339pub fn acquisition_roles(
340 entity_a: &BoxEmbedding,
341 entity_b: &BoxEmbedding,
342 acquisition_box: &BoxEmbedding,
343) -> (f32, f32) {
344 let buyer_role = interaction_strength(entity_a, acquisition_box, entity_b);
345 let seller_role = interaction_strength(entity_b, acquisition_box, entity_a);
346 (buyer_role, seller_role)
347}
348
349// =============================================================================
350// Gumbel Boxes (Noise Robustness)
351// =============================================================================
352
353/// A Gumbel box with soft, probabilistic boundaries.
354///
355/// Instead of hard walls, boundaries are modeled as Gumbel distributions,
356/// creating "fuzzy" boxes that tolerate slight misalignments. This prevents
357/// brittle logic failures when data is noisy.
358///
359/// # Example
360///
361/// ```rust,ignore
362/// use anno::backends::box_embeddings::{BoxEmbedding, GumbelBox};
363///
364/// let mean_box = BoxEmbedding::new(vec![0.0, 0.0], vec![1.0, 1.0]);
365/// let gumbel_box = GumbelBox::new(mean_box, 0.1); // Low temperature = sharp
366///
367/// // Membership is probabilistic, not binary
368/// let point = vec![0.5, 0.5];
369/// let prob = gumbel_box.membership_probability(&point);
370/// assert!(prob > 0.5); // High probability inside box
371/// ```
372#[derive(Debug, Clone, PartialEq)]
373pub struct GumbelBox {
374 /// Mean box boundaries (lower bounds)
375 pub mean_min: Vec<f32>,
376 /// Mean box boundaries (upper bounds)
377 pub mean_max: Vec<f32>,
378 /// Temperature: controls fuzziness (higher = more fuzzy)
379 /// Typical values: 0.01-0.1 for sharp, 0.5-1.0 for fuzzy
380 pub temperature: f32,
381}
382
383impl GumbelBox {
384 /// Create a new Gumbel box.
385 #[must_use]
386 pub fn new(mean_box: BoxEmbedding, temperature: f32) -> Self {
387 assert!(
388 temperature > 0.0,
389 "temperature must be positive, got {}",
390 temperature
391 );
392 Self {
393 mean_min: mean_box.min,
394 mean_max: mean_box.max,
395 temperature,
396 }
397 }
398
399 /// Compute membership probability for a point.
400 ///
401 /// Returns probability that point belongs to this box (0.0-1.0).
402 /// Uses Gumbel CDF approximation for soft boundaries.
403 #[must_use]
404 pub fn membership_probability(&self, point: &[f32]) -> f32 {
405 assert_eq!(
406 point.len(),
407 self.mean_min.len(),
408 "point dimension must match box dimension"
409 );
410
411 let mut prob = 1.0;
412 for (i, &coord) in point.iter().enumerate() {
413 // Gumbel CDF approximation: P(x < max) ≈ 1 / (1 + exp(-(max - x) / temp))
414 // For min boundary: P(x > min) ≈ 1 / (1 + exp(-(x - min) / temp))
415 let min_prob = 1.0 / (1.0 + (-(coord - self.mean_min[i]) / self.temperature).exp());
416 let max_prob = 1.0 / (1.0 + (-(self.mean_max[i] - coord) / self.temperature).exp());
417 prob *= min_prob * max_prob;
418 }
419 prob
420 }
421
422 /// Compute robust coreference score with another Gumbel box.
423 ///
424 /// Samples points from self and checks membership in other, averaging
425 /// probabilities. This tolerates slight misalignments.
426 ///
427 /// # Arguments
428 ///
429 /// * `other` - The other Gumbel box to compare against
430 /// * `samples` - Number of sample points to use (more = more accurate but slower)
431 /// * `rng` - Optional RNG for sampling. If None, uses deterministic grid sampling.
432 #[must_use]
433 pub fn robust_coreference(&self, other: &Self, samples: usize) -> f32 {
434 assert_eq!(
435 self.mean_min.len(),
436 other.mean_min.len(),
437 "boxes must have same dimension"
438 );
439
440 // Deterministic grid sampling (no RNG dependency)
441 // For each dimension, sample at regular intervals
442 let samples_per_dim = (samples as f32)
443 .powf(1.0 / self.mean_min.len() as f32)
444 .ceil() as usize;
445 let mut total_prob = 0.0;
446 let mut count = 0;
447
448 // Generate grid points
449 let mut indices = vec![0; self.mean_min.len()];
450 loop {
451 // Compute point from grid indices
452 let point: Vec<f32> = self
453 .mean_min
454 .iter()
455 .zip(self.mean_max.iter())
456 .zip(indices.iter())
457 .map(|((&min_val, &max_val), &idx)| {
458 let t = idx as f32 / (samples_per_dim - 1).max(1) as f32;
459 min_val + t * (max_val - min_val)
460 })
461 .collect();
462
463 total_prob += other.membership_probability(&point);
464 count += 1;
465
466 // Increment grid indices
467 let mut carry = true;
468 for idx in &mut indices {
469 if carry {
470 *idx += 1;
471 if *idx >= samples_per_dim {
472 *idx = 0;
473 carry = true;
474 } else {
475 carry = false;
476 }
477 }
478 }
479
480 if carry || count >= samples {
481 break;
482 }
483 }
484
485 total_prob / count as f32
486 }
487}
488
489// =============================================================================
490// Subsume Trait Implementations for GumbelBox
491// =============================================================================
492
493#[cfg(feature = "subsume")]
494impl subsume::Box for GumbelBox {
495 type Scalar = f32;
496 type Vector = Vec<f32>;
497
498 fn min(&self) -> &Self::Vector {
499 &self.mean_min
500 }
501
502 fn max(&self) -> &Self::Vector {
503 &self.mean_max
504 }
505
506 fn dim(&self) -> usize {
507 self.mean_min.len()
508 }
509
510 fn volume(&self, temperature: Self::Scalar) -> Result<Self::Scalar, subsume::BoxError> {
511 // Use log-space volume approximation for Gumbel boxes
512 let mut log_vol = 0.0;
513 for i in 0..self.dim() {
514 let diff = self.mean_max[i] - self.mean_min[i];
515 // Softplus approximation: temp * log(1 + exp(x/temp))
516 log_vol += (diff / temperature).exp().ln_1p() * temperature;
517 }
518 Ok(log_vol.exp())
519 }
520
521 fn intersection(&self, other: &Self) -> Result<Self, subsume::BoxError> {
522 if self.dim() != other.dim() {
523 return Err(subsume::BoxError::DimensionMismatch {
524 expected: self.dim(),
525 actual: other.dim(),
526 });
527 }
528
529 // Gumbel intersection uses LSE for max-stability
530 let mut new_min = Vec::with_capacity(self.dim());
531 let mut new_max = Vec::with_capacity(self.dim());
532
533 for i in 0..self.dim() {
534 let m1 = self.mean_min[i];
535 let m2 = other.mean_min[i];
536 let lse_min =
537 m1.max(m2) + self.temperature * (-(m1 - m2).abs() / self.temperature).exp().ln_1p();
538 new_min.push(lse_min);
539
540 let x1 = self.mean_max[i];
541 let x2 = other.mean_max[i];
542 let lse_max =
543 x1.min(x2) - self.temperature * (-(x1 - x2).abs() / self.temperature).exp().ln_1p();
544 new_max.push(lse_max);
545 }
546
547 Ok(GumbelBox {
548 mean_min: new_min,
549 mean_max: new_max,
550 temperature: self.temperature,
551 })
552 }
553
554 fn containment_prob(
555 &self,
556 other: &Self,
557 temperature: Self::Scalar,
558 ) -> Result<Self::Scalar, subsume::BoxError> {
559 let intersection = self.intersection(other)?;
560 let vol_int = intersection.volume(temperature)?;
561 let vol_other = other.volume(temperature)?;
562 if vol_other == 0.0 {
563 return Ok(0.0);
564 }
565 Ok(vol_int / vol_other)
566 }
567
568 fn overlap_prob(
569 &self,
570 other: &Self,
571 temperature: Self::Scalar,
572 ) -> Result<Self::Scalar, subsume::BoxError> {
573 let intersection = self.intersection(other)?;
574 let vol_int = intersection.volume(temperature)?;
575 let vol_self = self.volume(temperature)?;
576 let vol_other = other.volume(temperature)?;
577 let vol_union = vol_self + vol_other - vol_int;
578 if vol_union <= 0.0 {
579 return Ok(0.0);
580 }
581 Ok(vol_int / vol_union)
582 }
583
584 fn union(&self, other: &Self) -> Result<Self, subsume::BoxError> {
585 let mut new_min = Vec::with_capacity(self.dim());
586 let mut new_max = Vec::with_capacity(self.dim());
587 for i in 0..self.dim() {
588 new_min.push(self.mean_min[i].min(other.mean_min[i]));
589 new_max.push(self.mean_max[i].max(other.mean_max[i]));
590 }
591 Ok(GumbelBox {
592 mean_min: new_min,
593 mean_max: new_max,
594 temperature: self.temperature,
595 })
596 }
597
598 fn center(&self) -> Result<Self::Vector, subsume::BoxError> {
599 let mut center = Vec::with_capacity(self.dim());
600 for i in 0..self.dim() {
601 center.push((self.mean_min[i] + self.mean_max[i]) / 2.0);
602 }
603 Ok(center)
604 }
605
606 fn distance(&self, other: &Self) -> Result<Self::Scalar, subsume::BoxError> {
607 let mut dist_sq = 0.0;
608 for i in 0..self.dim() {
609 let gap = if self.mean_max[i] < other.mean_min[i] {
610 other.mean_min[i] - self.mean_max[i]
611 } else if other.mean_max[i] < self.mean_min[i] {
612 self.mean_min[i] - other.mean_max[i]
613 } else {
614 0.0
615 };
616 dist_sq += gap * gap;
617 }
618 Ok(dist_sq.sqrt())
619 }
620
621 fn truncate(&self, k: usize) -> Result<Self, subsume::BoxError> {
622 if k > self.dim() {
623 return Err(subsume::BoxError::MatryoshkaMismatch {
624 requested: k,
625 actual: self.dim(),
626 });
627 }
628 Ok(GumbelBox {
629 mean_min: self.mean_min[..k].to_vec(),
630 mean_max: self.mean_max[..k].to_vec(),
631 temperature: self.temperature,
632 })
633 }
634}
635
636#[cfg(feature = "subsume")]
637impl subsume::GumbelBox for GumbelBox {
638 fn temperature(&self) -> Self::Scalar {
639 self.temperature
640 }
641
642 fn membership_probability(
643 &self,
644 point: &Self::Vector,
645 ) -> Result<Self::Scalar, subsume::BoxError> {
646 Ok(self.membership_probability(point))
647 }
648
649 fn sample(&self) -> Self::Vector {
650 self.mean_min
651 .iter()
652 .zip(self.mean_max.iter())
653 .map(|(mn, mx)| (mn + mx) / 2.0)
654 .collect()
655 }
656}
657
658// Note: BoxCorefResolver is implemented in src/eval/coref_resolver.rs
659// to be alongside other coreference resolvers.