Skip to main content

anno/backends/
encoder.rs

1//! Encoder abstraction for GLiNER and span matching models.
2//!
3//! # Design
4//!
5//! GLiNER separates the encoder (BERT/RoBERTa/ModernBERT) from the span matching head.
6//! This module provides abstractions for:
7//!
8//! 1. **Encoder**: Transforms text to embeddings
9//! 2. **SpanMatcher**: Takes embeddings + entity type embeddings and computes similarity
10//!
11//! # Available Encoders
12//!
13//! | Model | Context | Notes |
14//! |-------|---------|-------|
15//! | BERT | 512 | Classic, widely supported |
16//! | RoBERTa | 512 | Similar interface, different pretraining |
17//! | DeBERTa | 512 | Alternative attention formulation |
18//! | ModernBERT | 8192 | Long-context encoder |
19//!
20//! # GLiNER Models by Encoder
21//!
22//! | Model ID | Base Encoder | Mode |
23//! |----------|--------------|------|
24//! | `onnx-community/gliner_small-v2.1` | DeBERTa-v3-small | Span |
25//! | `onnx-community/gliner_medium-v2.1` | DeBERTa-v3-base | Span |
26//! | `onnx-community/gliner_large-v2.1` | DeBERTa-v3-large | Span |
27//! | `knowledgator/modern-gliner-bi-large-v1.0` | ModernBERT-large | Span |
28//! | `knowledgator/gliner-multitask-v1.0` | DeBERTa-v3-base | Token |
29
30use std::fmt;
31
32/// Known encoder architectures.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
34#[non_exhaustive]
35pub enum EncoderType {
36    /// Classic BERT (512 tokens)
37    Bert,
38    /// RoBERTa (512 tokens)
39    Roberta,
40    /// DeBERTa (512 tokens, improved attention)
41    Deberta,
42    /// DeBERTa v3 (512 tokens, latest)
43    DebertaV3,
44    /// ModernBERT (8192 tokens)
45    ModernBert,
46    /// Unknown/custom encoder
47    Unknown,
48}
49
50impl EncoderType {
51    /// Maximum context length for this encoder.
52    #[must_use]
53    pub const fn max_context_length(&self) -> usize {
54        match self {
55            EncoderType::Bert => 512,
56            EncoderType::Roberta => 512,
57            EncoderType::Deberta => 512,
58            EncoderType::DebertaV3 => 512,
59            EncoderType::ModernBert => 8192,
60            EncoderType::Unknown => 512,
61        }
62    }
63
64    /// Whether this encoder uses RoPE (rotary position embeddings).
65    #[must_use]
66    pub const fn uses_rope(&self) -> bool {
67        matches!(self, EncoderType::ModernBert)
68    }
69
70    /// Relative speed (higher = faster).
71    ///
72    /// This is a coarse, internal-only heuristic. Do not treat it as a benchmark.
73    #[must_use]
74    pub const fn relative_speed(&self) -> u8 {
75        match self {
76            EncoderType::Bert => 5,
77            EncoderType::Roberta => 5,
78            EncoderType::Deberta => 4,
79            EncoderType::DebertaV3 => 4,
80            EncoderType::ModernBert => 6, // unpadding can reduce wasted compute
81            EncoderType::Unknown => 3,
82        }
83    }
84}
85
86impl fmt::Display for EncoderType {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        match self {
89            EncoderType::Bert => write!(f, "BERT"),
90            EncoderType::Roberta => write!(f, "RoBERTa"),
91            EncoderType::Deberta => write!(f, "DeBERTa"),
92            EncoderType::DebertaV3 => write!(f, "DeBERTa-v3"),
93            EncoderType::ModernBert => write!(f, "ModernBERT"),
94            EncoderType::Unknown => write!(f, "Unknown"),
95        }
96    }
97}
98
99/// Known GLiNER model variants.
100#[derive(Debug, Clone, PartialEq, Eq)]
101pub struct GLiNERModel {
102    /// HuggingFace model ID.
103    pub model_id: &'static str,
104    /// Base encoder type.
105    pub encoder: EncoderType,
106    /// Model size (parameters).
107    pub size: ModelSize,
108    /// Whether this model supports relation extraction.
109    pub supports_relations: bool,
110    /// Notes about this model.
111    pub notes: &'static str,
112}
113
114/// Model size category.
115#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
116pub enum ModelSize {
117    /// Smaller model size class.
118    Small,
119    /// Medium model size class.
120    Medium,
121    /// Large model size class.
122    Large,
123    /// Extra-large model size class.
124    XLarge,
125}
126
127impl fmt::Display for ModelSize {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        match self {
130            ModelSize::Small => write!(f, "S"),
131            ModelSize::Medium => write!(f, "M"),
132            ModelSize::Large => write!(f, "L"),
133            ModelSize::XLarge => write!(f, "XL"),
134        }
135    }
136}
137
138/// Catalog of known GLiNER models.
139pub static GLINER_MODELS: &[GLiNERModel] = &[
140    // DeBERTa-based (standard)
141    GLiNERModel {
142        model_id: "onnx-community/gliner_small-v2.1",
143        encoder: EncoderType::DebertaV3,
144        size: ModelSize::Small,
145        supports_relations: false,
146        notes: "Fast, good accuracy, recommended for CPU",
147    },
148    GLiNERModel {
149        model_id: "onnx-community/gliner_medium-v2.1",
150        encoder: EncoderType::DebertaV3,
151        size: ModelSize::Medium,
152        supports_relations: false,
153        notes: "Balanced speed/accuracy",
154    },
155    GLiNERModel {
156        model_id: "onnx-community/gliner_large-v2.1",
157        encoder: EncoderType::DebertaV3,
158        size: ModelSize::Large,
159        supports_relations: false,
160        notes: "Higher accuracy, recommended for GPU",
161    },
162    // ModernBERT-based (long-context)
163    GLiNERModel {
164        model_id: "knowledgator/modern-gliner-bi-large-v1.0",
165        encoder: EncoderType::ModernBert,
166        size: ModelSize::Large,
167        supports_relations: false,
168        notes: "Long-context encoder variant",
169    },
170    // Multitask (relations)
171    GLiNERModel {
172        model_id: "knowledgator/gliner-multitask-v1.0",
173        encoder: EncoderType::DebertaV3,
174        size: ModelSize::Medium,
175        supports_relations: true,
176        notes: "Supports relation extraction",
177    },
178    GLiNERModel {
179        model_id: "onnx-community/gliner-multitask-large-v0.5",
180        encoder: EncoderType::DebertaV3,
181        size: ModelSize::Large,
182        supports_relations: true,
183        notes: "Large multitask, higher accuracy relations",
184    },
185];
186
187impl GLiNERModel {
188    /// Find a model by ID.
189    #[must_use]
190    pub fn by_id(model_id: &str) -> Option<&'static GLiNERModel> {
191        GLINER_MODELS.iter().find(|m| m.model_id == model_id)
192    }
193
194    /// Get all models with a specific encoder.
195    #[must_use]
196    pub fn by_encoder(encoder: EncoderType) -> Vec<&'static GLiNERModel> {
197        GLINER_MODELS
198            .iter()
199            .filter(|m| m.encoder == encoder)
200            .collect()
201    }
202
203    /// Get models that support relations.
204    #[must_use]
205    pub fn with_relations() -> Vec<&'static GLiNERModel> {
206        GLINER_MODELS
207            .iter()
208            .filter(|m| m.supports_relations)
209            .collect()
210    }
211
212    /// Get the fastest model.
213    #[must_use]
214    pub fn fastest() -> &'static GLiNERModel {
215        &GLINER_MODELS[0] // Small is fastest
216    }
217
218    /// Get the most accurate model.
219    #[must_use]
220    pub fn most_accurate() -> &'static GLiNERModel {
221        // Prefer the ModernBERT variant when present; otherwise fall back to the largest DeBERTa v3.
222        GLINER_MODELS
223            .iter()
224            .find(|m| m.encoder == EncoderType::ModernBert)
225            .unwrap_or(&GLINER_MODELS[2])
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn test_encoder_type_display() {
235        assert_eq!(EncoderType::ModernBert.to_string(), "ModernBERT");
236        assert_eq!(EncoderType::DebertaV3.to_string(), "DeBERTa-v3");
237    }
238
239    #[test]
240    fn test_model_lookup() {
241        let model = GLiNERModel::by_id("onnx-community/gliner_small-v2.1");
242        assert!(model.is_some());
243        assert_eq!(model.unwrap().encoder, EncoderType::DebertaV3);
244    }
245
246    #[test]
247    fn test_models_by_encoder() {
248        let modern_models = GLiNERModel::by_encoder(EncoderType::ModernBert);
249        assert!(!modern_models.is_empty());
250        assert!(modern_models
251            .iter()
252            .all(|m| m.encoder == EncoderType::ModernBert));
253    }
254
255    #[test]
256    fn test_fastest_model() {
257        let fastest = GLiNERModel::fastest();
258        assert_eq!(fastest.size, ModelSize::Small);
259    }
260
261    #[test]
262    fn test_most_accurate() {
263        let best = GLiNERModel::most_accurate();
264        assert_eq!(best.encoder, EncoderType::ModernBert);
265    }
266
267    #[test]
268    fn test_context_length() {
269        assert_eq!(EncoderType::Bert.max_context_length(), 512);
270        assert_eq!(EncoderType::ModernBert.max_context_length(), 8192);
271    }
272}