1use std::sync::Arc;
7use tracing::warn;
8
9use brainwires_core::message::Message;
10use brainwires_core::provider::{ChatOptions, Provider};
11
12use crate::InferenceTimer;
13
14#[derive(Clone, Debug)]
16pub struct RelevanceResult {
17 pub content: String,
19 pub original_index: usize,
21 pub relevance_score: f32,
23 pub original_score: f32,
25 pub used_local_llm: bool,
27}
28
29impl RelevanceResult {
30 pub fn from_local(
32 content: String,
33 original_index: usize,
34 relevance_score: f32,
35 original_score: f32,
36 ) -> Self {
37 Self {
38 content,
39 original_index,
40 relevance_score,
41 original_score,
42 used_local_llm: true,
43 }
44 }
45
46 pub fn from_fallback(content: String, original_index: usize, original_score: f32) -> Self {
48 Self {
49 content,
50 original_index,
51 relevance_score: original_score,
52 original_score,
53 used_local_llm: false,
54 }
55 }
56}
57
58pub struct RelevanceScorer {
60 provider: Arc<dyn Provider>,
61 model_id: String,
62 min_score: f32,
64 max_items: usize,
66}
67
68impl RelevanceScorer {
69 pub fn new(provider: Arc<dyn Provider>, model_id: impl Into<String>) -> Self {
71 Self {
72 provider,
73 model_id: model_id.into(),
74 min_score: 0.5,
75 max_items: 10,
76 }
77 }
78
79 pub fn with_min_score(mut self, min_score: f32) -> Self {
81 self.min_score = min_score;
82 self
83 }
84
85 pub fn with_max_items(mut self, max_items: usize) -> Self {
87 self.max_items = max_items;
88 self
89 }
90
91 pub async fn rerank<T: AsRef<str>>(
95 &self,
96 query: &str,
97 items: &[(T, f32)], ) -> Vec<RelevanceResult> {
99 let timer = InferenceTimer::new("rerank_context", &self.model_id);
100
101 let items_to_score: Vec<_> = items.iter().take(self.max_items).collect();
103
104 if items_to_score.is_empty() {
105 timer.finish(true);
106 return Vec::new();
107 }
108
109 let prompt = self.build_rerank_prompt(query, &items_to_score);
111
112 let messages = vec![Message::user(&prompt)];
113 let options = ChatOptions::deterministic(100);
114
115 match self.provider.chat(&messages, None, &options).await {
116 Ok(response) => {
117 let output = response.message.text_or_summary();
118 let mut results = self.parse_rerank_output(&output, items);
119
120 results.sort_by(|a, b| {
122 b.relevance_score
123 .partial_cmp(&a.relevance_score)
124 .unwrap_or(std::cmp::Ordering::Equal)
125 });
126
127 results.retain(|r| r.relevance_score >= self.min_score);
129
130 timer.finish(true);
131 results
132 }
133 Err(e) => {
134 warn!(target: "local_llm", "Context re-ranking failed: {}", e);
135 timer.finish(false);
136
137 items
139 .iter()
140 .enumerate()
141 .filter(|(_, (_, score))| *score >= self.min_score)
142 .map(|(i, (content, score))| {
143 RelevanceResult::from_fallback(content.as_ref().to_string(), i, *score)
144 })
145 .collect()
146 }
147 }
148 }
149
150 pub async fn score_relevance(&self, query: &str, content: &str) -> Option<f32> {
152 let timer = InferenceTimer::new("score_relevance", &self.model_id);
153
154 let prompt = format!(
155 r#"Rate the relevance of this content to the query.
156
157Query: "{}"
158
159Content: "{}"
160
161Output a score from 0.0 (irrelevant) to 1.0 (highly relevant).
162Output ONLY the decimal number.
163
164Score:"#,
165 if query.len() > 100 {
166 &query[..100]
167 } else {
168 query
169 },
170 if content.len() > 300 {
171 &content[..300]
172 } else {
173 content
174 }
175 );
176
177 let messages = vec![Message::user(&prompt)];
178 let options = ChatOptions::deterministic(10);
179
180 match self.provider.chat(&messages, None, &options).await {
181 Ok(response) => {
182 let output = response.message.text_or_summary();
183 let score = self.parse_score(&output);
184 timer.finish(score.is_some());
185 score
186 }
187 Err(e) => {
188 warn!(target: "local_llm", "Relevance scoring failed: {}", e);
189 timer.finish(false);
190 None
191 }
192 }
193 }
194
195 pub fn score_heuristic(&self, query: &str, content: &str) -> f32 {
197 let query_lower = query.to_lowercase();
198 let content_lower = content.to_lowercase();
199
200 let query_words: Vec<&str> = query_lower
202 .split_whitespace()
203 .filter(|w| w.len() > 2)
204 .collect();
205
206 if query_words.is_empty() {
207 return 0.5; }
209
210 let mut matches = 0;
212 for word in &query_words {
213 if content_lower.contains(word) {
214 matches += 1;
215 }
216 }
217
218 let overlap_ratio = matches as f32 / query_words.len() as f32;
220
221 let phrase_bonus = if content_lower.contains(&query_lower) {
223 0.2
224 } else {
225 0.0
226 };
227
228 (overlap_ratio * 0.8 + phrase_bonus).min(1.0)
229 }
230
231 fn build_rerank_prompt<T: AsRef<str>>(&self, query: &str, items: &[&(T, f32)]) -> String {
233 let mut prompt = format!(
234 r#"Rank these items by relevance to the query.
235
236Query: "{}"
237
238Items:
239"#,
240 if query.len() > 150 {
241 &query[..150]
242 } else {
243 query
244 }
245 );
246
247 for (i, (content, _)) in items.iter().enumerate() {
248 let truncated = if content.as_ref().len() > 150 {
249 &content.as_ref()[..150]
250 } else {
251 content.as_ref()
252 };
253 prompt.push_str(&format!("{}. {}\n", i + 1, truncated));
254 }
255
256 prompt.push_str(
257 r#"
258Output format: item_number:score (0.0-1.0)
259Example: 1:0.9, 2:0.3, 3:0.7
260
261Scores:"#,
262 );
263
264 prompt
265 }
266
267 fn parse_rerank_output<T: AsRef<str>>(
269 &self,
270 output: &str,
271 items: &[(T, f32)],
272 ) -> Vec<RelevanceResult> {
273 let mut results = Vec::new();
274 let mut scored_indices = std::collections::HashSet::new();
275
276 for part in output.split([',', '\n', ' ']) {
278 let part = part.trim();
279 if let Some(colon_pos) = part.find(':')
280 && let (Ok(idx), score_str) = (
281 part[..colon_pos].trim().parse::<usize>(),
282 part[colon_pos + 1..].trim(),
283 )
284 && let Ok(score) = score_str.parse::<f32>()
285 {
286 let actual_idx = idx.saturating_sub(1); if actual_idx < items.len() && !scored_indices.contains(&actual_idx) {
288 scored_indices.insert(actual_idx);
289 let (content, original_score) = &items[actual_idx];
290 results.push(RelevanceResult::from_local(
291 content.as_ref().to_string(),
292 actual_idx,
293 score.clamp(0.0, 1.0),
294 *original_score,
295 ));
296 }
297 }
298 }
299
300 for (i, (content, original_score)) in items.iter().enumerate() {
302 if !scored_indices.contains(&i) {
303 results.push(RelevanceResult::from_fallback(
304 content.as_ref().to_string(),
305 i,
306 *original_score,
307 ));
308 }
309 }
310
311 results
312 }
313
314 fn parse_score(&self, output: &str) -> Option<f32> {
316 let trimmed = output.trim();
317
318 if let Ok(score) = trimmed.parse::<f32>() {
320 return Some(score.clamp(0.0, 1.0));
321 }
322
323 if let Ok(re) = regex::Regex::new(r"(\d+\.?\d*)")
325 && let Some(captures) = re.captures(trimmed)
326 && let Some(m) = captures.get(1)
327 && let Ok(score) = m.as_str().parse::<f32>()
328 {
329 return Some(score.clamp(0.0, 1.0));
330 }
331
332 None
333 }
334}
335
336pub struct RelevanceScorerBuilder {
338 provider: Option<Arc<dyn Provider>>,
339 model_id: String,
340 min_score: f32,
341 max_items: usize,
342}
343
344impl Default for RelevanceScorerBuilder {
345 fn default() -> Self {
346 Self {
347 provider: None,
348 model_id: "lfm2-350m".to_string(),
349 min_score: 0.5,
350 max_items: 10,
351 }
352 }
353}
354
355impl RelevanceScorerBuilder {
356 pub fn new() -> Self {
358 Self::default()
359 }
360
361 pub fn provider(mut self, provider: Arc<dyn Provider>) -> Self {
363 self.provider = Some(provider);
364 self
365 }
366
367 pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
369 self.model_id = model_id.into();
370 self
371 }
372
373 pub fn min_score(mut self, min_score: f32) -> Self {
375 self.min_score = min_score;
376 self
377 }
378
379 pub fn max_items(mut self, max_items: usize) -> Self {
381 self.max_items = max_items;
382 self
383 }
384
385 pub fn build(self) -> Option<RelevanceScorer> {
387 self.provider.map(|p| {
388 RelevanceScorer::new(p, self.model_id)
389 .with_min_score(self.min_score)
390 .with_max_items(self.max_items)
391 })
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_relevance_result() {
401 let local = RelevanceResult::from_local("test content".to_string(), 0, 0.9, 0.75);
402 assert!(local.used_local_llm);
403 assert_eq!(local.relevance_score, 0.9);
404 assert_eq!(local.original_score, 0.75);
405
406 let fallback = RelevanceResult::from_fallback("test content".to_string(), 1, 0.7);
407 assert!(!fallback.used_local_llm);
408 assert_eq!(fallback.relevance_score, 0.7);
409 }
410
411 #[test]
412 fn test_heuristic_scoring() {
413 let score = score_heuristic_direct(
414 "rust async programming",
415 "This article discusses async programming in Rust using tokio",
416 );
417 assert!(score > 0.5);
418
419 let low_score = score_heuristic_direct(
420 "python web development",
421 "This article discusses async programming in Rust using tokio",
422 );
423 assert!(low_score < 0.3);
424 }
425
426 fn score_heuristic_direct(query: &str, content: &str) -> f32 {
427 let query_lower = query.to_lowercase();
428 let content_lower = content.to_lowercase();
429
430 let query_words: Vec<&str> = query_lower
431 .split_whitespace()
432 .filter(|w| w.len() > 2)
433 .collect();
434
435 if query_words.is_empty() {
436 return 0.5;
437 }
438
439 let mut matches = 0;
440 for word in &query_words {
441 if content_lower.contains(word) {
442 matches += 1;
443 }
444 }
445
446 let overlap_ratio = matches as f32 / query_words.len() as f32;
447 let phrase_bonus = if content_lower.contains(&query_lower) {
448 0.2
449 } else {
450 0.0
451 };
452
453 (overlap_ratio * 0.8 + phrase_bonus).min(1.0)
454 }
455
456 #[test]
457 fn test_parse_rerank_output() {
458 let output = "1:0.9, 2:0.5, 3:0.7";
459 let items = vec![
460 ("first item".to_string(), 0.8),
461 ("second item".to_string(), 0.6),
462 ("third item".to_string(), 0.7),
463 ];
464
465 let results = parse_rerank_output_direct(output, &items);
466 assert_eq!(results.len(), 3);
467
468 let best = results
470 .iter()
471 .max_by(|a, b| a.relevance_score.partial_cmp(&b.relevance_score).unwrap())
472 .unwrap();
473 assert_eq!(best.original_index, 0); }
475
476 fn parse_rerank_output_direct(output: &str, items: &[(String, f32)]) -> Vec<RelevanceResult> {
477 let mut results = Vec::new();
478 let mut scored_indices = std::collections::HashSet::new();
479
480 for part in output.split(',') {
481 let part = part.trim();
482 if let Some(colon_pos) = part.find(':') {
483 if let (Ok(idx), score_str) = (
484 part[..colon_pos].trim().parse::<usize>(),
485 part[colon_pos + 1..].trim(),
486 ) {
487 if let Ok(score) = score_str.parse::<f32>() {
488 let actual_idx = idx.saturating_sub(1);
489 if actual_idx < items.len() && !scored_indices.contains(&actual_idx) {
490 scored_indices.insert(actual_idx);
491 let (content, original_score) = &items[actual_idx];
492 results.push(RelevanceResult::from_local(
493 content.clone(),
494 actual_idx,
495 score.clamp(0.0, 1.0),
496 *original_score,
497 ));
498 }
499 }
500 }
501 }
502 }
503
504 results
505 }
506
507 #[test]
508 fn test_parse_score() {
509 assert_eq!(parse_score_direct("0.85"), Some(0.85));
510 assert_eq!(parse_score_direct("Score: 0.7"), Some(0.7));
511 assert_eq!(parse_score_direct("1.5"), Some(1.0)); assert_eq!(parse_score_direct("-0.5"), Some(0.0)); assert_eq!(parse_score_direct("not a score"), None); }
515
516 fn parse_score_direct(output: &str) -> Option<f32> {
517 let trimmed = output.trim();
518
519 if let Ok(score) = trimmed.parse::<f32>() {
520 return Some(score.clamp(0.0, 1.0));
521 }
522
523 if let Ok(re) = regex::Regex::new(r"(\d+\.?\d*)") {
524 if let Some(captures) = re.captures(trimmed) {
525 if let Some(m) = captures.get(1) {
526 if let Ok(score) = m.as_str().parse::<f32>() {
527 return Some(score.clamp(0.0, 1.0));
528 }
529 }
530 }
531 }
532
533 None
534 }
535}