oxirs_chat/rag/
query_processing.rs1use super::*;
6
7pub struct QueryProcessor;
9
10impl Default for QueryProcessor {
11 fn default() -> Self {
12 Self::new()
13 }
14}
15
16impl QueryProcessor {
17 pub fn new() -> Self {
18 Self
19 }
20
21 pub async fn extract_constraints(
23 &self,
24 query: &str,
25 entities: &[ExtractedEntity],
26 ) -> Result<Vec<QueryConstraint>> {
27 let mut constraints = Vec::new();
28 let query_lower = query.to_lowercase();
29
30 let temporal_patterns = [
32 (
33 r"(?:in|during|from|since|before|after)\s+(\d{4})",
34 ConstraintType::Temporal,
35 "year",
36 ),
37 (
38 r"(?:today|yesterday|tomorrow|now|recent)",
39 ConstraintType::Temporal,
40 "relative_time",
41 ),
42 ];
43
44 for (pattern, constraint_type, operator) in temporal_patterns {
45 let regex = Regex::new(pattern)?;
46 for cap in regex.captures_iter(&query_lower) {
47 if let Some(value) = cap.get(1) {
48 constraints.push(QueryConstraint {
49 constraint_type,
50 value: value.as_str().to_string(),
51 operator: operator.to_string(),
52 });
53 } else if cap.get(0).is_some() {
54 constraints.push(QueryConstraint {
55 constraint_type,
56 value: cap
57 .get(0)
58 .expect("capture group 0 should exist")
59 .as_str()
60 .to_string(),
61 operator: operator.to_string(),
62 });
63 }
64 }
65 }
66
67 if query_lower.contains("type")
69 || query_lower.contains("kind")
70 || query_lower.contains("class")
71 {
72 constraints.push(QueryConstraint {
73 constraint_type: ConstraintType::Type,
74 value: "type_constraint".to_string(),
75 operator: "equals".to_string(),
76 });
77 }
78
79 let value_patterns = [
81 (r"(?:greater than|more than|>\s*)(\d+)", "greater_than"),
82 (r"(?:less than|fewer than|<\s*)(\d+)", "less_than"),
83 (r"(?:equals?|is|=\s*)(\d+)", "equals"),
84 ];
85
86 for (pattern, operator) in value_patterns {
87 let regex = Regex::new(pattern)?;
88 for cap in regex.captures_iter(&query_lower) {
89 if let Some(value) = cap.get(1) {
90 constraints.push(QueryConstraint {
91 constraint_type: ConstraintType::Value,
92 value: value.as_str().to_string(),
93 operator: operator.to_string(),
94 });
95 }
96 }
97 }
98
99 for entity in entities {
101 match entity.entity_type {
102 EntityType::Person => {
103 constraints.push(QueryConstraint {
104 constraint_type: ConstraintType::Entity,
105 value: entity.text.clone(),
106 operator: "person_filter".to_string(),
107 });
108 }
109 EntityType::Location => {
110 constraints.push(QueryConstraint {
111 constraint_type: ConstraintType::Spatial,
112 value: entity.text.clone(),
113 operator: "location_filter".to_string(),
114 });
115 }
116 _ => {}
117 }
118 }
119
120 debug!("Extracted {} constraints from query", constraints.len());
121 Ok(constraints)
122 }
123
124 pub fn analyze_query_intent(&self, query: &str) -> QueryIntent {
126 let query_lower = query.to_lowercase();
127
128 if query_lower.contains("how many") || query_lower.contains("count") {
129 QueryIntent::Counting
130 } else if query_lower.contains("what is") || query_lower.contains("define") {
131 QueryIntent::Definition
132 } else if query_lower.contains("compare") || query_lower.contains("difference") {
133 QueryIntent::Comparison
134 } else if query_lower.contains("list") || query_lower.contains("show all") {
135 QueryIntent::Listing
136 } else if query_lower.contains("why") || query_lower.contains("because") {
137 QueryIntent::Explanation
138 } else {
139 QueryIntent::General
140 }
141 }
142
143 pub fn calculate_query_complexity(&self, query: &str) -> f64 {
145 let word_count = query.split_whitespace().count();
146 let unique_words = query.split_whitespace().collect::<HashSet<_>>().len();
147 let question_words = ["what", "how", "why", "when", "where", "who", "which"];
148 let query_lower = query.to_lowercase();
149
150 let question_word_count = question_words
151 .iter()
152 .filter(|word| query_lower.contains(*word))
153 .count();
154
155 let complexity = (word_count as f64 * 0.05)
156 + (unique_words as f64 * 0.1)
157 + (question_word_count as f64 * 0.2);
158
159 complexity.min(1.0)
160 }
161}
162
163#[derive(Debug, Clone)]
165pub struct QueryConstraint {
166 pub constraint_type: ConstraintType,
167 pub value: String,
168 pub operator: String,
169}
170
171#[derive(Debug, Clone, Copy)]
173pub enum ConstraintType {
174 Temporal,
175 Spatial,
176 Type,
177 Value,
178 Entity,
179 Relationship,
180}
181
182#[derive(Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize)]
184pub enum QueryIntent {
185 Definition,
186 Comparison,
187 Counting,
188 Listing,
189 Explanation,
190 General,
191}
192
193use super::graph_traversal::{EntityType, ExtractedEntity};