gliner/model/input/relation/
mod.rs1pub mod schema;
2
3use composable::*;
4use std::collections::{HashMap, HashSet};
5use crate::model::pipeline::context::RelationContext;
6use crate::util::result::Result;
7use crate::model::output::decoded::SpanOutput;
8use schema::RelationSchema;
9
10
11const PROMPT_PREFIX: &str = "Extract relationships between entities from the text: \n";
12
13pub struct RelationInput {
15 pub prompts: Vec<String>,
16 pub labels: Vec<String>,
17 pub entity_labels: HashMap<String, HashSet<String>>,
18}
19
20impl RelationInput {
21
22 pub fn from_spans(spans: SpanOutput, schema: &RelationSchema) -> Self {
24 Self {
25 prompts: Self::make_prompts(&spans, PROMPT_PREFIX),
26 labels: Self::make_labels(&spans, schema),
27 entity_labels: Self::make_entity_labels(&spans),
28 }
29 }
30
31 fn make_prompts(spans: &SpanOutput, prefix: &str) -> Vec<String> {
33 spans.texts.iter().map(|t| format!("{prefix} {t}")).collect()
34 }
35
36 fn make_labels(spans: &SpanOutput, schema: &RelationSchema) -> Vec<String> {
38 let mut unique_entities: HashSet<(&str, &str)> = HashSet::new();
42 for seq in &spans.spans {
43 for span in seq {
44 unique_entities.insert((span.text(), span.class()));
45 }
46 }
47
48 let mut result = Vec::new();
52 for (relation, spec) in schema.relations() {
53 unique_entities.iter()
54 .filter(|(_, class)| spec.allows_subject(class))
55 .map(|(text, _)| format!("{} <> {}", text, relation))
56 .for_each(|l| result.push(l));
57 }
58
59 result
60 }
61
62 fn make_entity_labels(spans: &SpanOutput) -> HashMap<String, HashSet<String>> {
69 let mut entity_labels = HashMap::<String, HashSet<String>>::new();
70 for seq in &spans.spans {
71 for span in seq {
72 entity_labels.entry(span.text().to_string()).or_default().insert(span.class().to_string());
73 }
74 }
75 entity_labels
76 }
77
78}
79
80
81pub struct SpanOutputToRelationInput<'a> {
82 schema: &'a RelationSchema
83}
84
85impl<'a> SpanOutputToRelationInput<'a> {
86 pub fn new(schema: &'a RelationSchema) -> Self {
87 Self { schema }
88 }
89}
90
91impl Composable<SpanOutput, RelationInput> for SpanOutputToRelationInput<'_> {
92 fn apply(&self, input: SpanOutput) -> Result<RelationInput> {
93 Ok(RelationInput::from_spans(input, self.schema))
94 }
95}
96
97
98#[derive(Default)]
99pub struct RelationInputToTextInput {
100}
101
102impl Composable<RelationInput, (super::text::TextInput, RelationContext)> for RelationInputToTextInput {
103 fn apply(&self, input: RelationInput) -> Result<(super::text::TextInput, RelationContext)> {
104 Ok((super::text::TextInput::new(input.prompts, input.labels)?, RelationContext { entity_labels: input.entity_labels }))
105 }
106}