use std::collections::HashSet;
use crate::document::{
DocumentTree, NodeId, NodeReference, RefType, ReferenceExtractor, RetrievalIndex,
};
#[derive(Debug, Clone)]
pub struct ReferenceConfig {
pub max_depth: usize,
pub max_references: usize,
pub follow_pages: bool,
pub follow_tables_figures: bool,
pub min_confidence: f32,
pub include_types: Vec<RefType>,
}
impl Default for ReferenceConfig {
fn default() -> Self {
Self {
max_depth: 3,
max_references: 10,
follow_pages: true,
follow_tables_figures: true,
min_confidence: 0.5,
include_types: vec![
RefType::Section,
RefType::Appendix,
RefType::Table,
RefType::Figure,
RefType::Page,
],
}
}
}
impl ReferenceConfig {
pub fn conservative() -> Self {
Self {
max_depth: 2,
max_references: 5,
..Default::default()
}
}
pub fn aggressive() -> Self {
Self {
max_depth: 5,
max_references: 20,
..Default::default()
}
}
pub fn should_follow(&self, ref_type: RefType) -> bool {
if !self.include_types.contains(&ref_type) {
return false;
}
match ref_type {
RefType::Page => self.follow_pages,
RefType::Table | RefType::Figure => self.follow_tables_figures,
_ => true,
}
}
}
#[derive(Debug, Clone)]
pub struct FollowedReference {
pub source_node: NodeId,
pub reference: NodeReference,
pub target_node: Option<NodeId>,
pub depth: usize,
}
impl FollowedReference {
pub fn is_resolved(&self) -> bool {
self.target_node.is_some()
}
}
#[derive(Debug, Clone)]
pub struct ReferenceFollower {
config: ReferenceConfig,
}
impl Default for ReferenceFollower {
fn default() -> Self {
Self::new(ReferenceConfig::default())
}
}
impl ReferenceFollower {
pub fn new(config: ReferenceConfig) -> Self {
Self { config }
}
pub fn with_defaults() -> Self {
Self::default()
}
pub fn follow_from_node(
&self,
tree: &DocumentTree,
index: &RetrievalIndex,
node_id: NodeId,
) -> Vec<FollowedReference> {
let mut results = Vec::new();
let mut visited = HashSet::new();
visited.insert(node_id);
self.follow_from_node_inner(tree, index, node_id, 0, &mut visited, &mut results);
results.sort_by(|a, b| {
b.reference
.confidence
.partial_cmp(&a.reference.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(self.config.max_references);
results
}
fn follow_from_node_inner(
&self,
tree: &DocumentTree,
index: &RetrievalIndex,
node_id: NodeId,
depth: usize,
visited: &mut HashSet<NodeId>,
results: &mut Vec<FollowedReference>,
) {
if depth >= self.config.max_depth {
return;
}
if results.len() >= self.config.max_references {
return;
}
let node = match tree.get(node_id) {
Some(n) => n,
None => return,
};
let _refs = if !node.references.is_empty() {
node.references.clone()
} else {
ReferenceExtractor::extract(&node.content)
};
let resolved_refs = ReferenceExtractor::extract_and_resolve(&node.content, tree, index);
for r#ref in resolved_refs {
if !self.config.should_follow(r#ref.ref_type) {
continue;
}
if r#ref.confidence < self.config.min_confidence {
continue;
}
let followed = FollowedReference {
source_node: node_id,
reference: r#ref.clone(),
target_node: r#ref.target_node,
depth,
};
results.push(followed);
if let Some(target_id) = r#ref.target_node {
if !visited.contains(&target_id) {
visited.insert(target_id);
self.follow_from_node_inner(
tree,
index,
target_id,
depth + 1,
visited,
results,
);
}
}
}
}
pub fn follow_from_nodes(
&self,
tree: &DocumentTree,
index: &RetrievalIndex,
node_ids: &[NodeId],
) -> Vec<FollowedReference> {
let mut all_results = Vec::new();
let mut visited = HashSet::new();
visited.extend(node_ids.iter().copied());
for &node_id in node_ids {
self.follow_from_node_inner(tree, index, node_id, 0, &mut visited, &mut all_results);
}
let mut seen_targets = HashSet::new();
all_results.retain(|r| {
if let Some(target) = r.target_node {
seen_targets.insert(target)
} else {
true }
});
all_results.sort_by(|a, b| {
b.reference
.confidence
.partial_cmp(&a.reference.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
all_results.truncate(self.config.max_references);
all_results
}
pub fn find_reachable_nodes(
&self,
tree: &DocumentTree,
index: &RetrievalIndex,
start_node: NodeId,
) -> HashSet<NodeId> {
let mut reachable = HashSet::new();
let mut stack = vec![start_node];
while let Some(node_id) = stack.pop() {
if reachable.contains(&node_id) {
continue;
}
reachable.insert(node_id);
if let Some(node) = tree.get(node_id) {
let _refs = if !node.references.is_empty() {
node.references.clone()
} else {
ReferenceExtractor::extract(&node.content)
};
let resolved = ReferenceExtractor::extract_and_resolve(&node.content, tree, index);
for r#ref in resolved {
if self.config.should_follow(r#ref.ref_type)
&& r#ref.confidence >= self.config.min_confidence
{
if let Some(target_id) = r#ref.target_node {
if !reachable.contains(&target_id) {
stack.push(target_id);
}
}
}
}
}
if reachable.len() >= self.config.max_references * 2 {
break;
}
}
reachable
}
pub fn config(&self) -> &ReferenceConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct ReferenceExpansion {
pub original_nodes: Vec<NodeId>,
pub expanded_nodes: Vec<NodeId>,
pub references: Vec<FollowedReference>,
pub depth: usize,
}
impl ReferenceExpansion {
pub fn all_nodes(&self) -> Vec<NodeId> {
let mut all = self.original_nodes.clone();
all.extend(self.expanded_nodes.iter().copied());
all
}
pub fn new_nodes(&self) -> &[NodeId] {
&self.expanded_nodes
}
pub fn has_expansion(&self) -> bool {
!self.expanded_nodes.is_empty()
}
}
pub fn expand_with_references(
tree: &DocumentTree,
index: &RetrievalIndex,
initial_nodes: &[NodeId],
config: Option<ReferenceConfig>,
) -> ReferenceExpansion {
let config = config.unwrap_or_default();
let follower = ReferenceFollower::new(config);
let references = follower.follow_from_nodes(tree, index, initial_nodes);
let mut expanded_nodes = Vec::new();
let mut seen = HashSet::new();
seen.extend(initial_nodes.iter().copied());
for r#ref in &references {
if let Some(target_id) = r#ref.target_node {
if !seen.contains(&target_id) {
seen.insert(target_id);
expanded_nodes.push(target_id);
}
}
}
let depth = references.iter().map(|r| r.depth).max().unwrap_or(0);
ReferenceExpansion {
original_nodes: initial_nodes.to_vec(),
expanded_nodes,
references,
depth,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reference_config_default() {
let config = ReferenceConfig::default();
assert_eq!(config.max_depth, 3);
assert_eq!(config.max_references, 10);
assert!(config.follow_pages);
assert!(config.follow_tables_figures);
}
#[test]
fn test_reference_config_conservative() {
let config = ReferenceConfig::conservative();
assert_eq!(config.max_depth, 2);
assert_eq!(config.max_references, 5);
}
#[test]
fn test_reference_config_aggressive() {
let config = ReferenceConfig::aggressive();
assert_eq!(config.max_depth, 5);
assert_eq!(config.max_references, 20);
}
#[test]
fn test_reference_config_should_follow() {
let config = ReferenceConfig::default();
assert!(config.should_follow(RefType::Section));
assert!(config.should_follow(RefType::Appendix));
assert!(config.should_follow(RefType::Table));
assert!(config.should_follow(RefType::Page));
assert!(!config.should_follow(RefType::Unknown));
}
#[test]
fn test_followed_reference_is_resolved() {
use indextree::Arena;
let mut arena = Arena::new();
let node = arena.new_node(crate::document::TreeNode::default());
let node_id = NodeId(node);
let resolved = FollowedReference {
source_node: node_id,
reference: NodeReference::new(
"Section 2.1".to_string(),
"2.1".to_string(),
RefType::Section,
0,
),
target_node: Some(node_id),
depth: 0,
};
let unresolved = FollowedReference {
source_node: node_id,
reference: NodeReference::new(
"Section 99".to_string(),
"99".to_string(),
RefType::Section,
0,
),
target_node: None,
depth: 0,
};
assert!(resolved.is_resolved());
assert!(!unresolved.is_resolved());
}
#[test]
fn test_reference_expansion() {
let expansion = ReferenceExpansion {
original_nodes: vec![],
expanded_nodes: vec![],
references: vec![],
depth: 0,
};
assert!(!expansion.has_expansion());
assert_eq!(expansion.all_nodes().len(), 0);
}
}