Skip to main content

ctxgraph_extract/
rel.rs

1use std::path::Path;
2
3use composable::Composable;
4use gliner::model::input::relation::schema::RelationSchema;
5use gliner::model::input::text::TextInput;
6use gliner::model::output::decoded::SpanOutput;
7use gliner::model::output::relation::RelationOutput;
8use gliner::model::params::Parameters;
9use gliner::model::pipeline::relation::RelationPipeline;
10use gliner::model::pipeline::token::TokenPipeline;
11use orp::model::Model;
12use orp::params::RuntimeParameters;
13use orp::pipeline::Pipeline;
14
15use crate::ner::ExtractedEntity;
16use crate::schema::ExtractionSchema;
17
18/// A relation extracted between two entities.
19#[derive(Debug, Clone)]
20pub struct ExtractedRelation {
21    pub head: String,
22    pub relation: String,
23    pub tail: String,
24    pub confidence: f64,
25}
26
27/// Relation extraction engine.
28///
29/// Supports two modes:
30/// - **Model-based**: Uses gline-rs `RelationPipeline` with the multitask ONNX model.
31/// - **Heuristic**: Pattern-based extraction when no relation model is available.
32pub enum RelEngine {
33    ModelBased(ModelBasedRelEngine),
34    Heuristic,
35}
36
37/// Model-based relation extraction using gline-rs.
38///
39/// Requires `gliner-multitask-large-v0.5` ONNX model.
40pub struct ModelBasedRelEngine {
41    model: Model,
42    params: Parameters,
43    tokenizer_path: String,
44}
45
46impl ModelBasedRelEngine {
47    pub fn new(model_path: &Path, tokenizer_path: &Path) -> Result<Self, RelError> {
48        let runtime_params = RuntimeParameters::default();
49        let model = Model::new(
50            model_path
51                .to_str()
52                .ok_or(RelError::InvalidPath(model_path.display().to_string()))?,
53            runtime_params,
54        )
55        .map_err(|e| RelError::ModelLoad(e.to_string()))?;
56
57        Ok(Self {
58            model,
59            params: Parameters::default(),
60            tokenizer_path: tokenizer_path
61                .to_str()
62                .ok_or(RelError::InvalidPath(
63                    tokenizer_path.display().to_string(),
64                ))?
65                .to_string(),
66        })
67    }
68
69    pub fn extract(
70        &self,
71        text: &str,
72        labels: &[&str],
73        schema: &ExtractionSchema,
74    ) -> Result<(Vec<ExtractedEntity>, Vec<ExtractedRelation>), RelError> {
75        // Build relation schema from extraction schema
76        let mut relation_schema = RelationSchema::new();
77        for (rel_name, spec) in &schema.relation_types {
78            let heads: Vec<&str> = spec.head.iter().map(|s| s.as_str()).collect();
79            let tails: Vec<&str> = spec.tail.iter().map(|s| s.as_str()).collect();
80            relation_schema.push_with_allowed_labels(rel_name, &heads, &tails);
81        }
82
83        let input = TextInput::from_str(&[text], labels)
84            .map_err(|e| RelError::Inference(e.to_string()))?;
85
86        // Step 1: Run NER via TokenPipeline
87        let ner_pipeline = TokenPipeline::new(&self.tokenizer_path)
88            .map_err(|e| RelError::Inference(e.to_string()))?;
89        let ner_composable = ner_pipeline.to_composable(&self.model, &self.params);
90        let ner_output: SpanOutput = ner_composable
91            .apply(input)
92            .map_err(|e| RelError::Inference(e.to_string()))?;
93
94        // Collect entities from NER output using span character offsets directly
95        let mut entities = Vec::new();
96        for sequence_spans in &ner_output.spans {
97            for span in sequence_spans {
98                let (start, end) = span.offsets();
99                entities.push(ExtractedEntity {
100                    text: span.text().to_string(),
101                    entity_type: span.class().to_string(),
102                    span_start: start,
103                    span_end: end,
104                    confidence: span.probability() as f64,
105                });
106            }
107        }
108
109        // Step 2: Run relation extraction on top of NER output
110        let rel_pipeline =
111            RelationPipeline::default(&self.tokenizer_path, &relation_schema)
112                .map_err(|e| RelError::Inference(e.to_string()))?;
113        let rel_composable = rel_pipeline.to_composable(&self.model, &self.params);
114        let rel_output: RelationOutput = rel_composable
115            .apply(ner_output)
116            .map_err(|e| RelError::Inference(e.to_string()))?;
117
118        // Collect relations
119        let mut relations = Vec::new();
120        for sequence_rels in &rel_output.relations {
121            for rel in sequence_rels {
122                relations.push(ExtractedRelation {
123                    head: rel.subject().to_string(),
124                    relation: rel.class().to_string(),
125                    tail: rel.object().to_string(),
126                    confidence: rel.probability() as f64,
127                });
128            }
129        }
130
131        Ok((entities, relations))
132    }
133}
134
135impl RelEngine {
136    /// Create a model-based engine if the multitask model is available,
137    /// otherwise fall back to heuristic mode.
138    pub fn new(model_path: Option<&Path>, tokenizer_path: Option<&Path>) -> Result<Self, RelError> {
139        match (model_path, tokenizer_path) {
140            (Some(mp), Some(tp)) if mp.exists() && tp.exists() => {
141                let engine = ModelBasedRelEngine::new(mp, tp)?;
142                Ok(Self::ModelBased(engine))
143            }
144            _ => Ok(Self::Heuristic),
145        }
146    }
147
148    /// Extract relations between entities.
149    pub fn extract(
150        &self,
151        text: &str,
152        entities: &[ExtractedEntity],
153        schema: &ExtractionSchema,
154    ) -> Result<Vec<ExtractedRelation>, RelError> {
155        match self {
156            Self::ModelBased(engine) => {
157                let labels: Vec<&str> = schema.entity_labels();
158                let (_, relations) = engine.extract(text, &labels, schema)?;
159                Ok(relations)
160            }
161            Self::Heuristic => Ok(heuristic_relations(text, entities, schema)),
162        }
163    }
164}
165
166/// Heuristic relation extraction based on text patterns and entity co-occurrence.
167fn heuristic_relations(
168    text: &str,
169    entities: &[ExtractedEntity],
170    schema: &ExtractionSchema,
171) -> Vec<ExtractedRelation> {
172    let lower = text.to_lowercase();
173    let mut relations = Vec::new();
174
175    let patterns: &[(&str, &[&str])] = &[
176        ("chose", &["chose", "selected", "picked", "went with", "adopted"]),
177        ("rejected", &["rejected", "ruled out", "decided against", "dropped"]),
178        ("replaced", &["replaced", "migrated from", "switched from", "moved from"]),
179        ("depends_on", &["depends on", "relies on", "requires", "built on", "uses"]),
180        ("fixed", &["fixed", "resolved", "patched", "repaired", "debugged"]),
181        ("introduced", &["introduced", "added", "implemented", "created", "built"]),
182        ("deprecated", &["deprecated", "removed", "phased out", "sunset"]),
183        ("caused", &["caused", "resulted in", "led to", "triggered"]),
184        ("constrained_by", &["constrained by", "limited by", "blocked by", "due to"]),
185    ];
186
187    for (relation, keywords) in patterns {
188        let rel_spec = match schema.relation_types.get(*relation) {
189            Some(spec) => spec,
190            None => continue,
191        };
192
193        let keyword_found = keywords.iter().any(|kw| lower.contains(kw));
194        if !keyword_found {
195            continue;
196        }
197
198        for head in entities {
199            if !rel_spec.head.contains(&head.entity_type) {
200                continue;
201            }
202            for tail in entities {
203                if std::ptr::eq(head, tail) {
204                    continue;
205                }
206                if !rel_spec.tail.contains(&tail.entity_type) {
207                    continue;
208                }
209                relations.push(ExtractedRelation {
210                    head: head.text.clone(),
211                    relation: relation.to_string(),
212                    tail: tail.text.clone(),
213                    confidence: 0.6,
214                });
215            }
216        }
217    }
218
219    relations
220}
221
222#[derive(Debug, thiserror::Error)]
223pub enum RelError {
224    #[error("invalid path: {0}")]
225    InvalidPath(String),
226
227    #[error("failed to load model: {0}")]
228    ModelLoad(String),
229
230    #[error("inference error: {0}")]
231    Inference(String),
232}