1#![allow(unused_imports)]
2use crate::{Entity, EntityType, Error, Result};
7use anno_core::EntityCategory;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10#[cfg(feature = "candle")]
11use std::sync::RwLock;
12
13use crate::backends::inference::{ExtractionWithRelations, RelationExtractor, ZeroShotNER};
14
15#[cfg(feature = "onnx")]
23pub(super) const TOKEN_ENT: u32 = 128002;
24#[cfg(feature = "onnx")]
26pub(super) const TOKEN_SEP: u32 = 128003;
27#[cfg(feature = "onnx")]
29pub(super) const TOKEN_START: u32 = 1;
30#[cfg(feature = "onnx")]
32pub(super) const TOKEN_END: u32 = 2;
33
34pub(super) const MAX_SPAN_WIDTH: usize = 12;
36#[cfg(feature = "candle")]
38pub(super) const MAX_COUNT: usize = 20;
39
40#[derive(Debug, Default)]
46pub struct LabelCache {
47 #[cfg(feature = "candle")]
48 cache: RwLock<HashMap<String, Vec<f32>>>,
49 #[cfg(not(feature = "candle"))]
50 _phantom: std::marker::PhantomData<()>,
51}
52
53#[cfg(feature = "candle")]
54impl LabelCache {
55 pub(super) fn new() -> Self {
56 Self {
57 cache: RwLock::new(HashMap::new()),
58 }
59 }
60
61 pub(super) fn get(&self, label: &str) -> Option<Vec<f32>> {
62 self.cache.read().ok()?.get(label).cloned()
63 }
64
65 pub(super) fn insert(&self, label: String, embedding: Vec<f32>) {
66 if let Ok(mut cache) = self.cache.write() {
67 cache.insert(label, embedding);
68 }
69 }
70}
71
72#[cfg(not(feature = "candle"))]
73impl LabelCache {
74 #[allow(dead_code)]
75 fn new() -> Self {
76 Self {
77 _phantom: std::marker::PhantomData,
78 }
79 }
80}
81
82#[derive(Debug, Clone, Default, Serialize, Deserialize)]
101pub struct TaskSchema {
102 pub entities: Option<EntityTask>,
104 pub classifications: Vec<ClassificationTask>,
106 pub structures: Vec<StructureTask>,
108}
109
110impl TaskSchema {
111 pub fn new() -> Self {
113 Self::default()
114 }
115
116 pub fn with_entities(mut self, types: &[&str]) -> Self {
118 self.entities = Some(EntityTask {
119 types: types.iter().map(|s| s.to_string()).collect(),
120 descriptions: HashMap::new(),
121 });
122 self
123 }
124
125 pub fn with_entities_described(mut self, types_with_desc: HashMap<String, String>) -> Self {
127 let types: Vec<String> = types_with_desc.keys().cloned().collect();
128 self.entities = Some(EntityTask {
129 types,
130 descriptions: types_with_desc,
131 });
132 self
133 }
134
135 pub fn with_classification(mut self, name: &str, labels: &[&str], multi_label: bool) -> Self {
137 self.classifications.push(ClassificationTask {
138 name: name.to_string(),
139 labels: labels.iter().map(|s| s.to_string()).collect(),
140 multi_label,
141 descriptions: HashMap::new(),
142 });
143 self
144 }
145
146 pub fn with_structure(mut self, task: StructureTask) -> Self {
148 self.structures.push(task);
149 self
150 }
151}
152
153#[derive(Debug, Clone, Default, Serialize, Deserialize)]
155pub struct EntityTask {
156 pub types: Vec<String>,
158 pub descriptions: HashMap<String, String>,
160}
161
162#[derive(Debug, Clone, Default, Serialize, Deserialize)]
164pub struct ClassificationTask {
165 pub name: String,
167 pub labels: Vec<String>,
169 pub multi_label: bool,
171 pub descriptions: HashMap<String, String>,
173}
174
175#[derive(Debug, Clone, Default, Serialize, Deserialize)]
177pub struct StructureTask {
178 pub name: String,
180 #[serde(skip)]
182 pub structure_type: String,
183 pub fields: Vec<StructureField>,
185}
186
187impl StructureTask {
188 pub fn new(name: &str) -> Self {
190 Self {
191 name: name.to_string(),
192 structure_type: name.to_string(),
193 fields: Vec::new(),
194 }
195 }
196
197 pub fn with_field(mut self, name: &str, field_type: FieldType) -> Self {
199 self.fields.push(StructureField {
200 name: name.to_string(),
201 field_type,
202 description: None,
203 choices: None,
204 });
205 self
206 }
207
208 pub fn with_field_described(
210 mut self,
211 name: &str,
212 field_type: FieldType,
213 description: &str,
214 ) -> Self {
215 self.fields.push(StructureField {
216 name: name.to_string(),
217 field_type,
218 description: Some(description.to_string()),
219 choices: None,
220 });
221 self
222 }
223
224 pub fn with_choice_field(mut self, name: &str, choices: &[&str]) -> Self {
226 self.fields.push(StructureField {
227 name: name.to_string(),
228 field_type: FieldType::Choice,
229 description: None,
230 choices: Some(choices.iter().map(|s| s.to_string()).collect()),
231 });
232 self
233 }
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct StructureField {
239 pub name: String,
241 pub field_type: FieldType,
243 pub description: Option<String>,
245 pub choices: Option<Vec<String>>,
247}
248
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
251pub enum FieldType {
252 String,
254 List,
256 Choice,
258}
259
260#[derive(Debug, Clone, Default, Serialize, Deserialize)]
266pub struct ExtractionResult {
267 pub entities: Vec<Entity>,
269 pub classifications: HashMap<String, ClassificationResult>,
271 pub structures: Vec<ExtractedStructure>,
273}
274
275#[derive(Debug, Clone, Default, Serialize, Deserialize)]
277pub struct ClassificationResult {
278 pub labels: Vec<String>,
280 pub scores: HashMap<String, f32>,
282}
283
284#[derive(Debug, Clone, Default, Serialize, Deserialize)]
286pub struct ExtractedStructure {
287 pub structure_type: String,
289 pub fields: HashMap<String, StructureValue>,
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295pub enum StructureValue {
296 Single(String),
298 List(Vec<String>),
300}
301
302