gliner/model/output/
relation.rs

1use composable::Composable;
2use crate::model::input::relation::schema::RelationSchema;
3use crate::model::pipeline::context::RelationContext;
4use crate::util::result::Result;
5use crate::text::span::Span;
6use super::decoded::SpanOutput;
7
8/// Defines the final output of the relation extraction pipeline
9pub struct RelationOutput {
10    pub texts: Vec<String>,
11    pub entities: Vec<String>,
12    pub relations: Vec<Vec<Relation>>,    
13}
14
15/// Defines an individual relation
16pub struct Relation {
17    /// Relation label
18    class: String,
19    /// Text of the subject
20    subject: String,
21    /// Text of the object
22    object: String,
23    /// Input index in the batch
24    sequence: usize,    
25    /// Start offset
26    start: usize,
27    /// End offset
28    end: usize,
29    /// Probability 
30    probability: f32,
31}
32
33
34impl Relation {
35    
36    pub fn from(span: Span) -> Result<Self> {
37        let (start, end) = span.offsets();
38        let (subject, class) = Self::decode(span.class())?;
39        Ok(Self {
40            class,
41            subject,
42            object: span.text().to_string(),
43            sequence: span.sequence(),
44            start,
45            end,
46            probability: span.probability(),
47        })
48    }
49    
50    pub fn class(&self) -> &str {
51        &self.class
52    }
53    
54    pub fn subject(&self) -> &str {
55        &self.subject
56    }
57    
58    pub fn object(&self) -> &str {
59        &self.object
60    }
61    
62    pub fn sequence(&self) -> usize {
63        self.sequence
64    }
65    
66    pub fn offsets(&self) -> (usize, usize) {
67        (self.start, self.end)
68    }
69    
70    pub fn probability(&self) -> f32 {
71        self.probability
72    }
73    
74    fn decode(rel_class: &str) -> Result<(String, String)> {
75        let split: Vec<&str> = rel_class.split(" <> ").collect();
76        if split.len() != 2 {
77            RelationFormatError::invalid_relation_label(rel_class).err()
78        }
79        else {
80            Ok((split.get(0).unwrap().to_string(), split.get(1).unwrap().to_string()))
81        }        
82    }
83}
84
85
86impl std::fmt::Display for RelationOutput {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        for relations in &self.relations {
89            for relation in relations {
90                writeln!(f, "{:3} | {:15} | {:10} | {:15} | {:.1}%", relation.sequence(), relation.subject(), relation.class(), relation.object(), relation.probability() * 100.0)?;
91            }
92        }
93        Ok(())
94    }
95}
96
97
98/// SpanOutput -> RelationOutput
99pub struct SpanOutputToRelationOutput<'a> {
100    schema: &'a RelationSchema,
101}
102
103impl<'a> SpanOutputToRelationOutput<'a> {
104    pub fn new(schema: &'a RelationSchema) -> Self {
105        Self { schema }
106    }
107
108    fn is_valid(&self, relation: &Relation, context: &RelationContext) -> Result<bool> {
109        // check that one the potential labels of the object is allowed by the relation schema ("potential" because the model outputs the text of the object, not its actual label, and in some corner cases the same entity might have several labels)
110        // note that we might have no label at all, if the object is not part of the extracted entities (in such case the relation is not valid)
111        if let Some(potential_labels) = context.entity_labels.get(relation.object()) {
112            // get the spec for the relation label (checking that is is actually expected according to the schema)
113            let spec = self.schema.relations().get(relation.class()).ok_or(RelationFormatError::unexpected_relation_label(relation.class()))?;
114            // check that the spec allows one of the labels
115            Ok(spec.allows_one_of_objects(potential_labels))
116        }
117        else {
118            // in case the extracted object is not part of the extracted entities
119            Ok(false)
120        }
121    }
122}
123
124impl Composable<(SpanOutput, RelationContext), RelationOutput> for SpanOutputToRelationOutput<'_> {
125    fn apply(&self, input: (SpanOutput, RelationContext)) -> Result<RelationOutput> {
126        let (input, context) = input;        
127        let mut result = Vec::new();
128        for seq in input.spans {
129            let mut relations = Vec::new();
130            for span in seq {
131                let relation = Relation::from(span)?;
132                if self.is_valid(&relation, &context)? {
133                    relations.push(relation);
134                }
135            }
136            result.push(relations);
137        }
138        Ok(RelationOutput { 
139            texts: input.texts,
140            entities: input.entities,
141            relations: result 
142        })
143    }
144}
145
146
147
148#[derive(Debug, Clone)]
149/// Defines an error caused by an malformed or unexpected span label
150/// obtained from the relation extraction pipeline. This is likely to
151/// be an internal error, unless the pipeline was not used correctly.
152pub struct RelationFormatError {
153    message: String,
154}
155
156impl RelationFormatError {
157    pub fn invalid_relation_label(label: &str) -> Self {
158        Self { message: format!("invalid relation label format: {label}") }
159    }
160
161    pub fn unexpected_relation_label(label: &str) -> Self {
162        Self { message: format!("unexpected relation label: {label}") }
163    }
164
165    pub fn err<T>(self) -> Result<T> {
166        Err(Box::new(self))
167    }
168}
169
170impl std::error::Error for RelationFormatError { }
171
172impl std::fmt::Display for RelationFormatError {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        f.write_str(&self.message)
175    }
176}