pub mod arc_eager;
pub mod arc_standard;
pub mod graph;
pub mod projective;
pub use arc_eager::ArcEagerParser;
pub use arc_standard::{ArcStandardConfig, ArcStandardParser, Transition};
pub use graph::{DepLabel, DependencyArc, DependencyGraph};
pub use projective::{ChuLiuEdmonds, EisnerParser, ScoreMatrix};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn integration_projectivity_simple_tree() {
let tokens = vec!["The".into(), "cat".into(), "sat".into()];
let pos = vec!["DT".into(), "NN".into(), "VBD".into()];
let mut g = DependencyGraph::new(tokens, pos);
g.add_arc(0, 3, DepLabel::Root, 1.0);
g.add_arc(3, 2, DepLabel::Subj, 1.0);
g.add_arc(2, 1, DepLabel::Det, 1.0);
assert!(g.is_projective(), "simple left-branching tree should be projective");
}
#[test]
fn integration_arc_standard_three_tokens() {
let tokens = vec!["The".into(), "cat".into(), "sat".into()];
let pos = vec!["DT".into(), "NN".into(), "VBD".into()];
let parser = ArcStandardParser::new();
let graph = parser.parse(&tokens, &pos);
assert_eq!(graph.n_tokens, 3);
for i in 1..=3 {
assert!(
graph.head_of(i).is_some(),
"arc-standard: token {} missing head", i
);
}
}
#[test]
fn integration_las_uas() {
let tokens = vec!["a".into(), "b".into(), "c".into()];
let pos = vec!["NN".into(); 3];
let mut gold = DependencyGraph::new(tokens.clone(), pos.clone());
gold.add_arc(0, 3, DepLabel::Root, 1.0);
gold.add_arc(3, 2, DepLabel::Subj, 1.0);
gold.add_arc(2, 1, DepLabel::Det, 1.0);
let pred = gold.clone();
let las = pred.las(&gold);
let uas = pred.uas(&gold);
assert!((las - 1.0).abs() < 1e-9, "LAS should be 1.0, got {}", las);
assert!((uas - 1.0).abs() < 1e-9, "UAS should be 1.0, got {}", uas);
let mut wrong_label = DependencyGraph::new(tokens.clone(), pos.clone());
wrong_label.add_arc(0, 3, DepLabel::Dep, 1.0);
wrong_label.add_arc(3, 2, DepLabel::Dep, 1.0);
wrong_label.add_arc(2, 1, DepLabel::Dep, 1.0);
let las2 = wrong_label.las(&gold);
let uas2 = wrong_label.uas(&gold);
assert!((uas2 - 1.0).abs() < 1e-9, "UAS with wrong labels should be 1.0");
assert!(las2 < 1.0, "LAS with wrong labels should be < 1.0");
}
#[test]
fn integration_conllu_output() {
let tokens = vec!["The".into(), "cat".into(), "sat".into()];
let pos = vec!["DT".into(), "NN".into(), "VBD".into()];
let mut g = DependencyGraph::new(tokens, pos);
g.add_arc(0, 3, DepLabel::Root, 1.0);
g.add_arc(3, 2, DepLabel::Subj, 1.0);
g.add_arc(2, 1, DepLabel::Det, 1.0);
let conllu = g.to_conllu();
assert!(conllu.contains("The"), "CoNLL-U missing 'The'");
assert!(conllu.contains("cat"), "CoNLL-U missing 'cat'");
assert!(conllu.contains("sat"), "CoNLL-U missing 'sat'");
assert!(conllu.contains("\t2\tdet\t"), "CoNLL-U wrong head/label for 'The'");
assert!(conllu.contains("\t0\troot\t"), "CoNLL-U wrong head/label for 'sat'");
let g2 = DependencyGraph::from_conllu(&conllu);
assert_eq!(g2.n_tokens, g.n_tokens);
assert_eq!(g2.tokens, g.tokens);
}
#[test]
fn integration_eisner_small_sentence() {
let tokens = vec!["The".into(), "cat".into(), "sat".into()];
let pos = vec!["DT".into(), "NN".into(), "VBD".into()];
let parser = EisnerParser::from_heuristic(tokens.len());
let graph = parser.parse_to_graph(tokens.clone(), pos.clone());
assert_eq!(graph.n_tokens, 3);
for i in 1..=3 {
assert!(graph.head_of(i).is_some(), "Eisner: token {} missing head", i);
}
assert!(graph.is_projective());
}
#[test]
fn integration_arc_eager_three_tokens() {
let tokens = vec!["The".into(), "cat".into(), "sat".into()];
let pos = vec!["DT".into(), "NN".into(), "VBD".into()];
let parser = ArcEagerParser::new();
let graph = parser.parse(&tokens, &pos);
assert_eq!(graph.n_tokens, 3);
for i in 1..=3 {
assert!(
graph.head_of(i).is_some(),
"arc-eager: token {} missing head", i
);
}
}
#[test]
fn integration_chu_liu_edmonds() {
let scores = ScoreMatrix::from_distance_heuristic(3);
let tokens = vec!["The".into(), "cat".into(), "sat".into()];
let pos = vec!["DT".into(), "NN".into(), "VBD".into()];
let graph = ChuLiuEdmonds::parse_to_graph(&scores, tokens, pos);
assert_eq!(graph.n_tokens, 3);
for i in 1..=3 {
assert!(graph.head_of(i).is_some(), "CLE: token {} missing head", i);
}
}
}