use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AlignmentStatus {
MatchExact,
MatchGreater,
MatchLesser,
MatchFuzzy,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CharInterval {
pub start_pos: Option<usize>,
pub end_pos: Option<usize>,
}
impl CharInterval {
pub fn new(start_pos: Option<usize>, end_pos: Option<usize>) -> Self {
Self { start_pos, end_pos }
}
pub fn overlaps_with(&self, other: &CharInterval) -> bool {
match (self.start_pos, self.end_pos, other.start_pos, other.end_pos) {
(Some(s1), Some(e1), Some(s2), Some(e2)) => {
s1 < e2 && s2 < e1
}
_ => false, }
}
pub fn length(&self) -> Option<usize> {
match (self.start_pos, self.end_pos) {
(Some(start), Some(end)) if end >= start => Some(end - start),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct TokenInterval {
pub start_token: Option<usize>,
pub end_token: Option<usize>,
}
impl TokenInterval {
pub fn new(start_token: Option<usize>, end_token: Option<usize>) -> Self {
Self {
start_token,
end_token,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Extraction {
pub extraction_class: String,
pub extraction_text: String,
pub char_interval: Option<CharInterval>,
pub alignment_status: Option<AlignmentStatus>,
pub extraction_index: Option<usize>,
pub group_index: Option<usize>,
pub description: Option<String>,
pub attributes: Option<HashMap<String, serde_json::Value>>,
#[serde(skip)]
pub token_interval: Option<TokenInterval>,
}
impl Extraction {
pub fn new(extraction_class: String, extraction_text: String) -> Self {
Self {
extraction_class,
extraction_text,
char_interval: None,
alignment_status: None,
extraction_index: None,
group_index: None,
description: None,
attributes: None,
token_interval: None,
}
}
}
impl Default for Extraction {
fn default() -> Self {
Self {
extraction_class: String::new(),
extraction_text: String::new(),
char_interval: None,
alignment_status: None,
extraction_index: None,
group_index: None,
description: None,
attributes: None,
token_interval: None,
}
}
}
impl Extraction {
pub fn with_char_interval(
extraction_class: String,
extraction_text: String,
char_interval: CharInterval,
) -> Self {
Self {
extraction_class,
extraction_text,
char_interval: Some(char_interval),
alignment_status: None,
extraction_index: None,
group_index: None,
description: None,
attributes: None,
token_interval: None,
}
}
pub fn set_char_interval(&mut self, interval: CharInterval) {
self.char_interval = Some(interval);
}
pub fn set_attribute(&mut self, key: String, value: serde_json::Value) {
if self.attributes.is_none() {
self.attributes = Some(HashMap::new());
}
if let Some(attrs) = &mut self.attributes {
attrs.insert(key, value);
}
}
pub fn get_attribute(&self, key: &str) -> Option<&serde_json::Value> {
self.attributes.as_ref()?.get(key)
}
pub fn overlaps_with(&self, other: &Extraction) -> bool {
match (&self.char_interval, &other.char_interval) {
(Some(interval1), Some(interval2)) => interval1.overlaps_with(interval2),
_ => false,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Document {
pub text: String,
pub additional_context: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub document_id: Option<String>,
}
impl Document {
pub fn new(text: String) -> Self {
Self {
text,
additional_context: None,
document_id: None,
}
}
pub fn with_context(text: String, additional_context: String) -> Self {
Self {
text,
additional_context: Some(additional_context),
document_id: None,
}
}
pub fn get_document_id(&mut self) -> String {
if let Some(id) = &self.document_id {
id.clone()
} else {
let id = format!("doc_{}", Uuid::new_v4().simple().to_string()[..8].to_string());
self.document_id = Some(id.clone());
id
}
}
pub fn set_document_id(&mut self, id: String) {
self.document_id = Some(id);
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AnnotatedDocument {
#[serde(skip_serializing_if = "Option::is_none")]
pub document_id: Option<String>,
pub extractions: Option<Vec<Extraction>>,
pub text: Option<String>,
}
impl AnnotatedDocument {
pub fn new() -> Self {
Self {
document_id: None,
extractions: None,
text: None,
}
}
pub fn with_extractions(extractions: Vec<Extraction>, text: String) -> Self {
Self {
document_id: None,
extractions: Some(extractions),
text: Some(text),
}
}
pub fn get_document_id(&mut self) -> String {
if let Some(id) = &self.document_id {
id.clone()
} else {
let id = format!("doc_{}", Uuid::new_v4().simple().to_string()[..8].to_string());
self.document_id = Some(id.clone());
id
}
}
pub fn set_document_id(&mut self, id: String) {
self.document_id = Some(id);
}
pub fn add_extraction(&mut self, extraction: Extraction) {
if self.extractions.is_none() {
self.extractions = Some(Vec::new());
}
if let Some(extractions) = &mut self.extractions {
extractions.push(extraction);
}
}
pub fn extraction_count(&self) -> usize {
self.extractions.as_ref().map_or(0, |e| e.len())
}
pub fn extractions_by_class(&self, class_name: &str) -> Vec<&Extraction> {
self.extractions
.as_ref()
.map_or(Vec::new(), |extractions| {
extractions
.iter()
.filter(|e| e.extraction_class == class_name)
.collect()
})
}
}
impl Default for AnnotatedDocument {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum FormatType {
Json,
Yaml,
}
impl std::fmt::Display for FormatType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FormatType::Json => write!(f, "json"),
FormatType::Yaml => write!(f, "yaml"),
}
}
}
impl std::str::FromStr for FormatType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"json" => Ok(FormatType::Json),
"yaml" => Ok(FormatType::Yaml),
_ => Err(format!("Invalid format type: {}", s)),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ExampleData {
pub text: String,
pub extractions: Vec<Extraction>,
}
impl ExampleData {
pub fn new(text: String, extractions: Vec<Extraction>) -> Self {
Self { text, extractions }
}
pub fn with_text(text: String) -> Self {
Self {
text,
extractions: Vec::new(),
}
}
pub fn add_extraction(&mut self, extraction: Extraction) {
self.extractions.push(extraction);
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_char_interval_overlap() {
let interval1 = CharInterval::new(Some(0), Some(5));
let interval2 = CharInterval::new(Some(3), Some(8));
let interval3 = CharInterval::new(Some(10), Some(15));
assert!(interval1.overlaps_with(&interval2));
assert!(interval2.overlaps_with(&interval1));
assert!(!interval1.overlaps_with(&interval3));
assert!(!interval3.overlaps_with(&interval1));
}
#[test]
fn test_char_interval_length() {
let interval = CharInterval::new(Some(5), Some(10));
assert_eq!(interval.length(), Some(5));
let interval_none = CharInterval::new(None, Some(10));
assert_eq!(interval_none.length(), None);
}
#[test]
fn test_extraction_creation() {
let extraction = Extraction::new("person".to_string(), "John Doe".to_string());
assert_eq!(extraction.extraction_class, "person");
assert_eq!(extraction.extraction_text, "John Doe");
assert!(extraction.char_interval.is_none());
}
#[test]
fn test_extraction_attributes() {
let mut extraction = Extraction::new("person".to_string(), "John Doe".to_string());
extraction.set_attribute("age".to_string(), json!(30));
extraction.set_attribute("city".to_string(), json!("New York"));
assert_eq!(extraction.get_attribute("age"), Some(&json!(30)));
assert_eq!(extraction.get_attribute("city"), Some(&json!("New York")));
assert_eq!(extraction.get_attribute("nonexistent"), None);
}
#[test]
fn test_extraction_overlap() {
let mut extraction1 = Extraction::new("person".to_string(), "John".to_string());
extraction1.set_char_interval(CharInterval::new(Some(0), Some(4)));
let mut extraction2 = Extraction::new("name".to_string(), "John Doe".to_string());
extraction2.set_char_interval(CharInterval::new(Some(2), Some(8)));
let mut extraction3 = Extraction::new("city".to_string(), "Boston".to_string());
extraction3.set_char_interval(CharInterval::new(Some(10), Some(16)));
assert!(extraction1.overlaps_with(&extraction2));
assert!(!extraction1.overlaps_with(&extraction3));
}
#[test]
fn test_document_id_generation() {
let mut doc = Document::new("Test text".to_string());
let id1 = doc.get_document_id();
let id2 = doc.get_document_id();
assert_eq!(id1, id2); assert!(id1.starts_with("doc_"));
assert_eq!(id1.len(), 12); }
#[test]
fn test_annotated_document_operations() {
let mut doc = AnnotatedDocument::new();
assert_eq!(doc.extraction_count(), 0);
let extraction1 = Extraction::new("person".to_string(), "Alice".to_string());
let extraction2 = Extraction::new("person".to_string(), "Bob".to_string());
let extraction3 = Extraction::new("location".to_string(), "Paris".to_string());
doc.add_extraction(extraction1);
doc.add_extraction(extraction2);
doc.add_extraction(extraction3);
assert_eq!(doc.extraction_count(), 3);
let person_extractions = doc.extractions_by_class("person");
assert_eq!(person_extractions.len(), 2);
let location_extractions = doc.extractions_by_class("location");
assert_eq!(location_extractions.len(), 1);
}
#[test]
fn test_format_type_conversion() {
assert_eq!("json".parse::<FormatType>().unwrap(), FormatType::Json);
assert_eq!("yaml".parse::<FormatType>().unwrap(), FormatType::Yaml);
assert_eq!("JSON".parse::<FormatType>().unwrap(), FormatType::Json);
assert!(matches!("xml".parse::<FormatType>(), Err(_)));
assert_eq!(FormatType::Json.to_string(), "json");
assert_eq!(FormatType::Yaml.to_string(), "yaml");
}
#[test]
fn test_example_data() {
let mut example = ExampleData::with_text("John is 30 years old".to_string());
assert_eq!(example.extractions.len(), 0);
example.add_extraction(Extraction::new("person".to_string(), "John".to_string()));
example.add_extraction(Extraction::new("age".to_string(), "30".to_string()));
assert_eq!(example.extractions.len(), 2);
}
#[test]
fn test_serialization() {
let extraction = Extraction::new("person".to_string(), "John Doe".to_string());
let json_str = serde_json::to_string(&extraction).unwrap();
let deserialized: Extraction = serde_json::from_str(&json_str).unwrap();
assert_eq!(extraction, deserialized);
let doc = Document::new("Test text".to_string());
let json_str = serde_json::to_string(&doc).unwrap();
let deserialized: Document = serde_json::from_str(&json_str).unwrap();
assert_eq!(doc, deserialized);
}
}