1use std::path::Path;
2
3use composable::Composable;
4use gliner::model::input::relation::schema::RelationSchema;
5use gliner::model::input::text::TextInput;
6use gliner::model::output::decoded::SpanOutput;
7use gliner::model::output::relation::RelationOutput;
8use gliner::model::params::Parameters;
9use gliner::model::pipeline::relation::RelationPipeline;
10use gliner::model::pipeline::token::TokenPipeline;
11use orp::model::Model;
12use orp::params::RuntimeParameters;
13use orp::pipeline::Pipeline;
14
15use crate::ner::ExtractedEntity;
16use crate::schema::ExtractionSchema;
17
18#[derive(Debug, Clone)]
20pub struct ExtractedRelation {
21 pub head: String,
22 pub relation: String,
23 pub tail: String,
24 pub confidence: f64,
25}
26
27pub enum RelEngine {
33 ModelBased(ModelBasedRelEngine),
34 Heuristic,
35}
36
37pub struct ModelBasedRelEngine {
41 model: Model,
42 params: Parameters,
43 tokenizer_path: String,
44}
45
46impl ModelBasedRelEngine {
47 pub fn new(model_path: &Path, tokenizer_path: &Path) -> Result<Self, RelError> {
48 let runtime_params = RuntimeParameters::default();
49 let model = Model::new(
50 model_path
51 .to_str()
52 .ok_or(RelError::InvalidPath(model_path.display().to_string()))?,
53 runtime_params,
54 )
55 .map_err(|e| RelError::ModelLoad(e.to_string()))?;
56
57 Ok(Self {
58 model,
59 params: Parameters::default(),
60 tokenizer_path: tokenizer_path
61 .to_str()
62 .ok_or(RelError::InvalidPath(
63 tokenizer_path.display().to_string(),
64 ))?
65 .to_string(),
66 })
67 }
68
69 pub fn extract(
70 &self,
71 text: &str,
72 labels: &[&str],
73 schema: &ExtractionSchema,
74 ) -> Result<(Vec<ExtractedEntity>, Vec<ExtractedRelation>), RelError> {
75 let mut relation_schema = RelationSchema::new();
77 for (rel_name, spec) in &schema.relation_types {
78 let heads: Vec<&str> = spec.head.iter().map(|s| s.as_str()).collect();
79 let tails: Vec<&str> = spec.tail.iter().map(|s| s.as_str()).collect();
80 relation_schema.push_with_allowed_labels(rel_name, &heads, &tails);
81 }
82
83 let input = TextInput::from_str(&[text], labels)
84 .map_err(|e| RelError::Inference(e.to_string()))?;
85
86 let ner_pipeline = TokenPipeline::new(&self.tokenizer_path)
88 .map_err(|e| RelError::Inference(e.to_string()))?;
89 let ner_composable = ner_pipeline.to_composable(&self.model, &self.params);
90 let ner_output: SpanOutput = ner_composable
91 .apply(input)
92 .map_err(|e| RelError::Inference(e.to_string()))?;
93
94 let mut entities = Vec::new();
96 for sequence_spans in &ner_output.spans {
97 for span in sequence_spans {
98 let (start, end) = span.offsets();
99 entities.push(ExtractedEntity {
100 text: span.text().to_string(),
101 entity_type: span.class().to_string(),
102 span_start: start,
103 span_end: end,
104 confidence: span.probability() as f64,
105 });
106 }
107 }
108
109 let rel_pipeline =
111 RelationPipeline::default(&self.tokenizer_path, &relation_schema)
112 .map_err(|e| RelError::Inference(e.to_string()))?;
113 let rel_composable = rel_pipeline.to_composable(&self.model, &self.params);
114 let rel_output: RelationOutput = rel_composable
115 .apply(ner_output)
116 .map_err(|e| RelError::Inference(e.to_string()))?;
117
118 let mut relations = Vec::new();
120 for sequence_rels in &rel_output.relations {
121 for rel in sequence_rels {
122 relations.push(ExtractedRelation {
123 head: rel.subject().to_string(),
124 relation: rel.class().to_string(),
125 tail: rel.object().to_string(),
126 confidence: rel.probability() as f64,
127 });
128 }
129 }
130
131 Ok((entities, relations))
132 }
133}
134
135impl RelEngine {
136 pub fn new(model_path: Option<&Path>, tokenizer_path: Option<&Path>) -> Result<Self, RelError> {
139 match (model_path, tokenizer_path) {
140 (Some(mp), Some(tp)) if mp.exists() && tp.exists() => {
141 let engine = ModelBasedRelEngine::new(mp, tp)?;
142 Ok(Self::ModelBased(engine))
143 }
144 _ => Ok(Self::Heuristic),
145 }
146 }
147
148 pub fn extract(
150 &self,
151 text: &str,
152 entities: &[ExtractedEntity],
153 schema: &ExtractionSchema,
154 ) -> Result<Vec<ExtractedRelation>, RelError> {
155 match self {
156 Self::ModelBased(engine) => {
157 let labels: Vec<&str> = schema.entity_labels();
158 let (_, relations) = engine.extract(text, &labels, schema)?;
159 Ok(relations)
160 }
161 Self::Heuristic => Ok(heuristic_relations(text, entities, schema)),
162 }
163 }
164}
165
166fn heuristic_relations(
168 text: &str,
169 entities: &[ExtractedEntity],
170 schema: &ExtractionSchema,
171) -> Vec<ExtractedRelation> {
172 let lower = text.to_lowercase();
173 let mut relations = Vec::new();
174
175 let patterns: &[(&str, &[&str])] = &[
176 ("chose", &["chose", "selected", "picked", "went with", "adopted"]),
177 ("rejected", &["rejected", "ruled out", "decided against", "dropped"]),
178 ("replaced", &["replaced", "migrated from", "switched from", "moved from"]),
179 ("depends_on", &["depends on", "relies on", "requires", "built on", "uses"]),
180 ("fixed", &["fixed", "resolved", "patched", "repaired", "debugged"]),
181 ("introduced", &["introduced", "added", "implemented", "created", "built"]),
182 ("deprecated", &["deprecated", "removed", "phased out", "sunset"]),
183 ("caused", &["caused", "resulted in", "led to", "triggered"]),
184 ("constrained_by", &["constrained by", "limited by", "blocked by", "due to"]),
185 ];
186
187 for (relation, keywords) in patterns {
188 let rel_spec = match schema.relation_types.get(*relation) {
189 Some(spec) => spec,
190 None => continue,
191 };
192
193 let keyword_found = keywords.iter().any(|kw| lower.contains(kw));
194 if !keyword_found {
195 continue;
196 }
197
198 for head in entities {
199 if !rel_spec.head.contains(&head.entity_type) {
200 continue;
201 }
202 for tail in entities {
203 if std::ptr::eq(head, tail) {
204 continue;
205 }
206 if !rel_spec.tail.contains(&tail.entity_type) {
207 continue;
208 }
209 relations.push(ExtractedRelation {
210 head: head.text.clone(),
211 relation: relation.to_string(),
212 tail: tail.text.clone(),
213 confidence: 0.6,
214 });
215 }
216 }
217 }
218
219 relations
220}
221
222#[derive(Debug, thiserror::Error)]
223pub enum RelError {
224 #[error("invalid path: {0}")]
225 InvalidPath(String),
226
227 #[error("failed to load model: {0}")]
228 ModelLoad(String),
229
230 #[error("inference error: {0}")]
231 Inference(String),
232}