Skip to main content

anno/backends/gliner2/
schema.rs

1#![allow(unused_imports)]
2//! GLiNER2 shared schema types: task definition, extraction results, caches.
3//!
4//! These are feature-agnostic — imported by both the ONNX and Candle backends.
5
6use crate::{Entity, EntityType, Error, Result};
7use anno_core::EntityCategory;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10#[cfg(feature = "candle")]
11use std::sync::RwLock;
12
13use crate::backends::inference::{ExtractionWithRelations, RelationExtractor, ZeroShotNER};
14
15// =============================================================================
16// Special Token IDs (gliner-multitask-large-v0.5 vocabulary)
17// Valid tokens: [MASK]=128000, [FLERT]=128001, <<ENT>>=128002, <<SEP>>=128003
18// Note: [P], [C], [L] markers don't exist in this model - DO NOT USE 128004+
19// =============================================================================
20
21/// <<ENT>> token - entity type marker (class_token_index in config)
22#[cfg(feature = "onnx")]
23pub(super) const TOKEN_ENT: u32 = 128002;
24/// <<SEP>> separator token
25#[cfg(feature = "onnx")]
26pub(super) const TOKEN_SEP: u32 = 128003;
27/// Start token [CLS]
28#[cfg(feature = "onnx")]
29pub(super) const TOKEN_START: u32 = 1;
30/// End token [SEP]
31#[cfg(feature = "onnx")]
32pub(super) const TOKEN_END: u32 = 2;
33
34/// Default max span width
35pub(super) const MAX_SPAN_WIDTH: usize = 12;
36/// Max count for structure instances (0-19)
37#[cfg(feature = "candle")]
38pub(super) const MAX_COUNT: usize = 20;
39
40// =============================================================================
41// Label Embedding Cache
42// =============================================================================
43
44/// Cache for label embeddings to avoid recomputation
45#[derive(Debug, Default)]
46pub struct LabelCache {
47    #[cfg(feature = "candle")]
48    cache: RwLock<HashMap<String, Vec<f32>>>,
49    #[cfg(not(feature = "candle"))]
50    _phantom: std::marker::PhantomData<()>,
51}
52
53#[cfg(feature = "candle")]
54impl LabelCache {
55    pub(super) fn new() -> Self {
56        Self {
57            cache: RwLock::new(HashMap::new()),
58        }
59    }
60
61    pub(super) fn get(&self, label: &str) -> Option<Vec<f32>> {
62        self.cache.read().ok()?.get(label).cloned()
63    }
64
65    pub(super) fn insert(&self, label: String, embedding: Vec<f32>) {
66        if let Ok(mut cache) = self.cache.write() {
67            cache.insert(label, embedding);
68        }
69    }
70}
71
72#[cfg(not(feature = "candle"))]
73impl LabelCache {
74    #[allow(dead_code)]
75    fn new() -> Self {
76        Self {
77            _phantom: std::marker::PhantomData,
78        }
79    }
80}
81
82// =============================================================================
83// Task Schema
84// =============================================================================
85
86/// Schema defining what to extract.
87///
88/// Use builder methods to construct complex schemas:
89///
90/// ```rust,ignore
91/// let schema = TaskSchema::new()
92///     .with_entities(&["person", "organization"])
93///     .with_classification("sentiment", &["positive", "negative"], false)
94///     .with_structure(
95///         StructureTask::new("product")
96///             .with_field("name", FieldType::String)
97///             .with_field("price", FieldType::String)
98///     );
99/// ```
100#[derive(Debug, Clone, Default, Serialize, Deserialize)]
101pub struct TaskSchema {
102    /// Entity types to extract
103    pub entities: Option<EntityTask>,
104    /// Classification tasks
105    pub classifications: Vec<ClassificationTask>,
106    /// Structure extraction tasks
107    pub structures: Vec<StructureTask>,
108}
109
110impl TaskSchema {
111    /// Create empty schema.
112    pub fn new() -> Self {
113        Self::default()
114    }
115
116    /// Add entity types to extract.
117    pub fn with_entities(mut self, types: &[&str]) -> Self {
118        self.entities = Some(EntityTask {
119            types: types.iter().map(|s| s.to_string()).collect(),
120            descriptions: HashMap::new(),
121        });
122        self
123    }
124
125    /// Add entity types with descriptions for better zero-shot.
126    pub fn with_entities_described(mut self, types_with_desc: HashMap<String, String>) -> Self {
127        let types: Vec<String> = types_with_desc.keys().cloned().collect();
128        self.entities = Some(EntityTask {
129            types,
130            descriptions: types_with_desc,
131        });
132        self
133    }
134
135    /// Add a classification task.
136    pub fn with_classification(mut self, name: &str, labels: &[&str], multi_label: bool) -> Self {
137        self.classifications.push(ClassificationTask {
138            name: name.to_string(),
139            labels: labels.iter().map(|s| s.to_string()).collect(),
140            multi_label,
141            descriptions: HashMap::new(),
142        });
143        self
144    }
145
146    /// Add a structure extraction task.
147    pub fn with_structure(mut self, task: StructureTask) -> Self {
148        self.structures.push(task);
149        self
150    }
151}
152
153/// Entity extraction task configuration.
154#[derive(Debug, Clone, Default, Serialize, Deserialize)]
155pub struct EntityTask {
156    /// Entity type labels
157    pub types: Vec<String>,
158    /// Optional descriptions for each type
159    pub descriptions: HashMap<String, String>,
160}
161
162/// Classification task configuration.
163#[derive(Debug, Clone, Default, Serialize, Deserialize)]
164pub struct ClassificationTask {
165    /// Task name (e.g., "sentiment")
166    pub name: String,
167    /// Class labels
168    pub labels: Vec<String>,
169    /// Whether multiple labels can be selected
170    pub multi_label: bool,
171    /// Optional descriptions for labels
172    pub descriptions: HashMap<String, String>,
173}
174
175/// Hierarchical structure extraction task.
176#[derive(Debug, Clone, Default, Serialize, Deserialize)]
177pub struct StructureTask {
178    /// Structure type name (parent entity)
179    pub name: String,
180    /// Internal alias for compatibility
181    #[serde(skip)]
182    pub structure_type: String,
183    /// Child fields to extract
184    pub fields: Vec<StructureField>,
185}
186
187impl StructureTask {
188    /// Create new structure task.
189    pub fn new(name: &str) -> Self {
190        Self {
191            name: name.to_string(),
192            structure_type: name.to_string(),
193            fields: Vec::new(),
194        }
195    }
196
197    /// Add a field to extract.
198    pub fn with_field(mut self, name: &str, field_type: FieldType) -> Self {
199        self.fields.push(StructureField {
200            name: name.to_string(),
201            field_type,
202            description: None,
203            choices: None,
204        });
205        self
206    }
207
208    /// Add a field with description.
209    pub fn with_field_described(
210        mut self,
211        name: &str,
212        field_type: FieldType,
213        description: &str,
214    ) -> Self {
215        self.fields.push(StructureField {
216            name: name.to_string(),
217            field_type,
218            description: Some(description.to_string()),
219            choices: None,
220        });
221        self
222    }
223
224    /// Add a choice field with constrained options.
225    pub fn with_choice_field(mut self, name: &str, choices: &[&str]) -> Self {
226        self.fields.push(StructureField {
227            name: name.to_string(),
228            field_type: FieldType::Choice,
229            description: None,
230            choices: Some(choices.iter().map(|s| s.to_string()).collect()),
231        });
232        self
233    }
234}
235
236/// Structure field configuration.
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct StructureField {
239    /// Field name
240    pub name: String,
241    /// Field type
242    pub field_type: FieldType,
243    /// Optional description
244    pub description: Option<String>,
245    /// For Choice type: allowed values
246    pub choices: Option<Vec<String>>,
247}
248
249/// Field type for structure extraction.
250#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
251pub enum FieldType {
252    /// Single string value
253    String,
254    /// List of values
255    List,
256    /// Choice from constrained options
257    Choice,
258}
259
260// =============================================================================
261// Extraction Results
262// =============================================================================
263
264/// Combined extraction result.
265#[derive(Debug, Clone, Default, Serialize, Deserialize)]
266pub struct ExtractionResult {
267    /// Extracted entities
268    pub entities: Vec<Entity>,
269    /// Classification results by task name
270    pub classifications: HashMap<String, ClassificationResult>,
271    /// Extracted structures
272    pub structures: Vec<ExtractedStructure>,
273}
274
275/// Classification result.
276#[derive(Debug, Clone, Default, Serialize, Deserialize)]
277pub struct ClassificationResult {
278    /// Selected label(s)
279    pub labels: Vec<String>,
280    /// Score for each label
281    pub scores: HashMap<String, f32>,
282}
283
284/// Extracted structure instance.
285#[derive(Debug, Clone, Default, Serialize, Deserialize)]
286pub struct ExtractedStructure {
287    /// Structure type
288    pub structure_type: String,
289    /// Extracted field values
290    pub fields: HashMap<String, StructureValue>,
291}
292
293/// Value for a structure field.
294#[derive(Debug, Clone, Serialize, Deserialize)]
295pub enum StructureValue {
296    /// Single value
297    Single(String),
298    /// List of values
299    List(Vec<String>),
300}
301
302// =============================================================================