Skip to main content

anno/backends/inference/
mod.rs

1//! Inference abstractions shared across `anno` backends.
2//!
3//! This module is mostly **plumbing**: common traits, data shapes, and small
4//! utilities used by multiple NER / IE backends (including “fixed-label” and
5//! “open/zero-shot” styles).
6//!
7//! Some of the terminology and design choices correspond to well-known
8//! architectures in the NER/IE literature, but the code here should be treated
9//! as an implementation substrate, not a verbatim reproduction of any single
10//! paper’s experiment section.
11//!
12//! ## Paper pointers (context only)
13//!
14//! - GLiNER: arXiv:2311.08526
15//! - UniversalNER: arXiv:2308.03279
16//! - W2NER: arXiv:2112.10070
17//! - ModernBERT: arXiv:2412.13663
18
19use std::borrow::Cow;
20
21// =============================================================================
22// Modality Types
23// =============================================================================
24
25/// Input modality for the encoder.
26///
27/// Supports text, images, and hybrid (OCR + visual) inputs.
28/// This enables ColPali-style visual document understanding.
29#[derive(Debug, Clone)]
30pub enum ModalityInput<'a> {
31    /// Plain text input
32    Text(Cow<'a, str>),
33    /// Image bytes (PNG/JPEG)
34    Image {
35        /// Raw image bytes
36        data: Cow<'a, [u8]>,
37        /// Image format hint
38        format: ImageFormat,
39    },
40    /// Hybrid: text with visual location (e.g., OCR result)
41    Hybrid {
42        /// Extracted text
43        text: Cow<'a, str>,
44        /// Visual bounding boxes for each token/word
45        visual_positions: Vec<VisualPosition>,
46    },
47}
48
49/// Image format hint for decoding.
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
51pub enum ImageFormat {
52    /// PNG format
53    #[default]
54    Png,
55    /// JPEG format
56    Jpeg,
57    /// WebP format
58    Webp,
59    /// Unknown/auto-detect
60    Unknown,
61}
62
63/// Visual position of a text token in an image.
64#[derive(Debug, Clone, Copy)]
65pub struct VisualPosition {
66    /// Token/word index
67    pub token_idx: u32,
68    /// Normalized x coordinate (0.0-1.0)
69    pub x: f32,
70    /// Normalized y coordinate (0.0-1.0)
71    pub y: f32,
72    /// Normalized width (0.0-1.0)
73    pub width: f32,
74    /// Normalized height (0.0-1.0)
75    pub height: f32,
76    /// Page number (for multi-page documents)
77    pub page: u32,
78}
79
80// =============================================================================
81
82pub mod registry;
83pub use registry::*;
84
85pub mod encoder;
86pub use encoder::*;
87
88pub mod traits;
89pub use traits::*;
90
91pub mod late_interaction;
92pub use late_interaction::*;
93
94pub mod span;
95pub use span::*;
96
97pub mod coref;
98pub use coref::*;
99
100pub mod relation_extraction;
101pub use relation_extraction::{
102    extract_relation_triples, extract_relations, RelationExtractionConfig,
103};
104
105pub mod binary_embeddings;
106pub use binary_embeddings::*;
107// Tests
108// =============================================================================
109
110#[cfg(test)]
111mod tests {
112    use super::coref::{resolve_coreferences, CoreferenceConfig};
113    use super::late_interaction::DotProductInteraction;
114    use super::*;
115    use crate::{Entity, EntityType};
116
117    #[test]
118    fn test_semantic_registry_builder() {
119        let registry = SemanticRegistry::builder()
120            .add_entity("person", "A human being")
121            .add_entity("organization", "A company or group")
122            .add_relation("WORKS_FOR", "Employment relationship")
123            .build_placeholder(768);
124
125        assert_eq!(registry.len(), 3);
126        assert_eq!(registry.entity_labels().count(), 2);
127        assert_eq!(registry.relation_labels().count(), 1);
128    }
129
130    #[test]
131    fn test_standard_ner_registry() {
132        let registry = SemanticRegistry::standard_ner(768);
133        assert!(registry.len() >= 5);
134        assert!(registry.label_index.contains_key("person"));
135        assert!(registry.label_index.contains_key("organization"));
136    }
137
138    #[test]
139    fn test_dot_product_interaction() {
140        let interaction = DotProductInteraction::new();
141
142        // 2 spans, 3 labels, hidden_dim=4
143        let span_embs = vec![
144            1.0, 0.0, 0.0, 0.0, // span 0
145            0.0, 1.0, 0.0, 0.0, // span 1
146        ];
147        let label_embs = vec![
148            1.0, 0.0, 0.0, 0.0, // label 0 (matches span 0)
149            0.0, 1.0, 0.0, 0.0, // label 1 (matches span 1)
150            0.5, 0.5, 0.0, 0.0, // label 2 (partial match both)
151        ];
152
153        let scores = interaction.compute_similarity(&span_embs, 2, &label_embs, 3, 4);
154
155        assert_eq!(scores.len(), 6); // 2 * 3
156        assert!((scores[0] - 1.0).abs() < 0.01); // span0 vs label0
157        assert!((scores[4] - 1.0).abs() < 0.01); // span1 vs label1
158    }
159
160    #[test]
161    fn test_cosine_similarity() {
162        let a = vec![1.0, 0.0, 0.0];
163        let b = vec![1.0, 0.0, 0.0];
164        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
165
166        let c = vec![0.0, 1.0, 0.0];
167        assert!(cosine_similarity(&a, &c).abs() < 0.001);
168
169        let d = vec![-1.0, 0.0, 0.0];
170        assert!((cosine_similarity(&a, &d) - (-1.0)).abs() < 0.001);
171    }
172
173    #[test]
174    fn test_coreference_string_match() {
175        let entities = vec![
176            Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95),
177            Entity::new("Curie", EntityType::Person, 50, 55, 0.90),
178        ];
179
180        let embeddings = vec![0.0f32; 2 * 768]; // Placeholder
181        let clusters =
182            resolve_coreferences(&entities, &embeddings, 768, &CoreferenceConfig::default());
183
184        assert_eq!(clusters.len(), 1);
185        assert_eq!(clusters[0].members.len(), 2);
186        assert_eq!(clusters[0].canonical_name, "Marie Curie");
187    }
188
189    #[test]
190    fn test_handshaking_matrix() {
191        // 3 tokens, 2 labels, threshold 0.5
192        let scores = vec![
193            // token 0 with tokens 0,1,2 for labels 0,1
194            0.9, 0.1, // (0,0)
195            0.2, 0.8, // (0,1)
196            0.1, 0.1, // (0,2)
197            // token 1 with tokens 0,1,2
198            0.0, 0.0, // (1,0) - skipped (lower triangle)
199            0.7, 0.2, // (1,1)
200            0.3, 0.6, // (1,2)
201            // token 2
202            0.0, 0.0, // (2,0)
203            0.0, 0.0, // (2,1)
204            0.1, 0.1, // (2,2)
205        ];
206
207        let matrix = HandshakingMatrix::from_dense(&scores, 3, 2, 0.5);
208
209        // Should have cells for scores >= 0.5
210        assert!(matrix.cells.len() >= 4);
211    }
212
213    #[test]
214    fn test_relation_extraction() {
215        let entities = vec![
216            Entity::new("Steve Jobs", EntityType::Person, 0, 10, 0.95),
217            Entity::new("Apple", EntityType::Organization, 20, 25, 0.90),
218        ];
219
220        let text = "Steve Jobs founded Apple Inc in 1976";
221
222        let registry = SemanticRegistry::builder()
223            .add_relation("FOUNDED", "Founded an organization")
224            .build_placeholder(768);
225
226        let config = RelationExtractionConfig::default();
227        let relations = extract_relations(&entities, text, &registry, &config);
228
229        assert!(!relations.is_empty());
230        assert_eq!(relations[0].relation_type, "FOUNDED");
231    }
232
233    #[test]
234    fn test_relation_extraction_uses_character_offsets_with_unicode_prefix() {
235        // Unicode prefix ensures byte offsets != character offsets.
236        let text = "👋 Steve Jobs founded Apple Inc.";
237
238        // Compute character offsets explicitly (Entity spans are char-based).
239        let steve_start = text.find("Steve Jobs").expect("substring present");
240        // `find` returns byte offset; convert to char offset.
241        let conv = crate::offset::SpanConverter::new(text);
242        let steve_start_char = conv.byte_to_char(steve_start);
243        let steve_end_char = steve_start_char + "Steve Jobs".chars().count();
244
245        let apple_start = text.find("Apple").expect("substring present");
246        let apple_start_char = conv.byte_to_char(apple_start);
247        let apple_end_char = apple_start_char + "Apple".chars().count();
248
249        let entities = vec![
250            Entity::new(
251                "Steve Jobs",
252                EntityType::Person,
253                steve_start_char,
254                steve_end_char,
255                0.95,
256            ),
257            Entity::new(
258                "Apple",
259                EntityType::Organization,
260                apple_start_char,
261                apple_end_char,
262                0.90,
263            ),
264        ];
265
266        let registry = SemanticRegistry::builder()
267            .add_relation("FOUNDED", "Founded an organization")
268            .build_placeholder(768);
269
270        let config = RelationExtractionConfig::default();
271        let relations = extract_relations(&entities, text, &registry, &config);
272
273        assert!(
274            !relations.is_empty(),
275            "Expected FOUNDED relation to be detected"
276        );
277        assert_eq!(relations[0].relation_type, "FOUNDED");
278
279        // Trigger span should exist and cover "founded" in character offsets.
280        let trigger = relations[0]
281            .trigger_span
282            .expect("expected trigger_span to be present");
283        let trigger_text: String = text
284            .chars()
285            .skip(trigger.0)
286            .take(trigger.1.saturating_sub(trigger.0))
287            .collect();
288        assert_eq!(trigger_text.to_ascii_lowercase(), "founded");
289    }
290
291    // =========================================================================
292    // Binary Embedding Tests
293    // =========================================================================
294
295    #[test]
296    fn test_binary_hash_creation() {
297        let embedding = vec![0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8];
298        let hash = BinaryHash::from_embedding(&embedding);
299
300        assert_eq!(hash.dim, 8);
301        // Positive values at indices 0, 2, 4, 6 should be set
302        // bits[0] should have bits 0, 2, 4, 6 set = 0b01010101 = 85
303        assert_eq!(hash.bits[0], 85);
304    }
305
306    #[test]
307    fn test_hamming_distance_identical() {
308        let embedding = vec![0.1; 64];
309        let hash1 = BinaryHash::from_embedding(&embedding);
310        let hash2 = BinaryHash::from_embedding(&embedding);
311
312        assert_eq!(hash1.hamming_distance(&hash2), 0);
313    }
314
315    #[test]
316    fn test_hamming_distance_opposite() {
317        let embedding1 = vec![0.1; 64];
318        let embedding2 = vec![-0.1; 64];
319        let hash1 = BinaryHash::from_embedding(&embedding1);
320        let hash2 = BinaryHash::from_embedding(&embedding2);
321
322        assert_eq!(hash1.hamming_distance(&hash2), 64);
323    }
324
325    #[test]
326    fn test_hamming_distance_half() {
327        let embedding1 = vec![0.1; 64];
328        let mut embedding2 = vec![0.1; 64];
329        // Flip second half
330        embedding2[32..64].iter_mut().for_each(|x| *x = -0.1);
331
332        let hash1 = BinaryHash::from_embedding(&embedding1);
333        let hash2 = BinaryHash::from_embedding(&embedding2);
334
335        assert_eq!(hash1.hamming_distance(&hash2), 32);
336    }
337
338    #[test]
339    fn test_binary_blocker() {
340        let mut blocker = BinaryBlocker::new(5);
341
342        // Add some hashes
343        let base_embedding = vec![0.1; 64];
344        let similar_embedding = {
345            let mut e = vec![0.1; 64];
346            e[0] = -0.1; // Flip 1 bit
347            e[1] = -0.1; // Flip 2 bits
348            e
349        };
350        let different_embedding = vec![-0.1; 64];
351
352        blocker.add(0, BinaryHash::from_embedding(&base_embedding));
353        blocker.add(1, BinaryHash::from_embedding(&similar_embedding));
354        blocker.add(2, BinaryHash::from_embedding(&different_embedding));
355
356        // Query with base
357        let query = BinaryHash::from_embedding(&base_embedding);
358        let candidates = blocker.query(&query);
359
360        assert!(candidates.contains(&0), "Should find exact match");
361        assert!(
362            candidates.contains(&1),
363            "Should find similar (2 bits different)"
364        );
365        assert!(
366            !candidates.contains(&2),
367            "Should NOT find opposite (64 bits different)"
368        );
369    }
370
371    #[test]
372    fn test_two_stage_retrieval() {
373        // Create embeddings
374        let query = vec![1.0, 0.0, 0.0, 0.0];
375        let candidates = vec![
376            vec![1.0, 0.0, 0.0, 0.0],  // Identical
377            vec![0.9, 0.1, 0.0, 0.0],  // Similar
378            vec![-1.0, 0.0, 0.0, 0.0], // Opposite
379            vec![0.0, 1.0, 0.0, 0.0],  // Orthogonal
380        ];
381
382        // Generous threshold to get candidates
383        let results = two_stage_retrieval(&query, &candidates, 4, 2);
384
385        assert!(!results.is_empty());
386        // First result should be exact match
387        assert_eq!(results[0].0, 0);
388        assert!((results[0].1 - 1.0).abs() < 0.001);
389    }
390
391    #[test]
392    fn test_approximate_cosine() {
393        let embedding1 = vec![0.1; 768];
394        let embedding2 = vec![0.1; 768];
395        let hash1 = BinaryHash::from_embedding(&embedding1);
396        let hash2 = BinaryHash::from_embedding(&embedding2);
397
398        // Identical → approximate cosine should be ~1.0
399        let approx = hash1.approximate_cosine(&hash2);
400        assert!((approx - 1.0).abs() < 0.001);
401
402        // Opposite → approximate cosine should be ~-1.0
403        let embedding3 = vec![-0.1; 768];
404        let hash3 = BinaryHash::from_embedding(&embedding3);
405        let approx_opp = hash1.approximate_cosine(&hash3);
406        assert!((approx_opp - (-1.0)).abs() < 0.001);
407    }
408}