use regex::Regex;
use serde::{Deserialize, Serialize};
use std::sync::LazyLock;
use super::NodeId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum RefType {
Section,
Appendix,
Table,
Figure,
Page,
Equation,
Footnote,
Listing,
Unknown,
}
impl std::fmt::Display for RefType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RefType::Section => write!(f, "Section"),
RefType::Appendix => write!(f, "Appendix"),
RefType::Table => write!(f, "Table"),
RefType::Figure => write!(f, "Figure"),
RefType::Page => write!(f, "Page"),
RefType::Equation => write!(f, "Equation"),
RefType::Footnote => write!(f, "Footnote"),
RefType::Listing => write!(f, "Listing"),
RefType::Unknown => write!(f, "Reference"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeReference {
pub ref_text: String,
pub target_id: String,
pub ref_type: RefType,
pub target_node: Option<NodeId>,
pub confidence: f32,
pub position: usize,
}
impl NodeReference {
pub fn new(ref_text: String, target_id: String, ref_type: RefType, position: usize) -> Self {
Self {
ref_text,
target_id,
ref_type,
target_node: None,
confidence: 0.0,
position,
}
}
pub fn resolved(
ref_text: String,
target_id: String,
ref_type: RefType,
position: usize,
target_node: NodeId,
confidence: f32,
) -> Self {
Self {
ref_text,
target_id,
ref_type,
target_node: Some(target_node),
confidence,
position,
}
}
pub fn is_resolved(&self) -> bool {
self.target_node.is_some()
}
}
static SECTION_PATTERNS: LazyLock<Vec<(Regex, RefType)>> = LazyLock::new(|| {
vec![
(
Regex::new(r"(?i)(?:see\s+)?(?:section|sec\.?)\s+([\d.]+)").unwrap(),
RefType::Section,
),
(
Regex::new(r"(?i)(?:see\s+)?(?:chapter|ch\.?)\s+(\d+)").unwrap(),
RefType::Section,
),
]
});
static APPENDIX_PATTERNS: LazyLock<Vec<(Regex, RefType)>> = LazyLock::new(|| {
vec![
(
Regex::new(r"(?i)(?:see\s+)?(?:appendix|app\.?)\s+([A-Z]|[a-z])").unwrap(),
RefType::Appendix,
),
]
});
static TABLE_PATTERNS: LazyLock<Vec<(Regex, RefType)>> = LazyLock::new(|| {
vec![
(
Regex::new(r"(?i)(?:see\s+)?(?:table|tbl\.?)\s+([\d.]+)").unwrap(),
RefType::Table,
),
]
});
static FIGURE_PATTERNS: LazyLock<Vec<(Regex, RefType)>> = LazyLock::new(|| {
vec![
(
Regex::new(r"(?i)(?:see\s+)?(?:figure|fig\.?)\s+([\d.]+)").unwrap(),
RefType::Figure,
),
]
});
static PAGE_PATTERNS: LazyLock<Vec<(Regex, RefType)>> = LazyLock::new(|| {
vec![
(
Regex::new(r"(?i)(?:see\s+)?(?:page|p\.?)\s+(\d+)").unwrap(),
RefType::Page,
),
]
});
static EQUATION_PATTERNS: LazyLock<Vec<(Regex, RefType)>> = LazyLock::new(|| {
vec![
(
Regex::new(r"(?i)(?:see\s+)?(?:equation|eq\.?)\s+([\d.]+)").unwrap(),
RefType::Equation,
),
]
});
pub struct ReferenceExtractor;
impl ReferenceExtractor {
pub fn extract(text: &str) -> Vec<NodeReference> {
let mut references = Vec::new();
for (regex, ref_type) in SECTION_PATTERNS.iter() {
for cap in regex.captures_iter(text) {
if let (Some(full_match), Some(target)) = (cap.get(0), cap.get(1)) {
references.push(NodeReference::new(
full_match.as_str().to_string(),
target.as_str().to_string(),
*ref_type,
full_match.start(),
));
}
}
}
for (regex, ref_type) in APPENDIX_PATTERNS.iter() {
for cap in regex.captures_iter(text) {
if let (Some(full_match), Some(target)) = (cap.get(0), cap.get(1)) {
references.push(NodeReference::new(
full_match.as_str().to_string(),
target.as_str().to_uppercase(), *ref_type,
full_match.start(),
));
}
}
}
for (regex, ref_type) in TABLE_PATTERNS.iter() {
for cap in regex.captures_iter(text) {
if let (Some(full_match), Some(target)) = (cap.get(0), cap.get(1)) {
references.push(NodeReference::new(
full_match.as_str().to_string(),
target.as_str().to_string(),
*ref_type,
full_match.start(),
));
}
}
}
for (regex, ref_type) in FIGURE_PATTERNS.iter() {
for cap in regex.captures_iter(text) {
if let (Some(full_match), Some(target)) = (cap.get(0), cap.get(1)) {
references.push(NodeReference::new(
full_match.as_str().to_string(),
target.as_str().to_string(),
*ref_type,
full_match.start(),
));
}
}
}
for (regex, ref_type) in PAGE_PATTERNS.iter() {
for cap in regex.captures_iter(text) {
if let (Some(full_match), Some(target)) = (cap.get(0), cap.get(1)) {
references.push(NodeReference::new(
full_match.as_str().to_string(),
target.as_str().to_string(),
*ref_type,
full_match.start(),
));
}
}
}
for (regex, ref_type) in EQUATION_PATTERNS.iter() {
for cap in regex.captures_iter(text) {
if let (Some(full_match), Some(target)) = (cap.get(0), cap.get(1)) {
references.push(NodeReference::new(
full_match.as_str().to_string(),
target.as_str().to_string(),
*ref_type,
full_match.start(),
));
}
}
}
references.sort_by_key(|r| r.position);
references.dedup_by(|a, b| a.position == b.position);
references
}
pub fn extract_and_resolve(
text: &str,
tree: &super::DocumentTree,
index: &super::RetrievalIndex,
) -> Vec<NodeReference> {
let mut references = Self::extract(text);
for ref_mut in &mut references {
ref_mut.target_node = Self::resolve_reference(ref_mut, tree, index);
if ref_mut.target_node.is_some() {
ref_mut.confidence = 0.8;
}
}
references
}
fn resolve_reference(
r#ref: &NodeReference,
tree: &super::DocumentTree,
index: &super::RetrievalIndex,
) -> Option<NodeId> {
match r#ref.ref_type {
RefType::Section => {
if let Some(node_id) = index.find_by_structure(&r#ref.target_id) {
return Some(node_id);
}
for (structure, &node_id) in index.structures() {
if structure.starts_with(&format!("{}.", r#ref.target_id))
|| structure.as_str() == r#ref.target_id
{
return Some(node_id);
}
}
None
}
RefType::Appendix => {
for node_id in tree.traverse() {
if let Some(node) = tree.get(node_id) {
let title_lower = node.title.to_lowercase();
if title_lower
.starts_with(&format!("appendix {}", r#ref.target_id.to_lowercase()))
|| title_lower == format!("appendix {}", r#ref.target_id.to_lowercase())
{
return Some(node_id);
}
}
}
None
}
RefType::Table => {
for node_id in tree.traverse() {
if let Some(node) = tree.get(node_id) {
let title_lower = node.title.to_lowercase();
if title_lower.contains(&format!("table {}", r#ref.target_id)) {
return Some(node_id);
}
}
}
None
}
RefType::Figure => {
for node_id in tree.traverse() {
if let Some(node) = tree.get(node_id) {
let title_lower = node.title.to_lowercase();
if title_lower.contains(&format!("figure {}", r#ref.target_id))
|| title_lower.contains(&format!("fig {}", r#ref.target_id))
{
return Some(node_id);
}
}
}
None
}
RefType::Page => {
if let Ok(page) = r#ref.target_id.parse::<usize>() {
return index.find_by_page(page);
}
None
}
_ => None,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ReferenceResolver {
cache: std::collections::HashMap<String, Option<NodeId>>,
}
impl ReferenceResolver {
pub fn new() -> Self {
Self::default()
}
pub fn resolve_batch(
&mut self,
references: &[NodeReference],
tree: &super::DocumentTree,
index: &super::RetrievalIndex,
) {
for r#ref in references {
if !self.cache.contains_key(&r#ref.ref_text) {
let resolved = ReferenceExtractor::resolve_reference(r#ref, tree, index);
self.cache.insert(r#ref.ref_text.clone(), resolved);
}
}
}
pub fn get(&self, ref_text: &str) -> Option<Option<NodeId>> {
self.cache.get(ref_text).copied()
}
pub fn clear(&mut self) {
self.cache.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_section_references() {
let text = "For details, see Section 2.1 and Section 3.2.1.";
let refs = ReferenceExtractor::extract(text);
for r in &refs {
eprintln!(
"Extracted: {:?} '{}' -> '{}'",
r.ref_type, r.ref_text, r.target_id
);
}
assert!(
refs.iter()
.any(|r| r.ref_type == RefType::Section && r.target_id == "2.1")
);
assert!(refs.iter().any(|r| r.ref_type == RefType::Section));
}
#[test]
fn test_extract_appendix_references() {
let text = "See Appendix G for more information.";
let refs = ReferenceExtractor::extract(text);
assert!(
refs.iter()
.any(|r| r.ref_type == RefType::Appendix && r.target_id == "G")
);
}
#[test]
fn test_extract_table_references() {
let text = "The data is shown in Table 5.3 and Table 1.";
let refs = ReferenceExtractor::extract(text);
for r in &refs {
eprintln!(
"Extracted: {:?} '{}' -> '{}'",
r.ref_type, r.ref_text, r.target_id
);
}
assert!(
refs.iter()
.any(|r| r.ref_type == RefType::Table && r.target_id == "5.3")
);
assert!(
refs.iter().any(
|r| r.ref_type == RefType::Table && (r.target_id == "1" || r.target_id == "1.")
)
);
}
#[test]
fn test_extract_figure_references() {
let text = "As shown in Figure 2.1 and fig. 3.";
let refs = ReferenceExtractor::extract(text);
for r in &refs {
eprintln!(
"Extracted: {:?} '{}' -> '{}'",
r.ref_type, r.ref_text, r.target_id
);
}
assert!(
refs.iter()
.any(|r| r.ref_type == RefType::Figure && r.target_id == "2.1")
);
assert!(
refs.iter()
.any(|r| r.ref_type == RefType::Figure
&& (r.target_id == "3" || r.target_id == "3."))
);
}
#[test]
fn test_extract_page_references() {
let text = "See page 42 for details.";
let refs = ReferenceExtractor::extract(text);
assert!(
refs.iter()
.any(|r| r.ref_type == RefType::Page && r.target_id == "42")
);
}
#[test]
fn test_extract_mixed_references() {
let text = "For details, see Section 2.1, Appendix G, and Table 5.3.";
let refs = ReferenceExtractor::extract(text);
assert_eq!(refs.len(), 3);
assert!(refs.iter().any(|r| r.ref_type == RefType::Section));
assert!(refs.iter().any(|r| r.ref_type == RefType::Appendix));
assert!(refs.iter().any(|r| r.ref_type == RefType::Table));
}
#[test]
fn test_ref_type_display() {
assert_eq!(format!("{}", RefType::Section), "Section");
assert_eq!(format!("{}", RefType::Appendix), "Appendix");
assert_eq!(format!("{}", RefType::Table), "Table");
}
#[test]
fn test_node_reference_is_resolved() {
let unresolved = NodeReference::new(
"Section 2.1".to_string(),
"2.1".to_string(),
RefType::Section,
0,
);
assert!(!unresolved.is_resolved());
}
}