use std::collections::VecDeque;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DepLabel {
Root,
Subj,
Obj,
Iobj,
Csubj,
Ccomp,
Xcomp,
Nmod,
Amod,
Advmod,
Aux,
Det,
Case,
Punct,
Conj,
Cc,
Mark,
Dep,
Other(String),
}
impl std::fmt::Display for DepLabel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Root => write!(f, "root"),
Self::Subj => write!(f, "nsubj"),
Self::Obj => write!(f, "obj"),
Self::Iobj => write!(f, "iobj"),
Self::Csubj => write!(f, "csubj"),
Self::Ccomp => write!(f, "ccomp"),
Self::Xcomp => write!(f, "xcomp"),
Self::Nmod => write!(f, "nmod"),
Self::Amod => write!(f, "amod"),
Self::Advmod => write!(f, "advmod"),
Self::Aux => write!(f, "aux"),
Self::Det => write!(f, "det"),
Self::Case => write!(f, "case"),
Self::Punct => write!(f, "punct"),
Self::Conj => write!(f, "conj"),
Self::Cc => write!(f, "cc"),
Self::Mark => write!(f, "mark"),
Self::Dep => write!(f, "dep"),
Self::Other(s) => write!(f, "{}", s),
}
}
}
impl DepLabel {
pub fn from_str(s: &str) -> Self {
match s {
"root" => Self::Root,
"nsubj" => Self::Subj,
"obj" => Self::Obj,
"iobj" => Self::Iobj,
"csubj" => Self::Csubj,
"ccomp" => Self::Ccomp,
"xcomp" => Self::Xcomp,
"nmod" => Self::Nmod,
"amod" => Self::Amod,
"advmod" => Self::Advmod,
"aux" => Self::Aux,
"det" => Self::Det,
"case" => Self::Case,
"punct" => Self::Punct,
"conj" => Self::Conj,
"cc" => Self::Cc,
"mark" => Self::Mark,
"dep" => Self::Dep,
other => Self::Other(other.to_string()),
}
}
}
#[derive(Debug, Clone)]
pub struct DependencyArc {
pub head: usize,
pub dependent: usize,
pub label: DepLabel,
pub score: f64,
}
#[derive(Debug, Clone)]
pub struct DependencyGraph {
pub tokens: Vec<String>,
pub pos_tags: Vec<String>,
pub arcs: Vec<DependencyArc>,
pub n_tokens: usize,
}
impl DependencyGraph {
pub fn new(tokens: Vec<String>, pos_tags: Vec<String>) -> Self {
let n = tokens.len();
Self {
tokens,
pos_tags,
arcs: Vec::new(),
n_tokens: n,
}
}
pub fn add_arc(&mut self, head: usize, dependent: usize, label: DepLabel, score: f64) {
self.arcs.push(DependencyArc { head, dependent, label, score });
}
pub fn head_of(&self, i: usize) -> Option<usize> {
self.arcs.iter().find(|a| a.dependent == i).map(|a| a.head)
}
pub fn label_of(&self, i: usize) -> Option<&DepLabel> {
self.arcs.iter().find(|a| a.dependent == i).map(|a| &a.label)
}
pub fn dependents_of(&self, i: usize) -> Vec<usize> {
self.arcs
.iter()
.filter(|a| a.head == i)
.map(|a| a.dependent)
.collect()
}
pub fn path(&self, src: usize, dst: usize) -> Vec<(usize, bool)> {
let size = self.n_tokens + 1; let mut adj: Vec<Vec<(usize, bool)>> = vec![Vec::new(); size];
for arc in &self.arcs {
if arc.head < size && arc.dependent < size {
adj[arc.head].push((arc.dependent, false)); adj[arc.dependent].push((arc.head, true)); }
}
let mut visited = vec![false; size];
let mut prev: Vec<Option<(usize, bool)>> = vec![None; size];
let mut queue = VecDeque::new();
if src < size {
queue.push_back(src);
visited[src] = true;
}
'bfs: while let Some(curr) = queue.pop_front() {
if curr == dst {
break 'bfs;
}
for &(next, up) in &adj[curr] {
if !visited[next] {
visited[next] = true;
prev[next] = Some((curr, up));
queue.push_back(next);
}
}
}
let mut path = Vec::new();
let mut curr = dst;
while let Some((p, up)) = prev[curr] {
path.push((curr, up));
curr = p;
}
path.reverse();
path
}
pub fn is_projective(&self) -> bool {
for a1 in &self.arcs {
for a2 in &self.arcs {
if std::ptr::eq(a1, a2) {
continue;
}
let lo1 = a1.head.min(a1.dependent);
let hi1 = a1.head.max(a1.dependent);
let lo2 = a2.head.min(a2.dependent);
let hi2 = a2.head.max(a2.dependent);
if (lo1 < lo2 && lo2 < hi1 && hi1 < hi2)
|| (lo2 < lo1 && lo1 < hi2 && hi2 < hi1)
{
return false;
}
}
}
true
}
pub fn las(&self, gold: &DependencyGraph) -> f64 {
if self.n_tokens == 0 {
return 0.0;
}
let correct = self.arcs.iter().filter(|pred| {
gold.arcs.iter().any(|g| {
g.head == pred.head && g.dependent == pred.dependent && g.label == pred.label
})
}).count();
correct as f64 / self.n_tokens as f64
}
pub fn uas(&self, gold: &DependencyGraph) -> f64 {
if self.n_tokens == 0 {
return 0.0;
}
let correct = self.arcs.iter().filter(|pred| {
gold.arcs.iter().any(|g| {
g.head == pred.head && g.dependent == pred.dependent
})
}).count();
correct as f64 / self.n_tokens as f64
}
pub fn to_conllu(&self) -> String {
let mut out = String::new();
for i in 1..=self.n_tokens {
let arc = self.arcs.iter().find(|a| a.dependent == i);
let (head, label) = arc
.map(|a| (a.head, a.label.to_string()))
.unwrap_or((0, "dep".to_string()));
let form = self.tokens.get(i - 1).map(|s| s.as_str()).unwrap_or("_");
let pos = self.pos_tags.get(i - 1).map(|s| s.as_str()).unwrap_or("_");
out += &format!("{}\t{}\t_\t{}\t_\t_\t{}\t{}\t_\t_\n", i, form, pos, head, label);
}
out
}
pub fn from_conllu(conllu: &str) -> Self {
let mut tokens = Vec::new();
let mut pos_tags = Vec::new();
let mut arcs = Vec::new();
for line in conllu.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let cols: Vec<&str> = line.split('\t').collect();
if cols.len() < 8 {
continue;
}
if cols[0].contains('-') || cols[0].contains('.') {
continue;
}
let dep_idx: usize = cols[0].parse().unwrap_or(0);
let form = cols[1].to_string();
let pos = cols[3].to_string();
let head: usize = cols[6].parse().unwrap_or(0);
let label = DepLabel::from_str(cols[7]);
tokens.push(form);
pos_tags.push(pos);
arcs.push(DependencyArc { head, dependent: dep_idx, label, score: 1.0 });
}
let n = tokens.len();
Self { tokens, pos_tags, arcs, n_tokens: n }
}
pub fn subtree(&self, root_idx: usize) -> Vec<usize> {
let mut result = Vec::new();
let mut stack = vec![root_idx];
while let Some(node) = stack.pop() {
result.push(node);
for child in self.dependents_of(node) {
stack.push(child);
}
}
result.sort_unstable();
result
}
}
#[cfg(test)]
mod tests {
use super::*;
fn simple_graph() -> DependencyGraph {
let tokens = vec!["The".into(), "cat".into(), "sat".into()];
let pos = vec!["DT".into(), "NN".into(), "VBD".into()];
let mut g = DependencyGraph::new(tokens, pos);
g.add_arc(0, 3, DepLabel::Root, 1.0); g.add_arc(3, 2, DepLabel::Subj, 1.0); g.add_arc(2, 1, DepLabel::Det, 1.0); g
}
#[test]
fn test_projectivity() {
let g = simple_graph();
assert!(g.is_projective());
}
#[test]
fn test_non_projective() {
let tokens = vec!["a".into(), "b".into(), "c".into(), "d".into()];
let pos = vec!["NN".into(); 4];
let mut g = DependencyGraph::new(tokens, pos);
g.add_arc(1, 3, DepLabel::Dep, 1.0); g.add_arc(2, 4, DepLabel::Dep, 1.0); assert!(!g.is_projective());
}
#[test]
fn test_head_of_and_dependents_of() {
let g = simple_graph();
assert_eq!(g.head_of(3), Some(0)); assert_eq!(g.head_of(2), Some(3)); assert_eq!(g.head_of(1), Some(2)); let deps = g.dependents_of(3);
assert!(deps.contains(&2));
}
#[test]
fn test_conllu_roundtrip() {
let g = simple_graph();
let conllu = g.to_conllu();
let g2 = DependencyGraph::from_conllu(&conllu);
assert_eq!(g2.n_tokens, g.n_tokens);
assert_eq!(g2.tokens, g.tokens);
}
#[test]
fn test_las_uas() {
let gold = simple_graph();
let pred = simple_graph();
assert!((pred.las(&gold) - 1.0).abs() < 1e-9);
assert!((pred.uas(&gold) - 1.0).abs() < 1e-9);
}
#[test]
fn test_subtree() {
let g = simple_graph();
let sub = g.subtree(3);
assert!(sub.contains(&1));
assert!(sub.contains(&2));
assert!(sub.contains(&3));
}
#[test]
fn test_path() {
let g = simple_graph();
let path = g.path(1, 3);
assert!(!path.is_empty());
}
}