Skip to main content

ctxgraph_extract/
pipeline.rs

1use std::path::{Path, PathBuf};
2
3use chrono::{DateTime, Utc};
4
5use crate::coref::CorefResolver;
6use crate::ner::{ExtractedEntity, NerEngine, NerError};
7use crate::rel::{ExtractedRelation, RelEngine, RelError};
8use crate::schema::{ExtractionSchema, SchemaError};
9use crate::temporal::{self, TemporalResult};
10
11/// Complete result of running the extraction pipeline on a piece of text.
12#[derive(Debug, Clone)]
13pub struct ExtractionResult {
14    pub entities: Vec<ExtractedEntity>,
15    pub relations: Vec<ExtractedRelation>,
16    pub temporal: Vec<TemporalResult>,
17}
18
19/// The extraction pipeline orchestrates NER, relation extraction, and temporal parsing.
20///
21/// Created once and reused across multiple episodes. Model loading happens at construction
22/// time (~100-500ms), but subsequent inference calls are fast (<15ms).
23pub struct ExtractionPipeline {
24    schema: ExtractionSchema,
25    ner: NerEngine,
26    rel: RelEngine,
27    confidence_threshold: f64,
28}
29
30impl ExtractionPipeline {
31    /// Create a new extraction pipeline.
32    ///
33    /// - `schema`: Entity/relation type definitions.
34    /// - `models_dir`: Directory containing ONNX model files.
35    /// - `confidence_threshold`: Minimum confidence to keep an extraction (default: 0.5).
36    pub fn new(
37        schema: ExtractionSchema,
38        models_dir: &Path,
39        confidence_threshold: f64,
40    ) -> Result<Self, PipelineError> {
41        // Locate NER model files (span-based GLiNER v2.1)
42        let ner_model = find_ner_model(models_dir)?;
43        let ner_tokenizer = find_tokenizer(models_dir, "gliner")?;
44
45        // Use the pipeline's confidence_threshold as the GLiNER model-level threshold too.
46        // Parameters::default() hardcodes 0.5 which silently drops low-confidence spans
47        // before we ever see them; pass our threshold so callers control the cutoff.
48        let ner = NerEngine::new(&ner_model, &ner_tokenizer, confidence_threshold as f32)
49            .map_err(PipelineError::Ner)?;
50
51        // Locate relation model files (multitask GLiNER) — optional
52        let rel_model = find_rel_model(models_dir);
53        let rel_tokenizer = find_tokenizer(models_dir, "multitask").ok();
54
55        let rel = RelEngine::new(
56            rel_model.as_deref(),
57            rel_tokenizer.as_deref(),
58        )
59        .map_err(PipelineError::Rel)?;
60
61        Ok(Self {
62            schema,
63            ner,
64            rel,
65            confidence_threshold,
66        })
67    }
68
69    /// Create a pipeline with default settings.
70    ///
71    /// Uses `ExtractionSchema::default()` and 0.5 confidence threshold.
72    pub fn with_defaults(models_dir: &Path) -> Result<Self, PipelineError> {
73        Self::new(ExtractionSchema::default(), models_dir, 0.5)
74    }
75
76    /// Extract entities, relations, and temporal expressions from text.
77    pub fn extract(
78        &self,
79        text: &str,
80        reference_time: DateTime<Utc>,
81    ) -> Result<ExtractionResult, PipelineError> {
82        // Step 1: NER — extract entities.
83        // GLiNER v2.1 was trained on standard NER categories; passing schema key
84        // names directly ("Person", "Database", etc.) performs best. Using
85        // natural-language descriptions hurts performance because the combined
86        // label tokens compete with the input text for the 512-token budget.
87        let labels: Vec<&str> = self.schema.entity_labels();
88        let mut entities = self
89            .ner
90            .extract(text, &labels, None)
91            .map_err(PipelineError::Ner)?;
92
93        // Filter by confidence
94        entities.retain(|e| e.confidence >= self.confidence_threshold);
95
96        // Step 1b: Coreference resolution — resolve pronouns to preceding entities
97        let coref_entities = CorefResolver::resolve(text, &entities);
98        entities.extend(coref_entities);
99
100        // Step 2: Relation extraction
101        let mut relations = self
102            .rel
103            .extract(text, &entities, &self.schema)
104            .map_err(PipelineError::Rel)?;
105
106        // Filter by confidence
107        relations.retain(|r| r.confidence >= self.confidence_threshold);
108
109        // Step 3: Temporal parsing
110        let temporal = temporal::parse_temporal(text, reference_time);
111
112        Ok(ExtractionResult {
113            entities,
114            relations,
115            temporal,
116        })
117    }
118
119    /// Get the schema used by this pipeline.
120    pub fn schema(&self) -> &ExtractionSchema {
121        &self.schema
122    }
123
124    /// Get the confidence threshold.
125    pub fn confidence_threshold(&self) -> f64 {
126        self.confidence_threshold
127    }
128}
129
130/// Find the NER ONNX model file in the models directory.
131///
132/// Looks for these files in order:
133/// 1. `gliner_large-v2.1/onnx/model_int8.onnx` (quantized, recommended)
134/// 2. `gliner_large-v2.1/onnx/model.onnx` (full precision)
135/// 3. `gliner2-large-q8.onnx` (legacy flat layout)
136fn find_ner_model(models_dir: &Path) -> Result<PathBuf, PipelineError> {
137    let candidates = [
138        models_dir.join("gliner_large-v2.1/onnx/model_int8.onnx"),
139        models_dir.join("gliner_large-v2.1/onnx/model.onnx"),
140        models_dir.join("gliner2-large-q8.onnx"),
141    ];
142
143    for c in &candidates {
144        if c.exists() {
145            return Ok(c.clone());
146        }
147    }
148
149    Err(PipelineError::ModelNotFound {
150        model: "GLiNER v2.1 NER".into(),
151        searched: candidates.iter().map(|p| p.display().to_string()).collect(),
152    })
153}
154
155/// Find the relation extraction model (multitask GLiNER).
156fn find_rel_model(models_dir: &Path) -> Option<PathBuf> {
157    let candidates = [
158        models_dir.join("gliner-multitask-large-v0.5/onnx/model.onnx"),
159        models_dir.join("gliner-multitask-large.onnx"),
160        models_dir.join("glirel-large.onnx"),
161    ];
162
163    candidates.into_iter().find(|c| c.exists())
164}
165
166/// Find a tokenizer.json file associated with a model.
167fn find_tokenizer(models_dir: &Path, prefix: &str) -> Result<PathBuf, PipelineError> {
168    let candidates = if prefix == "gliner" {
169        vec![
170            models_dir.join("gliner_large-v2.1/tokenizer.json"),
171            models_dir.join("tokenizer.json"),
172        ]
173    } else {
174        vec![
175            models_dir.join("gliner-multitask-large-v0.5/tokenizer.json"),
176            models_dir.join("tokenizer.json"),
177        ]
178    };
179
180    for c in &candidates {
181        if c.exists() {
182            return Ok(c.clone());
183        }
184    }
185
186    Err(PipelineError::ModelNotFound {
187        model: format!("{prefix} tokenizer").into(),
188        searched: candidates.iter().map(|p| p.display().to_string()).collect(),
189    })
190}
191
192#[derive(Debug, thiserror::Error)]
193pub enum PipelineError {
194    #[error("NER error: {0}")]
195    Ner(#[from] NerError),
196
197    #[error("relation extraction error: {0}")]
198    Rel(#[from] RelError),
199
200    #[error("schema error: {0}")]
201    Schema(#[from] SchemaError),
202
203    #[error("model not found: {model}. Searched: {searched:?}")]
204    ModelNotFound {
205        model: String,
206        searched: Vec<String>,
207    },
208}