Skip to main content

anno/backends/inference/
encoder.rs

1//! Core encoder traits for GLiNER/ModernBERT-style bi-encoder extraction.
2
3#[allow(unused_imports)]
4use crate::{Entity, EntityType};
5use anno_core::RaggedBatch;
6
7// Core Encoder Traits (GLiNER/ModernBERT Alignment)
8// =============================================================================
9
10/// Text encoder trait for transformer-based encoders.
11///
12/// # Motivation
13///
14/// Modern NER systems require converting raw text into dense vector representations
15/// that capture semantic meaning. This trait abstracts the encoding step, allowing
16/// different transformer architectures to be used interchangeably.
17///
18/// # Supported Architectures
19///
20/// | Architecture | Context | Key Features | Speed |
21/// |--------------|---------|--------------|-------|
22/// | ModernBERT   | 8,192   | RoPE, GeGLU, unpadded inference | 3x faster |
23/// | DeBERTaV3    | 512     | Disentangled attention | Baseline |
24/// | BERT/RoBERTa | 512     | Classic, widely available | Baseline |
25///
26/// # Research Alignment (ModernBERT, Dec 2024)
27///
28/// From ModernBERT paper (arXiv:2412.13663):
29/// > "Pareto improvements to BERT... encoder-only models offer great
30/// > performance-size tradeoff for retrieval and classification."
31///
32/// Key innovations:
33/// - **Alternating Attention**: Global attention every 3 layers, local (128-token
34///   window) elsewhere. Reduces complexity for long sequences.
35/// - **Unpadding**: "ModernBERT unpads inputs *before* the token embedding layer
36///   and optionally repads model outputs leading to a 10-to-20 percent
37///   performance improvement over previous methods."
38/// - **RoPE**: Rotary positional embeddings enable extrapolation to longer sequences.
39/// - **GeGLU**: Gated activation function improves over GELU.
40///
41/// # Example
42///
43/// ```ignore
44/// use anno::TextEncoder;
45///
46/// fn process_document(encoder: &dyn TextEncoder, text: &str) {
47///     let output = encoder.encode(text).unwrap();
48///     println!("Encoded {} tokens into {} dimensions",
49///              output.num_tokens, output.hidden_dim);
50///
51///     // Token offsets map back to character positions
52///     for (i, (start, end)) in output.token_offsets.iter().enumerate() {
53///         println!("Token {}: chars {}..{}", i, start, end);
54///     }
55/// }
56/// ```
57pub trait TextEncoder: Send + Sync {
58    /// Encode text into token embeddings.
59    ///
60    /// # Arguments
61    /// * `text` - Input text to encode
62    ///
63    /// # Returns
64    /// * Token embeddings as flattened [num_tokens, hidden_dim]
65    /// * Attention mask indicating valid tokens
66    fn encode(&self, text: &str) -> crate::Result<EncoderOutput>;
67
68    /// Encode a batch of texts.
69    ///
70    /// # Arguments
71    /// * `texts` - Batch of input texts
72    ///
73    /// # Returns
74    /// * RaggedBatch containing all embeddings with document boundaries
75    fn encode_batch(&self, texts: &[&str]) -> crate::Result<(Vec<f32>, RaggedBatch)>;
76
77    /// Get the hidden dimension of the encoder.
78    fn hidden_dim(&self) -> usize;
79
80    /// Get the maximum sequence length.
81    fn max_length(&self) -> usize;
82
83    /// Get the encoder architecture name.
84    fn architecture(&self) -> &'static str;
85}
86
87/// Output from text encoding.
88#[derive(Debug, Clone)]
89pub struct EncoderOutput {
90    /// Token embeddings: [num_tokens, hidden_dim]
91    pub embeddings: Vec<f32>,
92    /// Number of tokens
93    pub num_tokens: usize,
94    /// Hidden dimension
95    pub hidden_dim: usize,
96    /// Token-to-character mapping (for span recovery)
97    pub token_offsets: Vec<(usize, usize)>,
98}
99
100/// Label encoder trait for encoding entity type descriptions.
101///
102/// # Motivation
103///
104/// Zero-shot NER works by encoding entity type *descriptions* into the same
105/// vector space as text spans. Instead of training separate classifiers for
106/// each entity type, we compute similarity between spans and label embeddings.
107///
108/// This enables:
109/// - **Unlimited entity types** at inference (no retraining needed)
110/// - **Faster inference** when labels are pre-computed
111/// - **Better generalization** to unseen entity types via semantic similarity
112///
113/// # Research Alignment
114///
115/// From GLiNER bi-encoder (knowledgator/modern-gliner-bi-base-v1.0):
116/// > "textual encoder is ModernBERT-base and entity label encoder is
117/// > sentence transformer - BGE-small-en."
118///
119/// # Example
120///
121/// ```ignore
122/// use anno::LabelEncoder;
123///
124/// fn setup_custom_types(encoder: &dyn LabelEncoder) {
125///     // Encode rich descriptions for better matching
126///     let labels = &[
127///         "a named individual human being",
128///         "a company, institution, or organized group",
129///         "a geographical location, city, country, or region",
130///     ];
131///
132///     let embeddings = encoder.encode_labels(labels).unwrap();
133///     // Store embeddings in SemanticRegistry for fast lookup
134/// }
135/// ```
136pub trait LabelEncoder: Send + Sync {
137    /// Encode a single label description.
138    ///
139    /// # Arguments
140    /// * `label` - Label description (e.g., "a named individual human being")
141    fn encode_label(&self, label: &str) -> crate::Result<Vec<f32>>;
142
143    /// Encode multiple labels.
144    ///
145    /// # Arguments
146    /// * `labels` - Label descriptions
147    ///
148    /// # Returns
149    /// Flattened embeddings: [num_labels, hidden_dim]
150    fn encode_labels(&self, labels: &[&str]) -> crate::Result<Vec<f32>>;
151
152    /// Get the hidden dimension.
153    fn hidden_dim(&self) -> usize;
154}
155
156/// Bi-encoder architecture combining text and label encoders.
157///
158/// # Motivation
159///
160/// The bi-encoder architecture treats NER as a **matching problem** rather than
161/// a classification problem. It encodes text spans and entity labels separately,
162/// then computes similarity scores to determine matches.
163///
164/// ```text
165/// ┌─────────────────┐         ┌─────────────────┐
166/// │   Text Input    │         │  Label Desc.    │
167/// │ "Steve Jobs"    │         │ "person name"   │
168/// └────────┬────────┘         └────────┬────────┘
169///          │                           │
170///          ▼                           ▼
171/// ┌─────────────────┐         ┌─────────────────┐
172/// │  TextEncoder    │         │  LabelEncoder   │
173/// │  (ModernBERT)   │         │  (BGE-small)    │
174/// └────────┬────────┘         └────────┬────────┘
175///          │                           │
176///          ▼                           ▼
177/// ┌─────────────────┐         ┌─────────────────┐
178/// │ Span Embedding  │◄───────►│ Label Embedding │
179/// │   [768]         │ cosine  │   [768]         │
180/// └─────────────────┘ sim     └─────────────────┘
181///                      │
182///                      ▼
183///               Score: 0.92
184/// ```
185///
186/// # Trade-offs
187///
188/// | Aspect | Bi-Encoder | Uni-Encoder |
189/// |--------|------------|-------------|
190/// | Entity types | Unlimited | Fixed at training |
191/// | Inference speed | Faster (pre-compute labels) | Slower |
192/// | Disambiguation | Harder (no label interaction) | Better |
193/// | Generalization | Better to new types | Limited |
194///
195/// # Research Alignment
196///
197/// From GLiNER: "GLiNER frames NER as a matching problem, comparing candidate
198/// spans with entity type embeddings."
199///
200/// From knowledgator: "Bi-encoder architecture brings several advantages...
201/// unlimited entities, faster inference, better generalization."
202///
203/// Drawback: "Lack of inter-label interactions that make it hard to
204/// disambiguate semantically similar but contextually different entities."
205///
206/// # Example
207///
208/// ```ignore
209/// use anno::BiEncoder;
210///
211/// fn extract_custom_entities(bi_enc: &dyn BiEncoder, text: &str) {
212///     let labels = &["software company", "hardware manufacturer", "person"];
213///     let scores = bi_enc.encode_and_match(text, labels, 8).unwrap();
214///
215///     for s in scores.iter().filter(|s| s.score > 0.5) {
216///         println!("Found '{}' as type {} (score: {:.2})",
217///                  &text[s.start..s.end], labels[s.label_idx], s.score);
218///     }
219/// }
220/// ```
221pub trait BiEncoder: Send + Sync {
222    /// Get the text encoder.
223    fn text_encoder(&self) -> &dyn TextEncoder;
224
225    /// Get the label encoder.
226    fn label_encoder(&self) -> &dyn LabelEncoder;
227
228    /// Encode text and labels, compute span-label similarities.
229    ///
230    /// # Arguments
231    /// * `text` - Input text
232    /// * `labels` - Entity type descriptions
233    /// * `max_span_width` - Maximum span width to consider
234    ///
235    /// # Returns
236    /// Similarity scores for each (span, label) pair
237    fn encode_and_match(
238        &self,
239        text: &str,
240        labels: &[&str],
241        max_span_width: usize,
242    ) -> crate::Result<Vec<SpanLabelScore>>;
243}
244
245/// Score for a (span, label) match.
246#[derive(Debug, Clone)]
247pub struct SpanLabelScore {
248    /// Span start (character offset)
249    pub start: usize,
250    /// Span end (character offset, exclusive)
251    pub end: usize,
252    /// Label index
253    pub label_idx: usize,
254    /// Similarity score (0.0 - 1.0)
255    pub score: f32,
256}
257
258// =============================================================================