use crate::error::{SeqError, SeqResult};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Tag {
Outside,
Begin(String),
Inside(String),
End(String),
Single(String),
}
impl Tag {
pub fn parse(s: &str) -> SeqResult<Tag> {
if s == "O" {
return Ok(Tag::Outside);
}
let mut parts = s.splitn(2, '-');
let prefix = parts.next().unwrap_or("");
let body = parts.next();
let ty = match body {
Some(b) if !b.is_empty() => b.to_string(),
_ => {
return Err(SeqError::InvalidObservation(format!(
"tag '{s}' is missing an entity type"
)));
}
};
match prefix {
"B" => Ok(Tag::Begin(ty)),
"I" => Ok(Tag::Inside(ty)),
"E" => Ok(Tag::End(ty)),
"S" => Ok(Tag::Single(ty)),
other => Err(SeqError::InvalidObservation(format!(
"unknown tag prefix '{other}' in '{s}'"
))),
}
}
#[must_use]
pub fn to_tag_string(&self) -> String {
match self {
Tag::Outside => "O".to_string(),
Tag::Begin(t) => format!("B-{t}"),
Tag::Inside(t) => format!("I-{t}"),
Tag::End(t) => format!("E-{t}"),
Tag::Single(t) => format!("S-{t}"),
}
}
#[must_use]
pub fn entity_type(&self) -> Option<&str> {
match self {
Tag::Outside => None,
Tag::Begin(t) | Tag::Inside(t) | Tag::End(t) | Tag::Single(t) => Some(t.as_str()),
}
}
}
pub fn parse_tags(tags: &[&str]) -> SeqResult<Vec<Tag>> {
tags.iter().map(|t| Tag::parse(t)).collect()
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Span {
pub entity_type: String,
pub start: usize,
pub end: usize,
}
#[must_use]
pub fn extract_spans(tags: &[Tag]) -> Vec<Span> {
let mut spans = Vec::new();
let mut open: Option<(String, usize)> = None;
for (i, tag) in tags.iter().enumerate() {
let close_before = match (&open, tag) {
(Some(_), Tag::Begin(_)) | (Some(_), Tag::Single(_)) | (Some(_), Tag::Outside) => true,
(Some((ot, _)), Tag::Inside(t)) | (Some((ot, _)), Tag::End(t)) => ot != t,
_ => false,
};
if close_before {
if let Some((ot, os)) = open.take() {
spans.push(Span {
entity_type: ot,
start: os,
end: i - 1,
});
}
}
match tag {
Tag::Outside => {}
Tag::Begin(t) | Tag::Single(t) => {
open = Some((t.clone(), i));
}
Tag::Inside(t) | Tag::End(t) => {
if open.is_none() {
open = Some((t.clone(), i));
}
}
}
if matches!(tag, Tag::Single(_) | Tag::End(_)) {
if let Some((ot, os)) = open.take() {
spans.push(Span {
entity_type: ot,
start: os,
end: i,
});
}
}
}
if let Some((ot, os)) = open.take() {
spans.push(Span {
entity_type: ot,
start: os,
end: tags.len() - 1,
});
}
spans
}
pub fn bio_to_bioes(tags: &[Tag]) -> SeqResult<Vec<Tag>> {
if tags.is_empty() {
return Err(SeqError::EmptyInput);
}
let spans = extract_spans(tags);
let mut out = vec![Tag::Outside; tags.len()];
for sp in spans {
if sp.start == sp.end {
out[sp.start] = Tag::Single(sp.entity_type);
} else {
out[sp.start] = Tag::Begin(sp.entity_type.clone());
for idx in sp.start + 1..sp.end {
out[idx] = Tag::Inside(sp.entity_type.clone());
}
out[sp.end] = Tag::End(sp.entity_type);
}
}
Ok(out)
}
pub fn bioes_to_bio(tags: &[Tag]) -> SeqResult<Vec<Tag>> {
if tags.is_empty() {
return Err(SeqError::EmptyInput);
}
let spans = extract_spans(tags);
let mut out = vec![Tag::Outside; tags.len()];
for sp in spans {
out[sp.start] = Tag::Begin(sp.entity_type.clone());
for idx in sp.start + 1..=sp.end {
out[idx] = Tag::Inside(sp.entity_type.clone());
}
}
Ok(out)
}
pub fn validate_bio(tags: &[Tag]) -> SeqResult<()> {
let mut prev: Option<&Tag> = None;
for (i, tag) in tags.iter().enumerate() {
match tag {
Tag::End(_) | Tag::Single(_) => {
return Err(SeqError::GraphInvariantViolated(format!(
"BIO sequence contains BIOES tag '{}' at {i}",
tag.to_tag_string()
)));
}
Tag::Inside(t) => {
let ok = matches!(prev, Some(Tag::Begin(p)) | Some(Tag::Inside(p)) if p == t);
if !ok {
return Err(SeqError::GraphInvariantViolated(format!(
"I-{t} at {i} not preceded by B-{t}/I-{t}"
)));
}
}
_ => {}
}
prev = Some(tag);
}
Ok(())
}
pub fn validate_bioes(tags: &[Tag]) -> SeqResult<()> {
let n = tags.len();
for i in 0..n {
match &tags[i] {
Tag::Inside(t) | Tag::End(t) => {
let prev_ok =
i > 0 && matches!(&tags[i - 1], Tag::Begin(p) | Tag::Inside(p) if p == t);
if !prev_ok {
return Err(SeqError::GraphInvariantViolated(format!(
"{} at {i} does not continue an open span of type {t}",
tags[i].to_tag_string()
)));
}
}
Tag::Begin(t) => {
let next_ok =
i + 1 < n && matches!(&tags[i + 1], Tag::Inside(p) | Tag::End(p) if p == t);
if !next_ok {
return Err(SeqError::GraphInvariantViolated(format!(
"B-{t} at {i} is not continued by I-{t}/E-{t}"
)));
}
}
_ => {}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn tags(strs: &[&str]) -> Vec<Tag> {
parse_tags(strs).expect("parse")
}
#[test]
fn parse_outside_and_typed() {
assert_eq!(Tag::parse("O").expect("o"), Tag::Outside);
assert_eq!(Tag::parse("B-PER").expect("b"), Tag::Begin("PER".into()));
assert_eq!(Tag::parse("I-LOC").expect("i"), Tag::Inside("LOC".into()));
assert_eq!(Tag::parse("E-ORG").expect("e"), Tag::End("ORG".into()));
assert_eq!(Tag::parse("S-MISC").expect("s"), Tag::Single("MISC".into()));
}
#[test]
fn parse_handles_hyphenated_types() {
assert_eq!(
Tag::parse("B-GPE-CITY").expect("ok"),
Tag::Begin("GPE-CITY".into())
);
}
#[test]
fn parse_rejects_bad_prefix_and_empty_type() {
assert!(Tag::parse("X-PER").is_err());
assert!(Tag::parse("B-").is_err());
assert!(Tag::parse("B").is_err());
}
#[test]
fn to_tag_string_roundtrip() {
for s in ["O", "B-PER", "I-PER", "E-PER", "S-LOC"] {
assert_eq!(Tag::parse(s).expect("p").to_tag_string(), s);
}
}
#[test]
fn extract_single_and_multi_spans() {
let t = tags(&["B-PER", "I-PER", "O", "S-LOC"]);
let spans = extract_spans(&t);
assert_eq!(spans.len(), 2);
assert_eq!(
spans[0],
Span {
entity_type: "PER".into(),
start: 0,
end: 1
}
);
assert_eq!(
spans[1],
Span {
entity_type: "LOC".into(),
start: 3,
end: 3
}
);
}
#[test]
fn extract_adjacent_same_type_spans() {
let t = tags(&["B-PER", "B-PER"]);
let spans = extract_spans(&t);
assert_eq!(spans.len(), 2);
assert_eq!(spans[0].start, 0);
assert_eq!(spans[0].end, 0);
assert_eq!(spans[1].start, 1);
assert_eq!(spans[1].end, 1);
}
#[test]
fn extract_dangling_inside_starts_new_span() {
let t = tags(&["I-PER", "I-PER"]);
let spans = extract_spans(&t);
assert_eq!(spans.len(), 1);
assert_eq!(spans[0].start, 0);
assert_eq!(spans[0].end, 1);
}
#[test]
fn extract_type_change_splits() {
let t = tags(&["B-PER", "I-LOC"]);
let spans = extract_spans(&t);
assert_eq!(spans.len(), 2);
assert_eq!(spans[0].entity_type, "PER");
assert_eq!(spans[0].end, 0);
assert_eq!(spans[1].entity_type, "LOC");
assert_eq!(spans[1].start, 1);
}
#[test]
fn bio_to_bioes_basic() {
let bio = tags(&["B-PER", "I-PER", "O", "B-LOC"]);
let bioes = bio_to_bioes(&bio).expect("conv");
let got: Vec<String> = bioes.iter().map(Tag::to_tag_string).collect();
assert_eq!(got, vec!["B-PER", "E-PER", "O", "S-LOC"]);
}
#[test]
fn bioes_to_bio_basic() {
let bioes = tags(&["B-PER", "E-PER", "O", "S-LOC"]);
let bio = bioes_to_bio(&bioes).expect("conv");
let got: Vec<String> = bio.iter().map(Tag::to_tag_string).collect();
assert_eq!(got, vec!["B-PER", "I-PER", "O", "B-LOC"]);
}
#[test]
fn conversion_roundtrip_preserves_spans() {
let bio = tags(&["O", "B-PER", "I-PER", "I-PER", "O", "B-LOC", "O"]);
let bioes = bio_to_bioes(&bio).expect("a");
let back = bioes_to_bio(&bioes).expect("b");
assert_eq!(extract_spans(&bio), extract_spans(&back));
}
#[test]
fn convert_rejects_empty() {
assert!(matches!(bio_to_bioes(&[]), Err(SeqError::EmptyInput)));
assert!(matches!(bioes_to_bio(&[]), Err(SeqError::EmptyInput)));
}
#[test]
fn validate_bio_accepts_well_formed() {
let t = tags(&["O", "B-PER", "I-PER", "O", "B-LOC"]);
assert!(validate_bio(&t).is_ok());
}
#[test]
fn validate_bio_rejects_dangling_inside() {
let t = tags(&["O", "I-PER"]);
assert!(matches!(
validate_bio(&t),
Err(SeqError::GraphInvariantViolated(_))
));
}
#[test]
fn validate_bio_rejects_type_mismatch_continuation() {
let t = tags(&["B-PER", "I-LOC"]);
assert!(validate_bio(&t).is_err());
}
#[test]
fn validate_bio_rejects_bioes_tags() {
let t = tags(&["S-PER"]);
assert!(validate_bio(&t).is_err());
}
#[test]
fn validate_bioes_accepts_well_formed() {
let t = tags(&["B-PER", "I-PER", "E-PER", "O", "S-LOC"]);
assert!(validate_bioes(&t).is_ok());
}
#[test]
fn validate_bioes_rejects_begin_without_end() {
let t = tags(&["B-PER", "O"]);
assert!(matches!(
validate_bioes(&t),
Err(SeqError::GraphInvariantViolated(_))
));
}
#[test]
fn validate_bioes_rejects_dangling_end() {
let t = tags(&["E-PER"]);
assert!(validate_bioes(&t).is_err());
}
}