use std::collections::{HashMap, HashSet};
use crate::types::{ChatMessage, MessageRole};
const BM25_K1: f64 = 1.2;
const BM25_B: f64 = 0.75;
pub fn score_messages(query: &str, messages: &[ChatMessage]) -> Vec<(usize, f64)> {
let query_terms = tokenize(query);
if query_terms.is_empty() {
return messages
.iter()
.enumerate()
.filter(|(_, m)| m.role != MessageRole::System)
.map(|(i, _)| (i, 0.0))
.collect();
}
let docs: Vec<(usize, Vec<String>)> = messages
.iter()
.enumerate()
.filter(|(_, m)| m.role != MessageRole::System)
.map(|(i, m)| (i, tokenize(&m.content)))
.collect();
if docs.is_empty() {
return Vec::new();
}
#[allow(clippy::cast_precision_loss)]
let n = docs.len() as f64;
#[allow(clippy::cast_precision_loss)]
let avg_dl: f64 = docs.iter().map(|(_, t)| t.len() as f64).sum::<f64>() / n;
let mut df: HashMap<&str, usize> = HashMap::new();
for (_, terms) in &docs {
let unique: HashSet<&str> = terms.iter().map(String::as_str).collect();
for qt in &query_terms {
if unique.contains(qt.as_str()) {
*df.entry(qt.as_str()).or_default() += 1;
}
}
}
let mut scores: Vec<(usize, f64)> = docs
.iter()
.map(|(idx, terms)| {
#[allow(clippy::cast_precision_loss)]
let dl = terms.len() as f64;
let tf_map = term_freq(terms);
let score: f64 = query_terms
.iter()
.map(|qt| {
#[allow(clippy::cast_precision_loss)]
let doc_freq = *df.get(qt.as_str()).unwrap_or(&0) as f64;
let idf = ((n - doc_freq + 0.5) / (doc_freq + 0.5)).ln_1p();
#[allow(clippy::cast_precision_loss)]
let tf = *tf_map.get(qt.as_str()).unwrap_or(&0) as f64;
idf * (tf * (BM25_K1 + 1.0))
/ BM25_K1.mul_add(1.0 - BM25_B + BM25_B * dl / avg_dl, tf)
})
.sum();
(*idx, score)
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores
}
fn term_freq(terms: &[String]) -> HashMap<&str, usize> {
let mut freq = HashMap::new();
for t in terms {
*freq.entry(t.as_str()).or_insert(0usize) += 1;
}
freq
}
fn tokenize(text: &str) -> Vec<String> {
text.split(|c: char| c.is_whitespace() || c.is_ascii_punctuation())
.filter(|w| w.len() >= 2)
.map(str::to_lowercase)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn msg(role: MessageRole, content: &str) -> ChatMessage {
match role {
MessageRole::User | MessageRole::Tool => ChatMessage::user(content),
MessageRole::Assistant => ChatMessage::assistant(content),
MessageRole::System => ChatMessage::system(content),
}
}
#[test]
fn relevant_message_scores_highest() {
let messages = vec![
msg(MessageRole::User, "What is the weather in Berlin?"),
msg(
MessageRole::Assistant,
"The weather in Berlin is sunny, 22°C.",
),
msg(MessageRole::User, "Tell me about Rust programming."),
msg(
MessageRole::Assistant,
"Rust is a systems programming language focused on safety.",
),
];
let scores = score_messages("Rust programming language", &messages);
assert!(!scores.is_empty());
let top = scores[0];
let top_msg = &messages[top.0];
assert!(
top_msg.content.contains("Rust"),
"Expected Rust-related message to score highest, got: {}",
top_msg.content
);
}
#[test]
fn system_messages_excluded() {
let messages = vec![
msg(MessageRole::System, "You are a helpful assistant."),
msg(MessageRole::User, "Hello!"),
msg(MessageRole::Assistant, "Hi there!"),
];
let scores = score_messages("Hello", &messages);
assert_eq!(scores.len(), 2, "System messages should be excluded");
assert!(
scores
.iter()
.all(|(idx, _)| messages[*idx].role != MessageRole::System)
);
}
#[test]
fn empty_query_returns_zero_scores() {
let messages = vec![
msg(MessageRole::User, "Hello"),
msg(MessageRole::Assistant, "Hi"),
];
let scores = score_messages("", &messages);
assert_eq!(scores.len(), 2);
assert!(scores.iter().all(|(_, s)| *s == 0.0));
}
#[test]
fn empty_messages_returns_empty() {
let scores = score_messages("query", &[]);
assert!(scores.is_empty());
}
#[test]
fn multiple_term_overlap_scores_higher() {
let messages = vec![
msg(MessageRole::User, "I like apples"),
msg(
MessageRole::User,
"I like apples and oranges from the market",
),
msg(MessageRole::User, "The cat sat on the mat"),
];
let scores = score_messages("apples oranges market", &messages);
let top = scores[0];
assert_eq!(top.0, 1, "Message with most query terms should rank first");
}
#[test]
fn preserves_original_indices() {
let messages = vec![
msg(MessageRole::System, "System prompt"),
msg(MessageRole::User, "First user message"),
msg(MessageRole::Assistant, "First response"),
msg(MessageRole::User, "Second user message"),
];
let scores = score_messages("first", &messages);
for (idx, _) in &scores {
assert!(*idx < messages.len());
assert_ne!(messages[*idx].role, MessageRole::System);
}
}
#[test]
fn scores_are_non_negative() {
let messages = vec![
msg(MessageRole::User, "Hello world"),
msg(MessageRole::Assistant, "Goodbye moon"),
];
let scores = score_messages("Hello world", &messages);
assert!(scores.iter().all(|(_, s)| *s >= 0.0));
}
}