Skip to main content

brainwires_reasoning/
retrieval_classifier.rs

1//! Retrieval Classifier - Enhanced Retrieval Gating
2//!
3//! Uses a provider to classify retrieval need semantically,
4//! replacing pattern-based detection with understanding of intent.
5
6use std::sync::Arc;
7use tracing::warn;
8
9use brainwires_core::message::Message;
10use brainwires_core::provider::{ChatOptions, Provider};
11
12use crate::InferenceTimer;
13
14/// Result of retrieval classification
15#[derive(Clone, Copy, Debug, PartialEq, Eq)]
16pub enum RetrievalNeed {
17    /// No retrieval needed - context is sufficient
18    None,
19    /// Low priority - might benefit from retrieval
20    Low,
21    /// Medium priority - likely needs retrieval
22    Medium,
23    /// High priority - definitely needs retrieval
24    High,
25}
26
27impl RetrievalNeed {
28    /// Check if retrieval should be performed
29    pub fn should_retrieve(&self) -> bool {
30        matches!(self, RetrievalNeed::Medium | RetrievalNeed::High)
31    }
32
33    /// Convert to a priority score (0.0 - 1.0)
34    pub fn as_score(&self) -> f32 {
35        match self {
36            RetrievalNeed::None => 0.0,
37            RetrievalNeed::Low => 0.25,
38            RetrievalNeed::Medium => 0.6,
39            RetrievalNeed::High => 0.9,
40        }
41    }
42}
43
44/// Result of classification with confidence
45#[derive(Clone, Debug)]
46pub struct ClassificationResult {
47    /// The classified retrieval need
48    pub need: RetrievalNeed,
49    /// Confidence score (0.0 - 1.0)
50    pub confidence: f32,
51    /// Whether LLM was used
52    pub used_local_llm: bool,
53    /// Detected intent (if LLM was used)
54    pub intent: Option<String>,
55}
56
57impl ClassificationResult {
58    /// Create a result from LLM classification
59    pub fn from_local(need: RetrievalNeed, confidence: f32, intent: Option<String>) -> Self {
60        Self {
61            need,
62            confidence,
63            used_local_llm: true,
64            intent,
65        }
66    }
67
68    /// Create a result from pattern-based fallback
69    pub fn from_fallback(need: RetrievalNeed, confidence: f32) -> Self {
70        Self {
71            need,
72            confidence,
73            used_local_llm: false,
74            intent: None,
75        }
76    }
77}
78
79/// Retrieval classifier for enhanced gating
80pub struct RetrievalClassifier {
81    provider: Arc<dyn Provider>,
82    model_id: String,
83}
84
85impl RetrievalClassifier {
86    /// Create a new retrieval classifier
87    pub fn new(provider: Arc<dyn Provider>, model_id: impl Into<String>) -> Self {
88        Self {
89            provider,
90            model_id: model_id.into(),
91        }
92    }
93
94    /// Classify retrieval need using the provider
95    ///
96    /// Returns classification with intent understanding.
97    pub async fn classify(&self, query: &str, context_len: usize) -> Option<ClassificationResult> {
98        let timer = InferenceTimer::new("retrieval_classify", &self.model_id);
99
100        let prompt = self.build_classification_prompt(query, context_len);
101
102        let messages = vec![Message::user(&prompt)];
103        let options = ChatOptions::deterministic(50);
104
105        match self.provider.chat(&messages, None, &options).await {
106            Ok(response) => {
107                let output = response.message.text_or_summary();
108                let result = self.parse_classification(&output);
109                timer.finish(true);
110                Some(result)
111            }
112            Err(e) => {
113                warn!(target: "local_llm", "Retrieval classification failed: {}", e);
114                timer.finish(false);
115                None
116            }
117        }
118    }
119
120    /// Heuristic classification (pattern-based fallback)
121    ///
122    /// Used when provider is unavailable or fails.
123    pub fn classify_heuristic(&self, query: &str, context_len: usize) -> ClassificationResult {
124        let lower = query.to_lowercase();
125        let mut score = 0.0f32;
126        let mut matches = 0;
127
128        // Reference patterns (high weight)
129        let reference_patterns = [
130            "earlier",
131            "before",
132            "we discussed",
133            "remember when",
134            "what was",
135            "didn't we",
136            "you mentioned",
137            "as i said",
138            "previously",
139            "last time",
140            "originally",
141            "initially",
142            "you said",
143            "i said",
144            "we talked",
145            "back when",
146            "recall",
147            "mentioned earlier",
148            "as mentioned",
149        ];
150
151        for pattern in reference_patterns {
152            if lower.contains(pattern) {
153                score += 0.4;
154                matches += 1;
155            }
156        }
157
158        // Question patterns (medium weight)
159        let question_patterns = [
160            "what did",
161            "when did",
162            "why did",
163            "how did",
164            "where was",
165            "who was",
166        ];
167
168        for pattern in question_patterns {
169            if lower.contains(pattern) {
170                score += 0.25;
171                matches += 1;
172            }
173        }
174
175        // Continuation patterns (low weight)
176        let continuation_patterns = [
177            "continue",
178            "keep going",
179            "and then",
180            "what about",
181            "more about",
182            "tell me more",
183            "go on",
184        ];
185
186        for pattern in continuation_patterns {
187            if lower.contains(pattern) {
188                score += 0.15;
189                matches += 1;
190            }
191        }
192
193        // Context length adjustment
194        if context_len < 3 {
195            score += 0.3;
196        } else if context_len < 5 {
197            score += 0.2;
198        } else if context_len < 10 {
199            score += 0.1;
200        }
201
202        // Pronoun patterns (only for short queries)
203        if context_len < 10 && query.len() < 100 && lower.contains('?') {
204            let pronouns = ["it", "they", "that", "those", "the one"];
205            if pronouns
206                .iter()
207                .any(|p| lower.split_whitespace().any(|w| w == *p))
208            {
209                score += 0.2;
210            }
211        }
212
213        score = score.min(1.0);
214
215        let need = match score {
216            s if s >= 0.6 => RetrievalNeed::High,
217            s if s >= 0.35 => RetrievalNeed::Medium,
218            s if s >= 0.15 => RetrievalNeed::Low,
219            _ => RetrievalNeed::None,
220        };
221
222        let confidence = if matches > 0 {
223            0.7 + (matches as f32 * 0.05).min(0.2)
224        } else {
225            0.5
226        };
227
228        ClassificationResult::from_fallback(need, confidence)
229    }
230
231    /// Build the classification prompt
232    fn build_classification_prompt(&self, query: &str, context_len: usize) -> String {
233        format!(
234            r#"Classify if this query needs to retrieve earlier conversation context.
235
236Query: "{}"
237Recent context messages: {}
238
239Classify as:
240- NONE: Query is self-contained, no prior context needed
241- LOW: Might benefit from context but not required
242- MEDIUM: Likely references earlier discussion
243- HIGH: Definitely refers to prior conversation
244
245Output format: LEVEL: brief reason
246Example: HIGH: references "earlier" and asks about past discussion
247
248Classification:"#,
249            if query.len() > 200 {
250                &query[..200]
251            } else {
252                query
253            },
254            context_len
255        )
256    }
257
258    /// Parse the LLM output to extract classification
259    fn parse_classification(&self, output: &str) -> ClassificationResult {
260        let upper = output.to_uppercase();
261        let trimmed = output.trim();
262
263        // Extract intent from the reason part
264        let intent = trimmed
265            .find(':')
266            .map(|colon_pos| trimmed[colon_pos + 1..].trim().to_string());
267
268        // Parse the level
269        let need = if upper.starts_with("HIGH") || upper.contains("HIGH:") {
270            RetrievalNeed::High
271        } else if upper.starts_with("MEDIUM") || upper.contains("MEDIUM:") {
272            RetrievalNeed::Medium
273        } else if upper.starts_with("LOW") || upper.contains("LOW:") {
274            RetrievalNeed::Low
275        } else if upper.starts_with("NONE") || upper.contains("NONE:") {
276            RetrievalNeed::None
277        } else {
278            // Ambiguous - default to low
279            RetrievalNeed::Low
280        };
281
282        ClassificationResult::from_local(need, 0.8, intent)
283    }
284}
285
286/// Builder for RetrievalClassifier
287pub struct RetrievalClassifierBuilder {
288    provider: Option<Arc<dyn Provider>>,
289    model_id: String,
290}
291
292impl Default for RetrievalClassifierBuilder {
293    fn default() -> Self {
294        Self {
295            provider: None,
296            model_id: "lfm2-350m".to_string(),
297        }
298    }
299}
300
301impl RetrievalClassifierBuilder {
302    /// Create a new builder with default settings.
303    pub fn new() -> Self {
304        Self::default()
305    }
306
307    /// Set the provider to use for retrieval classification.
308    pub fn provider(mut self, provider: Arc<dyn Provider>) -> Self {
309        self.provider = Some(provider);
310        self
311    }
312
313    /// Set the model ID to use for inference.
314    pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
315        self.model_id = model_id.into();
316        self
317    }
318
319    /// Build the retrieval classifier, returning `None` if no provider was set.
320    pub fn build(self) -> Option<RetrievalClassifier> {
321        self.provider
322            .map(|p| RetrievalClassifier::new(p, self.model_id))
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn test_retrieval_need_methods() {
332        assert!(!RetrievalNeed::None.should_retrieve());
333        assert!(!RetrievalNeed::Low.should_retrieve());
334        assert!(RetrievalNeed::Medium.should_retrieve());
335        assert!(RetrievalNeed::High.should_retrieve());
336
337        assert_eq!(RetrievalNeed::None.as_score(), 0.0);
338        assert!(RetrievalNeed::High.as_score() > RetrievalNeed::Low.as_score());
339    }
340
341    #[test]
342    fn test_classification_result() {
343        let local = ClassificationResult::from_local(
344            RetrievalNeed::High,
345            0.9,
346            Some("references earlier discussion".to_string()),
347        );
348        assert!(local.used_local_llm);
349        assert!(local.intent.is_some());
350
351        let fallback = ClassificationResult::from_fallback(RetrievalNeed::Medium, 0.7);
352        assert!(!fallback.used_local_llm);
353        assert!(fallback.intent.is_none());
354    }
355
356    #[test]
357    fn test_heuristic_classification_reference() {
358        let _classifier = RetrievalClassifierBuilder::default();
359
360        // Test reference patterns
361        let result = classify_heuristic_direct("What did we discuss earlier?", 10);
362        assert_eq!(result.need, RetrievalNeed::High);
363    }
364
365    #[test]
366    fn test_heuristic_classification_none() {
367        let result = classify_heuristic_direct("Write a hello world function in Python", 20);
368        assert_eq!(result.need, RetrievalNeed::None);
369    }
370
371    #[test]
372    fn test_heuristic_short_context() {
373        // Short context should increase retrieval need
374        let result = classify_heuristic_direct("Continue please", 2);
375        assert!(result.need.should_retrieve());
376    }
377
378    fn classify_heuristic_direct(query: &str, context_len: usize) -> ClassificationResult {
379        let lower = query.to_lowercase();
380        let mut score = 0.0f32;
381        let mut matches = 0;
382
383        let reference_patterns = ["earlier", "before", "we discussed", "previously"];
384
385        for pattern in reference_patterns {
386            if lower.contains(pattern) {
387                score += 0.4;
388                matches += 1;
389            }
390        }
391
392        let question_patterns = ["what did", "when did", "why did"];
393
394        for pattern in question_patterns {
395            if lower.contains(pattern) {
396                score += 0.25;
397                matches += 1;
398            }
399        }
400
401        // Continuation patterns (matching the real implementation)
402        let continuation_patterns = ["continue", "keep going", "and then"];
403
404        for pattern in continuation_patterns {
405            if lower.contains(pattern) {
406                score += 0.15;
407                matches += 1;
408            }
409        }
410
411        if context_len < 3 {
412            score += 0.3;
413        } else if context_len < 5 {
414            score += 0.2;
415        }
416
417        score = score.min(1.0);
418
419        let need = match score {
420            s if s >= 0.6 => RetrievalNeed::High,
421            s if s >= 0.35 => RetrievalNeed::Medium,
422            s if s >= 0.15 => RetrievalNeed::Low,
423            _ => RetrievalNeed::None,
424        };
425
426        let confidence = if matches > 0 {
427            0.7 + (matches as f32 * 0.05).min(0.2)
428        } else {
429            0.5
430        };
431
432        ClassificationResult::from_fallback(need, confidence)
433    }
434
435    #[test]
436    fn test_parse_classification() {
437        // Test parsing logic
438        let high = parse_classification_direct("HIGH: references earlier discussion");
439        assert_eq!(high.need, RetrievalNeed::High);
440
441        let none = parse_classification_direct("NONE: self-contained query");
442        assert_eq!(none.need, RetrievalNeed::None);
443    }
444
445    fn parse_classification_direct(output: &str) -> ClassificationResult {
446        let upper = output.to_uppercase();
447        let trimmed = output.trim();
448
449        let intent = if let Some(colon_pos) = trimmed.find(':') {
450            Some(trimmed[colon_pos + 1..].trim().to_string())
451        } else {
452            None
453        };
454
455        let need = if upper.starts_with("HIGH") {
456            RetrievalNeed::High
457        } else if upper.starts_with("MEDIUM") {
458            RetrievalNeed::Medium
459        } else if upper.starts_with("LOW") {
460            RetrievalNeed::Low
461        } else if upper.starts_with("NONE") {
462            RetrievalNeed::None
463        } else {
464            RetrievalNeed::Low
465        };
466
467        ClassificationResult::from_local(need, 0.8, intent)
468    }
469}