1use std::sync::Arc;
7use tracing::warn;
8
9use brainwires_core::message::Message;
10use brainwires_core::provider::{ChatOptions, Provider};
11
12use crate::InferenceTimer;
13
14#[derive(Clone, Copy, Debug, PartialEq, Eq)]
16pub enum RetrievalNeed {
17 None,
19 Low,
21 Medium,
23 High,
25}
26
27impl RetrievalNeed {
28 pub fn should_retrieve(&self) -> bool {
30 matches!(self, RetrievalNeed::Medium | RetrievalNeed::High)
31 }
32
33 pub fn as_score(&self) -> f32 {
35 match self {
36 RetrievalNeed::None => 0.0,
37 RetrievalNeed::Low => 0.25,
38 RetrievalNeed::Medium => 0.6,
39 RetrievalNeed::High => 0.9,
40 }
41 }
42}
43
44#[derive(Clone, Debug)]
46pub struct ClassificationResult {
47 pub need: RetrievalNeed,
49 pub confidence: f32,
51 pub used_local_llm: bool,
53 pub intent: Option<String>,
55}
56
57impl ClassificationResult {
58 pub fn from_local(need: RetrievalNeed, confidence: f32, intent: Option<String>) -> Self {
60 Self {
61 need,
62 confidence,
63 used_local_llm: true,
64 intent,
65 }
66 }
67
68 pub fn from_fallback(need: RetrievalNeed, confidence: f32) -> Self {
70 Self {
71 need,
72 confidence,
73 used_local_llm: false,
74 intent: None,
75 }
76 }
77}
78
79pub struct RetrievalClassifier {
81 provider: Arc<dyn Provider>,
82 model_id: String,
83}
84
85impl RetrievalClassifier {
86 pub fn new(provider: Arc<dyn Provider>, model_id: impl Into<String>) -> Self {
88 Self {
89 provider,
90 model_id: model_id.into(),
91 }
92 }
93
94 pub async fn classify(&self, query: &str, context_len: usize) -> Option<ClassificationResult> {
98 let timer = InferenceTimer::new("retrieval_classify", &self.model_id);
99
100 let prompt = self.build_classification_prompt(query, context_len);
101
102 let messages = vec![Message::user(&prompt)];
103 let options = ChatOptions::deterministic(50);
104
105 match self.provider.chat(&messages, None, &options).await {
106 Ok(response) => {
107 let output = response.message.text_or_summary();
108 let result = self.parse_classification(&output);
109 timer.finish(true);
110 Some(result)
111 }
112 Err(e) => {
113 warn!(target: "local_llm", "Retrieval classification failed: {}", e);
114 timer.finish(false);
115 None
116 }
117 }
118 }
119
120 pub fn classify_heuristic(&self, query: &str, context_len: usize) -> ClassificationResult {
124 let lower = query.to_lowercase();
125 let mut score = 0.0f32;
126 let mut matches = 0;
127
128 let reference_patterns = [
130 "earlier",
131 "before",
132 "we discussed",
133 "remember when",
134 "what was",
135 "didn't we",
136 "you mentioned",
137 "as i said",
138 "previously",
139 "last time",
140 "originally",
141 "initially",
142 "you said",
143 "i said",
144 "we talked",
145 "back when",
146 "recall",
147 "mentioned earlier",
148 "as mentioned",
149 ];
150
151 for pattern in reference_patterns {
152 if lower.contains(pattern) {
153 score += 0.4;
154 matches += 1;
155 }
156 }
157
158 let question_patterns = [
160 "what did",
161 "when did",
162 "why did",
163 "how did",
164 "where was",
165 "who was",
166 ];
167
168 for pattern in question_patterns {
169 if lower.contains(pattern) {
170 score += 0.25;
171 matches += 1;
172 }
173 }
174
175 let continuation_patterns = [
177 "continue",
178 "keep going",
179 "and then",
180 "what about",
181 "more about",
182 "tell me more",
183 "go on",
184 ];
185
186 for pattern in continuation_patterns {
187 if lower.contains(pattern) {
188 score += 0.15;
189 matches += 1;
190 }
191 }
192
193 if context_len < 3 {
195 score += 0.3;
196 } else if context_len < 5 {
197 score += 0.2;
198 } else if context_len < 10 {
199 score += 0.1;
200 }
201
202 if context_len < 10 && query.len() < 100 && lower.contains('?') {
204 let pronouns = ["it", "they", "that", "those", "the one"];
205 if pronouns
206 .iter()
207 .any(|p| lower.split_whitespace().any(|w| w == *p))
208 {
209 score += 0.2;
210 }
211 }
212
213 score = score.min(1.0);
214
215 let need = match score {
216 s if s >= 0.6 => RetrievalNeed::High,
217 s if s >= 0.35 => RetrievalNeed::Medium,
218 s if s >= 0.15 => RetrievalNeed::Low,
219 _ => RetrievalNeed::None,
220 };
221
222 let confidence = if matches > 0 {
223 0.7 + (matches as f32 * 0.05).min(0.2)
224 } else {
225 0.5
226 };
227
228 ClassificationResult::from_fallback(need, confidence)
229 }
230
231 fn build_classification_prompt(&self, query: &str, context_len: usize) -> String {
233 format!(
234 r#"Classify if this query needs to retrieve earlier conversation context.
235
236Query: "{}"
237Recent context messages: {}
238
239Classify as:
240- NONE: Query is self-contained, no prior context needed
241- LOW: Might benefit from context but not required
242- MEDIUM: Likely references earlier discussion
243- HIGH: Definitely refers to prior conversation
244
245Output format: LEVEL: brief reason
246Example: HIGH: references "earlier" and asks about past discussion
247
248Classification:"#,
249 if query.len() > 200 {
250 &query[..200]
251 } else {
252 query
253 },
254 context_len
255 )
256 }
257
258 fn parse_classification(&self, output: &str) -> ClassificationResult {
260 let upper = output.to_uppercase();
261 let trimmed = output.trim();
262
263 let intent = trimmed
265 .find(':')
266 .map(|colon_pos| trimmed[colon_pos + 1..].trim().to_string());
267
268 let need = if upper.starts_with("HIGH") || upper.contains("HIGH:") {
270 RetrievalNeed::High
271 } else if upper.starts_with("MEDIUM") || upper.contains("MEDIUM:") {
272 RetrievalNeed::Medium
273 } else if upper.starts_with("LOW") || upper.contains("LOW:") {
274 RetrievalNeed::Low
275 } else if upper.starts_with("NONE") || upper.contains("NONE:") {
276 RetrievalNeed::None
277 } else {
278 RetrievalNeed::Low
280 };
281
282 ClassificationResult::from_local(need, 0.8, intent)
283 }
284}
285
286pub struct RetrievalClassifierBuilder {
288 provider: Option<Arc<dyn Provider>>,
289 model_id: String,
290}
291
292impl Default for RetrievalClassifierBuilder {
293 fn default() -> Self {
294 Self {
295 provider: None,
296 model_id: "lfm2-350m".to_string(),
297 }
298 }
299}
300
301impl RetrievalClassifierBuilder {
302 pub fn new() -> Self {
304 Self::default()
305 }
306
307 pub fn provider(mut self, provider: Arc<dyn Provider>) -> Self {
309 self.provider = Some(provider);
310 self
311 }
312
313 pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
315 self.model_id = model_id.into();
316 self
317 }
318
319 pub fn build(self) -> Option<RetrievalClassifier> {
321 self.provider
322 .map(|p| RetrievalClassifier::new(p, self.model_id))
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_retrieval_need_methods() {
332 assert!(!RetrievalNeed::None.should_retrieve());
333 assert!(!RetrievalNeed::Low.should_retrieve());
334 assert!(RetrievalNeed::Medium.should_retrieve());
335 assert!(RetrievalNeed::High.should_retrieve());
336
337 assert_eq!(RetrievalNeed::None.as_score(), 0.0);
338 assert!(RetrievalNeed::High.as_score() > RetrievalNeed::Low.as_score());
339 }
340
341 #[test]
342 fn test_classification_result() {
343 let local = ClassificationResult::from_local(
344 RetrievalNeed::High,
345 0.9,
346 Some("references earlier discussion".to_string()),
347 );
348 assert!(local.used_local_llm);
349 assert!(local.intent.is_some());
350
351 let fallback = ClassificationResult::from_fallback(RetrievalNeed::Medium, 0.7);
352 assert!(!fallback.used_local_llm);
353 assert!(fallback.intent.is_none());
354 }
355
356 #[test]
357 fn test_heuristic_classification_reference() {
358 let _classifier = RetrievalClassifierBuilder::default();
359
360 let result = classify_heuristic_direct("What did we discuss earlier?", 10);
362 assert_eq!(result.need, RetrievalNeed::High);
363 }
364
365 #[test]
366 fn test_heuristic_classification_none() {
367 let result = classify_heuristic_direct("Write a hello world function in Python", 20);
368 assert_eq!(result.need, RetrievalNeed::None);
369 }
370
371 #[test]
372 fn test_heuristic_short_context() {
373 let result = classify_heuristic_direct("Continue please", 2);
375 assert!(result.need.should_retrieve());
376 }
377
378 fn classify_heuristic_direct(query: &str, context_len: usize) -> ClassificationResult {
379 let lower = query.to_lowercase();
380 let mut score = 0.0f32;
381 let mut matches = 0;
382
383 let reference_patterns = ["earlier", "before", "we discussed", "previously"];
384
385 for pattern in reference_patterns {
386 if lower.contains(pattern) {
387 score += 0.4;
388 matches += 1;
389 }
390 }
391
392 let question_patterns = ["what did", "when did", "why did"];
393
394 for pattern in question_patterns {
395 if lower.contains(pattern) {
396 score += 0.25;
397 matches += 1;
398 }
399 }
400
401 let continuation_patterns = ["continue", "keep going", "and then"];
403
404 for pattern in continuation_patterns {
405 if lower.contains(pattern) {
406 score += 0.15;
407 matches += 1;
408 }
409 }
410
411 if context_len < 3 {
412 score += 0.3;
413 } else if context_len < 5 {
414 score += 0.2;
415 }
416
417 score = score.min(1.0);
418
419 let need = match score {
420 s if s >= 0.6 => RetrievalNeed::High,
421 s if s >= 0.35 => RetrievalNeed::Medium,
422 s if s >= 0.15 => RetrievalNeed::Low,
423 _ => RetrievalNeed::None,
424 };
425
426 let confidence = if matches > 0 {
427 0.7 + (matches as f32 * 0.05).min(0.2)
428 } else {
429 0.5
430 };
431
432 ClassificationResult::from_fallback(need, confidence)
433 }
434
435 #[test]
436 fn test_parse_classification() {
437 let high = parse_classification_direct("HIGH: references earlier discussion");
439 assert_eq!(high.need, RetrievalNeed::High);
440
441 let none = parse_classification_direct("NONE: self-contained query");
442 assert_eq!(none.need, RetrievalNeed::None);
443 }
444
445 fn parse_classification_direct(output: &str) -> ClassificationResult {
446 let upper = output.to_uppercase();
447 let trimmed = output.trim();
448
449 let intent = if let Some(colon_pos) = trimmed.find(':') {
450 Some(trimmed[colon_pos + 1..].trim().to_string())
451 } else {
452 None
453 };
454
455 let need = if upper.starts_with("HIGH") {
456 RetrievalNeed::High
457 } else if upper.starts_with("MEDIUM") {
458 RetrievalNeed::Medium
459 } else if upper.starts_with("LOW") {
460 RetrievalNeed::Low
461 } else if upper.starts_with("NONE") {
462 RetrievalNeed::None
463 } else {
464 RetrievalNeed::Low
465 };
466
467 ClassificationResult::from_local(need, 0.8, intent)
468 }
469}