Skip to main content

ctxgraph_extract/
pipeline.rs

1use std::path::{Path, PathBuf};
2
3use chrono::{DateTime, Utc};
4
5use crate::ner::{ExtractedEntity, NerEngine, NerError};
6use crate::rel::{ExtractedRelation, RelEngine, RelError};
7use crate::schema::{ExtractionSchema, SchemaError};
8use crate::temporal::{self, TemporalResult};
9
10/// Complete result of running the extraction pipeline on a piece of text.
11#[derive(Debug, Clone)]
12pub struct ExtractionResult {
13    pub entities: Vec<ExtractedEntity>,
14    pub relations: Vec<ExtractedRelation>,
15    pub temporal: Vec<TemporalResult>,
16}
17
18/// The extraction pipeline orchestrates NER, relation extraction, and temporal parsing.
19///
20/// Created once and reused across multiple episodes. Model loading happens at construction
21/// time (~100-500ms), but subsequent inference calls are fast (<15ms).
22pub struct ExtractionPipeline {
23    schema: ExtractionSchema,
24    ner: NerEngine,
25    rel: RelEngine,
26    confidence_threshold: f64,
27}
28
29impl ExtractionPipeline {
30    /// Create a new extraction pipeline.
31    ///
32    /// - `schema`: Entity/relation type definitions.
33    /// - `models_dir`: Directory containing ONNX model files.
34    /// - `confidence_threshold`: Minimum confidence to keep an extraction (default: 0.5).
35    pub fn new(
36        schema: ExtractionSchema,
37        models_dir: &Path,
38        confidence_threshold: f64,
39    ) -> Result<Self, PipelineError> {
40        // Locate NER model files (span-based GLiNER v2.1)
41        let ner_model = find_ner_model(models_dir)?;
42        let ner_tokenizer = find_tokenizer(models_dir, "gliner")?;
43
44        // Use the pipeline's confidence_threshold as the GLiNER model-level threshold too.
45        // Parameters::default() hardcodes 0.5 which silently drops low-confidence spans
46        // before we ever see them; pass our threshold so callers control the cutoff.
47        let ner = NerEngine::new(&ner_model, &ner_tokenizer, confidence_threshold as f32)
48            .map_err(PipelineError::Ner)?;
49
50        // Locate relation model files (multitask GLiNER) — optional
51        let rel_model = find_rel_model(models_dir);
52        let rel_tokenizer = find_tokenizer(models_dir, "multitask").ok();
53
54        let rel = RelEngine::new(
55            rel_model.as_deref(),
56            rel_tokenizer.as_deref(),
57        )
58        .map_err(PipelineError::Rel)?;
59
60        Ok(Self {
61            schema,
62            ner,
63            rel,
64            confidence_threshold,
65        })
66    }
67
68    /// Create a pipeline with default settings.
69    ///
70    /// Uses `ExtractionSchema::default()` and 0.5 confidence threshold.
71    pub fn with_defaults(models_dir: &Path) -> Result<Self, PipelineError> {
72        Self::new(ExtractionSchema::default(), models_dir, 0.5)
73    }
74
75    /// Extract entities, relations, and temporal expressions from text.
76    pub fn extract(
77        &self,
78        text: &str,
79        reference_time: DateTime<Utc>,
80    ) -> Result<ExtractionResult, PipelineError> {
81        // Step 1: NER — extract entities.
82        // GLiNER v2.1 was trained on standard NER categories; passing schema key
83        // names directly ("Person", "Database", etc.) performs best. Using
84        // natural-language descriptions hurts performance because the combined
85        // label tokens compete with the input text for the 512-token budget.
86        let labels: Vec<&str> = self.schema.entity_labels();
87        let mut entities = self
88            .ner
89            .extract(text, &labels, None)
90            .map_err(PipelineError::Ner)?;
91
92        // Filter by confidence
93        entities.retain(|e| e.confidence >= self.confidence_threshold);
94
95        // Step 2: Relation extraction
96        let mut relations = self
97            .rel
98            .extract(text, &entities, &self.schema)
99            .map_err(PipelineError::Rel)?;
100
101        // Filter by confidence
102        relations.retain(|r| r.confidence >= self.confidence_threshold);
103
104        // Step 3: Temporal parsing
105        let temporal = temporal::parse_temporal(text, reference_time);
106
107        Ok(ExtractionResult {
108            entities,
109            relations,
110            temporal,
111        })
112    }
113
114    /// Get the schema used by this pipeline.
115    pub fn schema(&self) -> &ExtractionSchema {
116        &self.schema
117    }
118
119    /// Get the confidence threshold.
120    pub fn confidence_threshold(&self) -> f64 {
121        self.confidence_threshold
122    }
123}
124
125/// Find the NER ONNX model file in the models directory.
126///
127/// Looks for these files in order:
128/// 1. `gliner_large-v2.1/onnx/model_int8.onnx` (quantized, recommended)
129/// 2. `gliner_large-v2.1/onnx/model.onnx` (full precision)
130/// 3. `gliner2-large-q8.onnx` (legacy flat layout)
131fn find_ner_model(models_dir: &Path) -> Result<PathBuf, PipelineError> {
132    let candidates = [
133        models_dir.join("gliner_large-v2.1/onnx/model_int8.onnx"),
134        models_dir.join("gliner_large-v2.1/onnx/model.onnx"),
135        models_dir.join("gliner2-large-q8.onnx"),
136    ];
137
138    for c in &candidates {
139        if c.exists() {
140            return Ok(c.clone());
141        }
142    }
143
144    Err(PipelineError::ModelNotFound {
145        model: "GLiNER v2.1 NER".into(),
146        searched: candidates.iter().map(|p| p.display().to_string()).collect(),
147    })
148}
149
150/// Find the relation extraction model (multitask GLiNER).
151fn find_rel_model(models_dir: &Path) -> Option<PathBuf> {
152    let candidates = [
153        models_dir.join("gliner-multitask-large-v0.5/onnx/model.onnx"),
154        models_dir.join("gliner-multitask-large.onnx"),
155        models_dir.join("glirel-large.onnx"),
156    ];
157
158    candidates.into_iter().find(|c| c.exists())
159}
160
161/// Find a tokenizer.json file associated with a model.
162fn find_tokenizer(models_dir: &Path, prefix: &str) -> Result<PathBuf, PipelineError> {
163    let candidates = if prefix == "gliner" {
164        vec![
165            models_dir.join("gliner_large-v2.1/tokenizer.json"),
166            models_dir.join("tokenizer.json"),
167        ]
168    } else {
169        vec![
170            models_dir.join("gliner-multitask-large-v0.5/tokenizer.json"),
171            models_dir.join("tokenizer.json"),
172        ]
173    };
174
175    for c in &candidates {
176        if c.exists() {
177            return Ok(c.clone());
178        }
179    }
180
181    Err(PipelineError::ModelNotFound {
182        model: format!("{prefix} tokenizer").into(),
183        searched: candidates.iter().map(|p| p.display().to_string()).collect(),
184    })
185}
186
187#[derive(Debug, thiserror::Error)]
188pub enum PipelineError {
189    #[error("NER error: {0}")]
190    Ner(#[from] NerError),
191
192    #[error("relation extraction error: {0}")]
193    Rel(#[from] RelError),
194
195    #[error("schema error: {0}")]
196    Schema(#[from] SchemaError),
197
198    #[error("model not found: {model}. Searched: {searched:?}")]
199    ModelNotFound {
200        model: String,
201        searched: Vec<String>,
202    },
203}