Skip to main content

ctxgraph_extract/
pipeline.rs

1use std::path::{Path, PathBuf};
2
3use chrono::{DateTime, Utc};
4
5use crate::coref::CorefResolver;
6use crate::ner::{ExtractedEntity, NerEngine, NerError};
7use crate::rel::{ExtractedRelation, RelEngine, RelError};
8use crate::remap;
9use crate::schema::{ExtractionSchema, SchemaError};
10use crate::temporal::{self, TemporalResult};
11
12/// Complete result of running the extraction pipeline on a piece of text.
13#[derive(Debug, Clone)]
14pub struct ExtractionResult {
15    pub entities: Vec<ExtractedEntity>,
16    pub relations: Vec<ExtractedRelation>,
17    pub temporal: Vec<TemporalResult>,
18}
19
20/// The extraction pipeline orchestrates NER, relation extraction, and temporal parsing.
21///
22/// Created once and reused across multiple episodes. Model loading happens at construction
23/// time (~100-500ms), but subsequent inference calls are fast (<15ms).
24pub struct ExtractionPipeline {
25    schema: ExtractionSchema,
26    ner: NerEngine,
27    rel: RelEngine,
28    confidence_threshold: f64,
29}
30
31impl ExtractionPipeline {
32    /// Create a new extraction pipeline.
33    ///
34    /// - `schema`: Entity/relation type definitions.
35    /// - `models_dir`: Directory containing ONNX model files.
36    /// - `confidence_threshold`: Minimum confidence to keep an extraction (default: 0.5).
37    pub fn new(
38        schema: ExtractionSchema,
39        models_dir: &Path,
40        confidence_threshold: f64,
41    ) -> Result<Self, PipelineError> {
42        // Locate NER model files (span-based GLiNER v2.1)
43        let ner_model = find_ner_model(models_dir)?;
44        let ner_tokenizer = find_tokenizer(models_dir, "gliner")?;
45
46        // Use the pipeline's confidence_threshold as the GLiNER model-level threshold too.
47        // Parameters::default() hardcodes 0.5 which silently drops low-confidence spans
48        // before we ever see them; pass our threshold so callers control the cutoff.
49        let ner = NerEngine::new(&ner_model, &ner_tokenizer, confidence_threshold as f32)
50            .map_err(PipelineError::Ner)?;
51
52        // Locate relation model files (multitask GLiNER) — optional
53        let rel_model = find_rel_model(models_dir);
54        let rel_tokenizer = find_tokenizer(models_dir, "multitask").ok();
55
56        let rel = RelEngine::new(
57            rel_model.as_deref(),
58            rel_tokenizer.as_deref(),
59        )
60        .map_err(PipelineError::Rel)?;
61
62        Ok(Self {
63            schema,
64            ner,
65            rel,
66            confidence_threshold,
67        })
68    }
69
70    /// Create a pipeline with default settings.
71    ///
72    /// Uses `ExtractionSchema::default()` and 0.5 confidence threshold.
73    pub fn with_defaults(models_dir: &Path) -> Result<Self, PipelineError> {
74        Self::new(ExtractionSchema::default(), models_dir, 0.5)
75    }
76
77    /// Extract entities, relations, and temporal expressions from text.
78    pub fn extract(
79        &self,
80        text: &str,
81        reference_time: DateTime<Utc>,
82    ) -> Result<ExtractionResult, PipelineError> {
83        // Step 1: NER — extract entities.
84        // GLiNER v2.1 span mode works best with schema key names as labels
85        // ("Person", "Database", etc.) — these match the model's training vocabulary
86        // and produce reliable entity_type values without a description→key mapping.
87        // Natural-language descriptions were tested but hurt performance because
88        // the model misclassifies service/component names under "programming language"
89        // and similar too-generic prompts.
90        let labels: Vec<&str> = self.schema.entity_labels();
91        let mut entities = self
92            .ner
93            .extract(text, &labels, None)
94            .map_err(PipelineError::Ner)?;
95
96        // Filter by confidence
97        entities.retain(|e| e.confidence >= self.confidence_threshold);
98
99        // Step 1b: Coreference resolution — resolve pronouns to preceding entities
100        let coref_entities = CorefResolver::resolve(text, &entities);
101        entities.extend(coref_entities);
102
103        // Step 1c: Supplement entities — dictionary-based detection for known names
104        // that GLiNER missed, boosting recall from ~0.59 toward ~0.75+
105        remap::supplement_entities(text, &mut entities);
106
107        // Step 1d: Entity type remapping — fix common GLiNER misclassifications
108        // using domain knowledge lookup tables (Database, Infrastructure, Pattern, etc.)
109        remap::remap_entity_types(&mut entities);
110
111        // Step 1e: LLM entity cleanup — DISABLED.
112        // Tested both local Ollama (qwen2.5:3b) and GPT-4.1-mini for cleaning
113        // up GLiNER entity names. Both hurt full benchmark F1:
114        //   - Ollama: entity 0.845→0.787, combined 0.678→0.630
115        //   - GPT:    entity 0.845→0.811, combined 0.700→0.659
116        // Root cause: most entities are already correct; cleanup removes/renames
117        // valid multi-word entities like "saga pattern", "Cloudflare Workers".
118        // The approach helps ~5 hard cases but hurts ~15 others.
119        // Future: try full LLM entity extraction (Graphiti-style) instead of
120        // cleanup, or only apply cleanup to entities with low confidence.
121
122        // Step 2: Relation extraction
123        let mut relations = self
124            .rel
125            .extract(text, &entities, &self.schema)
126            .map_err(PipelineError::Rel)?;
127
128        // Filter by confidence
129        relations.retain(|r| r.confidence >= self.confidence_threshold);
130
131        // Step 3: Temporal parsing
132        let temporal = temporal::parse_temporal(text, reference_time);
133
134        Ok(ExtractionResult {
135            entities,
136            relations,
137            temporal,
138        })
139    }
140
141    /// Get the schema used by this pipeline.
142    pub fn schema(&self) -> &ExtractionSchema {
143        &self.schema
144    }
145
146    /// Get the confidence threshold.
147    pub fn confidence_threshold(&self) -> f64 {
148        self.confidence_threshold
149    }
150}
151
152/// Find the NER ONNX model file in the models directory.
153///
154/// Looks for these files in order:
155/// 1. `gliner_large-v2.1/onnx/model_int8.onnx` (quantized, recommended)
156/// 2. `gliner_large-v2.1/onnx/model.onnx` (full precision)
157/// 3. `gliner2-large-q8.onnx` (legacy flat layout)
158fn find_ner_model(models_dir: &Path) -> Result<PathBuf, PipelineError> {
159    let candidates = [
160        models_dir.join("gliner_large-v2.1/onnx/model_int8.onnx"),
161        models_dir.join("gliner_large-v2.1/onnx/model.onnx"),
162        models_dir.join("gliner2-large-q8.onnx"),
163    ];
164
165    for c in &candidates {
166        if c.exists() {
167            return Ok(c.clone());
168        }
169    }
170
171    Err(PipelineError::ModelNotFound {
172        model: "GLiNER v2.1 NER".into(),
173        searched: candidates.iter().map(|p| p.display().to_string()).collect(),
174    })
175}
176
177/// Find the relation extraction model (token-level multitask GLiNER).
178///
179/// NOTE: gline-rs RelationPipeline requires a **token-level** model (4 inputs:
180/// input_ids, attention_mask, words_mask, text_lengths). Span-level models like
181/// `gliner_multi-v2.1` are NOT compatible and must not be listed here.
182///
183/// Compatible model: `knowledgator/gliner-multitask-large-v0.5` (token_level mode).
184/// Pre-converted ONNX available from `onnx-community/gliner-multitask-large-v0.5`.
185fn find_rel_model(models_dir: &Path) -> Option<PathBuf> {
186    let candidates = [
187        // INT8 quantized (from onnx-community, downloaded by ModelManager)
188        models_dir.join("gliner-multitask-large-v0.5/onnx/model_int8.onnx"),
189        // Full precision (from manual conversion via scripts/convert_model.py)
190        models_dir.join("gliner-multitask-large-v0.5/onnx/model.onnx"),
191        // Legacy flat layout
192        models_dir.join("gliner-multitask-large.onnx"),
193    ];
194
195    candidates.into_iter().find(|c| c.exists())
196}
197
198/// Find a tokenizer.json file associated with a model.
199fn find_tokenizer(models_dir: &Path, prefix: &str) -> Result<PathBuf, PipelineError> {
200    let candidates = if prefix == "gliner" {
201        vec![
202            models_dir.join("gliner_large-v2.1/tokenizer.json"),
203            models_dir.join("tokenizer.json"),
204        ]
205    } else if prefix == "multitask" {
206        vec![
207            models_dir.join("gliner-multitask-large-v0.5/tokenizer.json"),
208            models_dir.join("tokenizer.json"),
209        ]
210    } else {
211        vec![models_dir.join("tokenizer.json")]
212    };
213
214    for c in &candidates {
215        if c.exists() {
216            return Ok(c.clone());
217        }
218    }
219
220    Err(PipelineError::ModelNotFound {
221        model: format!("{prefix} tokenizer").into(),
222        searched: candidates.iter().map(|p| p.display().to_string()).collect(),
223    })
224}
225
226#[derive(Debug, thiserror::Error)]
227pub enum PipelineError {
228    #[error("NER error: {0}")]
229    Ner(#[from] NerError),
230
231    #[error("relation extraction error: {0}")]
232    Rel(#[from] RelError),
233
234    #[error("schema error: {0}")]
235    Schema(#[from] SchemaError),
236
237    #[error("model not found: {model}. Searched: {searched:?}")]
238    ModelNotFound {
239        model: String,
240        searched: Vec<String>,
241    },
242}