ctxgraph_extract/
pipeline.rs1use std::path::{Path, PathBuf};
2
3use chrono::{DateTime, Utc};
4
5use crate::coref::CorefResolver;
6use crate::ner::{ExtractedEntity, NerEngine, NerError};
7use crate::rel::{ExtractedRelation, RelEngine, RelError};
8use crate::schema::{ExtractionSchema, SchemaError};
9use crate::temporal::{self, TemporalResult};
10
11#[derive(Debug, Clone)]
13pub struct ExtractionResult {
14 pub entities: Vec<ExtractedEntity>,
15 pub relations: Vec<ExtractedRelation>,
16 pub temporal: Vec<TemporalResult>,
17}
18
19pub struct ExtractionPipeline {
24 schema: ExtractionSchema,
25 ner: NerEngine,
26 rel: RelEngine,
27 confidence_threshold: f64,
28}
29
30impl ExtractionPipeline {
31 pub fn new(
37 schema: ExtractionSchema,
38 models_dir: &Path,
39 confidence_threshold: f64,
40 ) -> Result<Self, PipelineError> {
41 let ner_model = find_ner_model(models_dir)?;
43 let ner_tokenizer = find_tokenizer(models_dir, "gliner")?;
44
45 let ner = NerEngine::new(&ner_model, &ner_tokenizer, confidence_threshold as f32)
49 .map_err(PipelineError::Ner)?;
50
51 let rel_model = find_rel_model(models_dir);
53 let rel_tokenizer = find_tokenizer(models_dir, "multitask").ok();
54
55 let rel = RelEngine::new(
56 rel_model.as_deref(),
57 rel_tokenizer.as_deref(),
58 )
59 .map_err(PipelineError::Rel)?;
60
61 Ok(Self {
62 schema,
63 ner,
64 rel,
65 confidence_threshold,
66 })
67 }
68
69 pub fn with_defaults(models_dir: &Path) -> Result<Self, PipelineError> {
73 Self::new(ExtractionSchema::default(), models_dir, 0.5)
74 }
75
76 pub fn extract(
78 &self,
79 text: &str,
80 reference_time: DateTime<Utc>,
81 ) -> Result<ExtractionResult, PipelineError> {
82 let labels: Vec<&str> = self.schema.entity_labels();
88 let mut entities = self
89 .ner
90 .extract(text, &labels, None)
91 .map_err(PipelineError::Ner)?;
92
93 entities.retain(|e| e.confidence >= self.confidence_threshold);
95
96 let coref_entities = CorefResolver::resolve(text, &entities);
98 entities.extend(coref_entities);
99
100 let mut relations = self
102 .rel
103 .extract(text, &entities, &self.schema)
104 .map_err(PipelineError::Rel)?;
105
106 relations.retain(|r| r.confidence >= self.confidence_threshold);
108
109 let temporal = temporal::parse_temporal(text, reference_time);
111
112 Ok(ExtractionResult {
113 entities,
114 relations,
115 temporal,
116 })
117 }
118
119 pub fn schema(&self) -> &ExtractionSchema {
121 &self.schema
122 }
123
124 pub fn confidence_threshold(&self) -> f64 {
126 self.confidence_threshold
127 }
128}
129
130fn find_ner_model(models_dir: &Path) -> Result<PathBuf, PipelineError> {
137 let candidates = [
138 models_dir.join("gliner_large-v2.1/onnx/model_int8.onnx"),
139 models_dir.join("gliner_large-v2.1/onnx/model.onnx"),
140 models_dir.join("gliner2-large-q8.onnx"),
141 ];
142
143 for c in &candidates {
144 if c.exists() {
145 return Ok(c.clone());
146 }
147 }
148
149 Err(PipelineError::ModelNotFound {
150 model: "GLiNER v2.1 NER".into(),
151 searched: candidates.iter().map(|p| p.display().to_string()).collect(),
152 })
153}
154
155fn find_rel_model(models_dir: &Path) -> Option<PathBuf> {
157 let candidates = [
158 models_dir.join("gliner-multitask-large-v0.5/onnx/model.onnx"),
159 models_dir.join("gliner-multitask-large.onnx"),
160 models_dir.join("glirel-large.onnx"),
161 ];
162
163 candidates.into_iter().find(|c| c.exists())
164}
165
166fn find_tokenizer(models_dir: &Path, prefix: &str) -> Result<PathBuf, PipelineError> {
168 let candidates = if prefix == "gliner" {
169 vec![
170 models_dir.join("gliner_large-v2.1/tokenizer.json"),
171 models_dir.join("tokenizer.json"),
172 ]
173 } else {
174 vec![
175 models_dir.join("gliner-multitask-large-v0.5/tokenizer.json"),
176 models_dir.join("tokenizer.json"),
177 ]
178 };
179
180 for c in &candidates {
181 if c.exists() {
182 return Ok(c.clone());
183 }
184 }
185
186 Err(PipelineError::ModelNotFound {
187 model: format!("{prefix} tokenizer").into(),
188 searched: candidates.iter().map(|p| p.display().to_string()).collect(),
189 })
190}
191
192#[derive(Debug, thiserror::Error)]
193pub enum PipelineError {
194 #[error("NER error: {0}")]
195 Ner(#[from] NerError),
196
197 #[error("relation extraction error: {0}")]
198 Rel(#[from] RelError),
199
200 #[error("schema error: {0}")]
201 Schema(#[from] SchemaError),
202
203 #[error("model not found: {model}. Searched: {searched:?}")]
204 ModelNotFound {
205 model: String,
206 searched: Vec<String>,
207 },
208}