gliner/model/input/relation/
mod.rs

1pub mod schema;
2
3use composable::*;
4use std::collections::{HashMap, HashSet};
5use crate::model::pipeline::context::RelationContext;
6use crate::util::result::Result;
7use crate::model::output::decoded::SpanOutput;
8use schema::RelationSchema;
9
10
11const PROMPT_PREFIX: &str = "Extract relationships between entities from the text: \n";
12
13/// Input data for Relation Extraction
14pub struct RelationInput {
15    pub prompts: Vec<String>,
16    pub labels: Vec<String>,
17    pub entity_labels: HashMap<String, HashSet<String>>,
18}
19
20impl RelationInput {
21
22    /// Builds a relation input from a span output and a relation schema
23    pub fn from_spans(spans: SpanOutput, schema: &RelationSchema) -> Self {
24        Self {
25            prompts: Self::make_prompts(&spans, PROMPT_PREFIX),
26            labels: Self::make_labels(&spans, schema),
27            entity_labels: Self::make_entity_labels(&spans),
28        }
29    }
30    
31    /// Prepare the prompts basing on the provided prefix
32    fn make_prompts(spans: &SpanOutput, prefix: &str) -> Vec<String> {        
33        spans.texts.iter().map(|t| format!("{prefix} {t}")).collect()
34    }
35    
36    /// Prepare the labels basing on extracted entities and the provided schema
37    fn make_labels(spans: &SpanOutput, schema: &RelationSchema) -> Vec<String> {
38        // List unique (entity, class) entries found in all spans for all sequences.
39        // This is sub-optimal because one huge label list will be made for all sequences, 
40        // but this is how GLiNER multitask works...
41        let mut unique_entities: HashSet<(&str, &str)> = HashSet::new();
42        for seq in &spans.spans {
43            for span in seq {
44                unique_entities.insert((span.text(), span.class()));
45            }
46        }
47
48        // Actually create the labels. Labels for not allowed entity classes for the subject (according 
49        // to the schema) will not be included. The check on the object class has to be made when
50        // decoding the result.
51        let mut result = Vec::new();
52        for (relation, spec) in schema.relations() {
53            unique_entities.iter()
54                .filter(|(_, class)| spec.allows_subject(class))
55                .map(|(text, _)| format!("{} <> {}", text, relation))
56                .for_each(|l| result.push(l));            
57        }
58
59        result
60    }
61    
62    /// Build entity-text -> entity-labels map (which will be used when decoding, to filter relations basing on allowed objects).
63    /// 
64    /// Multiple labels for the same entity text is supported, but in this case there is no guarantee that a 
65    /// relation actually mentions an entity as having a given label since we just have this information 
66    /// (limitation of GLiNER multi). So, as soon as one expected class is found for an entity, it will have
67    /// to be accepted without knowing its actual class within the relation (which is probably ok).
68    fn make_entity_labels(spans: &SpanOutput) -> HashMap<String, HashSet<String>> {        
69        let mut entity_labels = HashMap::<String, HashSet<String>>::new();
70        for seq in &spans.spans {
71            for span in seq {
72                entity_labels.entry(span.text().to_string()).or_default().insert(span.class().to_string());
73            }
74        }
75        entity_labels
76    }
77
78}
79
80
81pub struct SpanOutputToRelationInput<'a> {
82    schema: &'a RelationSchema
83}
84
85impl<'a> SpanOutputToRelationInput<'a> {
86    pub fn new(schema: &'a RelationSchema) -> Self {
87        Self { schema }
88    }
89}
90
91impl Composable<SpanOutput, RelationInput> for SpanOutputToRelationInput<'_> {
92    fn apply(&self, input: SpanOutput) -> Result<RelationInput> {
93        Ok(RelationInput::from_spans(input, self.schema))
94    }
95}
96
97
98#[derive(Default)]
99pub struct RelationInputToTextInput {    
100}
101
102impl Composable<RelationInput, (super::text::TextInput, RelationContext)> for RelationInputToTextInput {
103    fn apply(&self, input: RelationInput) -> Result<(super::text::TextInput, RelationContext)> {
104        Ok((super::text::TextInput::new(input.prompts, input.labels)?, RelationContext { entity_labels: input.entity_labels }))
105    }
106}