langextract_rust/
data.rs

1//! Core data types for the annotation pipeline.
2//!
3//! This module defines the fundamental data structures used throughout the langextract
4//! library, including documents, extractions, and configuration types.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use uuid::Uuid;
9
10/// Status indicating how well an extraction aligns with the source text
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum AlignmentStatus {
14    /// Extraction matches the source text exactly
15    MatchExact,
16    /// Extraction text is longer than the source span
17    MatchGreater,
18    /// Extraction text is shorter than the source span
19    MatchLesser,
20    /// Extraction text approximately matches but with differences
21    MatchFuzzy,
22}
23
24/// Represents a character interval in text
25#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
26pub struct CharInterval {
27    /// Starting position of the interval (inclusive)
28    pub start_pos: Option<usize>,
29    /// Ending position of the interval (exclusive)
30    pub end_pos: Option<usize>,
31}
32
33impl CharInterval {
34    /// Create a new character interval
35    pub fn new(start_pos: Option<usize>, end_pos: Option<usize>) -> Self {
36        Self { start_pos, end_pos }
37    }
38
39    /// Check if this interval overlaps with another
40    pub fn overlaps_with(&self, other: &CharInterval) -> bool {
41        match (self.start_pos, self.end_pos, other.start_pos, other.end_pos) {
42            (Some(s1), Some(e1), Some(s2), Some(e2)) => {
43                // Two intervals overlap if one starts before the other ends
44                s1 < e2 && s2 < e1
45            }
46            _ => false, // If any position is None, consider no overlap
47        }
48    }
49
50    /// Get the length of the interval
51    pub fn length(&self) -> Option<usize> {
52        match (self.start_pos, self.end_pos) {
53            (Some(start), Some(end)) if end >= start => Some(end - start),
54            _ => None,
55        }
56    }
57}
58
59/// Token interval information (placeholder for future tokenizer integration)
60#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
61pub struct TokenInterval {
62    /// Starting token index
63    pub start_token: Option<usize>,
64    /// Ending token index
65    pub end_token: Option<usize>,
66}
67
68impl TokenInterval {
69    /// Create a new token interval
70    pub fn new(start_token: Option<usize>, end_token: Option<usize>) -> Self {
71        Self {
72            start_token,
73            end_token,
74        }
75    }
76}
77
78/// Represents an extraction extracted from text
79///
80/// This struct encapsulates an extraction's characteristics and its position
81/// within the source text. It can represent diverse information for NLP
82/// information extraction tasks.
83#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
84pub struct Extraction {
85    /// The class or type of the extraction
86    pub extraction_class: String,
87    /// The extracted text content
88    pub extraction_text: String,
89    /// Character position in the original text
90    pub char_interval: Option<CharInterval>,
91    /// How well this extraction aligns with the source
92    pub alignment_status: Option<AlignmentStatus>,
93    /// Index of this extraction in the list
94    pub extraction_index: Option<usize>,
95    /// Group index for related extractions
96    pub group_index: Option<usize>,
97    /// Human-readable description
98    pub description: Option<String>,
99    /// Additional attributes as key-value pairs
100    pub attributes: Option<HashMap<String, serde_json::Value>>,
101    /// Token position information
102    #[serde(skip)]
103    pub token_interval: Option<TokenInterval>,
104}
105
106impl Extraction {
107    /// Create a new extraction with just the class and text
108    pub fn new(extraction_class: String, extraction_text: String) -> Self {
109        Self {
110            extraction_class,
111            extraction_text,
112            char_interval: None,
113            alignment_status: None,
114            extraction_index: None,
115            group_index: None,
116            description: None,
117            attributes: None,
118            token_interval: None,
119        }
120    }
121}
122
123impl Default for Extraction {
124    fn default() -> Self {
125        Self {
126            extraction_class: String::new(),
127            extraction_text: String::new(),
128            char_interval: None,
129            alignment_status: None,
130            extraction_index: None,
131            group_index: None,
132            description: None,
133            attributes: None,
134            token_interval: None,
135        }
136    }
137}
138
139impl Extraction {
140    /// Create a new extraction with character interval
141    pub fn with_char_interval(
142        extraction_class: String,
143        extraction_text: String,
144        char_interval: CharInterval,
145    ) -> Self {
146        Self {
147            extraction_class,
148            extraction_text,
149            char_interval: Some(char_interval),
150            alignment_status: None,
151            extraction_index: None,
152            group_index: None,
153            description: None,
154            attributes: None,
155            token_interval: None,
156        }
157    }
158
159    /// Set the character interval for this extraction
160    pub fn set_char_interval(&mut self, interval: CharInterval) {
161        self.char_interval = Some(interval);
162    }
163
164    /// Set an attribute value
165    pub fn set_attribute(&mut self, key: String, value: serde_json::Value) {
166        if self.attributes.is_none() {
167            self.attributes = Some(HashMap::new());
168        }
169        if let Some(attrs) = &mut self.attributes {
170            attrs.insert(key, value);
171        }
172    }
173
174    /// Get an attribute value
175    pub fn get_attribute(&self, key: &str) -> Option<&serde_json::Value> {
176        self.attributes.as_ref()?.get(key)
177    }
178
179    /// Check if this extraction overlaps with another based on character intervals
180    pub fn overlaps_with(&self, other: &Extraction) -> bool {
181        match (&self.char_interval, &other.char_interval) {
182            (Some(interval1), Some(interval2)) => interval1.overlaps_with(interval2),
183            _ => false,
184        }
185    }
186}
187
188/// Document class for input text
189///
190/// Represents a single document to be processed by the annotation pipeline.
191#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
192pub struct Document {
193    /// Raw text content of the document
194    pub text: String,
195    /// Optional additional context to supplement prompt instructions
196    pub additional_context: Option<String>,
197    /// Unique identifier for the document
198    #[serde(skip_serializing_if = "Option::is_none")]
199    pub document_id: Option<String>,
200}
201
202impl Document {
203    /// Create a new document with the given text
204    pub fn new(text: String) -> Self {
205        Self {
206            text,
207            additional_context: None,
208            document_id: None,
209        }
210    }
211
212    /// Create a new document with text and additional context
213    pub fn with_context(text: String, additional_context: String) -> Self {
214        Self {
215            text,
216            additional_context: Some(additional_context),
217            document_id: None,
218        }
219    }
220
221    /// Get or generate a document ID
222    pub fn get_document_id(&mut self) -> String {
223        if let Some(id) = &self.document_id {
224            id.clone()
225        } else {
226            let id = format!("doc_{}", Uuid::new_v4().simple().to_string()[..8].to_string());
227            self.document_id = Some(id.clone());
228            id
229        }
230    }
231
232    /// Set a specific document ID
233    pub fn set_document_id(&mut self, id: String) {
234        self.document_id = Some(id);
235    }
236}
237
238/// Annotated document with extractions
239///
240/// Represents the result of processing a document through the annotation pipeline.
241#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
242pub struct AnnotatedDocument {
243    /// Unique identifier for the document
244    #[serde(skip_serializing_if = "Option::is_none")]
245    pub document_id: Option<String>,
246    /// List of extractions found in the document
247    pub extractions: Option<Vec<Extraction>>,
248    /// Original text content
249    pub text: Option<String>,
250}
251
252impl AnnotatedDocument {
253    /// Create a new annotated document
254    pub fn new() -> Self {
255        Self {
256            document_id: None,
257            extractions: None,
258            text: None,
259        }
260    }
261
262    /// Create an annotated document with extractions and text
263    pub fn with_extractions(extractions: Vec<Extraction>, text: String) -> Self {
264        Self {
265            document_id: None,
266            extractions: Some(extractions),
267            text: Some(text),
268        }
269    }
270
271    /// Get or generate a document ID
272    pub fn get_document_id(&mut self) -> String {
273        if let Some(id) = &self.document_id {
274            id.clone()
275        } else {
276            let id = format!("doc_{}", Uuid::new_v4().simple().to_string()[..8].to_string());
277            self.document_id = Some(id.clone());
278            id
279        }
280    }
281
282    /// Set the document ID
283    pub fn set_document_id(&mut self, id: String) {
284        self.document_id = Some(id);
285    }
286
287    /// Add an extraction to this document
288    pub fn add_extraction(&mut self, extraction: Extraction) {
289        if self.extractions.is_none() {
290            self.extractions = Some(Vec::new());
291        }
292        if let Some(extractions) = &mut self.extractions {
293            extractions.push(extraction);
294        }
295    }
296
297    /// Get the number of extractions
298    pub fn extraction_count(&self) -> usize {
299        self.extractions.as_ref().map_or(0, |e| e.len())
300    }
301
302    /// Get extractions of a specific class
303    pub fn extractions_by_class(&self, class_name: &str) -> Vec<&Extraction> {
304        self.extractions
305            .as_ref()
306            .map_or(Vec::new(), |extractions| {
307                extractions
308                    .iter()
309                    .filter(|e| e.extraction_class == class_name)
310                    .collect()
311            })
312    }
313}
314
315impl Default for AnnotatedDocument {
316    fn default() -> Self {
317        Self::new()
318    }
319}
320
321/// Enumeration of supported output formats
322#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
323#[serde(rename_all = "lowercase")]
324pub enum FormatType {
325    /// JSON output format
326    Json,
327    /// YAML output format
328    Yaml,
329}
330
331impl std::fmt::Display for FormatType {
332    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333        match self {
334            FormatType::Json => write!(f, "json"),
335            FormatType::Yaml => write!(f, "yaml"),
336        }
337    }
338}
339
340impl std::str::FromStr for FormatType {
341    type Err = String;
342
343    fn from_str(s: &str) -> Result<Self, Self::Err> {
344        match s.to_lowercase().as_str() {
345            "json" => Ok(FormatType::Json),
346            "yaml" => Ok(FormatType::Yaml),
347            _ => Err(format!("Invalid format type: {}", s)),
348        }
349    }
350}
351
352/// Example data for training/prompting
353///
354/// Represents a single training example that shows the model how to extract
355/// information from text.
356#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
357pub struct ExampleData {
358    /// The raw input text (sentence, paragraph, etc.)
359    pub text: String,
360    /// List of extractions that should be found in this text
361    pub extractions: Vec<Extraction>,
362}
363
364impl ExampleData {
365    /// Create a new example with text and extractions
366    pub fn new(text: String, extractions: Vec<Extraction>) -> Self {
367        Self { text, extractions }
368    }
369
370    /// Create an example with just text (no extractions)
371    pub fn with_text(text: String) -> Self {
372        Self {
373            text,
374            extractions: Vec::new(),
375        }
376    }
377
378    /// Add an extraction to this example
379    pub fn add_extraction(&mut self, extraction: Extraction) {
380        self.extractions.push(extraction);
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use serde_json::json;
388
389    #[test]
390    fn test_char_interval_overlap() {
391        let interval1 = CharInterval::new(Some(0), Some(5));
392        let interval2 = CharInterval::new(Some(3), Some(8));
393        let interval3 = CharInterval::new(Some(10), Some(15));
394
395        assert!(interval1.overlaps_with(&interval2));
396        assert!(interval2.overlaps_with(&interval1));
397        assert!(!interval1.overlaps_with(&interval3));
398        assert!(!interval3.overlaps_with(&interval1));
399    }
400
401    #[test]
402    fn test_char_interval_length() {
403        let interval = CharInterval::new(Some(5), Some(10));
404        assert_eq!(interval.length(), Some(5));
405
406        let interval_none = CharInterval::new(None, Some(10));
407        assert_eq!(interval_none.length(), None);
408    }
409
410    #[test]
411    fn test_extraction_creation() {
412        let extraction = Extraction::new("person".to_string(), "John Doe".to_string());
413        assert_eq!(extraction.extraction_class, "person");
414        assert_eq!(extraction.extraction_text, "John Doe");
415        assert!(extraction.char_interval.is_none());
416    }
417
418    #[test]
419    fn test_extraction_attributes() {
420        let mut extraction = Extraction::new("person".to_string(), "John Doe".to_string());
421        extraction.set_attribute("age".to_string(), json!(30));
422        extraction.set_attribute("city".to_string(), json!("New York"));
423
424        assert_eq!(extraction.get_attribute("age"), Some(&json!(30)));
425        assert_eq!(extraction.get_attribute("city"), Some(&json!("New York")));
426        assert_eq!(extraction.get_attribute("nonexistent"), None);
427    }
428
429    #[test]
430    fn test_extraction_overlap() {
431        let mut extraction1 = Extraction::new("person".to_string(), "John".to_string());
432        extraction1.set_char_interval(CharInterval::new(Some(0), Some(4)));
433
434        let mut extraction2 = Extraction::new("name".to_string(), "John Doe".to_string());
435        extraction2.set_char_interval(CharInterval::new(Some(2), Some(8)));
436
437        let mut extraction3 = Extraction::new("city".to_string(), "Boston".to_string());
438        extraction3.set_char_interval(CharInterval::new(Some(10), Some(16)));
439
440        assert!(extraction1.overlaps_with(&extraction2));
441        assert!(!extraction1.overlaps_with(&extraction3));
442    }
443
444    #[test]
445    fn test_document_id_generation() {
446        let mut doc = Document::new("Test text".to_string());
447        let id1 = doc.get_document_id();
448        let id2 = doc.get_document_id();
449
450        assert_eq!(id1, id2); // Should be same ID when called multiple times
451        assert!(id1.starts_with("doc_"));
452        assert_eq!(id1.len(), 12); // "doc_" + 8 hex chars
453    }
454
455    #[test]
456    fn test_annotated_document_operations() {
457        let mut doc = AnnotatedDocument::new();
458        assert_eq!(doc.extraction_count(), 0);
459
460        let extraction1 = Extraction::new("person".to_string(), "Alice".to_string());
461        let extraction2 = Extraction::new("person".to_string(), "Bob".to_string());
462        let extraction3 = Extraction::new("location".to_string(), "Paris".to_string());
463
464        doc.add_extraction(extraction1);
465        doc.add_extraction(extraction2);
466        doc.add_extraction(extraction3);
467
468        assert_eq!(doc.extraction_count(), 3);
469
470        let person_extractions = doc.extractions_by_class("person");
471        assert_eq!(person_extractions.len(), 2);
472
473        let location_extractions = doc.extractions_by_class("location");
474        assert_eq!(location_extractions.len(), 1);
475    }
476
477    #[test]
478    fn test_format_type_conversion() {
479        assert_eq!("json".parse::<FormatType>().unwrap(), FormatType::Json);
480        assert_eq!("yaml".parse::<FormatType>().unwrap(), FormatType::Yaml);
481        assert_eq!("JSON".parse::<FormatType>().unwrap(), FormatType::Json);
482
483        assert!(matches!("xml".parse::<FormatType>(), Err(_)));
484
485        assert_eq!(FormatType::Json.to_string(), "json");
486        assert_eq!(FormatType::Yaml.to_string(), "yaml");
487    }
488
489    #[test]
490    fn test_example_data() {
491        let mut example = ExampleData::with_text("John is 30 years old".to_string());
492        assert_eq!(example.extractions.len(), 0);
493
494        example.add_extraction(Extraction::new("person".to_string(), "John".to_string()));
495        example.add_extraction(Extraction::new("age".to_string(), "30".to_string()));
496
497        assert_eq!(example.extractions.len(), 2);
498    }
499
500    #[test]
501    fn test_serialization() {
502        let extraction = Extraction::new("person".to_string(), "John Doe".to_string());
503        let json_str = serde_json::to_string(&extraction).unwrap();
504        let deserialized: Extraction = serde_json::from_str(&json_str).unwrap();
505        assert_eq!(extraction, deserialized);
506
507        let doc = Document::new("Test text".to_string());
508        let json_str = serde_json::to_string(&doc).unwrap();
509        let deserialized: Document = serde_json::from_str(&json_str).unwrap();
510        assert_eq!(doc, deserialized);
511    }
512}