anno/backends/inference/encoder.rs
1//! Core encoder traits for GLiNER/ModernBERT-style bi-encoder extraction.
2
3#[allow(unused_imports)]
4use crate::{Entity, EntityType};
5use anno_core::RaggedBatch;
6
7// Core Encoder Traits (GLiNER/ModernBERT Alignment)
8// =============================================================================
9
10/// Text encoder trait for transformer-based encoders.
11///
12/// # Motivation
13///
14/// Modern NER systems require converting raw text into dense vector representations
15/// that capture semantic meaning. This trait abstracts the encoding step, allowing
16/// different transformer architectures to be used interchangeably.
17///
18/// # Supported Architectures
19///
20/// | Architecture | Context | Key Features | Speed |
21/// |--------------|---------|--------------|-------|
22/// | ModernBERT | 8,192 | RoPE, GeGLU, unpadded inference | 3x faster |
23/// | DeBERTaV3 | 512 | Disentangled attention | Baseline |
24/// | BERT/RoBERTa | 512 | Classic, widely available | Baseline |
25///
26/// # Research Alignment (ModernBERT, Dec 2024)
27///
28/// From ModernBERT paper (arXiv:2412.13663):
29/// > "Pareto improvements to BERT... encoder-only models offer great
30/// > performance-size tradeoff for retrieval and classification."
31///
32/// Key innovations:
33/// - **Alternating Attention**: Global attention every 3 layers, local (128-token
34/// window) elsewhere. Reduces complexity for long sequences.
35/// - **Unpadding**: "ModernBERT unpads inputs *before* the token embedding layer
36/// and optionally repads model outputs leading to a 10-to-20 percent
37/// performance improvement over previous methods."
38/// - **RoPE**: Rotary positional embeddings enable extrapolation to longer sequences.
39/// - **GeGLU**: Gated activation function improves over GELU.
40///
41/// # Example
42///
43/// ```ignore
44/// use anno::TextEncoder;
45///
46/// fn process_document(encoder: &dyn TextEncoder, text: &str) {
47/// let output = encoder.encode(text).unwrap();
48/// println!("Encoded {} tokens into {} dimensions",
49/// output.num_tokens, output.hidden_dim);
50///
51/// // Token offsets map back to character positions
52/// for (i, (start, end)) in output.token_offsets.iter().enumerate() {
53/// println!("Token {}: chars {}..{}", i, start, end);
54/// }
55/// }
56/// ```
57pub trait TextEncoder: Send + Sync {
58 /// Encode text into token embeddings.
59 ///
60 /// # Arguments
61 /// * `text` - Input text to encode
62 ///
63 /// # Returns
64 /// * Token embeddings as flattened [num_tokens, hidden_dim]
65 /// * Attention mask indicating valid tokens
66 fn encode(&self, text: &str) -> crate::Result<EncoderOutput>;
67
68 /// Encode a batch of texts.
69 ///
70 /// # Arguments
71 /// * `texts` - Batch of input texts
72 ///
73 /// # Returns
74 /// * RaggedBatch containing all embeddings with document boundaries
75 fn encode_batch(&self, texts: &[&str]) -> crate::Result<(Vec<f32>, RaggedBatch)>;
76
77 /// Get the hidden dimension of the encoder.
78 fn hidden_dim(&self) -> usize;
79
80 /// Get the maximum sequence length.
81 fn max_length(&self) -> usize;
82
83 /// Get the encoder architecture name.
84 fn architecture(&self) -> &'static str;
85}
86
87/// Output from text encoding.
88#[derive(Debug, Clone)]
89pub struct EncoderOutput {
90 /// Token embeddings: [num_tokens, hidden_dim]
91 pub embeddings: Vec<f32>,
92 /// Number of tokens
93 pub num_tokens: usize,
94 /// Hidden dimension
95 pub hidden_dim: usize,
96 /// Token-to-character mapping (for span recovery)
97 pub token_offsets: Vec<(usize, usize)>,
98}
99
100/// Label encoder trait for encoding entity type descriptions.
101///
102/// # Motivation
103///
104/// Zero-shot NER works by encoding entity type *descriptions* into the same
105/// vector space as text spans. Instead of training separate classifiers for
106/// each entity type, we compute similarity between spans and label embeddings.
107///
108/// This enables:
109/// - **Unlimited entity types** at inference (no retraining needed)
110/// - **Faster inference** when labels are pre-computed
111/// - **Better generalization** to unseen entity types via semantic similarity
112///
113/// # Research Alignment
114///
115/// From GLiNER bi-encoder (knowledgator/modern-gliner-bi-base-v1.0):
116/// > "textual encoder is ModernBERT-base and entity label encoder is
117/// > sentence transformer - BGE-small-en."
118///
119/// # Example
120///
121/// ```ignore
122/// use anno::LabelEncoder;
123///
124/// fn setup_custom_types(encoder: &dyn LabelEncoder) {
125/// // Encode rich descriptions for better matching
126/// let labels = &[
127/// "a named individual human being",
128/// "a company, institution, or organized group",
129/// "a geographical location, city, country, or region",
130/// ];
131///
132/// let embeddings = encoder.encode_labels(labels).unwrap();
133/// // Store embeddings in SemanticRegistry for fast lookup
134/// }
135/// ```
136pub trait LabelEncoder: Send + Sync {
137 /// Encode a single label description.
138 ///
139 /// # Arguments
140 /// * `label` - Label description (e.g., "a named individual human being")
141 fn encode_label(&self, label: &str) -> crate::Result<Vec<f32>>;
142
143 /// Encode multiple labels.
144 ///
145 /// # Arguments
146 /// * `labels` - Label descriptions
147 ///
148 /// # Returns
149 /// Flattened embeddings: [num_labels, hidden_dim]
150 fn encode_labels(&self, labels: &[&str]) -> crate::Result<Vec<f32>>;
151
152 /// Get the hidden dimension.
153 fn hidden_dim(&self) -> usize;
154}
155
156/// Bi-encoder architecture combining text and label encoders.
157///
158/// # Motivation
159///
160/// The bi-encoder architecture treats NER as a **matching problem** rather than
161/// a classification problem. It encodes text spans and entity labels separately,
162/// then computes similarity scores to determine matches.
163///
164/// ```text
165/// ┌─────────────────┐ ┌─────────────────┐
166/// │ Text Input │ │ Label Desc. │
167/// │ "Steve Jobs" │ │ "person name" │
168/// └────────┬────────┘ └────────┬────────┘
169/// │ │
170/// ▼ ▼
171/// ┌─────────────────┐ ┌─────────────────┐
172/// │ TextEncoder │ │ LabelEncoder │
173/// │ (ModernBERT) │ │ (BGE-small) │
174/// └────────┬────────┘ └────────┬────────┘
175/// │ │
176/// ▼ ▼
177/// ┌─────────────────┐ ┌─────────────────┐
178/// │ Span Embedding │◄───────►│ Label Embedding │
179/// │ [768] │ cosine │ [768] │
180/// └─────────────────┘ sim └─────────────────┘
181/// │
182/// ▼
183/// Score: 0.92
184/// ```
185///
186/// # Trade-offs
187///
188/// | Aspect | Bi-Encoder | Uni-Encoder |
189/// |--------|------------|-------------|
190/// | Entity types | Unlimited | Fixed at training |
191/// | Inference speed | Faster (pre-compute labels) | Slower |
192/// | Disambiguation | Harder (no label interaction) | Better |
193/// | Generalization | Better to new types | Limited |
194///
195/// # Research Alignment
196///
197/// From GLiNER: "GLiNER frames NER as a matching problem, comparing candidate
198/// spans with entity type embeddings."
199///
200/// From knowledgator: "Bi-encoder architecture brings several advantages...
201/// unlimited entities, faster inference, better generalization."
202///
203/// Drawback: "Lack of inter-label interactions that make it hard to
204/// disambiguate semantically similar but contextually different entities."
205///
206/// # Example
207///
208/// ```ignore
209/// use anno::BiEncoder;
210///
211/// fn extract_custom_entities(bi_enc: &dyn BiEncoder, text: &str) {
212/// let labels = &["software company", "hardware manufacturer", "person"];
213/// let scores = bi_enc.encode_and_match(text, labels, 8).unwrap();
214///
215/// for s in scores.iter().filter(|s| s.score > 0.5) {
216/// println!("Found '{}' as type {} (score: {:.2})",
217/// &text[s.start..s.end], labels[s.label_idx], s.score);
218/// }
219/// }
220/// ```
221pub trait BiEncoder: Send + Sync {
222 /// Get the text encoder.
223 fn text_encoder(&self) -> &dyn TextEncoder;
224
225 /// Get the label encoder.
226 fn label_encoder(&self) -> &dyn LabelEncoder;
227
228 /// Encode text and labels, compute span-label similarities.
229 ///
230 /// # Arguments
231 /// * `text` - Input text
232 /// * `labels` - Entity type descriptions
233 /// * `max_span_width` - Maximum span width to consider
234 ///
235 /// # Returns
236 /// Similarity scores for each (span, label) pair
237 fn encode_and_match(
238 &self,
239 text: &str,
240 labels: &[&str],
241 max_span_width: usize,
242 ) -> crate::Result<Vec<SpanLabelScore>>;
243}
244
245/// Score for a (span, label) match.
246#[derive(Debug, Clone)]
247pub struct SpanLabelScore {
248 /// Span start (character offset)
249 pub start: usize,
250 /// Span end (character offset, exclusive)
251 pub end: usize,
252 /// Label index
253 pub label_idx: usize,
254 /// Similarity score (0.0 - 1.0)
255 pub score: f32,
256}
257
258// =============================================================================