gliner/model/output/
relation.rs1use composable::Composable;
2use crate::model::input::relation::schema::RelationSchema;
3use crate::model::pipeline::context::RelationContext;
4use crate::util::result::Result;
5use crate::text::span::Span;
6use super::decoded::SpanOutput;
7
8pub struct RelationOutput {
10 pub texts: Vec<String>,
11 pub entities: Vec<String>,
12 pub relations: Vec<Vec<Relation>>,
13}
14
15pub struct Relation {
17 class: String,
19 subject: String,
21 object: String,
23 sequence: usize,
25 start: usize,
27 end: usize,
29 probability: f32,
31}
32
33
34impl Relation {
35
36 pub fn from(span: Span) -> Result<Self> {
37 let (start, end) = span.offsets();
38 let (subject, class) = Self::decode(span.class())?;
39 Ok(Self {
40 class,
41 subject,
42 object: span.text().to_string(),
43 sequence: span.sequence(),
44 start,
45 end,
46 probability: span.probability(),
47 })
48 }
49
50 pub fn class(&self) -> &str {
51 &self.class
52 }
53
54 pub fn subject(&self) -> &str {
55 &self.subject
56 }
57
58 pub fn object(&self) -> &str {
59 &self.object
60 }
61
62 pub fn sequence(&self) -> usize {
63 self.sequence
64 }
65
66 pub fn offsets(&self) -> (usize, usize) {
67 (self.start, self.end)
68 }
69
70 pub fn probability(&self) -> f32 {
71 self.probability
72 }
73
74 fn decode(rel_class: &str) -> Result<(String, String)> {
75 let split: Vec<&str> = rel_class.split(" <> ").collect();
76 if split.len() != 2 {
77 RelationFormatError::invalid_relation_label(rel_class).err()
78 }
79 else {
80 Ok((split.get(0).unwrap().to_string(), split.get(1).unwrap().to_string()))
81 }
82 }
83}
84
85
86impl std::fmt::Display for RelationOutput {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 for relations in &self.relations {
89 for relation in relations {
90 writeln!(f, "{:3} | {:15} | {:10} | {:15} | {:.1}%", relation.sequence(), relation.subject(), relation.class(), relation.object(), relation.probability() * 100.0)?;
91 }
92 }
93 Ok(())
94 }
95}
96
97
98pub struct SpanOutputToRelationOutput<'a> {
100 schema: &'a RelationSchema,
101}
102
103impl<'a> SpanOutputToRelationOutput<'a> {
104 pub fn new(schema: &'a RelationSchema) -> Self {
105 Self { schema }
106 }
107
108 fn is_valid(&self, relation: &Relation, context: &RelationContext) -> Result<bool> {
109 if let Some(potential_labels) = context.entity_labels.get(relation.object()) {
112 let spec = self.schema.relations().get(relation.class()).ok_or(RelationFormatError::unexpected_relation_label(relation.class()))?;
114 Ok(spec.allows_one_of_objects(potential_labels))
116 }
117 else {
118 Ok(false)
120 }
121 }
122}
123
124impl Composable<(SpanOutput, RelationContext), RelationOutput> for SpanOutputToRelationOutput<'_> {
125 fn apply(&self, input: (SpanOutput, RelationContext)) -> Result<RelationOutput> {
126 let (input, context) = input;
127 let mut result = Vec::new();
128 for seq in input.spans {
129 let mut relations = Vec::new();
130 for span in seq {
131 let relation = Relation::from(span)?;
132 if self.is_valid(&relation, &context)? {
133 relations.push(relation);
134 }
135 }
136 result.push(relations);
137 }
138 Ok(RelationOutput {
139 texts: input.texts,
140 entities: input.entities,
141 relations: result
142 })
143 }
144}
145
146
147
148#[derive(Debug, Clone)]
149pub struct RelationFormatError {
153 message: String,
154}
155
156impl RelationFormatError {
157 pub fn invalid_relation_label(label: &str) -> Self {
158 Self { message: format!("invalid relation label format: {label}") }
159 }
160
161 pub fn unexpected_relation_label(label: &str) -> Self {
162 Self { message: format!("unexpected relation label: {label}") }
163 }
164
165 pub fn err<T>(self) -> Result<T> {
166 Err(Box::new(self))
167 }
168}
169
170impl std::error::Error for RelationFormatError { }
171
172impl std::fmt::Display for RelationFormatError {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 f.write_str(&self.message)
175 }
176}