1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use uuid::Uuid;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum AlignmentStatus {
14 MatchExact,
16 MatchGreater,
18 MatchLesser,
20 MatchFuzzy,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26pub struct CharInterval {
27 pub start_pos: Option<usize>,
29 pub end_pos: Option<usize>,
31}
32
33impl CharInterval {
34 pub fn new(start_pos: Option<usize>, end_pos: Option<usize>) -> Self {
36 Self { start_pos, end_pos }
37 }
38
39 pub fn overlaps_with(&self, other: &CharInterval) -> bool {
41 match (self.start_pos, self.end_pos, other.start_pos, other.end_pos) {
42 (Some(s1), Some(e1), Some(s2), Some(e2)) => {
43 s1 < e2 && s2 < e1
45 }
46 _ => false, }
48 }
49
50 pub fn length(&self) -> Option<usize> {
52 match (self.start_pos, self.end_pos) {
53 (Some(start), Some(end)) if end >= start => Some(end - start),
54 _ => None,
55 }
56 }
57}
58
59#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
61pub struct TokenInterval {
62 pub start_token: Option<usize>,
64 pub end_token: Option<usize>,
66}
67
68impl TokenInterval {
69 pub fn new(start_token: Option<usize>, end_token: Option<usize>) -> Self {
71 Self {
72 start_token,
73 end_token,
74 }
75 }
76}
77
78#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub struct Extraction {
85 pub extraction_class: String,
87 pub extraction_text: String,
89 pub char_interval: Option<CharInterval>,
91 pub alignment_status: Option<AlignmentStatus>,
93 pub extraction_index: Option<usize>,
95 pub group_index: Option<usize>,
97 pub description: Option<String>,
99 pub attributes: Option<HashMap<String, serde_json::Value>>,
101 #[serde(skip)]
103 pub token_interval: Option<TokenInterval>,
104}
105
106impl Extraction {
107 pub fn new(extraction_class: String, extraction_text: String) -> Self {
109 Self {
110 extraction_class,
111 extraction_text,
112 char_interval: None,
113 alignment_status: None,
114 extraction_index: None,
115 group_index: None,
116 description: None,
117 attributes: None,
118 token_interval: None,
119 }
120 }
121}
122
123impl Default for Extraction {
124 fn default() -> Self {
125 Self {
126 extraction_class: String::new(),
127 extraction_text: String::new(),
128 char_interval: None,
129 alignment_status: None,
130 extraction_index: None,
131 group_index: None,
132 description: None,
133 attributes: None,
134 token_interval: None,
135 }
136 }
137}
138
139impl Extraction {
140 pub fn with_char_interval(
142 extraction_class: String,
143 extraction_text: String,
144 char_interval: CharInterval,
145 ) -> Self {
146 Self {
147 extraction_class,
148 extraction_text,
149 char_interval: Some(char_interval),
150 alignment_status: None,
151 extraction_index: None,
152 group_index: None,
153 description: None,
154 attributes: None,
155 token_interval: None,
156 }
157 }
158
159 pub fn set_char_interval(&mut self, interval: CharInterval) {
161 self.char_interval = Some(interval);
162 }
163
164 pub fn set_attribute(&mut self, key: String, value: serde_json::Value) {
166 if self.attributes.is_none() {
167 self.attributes = Some(HashMap::new());
168 }
169 if let Some(attrs) = &mut self.attributes {
170 attrs.insert(key, value);
171 }
172 }
173
174 pub fn get_attribute(&self, key: &str) -> Option<&serde_json::Value> {
176 self.attributes.as_ref()?.get(key)
177 }
178
179 pub fn overlaps_with(&self, other: &Extraction) -> bool {
181 match (&self.char_interval, &other.char_interval) {
182 (Some(interval1), Some(interval2)) => interval1.overlaps_with(interval2),
183 _ => false,
184 }
185 }
186}
187
188#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
192pub struct Document {
193 pub text: String,
195 pub additional_context: Option<String>,
197 #[serde(skip_serializing_if = "Option::is_none")]
199 pub document_id: Option<String>,
200}
201
202impl Document {
203 pub fn new(text: String) -> Self {
205 Self {
206 text,
207 additional_context: None,
208 document_id: None,
209 }
210 }
211
212 pub fn with_context(text: String, additional_context: String) -> Self {
214 Self {
215 text,
216 additional_context: Some(additional_context),
217 document_id: None,
218 }
219 }
220
221 pub fn get_document_id(&mut self) -> String {
223 if let Some(id) = &self.document_id {
224 id.clone()
225 } else {
226 let id = format!("doc_{}", Uuid::new_v4().simple().to_string()[..8].to_string());
227 self.document_id = Some(id.clone());
228 id
229 }
230 }
231
232 pub fn set_document_id(&mut self, id: String) {
234 self.document_id = Some(id);
235 }
236}
237
238#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
242pub struct AnnotatedDocument {
243 #[serde(skip_serializing_if = "Option::is_none")]
245 pub document_id: Option<String>,
246 pub extractions: Option<Vec<Extraction>>,
248 pub text: Option<String>,
250}
251
252impl AnnotatedDocument {
253 pub fn new() -> Self {
255 Self {
256 document_id: None,
257 extractions: None,
258 text: None,
259 }
260 }
261
262 pub fn with_extractions(extractions: Vec<Extraction>, text: String) -> Self {
264 Self {
265 document_id: None,
266 extractions: Some(extractions),
267 text: Some(text),
268 }
269 }
270
271 pub fn get_document_id(&mut self) -> String {
273 if let Some(id) = &self.document_id {
274 id.clone()
275 } else {
276 let id = format!("doc_{}", Uuid::new_v4().simple().to_string()[..8].to_string());
277 self.document_id = Some(id.clone());
278 id
279 }
280 }
281
282 pub fn set_document_id(&mut self, id: String) {
284 self.document_id = Some(id);
285 }
286
287 pub fn add_extraction(&mut self, extraction: Extraction) {
289 if self.extractions.is_none() {
290 self.extractions = Some(Vec::new());
291 }
292 if let Some(extractions) = &mut self.extractions {
293 extractions.push(extraction);
294 }
295 }
296
297 pub fn extraction_count(&self) -> usize {
299 self.extractions.as_ref().map_or(0, |e| e.len())
300 }
301
302 pub fn extractions_by_class(&self, class_name: &str) -> Vec<&Extraction> {
304 self.extractions
305 .as_ref()
306 .map_or(Vec::new(), |extractions| {
307 extractions
308 .iter()
309 .filter(|e| e.extraction_class == class_name)
310 .collect()
311 })
312 }
313}
314
315impl Default for AnnotatedDocument {
316 fn default() -> Self {
317 Self::new()
318 }
319}
320
321#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
323#[serde(rename_all = "lowercase")]
324pub enum FormatType {
325 Json,
327 Yaml,
329}
330
331impl std::fmt::Display for FormatType {
332 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333 match self {
334 FormatType::Json => write!(f, "json"),
335 FormatType::Yaml => write!(f, "yaml"),
336 }
337 }
338}
339
340impl std::str::FromStr for FormatType {
341 type Err = String;
342
343 fn from_str(s: &str) -> Result<Self, Self::Err> {
344 match s.to_lowercase().as_str() {
345 "json" => Ok(FormatType::Json),
346 "yaml" => Ok(FormatType::Yaml),
347 _ => Err(format!("Invalid format type: {}", s)),
348 }
349 }
350}
351
352#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
357pub struct ExampleData {
358 pub text: String,
360 pub extractions: Vec<Extraction>,
362}
363
364impl ExampleData {
365 pub fn new(text: String, extractions: Vec<Extraction>) -> Self {
367 Self { text, extractions }
368 }
369
370 pub fn with_text(text: String) -> Self {
372 Self {
373 text,
374 extractions: Vec::new(),
375 }
376 }
377
378 pub fn add_extraction(&mut self, extraction: Extraction) {
380 self.extractions.push(extraction);
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use serde_json::json;
388
389 #[test]
390 fn test_char_interval_overlap() {
391 let interval1 = CharInterval::new(Some(0), Some(5));
392 let interval2 = CharInterval::new(Some(3), Some(8));
393 let interval3 = CharInterval::new(Some(10), Some(15));
394
395 assert!(interval1.overlaps_with(&interval2));
396 assert!(interval2.overlaps_with(&interval1));
397 assert!(!interval1.overlaps_with(&interval3));
398 assert!(!interval3.overlaps_with(&interval1));
399 }
400
401 #[test]
402 fn test_char_interval_length() {
403 let interval = CharInterval::new(Some(5), Some(10));
404 assert_eq!(interval.length(), Some(5));
405
406 let interval_none = CharInterval::new(None, Some(10));
407 assert_eq!(interval_none.length(), None);
408 }
409
410 #[test]
411 fn test_extraction_creation() {
412 let extraction = Extraction::new("person".to_string(), "John Doe".to_string());
413 assert_eq!(extraction.extraction_class, "person");
414 assert_eq!(extraction.extraction_text, "John Doe");
415 assert!(extraction.char_interval.is_none());
416 }
417
418 #[test]
419 fn test_extraction_attributes() {
420 let mut extraction = Extraction::new("person".to_string(), "John Doe".to_string());
421 extraction.set_attribute("age".to_string(), json!(30));
422 extraction.set_attribute("city".to_string(), json!("New York"));
423
424 assert_eq!(extraction.get_attribute("age"), Some(&json!(30)));
425 assert_eq!(extraction.get_attribute("city"), Some(&json!("New York")));
426 assert_eq!(extraction.get_attribute("nonexistent"), None);
427 }
428
429 #[test]
430 fn test_extraction_overlap() {
431 let mut extraction1 = Extraction::new("person".to_string(), "John".to_string());
432 extraction1.set_char_interval(CharInterval::new(Some(0), Some(4)));
433
434 let mut extraction2 = Extraction::new("name".to_string(), "John Doe".to_string());
435 extraction2.set_char_interval(CharInterval::new(Some(2), Some(8)));
436
437 let mut extraction3 = Extraction::new("city".to_string(), "Boston".to_string());
438 extraction3.set_char_interval(CharInterval::new(Some(10), Some(16)));
439
440 assert!(extraction1.overlaps_with(&extraction2));
441 assert!(!extraction1.overlaps_with(&extraction3));
442 }
443
444 #[test]
445 fn test_document_id_generation() {
446 let mut doc = Document::new("Test text".to_string());
447 let id1 = doc.get_document_id();
448 let id2 = doc.get_document_id();
449
450 assert_eq!(id1, id2); assert!(id1.starts_with("doc_"));
452 assert_eq!(id1.len(), 12); }
454
455 #[test]
456 fn test_annotated_document_operations() {
457 let mut doc = AnnotatedDocument::new();
458 assert_eq!(doc.extraction_count(), 0);
459
460 let extraction1 = Extraction::new("person".to_string(), "Alice".to_string());
461 let extraction2 = Extraction::new("person".to_string(), "Bob".to_string());
462 let extraction3 = Extraction::new("location".to_string(), "Paris".to_string());
463
464 doc.add_extraction(extraction1);
465 doc.add_extraction(extraction2);
466 doc.add_extraction(extraction3);
467
468 assert_eq!(doc.extraction_count(), 3);
469
470 let person_extractions = doc.extractions_by_class("person");
471 assert_eq!(person_extractions.len(), 2);
472
473 let location_extractions = doc.extractions_by_class("location");
474 assert_eq!(location_extractions.len(), 1);
475 }
476
477 #[test]
478 fn test_format_type_conversion() {
479 assert_eq!("json".parse::<FormatType>().unwrap(), FormatType::Json);
480 assert_eq!("yaml".parse::<FormatType>().unwrap(), FormatType::Yaml);
481 assert_eq!("JSON".parse::<FormatType>().unwrap(), FormatType::Json);
482
483 assert!(matches!("xml".parse::<FormatType>(), Err(_)));
484
485 assert_eq!(FormatType::Json.to_string(), "json");
486 assert_eq!(FormatType::Yaml.to_string(), "yaml");
487 }
488
489 #[test]
490 fn test_example_data() {
491 let mut example = ExampleData::with_text("John is 30 years old".to_string());
492 assert_eq!(example.extractions.len(), 0);
493
494 example.add_extraction(Extraction::new("person".to_string(), "John".to_string()));
495 example.add_extraction(Extraction::new("age".to_string(), "30".to_string()));
496
497 assert_eq!(example.extractions.len(), 2);
498 }
499
500 #[test]
501 fn test_serialization() {
502 let extraction = Extraction::new("person".to_string(), "John Doe".to_string());
503 let json_str = serde_json::to_string(&extraction).unwrap();
504 let deserialized: Extraction = serde_json::from_str(&json_str).unwrap();
505 assert_eq!(extraction, deserialized);
506
507 let doc = Document::new("Test text".to_string());
508 let json_str = serde_json::to_string(&doc).unwrap();
509 let deserialized: Document = serde_json::from_str(&json_str).unwrap();
510 assert_eq!(doc, deserialized);
511 }
512}