Skip to main content

brainwires_reasoning/
relevance_scorer.rs

1//! Relevance Scorer - Context Re-ranking
2//!
3//! Uses a provider to score and re-rank retrieved context items
4//! based on semantic relevance to the query, replacing fixed thresholds.
5
6use std::sync::Arc;
7use tracing::warn;
8
9use brainwires_core::message::Message;
10use brainwires_core::provider::{ChatOptions, Provider};
11
12use crate::InferenceTimer;
13
14/// Result of relevance scoring
15#[derive(Clone, Debug)]
16pub struct RelevanceResult {
17    /// The scored content
18    pub content: String,
19    /// Original index in the input list
20    pub original_index: usize,
21    /// Relevance score (0.0 - 1.0)
22    pub relevance_score: f32,
23    /// Original similarity score (before re-ranking)
24    pub original_score: f32,
25    /// Whether LLM was used for scoring
26    pub used_local_llm: bool,
27}
28
29impl RelevanceResult {
30    /// Create from LLM scoring
31    pub fn from_local(
32        content: String,
33        original_index: usize,
34        relevance_score: f32,
35        original_score: f32,
36    ) -> Self {
37        Self {
38            content,
39            original_index,
40            relevance_score,
41            original_score,
42            used_local_llm: true,
43        }
44    }
45
46    /// Create from fallback (keep original score)
47    pub fn from_fallback(content: String, original_index: usize, original_score: f32) -> Self {
48        Self {
49            content,
50            original_index,
51            relevance_score: original_score,
52            original_score,
53            used_local_llm: false,
54        }
55    }
56}
57
58/// Relevance scorer for context re-ranking
59pub struct RelevanceScorer {
60    provider: Arc<dyn Provider>,
61    model_id: String,
62    /// Minimum score to include in results
63    min_score: f32,
64    /// Maximum items to re-rank (for efficiency)
65    max_items: usize,
66}
67
68impl RelevanceScorer {
69    /// Create a new relevance scorer
70    pub fn new(provider: Arc<dyn Provider>, model_id: impl Into<String>) -> Self {
71        Self {
72            provider,
73            model_id: model_id.into(),
74            min_score: 0.5,
75            max_items: 10,
76        }
77    }
78
79    /// Set minimum relevance score threshold
80    pub fn with_min_score(mut self, min_score: f32) -> Self {
81        self.min_score = min_score;
82        self
83    }
84
85    /// Set maximum items to re-rank
86    pub fn with_max_items(mut self, max_items: usize) -> Self {
87        self.max_items = max_items;
88        self
89    }
90
91    /// Re-rank a list of retrieved items by semantic relevance
92    ///
93    /// Returns items sorted by relevance score (highest first).
94    pub async fn rerank<T: AsRef<str>>(
95        &self,
96        query: &str,
97        items: &[(T, f32)], // (content, original_score) pairs
98    ) -> Vec<RelevanceResult> {
99        let timer = InferenceTimer::new("rerank_context", &self.model_id);
100
101        // Limit items for efficiency
102        let items_to_score: Vec<_> = items.iter().take(self.max_items).collect();
103
104        if items_to_score.is_empty() {
105            timer.finish(true);
106            return Vec::new();
107        }
108
109        // Build scoring prompt
110        let prompt = self.build_rerank_prompt(query, &items_to_score);
111
112        let messages = vec![Message::user(&prompt)];
113        let options = ChatOptions::deterministic(100);
114
115        match self.provider.chat(&messages, None, &options).await {
116            Ok(response) => {
117                let output = response.message.text_or_summary();
118                let mut results = self.parse_rerank_output(&output, items);
119
120                // Sort by relevance score descending
121                results.sort_by(|a, b| {
122                    b.relevance_score
123                        .partial_cmp(&a.relevance_score)
124                        .unwrap_or(std::cmp::Ordering::Equal)
125                });
126
127                // Filter by minimum score
128                results.retain(|r| r.relevance_score >= self.min_score);
129
130                timer.finish(true);
131                results
132            }
133            Err(e) => {
134                warn!(target: "local_llm", "Context re-ranking failed: {}", e);
135                timer.finish(false);
136
137                // Fallback: keep original order/scores
138                items
139                    .iter()
140                    .enumerate()
141                    .filter(|(_, (_, score))| *score >= self.min_score)
142                    .map(|(i, (content, score))| {
143                        RelevanceResult::from_fallback(content.as_ref().to_string(), i, *score)
144                    })
145                    .collect()
146            }
147        }
148    }
149
150    /// Score a single item's relevance to a query
151    pub async fn score_relevance(&self, query: &str, content: &str) -> Option<f32> {
152        let timer = InferenceTimer::new("score_relevance", &self.model_id);
153
154        let prompt = format!(
155            r#"Rate the relevance of this content to the query.
156
157Query: "{}"
158
159Content: "{}"
160
161Output a score from 0.0 (irrelevant) to 1.0 (highly relevant).
162Output ONLY the decimal number.
163
164Score:"#,
165            if query.len() > 100 {
166                &query[..100]
167            } else {
168                query
169            },
170            if content.len() > 300 {
171                &content[..300]
172            } else {
173                content
174            }
175        );
176
177        let messages = vec![Message::user(&prompt)];
178        let options = ChatOptions::deterministic(10);
179
180        match self.provider.chat(&messages, None, &options).await {
181            Ok(response) => {
182                let output = response.message.text_or_summary();
183                let score = self.parse_score(&output);
184                timer.finish(score.is_some());
185                score
186            }
187            Err(e) => {
188                warn!(target: "local_llm", "Relevance scoring failed: {}", e);
189                timer.finish(false);
190                None
191            }
192        }
193    }
194
195    /// Heuristic relevance scoring (no LLM)
196    pub fn score_heuristic(&self, query: &str, content: &str) -> f32 {
197        let query_lower = query.to_lowercase();
198        let content_lower = content.to_lowercase();
199
200        // Extract query words (>2 chars)
201        let query_words: Vec<&str> = query_lower
202            .split_whitespace()
203            .filter(|w| w.len() > 2)
204            .collect();
205
206        if query_words.is_empty() {
207            return 0.5; // Default for empty query
208        }
209
210        // Count word matches
211        let mut matches = 0;
212        for word in &query_words {
213            if content_lower.contains(word) {
214                matches += 1;
215            }
216        }
217
218        // Calculate overlap ratio
219        let overlap_ratio = matches as f32 / query_words.len() as f32;
220
221        // Check for exact phrase match (bonus)
222        let phrase_bonus = if content_lower.contains(&query_lower) {
223            0.2
224        } else {
225            0.0
226        };
227
228        (overlap_ratio * 0.8 + phrase_bonus).min(1.0)
229    }
230
231    /// Build the re-ranking prompt
232    fn build_rerank_prompt<T: AsRef<str>>(&self, query: &str, items: &[&(T, f32)]) -> String {
233        let mut prompt = format!(
234            r#"Rank these items by relevance to the query.
235
236Query: "{}"
237
238Items:
239"#,
240            if query.len() > 150 {
241                &query[..150]
242            } else {
243                query
244            }
245        );
246
247        for (i, (content, _)) in items.iter().enumerate() {
248            let truncated = if content.as_ref().len() > 150 {
249                &content.as_ref()[..150]
250            } else {
251                content.as_ref()
252            };
253            prompt.push_str(&format!("{}. {}\n", i + 1, truncated));
254        }
255
256        prompt.push_str(
257            r#"
258Output format: item_number:score (0.0-1.0)
259Example: 1:0.9, 2:0.3, 3:0.7
260
261Scores:"#,
262        );
263
264        prompt
265    }
266
267    /// Parse the re-ranking output
268    fn parse_rerank_output<T: AsRef<str>>(
269        &self,
270        output: &str,
271        items: &[(T, f32)],
272    ) -> Vec<RelevanceResult> {
273        let mut results = Vec::new();
274        let mut scored_indices = std::collections::HashSet::new();
275
276        // Parse "N:score" patterns
277        for part in output.split([',', '\n', ' ']) {
278            let part = part.trim();
279            if let Some(colon_pos) = part.find(':')
280                && let (Ok(idx), score_str) = (
281                    part[..colon_pos].trim().parse::<usize>(),
282                    part[colon_pos + 1..].trim(),
283                )
284                && let Ok(score) = score_str.parse::<f32>()
285            {
286                let actual_idx = idx.saturating_sub(1); // 1-indexed to 0-indexed
287                if actual_idx < items.len() && !scored_indices.contains(&actual_idx) {
288                    scored_indices.insert(actual_idx);
289                    let (content, original_score) = &items[actual_idx];
290                    results.push(RelevanceResult::from_local(
291                        content.as_ref().to_string(),
292                        actual_idx,
293                        score.clamp(0.0, 1.0),
294                        *original_score,
295                    ));
296                }
297            }
298        }
299
300        // Add any items that weren't scored (with original scores)
301        for (i, (content, original_score)) in items.iter().enumerate() {
302            if !scored_indices.contains(&i) {
303                results.push(RelevanceResult::from_fallback(
304                    content.as_ref().to_string(),
305                    i,
306                    *original_score,
307                ));
308            }
309        }
310
311        results
312    }
313
314    /// Parse a score from LLM output
315    fn parse_score(&self, output: &str) -> Option<f32> {
316        let trimmed = output.trim();
317
318        // Try direct parse
319        if let Ok(score) = trimmed.parse::<f32>() {
320            return Some(score.clamp(0.0, 1.0));
321        }
322
323        // Look for a number pattern
324        if let Ok(re) = regex::Regex::new(r"(\d+\.?\d*)")
325            && let Some(captures) = re.captures(trimmed)
326            && let Some(m) = captures.get(1)
327            && let Ok(score) = m.as_str().parse::<f32>()
328        {
329            return Some(score.clamp(0.0, 1.0));
330        }
331
332        None
333    }
334}
335
336/// Builder for RelevanceScorer
337pub struct RelevanceScorerBuilder {
338    provider: Option<Arc<dyn Provider>>,
339    model_id: String,
340    min_score: f32,
341    max_items: usize,
342}
343
344impl Default for RelevanceScorerBuilder {
345    fn default() -> Self {
346        Self {
347            provider: None,
348            model_id: "lfm2-350m".to_string(),
349            min_score: 0.5,
350            max_items: 10,
351        }
352    }
353}
354
355impl RelevanceScorerBuilder {
356    /// Create a new builder with default settings.
357    pub fn new() -> Self {
358        Self::default()
359    }
360
361    /// Set the provider to use for relevance scoring.
362    pub fn provider(mut self, provider: Arc<dyn Provider>) -> Self {
363        self.provider = Some(provider);
364        self
365    }
366
367    /// Set the model ID to use for inference.
368    pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
369        self.model_id = model_id.into();
370        self
371    }
372
373    /// Set the minimum relevance score to include in results.
374    pub fn min_score(mut self, min_score: f32) -> Self {
375        self.min_score = min_score;
376        self
377    }
378
379    /// Set the maximum number of items to re-rank.
380    pub fn max_items(mut self, max_items: usize) -> Self {
381        self.max_items = max_items;
382        self
383    }
384
385    /// Build the relevance scorer, returning `None` if no provider was set.
386    pub fn build(self) -> Option<RelevanceScorer> {
387        self.provider.map(|p| {
388            RelevanceScorer::new(p, self.model_id)
389                .with_min_score(self.min_score)
390                .with_max_items(self.max_items)
391        })
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test]
400    fn test_relevance_result() {
401        let local = RelevanceResult::from_local("test content".to_string(), 0, 0.9, 0.75);
402        assert!(local.used_local_llm);
403        assert_eq!(local.relevance_score, 0.9);
404        assert_eq!(local.original_score, 0.75);
405
406        let fallback = RelevanceResult::from_fallback("test content".to_string(), 1, 0.7);
407        assert!(!fallback.used_local_llm);
408        assert_eq!(fallback.relevance_score, 0.7);
409    }
410
411    #[test]
412    fn test_heuristic_scoring() {
413        let score = score_heuristic_direct(
414            "rust async programming",
415            "This article discusses async programming in Rust using tokio",
416        );
417        assert!(score > 0.5);
418
419        let low_score = score_heuristic_direct(
420            "python web development",
421            "This article discusses async programming in Rust using tokio",
422        );
423        assert!(low_score < 0.3);
424    }
425
426    fn score_heuristic_direct(query: &str, content: &str) -> f32 {
427        let query_lower = query.to_lowercase();
428        let content_lower = content.to_lowercase();
429
430        let query_words: Vec<&str> = query_lower
431            .split_whitespace()
432            .filter(|w| w.len() > 2)
433            .collect();
434
435        if query_words.is_empty() {
436            return 0.5;
437        }
438
439        let mut matches = 0;
440        for word in &query_words {
441            if content_lower.contains(word) {
442                matches += 1;
443            }
444        }
445
446        let overlap_ratio = matches as f32 / query_words.len() as f32;
447        let phrase_bonus = if content_lower.contains(&query_lower) {
448            0.2
449        } else {
450            0.0
451        };
452
453        (overlap_ratio * 0.8 + phrase_bonus).min(1.0)
454    }
455
456    #[test]
457    fn test_parse_rerank_output() {
458        let output = "1:0.9, 2:0.5, 3:0.7";
459        let items = vec![
460            ("first item".to_string(), 0.8),
461            ("second item".to_string(), 0.6),
462            ("third item".to_string(), 0.7),
463        ];
464
465        let results = parse_rerank_output_direct(output, &items);
466        assert_eq!(results.len(), 3);
467
468        // Find the highest scored item
469        let best = results
470            .iter()
471            .max_by(|a, b| a.relevance_score.partial_cmp(&b.relevance_score).unwrap())
472            .unwrap();
473        assert_eq!(best.original_index, 0); // First item had 0.9 score
474    }
475
476    fn parse_rerank_output_direct(output: &str, items: &[(String, f32)]) -> Vec<RelevanceResult> {
477        let mut results = Vec::new();
478        let mut scored_indices = std::collections::HashSet::new();
479
480        for part in output.split(',') {
481            let part = part.trim();
482            if let Some(colon_pos) = part.find(':') {
483                if let (Ok(idx), score_str) = (
484                    part[..colon_pos].trim().parse::<usize>(),
485                    part[colon_pos + 1..].trim(),
486                ) {
487                    if let Ok(score) = score_str.parse::<f32>() {
488                        let actual_idx = idx.saturating_sub(1);
489                        if actual_idx < items.len() && !scored_indices.contains(&actual_idx) {
490                            scored_indices.insert(actual_idx);
491                            let (content, original_score) = &items[actual_idx];
492                            results.push(RelevanceResult::from_local(
493                                content.clone(),
494                                actual_idx,
495                                score.clamp(0.0, 1.0),
496                                *original_score,
497                            ));
498                        }
499                    }
500                }
501            }
502        }
503
504        results
505    }
506
507    #[test]
508    fn test_parse_score() {
509        assert_eq!(parse_score_direct("0.85"), Some(0.85));
510        assert_eq!(parse_score_direct("Score: 0.7"), Some(0.7));
511        assert_eq!(parse_score_direct("1.5"), Some(1.0)); // Clamped
512        assert_eq!(parse_score_direct("-0.5"), Some(0.0)); // Negative clamped to 0.0
513        assert_eq!(parse_score_direct("not a score"), None); // No number found
514    }
515
516    fn parse_score_direct(output: &str) -> Option<f32> {
517        let trimmed = output.trim();
518
519        if let Ok(score) = trimmed.parse::<f32>() {
520            return Some(score.clamp(0.0, 1.0));
521        }
522
523        if let Ok(re) = regex::Regex::new(r"(\d+\.?\d*)") {
524            if let Some(captures) = re.captures(trimmed) {
525                if let Some(m) = captures.get(1) {
526                    if let Ok(score) = m.as_str().parse::<f32>() {
527                        return Some(score.clamp(0.0, 1.0));
528                    }
529                }
530            }
531        }
532
533        None
534    }
535}