use ordered_float::OrderedFloat;
use petgraph::algo::tarjan_scc;
use udgraph::graph::{DepTriple, Sentence};
use super::DependencyEncoding;
use crate::EncodingProb;
use conllu::display::ConlluSentence;
use udgraph::Error;
pub fn attach_orphans<'a, S, H>(
labels: &[S],
sentence: &mut Sentence,
head_idx: usize,
) -> Result<(), Error>
where
H: 'a + Clone,
S: AsRef<[EncodingProb<DependencyEncoding<H>>]>,
{
#[allow(clippy::needless_collect)]
let token_indices: Vec<_> = (0..sentence.len())
.filter(|&idx| sentence[idx].is_token())
.collect();
for (idx, encodings) in token_indices.into_iter().zip(labels) {
if sentence.dep_graph().head(idx).is_none() {
let relation = encodings.as_ref()[0].encoding().label().to_owned();
sentence
.dep_graph_mut()
.add_deprel(DepTriple::new(head_idx, Some(relation), idx))?;
}
}
Ok(())
}
pub fn break_cycles(sent: &mut Sentence, root_idx: usize) -> Result<(), Error> {
let mut prev_components = Vec::new();
loop {
let components = {
tarjan_scc(sent.get_ref())
.into_iter()
.filter(|c| c.len() > 1)
.collect::<Vec<_>>()
};
if components.is_empty() {
break;
}
assert_ne!(
components,
prev_components,
"Could not break cycle(s) in:\n\n{}",
ConlluSentence::borrowed(sent)
);
for cycle in components.iter() {
let first_token = cycle
.iter()
.filter(|idx| idx.index() != root_idx)
.min()
.expect("Cannot get minimum, but iterator is non-empty")
.index();
let head_rel = sent
.dep_graph()
.head(first_token)
.expect("Token without a head")
.relation()
.map(ToOwned::to_owned);
sent.dep_graph_mut()
.add_deprel(DepTriple::new(root_idx, head_rel, first_token))?;
}
prev_components = components;
}
Ok(())
}
fn find_root_candidate<'a, S, H, F>(
labels: &[S],
decode_fun: F,
root_relation: &str,
) -> Option<(DepTriple<String>, f32)>
where
H: 'a + Clone,
S: AsRef<[EncodingProb<DependencyEncoding<H>>]>,
F: Fn(usize, &DependencyEncoding<H>) -> Option<DepTriple<String>>,
{
labels
.iter()
.enumerate()
.filter_map(|(idx, encodings)| {
encodings
.as_ref()
.iter()
.filter(|e| e.encoding().label() == root_relation)
.filter_map(|e| {
let triple = decode_fun(idx + 1, e.encoding())?;
if triple.head() == 0 {
Some((triple, e.prob()))
} else {
None
}
})
.next()
})
.max_by_key(|(_, prob)| OrderedFloat(*prob))
}
pub fn find_or_create_root<'a, S, H, F>(
labels: &[S],
sentence: &mut Sentence,
decode_fun: F,
root_relation: &str,
) -> Result<usize, Error>
where
H: 'a + Clone,
S: AsRef<[EncodingProb<DependencyEncoding<H>>]>,
F: Fn(usize, &DependencyEncoding<H>) -> Option<DepTriple<String>>,
{
if let Some(root_idx) = first_root(sentence) {
return Ok(root_idx);
}
let triple = match find_root_candidate(labels, decode_fun, root_relation) {
Some((triple, _)) => triple,
None => DepTriple::new(0, Some(root_relation.to_owned()), 1),
};
let dependent = triple.dependent();
sentence.dep_graph_mut().add_deprel(triple)?;
Ok(dependent)
}
fn first_root(sentence: &Sentence) -> Option<usize> {
for idx in sentence
.iter()
.enumerate()
.filter_map(|(idx, node)| node.token().map(|_| idx))
{
if let Some(triple) = sentence.dep_graph().head(idx) {
if triple.head() == 0 {
return Some(idx);
}
}
}
None
}
#[cfg(test)]
mod tests {
use udgraph::graph::{DepTriple, Sentence};
use udgraph::token::TokenBuilder;
use super::{attach_orphans, break_cycles, find_or_create_root, first_root};
use crate::depseq::{DependencyEncoding, PosLayer, RelativePos, RelativePosEncoder};
use crate::{EncodingProb, SentenceEncoder};
const ROOT_POS: &str = "ROOT";
const ROOT_RELATION: &str = "root";
fn test_graph() -> Sentence {
let mut sent = Sentence::new();
sent.push(TokenBuilder::new("Die").xpos("det").into());
sent.push(TokenBuilder::new("AWO").xpos("noun").into());
sent.push(TokenBuilder::new("veruntreute").xpos("verb").into());
sent.push(TokenBuilder::new("Spendengeld").xpos("noun").into());
sent.dep_graph_mut()
.add_deprel(DepTriple::new(2, Some("det"), 1))
.unwrap();
sent.dep_graph_mut()
.add_deprel(DepTriple::new(3, Some("subj"), 2))
.unwrap();
sent.dep_graph_mut()
.add_deprel(DepTriple::new(0, Some(ROOT_RELATION), 3))
.unwrap();
sent.dep_graph_mut()
.add_deprel(DepTriple::new(3, Some("obj"), 4))
.unwrap();
sent
}
fn test_graph_cycle() -> Sentence {
let mut sent = Sentence::new();
sent.push(TokenBuilder::new("Die").upos("det").into());
sent.push(TokenBuilder::new("AWO").upos("noun").into());
sent.push(TokenBuilder::new("veruntreute").upos("verb").into());
sent.push(TokenBuilder::new("Spendengeld").upos("noun").into());
sent.dep_graph_mut()
.add_deprel(DepTriple::new(2, Some("det"), 1))
.unwrap();
sent.dep_graph_mut()
.add_deprel(DepTriple::new(1, Some("subj"), 2))
.unwrap();
sent.dep_graph_mut()
.add_deprel(DepTriple::new(0, Some(ROOT_RELATION), 3))
.unwrap();
sent.dep_graph_mut()
.add_deprel(DepTriple::new(3, Some("obj"), 4))
.unwrap();
sent
}
fn test_graph_no_root() -> Sentence {
let mut sent = Sentence::new();
sent.push(TokenBuilder::new("Die").xpos("det").into());
sent.push(TokenBuilder::new("AWO").xpos("noun").into());
sent.push(TokenBuilder::new("veruntreute").xpos("verb").into());
sent.push(TokenBuilder::new("Spendengeld").xpos("noun").into());
sent.dep_graph_mut()
.add_deprel(DepTriple::new(2, Some("det"), 1))
.unwrap();
sent.dep_graph_mut()
.add_deprel(DepTriple::new(3, Some("subj"), 2))
.unwrap();
sent.dep_graph_mut()
.add_deprel(DepTriple::new(4, Some("foo"), 3))
.unwrap();
sent.dep_graph_mut()
.add_deprel(DepTriple::new(3, Some("obj"), 4))
.unwrap();
sent
}
#[test]
fn find_first_root() {
let sent = test_graph();
assert_eq!(first_root(&sent), Some(3));
}
#[test]
fn attach_two_orphans() {
let mut sent = Sentence::new();
sent.push(TokenBuilder::new("Die").xpos("det").into());
sent.push(TokenBuilder::new("AWO").xpos("noun").into());
sent.push(TokenBuilder::new("veruntreute").xpos("verb").into());
sent.push(TokenBuilder::new("Spendengeld").xpos("noun").into());
sent.dep_graph_mut()
.add_deprel(DepTriple::new(2, Some("det"), 1))
.unwrap();
sent.dep_graph_mut()
.add_deprel(DepTriple::new(0, Some(ROOT_RELATION), 3))
.unwrap();
let encodings: Vec<_> = RelativePosEncoder::new(PosLayer::XPos, ROOT_RELATION)
.encode(&test_graph())
.unwrap()
.into_iter()
.map(|e| [EncodingProb::new(e, 1.)])
.collect();
attach_orphans(&encodings, &mut sent, 3).unwrap();
assert_eq!(sent, test_graph());
}
#[test]
fn add_missing_root() {
let mut sent = test_graph_no_root();
let encodings: Vec<Vec<EncodingProb<DependencyEncoding<RelativePos>>>> = vec![
vec![EncodingProb::new(
DependencyEncoding {
label: ROOT_RELATION.to_owned(),
head: RelativePos::new(ROOT_POS, -1),
},
0.4,
)],
vec![],
vec![
EncodingProb::new(
DependencyEncoding {
label: "distractor".to_owned(),
head: RelativePos::new(ROOT_POS, -1),
},
0.6,
),
EncodingProb::new(
DependencyEncoding {
label: ROOT_RELATION.to_owned(),
head: RelativePos::new(ROOT_POS, -1),
},
0.4,
),
],
vec![EncodingProb::new(
DependencyEncoding {
label: ROOT_RELATION.to_owned(),
head: RelativePos::new(ROOT_POS, -1),
},
0.3,
)],
];
let pos_table = RelativePosEncoder::new(PosLayer::XPos, "root").pos_position_table(&sent);
find_or_create_root(
&encodings,
&mut sent,
|idx, encoding| RelativePosEncoder::decode_idx(&pos_table, idx, encoding).ok(),
ROOT_RELATION,
)
.unwrap();
assert_eq!(sent, test_graph());
}
#[test]
fn break_simple_cycle() {
let mut check = test_graph_cycle();
check
.dep_graph_mut()
.add_deprel(DepTriple::new(3, Some("det"), 1))
.unwrap();
let mut sent = test_graph_cycle();
break_cycles(&mut sent, 3).unwrap();
assert_eq!(sent, check);
}
}