Skip to main content

hirn_exec/operators/
query_complexity.rs

1//! `QueryComplexityExec` — classifies query complexity for depth scheduling.
2//!
3//! Classification rules (all thresholds configurable):
4//! - Token count > 50: +1 point
5//! - Temporal keywords present: +1 point
6//! - Entity count > 3: +1 point
7//! - Multi-hop requested (EXPAND GRAPH DEPTH > 1): +1 point
8//! - Causal query (FOLLOW CAUSES): +1 point
9//! - Iterative mode: +1 point
10//!
11//! Simple (0 pts) / Medium (1–2 pts) / Complex (3+ pts).
12
13use std::any::Any;
14use std::fmt;
15use std::sync::Arc;
16
17use arrow_array::{RecordBatch, StringArray, UInt32Array};
18use arrow_schema::{DataType, Field, Schema, SchemaRef};
19use datafusion_common::Result;
20use datafusion_execution::{SendableRecordBatchStream, TaskContext};
21use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
22use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
23use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
24
25/// Complexity level for depth scheduling.
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum Complexity {
28    Simple,
29    Medium,
30    Complex,
31}
32
33impl fmt::Display for Complexity {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        match self {
36            Self::Simple => write!(f, "Simple"),
37            Self::Medium => write!(f, "Medium"),
38            Self::Complex => write!(f, "Complex"),
39        }
40    }
41}
42
43/// Configuration for complexity classification thresholds.
44#[derive(Debug, Clone)]
45pub struct ComplexityConfig {
46    /// Token count threshold (default: 50).
47    pub token_threshold: usize,
48    /// Entity count threshold (default: 3).
49    pub entity_threshold: usize,
50    /// Medium threshold: >= this many points (default: 1).
51    pub medium_threshold: u32,
52    /// Complex threshold: >= this many points (default: 3).
53    pub complex_threshold: u32,
54}
55
56impl Default for ComplexityConfig {
57    fn default() -> Self {
58        Self {
59            token_threshold: 50,
60            entity_threshold: 3,
61            medium_threshold: 1,
62            complex_threshold: 3,
63        }
64    }
65}
66
67/// Features extracted from a query for complexity classification.
68#[derive(Debug, Clone, Default)]
69pub struct QueryFeatures {
70    /// Approximate token count of the query text.
71    pub token_count: usize,
72    /// Whether temporal keywords are present (AFTER, BEFORE, BETWEEN, AS OF).
73    pub has_temporal: bool,
74    /// Number of entities referenced (INVOLVING clause count).
75    pub entity_count: usize,
76    /// Graph expansion depth (0 = no expansion).
77    pub graph_depth: u32,
78    /// Whether FOLLOW CAUSES is present.
79    pub has_causal: bool,
80    /// Whether iterative mode is requested.
81    pub is_iterative: bool,
82}
83
84impl QueryFeatures {
85    /// Classify query complexity based on features and config.
86    pub fn classify(&self, config: &ComplexityConfig) -> (Complexity, u32) {
87        let mut points: u32 = 0;
88        if self.token_count > config.token_threshold {
89            points += 1;
90        }
91        if self.has_temporal {
92            points += 1;
93        }
94        if self.entity_count > config.entity_threshold {
95            points += 1;
96        }
97        if self.graph_depth > 1 {
98            points += 1;
99        }
100        if self.has_causal {
101            points += 1;
102        }
103        if self.is_iterative {
104            points += 1;
105        }
106
107        let complexity = if points >= config.complex_threshold {
108            Complexity::Complex
109        } else if points >= config.medium_threshold {
110            Complexity::Medium
111        } else {
112            Complexity::Simple
113        };
114
115        (complexity, points)
116    }
117}
118
119/// Output schema: `query_complexity (Utf8)`, `complexity_points (UInt32)`.
120fn output_schema() -> SchemaRef {
121    Arc::new(Schema::new(vec![
122        Field::new("query_complexity", DataType::Utf8, false),
123        Field::new("complexity_points", DataType::UInt32, false),
124    ]))
125}
126
127/// DataFusion operator that classifies query complexity for depth scheduling.
128///
129/// This is a leaf operator (no children) — it computes classification from
130/// `QueryFeatures` provided at construction time.
131#[derive(Debug)]
132pub struct QueryComplexityExec {
133    features: QueryFeatures,
134    config: ComplexityConfig,
135    schema: SchemaRef,
136    properties: PlanProperties,
137}
138
139impl QueryComplexityExec {
140    pub fn new(features: QueryFeatures, config: ComplexityConfig) -> Self {
141        let schema = output_schema();
142        let properties = PlanProperties::new(
143            datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
144            datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
145            EmissionType::Final,
146            Boundedness::Bounded,
147        );
148        Self {
149            features,
150            config,
151            schema,
152            properties,
153        }
154    }
155
156    pub fn features(&self) -> &QueryFeatures {
157        &self.features
158    }
159
160    pub fn config(&self) -> &ComplexityConfig {
161        &self.config
162    }
163}
164
165impl DisplayAs for QueryComplexityExec {
166    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167        let (complexity, points) = self.features.classify(&self.config);
168        write!(
169            f,
170            "QueryComplexityExec: complexity={complexity}, points={points}"
171        )
172    }
173}
174
175impl ExecutionPlan for QueryComplexityExec {
176    fn name(&self) -> &str {
177        "QueryComplexityExec"
178    }
179
180    fn as_any(&self) -> &dyn Any {
181        self
182    }
183
184    fn schema(&self) -> SchemaRef {
185        self.schema.clone()
186    }
187
188    fn properties(&self) -> &PlanProperties {
189        &self.properties
190    }
191
192    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
193        vec![]
194    }
195
196    fn with_new_children(
197        self: Arc<Self>,
198        _children: Vec<Arc<dyn ExecutionPlan>>,
199    ) -> Result<Arc<dyn ExecutionPlan>> {
200        Ok(Arc::new(Self::new(
201            self.features.clone(),
202            self.config.clone(),
203        )))
204    }
205
206    fn execute(
207        &self,
208        _partition: usize,
209        _context: Arc<TaskContext>,
210    ) -> Result<SendableRecordBatchStream> {
211        let (complexity, points) = self.features.classify(&self.config);
212
213        let batch = RecordBatch::try_new(
214            self.schema.clone(),
215            vec![
216                Arc::new(StringArray::from(vec![complexity.to_string()])),
217                Arc::new(UInt32Array::from(vec![points])),
218            ],
219        )?;
220
221        let schema = self.schema.clone();
222        let stream = futures::stream::once(async move { Ok(batch) });
223        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn simple_query() {
233        let features = QueryFeatures {
234            token_count: 5,
235            ..Default::default()
236        };
237        let config = ComplexityConfig::default();
238        let (complexity, points) = features.classify(&config);
239        assert_eq!(complexity, Complexity::Simple);
240        assert_eq!(points, 0);
241    }
242
243    #[test]
244    fn medium_query_temporal() {
245        let features = QueryFeatures {
246            token_count: 10,
247            has_temporal: true,
248            ..Default::default()
249        };
250        let config = ComplexityConfig::default();
251        let (complexity, points) = features.classify(&config);
252        assert_eq!(complexity, Complexity::Medium);
253        assert_eq!(points, 1);
254    }
255
256    #[test]
257    fn medium_query_graph_depth() {
258        let features = QueryFeatures {
259            graph_depth: 2,
260            ..Default::default()
261        };
262        let config = ComplexityConfig::default();
263        let (complexity, points) = features.classify(&config);
264        assert_eq!(complexity, Complexity::Medium);
265        assert_eq!(points, 1);
266    }
267
268    #[test]
269    fn complex_query_all_features() {
270        let features = QueryFeatures {
271            token_count: 60,
272            has_temporal: true,
273            entity_count: 5,
274            graph_depth: 3,
275            has_causal: true,
276            is_iterative: true,
277        };
278        let config = ComplexityConfig::default();
279        let (complexity, points) = features.classify(&config);
280        assert_eq!(complexity, Complexity::Complex);
281        assert_eq!(points, 6);
282    }
283
284    #[test]
285    fn complex_query_three_features() {
286        let features = QueryFeatures {
287            has_temporal: true,
288            entity_count: 5,
289            has_causal: true,
290            ..Default::default()
291        };
292        let config = ComplexityConfig::default();
293        let (complexity, points) = features.classify(&config);
294        assert_eq!(complexity, Complexity::Complex);
295        assert_eq!(points, 3);
296    }
297
298    #[test]
299    fn custom_thresholds() {
300        let features = QueryFeatures {
301            token_count: 30,
302            has_temporal: true,
303            ..Default::default()
304        };
305        let config = ComplexityConfig {
306            token_threshold: 20,
307            complex_threshold: 2,
308            ..Default::default()
309        };
310        let (complexity, points) = features.classify(&config);
311        assert_eq!(complexity, Complexity::Complex);
312        assert_eq!(points, 2);
313    }
314
315    #[test]
316    fn classification_sub_millisecond() {
317        let features = QueryFeatures {
318            token_count: 100,
319            has_temporal: true,
320            entity_count: 10,
321            graph_depth: 5,
322            has_causal: true,
323            is_iterative: true,
324        };
325        let config = ComplexityConfig::default();
326        let start = std::time::Instant::now();
327        for _ in 0..10_000 {
328            std::hint::black_box(features.classify(&config));
329        }
330        let elapsed = start.elapsed();
331        // 10K classifications should take well under 1ms total.
332        assert!(elapsed.as_millis() < 10, "too slow: {elapsed:?}");
333    }
334
335    #[tokio::test]
336    async fn execute_produces_batch() {
337        let features = QueryFeatures {
338            has_temporal: true,
339            has_causal: true,
340            entity_count: 5,
341            ..Default::default()
342        };
343        let exec = QueryComplexityExec::new(features, ComplexityConfig::default());
344        let ctx = Arc::new(TaskContext::default());
345        let mut stream = exec.execute(0, ctx).unwrap();
346
347        use futures::StreamExt;
348        let batch = stream.next().await.unwrap().unwrap();
349        assert_eq!(batch.num_rows(), 1);
350
351        let complexity = batch
352            .column(0)
353            .as_any()
354            .downcast_ref::<StringArray>()
355            .unwrap();
356        assert_eq!(complexity.value(0), "Complex");
357
358        let points = batch
359            .column(1)
360            .as_any()
361            .downcast_ref::<UInt32Array>()
362            .unwrap();
363        assert_eq!(points.value(0), 3);
364    }
365}