use crate::AprenderError;
#[derive(Debug, Clone, PartialEq)]
pub struct Entities {
pub emails: Vec<String>,
pub urls: Vec<String>,
pub phone_numbers: Vec<String>,
pub mentions: Vec<String>,
pub hashtags: Vec<String>,
pub named_entities: Vec<String>,
}
impl Entities {
#[must_use]
pub fn new() -> Self {
Self {
emails: Vec::new(),
urls: Vec::new(),
phone_numbers: Vec::new(),
mentions: Vec::new(),
hashtags: Vec::new(),
named_entities: Vec::new(),
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.emails.is_empty()
&& self.urls.is_empty()
&& self.phone_numbers.is_empty()
&& self.mentions.is_empty()
&& self.hashtags.is_empty()
&& self.named_entities.is_empty()
}
#[must_use]
pub fn total_count(&self) -> usize {
self.emails.len()
+ self.urls.len()
+ self.phone_numbers.len()
+ self.mentions.len()
+ self.hashtags.len()
+ self.named_entities.len()
}
}
impl Default for Entities {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct EntityExtractor {
extract_named_entities: bool,
}
impl EntityExtractor {
#[must_use]
pub fn new() -> Self {
Self {
extract_named_entities: true,
}
}
#[must_use]
pub fn with_named_entities(mut self, enable: bool) -> Self {
self.extract_named_entities = enable;
self
}
pub fn extract(&self, text: &str) -> Result<Entities, AprenderError> {
let mut entities = Entities::new();
entities.emails = EntityExtractor::extract_emails(text);
entities.urls = EntityExtractor::extract_urls(text);
entities.phone_numbers = self.extract_phone_numbers(text);
entities.mentions = EntityExtractor::extract_mentions(text);
entities.hashtags = EntityExtractor::extract_hashtags(text);
if self.extract_named_entities {
entities.named_entities = EntityExtractor::extract_capitalized_words(text);
}
Ok(entities)
}
fn extract_emails(text: &str) -> Vec<String> {
let mut emails = Vec::new();
for word in text.split_whitespace() {
if EntityExtractor::is_email(word) {
emails.push(word.to_string());
}
}
emails
}
fn is_email(s: &str) -> bool {
let at_count = s.chars().filter(|&c| c == '@').count();
if at_count != 1 {
return false;
}
let parts: Vec<&str> = s.split('@').collect();
if parts.len() != 2 {
return false;
}
let local = parts[0];
let domain = parts[1];
if local.is_empty() {
return false;
}
if !domain.contains('.') || domain.is_empty() {
return false;
}
let domain_parts: Vec<&str> = domain.split('.').collect();
if domain_parts.iter().any(|p| p.is_empty()) {
return false;
}
true
}
fn extract_urls(text: &str) -> Vec<String> {
let mut urls = Vec::new();
for word in text.split_whitespace() {
if word.starts_with("http://") || word.starts_with("https://") {
let url = word.trim_end_matches(|c: char| c.is_ascii_punctuation());
if !url.is_empty() {
urls.push(url.to_string());
}
}
}
urls
}
#[allow(clippy::unused_self)]
fn extract_phone_numbers(&self, text: &str) -> Vec<String> {
let mut phones = Vec::new();
for word in text.split_whitespace() {
if self.is_phone_number(word) {
phones.push(word.to_string());
}
}
phones
}
#[allow(clippy::unused_self)]
fn is_phone_number(&self, s: &str) -> bool {
let digits: String = s.chars().filter(char::is_ascii_digit).collect();
if digits.len() != 10 {
return false;
}
let formats = ["###-###-####", "##########", "(###) ###-####"];
for format in &formats {
if Self::matches_phone_format(s, format) {
return true;
}
}
false
}
fn matches_phone_format(s: &str, format: &str) -> bool {
if s.len() != format.len() {
return false;
}
for (c1, c2) in s.chars().zip(format.chars()) {
match c2 {
'#' => {
if !c1.is_ascii_digit() {
return false;
}
}
_ => {
if c1 != c2 {
return false;
}
}
}
}
true
}
fn extract_mentions(text: &str) -> Vec<String> {
let mut mentions = Vec::new();
for word in text.split_whitespace() {
if word.starts_with('@') && word.len() > 1 {
let mention = word[1..].trim_end_matches(|c: char| c.is_ascii_punctuation());
if !mention.is_empty() && mention.chars().all(|c| c.is_alphanumeric() || c == '_') {
mentions.push(format!("@{mention}"));
}
}
}
mentions
}
fn extract_hashtags(text: &str) -> Vec<String> {
let mut hashtags = Vec::new();
for word in text.split_whitespace() {
if word.starts_with('#') && word.len() > 1 {
let hashtag = word[1..].trim_end_matches(|c: char| c.is_ascii_punctuation());
if !hashtag.is_empty() && hashtag.chars().all(|c| c.is_alphanumeric() || c == '_') {
hashtags.push(format!("#{hashtag}"));
}
}
}
hashtags
}
fn extract_capitalized_words(text: &str) -> Vec<String> {
let mut entities = Vec::new();
for word in text.split_whitespace() {
let clean: String = word.chars().filter(|c| c.is_alphabetic()).collect();
if clean.is_empty() {
continue;
}
let first_char = clean.chars().next().expect("checked is_empty");
if first_char.is_uppercase() && clean.len() > 1 {
if !clean.chars().skip(1).all(char::is_uppercase) {
entities.push(clean);
}
}
}
let mut seen = std::collections::HashSet::new();
entities.retain(|e| seen.insert(e.clone()));
entities
}
}
impl Default for EntityExtractor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[path = "entities_tests.rs"]
mod tests;
#[cfg(test)]
#[path = "entities_contract_falsify.rs"]
mod entities_contract_falsify;