1use crate::offset::TextSpan;
48use serde::{Deserialize, Serialize};
49use std::collections::HashMap;
50
51#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
53pub enum ReferenceType {
54 WikipediaUrl,
56 WikidataUrl,
58 DbpediaUrl,
60 WebUrl,
62 AcademicCitation,
64 Doi,
66 Arxiv,
68 CrossReference,
70 FootnoteMarker,
72 Isbn,
74 SocialHandle,
76 Hashtag,
78 #[default]
80 Unknown,
81}
82
83impl ReferenceType {
84 pub fn is_resolvable(&self) -> bool {
86 matches!(
87 self,
88 Self::WikipediaUrl
89 | Self::WikidataUrl
90 | Self::DbpediaUrl
91 | Self::WebUrl
92 | Self::Doi
93 | Self::Arxiv
94 )
95 }
96
97 pub fn is_kb_link(&self) -> bool {
99 matches!(
100 self,
101 Self::WikipediaUrl | Self::WikidataUrl | Self::DbpediaUrl
102 )
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct Reference {
109 pub text: String,
111 pub start: usize,
113 pub end: usize,
115 pub reference_type: ReferenceType,
117 pub url: Option<String>,
119 pub entity_id: Option<String>,
121 pub title: Option<String>,
123 pub antecedent: Option<String>,
125 pub is_resolved: bool,
127}
128
129impl Reference {
130 pub fn new(text: &str, start: usize, end: usize, ref_type: ReferenceType) -> Self {
132 Self {
133 text: text.to_string(),
134 start,
135 end,
136 reference_type: ref_type,
137 url: None,
138 entity_id: None,
139 title: None,
140 antecedent: None,
141 is_resolved: false,
142 }
143 }
144
145 pub fn with_url(mut self, url: &str) -> Self {
147 self.url = Some(url.to_string());
148 self
149 }
150
151 pub fn with_entity_id(mut self, id: &str) -> Self {
153 self.entity_id = Some(id.to_string());
154 self
155 }
156
157 pub fn with_antecedent(mut self, antecedent: &str) -> Self {
159 self.antecedent = Some(antecedent.to_string());
160 self
161 }
162
163 pub fn mark_resolved(mut self) -> Self {
165 self.is_resolved = true;
166 self
167 }
168
169 pub fn wikidata_qid(&self) -> Option<&str> {
171 if self.reference_type == ReferenceType::WikidataUrl {
172 self.entity_id.as_deref()
173 } else {
174 None
175 }
176 }
177}
178
179#[derive(Debug, Clone, Default)]
181pub struct ReferenceExtractor {
182 extract_wikipedia: bool,
184 extract_web_urls: bool,
186 extract_citations: bool,
188 extract_social: bool,
190}
191
192impl ReferenceExtractor {
193 pub fn new() -> Self {
195 Self {
196 extract_wikipedia: true,
197 extract_web_urls: true,
198 extract_citations: true,
199 extract_social: true,
200 }
201 }
202
203 pub fn wikipedia(mut self, enabled: bool) -> Self {
205 self.extract_wikipedia = enabled;
206 self
207 }
208
209 pub fn extract(&self, text: &str) -> Vec<Reference> {
211 let mut refs = Vec::new();
212
213 if self.extract_web_urls || self.extract_wikipedia {
215 refs.extend(self.extract_urls(text));
216 }
217
218 if self.extract_citations {
220 refs.extend(self.extract_citations(text));
221 }
222
223 refs.extend(self.extract_dois(text));
225
226 refs.extend(self.extract_cross_refs(text));
228
229 if self.extract_social {
231 refs.extend(self.extract_social_handles(text));
232 }
233
234 refs.sort_by_key(|r| r.start);
236
237 refs
238 }
239
240 fn extract_urls(&self, text: &str) -> Vec<Reference> {
242 let mut refs = Vec::new();
243
244 let url_pattern = regex::Regex::new(r"https?://[^\s<>\[\]{}|\\^`\x00-\x1f\x7f]+").ok();
246
247 if let Some(re) = url_pattern {
248 for m in re.find_iter(text) {
249 let url = m.as_str();
250 let ref_type = self.classify_url(url);
251
252 if !self.extract_wikipedia && ref_type == ReferenceType::WikipediaUrl {
253 continue;
254 }
255
256 let span = TextSpan::from_bytes(text, m.start(), m.end());
257 let mut reference =
258 Reference::new(url, span.char_start, span.char_end, ref_type.clone());
259 reference.url = Some(url.to_string());
260
261 if let Some(id) = self.extract_entity_id(url, &ref_type) {
263 reference.entity_id = Some(id);
264 }
265
266 refs.push(reference);
267 }
268 }
269
270 refs
271 }
272
273 fn classify_url(&self, url: &str) -> ReferenceType {
275 if url.contains("wikipedia.org") {
276 ReferenceType::WikipediaUrl
277 } else if url.contains("wikidata.org") {
278 ReferenceType::WikidataUrl
279 } else if url.contains("dbpedia.org") {
280 ReferenceType::DbpediaUrl
281 } else if url.contains("arxiv.org") {
282 ReferenceType::Arxiv
283 } else if url.contains("doi.org") {
284 ReferenceType::Doi
285 } else {
286 ReferenceType::WebUrl
287 }
288 }
289
290 fn extract_entity_id(&self, url: &str, ref_type: &ReferenceType) -> Option<String> {
292 match ref_type {
293 ReferenceType::WikipediaUrl => {
294 url.split("/wiki/").last().map(|s| s.to_string())
297 }
298 ReferenceType::WikidataUrl => {
299 let re = regex::Regex::new(r"Q\d+").ok()?;
302 re.find(url).map(|m| m.as_str().to_string())
303 }
304 ReferenceType::DbpediaUrl => {
305 url.split("/resource/").last().map(|s| s.to_string())
307 }
308 _ => None,
309 }
310 }
311
312 fn extract_citations(&self, text: &str) -> Vec<Reference> {
314 let mut refs = Vec::new();
315
316 let citation_patterns = [
318 r"\b([A-Z][a-z]+(?:\s+(?:et\s+al\.?|and\s+[A-Z][a-z]+))?),?\s*\(?\d{4}\)?",
319 r"\[([A-Z][a-z]+)\s+\d{4}\]",
320 ];
321
322 for pattern in &citation_patterns {
323 if let Ok(re) = regex::Regex::new(pattern) {
324 for m in re.find_iter(text) {
325 let span = TextSpan::from_bytes(text, m.start(), m.end());
326 refs.push(Reference::new(
327 m.as_str(),
328 span.char_start,
329 span.char_end,
330 ReferenceType::AcademicCitation,
331 ));
332 }
333 }
334 }
335
336 refs
337 }
338
339 fn extract_dois(&self, text: &str) -> Vec<Reference> {
341 let mut refs = Vec::new();
342
343 let doi_pattern = regex::Regex::new(r"10\.\d{4,}/[^\s]+").ok();
345
346 if let Some(re) = doi_pattern {
347 for m in re.find_iter(text) {
348 let span = TextSpan::from_bytes(text, m.start(), m.end());
349 let mut reference = Reference::new(
350 m.as_str(),
351 span.char_start,
352 span.char_end,
353 ReferenceType::Doi,
354 );
355 reference.url = Some(format!("https://doi.org/{}", m.as_str()));
356 refs.push(reference);
357 }
358 }
359
360 refs
361 }
362
363 fn extract_cross_refs(&self, text: &str) -> Vec<Reference> {
365 let mut refs = Vec::new();
366
367 let patterns = [
368 r"\b[Ss]ection\s+\d+(?:\.\d+)*",
369 r"\b[Ff]igure\s+\d+(?:\.\d+)*",
370 r"\b[Tt]able\s+\d+(?:\.\d+)*",
371 r"\b[Aa]ppendix\s+[A-Z]",
372 r"\b[Cc]hapter\s+\d+",
373 ];
374
375 for pattern in &patterns {
376 if let Ok(re) = regex::Regex::new(pattern) {
377 for m in re.find_iter(text) {
378 let span = TextSpan::from_bytes(text, m.start(), m.end());
379 refs.push(Reference::new(
380 m.as_str(),
381 span.char_start,
382 span.char_end,
383 ReferenceType::CrossReference,
384 ));
385 }
386 }
387 }
388
389 refs
390 }
391
392 fn extract_social_handles(&self, text: &str) -> Vec<Reference> {
394 let mut refs = Vec::new();
395
396 if let Ok(re) = regex::Regex::new(r"@[A-Za-z_][A-Za-z0-9_]{0,14}") {
398 for m in re.find_iter(text) {
399 let span = TextSpan::from_bytes(text, m.start(), m.end());
400 refs.push(Reference::new(
401 m.as_str(),
402 span.char_start,
403 span.char_end,
404 ReferenceType::SocialHandle,
405 ));
406 }
407 }
408
409 if let Ok(re) = regex::Regex::new(r"#[A-Za-z][A-Za-z0-9_]*") {
411 for m in re.find_iter(text) {
412 let span = TextSpan::from_bytes(text, m.start(), m.end());
413 refs.push(Reference::new(
414 m.as_str(),
415 span.char_start,
416 span.char_end,
417 ReferenceType::Hashtag,
418 ));
419 }
420 }
421
422 refs
423 }
424}
425
426#[derive(Debug, Clone, Serialize, Deserialize)]
428pub struct ResolvedReference {
429 pub reference: Reference,
431 pub content: Option<String>,
433 pub entities: Vec<ExtractedEntity>,
435 pub metadata: HashMap<String, String>,
437 pub error: Option<String>,
439}
440
441#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct ExtractedEntity {
444 pub text: String,
446 pub entity_type: String,
448 pub confidence: f64,
450 pub start: usize,
452 pub end: usize,
454}
455
456#[derive(Debug, Clone, Default, Serialize, Deserialize)]
460pub struct ReferenceGraph {
461 pub nodes: Vec<String>,
463 pub edges: Vec<(String, String, ReferenceType, f64)>,
465}
466
467impl ReferenceGraph {
468 pub fn new() -> Self {
470 Self::default()
471 }
472
473 pub fn add_document(&mut self, doc_id: &str) {
475 if !self.nodes.contains(&doc_id.to_string()) {
476 self.nodes.push(doc_id.to_string());
477 }
478 }
479
480 pub fn add_reference(
482 &mut self,
483 source_doc: &str,
484 target_doc: &str,
485 ref_type: ReferenceType,
486 weight: f64,
487 ) {
488 self.add_document(source_doc);
489 self.add_document(target_doc);
490 self.edges.push((
491 source_doc.to_string(),
492 target_doc.to_string(),
493 ref_type,
494 weight,
495 ));
496 }
497
498 pub fn get_references(&self, doc_id: &str) -> Vec<(&str, &ReferenceType)> {
500 self.edges
501 .iter()
502 .filter(|(src, _, _, _)| src == doc_id)
503 .map(|(_, tgt, rt, _)| (tgt.as_str(), rt))
504 .collect()
505 }
506
507 pub fn get_referrers(&self, doc_id: &str) -> Vec<(&str, &ReferenceType)> {
509 self.edges
510 .iter()
511 .filter(|(_, tgt, _, _)| tgt == doc_id)
512 .map(|(src, _, rt, _)| (src.as_str(), rt))
513 .collect()
514 }
515
516 pub fn get_depth(&self, doc_id: &str) -> usize {
520 use std::collections::{HashSet, VecDeque};
522
523 let roots: HashSet<&str> = self
524 .nodes
525 .iter()
526 .filter(|n| self.get_referrers(n).is_empty())
527 .map(|s| s.as_str())
528 .collect();
529
530 if roots.contains(doc_id) {
531 return 0;
532 }
533
534 let mut visited: HashSet<&str> = HashSet::new();
535 let mut queue: VecDeque<(&str, usize)> = VecDeque::new();
536
537 for root in &roots {
538 queue.push_back((*root, 0));
539 visited.insert(*root);
540 }
541
542 while let Some((current, depth)) = queue.pop_front() {
543 for (target, _) in self.get_references(current) {
544 if target == doc_id {
545 return depth + 1;
546 }
547 if !visited.contains(target) {
548 visited.insert(target);
549 queue.push_back((target, depth + 1));
550 }
551 }
552 }
553
554 usize::MAX }
556
557 pub fn to_graph_edges(&self) -> Vec<(usize, usize, f64)> {
559 let node_index: HashMap<&str, usize> = self
560 .nodes
561 .iter()
562 .enumerate()
563 .map(|(i, n)| (n.as_str(), i))
564 .collect();
565
566 self.edges
567 .iter()
568 .filter_map(|(src, tgt, _, weight)| {
569 let src_idx = node_index.get(src.as_str())?;
570 let tgt_idx = node_index.get(tgt.as_str())?;
571 Some((*src_idx, *tgt_idx, *weight))
572 })
573 .collect()
574 }
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580 use crate::offset::TextSpan;
581
582 #[test]
583 fn test_wikipedia_url_extraction() {
584 let extractor = ReferenceExtractor::new();
585 let text = "See https://en.wikipedia.org/wiki/Albert_Einstein for more.";
586 let refs = extractor.extract(text);
587
588 assert_eq!(refs.len(), 1);
589 assert_eq!(refs[0].reference_type, ReferenceType::WikipediaUrl);
590 assert_eq!(refs[0].entity_id, Some("Albert_Einstein".to_string()));
591 }
592
593 #[test]
594 fn test_reference_offsets_are_character_offsets_with_unicode_prefix() {
595 let extractor = ReferenceExtractor::new();
596 let text = "Müller: see https://en.wikipedia.org/wiki/Paris for travel tips.";
597 let refs = extractor.extract(text);
598 assert_eq!(refs.len(), 1);
599
600 let r = &refs[0];
601 let extracted = TextSpan::from_chars(text, r.start, r.end).extract(text);
602 assert_eq!(extracted, r.text);
603 }
604
605 #[test]
606 fn test_wikidata_url_extraction() {
607 let extractor = ReferenceExtractor::new();
608 let text = "Entity: https://www.wikidata.org/wiki/Q937";
609 let refs = extractor.extract(text);
610
611 assert_eq!(refs.len(), 1);
612 assert_eq!(refs[0].reference_type, ReferenceType::WikidataUrl);
613 assert_eq!(refs[0].entity_id, Some("Q937".to_string()));
614 }
615
616 #[test]
617 fn test_doi_extraction() {
618 let extractor = ReferenceExtractor::new();
619 let text = "The paper 10.1038/nature12373 shows interesting results.";
620 let refs = extractor.extract(text);
621
622 assert_eq!(refs.len(), 1);
623 assert_eq!(refs[0].reference_type, ReferenceType::Doi);
624 assert!(refs[0].url.as_ref().unwrap().contains("doi.org"));
625 }
626
627 #[test]
628 fn test_cross_reference_extraction() {
629 let extractor = ReferenceExtractor::new();
630 let text = "As shown in Section 3.2 and Figure 5, the results are clear.";
631 let refs = extractor.extract(text);
632
633 assert_eq!(refs.len(), 2);
634 assert!(refs
635 .iter()
636 .all(|r| r.reference_type == ReferenceType::CrossReference));
637 }
638
639 #[test]
640 fn test_social_handle_extraction() {
641 let extractor = ReferenceExtractor::new();
642 let text = "Follow @OpenAI and check #MachineLearning for updates.";
643 let refs = extractor.extract(text);
644
645 assert_eq!(refs.len(), 2);
646 assert!(refs
647 .iter()
648 .any(|r| r.reference_type == ReferenceType::SocialHandle));
649 assert!(refs
650 .iter()
651 .any(|r| r.reference_type == ReferenceType::Hashtag));
652 }
653
654 #[test]
655 fn test_reference_graph() {
656 let mut graph = ReferenceGraph::new();
657
658 graph.add_reference("doc1", "wiki_einstein", ReferenceType::WikipediaUrl, 1.0);
659 graph.add_reference("doc2", "wiki_einstein", ReferenceType::WikipediaUrl, 1.0);
660 graph.add_reference("doc1", "doc3", ReferenceType::WebUrl, 0.5);
661
662 assert_eq!(graph.nodes.len(), 4);
663 assert_eq!(graph.edges.len(), 3);
664
665 let refs = graph.get_references("doc1");
666 assert_eq!(refs.len(), 2);
667
668 let referrers = graph.get_referrers("wiki_einstein");
669 assert_eq!(referrers.len(), 2);
670 }
671
672 #[test]
673 fn test_reference_depth() {
674 let mut graph = ReferenceGraph::new();
675
676 graph.add_document("root");
677 graph.add_reference("root", "level1", ReferenceType::WebUrl, 1.0);
678 graph.add_reference("level1", "level2", ReferenceType::WebUrl, 1.0);
679 graph.add_reference("level2", "level3", ReferenceType::WebUrl, 1.0);
680
681 assert_eq!(graph.get_depth("root"), 0);
682 assert_eq!(graph.get_depth("level1"), 1);
683 assert_eq!(graph.get_depth("level2"), 2);
684 assert_eq!(graph.get_depth("level3"), 3);
685 }
686
687 #[test]
690 fn test_multiple_references_same_text() {
691 let extractor = ReferenceExtractor::new();
692 let text = "See https://en.wikipedia.org/wiki/Paris and \
693 https://en.wikipedia.org/wiki/London for travel info.";
694 let refs = extractor.extract(text);
695
696 assert_eq!(refs.len(), 2);
697 assert!(refs
698 .iter()
699 .any(|r| r.entity_id == Some("Paris".to_string())));
700 assert!(refs
701 .iter()
702 .any(|r| r.entity_id == Some("London".to_string())));
703 }
704
705 #[test]
706 fn test_multilingual_wikipedia_urls() {
707 let extractor = ReferenceExtractor::new();
708
709 let text_ja = "See https://ja.wikipedia.org/wiki/東京 for info.";
711 let refs_ja = extractor.extract(text_ja);
712 assert!(!refs_ja.is_empty());
713
714 let text_zh = "See https://zh.wikipedia.org/wiki/北京 for info.";
716 let refs_zh = extractor.extract(text_zh);
717 assert!(!refs_zh.is_empty());
718
719 let text_ar = "See https://ar.wikipedia.org/wiki/القاهرة for info.";
721 let refs_ar = extractor.extract(text_ar);
722 assert!(!refs_ar.is_empty());
723 }
724
725 #[test]
726 fn test_empty_text() {
727 let extractor = ReferenceExtractor::new();
728 let refs = extractor.extract("");
729
730 assert!(refs.is_empty());
731 }
732
733 #[test]
734 fn test_no_references() {
735 let extractor = ReferenceExtractor::new();
736 let text = "This is plain text with no references at all.";
737 let refs = extractor.extract(text);
738
739 assert!(refs.is_empty());
740 }
741
742 #[test]
743 fn test_reference_type_display() {
744 assert_ne!(ReferenceType::WikipediaUrl, ReferenceType::WikidataUrl);
746 assert_ne!(ReferenceType::Doi, ReferenceType::Arxiv);
747 assert_ne!(ReferenceType::SocialHandle, ReferenceType::Hashtag);
748 }
749
750 #[test]
751 fn test_dbpedia_url_extraction() {
752 let extractor = ReferenceExtractor::new();
753 let text = "Resource: http://dbpedia.org/resource/Albert_Einstein";
754 let refs = extractor.extract(text);
755
756 assert_eq!(refs.len(), 1);
757 assert_eq!(refs[0].reference_type, ReferenceType::DbpediaUrl);
758 }
759
760 #[test]
761 fn test_arxiv_id_extraction() {
762 let extractor = ReferenceExtractor::new();
763 let text = "The paper arXiv:2301.07041 introduced new methods.";
764 let refs = extractor.extract(text);
765
766 let _ = refs;
769 }
770
771 #[test]
772 fn test_reference_serialization() {
773 let reference = Reference {
774 text: "https://example.com".to_string(),
775 start: 0,
776 end: 19,
777 reference_type: ReferenceType::WebUrl,
778 url: Some("https://example.com".to_string()),
779 entity_id: None,
780 title: None,
781 antecedent: None,
782 is_resolved: true,
783 };
784
785 let json = serde_json::to_string(&reference).unwrap();
786 let deserialized: Reference = serde_json::from_str(&json).unwrap();
787
788 assert_eq!(deserialized.text, reference.text);
789 assert_eq!(deserialized.reference_type, reference.reference_type);
790 }
791
792 #[test]
793 fn test_reference_graph_empty() {
794 let graph = ReferenceGraph::new();
795
796 assert!(graph.nodes.is_empty());
797 assert!(graph.edges.is_empty());
798
799 let refs = graph.get_references("nonexistent");
800 assert!(refs.is_empty());
801 }
802
803 #[test]
804 fn test_citation_patterns() {
805 let extractor = ReferenceExtractor::new();
806
807 let text = "According to Smith et al. (2020), the results show...";
809 let refs = extractor.extract(text);
810
811 assert!(
814 refs.is_empty()
815 || refs
816 .iter()
817 .any(|r| r.reference_type == ReferenceType::AcademicCitation)
818 );
819 }
820}