Skip to main content

hirn_engine/
adaptive.rs

1//! Adaptive retrieval strategy (Jeong et al., NAACL 2024).
2//!
3//! Classifies query complexity using lightweight heuristics and routes to
4//! the optimal retrieval strategy:
5//!
6//! | Complexity | Strategy                                    |
7//! |------------|---------------------------------------------|
8//! | Simple     | Local only (HNSW + spreading activation)    |
9//! | Moderate   | Hybrid (local + community global)           |
10//! | Complex    | Full pipeline: RAPTOR + community + local   |
11//!
12//! The classifier uses five orthogonal signals:
13//! 1. **Token count** — longer queries tend to be more complex.
14//! 2. **Clause count** — more WHERE/INVOLVING/TEMPORAL clauses = more complex.
15//! 3. **Question words** — "why", "how", "compare" suggest analytical queries.
16//! 4. **Entity count** — multi-entity queries benefit from graph traversal.
17//! 5. **Temporal scope** — temporal constraints suggest moderate complexity.
18//!
19//! Reference: "Adaptive-RAG: Learning to Adapt Retrieval-Augmented
20//!             Large Language Models through Question Complexity"
21//!             (Jeong et al., NAACL 2024)
22
23use hirn_query::ast::RetrievalMode;
24
25/// Query complexity level determined by the adaptive classifier.
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum QueryComplexity {
28    /// Factoid / keyword lookups — vector search is sufficient.
29    Simple,
30    /// Multi-faceted queries — benefit from both local and global retrieval.
31    Moderate,
32    /// Analytical / comparative / multi-hop — need the full retrieval pipeline.
33    Complex,
34}
35
36/// Classify query complexity and return the recommended `RetrievalMode`.
37///
38/// This is a deterministic, rule-based classifier inspired by Adaptive-RAG.
39/// It avoids the cost of an LLM call for routing while still achieving good
40/// strategy selection for most queries.
41pub fn classify_and_route(
42    query: &str,
43    involving_count: usize,
44    where_count: usize,
45    has_temporal: bool,
46    has_expand: bool,
47    has_follow_causes: bool,
48) -> RetrievalMode {
49    let complexity = classify_query(
50        query,
51        involving_count,
52        where_count,
53        has_temporal,
54        has_expand,
55        has_follow_causes,
56    );
57
58    match complexity {
59        QueryComplexity::Simple => RetrievalMode::Local,
60        QueryComplexity::Moderate => RetrievalMode::Hybrid,
61        QueryComplexity::Complex => RetrievalMode::Raptor,
62    }
63}
64
65/// Classify query complexity into Simple / Moderate / Complex.
66pub fn classify_query(
67    query: &str,
68    involving_count: usize,
69    where_count: usize,
70    has_temporal: bool,
71    has_expand: bool,
72    has_follow_causes: bool,
73) -> QueryComplexity {
74    let mut score: u32 = 0;
75
76    // Signal 1: Token count (whitespace-split approximation).
77    let token_count = query.split_whitespace().count();
78    if token_count >= 20 {
79        score += 3;
80    } else if token_count >= 10 {
81        score += 2;
82    } else if token_count >= 4 {
83        score += 1;
84    }
85
86    // Signal 2: Clause count — each additional clause adds complexity.
87    score += (where_count as u32).min(3);
88    if involving_count > 2 {
89        score += 2;
90    } else if involving_count > 0 {
91        score += 1;
92    }
93
94    // Signal 3: Complex question words / analytical patterns.
95    let lower = query.to_lowercase();
96    let complex_patterns = [
97        "compare",
98        "contrast",
99        "why",
100        "how does",
101        "what caused",
102        "relationship between",
103        "difference between",
104        "trade-off",
105        "pros and cons",
106        "implications of",
107        "summarize all",
108        "overview of",
109        "explain the",
110        "analyze",
111    ];
112    let moderate_patterns = [
113        "how", "what are", "describe", "list", "when did", "where", "who", "which",
114    ];
115
116    let complex_hits = complex_patterns
117        .iter()
118        .filter(|p| lower.contains(*p))
119        .count();
120    let moderate_hits = moderate_patterns
121        .iter()
122        .filter(|p| lower.contains(*p))
123        .count();
124
125    score += (complex_hits as u32) * 2;
126    score += (moderate_hits as u32).min(2);
127
128    // Signal 4: Temporal scope adds moderate complexity.
129    if has_temporal {
130        score += 2;
131    }
132
133    // Signal 5: Expand / follow_causes demand graph traversal.
134    if has_expand {
135        score += 3;
136    }
137    if has_follow_causes {
138        score += 3;
139    }
140
141    // Route based on aggregate score.
142    if score >= 6 {
143        QueryComplexity::Complex
144    } else if score >= 3 {
145        QueryComplexity::Moderate
146    } else {
147        QueryComplexity::Simple
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn simple_factoid_query() {
157        let c = classify_query("what is JWT", 0, 0, false, false, false);
158        assert_eq!(c, QueryComplexity::Simple);
159    }
160
161    #[test]
162    fn moderate_query_with_entity() {
163        let c = classify_query(
164            "how does authentication work with OAuth tokens",
165            1,
166            0,
167            false,
168            false,
169            false,
170        );
171        assert_eq!(c, QueryComplexity::Moderate);
172    }
173
174    #[test]
175    fn complex_analytical_query() {
176        let c = classify_query(
177            "compare the trade-off between JWT and session-based authentication across all services",
178            3,
179            1,
180            false,
181            true,
182            false,
183        );
184        assert_eq!(c, QueryComplexity::Complex);
185    }
186
187    #[test]
188    fn temporal_adds_complexity() {
189        let c = classify_query("what happened with deployments", 0, 0, true, false, false);
190        assert_eq!(c, QueryComplexity::Moderate);
191    }
192
193    #[test]
194    fn follow_causes_is_complex() {
195        let c = classify_query("why did the service fail", 0, 0, false, false, true);
196        assert_eq!(c, QueryComplexity::Complex);
197    }
198
199    #[test]
200    fn classify_and_route_simple() {
201        let mode = classify_and_route("hello", 0, 0, false, false, false);
202        assert_eq!(mode, RetrievalMode::Local);
203    }
204
205    #[test]
206    fn classify_and_route_complex() {
207        let mode = classify_and_route(
208            "compare all authentication strategies and their trade-offs",
209            2,
210            1,
211            true,
212            true,
213            false,
214        );
215        assert_eq!(mode, RetrievalMode::Raptor);
216    }
217}