ctxgraph_extract/pipeline.rs
1use std::path::{Path, PathBuf};
2
3use chrono::{DateTime, Utc};
4
5use crate::coref::CorefResolver;
6use crate::ner::{ExtractedEntity, NerEngine, NerError};
7use crate::rel::{ExtractedRelation, RelEngine, RelError};
8use crate::remap;
9use crate::schema::{ExtractionSchema, SchemaError};
10use crate::temporal::{self, TemporalResult};
11
12/// Complete result of running the extraction pipeline on a piece of text.
13#[derive(Debug, Clone)]
14pub struct ExtractionResult {
15 pub entities: Vec<ExtractedEntity>,
16 pub relations: Vec<ExtractedRelation>,
17 pub temporal: Vec<TemporalResult>,
18}
19
20/// The extraction pipeline orchestrates NER, relation extraction, and temporal parsing.
21///
22/// Created once and reused across multiple episodes. Model loading happens at construction
23/// time (~100-500ms), but subsequent inference calls are fast (<15ms).
24pub struct ExtractionPipeline {
25 schema: ExtractionSchema,
26 ner: NerEngine,
27 rel: RelEngine,
28 confidence_threshold: f64,
29}
30
31impl ExtractionPipeline {
32 /// Create a new extraction pipeline.
33 ///
34 /// - `schema`: Entity/relation type definitions.
35 /// - `models_dir`: Directory containing ONNX model files.
36 /// - `confidence_threshold`: Minimum confidence to keep an extraction (default: 0.5).
37 pub fn new(
38 schema: ExtractionSchema,
39 models_dir: &Path,
40 confidence_threshold: f64,
41 ) -> Result<Self, PipelineError> {
42 // Locate NER model files (span-based GLiNER v2.1)
43 let ner_model = find_ner_model(models_dir)?;
44 let ner_tokenizer = find_tokenizer(models_dir, "gliner")?;
45
46 // Use the pipeline's confidence_threshold as the GLiNER model-level threshold too.
47 // Parameters::default() hardcodes 0.5 which silently drops low-confidence spans
48 // before we ever see them; pass our threshold so callers control the cutoff.
49 let ner = NerEngine::new(&ner_model, &ner_tokenizer, confidence_threshold as f32)
50 .map_err(PipelineError::Ner)?;
51
52 // Locate relation model files (multitask GLiNER) — optional
53 let rel_model = find_rel_model(models_dir);
54 let rel_tokenizer = find_tokenizer(models_dir, "multitask").ok();
55
56 let rel = RelEngine::new(
57 rel_model.as_deref(),
58 rel_tokenizer.as_deref(),
59 )
60 .map_err(PipelineError::Rel)?;
61
62 Ok(Self {
63 schema,
64 ner,
65 rel,
66 confidence_threshold,
67 })
68 }
69
70 /// Create a pipeline with default settings.
71 ///
72 /// Uses `ExtractionSchema::default()` and 0.5 confidence threshold.
73 pub fn with_defaults(models_dir: &Path) -> Result<Self, PipelineError> {
74 Self::new(ExtractionSchema::default(), models_dir, 0.5)
75 }
76
77 /// Extract entities, relations, and temporal expressions from text.
78 pub fn extract(
79 &self,
80 text: &str,
81 reference_time: DateTime<Utc>,
82 ) -> Result<ExtractionResult, PipelineError> {
83 // Step 1: NER — extract entities.
84 // GLiNER v2.1 span mode works best with schema key names as labels
85 // ("Person", "Database", etc.) — these match the model's training vocabulary
86 // and produce reliable entity_type values without a description→key mapping.
87 // Natural-language descriptions were tested but hurt performance because
88 // the model misclassifies service/component names under "programming language"
89 // and similar too-generic prompts.
90 let labels: Vec<&str> = self.schema.entity_labels();
91 let mut entities = self
92 .ner
93 .extract(text, &labels, None)
94 .map_err(PipelineError::Ner)?;
95
96 // Filter by confidence
97 entities.retain(|e| e.confidence >= self.confidence_threshold);
98
99 // Step 1b: Coreference resolution — resolve pronouns to preceding entities
100 let coref_entities = CorefResolver::resolve(text, &entities);
101 entities.extend(coref_entities);
102
103 // Step 1c: Supplement entities — dictionary-based detection for known names
104 // that GLiNER missed, boosting recall from ~0.59 toward ~0.75+
105 remap::supplement_entities(text, &mut entities);
106
107 // Step 1d: Entity type remapping — fix common GLiNER misclassifications
108 // using domain knowledge lookup tables (Database, Infrastructure, Pattern, etc.)
109 remap::remap_entity_types(&mut entities);
110
111 // Step 1e: LLM entity cleanup — DISABLED.
112 // Tested both local Ollama (qwen2.5:3b) and GPT-4.1-mini for cleaning
113 // up GLiNER entity names. Both hurt full benchmark F1:
114 // - Ollama: entity 0.845→0.787, combined 0.678→0.630
115 // - GPT: entity 0.845→0.811, combined 0.700→0.659
116 // Root cause: most entities are already correct; cleanup removes/renames
117 // valid multi-word entities like "saga pattern", "Cloudflare Workers".
118 // The approach helps ~5 hard cases but hurts ~15 others.
119 // Future: try full LLM entity extraction (Graphiti-style) instead of
120 // cleanup, or only apply cleanup to entities with low confidence.
121
122 // Step 2: Relation extraction
123 let mut relations = self
124 .rel
125 .extract(text, &entities, &self.schema)
126 .map_err(PipelineError::Rel)?;
127
128 // Filter by confidence
129 relations.retain(|r| r.confidence >= self.confidence_threshold);
130
131 // Step 3: Temporal parsing
132 let temporal = temporal::parse_temporal(text, reference_time);
133
134 Ok(ExtractionResult {
135 entities,
136 relations,
137 temporal,
138 })
139 }
140
141 /// Get the schema used by this pipeline.
142 pub fn schema(&self) -> &ExtractionSchema {
143 &self.schema
144 }
145
146 /// Get the confidence threshold.
147 pub fn confidence_threshold(&self) -> f64 {
148 self.confidence_threshold
149 }
150}
151
152/// Find the NER ONNX model file in the models directory.
153///
154/// Looks for these files in order:
155/// 1. `gliner_large-v2.1/onnx/model_int8.onnx` (quantized, recommended)
156/// 2. `gliner_large-v2.1/onnx/model.onnx` (full precision)
157/// 3. `gliner2-large-q8.onnx` (legacy flat layout)
158fn find_ner_model(models_dir: &Path) -> Result<PathBuf, PipelineError> {
159 let candidates = [
160 models_dir.join("gliner_large-v2.1/onnx/model_int8.onnx"),
161 models_dir.join("gliner_large-v2.1/onnx/model.onnx"),
162 models_dir.join("gliner2-large-q8.onnx"),
163 ];
164
165 for c in &candidates {
166 if c.exists() {
167 return Ok(c.clone());
168 }
169 }
170
171 Err(PipelineError::ModelNotFound {
172 model: "GLiNER v2.1 NER".into(),
173 searched: candidates.iter().map(|p| p.display().to_string()).collect(),
174 })
175}
176
177/// Find the relation extraction model (token-level multitask GLiNER).
178///
179/// NOTE: gline-rs RelationPipeline requires a **token-level** model (4 inputs:
180/// input_ids, attention_mask, words_mask, text_lengths). Span-level models like
181/// `gliner_multi-v2.1` are NOT compatible and must not be listed here.
182///
183/// Compatible model: `knowledgator/gliner-multitask-large-v0.5` (token_level mode).
184/// Pre-converted ONNX available from `onnx-community/gliner-multitask-large-v0.5`.
185fn find_rel_model(models_dir: &Path) -> Option<PathBuf> {
186 let candidates = [
187 // INT8 quantized (from onnx-community, downloaded by ModelManager)
188 models_dir.join("gliner-multitask-large-v0.5/onnx/model_int8.onnx"),
189 // Full precision (from manual conversion via scripts/convert_model.py)
190 models_dir.join("gliner-multitask-large-v0.5/onnx/model.onnx"),
191 // Legacy flat layout
192 models_dir.join("gliner-multitask-large.onnx"),
193 ];
194
195 candidates.into_iter().find(|c| c.exists())
196}
197
198/// Find a tokenizer.json file associated with a model.
199fn find_tokenizer(models_dir: &Path, prefix: &str) -> Result<PathBuf, PipelineError> {
200 let candidates = if prefix == "gliner" {
201 vec![
202 models_dir.join("gliner_large-v2.1/tokenizer.json"),
203 models_dir.join("tokenizer.json"),
204 ]
205 } else if prefix == "multitask" {
206 vec![
207 models_dir.join("gliner-multitask-large-v0.5/tokenizer.json"),
208 models_dir.join("tokenizer.json"),
209 ]
210 } else {
211 vec![models_dir.join("tokenizer.json")]
212 };
213
214 for c in &candidates {
215 if c.exists() {
216 return Ok(c.clone());
217 }
218 }
219
220 Err(PipelineError::ModelNotFound {
221 model: format!("{prefix} tokenizer").into(),
222 searched: candidates.iter().map(|p| p.display().to_string()).collect(),
223 })
224}
225
226#[derive(Debug, thiserror::Error)]
227pub enum PipelineError {
228 #[error("NER error: {0}")]
229 Ner(#[from] NerError),
230
231 #[error("relation extraction error: {0}")]
232 Rel(#[from] RelError),
233
234 #[error("schema error: {0}")]
235 Schema(#[from] SchemaError),
236
237 #[error("model not found: {model}. Searched: {searched:?}")]
238 ModelNotFound {
239 model: String,
240 searched: Vec<String>,
241 },
242}