use std::sync::Arc;
use tracing::warn;
use brainwires_core::message::Message;
use brainwires_core::provider::{ChatOptions, Provider};
use super::InferenceTimer;
#[derive(Clone, Debug)]
pub struct RelevanceResult {
pub content: String,
pub original_index: usize,
pub relevance_score: f32,
pub original_score: f32,
pub used_local_llm: bool,
}
impl RelevanceResult {
pub fn from_local(
content: String,
original_index: usize,
relevance_score: f32,
original_score: f32,
) -> Self {
Self {
content,
original_index,
relevance_score,
original_score,
used_local_llm: true,
}
}
pub fn from_fallback(content: String, original_index: usize, original_score: f32) -> Self {
Self {
content,
original_index,
relevance_score: original_score,
original_score,
used_local_llm: false,
}
}
}
pub struct RelevanceScorer {
provider: Arc<dyn Provider>,
model_id: String,
min_score: f32,
max_items: usize,
}
impl RelevanceScorer {
pub fn new(provider: Arc<dyn Provider>, model_id: impl Into<String>) -> Self {
Self {
provider,
model_id: model_id.into(),
min_score: 0.5,
max_items: 10,
}
}
pub fn with_min_score(mut self, min_score: f32) -> Self {
self.min_score = min_score;
self
}
pub fn with_max_items(mut self, max_items: usize) -> Self {
self.max_items = max_items;
self
}
pub async fn rerank<T: AsRef<str>>(
&self,
query: &str,
items: &[(T, f32)], ) -> Vec<RelevanceResult> {
let timer = InferenceTimer::new("rerank_context", &self.model_id);
let items_to_score: Vec<_> = items.iter().take(self.max_items).collect();
if items_to_score.is_empty() {
timer.finish(true);
return Vec::new();
}
let prompt = self.build_rerank_prompt(query, &items_to_score);
let messages = vec![Message::user(&prompt)];
let options = ChatOptions::deterministic(100);
match self.provider.chat(&messages, None, &options).await {
Ok(response) => {
let output = response.message.text_or_summary();
let mut results = self.parse_rerank_output(&output, items);
results.sort_by(|a, b| {
b.relevance_score
.partial_cmp(&a.relevance_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.retain(|r| r.relevance_score >= self.min_score);
timer.finish(true);
results
}
Err(e) => {
warn!(target: "local_llm", "Context re-ranking failed: {}", e);
timer.finish(false);
items
.iter()
.enumerate()
.filter(|(_, (_, score))| *score >= self.min_score)
.map(|(i, (content, score))| {
RelevanceResult::from_fallback(content.as_ref().to_string(), i, *score)
})
.collect()
}
}
}
pub async fn score_relevance(&self, query: &str, content: &str) -> Option<f32> {
let timer = InferenceTimer::new("score_relevance", &self.model_id);
let prompt = format!(
r#"Rate the relevance of this content to the query.
Query: "{}"
Content: "{}"
Output a score from 0.0 (irrelevant) to 1.0 (highly relevant).
Output ONLY the decimal number.
Score:"#,
if query.len() > 100 {
&query[..100]
} else {
query
},
if content.len() > 300 {
&content[..300]
} else {
content
}
);
let messages = vec![Message::user(&prompt)];
let options = ChatOptions::deterministic(10);
match self.provider.chat(&messages, None, &options).await {
Ok(response) => {
let output = response.message.text_or_summary();
let score = self.parse_score(&output);
timer.finish(score.is_some());
score
}
Err(e) => {
warn!(target: "local_llm", "Relevance scoring failed: {}", e);
timer.finish(false);
None
}
}
}
pub fn score_heuristic(&self, query: &str, content: &str) -> f32 {
let query_lower = query.to_lowercase();
let content_lower = content.to_lowercase();
let query_words: Vec<&str> = query_lower
.split_whitespace()
.filter(|w| w.len() > 2)
.collect();
if query_words.is_empty() {
return 0.5; }
let mut matches = 0;
for word in &query_words {
if content_lower.contains(word) {
matches += 1;
}
}
let overlap_ratio = matches as f32 / query_words.len() as f32;
let phrase_bonus = if content_lower.contains(&query_lower) {
0.2
} else {
0.0
};
(overlap_ratio * 0.8 + phrase_bonus).min(1.0)
}
fn build_rerank_prompt<T: AsRef<str>>(&self, query: &str, items: &[&(T, f32)]) -> String {
let mut prompt = format!(
r#"Rank these items by relevance to the query.
Query: "{}"
Items:
"#,
if query.len() > 150 {
&query[..150]
} else {
query
}
);
for (i, (content, _)) in items.iter().enumerate() {
let truncated = if content.as_ref().len() > 150 {
&content.as_ref()[..150]
} else {
content.as_ref()
};
prompt.push_str(&format!("{}. {}\n", i + 1, truncated));
}
prompt.push_str(
r#"
Output format: item_number:score (0.0-1.0)
Example: 1:0.9, 2:0.3, 3:0.7
Scores:"#,
);
prompt
}
fn parse_rerank_output<T: AsRef<str>>(
&self,
output: &str,
items: &[(T, f32)],
) -> Vec<RelevanceResult> {
let mut results = Vec::new();
let mut scored_indices = std::collections::HashSet::new();
for part in output.split([',', '\n', ' ']) {
let part = part.trim();
if let Some(colon_pos) = part.find(':')
&& let (Ok(idx), score_str) = (
part[..colon_pos].trim().parse::<usize>(),
part[colon_pos + 1..].trim(),
)
&& let Ok(score) = score_str.parse::<f32>()
{
let actual_idx = idx.saturating_sub(1); if actual_idx < items.len() && !scored_indices.contains(&actual_idx) {
scored_indices.insert(actual_idx);
let (content, original_score) = &items[actual_idx];
results.push(RelevanceResult::from_local(
content.as_ref().to_string(),
actual_idx,
score.clamp(0.0, 1.0),
*original_score,
));
}
}
}
for (i, (content, original_score)) in items.iter().enumerate() {
if !scored_indices.contains(&i) {
results.push(RelevanceResult::from_fallback(
content.as_ref().to_string(),
i,
*original_score,
));
}
}
results
}
fn parse_score(&self, output: &str) -> Option<f32> {
let trimmed = output.trim();
if let Ok(score) = trimmed.parse::<f32>() {
return Some(score.clamp(0.0, 1.0));
}
if let Ok(re) = regex::Regex::new(r"(\d+\.?\d*)")
&& let Some(captures) = re.captures(trimmed)
&& let Some(m) = captures.get(1)
&& let Ok(score) = m.as_str().parse::<f32>()
{
return Some(score.clamp(0.0, 1.0));
}
None
}
}
pub struct RelevanceScorerBuilder {
provider: Option<Arc<dyn Provider>>,
model_id: String,
min_score: f32,
max_items: usize,
}
impl Default for RelevanceScorerBuilder {
fn default() -> Self {
Self {
provider: None,
model_id: "lfm2-350m".to_string(),
min_score: 0.5,
max_items: 10,
}
}
}
impl RelevanceScorerBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn provider(mut self, provider: Arc<dyn Provider>) -> Self {
self.provider = Some(provider);
self
}
pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
self.model_id = model_id.into();
self
}
pub fn min_score(mut self, min_score: f32) -> Self {
self.min_score = min_score;
self
}
pub fn max_items(mut self, max_items: usize) -> Self {
self.max_items = max_items;
self
}
pub fn build(self) -> Option<RelevanceScorer> {
self.provider.map(|p| {
RelevanceScorer::new(p, self.model_id)
.with_min_score(self.min_score)
.with_max_items(self.max_items)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_relevance_result() {
let local = RelevanceResult::from_local("test content".to_string(), 0, 0.9, 0.75);
assert!(local.used_local_llm);
assert_eq!(local.relevance_score, 0.9);
assert_eq!(local.original_score, 0.75);
let fallback = RelevanceResult::from_fallback("test content".to_string(), 1, 0.7);
assert!(!fallback.used_local_llm);
assert_eq!(fallback.relevance_score, 0.7);
}
#[test]
fn test_heuristic_scoring() {
let score = score_heuristic_direct(
"rust async programming",
"This article discusses async programming in Rust using tokio",
);
assert!(score > 0.5);
let low_score = score_heuristic_direct(
"python web development",
"This article discusses async programming in Rust using tokio",
);
assert!(low_score < 0.3);
}
fn score_heuristic_direct(query: &str, content: &str) -> f32 {
let query_lower = query.to_lowercase();
let content_lower = content.to_lowercase();
let query_words: Vec<&str> = query_lower
.split_whitespace()
.filter(|w| w.len() > 2)
.collect();
if query_words.is_empty() {
return 0.5;
}
let mut matches = 0;
for word in &query_words {
if content_lower.contains(word) {
matches += 1;
}
}
let overlap_ratio = matches as f32 / query_words.len() as f32;
let phrase_bonus = if content_lower.contains(&query_lower) {
0.2
} else {
0.0
};
(overlap_ratio * 0.8 + phrase_bonus).min(1.0)
}
#[test]
fn test_parse_rerank_output() {
let output = "1:0.9, 2:0.5, 3:0.7";
let items = vec![
("first item".to_string(), 0.8),
("second item".to_string(), 0.6),
("third item".to_string(), 0.7),
];
let results = parse_rerank_output_direct(output, &items);
assert_eq!(results.len(), 3);
let best = results
.iter()
.max_by(|a, b| a.relevance_score.partial_cmp(&b.relevance_score).unwrap())
.unwrap();
assert_eq!(best.original_index, 0); }
fn parse_rerank_output_direct(output: &str, items: &[(String, f32)]) -> Vec<RelevanceResult> {
let mut results = Vec::new();
let mut scored_indices = std::collections::HashSet::new();
for part in output.split(',') {
let part = part.trim();
if let Some(colon_pos) = part.find(':') {
if let (Ok(idx), score_str) = (
part[..colon_pos].trim().parse::<usize>(),
part[colon_pos + 1..].trim(),
) {
if let Ok(score) = score_str.parse::<f32>() {
let actual_idx = idx.saturating_sub(1);
if actual_idx < items.len() && !scored_indices.contains(&actual_idx) {
scored_indices.insert(actual_idx);
let (content, original_score) = &items[actual_idx];
results.push(RelevanceResult::from_local(
content.clone(),
actual_idx,
score.clamp(0.0, 1.0),
*original_score,
));
}
}
}
}
}
results
}
#[test]
fn test_parse_score() {
assert_eq!(parse_score_direct("0.85"), Some(0.85));
assert_eq!(parse_score_direct("Score: 0.7"), Some(0.7));
assert_eq!(parse_score_direct("1.5"), Some(1.0)); assert_eq!(parse_score_direct("-0.5"), Some(0.0)); assert_eq!(parse_score_direct("not a score"), None); }
fn parse_score_direct(output: &str) -> Option<f32> {
let trimmed = output.trim();
if let Ok(score) = trimmed.parse::<f32>() {
return Some(score.clamp(0.0, 1.0));
}
if let Ok(re) = regex::Regex::new(r"(\d+\.?\d*)") {
if let Some(captures) = re.captures(trimmed) {
if let Some(m) = captures.get(1) {
if let Ok(score) = m.as_str().parse::<f32>() {
return Some(score.clamp(0.0, 1.0));
}
}
}
}
None
}
}