use indexmap::IndexMap;
use petgraph::graph::{DiGraph, NodeIndex};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct Concept {
pub text: String,
pub concept_type: ConceptType,
pub frequency: usize,
pub document_ids: HashSet<String>,
pub chunk_ids: HashSet<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ConceptType {
NounPhrase,
NamedEntity,
Keyword,
TechnicalTerm,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConceptRelation {
pub source: String,
pub target: String,
pub count: usize,
pub shared_chunks: Vec<String>,
pub confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConceptGraph {
pub concepts: IndexMap<String, Concept>,
pub relations: Vec<ConceptRelation>,
#[serde(skip)]
pub graph: DiGraph<String, f32>,
#[serde(skip)]
pub concept_to_node: HashMap<String, NodeIndex>,
}
pub struct ConceptExtractor {
min_length: usize,
max_words: usize,
noun_phrase_pattern: Regex,
capitalized_pattern: Regex,
stopwords: HashSet<String>,
}
impl ConceptExtractor {
pub fn new() -> Self {
Self::with_config(ConceptExtractorConfig::default())
}
pub fn with_config(config: ConceptExtractorConfig) -> Self {
let noun_phrase_pattern =
Regex::new(r"\b[A-Z][a-z]+(?:\s+[A-Z]?[a-z]+){1,4}\b").expect("Invalid regex pattern");
let capitalized_pattern =
Regex::new(r"\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)+\b").expect("Invalid regex pattern");
Self {
min_length: config.min_length,
max_words: config.max_words,
noun_phrase_pattern,
capitalized_pattern,
stopwords: Self::default_stopwords(),
}
}
pub fn extract_concepts(&self, text: &str) -> Vec<String> {
let mut concepts = Vec::new();
for cap in self.capitalized_pattern.captures_iter(text) {
if let Some(phrase) = cap.get(0) {
let phrase_text = phrase.as_str();
if self.is_valid_concept(phrase_text) {
concepts.push(phrase_text.to_string());
}
}
}
for cap in self.noun_phrase_pattern.captures_iter(text) {
if let Some(phrase) = cap.get(0) {
let phrase_text = phrase.as_str();
if self.is_valid_concept(phrase_text) {
concepts.push(phrase_text.to_string());
}
}
}
let keywords = self.extract_keywords(text);
concepts.extend(keywords);
concepts.sort();
concepts.dedup();
concepts
}
fn is_valid_concept(&self, phrase: &str) -> bool {
if phrase.len() < self.min_length {
return false;
}
let word_count = phrase.split_whitespace().count();
if word_count > self.max_words {
return false;
}
let words: Vec<&str> = phrase.split_whitespace().collect();
let stopword_count = words
.iter()
.filter(|w| self.stopwords.contains(&w.to_lowercase()))
.count();
if stopword_count > words.len() / 2 {
return false;
}
true
}
fn extract_keywords(&self, text: &str) -> Vec<String> {
let mut word_freq: HashMap<String, usize> = HashMap::new();
for word in text.split_whitespace() {
let normalized = word
.to_lowercase()
.trim_matches(|c: char| !c.is_alphanumeric())
.to_string();
if normalized.len() >= self.min_length && !self.stopwords.contains(&normalized) {
*word_freq.entry(normalized).or_insert(0) += 1;
}
}
let mut keywords: Vec<_> = word_freq.into_iter().collect();
keywords.sort_by(|a, b| b.1.cmp(&a.1));
keywords.into_iter()
.take(20) .filter(|(_, freq)| *freq >= 2) .map(|(word, _)| word)
.collect()
}
fn default_stopwords() -> HashSet<String> {
vec![
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with",
"by", "from", "as", "is", "was", "are", "were", "be", "been", "being", "have", "has",
"had", "do", "does", "did", "will", "would", "should", "could", "may", "might", "must",
"can", "this", "that", "these", "those", "it", "its", "i", "you", "he", "she", "we",
"they", "them", "their", "what", "which", "who", "when", "where", "why", "how", "all",
"each", "every", "both", "few", "more", "most", "other", "some", "such", "no", "nor",
"not", "only", "own", "same", "so", "than", "too", "very", "just", "now",
]
.into_iter()
.map(String::from)
.collect()
}
}
#[derive(Debug, Clone)]
pub struct ConceptExtractorConfig {
pub min_length: usize,
pub max_words: usize,
}
impl Default for ConceptExtractorConfig {
fn default() -> Self {
Self {
min_length: 3,
max_words: 5,
}
}
}
pub struct ConceptGraphBuilder {
concepts: IndexMap<String, Concept>,
document_concepts: HashMap<String, Vec<String>>,
chunk_concepts: HashMap<String, Vec<String>>,
co_occurrence_threshold: usize,
}
impl ConceptGraphBuilder {
pub fn new() -> Self {
Self {
concepts: IndexMap::new(),
document_concepts: HashMap::new(),
chunk_concepts: HashMap::new(),
co_occurrence_threshold: 1,
}
}
pub fn with_co_occurrence_threshold(mut self, threshold: usize) -> Self {
self.co_occurrence_threshold = threshold;
self
}
pub fn add_document_concepts(&mut self, document_id: &str, extracted_concepts: Vec<String>) {
self.document_concepts
.insert(document_id.to_string(), extracted_concepts.clone());
for concept_text in extracted_concepts {
let concept = self
.concepts
.entry(concept_text.clone())
.or_insert_with(|| Concept {
text: concept_text.clone(),
concept_type: ConceptType::NounPhrase,
frequency: 0,
document_ids: HashSet::new(),
chunk_ids: HashSet::new(),
});
concept.frequency += 1;
concept.document_ids.insert(document_id.to_string());
}
}
pub fn add_chunk_concepts(&mut self, chunk_id: &str, extracted_concepts: Vec<String>) {
self.chunk_concepts
.insert(chunk_id.to_string(), extracted_concepts.clone());
for concept_text in extracted_concepts {
if let Some(concept) = self.concepts.get_mut(&concept_text) {
concept.chunk_ids.insert(chunk_id.to_string());
}
}
}
pub fn build(self) -> ConceptGraph {
let mut graph = DiGraph::new();
let mut concept_to_node = HashMap::new();
for (concept_text, _) in &self.concepts {
let node_idx = graph.add_node(concept_text.clone());
concept_to_node.insert(concept_text.clone(), node_idx);
}
let relations = self.build_co_occurrence_relations();
for relation in &relations {
if let (Some(&source_idx), Some(&target_idx)) = (
concept_to_node.get(&relation.source),
concept_to_node.get(&relation.target),
) {
graph.add_edge(source_idx, target_idx, relation.confidence);
}
}
ConceptGraph {
concepts: self.concepts,
relations,
graph,
concept_to_node,
}
}
fn build_co_occurrence_relations(&self) -> Vec<ConceptRelation> {
let mut relations = Vec::new();
let concept_list: Vec<_> = self.concepts.keys().collect();
for i in 0..concept_list.len() {
for j in (i + 1)..concept_list.len() {
let concept_a = concept_list[i];
let concept_b = concept_list[j];
if let (Some(concept_a_data), Some(concept_b_data)) =
(self.concepts.get(concept_a), self.concepts.get(concept_b))
{
let shared_chunks: Vec<String> = concept_a_data
.chunk_ids
.intersection(&concept_b_data.chunk_ids)
.cloned()
.collect();
if shared_chunks.len() >= self.co_occurrence_threshold {
let confidence = self.calculate_confidence(
&concept_a_data.chunk_ids,
&concept_b_data.chunk_ids,
&shared_chunks,
);
relations.push(ConceptRelation {
source: concept_a.clone(),
target: concept_b.clone(),
count: shared_chunks.len(),
shared_chunks,
confidence,
});
}
}
}
}
relations
}
fn calculate_confidence(
&self,
chunks_a: &HashSet<String>,
chunks_b: &HashSet<String>,
shared: &[String],
) -> f32 {
let intersection = shared.len();
let union = chunks_a.len() + chunks_b.len() - intersection;
if union == 0 {
return 0.0;
}
intersection as f32 / union as f32
}
}
impl Default for ConceptExtractor {
fn default() -> Self {
Self::new()
}
}
impl Default for ConceptGraphBuilder {
fn default() -> Self {
Self::new()
}
}
impl ConceptGraph {
pub fn get_related_concepts(&self, concept: &str, max_results: usize) -> Vec<String> {
if let Some(&node_idx) = self.concept_to_node.get(concept) {
let mut related = Vec::new();
for edge in self.graph.edges(node_idx) {
related.push((self.graph[edge.target()].clone(), *edge.weight()));
}
related.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
related
.into_iter()
.take(max_results)
.map(|(c, _)| c)
.collect()
} else {
Vec::new()
}
}
pub fn concept_count(&self) -> usize {
self.concepts.len()
}
pub fn relation_count(&self) -> usize {
self.relations.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_concept_extraction() {
let extractor = ConceptExtractor::new();
let text = "Machine Learning and Artificial Intelligence are transforming Natural Language Processing.";
let concepts = extractor.extract_concepts(text);
assert!(!concepts.is_empty());
assert!(concepts
.iter()
.any(|c| c.contains("Machine") || c.contains("Learning")));
}
#[test]
fn test_concept_graph_building() {
let mut builder = ConceptGraphBuilder::new();
builder.add_document_concepts(
"doc1",
vec!["concept_a".to_string(), "concept_b".to_string()],
);
builder.add_chunk_concepts(
"chunk1",
vec!["concept_a".to_string(), "concept_b".to_string()],
);
let graph = builder.build();
assert_eq!(graph.concept_count(), 2);
assert!(!graph.relations.is_empty());
}
#[test]
fn test_stopword_filtering() {
let extractor = ConceptExtractor::new();
assert!(!extractor.is_valid_concept("the the the"));
assert!(extractor.is_valid_concept("Machine Learning"));
}
}