#[cfg(feature = "rograg")]
use crate::Result;
#[cfg(feature = "rograg")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "rograg")]
use std::collections::HashMap;
#[cfg(feature = "rograg")]
use strum::{Display as StrumDisplay, EnumString};
#[cfg(feature = "rograg")]
use thiserror::Error;
#[cfg(feature = "rograg")]
#[derive(Error, Debug)]
pub enum IntentClassificationError {
#[error("Unable to classify query intent: {query}")]
CannotClassify {
query: String,
},
#[error("Ambiguous intent detected: {intents:?}")]
AmbiguousIntent {
intents: Vec<QueryIntent>,
},
#[error("Insufficient confidence for classification: {confidence}")]
InsufficientConfidence {
confidence: f32,
},
}
#[cfg(feature = "rograg")]
#[derive(Debug, Clone, PartialEq, Eq, Hash, StrumDisplay, EnumString, Serialize, Deserialize)]
pub enum QueryIntent {
Factual,
Definitional,
Relational,
Temporal,
Causal,
Comparative,
Exploratory,
Summary,
Inappropriate,
Ambiguous,
}
#[cfg(feature = "rograg")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IntentResult {
pub primary_intent: QueryIntent,
pub secondary_intents: Vec<(QueryIntent, f32)>,
pub confidence: f32,
pub should_refuse: bool,
pub refusal_reason: Option<String>,
pub suggested_reformulation: Option<String>,
pub complexity_score: f32,
}
#[cfg(feature = "rograg")]
impl Default for IntentResult {
fn default() -> Self {
Self {
primary_intent: QueryIntent::Exploratory,
secondary_intents: vec![],
confidence: 0.0,
should_refuse: false,
refusal_reason: None,
suggested_reformulation: None,
complexity_score: 0.0,
}
}
}
#[cfg(feature = "rograg")]
#[derive(Debug, Clone)]
pub struct IntentClassificationConfig {
pub confidence_threshold: f32,
pub refusal_threshold: f32,
pub enable_inappropriate_detection: bool,
pub enable_ambiguity_detection: bool,
pub suggest_reformulations: bool,
}
#[cfg(feature = "rograg")]
impl Default for IntentClassificationConfig {
fn default() -> Self {
Self {
confidence_threshold: 0.7,
refusal_threshold: 0.8,
enable_inappropriate_detection: true,
enable_ambiguity_detection: true,
suggest_reformulations: true,
}
}
}
#[cfg(feature = "rograg")]
pub struct IntentClassifier {
config: IntentClassificationConfig,
intent_patterns: HashMap<QueryIntent, Vec<IntentPattern>>,
inappropriate_patterns: Vec<regex::Regex>,
}
#[cfg(feature = "rograg")]
#[derive(Debug, Clone)]
struct IntentPattern {
keywords: Vec<String>,
patterns: Vec<regex::Regex>,
weight: f32,
requires_all: bool, }
#[cfg(feature = "rograg")]
impl IntentClassifier {
pub fn new() -> Result<Self> {
let config = IntentClassificationConfig::default();
let mut classifier = Self {
config,
intent_patterns: HashMap::new(),
inappropriate_patterns: vec![],
};
classifier.initialize_patterns()?;
Ok(classifier)
}
pub fn with_config(config: IntentClassificationConfig) -> Result<Self> {
let mut classifier = Self {
config,
intent_patterns: HashMap::new(),
inappropriate_patterns: vec![],
};
classifier.initialize_patterns()?;
Ok(classifier)
}
fn initialize_patterns(&mut self) -> Result<()> {
self.add_intent_pattern(
QueryIntent::Factual,
IntentPattern {
keywords: ["what", "which", "how many", "how much"]
.iter()
.map(|s| s.to_string())
.collect(),
patterns: vec![
regex::Regex::new(r"\bwhat (?:is|are|was|were)\b")?,
regex::Regex::new(r"\bwhich (?:is|are|was|were)\b")?,
regex::Regex::new(r"\bhow many\b")?,
regex::Regex::new(r"\bhow much\b")?,
],
weight: 1.0,
requires_all: false,
},
);
self.add_intent_pattern(
QueryIntent::Definitional,
IntentPattern {
keywords: ["define", "definition", "meaning", "explain", "what is"]
.iter()
.map(|s| s.to_string())
.collect(),
patterns: vec![
regex::Regex::new(r"\bdefine\b")?,
regex::Regex::new(r"\bdefinition of\b")?,
regex::Regex::new(r"\bmeaning of\b")?,
regex::Regex::new(r"\bexplain what\b")?,
regex::Regex::new(r"\bwhat (?:is|are) (?:the )?(?:concept|idea|notion) of\b")?,
],
weight: 1.0,
requires_all: false,
},
);
self.add_intent_pattern(
QueryIntent::Relational,
IntentPattern {
keywords: ["relationship", "related", "connection", "between", "and"]
.iter()
.map(|s| s.to_string())
.collect(),
patterns: vec![
regex::Regex::new(r"\brelationship between\b")?,
regex::Regex::new(r"\bhow (?:is|are) .+ related to\b")?,
regex::Regex::new(r"\bconnection between\b")?,
regex::Regex::new(r"\b\w+ and \w+\b")?, ],
weight: 1.0,
requires_all: false,
},
);
self.add_intent_pattern(
QueryIntent::Temporal,
IntentPattern {
keywords: ["when", "time", "date", "year", "before", "after", "during"]
.iter()
.map(|s| s.to_string())
.collect(),
patterns: vec![
regex::Regex::new(r"\bwhen (?:did|was|were|will|is|are)\b")?,
regex::Regex::new(r"\bwhat (?:time|date|year)\b")?,
regex::Regex::new(r"\bbefore .+ happened\b")?,
regex::Regex::new(r"\bafter .+ happened\b")?,
regex::Regex::new(r"\bduring .+ period\b")?,
],
weight: 1.0,
requires_all: false,
},
);
self.add_intent_pattern(
QueryIntent::Causal,
IntentPattern {
keywords: ["why", "because", "cause", "reason", "result", "due to"]
.iter()
.map(|s| s.to_string())
.collect(),
patterns: vec![
regex::Regex::new(r"\bwhy (?:did|was|were|is|are|do|does)\b")?,
regex::Regex::new(r"\bwhat (?:caused|causes)\b")?,
regex::Regex::new(r"\breason for\b")?,
regex::Regex::new(r"\bdue to what\b")?,
regex::Regex::new(r"\bwhat led to\b")?,
],
weight: 1.0,
requires_all: false,
},
);
self.add_intent_pattern(
QueryIntent::Comparative,
IntentPattern {
keywords: [
"compare",
"difference",
"versus",
"vs",
"better",
"worse",
"similar",
]
.iter()
.map(|s| s.to_string())
.collect(),
patterns: vec![
regex::Regex::new(r"\bcompare .+ (?:to|with|and)\b")?,
regex::Regex::new(r"\bdifference between\b")?,
regex::Regex::new(r"\b.+ (?:versus|vs) .+\b")?,
regex::Regex::new(r"\bwhich is (?:better|worse)\b")?,
regex::Regex::new(r"\bhow (?:similar|different)\b")?,
],
weight: 1.0,
requires_all: false,
},
);
self.add_intent_pattern(
QueryIntent::Summary,
IntentPattern {
keywords: [
"summarize",
"overview",
"summary",
"tell me about",
"describe",
]
.iter()
.map(|s| s.to_string())
.collect(),
patterns: vec![
regex::Regex::new(r"\bsummarize\b")?,
regex::Regex::new(r"\bgive (?:me )?(?:an )?overview\b")?,
regex::Regex::new(r"\btell me about\b")?,
regex::Regex::new(r"\bdescribe .+\b")?,
regex::Regex::new(r"\bwhat (?:can you tell me )?about\b")?,
],
weight: 1.0,
requires_all: false,
},
);
if self.config.enable_inappropriate_detection {
self.inappropriate_patterns = vec![
regex::Regex::new(r"\b(?:hate|violence|harm|illegal|inappropriate)\b")?,
];
}
Ok(())
}
fn add_intent_pattern(&mut self, intent: QueryIntent, pattern: IntentPattern) {
self.intent_patterns
.entry(intent)
.or_default()
.push(pattern);
}
pub fn classify(&self, query: &str) -> Result<IntentResult> {
let query_lower = query.to_lowercase();
if self.config.enable_inappropriate_detection && self.is_inappropriate(&query_lower) {
return Ok(IntentResult {
primary_intent: QueryIntent::Inappropriate,
secondary_intents: vec![],
confidence: 1.0,
should_refuse: true,
refusal_reason: Some("Query contains inappropriate content".to_string()),
suggested_reformulation: None,
complexity_score: 0.0,
});
}
let mut intent_scores: HashMap<QueryIntent, f32> = HashMap::new();
for (intent, patterns) in &self.intent_patterns {
let score = self.calculate_intent_score(&query_lower, patterns);
if score > 0.0 {
intent_scores.insert(intent.clone(), score);
}
}
let mut sorted_intents: Vec<(QueryIntent, f32)> = intent_scores.into_iter().collect();
sorted_intents.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
if sorted_intents.is_empty() {
return Ok(IntentResult {
primary_intent: QueryIntent::Ambiguous,
secondary_intents: vec![],
confidence: 0.0,
should_refuse: true,
refusal_reason: Some("Unable to understand the query intent".to_string()),
suggested_reformulation: self.suggest_reformulation(query),
complexity_score: self.calculate_complexity(&query_lower),
});
}
let (primary_intent, primary_score) = sorted_intents[0].clone();
let secondary_intents: Vec<(QueryIntent, f32)> =
sorted_intents.into_iter().skip(1).take(2).collect();
let is_ambiguous = if self.config.enable_ambiguity_detection {
secondary_intents
.iter()
.any(|(_, score)| *score > primary_score * 0.8)
} else {
false
};
let final_intent = if is_ambiguous {
QueryIntent::Ambiguous
} else {
primary_intent
};
let should_refuse = primary_score < self.config.refusal_threshold || is_ambiguous;
let refusal_reason = if should_refuse {
if is_ambiguous {
Some("Query intent is ambiguous - please be more specific".to_string())
} else {
Some("Insufficient confidence in understanding the query".to_string())
}
} else {
None
};
Ok(IntentResult {
primary_intent: final_intent,
secondary_intents,
confidence: primary_score,
should_refuse,
refusal_reason,
suggested_reformulation: if should_refuse && self.config.suggest_reformulations {
self.suggest_reformulation(query)
} else {
None
},
complexity_score: self.calculate_complexity(&query_lower),
})
}
fn calculate_intent_score(&self, query: &str, patterns: &[IntentPattern]) -> f32 {
let mut total_score = 0.0;
for pattern in patterns {
let mut pattern_score = 0.0;
let keyword_matches = pattern
.keywords
.iter()
.filter(|keyword| query.contains(&keyword.to_lowercase()))
.count();
if pattern.requires_all && keyword_matches != pattern.keywords.len() {
continue; }
if keyword_matches > 0 {
pattern_score += (keyword_matches as f32 / pattern.keywords.len() as f32) * 0.5;
}
let regex_matches = pattern
.patterns
.iter()
.filter(|regex| regex.is_match(query))
.count();
if regex_matches > 0 {
pattern_score += (regex_matches as f32 / pattern.patterns.len() as f32) * 0.5;
}
total_score += pattern_score * pattern.weight;
}
total_score.min(1.0) }
fn is_inappropriate(&self, query: &str) -> bool {
self.inappropriate_patterns
.iter()
.any(|pattern| pattern.is_match(query))
}
fn calculate_complexity(&self, query: &str) -> f32 {
let word_count = query.split_whitespace().count();
let sentence_count = query
.chars()
.filter(|&c| c == '.' || c == '?' || c == '!')
.count()
.max(1);
let avg_word_length =
query.chars().filter(|c| c.is_alphabetic()).count() as f32 / word_count.max(1) as f32;
let length_complexity = (word_count as f32 / 20.0).min(1.0); let sentence_complexity = (sentence_count as f32 / 3.0).min(1.0); let word_length_complexity = (avg_word_length / 8.0).min(1.0);
let has_conjunctions =
query.contains(" and ") || query.contains(" or ") || query.contains(" but ");
let has_subordination = query.contains(" because ")
|| query.contains(" since ")
|| query.contains(" although ");
let construct_complexity = if has_conjunctions || has_subordination {
0.3
} else {
0.0
};
(length_complexity * 0.3
+ sentence_complexity * 0.2
+ word_length_complexity * 0.2
+ construct_complexity)
.min(1.0)
}
fn suggest_reformulation(&self, query: &str) -> Option<String> {
if !self.config.suggest_reformulations {
return None;
}
let query_lower = query.to_lowercase();
if query_lower.starts_with("tell me about") {
Some(
"Try asking a more specific question like 'What is...?' or 'How does...?'"
.to_string(),
)
} else if query_lower.contains(" and ") {
Some(
"Try breaking your question into separate parts or focus on one aspect".to_string(),
)
} else if query.split_whitespace().count() > 20 {
Some("Try using a shorter, more focused question".to_string())
} else if !query.ends_with('?') && !query.ends_with('.') && !query.ends_with('!') {
Some("Try phrasing your request as a clear question".to_string())
} else {
Some("Try being more specific about what information you're looking for".to_string())
}
}
pub fn get_config(&self) -> &IntentClassificationConfig {
&self.config
}
pub fn update_config(&mut self, config: IntentClassificationConfig) -> Result<()> {
let old_inappropriate_detection = self.config.enable_inappropriate_detection;
self.config = config;
if self.config.enable_inappropriate_detection != old_inappropriate_detection {
self.initialize_patterns()?;
}
Ok(())
}
pub fn get_statistics(&self) -> IntentClassificationStats {
let total_patterns = self
.intent_patterns
.values()
.map(|patterns| patterns.len())
.sum();
IntentClassificationStats {
supported_intents: self.intent_patterns.keys().cloned().collect(),
total_patterns,
inappropriate_patterns: self.inappropriate_patterns.len(),
confidence_threshold: self.config.confidence_threshold,
refusal_threshold: self.config.refusal_threshold,
}
}
}
#[cfg(feature = "rograg")]
#[derive(Debug)]
pub struct IntentClassificationStats {
pub supported_intents: Vec<QueryIntent>,
pub total_patterns: usize,
pub inappropriate_patterns: usize,
pub confidence_threshold: f32,
pub refusal_threshold: f32,
}
#[cfg(feature = "rograg")]
impl IntentClassificationStats {
pub fn print(&self) {
tracing::info!(
supported_intents = self.supported_intents.len(),
total_patterns = self.total_patterns,
inappropriate_patterns = self.inappropriate_patterns,
confidence_threshold = format!("{:.2}", self.confidence_threshold),
refusal_threshold = format!("{:.2}", self.refusal_threshold),
"Intent classification statistics"
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "rograg")]
#[test]
fn test_ambiguous_query() {
let classifier = IntentClassifier::new().unwrap();
let result = classifier.classify("something unclear").unwrap();
assert!(result.primary_intent == QueryIntent::Ambiguous || result.confidence < 0.5);
}
#[cfg(feature = "rograg")]
#[test]
fn test_complexity_calculation() {
let classifier = IntentClassifier::new().unwrap();
let simple_result = classifier.classify("What is Tom?").unwrap();
let complex_result = classifier.classify("What is the intricate relationship between Entity Name and Second Entity, and how does it evolve throughout their various adventures and escapades?").unwrap();
assert!(complex_result.complexity_score > simple_result.complexity_score);
}
#[cfg(feature = "rograg")]
#[test]
fn test_reformulation_suggestions() {
let config = IntentClassificationConfig {
suggest_reformulations: true,
refusal_threshold: 0.9, ..Default::default()
};
let classifier = IntentClassifier::with_config(config).unwrap();
let result = classifier.classify("tell me about stuff").unwrap();
assert!(result.suggested_reformulation.is_some());
}
}