use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde::Serialize;
use super::graph::GraphMetrics;
use crate::model::Issue;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SearchMode {
Text,
Hybrid,
}
impl SearchMode {
pub fn from_str_or_default(s: &str) -> Self {
match s.to_ascii_lowercase().as_str() {
"hybrid" => Self::Hybrid,
_ => Self::Text,
}
}
pub const fn as_str(self) -> &'static str {
match self {
Self::Text => "text",
Self::Hybrid => "hybrid",
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct SearchWeights {
pub text: f64,
pub pagerank: f64,
pub status: f64,
pub impact: f64,
pub priority: f64,
pub recency: f64,
}
impl SearchWeights {
#[must_use]
pub fn normalize(&self) -> Self {
let sum =
self.text + self.pagerank + self.status + self.impact + self.priority + self.recency;
if sum <= 0.0 {
return Self::default_preset();
}
Self {
text: self.text / sum,
pagerank: self.pagerank / sum,
status: self.status / sum,
impact: self.impact / sum,
priority: self.priority / sum,
recency: self.recency / sum,
}
}
}
pub fn get_preset(name: &str) -> SearchWeights {
match name.to_ascii_lowercase().as_str() {
"bug-hunting" => SearchWeights {
text: 0.30,
pagerank: 0.15,
status: 0.15,
impact: 0.15,
priority: 0.20,
recency: 0.05,
},
"sprint-planning" => SearchWeights {
text: 0.30,
pagerank: 0.20,
status: 0.25,
impact: 0.15,
priority: 0.05,
recency: 0.05,
},
"impact-first" => SearchWeights {
text: 0.25,
pagerank: 0.30,
status: 0.10,
impact: 0.20,
priority: 0.10,
recency: 0.05,
},
"text-only" => SearchWeights {
text: 1.0,
pagerank: 0.0,
status: 0.0,
impact: 0.0,
priority: 0.0,
recency: 0.0,
},
_ => SearchWeights::default_preset(),
}
}
impl SearchWeights {
pub fn default_preset() -> Self {
Self {
text: 0.40,
pagerank: 0.20,
status: 0.15,
impact: 0.10,
priority: 0.10,
recency: 0.05,
}
}
pub fn from_json(json_str: &str) -> Result<Self, String> {
let map: HashMap<String, f64> =
serde_json::from_str(json_str).map_err(|e| format!("invalid weights JSON: {e}"))?;
let weights = Self {
text: map.get("text").copied().unwrap_or(0.0),
pagerank: map.get("pagerank").copied().unwrap_or(0.0),
status: map.get("status").copied().unwrap_or(0.0),
impact: map.get("impact").copied().unwrap_or(0.0),
priority: map.get("priority").copied().unwrap_or(0.0),
recency: map.get("recency").copied().unwrap_or(0.0),
};
if weights.text < 0.0
|| weights.pagerank < 0.0
|| weights.status < 0.0
|| weights.impact < 0.0
|| weights.priority < 0.0
|| weights.recency < 0.0
{
return Err("all weights must be non-negative".to_string());
}
Ok(weights.normalize())
}
}
fn compute_text_score(query: &str, issue: &Issue) -> f64 {
let query_lower = query.to_ascii_lowercase();
let tokens: Vec<&str> = query_lower.split_whitespace().collect();
if tokens.is_empty() {
return 0.0;
}
let doc = format!(
"{id} {id} {id} {title} {title} {labels} {desc}",
id = issue.id.to_ascii_lowercase(),
title = issue.title.to_ascii_lowercase(),
labels = issue.labels.join(" ").to_ascii_lowercase(),
desc = issue.description.to_ascii_lowercase(),
);
if issue.id.to_ascii_lowercase() == query_lower {
return 1.0;
}
let mut hit_count = 0usize;
for token in &tokens {
if doc.contains(token) {
hit_count += 1;
}
}
if hit_count == 0 {
return 0.0;
}
let token_coverage = hit_count as f64 / tokens.len() as f64;
let title_lower = issue.title.to_ascii_lowercase();
const TITLE_MATCH_BONUS: f64 = 0.3;
let title_bonus = if title_lower.contains(&query_lower) {
TITLE_MATCH_BONUS
} else {
0.0
};
let id_lower = issue.id.to_ascii_lowercase();
const ID_MATCH_BONUS: f64 = 0.2;
let id_bonus = if id_lower.contains(&query_lower) {
ID_MATCH_BONUS
} else {
0.0
};
const TOKEN_COVERAGE_WEIGHT: f64 = 0.5;
(token_coverage * TOKEN_COVERAGE_WEIGHT + title_bonus + id_bonus).min(1.0)
}
fn is_short_query(query: &str) -> bool {
let tokens = query.split_whitespace().count();
tokens <= 2 || query.len() <= 12
}
fn adjust_weights_for_short_query(weights: &SearchWeights) -> SearchWeights {
if weights.text >= 0.55 {
return weights.clone();
}
let target = 0.55;
let remaining = 1.0 - weights.text;
if remaining <= 0.0 {
return weights.clone();
}
let scale = (1.0 - target) / remaining;
SearchWeights {
text: target,
pagerank: weights.pagerank * scale,
status: weights.status * scale,
impact: weights.impact * scale,
priority: weights.priority * scale,
recency: weights.recency * scale,
}
.normalize()
}
fn normalize_status(status: &str) -> f64 {
match status.to_ascii_lowercase().as_str() {
"open" => 1.0,
"in_progress" => 0.8,
"closed" => 0.1,
"tombstone" => 0.0,
_ => 0.5,
}
}
fn normalize_priority(priority: i32) -> f64 {
match priority.clamp(0, 4) {
0 => 1.0,
1 => 0.8,
2 => 0.6,
3 => 0.4,
_ => 0.2,
}
}
fn normalize_impact(blocks_count: usize, max_blocks: usize) -> f64 {
if max_blocks == 0 {
return 0.5;
}
blocks_count as f64 / max_blocks as f64
}
fn normalize_recency(updated_at: Option<DateTime<Utc>>) -> f64 {
let Some(ts) = updated_at else {
return 0.0;
};
let now = Utc::now();
if ts > now {
return 0.5;
}
let days = (now - ts).num_days();
(-(days as f64) / 30.0_f64).exp()
}
#[derive(Debug, Clone, Serialize)]
pub struct SearchResult {
pub issue_id: String,
pub score: f64,
pub title: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub text_score: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub component_scores: Option<ComponentScores>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ComponentScores {
pub pagerank: f64,
pub status: f64,
pub impact: f64,
pub priority: f64,
pub recency: f64,
}
#[derive(Debug, Serialize)]
pub struct RobotSearchOutput {
#[serde(flatten)]
pub envelope: crate::robot::RobotEnvelope,
pub query: String,
pub limit: usize,
pub mode: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub preset: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub weights: Option<SearchWeights>,
pub results: Vec<SearchResult>,
pub usage_hints: Vec<String>,
}
pub fn execute_search(
query: &str,
issues: &[Issue],
metrics: &GraphMetrics,
mode: SearchMode,
weights: &SearchWeights,
limit: usize,
) -> Vec<SearchResult> {
let query = query.trim();
if query.is_empty() {
return Vec::new();
}
let max_blocks = metrics.blocks_count.values().copied().max().unwrap_or(0);
let effective_weights = if mode == SearchMode::Hybrid && is_short_query(query) {
adjust_weights_for_short_query(weights)
} else {
weights.clone()
};
let mut results: Vec<SearchResult> = issues
.iter()
.filter_map(|issue| {
let text_score = compute_text_score(query, issue);
let lexical_boost = if is_short_query(query) {
let doc = format!(
"{} {} {} {}",
issue.id,
issue.title,
issue.labels.join(" "),
issue.description,
);
if doc
.to_ascii_lowercase()
.contains(&query.to_ascii_lowercase())
{
0.35
} else {
0.0
}
} else {
0.0
};
let boosted_text = (text_score + lexical_boost).min(1.0);
if boosted_text <= 0.0 {
return None;
}
let (score, text_score_field, components) = match mode {
SearchMode::Text => (boosted_text, None, None),
SearchMode::Hybrid => {
let pagerank = metrics.pagerank.get(&issue.id).copied().unwrap_or(0.0);
let status = normalize_status(&issue.status);
let blocks = metrics.blocks_count.get(&issue.id).copied().unwrap_or(0);
let impact = normalize_impact(blocks, max_blocks);
let priority = normalize_priority(issue.priority);
let recency = normalize_recency(issue.updated_at);
let hybrid_score = effective_weights.text * boosted_text
+ effective_weights.pagerank * pagerank
+ effective_weights.status * status
+ effective_weights.impact * impact
+ effective_weights.priority * priority
+ effective_weights.recency * recency;
(
hybrid_score,
Some(boosted_text),
Some(ComponentScores {
pagerank,
status,
impact,
priority,
recency,
}),
)
}
};
if score <= 0.0 {
return None;
}
Some(SearchResult {
issue_id: issue.id.clone(),
score,
title: issue.title.clone(),
text_score: text_score_field,
component_scores: components,
})
})
.collect();
results.sort_by(|a, b| {
b.score
.total_cmp(&a.score)
.then_with(|| a.issue_id.cmp(&b.issue_id))
});
if let Some(pos) = results
.iter()
.position(|r| r.issue_id.eq_ignore_ascii_case(query))
{
if pos > 0 {
let exact = results.remove(pos);
results.insert(0, exact);
}
}
if limit > 0 {
results.truncate(limit);
}
results
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analysis::graph::IssueGraph;
use crate::model::Issue;
fn make_issue(id: &str, title: &str, status: &str, priority: i32) -> Issue {
Issue {
id: id.to_string(),
title: title.to_string(),
status: status.to_string(),
priority,
..Issue::default()
}
}
fn make_issues_and_metrics() -> (Vec<Issue>, GraphMetrics) {
let issues = vec![
make_issue("AUTH-1", "Fix authentication bug", "open", 0),
make_issue("NET-2", "Network timeout handling", "in_progress", 1),
make_issue("DB-3", "Database migration script", "open", 2),
make_issue("AUTH-4", "OAuth token refresh", "blocked", 1),
make_issue("UI-5", "Dashboard layout fix", "closed", 3),
];
let graph = IssueGraph::build(&issues);
let metrics = graph.compute_metrics();
(issues, metrics)
}
#[test]
fn text_search_basic() {
let (issues, metrics) = make_issues_and_metrics();
let weights = SearchWeights::default_preset();
let results = execute_search(
"authentication",
&issues,
&metrics,
SearchMode::Text,
&weights,
10,
);
assert!(!results.is_empty());
assert_eq!(results[0].issue_id, "AUTH-1");
}
#[test]
fn text_search_no_results() {
let (issues, metrics) = make_issues_and_metrics();
let weights = SearchWeights::default_preset();
let results = execute_search(
"zzzznotfound",
&issues,
&metrics,
SearchMode::Text,
&weights,
10,
);
assert!(results.is_empty());
}
#[test]
fn text_search_whitespace_query_returns_no_results() {
let (issues, metrics) = make_issues_and_metrics();
let weights = SearchWeights::default_preset();
let results = execute_search(" \t ", &issues, &metrics, SearchMode::Text, &weights, 10);
assert!(results.is_empty());
}
#[test]
fn text_search_limit() {
let (issues, metrics) = make_issues_and_metrics();
let weights = SearchWeights::default_preset();
let results = execute_search("fix", &issues, &metrics, SearchMode::Text, &weights, 1);
assert!(results.len() <= 1);
}
#[test]
fn exact_id_match_promoted() {
let (issues, metrics) = make_issues_and_metrics();
let weights = SearchWeights::default_preset();
let results = execute_search("DB-3", &issues, &metrics, SearchMode::Text, &weights, 10);
assert!(!results.is_empty());
assert_eq!(results[0].issue_id, "DB-3");
}
#[test]
fn hybrid_mode_includes_components() {
let (issues, metrics) = make_issues_and_metrics();
let weights = SearchWeights::default_preset();
let results = execute_search("auth", &issues, &metrics, SearchMode::Hybrid, &weights, 10);
assert!(!results.is_empty());
assert!(results[0].text_score.is_some());
assert!(results[0].component_scores.is_some());
}
#[test]
fn hybrid_search_whitespace_query_returns_no_results() {
let (issues, metrics) = make_issues_and_metrics();
let weights = SearchWeights::default_preset();
let results = execute_search(" \n", &issues, &metrics, SearchMode::Hybrid, &weights, 10);
assert!(results.is_empty());
}
#[test]
fn hybrid_search_without_lexical_match_returns_no_results() {
let (issues, metrics) = make_issues_and_metrics();
let weights = SearchWeights::default_preset();
let results = execute_search(
"zzzznotfound",
&issues,
&metrics,
SearchMode::Hybrid,
&weights,
10,
);
assert!(results.is_empty());
}
#[test]
fn preset_weights_valid() {
let presets = [
"default",
"bug-hunting",
"sprint-planning",
"impact-first",
"text-only",
];
for name in &presets {
let w = get_preset(name);
let sum = w.text + w.pagerank + w.status + w.impact + w.priority + w.recency;
assert!((sum - 1.0).abs() < 0.001, "preset {name} sum = {sum}");
}
}
#[test]
fn custom_weights_parsing() {
let json = r#"{"text":0.5,"pagerank":0.2,"status":0.1,"impact":0.1,"priority":0.05,"recency":0.05}"#;
let weights = SearchWeights::from_json(json).unwrap();
let sum = weights.text
+ weights.pagerank
+ weights.status
+ weights.impact
+ weights.priority
+ weights.recency;
assert!((sum - 1.0).abs() < 0.001);
}
#[test]
fn short_query_detection() {
assert!(is_short_query("auth"));
assert!(is_short_query("fix bug"));
assert!(!is_short_query("authentication handling in the login flow"));
}
#[test]
fn short_query_weight_adjustment() {
let weights = SearchWeights::default_preset();
let adjusted = adjust_weights_for_short_query(&weights);
assert!(adjusted.text >= 0.55);
let sum = adjusted.text
+ adjusted.pagerank
+ adjusted.status
+ adjusted.impact
+ adjusted.priority
+ adjusted.recency;
assert!((sum - 1.0).abs() < 0.001);
}
#[test]
fn deterministic_output() {
let (issues, metrics) = make_issues_and_metrics();
let weights = SearchWeights::default_preset();
let r1 = execute_search("fix", &issues, &metrics, SearchMode::Text, &weights, 10);
let r2 = execute_search("fix", &issues, &metrics, SearchMode::Text, &weights, 10);
assert_eq!(r1.len(), r2.len());
for (a, b) in r1.iter().zip(r2.iter()) {
assert_eq!(a.issue_id, b.issue_id);
assert!((a.score - b.score).abs() < f64::EPSILON);
}
}
#[test]
fn recency_for_current_timestamp_is_high() {
let score = normalize_recency(Some(chrono::Utc::now()));
assert!(
score > 0.9,
"expected recent timestamp score > 0.9, got {score}"
);
}
}