#[cfg(feature = "rograg")]
use crate::Result;
#[cfg(feature = "rograg")]
use async_trait::async_trait;
#[cfg(feature = "rograg")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "rograg")]
use strum::{Display as StrumDisplay, EnumString};
#[cfg(feature = "rograg")]
use thiserror::Error;
#[cfg(feature = "rograg")]
#[derive(Error, Debug)]
pub enum DecompositionError {
#[error("Query too complex to decompose: {message}")]
TooComplex {
message: String,
},
#[error("Invalid query structure: {message}")]
InvalidStructure {
message: String,
},
#[error("Decomposition strategy failed: {strategy}: {reason}")]
StrategyFailed {
strategy: String,
reason: String,
},
#[error("No valid subqueries generated")]
NoValidSubqueries,
}
#[cfg(feature = "rograg")]
#[derive(Debug, Clone, StrumDisplay, EnumString, Serialize, Deserialize, PartialEq)]
pub enum DecompositionStrategy {
Semantic,
Syntactic,
Hybrid,
Logical,
}
#[cfg(feature = "rograg")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DecompositionResult {
pub original_query: String,
pub subqueries: Vec<Subquery>,
pub strategy_used: DecompositionStrategy,
pub confidence: f32,
pub dependencies: Vec<QueryDependency>,
}
#[cfg(feature = "rograg")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Subquery {
pub id: String,
pub text: String,
pub query_type: SubqueryType,
pub priority: f32,
pub dependencies: Vec<String>,
}
#[cfg(feature = "rograg")]
#[derive(Debug, Clone, StrumDisplay, EnumString, Serialize, Deserialize)]
pub enum SubqueryType {
Entity,
Relationship,
Attribute,
Temporal,
Causal,
Comparative,
Definitional,
}
#[cfg(feature = "rograg")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryDependency {
pub dependent_id: String,
pub prerequisite_id: String,
pub dependency_type: DependencyType,
}
#[cfg(feature = "rograg")]
#[derive(Debug, Clone, StrumDisplay, EnumString, Serialize, Deserialize)]
pub enum DependencyType {
Sequential,
Reference,
Context,
}
#[cfg(feature = "rograg")]
#[async_trait]
pub trait QueryDecomposer: Send + Sync {
async fn decompose(&self, query: &str) -> Result<DecompositionResult>;
fn can_decompose(&self, query: &str) -> bool;
fn strategy_name(&self) -> &str;
}
#[cfg(feature = "rograg")]
pub struct SemanticQueryDecomposer {
patterns: Vec<SemanticPattern>,
}
#[cfg(feature = "rograg")]
#[derive(Debug, Clone)]
struct SemanticPattern {
pattern: regex::Regex,
extractor: fn(&str) -> Vec<String>,
subquery_type: SubqueryType,
}
#[cfg(feature = "rograg")]
impl SemanticQueryDecomposer {
pub fn new() -> Result<Self> {
let patterns = vec![
SemanticPattern {
pattern: regex::Regex::new(r"\b(who|what) is (.+?) and (.+)")?,
extractor: |text| {
if let Some(caps) = regex::Regex::new(r"\b(who|what) is (.+?) and (.+)")
.expect("static regex literal")
.captures(text)
{
vec![
format!(
"{} is {}",
caps.get(1).expect("regex capture").as_str(),
caps.get(2).expect("regex capture").as_str()
),
caps.get(3).expect("regex capture").as_str().to_string(),
]
} else {
vec![]
}
},
subquery_type: SubqueryType::Entity,
},
SemanticPattern {
pattern: regex::Regex::new(
r"\bhow (?:is|are) (.+?) (?:related to|connected to) (.+)",
)?,
extractor: |text| {
if let Some(caps) = regex::Regex::new(
r"\bhow (?:is|are) (.+?) (?:related to|connected to) (.+)",
)
.expect("static regex literal")
.captures(text)
{
vec![
format!("What is {}", caps.get(1).expect("regex capture").as_str()),
format!("What is {}", caps.get(2).expect("regex capture").as_str()),
format!(
"How are {} and {} related",
caps.get(1).expect("regex capture").as_str(),
caps.get(2).expect("regex capture").as_str()
),
]
} else {
vec![]
}
},
subquery_type: SubqueryType::Relationship,
},
];
Ok(Self { patterns })
}
}
#[cfg(feature = "rograg")]
#[async_trait]
impl QueryDecomposer for SemanticQueryDecomposer {
async fn decompose(&self, query: &str) -> Result<DecompositionResult> {
let mut all_subqueries = Vec::new();
let mut strategy_confidence = 0.0;
for pattern in &self.patterns {
if pattern.pattern.is_match(query) {
let subquery_texts = (pattern.extractor)(query);
for (idx, text) in subquery_texts.into_iter().enumerate() {
if !text.trim().is_empty() {
all_subqueries.push(Subquery {
id: format!("sem_{idx}"),
text: text.trim().to_string(),
query_type: pattern.subquery_type.clone(),
priority: 1.0 - (idx as f32 * 0.1),
dependencies: if idx > 0 {
vec![format!("sem_{}", idx - 1)]
} else {
vec![]
},
});
}
}
strategy_confidence = 0.8;
break;
}
}
if all_subqueries.is_empty() {
let conjunctions = ["and", "or", "but", "also", "furthermore"];
for conjunction in &conjunctions {
if query.to_lowercase().contains(conjunction) {
let parts: Vec<&str> = query.split(conjunction).collect();
if parts.len() > 1 {
for (idx, part) in parts.iter().enumerate() {
let text = part.trim();
if !text.is_empty() {
all_subqueries.push(Subquery {
id: format!("sem_fallback_{idx}"),
text: text.to_string(),
query_type: SubqueryType::Entity, priority: 1.0 - (idx as f32 * 0.2),
dependencies: vec![],
});
}
}
strategy_confidence = 0.5;
break;
}
}
}
}
if all_subqueries.is_empty() {
return Ok(DecompositionResult::single_query(query.to_string()));
}
let dependencies = self.analyze_dependencies(&all_subqueries);
Ok(DecompositionResult {
original_query: query.to_string(),
subqueries: all_subqueries,
strategy_used: DecompositionStrategy::Semantic,
confidence: strategy_confidence,
dependencies,
})
}
fn can_decompose(&self, query: &str) -> bool {
self.patterns.iter().any(|p| p.pattern.is_match(query))
}
fn strategy_name(&self) -> &str {
"semantic"
}
}
#[cfg(feature = "rograg")]
impl SemanticQueryDecomposer {
fn analyze_dependencies(&self, subqueries: &[Subquery]) -> Vec<QueryDependency> {
let mut dependencies = Vec::new();
for (i, subquery) in subqueries.iter().enumerate() {
match subquery.query_type {
SubqueryType::Relationship => {
for (j, prereq) in subqueries.iter().enumerate() {
if j < i
&& matches!(
prereq.query_type,
SubqueryType::Entity | SubqueryType::Definitional
)
{
dependencies.push(QueryDependency {
dependent_id: subquery.id.clone(),
prerequisite_id: prereq.id.clone(),
dependency_type: DependencyType::Reference,
});
}
}
},
SubqueryType::Attribute => {
for (j, prereq) in subqueries.iter().enumerate() {
if j < i
&& matches!(
prereq.query_type,
SubqueryType::Entity | SubqueryType::Definitional
)
{
dependencies.push(QueryDependency {
dependent_id: subquery.id.clone(),
prerequisite_id: prereq.id.clone(),
dependency_type: DependencyType::Reference,
});
}
}
},
SubqueryType::Comparative => {
for (j, prereq) in subqueries.iter().enumerate() {
if j < i
&& matches!(
prereq.query_type,
SubqueryType::Entity
| SubqueryType::Attribute
| SubqueryType::Definitional
)
{
dependencies.push(QueryDependency {
dependent_id: subquery.id.clone(),
prerequisite_id: prereq.id.clone(),
dependency_type: DependencyType::Reference,
});
}
}
},
SubqueryType::Temporal => {
for (j, prereq) in subqueries.iter().enumerate() {
if j < i && matches!(prereq.query_type, SubqueryType::Entity) {
dependencies.push(QueryDependency {
dependent_id: subquery.id.clone(),
prerequisite_id: prereq.id.clone(),
dependency_type: DependencyType::Context,
});
break; }
}
},
SubqueryType::Causal => {
if i > 0 {
let prev_query = &subqueries[i - 1];
if matches!(
prev_query.query_type,
SubqueryType::Entity
| SubqueryType::Temporal
| SubqueryType::Relationship
) {
dependencies.push(QueryDependency {
dependent_id: subquery.id.clone(),
prerequisite_id: prev_query.id.clone(),
dependency_type: DependencyType::Sequential,
});
}
}
},
SubqueryType::Entity | SubqueryType::Definitional => {
},
}
}
dependencies.sort_by(|a, b| {
a.dependent_id
.cmp(&b.dependent_id)
.then(a.prerequisite_id.cmp(&b.prerequisite_id))
});
dependencies.dedup_by(|a, b| {
a.dependent_id == b.dependent_id && a.prerequisite_id == b.prerequisite_id
});
dependencies
}
}
#[cfg(feature = "rograg")]
pub struct SyntacticQueryDecomposer {
clause_separators: Vec<String>,
}
#[cfg(feature = "rograg")]
impl Default for SyntacticQueryDecomposer {
fn default() -> Self {
Self::new()
}
}
impl SyntacticQueryDecomposer {
pub fn new() -> Self {
Self {
clause_separators: vec![
"and".to_string(),
"or".to_string(),
"but".to_string(),
",".to_string(),
";".to_string(),
"also".to_string(),
"furthermore".to_string(),
"moreover".to_string(),
"however".to_string(),
"therefore".to_string(),
],
}
}
fn identify_clause_boundaries(&self, query: &str) -> Vec<usize> {
let mut boundaries = vec![0];
for separator in &self.clause_separators {
let separator_lower = separator.to_lowercase();
let query_lower = query.to_lowercase();
let mut start = 0;
while let Some(pos) = query_lower[start..].find(&separator_lower) {
let absolute_pos = start + pos;
if !boundaries.contains(&absolute_pos) {
boundaries.push(absolute_pos);
}
start = absolute_pos + separator.len();
}
}
boundaries.push(query.len());
boundaries.sort();
boundaries.dedup();
boundaries
}
fn extract_clauses(&self, query: &str) -> Vec<String> {
let boundaries = self.identify_clause_boundaries(query);
let mut clauses = Vec::new();
for window in boundaries.windows(2) {
if let [start, end] = window {
let clause = query[*start..*end].trim();
let clause = self
.clause_separators
.iter()
.fold(clause.to_string(), |acc, sep| {
if acc.to_lowercase().starts_with(&sep.to_lowercase()) {
acc[sep.len()..].trim().to_string()
} else {
acc
}
});
if !clause.is_empty() && clause.len() > 3 {
clauses.push(clause);
}
}
}
clauses
}
fn classify_clause_type(&self, clause: &str) -> SubqueryType {
let clause_lower = clause.to_lowercase();
if clause_lower.starts_with("who") || clause_lower.starts_with("what person") {
SubqueryType::Entity
} else if clause_lower.starts_with("what") {
SubqueryType::Definitional
} else if clause_lower.starts_with("when") {
SubqueryType::Temporal
} else if clause_lower.starts_with("why") || clause_lower.contains("because") {
SubqueryType::Causal
} else if clause_lower.contains("relation") || clause_lower.contains("connect") {
SubqueryType::Relationship
} else if clause_lower.contains("compare") || clause_lower.contains("versus") {
SubqueryType::Comparative
} else {
SubqueryType::Attribute
}
}
}
#[cfg(feature = "rograg")]
#[async_trait]
impl QueryDecomposer for SyntacticQueryDecomposer {
async fn decompose(&self, query: &str) -> Result<DecompositionResult> {
let clauses = self.extract_clauses(query);
if clauses.len() <= 1 {
return Ok(DecompositionResult::single_query(query.to_string()));
}
let subqueries: Vec<Subquery> = clauses
.into_iter()
.enumerate()
.map(|(idx, clause)| Subquery {
id: format!("syn_{idx}"),
text: clause.clone(),
query_type: self.classify_clause_type(&clause),
priority: 1.0 - (idx as f32 * 0.1),
dependencies: vec![],
})
.collect();
let confidence = if subqueries.len() > 1 { 0.7 } else { 0.3 };
Ok(DecompositionResult {
original_query: query.to_string(),
subqueries,
strategy_used: DecompositionStrategy::Syntactic,
confidence,
dependencies: vec![],
})
}
fn can_decompose(&self, query: &str) -> bool {
self.clause_separators
.iter()
.any(|sep| query.to_lowercase().contains(&sep.to_lowercase()))
}
fn strategy_name(&self) -> &str {
"syntactic"
}
}
#[cfg(feature = "rograg")]
pub struct HybridQueryDecomposer {
semantic: SemanticQueryDecomposer,
syntactic: SyntacticQueryDecomposer,
}
#[cfg(feature = "rograg")]
impl HybridQueryDecomposer {
pub fn new() -> Result<Self> {
Ok(Self {
semantic: SemanticQueryDecomposer::new()?,
syntactic: SyntacticQueryDecomposer::new(),
})
}
}
#[cfg(feature = "rograg")]
#[async_trait]
impl QueryDecomposer for HybridQueryDecomposer {
async fn decompose(&self, query: &str) -> Result<DecompositionResult> {
if self.semantic.can_decompose(query) {
let semantic_result = self.semantic.decompose(query).await?;
if semantic_result.confidence > 0.6 {
return Ok(DecompositionResult {
strategy_used: DecompositionStrategy::Hybrid,
..semantic_result
});
}
}
if self.syntactic.can_decompose(query) {
let syntactic_result = self.syntactic.decompose(query).await?;
return Ok(DecompositionResult {
strategy_used: DecompositionStrategy::Hybrid,
..syntactic_result
});
}
Ok(DecompositionResult::single_query(query.to_string()))
}
fn can_decompose(&self, query: &str) -> bool {
self.semantic.can_decompose(query) || self.syntactic.can_decompose(query)
}
fn strategy_name(&self) -> &str {
"hybrid"
}
}
#[cfg(feature = "rograg")]
impl DecompositionResult {
pub fn single_query(query: String) -> Self {
Self {
original_query: query.clone(),
subqueries: vec![Subquery {
id: "single".to_string(),
text: query,
query_type: SubqueryType::Entity,
priority: 1.0,
dependencies: vec![],
}],
strategy_used: DecompositionStrategy::Semantic,
confidence: 1.0,
dependencies: vec![],
}
}
pub fn is_decomposed(&self) -> bool {
self.subqueries.len() > 1
}
pub fn ordered_subqueries(&self) -> Vec<&Subquery> {
let mut subqueries: Vec<&Subquery> = self.subqueries.iter().collect();
subqueries.sort_by(|a, b| {
b.priority
.partial_cmp(&a.priority)
.unwrap_or(std::cmp::Ordering::Equal)
});
subqueries
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "rograg")]
#[tokio::test]
async fn test_semantic_decomposition() {
let decomposer = SemanticQueryDecomposer::new().unwrap();
let result = decomposer
.decompose("Who is Entity Name and what is his relationship with Second Entity?")
.await
.unwrap();
assert!(result.is_decomposed());
assert!(result.subqueries.len() >= 2);
assert_eq!(result.strategy_used, DecompositionStrategy::Semantic);
}
#[cfg(feature = "rograg")]
#[tokio::test]
async fn test_syntactic_decomposition() {
let decomposer = SyntacticQueryDecomposer::new();
let result = decomposer
.decompose("Tell me about Entity Name, and also describe Second Entity")
.await
.unwrap();
assert!(result.is_decomposed());
assert_eq!(result.strategy_used, DecompositionStrategy::Syntactic);
}
#[cfg(feature = "rograg")]
#[tokio::test]
async fn test_hybrid_decomposition() {
let decomposer = HybridQueryDecomposer::new().unwrap();
let result = decomposer
.decompose("What is friendship and how are Tom and Huck related?")
.await
.unwrap();
assert_eq!(result.strategy_used, DecompositionStrategy::Hybrid);
}
#[cfg(feature = "rograg")]
#[tokio::test]
async fn test_single_query_fallback() {
let decomposer = HybridQueryDecomposer::new().unwrap();
let result = decomposer.decompose("Simple query").await.unwrap();
assert!(!result.is_decomposed());
assert_eq!(result.subqueries.len(), 1);
}
}