1use std::any::Any;
14use std::fmt;
15use std::sync::Arc;
16
17use arrow_array::{RecordBatch, StringArray, UInt32Array};
18use arrow_schema::{DataType, Field, Schema, SchemaRef};
19use datafusion_common::Result;
20use datafusion_execution::{SendableRecordBatchStream, TaskContext};
21use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
22use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
23use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum Complexity {
28 Simple,
29 Medium,
30 Complex,
31}
32
33impl fmt::Display for Complexity {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 match self {
36 Self::Simple => write!(f, "Simple"),
37 Self::Medium => write!(f, "Medium"),
38 Self::Complex => write!(f, "Complex"),
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct ComplexityConfig {
46 pub token_threshold: usize,
48 pub entity_threshold: usize,
50 pub medium_threshold: u32,
52 pub complex_threshold: u32,
54}
55
56impl Default for ComplexityConfig {
57 fn default() -> Self {
58 Self {
59 token_threshold: 50,
60 entity_threshold: 3,
61 medium_threshold: 1,
62 complex_threshold: 3,
63 }
64 }
65}
66
67#[derive(Debug, Clone, Default)]
69pub struct QueryFeatures {
70 pub token_count: usize,
72 pub has_temporal: bool,
74 pub entity_count: usize,
76 pub graph_depth: u32,
78 pub has_causal: bool,
80 pub is_iterative: bool,
82}
83
84impl QueryFeatures {
85 pub fn classify(&self, config: &ComplexityConfig) -> (Complexity, u32) {
87 let mut points: u32 = 0;
88 if self.token_count > config.token_threshold {
89 points += 1;
90 }
91 if self.has_temporal {
92 points += 1;
93 }
94 if self.entity_count > config.entity_threshold {
95 points += 1;
96 }
97 if self.graph_depth > 1 {
98 points += 1;
99 }
100 if self.has_causal {
101 points += 1;
102 }
103 if self.is_iterative {
104 points += 1;
105 }
106
107 let complexity = if points >= config.complex_threshold {
108 Complexity::Complex
109 } else if points >= config.medium_threshold {
110 Complexity::Medium
111 } else {
112 Complexity::Simple
113 };
114
115 (complexity, points)
116 }
117}
118
119fn output_schema() -> SchemaRef {
121 Arc::new(Schema::new(vec![
122 Field::new("query_complexity", DataType::Utf8, false),
123 Field::new("complexity_points", DataType::UInt32, false),
124 ]))
125}
126
127#[derive(Debug)]
132pub struct QueryComplexityExec {
133 features: QueryFeatures,
134 config: ComplexityConfig,
135 schema: SchemaRef,
136 properties: PlanProperties,
137}
138
139impl QueryComplexityExec {
140 pub fn new(features: QueryFeatures, config: ComplexityConfig) -> Self {
141 let schema = output_schema();
142 let properties = PlanProperties::new(
143 datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
144 datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
145 EmissionType::Final,
146 Boundedness::Bounded,
147 );
148 Self {
149 features,
150 config,
151 schema,
152 properties,
153 }
154 }
155
156 pub fn features(&self) -> &QueryFeatures {
157 &self.features
158 }
159
160 pub fn config(&self) -> &ComplexityConfig {
161 &self.config
162 }
163}
164
165impl DisplayAs for QueryComplexityExec {
166 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167 let (complexity, points) = self.features.classify(&self.config);
168 write!(
169 f,
170 "QueryComplexityExec: complexity={complexity}, points={points}"
171 )
172 }
173}
174
175impl ExecutionPlan for QueryComplexityExec {
176 fn name(&self) -> &str {
177 "QueryComplexityExec"
178 }
179
180 fn as_any(&self) -> &dyn Any {
181 self
182 }
183
184 fn schema(&self) -> SchemaRef {
185 self.schema.clone()
186 }
187
188 fn properties(&self) -> &PlanProperties {
189 &self.properties
190 }
191
192 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
193 vec![]
194 }
195
196 fn with_new_children(
197 self: Arc<Self>,
198 _children: Vec<Arc<dyn ExecutionPlan>>,
199 ) -> Result<Arc<dyn ExecutionPlan>> {
200 Ok(Arc::new(Self::new(
201 self.features.clone(),
202 self.config.clone(),
203 )))
204 }
205
206 fn execute(
207 &self,
208 _partition: usize,
209 _context: Arc<TaskContext>,
210 ) -> Result<SendableRecordBatchStream> {
211 let (complexity, points) = self.features.classify(&self.config);
212
213 let batch = RecordBatch::try_new(
214 self.schema.clone(),
215 vec![
216 Arc::new(StringArray::from(vec![complexity.to_string()])),
217 Arc::new(UInt32Array::from(vec![points])),
218 ],
219 )?;
220
221 let schema = self.schema.clone();
222 let stream = futures::stream::once(async move { Ok(batch) });
223 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn simple_query() {
233 let features = QueryFeatures {
234 token_count: 5,
235 ..Default::default()
236 };
237 let config = ComplexityConfig::default();
238 let (complexity, points) = features.classify(&config);
239 assert_eq!(complexity, Complexity::Simple);
240 assert_eq!(points, 0);
241 }
242
243 #[test]
244 fn medium_query_temporal() {
245 let features = QueryFeatures {
246 token_count: 10,
247 has_temporal: true,
248 ..Default::default()
249 };
250 let config = ComplexityConfig::default();
251 let (complexity, points) = features.classify(&config);
252 assert_eq!(complexity, Complexity::Medium);
253 assert_eq!(points, 1);
254 }
255
256 #[test]
257 fn medium_query_graph_depth() {
258 let features = QueryFeatures {
259 graph_depth: 2,
260 ..Default::default()
261 };
262 let config = ComplexityConfig::default();
263 let (complexity, points) = features.classify(&config);
264 assert_eq!(complexity, Complexity::Medium);
265 assert_eq!(points, 1);
266 }
267
268 #[test]
269 fn complex_query_all_features() {
270 let features = QueryFeatures {
271 token_count: 60,
272 has_temporal: true,
273 entity_count: 5,
274 graph_depth: 3,
275 has_causal: true,
276 is_iterative: true,
277 };
278 let config = ComplexityConfig::default();
279 let (complexity, points) = features.classify(&config);
280 assert_eq!(complexity, Complexity::Complex);
281 assert_eq!(points, 6);
282 }
283
284 #[test]
285 fn complex_query_three_features() {
286 let features = QueryFeatures {
287 has_temporal: true,
288 entity_count: 5,
289 has_causal: true,
290 ..Default::default()
291 };
292 let config = ComplexityConfig::default();
293 let (complexity, points) = features.classify(&config);
294 assert_eq!(complexity, Complexity::Complex);
295 assert_eq!(points, 3);
296 }
297
298 #[test]
299 fn custom_thresholds() {
300 let features = QueryFeatures {
301 token_count: 30,
302 has_temporal: true,
303 ..Default::default()
304 };
305 let config = ComplexityConfig {
306 token_threshold: 20,
307 complex_threshold: 2,
308 ..Default::default()
309 };
310 let (complexity, points) = features.classify(&config);
311 assert_eq!(complexity, Complexity::Complex);
312 assert_eq!(points, 2);
313 }
314
315 #[test]
316 fn classification_sub_millisecond() {
317 let features = QueryFeatures {
318 token_count: 100,
319 has_temporal: true,
320 entity_count: 10,
321 graph_depth: 5,
322 has_causal: true,
323 is_iterative: true,
324 };
325 let config = ComplexityConfig::default();
326 let start = std::time::Instant::now();
327 for _ in 0..10_000 {
328 std::hint::black_box(features.classify(&config));
329 }
330 let elapsed = start.elapsed();
331 assert!(elapsed.as_millis() < 10, "too slow: {elapsed:?}");
333 }
334
335 #[tokio::test]
336 async fn execute_produces_batch() {
337 let features = QueryFeatures {
338 has_temporal: true,
339 has_causal: true,
340 entity_count: 5,
341 ..Default::default()
342 };
343 let exec = QueryComplexityExec::new(features, ComplexityConfig::default());
344 let ctx = Arc::new(TaskContext::default());
345 let mut stream = exec.execute(0, ctx).unwrap();
346
347 use futures::StreamExt;
348 let batch = stream.next().await.unwrap().unwrap();
349 assert_eq!(batch.num_rows(), 1);
350
351 let complexity = batch
352 .column(0)
353 .as_any()
354 .downcast_ref::<StringArray>()
355 .unwrap();
356 assert_eq!(complexity.value(0), "Complex");
357
358 let points = batch
359 .column(1)
360 .as_any()
361 .downcast_ref::<UInt32Array>()
362 .unwrap();
363 assert_eq!(points.value(0), 3);
364 }
365}