use super::graph::{DepLabel, DependencyGraph};
#[derive(Debug, Clone)]
pub enum ArcEagerTransition {
Shift,
Reduce,
LeftArc(DepLabel),
RightArc(DepLabel),
}
impl std::fmt::Display for ArcEagerTransition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Shift => write!(f, "SHIFT"),
Self::Reduce => write!(f, "REDUCE"),
Self::LeftArc(l) => write!(f, "LEFT-ARC({})", l),
Self::RightArc(l) => write!(f, "RIGHT-ARC({})", l),
}
}
}
#[derive(Debug, Clone)]
pub struct ArcEagerConfig {
pub stack: Vec<usize>,
pub buffer: Vec<usize>,
pub arcs: Vec<(usize, usize, DepLabel)>,
pub has_head: Vec<bool>,
pub n_tokens: usize,
}
impl ArcEagerConfig {
pub fn new(n_tokens: usize) -> Self {
Self {
stack: vec![0],
buffer: (1..=n_tokens).collect(),
arcs: Vec::new(),
has_head: vec![false; n_tokens + 1], n_tokens,
}
}
pub fn is_terminal(&self) -> bool {
self.buffer.is_empty()
}
pub fn stack_top(&self) -> Option<usize> {
self.stack.last().copied()
}
pub fn buffer_front(&self) -> Option<usize> {
self.buffer.first().copied()
}
pub fn apply(&mut self, t: &ArcEagerTransition) -> bool {
match t {
ArcEagerTransition::Shift => {
match self.buffer_front() {
None => false,
Some(w) => {
self.buffer.remove(0);
self.stack.push(w);
true
}
}
}
ArcEagerTransition::Reduce => {
match self.stack_top() {
None | Some(0) => false, Some(top) => {
if !self.has_head[top] {
return false;
}
self.stack.pop();
true
}
}
}
ArcEagerTransition::LeftArc(label) => {
let top = match self.stack_top() { None => return false, Some(t) => t };
let front = match self.buffer_front() { None => return false, Some(f) => f };
if top == 0 { return false; } if self.has_head[top] { return false; } self.arcs.push((front, top, label.clone()));
self.has_head[top] = true;
self.stack.pop();
true
}
ArcEagerTransition::RightArc(label) => {
let top = match self.stack_top() { None => return false, Some(t) => t };
let front = match self.buffer_front() { None => return false, Some(f) => f };
self.arcs.push((top, front, label.clone()));
self.has_head[front] = true;
self.buffer.remove(0);
self.stack.push(front);
true
}
}
}
pub fn legal_transitions(&self) -> Vec<ArcEagerTransition> {
let mut legal = Vec::new();
if !self.buffer.is_empty() {
legal.push(ArcEagerTransition::Shift);
}
if let Some(top) = self.stack_top() {
if top != 0 && self.has_head[top] {
legal.push(ArcEagerTransition::Reduce);
}
if let Some(_front) = self.buffer_front() {
if top != 0 && !self.has_head[top] {
legal.push(ArcEagerTransition::LeftArc(DepLabel::Dep));
}
legal.push(ArcEagerTransition::RightArc(DepLabel::Dep));
}
}
legal
}
}
pub struct ArcEagerParser;
impl ArcEagerParser {
pub fn new() -> Self {
Self
}
pub fn parse(&self, tokens: &[String], pos_tags: &[String]) -> DependencyGraph {
let n = tokens.len();
if n == 0 {
return DependencyGraph::new(Vec::new(), Vec::new());
}
let mut config = ArcEagerConfig::new(n);
let mut graph = DependencyGraph::new(tokens.to_vec(), pos_tags.to_vec());
let max_steps = 4 * n + 10;
for _ in 0..max_steps {
if config.is_terminal() { break; }
let trans = self.oracle(&config, pos_tags);
if !config.apply(&trans) {
if !config.buffer.is_empty() {
config.apply(&ArcEagerTransition::Shift);
} else if let Some(top) = config.stack_top() {
if top != 0 {
config.arcs.push((0, top, DepLabel::Root));
config.has_head[top] = true;
config.stack.pop();
} else {
break;
}
} else {
break;
}
}
}
for (head, dep, label) in &config.arcs {
graph.add_arc(*head, *dep, label.clone(), 1.0);
}
for i in 1..=n {
if graph.head_of(i).is_none() {
graph.add_arc(0, i, DepLabel::Root, 0.5);
}
}
graph
}
fn oracle(&self, config: &ArcEagerConfig, pos_tags: &[String]) -> ArcEagerTransition {
let top = config.stack_top();
let front = config.buffer_front();
let pos = |idx: Option<usize>| -> &str {
match idx {
None | Some(0) => "ROOT",
Some(i) => pos_tags.get(i - 1).map(|s| s.as_str()).unwrap_or("_"),
}
};
let pt = pos(top);
let pf = pos(front);
if let Some(t) = top {
if t != 0 && config.has_head[t] && front.is_none() {
return ArcEagerTransition::Reduce;
}
}
if top.is_none() || matches!(top, Some(0)) {
return ArcEagerTransition::Shift;
}
if is_punct(pf) {
if !config.has_head[front.unwrap_or(0)] {
return ArcEagerTransition::LeftArc(DepLabel::Punct);
}
}
if (is_det(pf) || is_adj(pf)) && is_noun(pt) {
}
if is_det(pt) && is_noun(pf) {
return ArcEagerTransition::Shift;
}
if is_verb(pt) && is_noun(pf) {
return ArcEagerTransition::RightArc(DepLabel::Obj);
}
if is_noun(pt) && is_verb(pf) {
return ArcEagerTransition::Shift;
}
if let Some(t) = top {
if t != 0 && config.has_head[t] {
return ArcEagerTransition::Reduce;
}
}
if !config.buffer.is_empty() {
ArcEagerTransition::Shift
} else {
ArcEagerTransition::RightArc(DepLabel::Dep)
}
}
}
impl Default for ArcEagerParser {
fn default() -> Self {
Self::new()
}
}
fn is_punct(pos: &str) -> bool {
matches!(pos, "PUNCT" | "." | "," | ":" | ";" | "!" | "?")
|| pos.starts_with("PUNCT")
}
fn is_det(pos: &str) -> bool { matches!(pos, "DT" | "det" | "DET") }
fn is_noun(pos: &str) -> bool {
pos.starts_with("NN") || matches!(pos, "noun" | "NOUN" | "PROPN")
}
fn is_adj(pos: &str) -> bool { pos.starts_with("JJ") || matches!(pos, "adj" | "ADJ") }
fn is_verb(pos: &str) -> bool { pos.starts_with('V') || matches!(pos, "verb" | "VERB") }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_arc_eager_parse() {
let tokens = ["The", "cat", "sat"].map(String::from).to_vec();
let pos = ["DT", "NN", "VBD"].map(String::from).to_vec();
let parser = ArcEagerParser::new();
let graph = parser.parse(&tokens, &pos);
assert_eq!(graph.n_tokens, 3);
for i in 1..=3 {
assert!(graph.head_of(i).is_some(), "token {} has no head", i);
}
}
#[test]
fn test_arc_eager_empty() {
let parser = ArcEagerParser::new();
let g = parser.parse(&[], &[]);
assert_eq!(g.n_tokens, 0);
}
#[test]
fn test_config_shift() {
let mut cfg = ArcEagerConfig::new(3);
assert!(!cfg.is_terminal());
cfg.apply(&ArcEagerTransition::Shift);
assert_eq!(cfg.stack.len(), 2); assert_eq!(cfg.buffer.len(), 2);
}
#[test]
fn test_config_left_arc() {
let mut cfg = ArcEagerConfig::new(3);
cfg.apply(&ArcEagerTransition::Shift); let ok = cfg.apply(&ArcEagerTransition::LeftArc(DepLabel::Det));
assert!(ok);
assert!(cfg.has_head[1]);
assert_eq!(cfg.stack, vec![0]); }
#[test]
fn test_config_right_arc() {
let mut cfg = ArcEagerConfig::new(3);
cfg.apply(&ArcEagerTransition::Shift); let ok = cfg.apply(&ArcEagerTransition::RightArc(DepLabel::Obj));
assert!(ok);
assert!(cfg.has_head[2]);
assert_eq!(cfg.stack.last(), Some(&2)); }
#[test]
fn test_reduce_precondition() {
let mut cfg = ArcEagerConfig::new(2);
cfg.apply(&ArcEagerTransition::Shift); let ok = cfg.apply(&ArcEagerTransition::Reduce);
assert!(!ok);
}
}