ctxgraph_extract/
pipeline.rs1use std::path::{Path, PathBuf};
2
3use chrono::{DateTime, Utc};
4
5use crate::ner::{ExtractedEntity, NerEngine, NerError};
6use crate::rel::{ExtractedRelation, RelEngine, RelError};
7use crate::schema::{ExtractionSchema, SchemaError};
8use crate::temporal::{self, TemporalResult};
9
10#[derive(Debug, Clone)]
12pub struct ExtractionResult {
13 pub entities: Vec<ExtractedEntity>,
14 pub relations: Vec<ExtractedRelation>,
15 pub temporal: Vec<TemporalResult>,
16}
17
18pub struct ExtractionPipeline {
23 schema: ExtractionSchema,
24 ner: NerEngine,
25 rel: RelEngine,
26 confidence_threshold: f64,
27}
28
29impl ExtractionPipeline {
30 pub fn new(
36 schema: ExtractionSchema,
37 models_dir: &Path,
38 confidence_threshold: f64,
39 ) -> Result<Self, PipelineError> {
40 let ner_model = find_ner_model(models_dir)?;
42 let ner_tokenizer = find_tokenizer(models_dir, "gliner")?;
43
44 let ner = NerEngine::new(&ner_model, &ner_tokenizer, confidence_threshold as f32)
48 .map_err(PipelineError::Ner)?;
49
50 let rel_model = find_rel_model(models_dir);
52 let rel_tokenizer = find_tokenizer(models_dir, "multitask").ok();
53
54 let rel = RelEngine::new(
55 rel_model.as_deref(),
56 rel_tokenizer.as_deref(),
57 )
58 .map_err(PipelineError::Rel)?;
59
60 Ok(Self {
61 schema,
62 ner,
63 rel,
64 confidence_threshold,
65 })
66 }
67
68 pub fn with_defaults(models_dir: &Path) -> Result<Self, PipelineError> {
72 Self::new(ExtractionSchema::default(), models_dir, 0.5)
73 }
74
75 pub fn extract(
77 &self,
78 text: &str,
79 reference_time: DateTime<Utc>,
80 ) -> Result<ExtractionResult, PipelineError> {
81 let labels: Vec<&str> = self.schema.entity_labels();
87 let mut entities = self
88 .ner
89 .extract(text, &labels, None)
90 .map_err(PipelineError::Ner)?;
91
92 entities.retain(|e| e.confidence >= self.confidence_threshold);
94
95 let mut relations = self
97 .rel
98 .extract(text, &entities, &self.schema)
99 .map_err(PipelineError::Rel)?;
100
101 relations.retain(|r| r.confidence >= self.confidence_threshold);
103
104 let temporal = temporal::parse_temporal(text, reference_time);
106
107 Ok(ExtractionResult {
108 entities,
109 relations,
110 temporal,
111 })
112 }
113
114 pub fn schema(&self) -> &ExtractionSchema {
116 &self.schema
117 }
118
119 pub fn confidence_threshold(&self) -> f64 {
121 self.confidence_threshold
122 }
123}
124
125fn find_ner_model(models_dir: &Path) -> Result<PathBuf, PipelineError> {
132 let candidates = [
133 models_dir.join("gliner_large-v2.1/onnx/model_int8.onnx"),
134 models_dir.join("gliner_large-v2.1/onnx/model.onnx"),
135 models_dir.join("gliner2-large-q8.onnx"),
136 ];
137
138 for c in &candidates {
139 if c.exists() {
140 return Ok(c.clone());
141 }
142 }
143
144 Err(PipelineError::ModelNotFound {
145 model: "GLiNER v2.1 NER".into(),
146 searched: candidates.iter().map(|p| p.display().to_string()).collect(),
147 })
148}
149
150fn find_rel_model(models_dir: &Path) -> Option<PathBuf> {
152 let candidates = [
153 models_dir.join("gliner-multitask-large-v0.5/onnx/model.onnx"),
154 models_dir.join("gliner-multitask-large.onnx"),
155 models_dir.join("glirel-large.onnx"),
156 ];
157
158 candidates.into_iter().find(|c| c.exists())
159}
160
161fn find_tokenizer(models_dir: &Path, prefix: &str) -> Result<PathBuf, PipelineError> {
163 let candidates = if prefix == "gliner" {
164 vec![
165 models_dir.join("gliner_large-v2.1/tokenizer.json"),
166 models_dir.join("tokenizer.json"),
167 ]
168 } else {
169 vec![
170 models_dir.join("gliner-multitask-large-v0.5/tokenizer.json"),
171 models_dir.join("tokenizer.json"),
172 ]
173 };
174
175 for c in &candidates {
176 if c.exists() {
177 return Ok(c.clone());
178 }
179 }
180
181 Err(PipelineError::ModelNotFound {
182 model: format!("{prefix} tokenizer").into(),
183 searched: candidates.iter().map(|p| p.display().to_string()).collect(),
184 })
185}
186
187#[derive(Debug, thiserror::Error)]
188pub enum PipelineError {
189 #[error("NER error: {0}")]
190 Ner(#[from] NerError),
191
192 #[error("relation extraction error: {0}")]
193 Rel(#[from] RelError),
194
195 #[error("schema error: {0}")]
196 Schema(#[from] SchemaError),
197
198 #[error("model not found: {model}. Searched: {searched:?}")]
199 ModelNotFound {
200 model: String,
201 searched: Vec<String>,
202 },
203}