1use std::fmt;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
34#[non_exhaustive]
35pub enum EncoderType {
36 Bert,
38 Roberta,
40 Deberta,
42 DebertaV3,
44 ModernBert,
46 Unknown,
48}
49
50impl EncoderType {
51 #[must_use]
53 pub const fn max_context_length(&self) -> usize {
54 match self {
55 EncoderType::Bert => 512,
56 EncoderType::Roberta => 512,
57 EncoderType::Deberta => 512,
58 EncoderType::DebertaV3 => 512,
59 EncoderType::ModernBert => 8192,
60 EncoderType::Unknown => 512,
61 }
62 }
63
64 #[must_use]
66 pub const fn uses_rope(&self) -> bool {
67 matches!(self, EncoderType::ModernBert)
68 }
69
70 #[must_use]
74 pub const fn relative_speed(&self) -> u8 {
75 match self {
76 EncoderType::Bert => 5,
77 EncoderType::Roberta => 5,
78 EncoderType::Deberta => 4,
79 EncoderType::DebertaV3 => 4,
80 EncoderType::ModernBert => 6, EncoderType::Unknown => 3,
82 }
83 }
84}
85
86impl fmt::Display for EncoderType {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 match self {
89 EncoderType::Bert => write!(f, "BERT"),
90 EncoderType::Roberta => write!(f, "RoBERTa"),
91 EncoderType::Deberta => write!(f, "DeBERTa"),
92 EncoderType::DebertaV3 => write!(f, "DeBERTa-v3"),
93 EncoderType::ModernBert => write!(f, "ModernBERT"),
94 EncoderType::Unknown => write!(f, "Unknown"),
95 }
96 }
97}
98
99#[derive(Debug, Clone, PartialEq, Eq)]
101pub struct GLiNERModel {
102 pub model_id: &'static str,
104 pub encoder: EncoderType,
106 pub size: ModelSize,
108 pub supports_relations: bool,
110 pub notes: &'static str,
112}
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
116pub enum ModelSize {
117 Small,
119 Medium,
121 Large,
123 XLarge,
125}
126
127impl fmt::Display for ModelSize {
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 match self {
130 ModelSize::Small => write!(f, "S"),
131 ModelSize::Medium => write!(f, "M"),
132 ModelSize::Large => write!(f, "L"),
133 ModelSize::XLarge => write!(f, "XL"),
134 }
135 }
136}
137
138pub static GLINER_MODELS: &[GLiNERModel] = &[
140 GLiNERModel {
142 model_id: "onnx-community/gliner_small-v2.1",
143 encoder: EncoderType::DebertaV3,
144 size: ModelSize::Small,
145 supports_relations: false,
146 notes: "Fast, good accuracy, recommended for CPU",
147 },
148 GLiNERModel {
149 model_id: "onnx-community/gliner_medium-v2.1",
150 encoder: EncoderType::DebertaV3,
151 size: ModelSize::Medium,
152 supports_relations: false,
153 notes: "Balanced speed/accuracy",
154 },
155 GLiNERModel {
156 model_id: "onnx-community/gliner_large-v2.1",
157 encoder: EncoderType::DebertaV3,
158 size: ModelSize::Large,
159 supports_relations: false,
160 notes: "Higher accuracy, recommended for GPU",
161 },
162 GLiNERModel {
164 model_id: "knowledgator/modern-gliner-bi-large-v1.0",
165 encoder: EncoderType::ModernBert,
166 size: ModelSize::Large,
167 supports_relations: false,
168 notes: "Long-context encoder variant",
169 },
170 GLiNERModel {
172 model_id: "knowledgator/gliner-multitask-v1.0",
173 encoder: EncoderType::DebertaV3,
174 size: ModelSize::Medium,
175 supports_relations: true,
176 notes: "Supports relation extraction",
177 },
178 GLiNERModel {
179 model_id: "onnx-community/gliner-multitask-large-v0.5",
180 encoder: EncoderType::DebertaV3,
181 size: ModelSize::Large,
182 supports_relations: true,
183 notes: "Large multitask, higher accuracy relations",
184 },
185];
186
187impl GLiNERModel {
188 #[must_use]
190 pub fn by_id(model_id: &str) -> Option<&'static GLiNERModel> {
191 GLINER_MODELS.iter().find(|m| m.model_id == model_id)
192 }
193
194 #[must_use]
196 pub fn by_encoder(encoder: EncoderType) -> Vec<&'static GLiNERModel> {
197 GLINER_MODELS
198 .iter()
199 .filter(|m| m.encoder == encoder)
200 .collect()
201 }
202
203 #[must_use]
205 pub fn with_relations() -> Vec<&'static GLiNERModel> {
206 GLINER_MODELS
207 .iter()
208 .filter(|m| m.supports_relations)
209 .collect()
210 }
211
212 #[must_use]
214 pub fn fastest() -> &'static GLiNERModel {
215 &GLINER_MODELS[0] }
217
218 #[must_use]
220 pub fn most_accurate() -> &'static GLiNERModel {
221 GLINER_MODELS
223 .iter()
224 .find(|m| m.encoder == EncoderType::ModernBert)
225 .unwrap_or(&GLINER_MODELS[2])
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn test_encoder_type_display() {
235 assert_eq!(EncoderType::ModernBert.to_string(), "ModernBERT");
236 assert_eq!(EncoderType::DebertaV3.to_string(), "DeBERTa-v3");
237 }
238
239 #[test]
240 fn test_model_lookup() {
241 let model = GLiNERModel::by_id("onnx-community/gliner_small-v2.1");
242 assert!(model.is_some());
243 assert_eq!(model.unwrap().encoder, EncoderType::DebertaV3);
244 }
245
246 #[test]
247 fn test_models_by_encoder() {
248 let modern_models = GLiNERModel::by_encoder(EncoderType::ModernBert);
249 assert!(!modern_models.is_empty());
250 assert!(modern_models
251 .iter()
252 .all(|m| m.encoder == EncoderType::ModernBert));
253 }
254
255 #[test]
256 fn test_fastest_model() {
257 let fastest = GLiNERModel::fastest();
258 assert_eq!(fastest.size, ModelSize::Small);
259 }
260
261 #[test]
262 fn test_most_accurate() {
263 let best = GLiNERModel::most_accurate();
264 assert_eq!(best.encoder, EncoderType::ModernBert);
265 }
266
267 #[test]
268 fn test_context_length() {
269 assert_eq!(EncoderType::Bert.max_context_length(), 512);
270 assert_eq!(EncoderType::ModernBert.max_context_length(), 8192);
271 }
272}