1use std::collections::HashMap;
2
3use chrono::{DateTime, Utc};
4use serde::Serialize;
5
6use super::graph::GraphMetrics;
7use crate::model::Issue;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum SearchMode {
15 Text,
16 Hybrid,
17}
18
19impl SearchMode {
20 pub fn from_str_or_default(s: &str) -> Self {
21 match s.to_ascii_lowercase().as_str() {
22 "hybrid" => Self::Hybrid,
23 _ => Self::Text,
24 }
25 }
26
27 pub const fn as_str(self) -> &'static str {
28 match self {
29 Self::Text => "text",
30 Self::Hybrid => "hybrid",
31 }
32 }
33}
34
35#[derive(Debug, Clone, Serialize)]
36pub struct SearchWeights {
37 pub text: f64,
38 pub pagerank: f64,
39 pub status: f64,
40 pub impact: f64,
41 pub priority: f64,
42 pub recency: f64,
43}
44
45impl SearchWeights {
46 #[must_use]
47 pub fn normalize(&self) -> Self {
48 let sum =
49 self.text + self.pagerank + self.status + self.impact + self.priority + self.recency;
50 if sum <= 0.0 {
51 return Self::default_preset();
52 }
53 Self {
54 text: self.text / sum,
55 pagerank: self.pagerank / sum,
56 status: self.status / sum,
57 impact: self.impact / sum,
58 priority: self.priority / sum,
59 recency: self.recency / sum,
60 }
61 }
62}
63
64pub fn get_preset(name: &str) -> SearchWeights {
65 match name.to_ascii_lowercase().as_str() {
66 "bug-hunting" => SearchWeights {
67 text: 0.30,
68 pagerank: 0.15,
69 status: 0.15,
70 impact: 0.15,
71 priority: 0.20,
72 recency: 0.05,
73 },
74 "sprint-planning" => SearchWeights {
75 text: 0.30,
76 pagerank: 0.20,
77 status: 0.25,
78 impact: 0.15,
79 priority: 0.05,
80 recency: 0.05,
81 },
82 "impact-first" => SearchWeights {
83 text: 0.25,
84 pagerank: 0.30,
85 status: 0.10,
86 impact: 0.20,
87 priority: 0.10,
88 recency: 0.05,
89 },
90 "text-only" => SearchWeights {
91 text: 1.0,
92 pagerank: 0.0,
93 status: 0.0,
94 impact: 0.0,
95 priority: 0.0,
96 recency: 0.0,
97 },
98 _ => SearchWeights::default_preset(),
99 }
100}
101
102impl SearchWeights {
103 pub fn default_preset() -> Self {
104 Self {
105 text: 0.40,
106 pagerank: 0.20,
107 status: 0.15,
108 impact: 0.10,
109 priority: 0.10,
110 recency: 0.05,
111 }
112 }
113
114 pub fn from_json(json_str: &str) -> Result<Self, String> {
116 let map: HashMap<String, f64> =
117 serde_json::from_str(json_str).map_err(|e| format!("invalid weights JSON: {e}"))?;
118
119 let weights = Self {
120 text: map.get("text").copied().unwrap_or(0.0),
121 pagerank: map.get("pagerank").copied().unwrap_or(0.0),
122 status: map.get("status").copied().unwrap_or(0.0),
123 impact: map.get("impact").copied().unwrap_or(0.0),
124 priority: map.get("priority").copied().unwrap_or(0.0),
125 recency: map.get("recency").copied().unwrap_or(0.0),
126 };
127
128 if weights.text < 0.0
130 || weights.pagerank < 0.0
131 || weights.status < 0.0
132 || weights.impact < 0.0
133 || weights.priority < 0.0
134 || weights.recency < 0.0
135 {
136 return Err("all weights must be non-negative".to_string());
137 }
138
139 Ok(weights.normalize())
140 }
141}
142
143fn compute_text_score(query: &str, issue: &Issue) -> f64 {
149 let query_lower = query.to_ascii_lowercase();
150 let tokens: Vec<&str> = query_lower.split_whitespace().collect();
151 if tokens.is_empty() {
152 return 0.0;
153 }
154
155 let doc = format!(
157 "{id} {id} {id} {title} {title} {labels} {desc}",
158 id = issue.id.to_ascii_lowercase(),
159 title = issue.title.to_ascii_lowercase(),
160 labels = issue.labels.join(" ").to_ascii_lowercase(),
161 desc = issue.description.to_ascii_lowercase(),
162 );
163
164 if issue.id.to_ascii_lowercase() == query_lower {
166 return 1.0;
167 }
168
169 let mut hit_count = 0usize;
171 for token in &tokens {
172 if doc.contains(token) {
173 hit_count += 1;
174 }
175 }
176
177 if hit_count == 0 {
178 return 0.0;
179 }
180
181 let token_coverage = hit_count as f64 / tokens.len() as f64;
182
183 let title_lower = issue.title.to_ascii_lowercase();
186 const TITLE_MATCH_BONUS: f64 = 0.3;
187 let title_bonus = if title_lower.contains(&query_lower) {
188 TITLE_MATCH_BONUS
189 } else {
190 0.0
191 };
192
193 let id_lower = issue.id.to_ascii_lowercase();
196 const ID_MATCH_BONUS: f64 = 0.2;
197 let id_bonus = if id_lower.contains(&query_lower) {
198 ID_MATCH_BONUS
199 } else {
200 0.0
201 };
202
203 const TOKEN_COVERAGE_WEIGHT: f64 = 0.5;
207 (token_coverage * TOKEN_COVERAGE_WEIGHT + title_bonus + id_bonus).min(1.0)
208}
209
210fn is_short_query(query: &str) -> bool {
212 let tokens = query.split_whitespace().count();
213 tokens <= 2 || query.len() <= 12
214}
215
216fn adjust_weights_for_short_query(weights: &SearchWeights) -> SearchWeights {
218 if weights.text >= 0.55 {
219 return weights.clone();
220 }
221 let target = 0.55;
222 let remaining = 1.0 - weights.text;
223 if remaining <= 0.0 {
224 return weights.clone();
225 }
226 let scale = (1.0 - target) / remaining;
227 SearchWeights {
228 text: target,
229 pagerank: weights.pagerank * scale,
230 status: weights.status * scale,
231 impact: weights.impact * scale,
232 priority: weights.priority * scale,
233 recency: weights.recency * scale,
234 }
235 .normalize()
236}
237
238fn normalize_status(status: &str) -> f64 {
243 match status.to_ascii_lowercase().as_str() {
244 "open" => 1.0,
245 "in_progress" => 0.8,
246 "closed" => 0.1,
247 "tombstone" => 0.0,
248 _ => 0.5,
249 }
250}
251
252fn normalize_priority(priority: i32) -> f64 {
253 match priority.clamp(0, 4) {
254 0 => 1.0,
255 1 => 0.8,
256 2 => 0.6,
257 3 => 0.4,
258 _ => 0.2,
259 }
260}
261
262fn normalize_impact(blocks_count: usize, max_blocks: usize) -> f64 {
263 if max_blocks == 0 {
264 return 0.5;
265 }
266 blocks_count as f64 / max_blocks as f64
267}
268
269fn normalize_recency(updated_at: Option<DateTime<Utc>>) -> f64 {
270 let Some(ts) = updated_at else {
271 return 0.0;
272 };
273 let now = Utc::now();
274 if ts > now {
275 return 0.5;
276 }
277 let days = (now - ts).num_days();
278 (-(days as f64) / 30.0_f64).exp()
280}
281
282#[derive(Debug, Clone, Serialize)]
287pub struct SearchResult {
288 pub issue_id: String,
289 pub score: f64,
290 pub title: String,
291 #[serde(skip_serializing_if = "Option::is_none")]
292 pub text_score: Option<f64>,
293 #[serde(skip_serializing_if = "Option::is_none")]
294 pub component_scores: Option<ComponentScores>,
295}
296
297#[derive(Debug, Clone, Serialize)]
298pub struct ComponentScores {
299 pub pagerank: f64,
300 pub status: f64,
301 pub impact: f64,
302 pub priority: f64,
303 pub recency: f64,
304}
305
306#[derive(Debug, Serialize)]
307pub struct RobotSearchOutput {
308 #[serde(flatten)]
309 pub envelope: crate::robot::RobotEnvelope,
310 pub query: String,
311 pub limit: usize,
312 pub mode: String,
313 #[serde(skip_serializing_if = "Option::is_none")]
314 pub preset: Option<String>,
315 #[serde(skip_serializing_if = "Option::is_none")]
316 pub weights: Option<SearchWeights>,
317 pub results: Vec<SearchResult>,
318 pub usage_hints: Vec<String>,
319}
320
321pub fn execute_search(
323 query: &str,
324 issues: &[Issue],
325 metrics: &GraphMetrics,
326 mode: SearchMode,
327 weights: &SearchWeights,
328 limit: usize,
329) -> Vec<SearchResult> {
330 let query = query.trim();
331 if query.is_empty() {
332 return Vec::new();
333 }
334
335 let max_blocks = metrics.blocks_count.values().copied().max().unwrap_or(0);
336
337 let effective_weights = if mode == SearchMode::Hybrid && is_short_query(query) {
338 adjust_weights_for_short_query(weights)
339 } else {
340 weights.clone()
341 };
342
343 let mut results: Vec<SearchResult> = issues
344 .iter()
345 .filter_map(|issue| {
346 let text_score = compute_text_score(query, issue);
347
348 let lexical_boost = if is_short_query(query) {
350 let doc = format!(
351 "{} {} {} {}",
352 issue.id,
353 issue.title,
354 issue.labels.join(" "),
355 issue.description,
356 );
357 if doc
358 .to_ascii_lowercase()
359 .contains(&query.to_ascii_lowercase())
360 {
361 0.35
362 } else {
363 0.0
364 }
365 } else {
366 0.0
367 };
368
369 let boosted_text = (text_score + lexical_boost).min(1.0);
370
371 if boosted_text <= 0.0 {
375 return None;
376 }
377
378 let (score, text_score_field, components) = match mode {
379 SearchMode::Text => (boosted_text, None, None),
380 SearchMode::Hybrid => {
381 let pagerank = metrics.pagerank.get(&issue.id).copied().unwrap_or(0.0);
382 let status = normalize_status(&issue.status);
383 let blocks = metrics.blocks_count.get(&issue.id).copied().unwrap_or(0);
384 let impact = normalize_impact(blocks, max_blocks);
385 let priority = normalize_priority(issue.priority);
386 let recency = normalize_recency(issue.updated_at);
387
388 let hybrid_score = effective_weights.text * boosted_text
389 + effective_weights.pagerank * pagerank
390 + effective_weights.status * status
391 + effective_weights.impact * impact
392 + effective_weights.priority * priority
393 + effective_weights.recency * recency;
394
395 (
396 hybrid_score,
397 Some(boosted_text),
398 Some(ComponentScores {
399 pagerank,
400 status,
401 impact,
402 priority,
403 recency,
404 }),
405 )
406 }
407 };
408
409 if score <= 0.0 {
411 return None;
412 }
413
414 Some(SearchResult {
415 issue_id: issue.id.clone(),
416 score,
417 title: issue.title.clone(),
418 text_score: text_score_field,
419 component_scores: components,
420 })
421 })
422 .collect();
423
424 results.sort_by(|a, b| {
426 b.score
427 .total_cmp(&a.score)
428 .then_with(|| a.issue_id.cmp(&b.issue_id))
429 });
430
431 if let Some(pos) = results
433 .iter()
434 .position(|r| r.issue_id.eq_ignore_ascii_case(query))
435 {
436 if pos > 0 {
437 let exact = results.remove(pos);
438 results.insert(0, exact);
439 }
440 }
441
442 if limit > 0 {
443 results.truncate(limit);
444 }
445
446 results
447}
448
449#[cfg(test)]
454mod tests {
455 use super::*;
456 use crate::analysis::graph::IssueGraph;
457 use crate::model::Issue;
458
459 fn make_issue(id: &str, title: &str, status: &str, priority: i32) -> Issue {
460 Issue {
461 id: id.to_string(),
462 title: title.to_string(),
463 status: status.to_string(),
464 priority,
465 ..Issue::default()
466 }
467 }
468
469 fn make_issues_and_metrics() -> (Vec<Issue>, GraphMetrics) {
470 let issues = vec![
471 make_issue("AUTH-1", "Fix authentication bug", "open", 0),
472 make_issue("NET-2", "Network timeout handling", "in_progress", 1),
473 make_issue("DB-3", "Database migration script", "open", 2),
474 make_issue("AUTH-4", "OAuth token refresh", "blocked", 1),
475 make_issue("UI-5", "Dashboard layout fix", "closed", 3),
476 ];
477 let graph = IssueGraph::build(&issues);
478 let metrics = graph.compute_metrics();
479 (issues, metrics)
480 }
481
482 #[test]
483 fn text_search_basic() {
484 let (issues, metrics) = make_issues_and_metrics();
485 let weights = SearchWeights::default_preset();
486 let results = execute_search(
487 "authentication",
488 &issues,
489 &metrics,
490 SearchMode::Text,
491 &weights,
492 10,
493 );
494
495 assert!(!results.is_empty());
496 assert_eq!(results[0].issue_id, "AUTH-1");
497 }
498
499 #[test]
500 fn text_search_no_results() {
501 let (issues, metrics) = make_issues_and_metrics();
502 let weights = SearchWeights::default_preset();
503 let results = execute_search(
504 "zzzznotfound",
505 &issues,
506 &metrics,
507 SearchMode::Text,
508 &weights,
509 10,
510 );
511
512 assert!(results.is_empty());
513 }
514
515 #[test]
516 fn text_search_whitespace_query_returns_no_results() {
517 let (issues, metrics) = make_issues_and_metrics();
518 let weights = SearchWeights::default_preset();
519 let results = execute_search(" \t ", &issues, &metrics, SearchMode::Text, &weights, 10);
520
521 assert!(results.is_empty());
522 }
523
524 #[test]
525 fn text_search_limit() {
526 let (issues, metrics) = make_issues_and_metrics();
527 let weights = SearchWeights::default_preset();
528 let results = execute_search("fix", &issues, &metrics, SearchMode::Text, &weights, 1);
529
530 assert!(results.len() <= 1);
531 }
532
533 #[test]
534 fn exact_id_match_promoted() {
535 let (issues, metrics) = make_issues_and_metrics();
536 let weights = SearchWeights::default_preset();
537 let results = execute_search("DB-3", &issues, &metrics, SearchMode::Text, &weights, 10);
538
539 assert!(!results.is_empty());
540 assert_eq!(results[0].issue_id, "DB-3");
541 }
542
543 #[test]
544 fn hybrid_mode_includes_components() {
545 let (issues, metrics) = make_issues_and_metrics();
546 let weights = SearchWeights::default_preset();
547 let results = execute_search("auth", &issues, &metrics, SearchMode::Hybrid, &weights, 10);
548
549 assert!(!results.is_empty());
550 assert!(results[0].text_score.is_some());
551 assert!(results[0].component_scores.is_some());
552 }
553
554 #[test]
555 fn hybrid_search_whitespace_query_returns_no_results() {
556 let (issues, metrics) = make_issues_and_metrics();
557 let weights = SearchWeights::default_preset();
558 let results = execute_search(" \n", &issues, &metrics, SearchMode::Hybrid, &weights, 10);
559
560 assert!(results.is_empty());
561 }
562
563 #[test]
564 fn hybrid_search_without_lexical_match_returns_no_results() {
565 let (issues, metrics) = make_issues_and_metrics();
566 let weights = SearchWeights::default_preset();
567 let results = execute_search(
568 "zzzznotfound",
569 &issues,
570 &metrics,
571 SearchMode::Hybrid,
572 &weights,
573 10,
574 );
575
576 assert!(results.is_empty());
577 }
578
579 #[test]
580 fn preset_weights_valid() {
581 let presets = [
582 "default",
583 "bug-hunting",
584 "sprint-planning",
585 "impact-first",
586 "text-only",
587 ];
588 for name in &presets {
589 let w = get_preset(name);
590 let sum = w.text + w.pagerank + w.status + w.impact + w.priority + w.recency;
591 assert!((sum - 1.0).abs() < 0.001, "preset {name} sum = {sum}");
592 }
593 }
594
595 #[test]
596 fn custom_weights_parsing() {
597 let json = r#"{"text":0.5,"pagerank":0.2,"status":0.1,"impact":0.1,"priority":0.05,"recency":0.05}"#;
598 let weights = SearchWeights::from_json(json).unwrap();
599 let sum = weights.text
600 + weights.pagerank
601 + weights.status
602 + weights.impact
603 + weights.priority
604 + weights.recency;
605 assert!((sum - 1.0).abs() < 0.001);
606 }
607
608 #[test]
609 fn short_query_detection() {
610 assert!(is_short_query("auth"));
611 assert!(is_short_query("fix bug"));
612 assert!(!is_short_query("authentication handling in the login flow"));
613 }
614
615 #[test]
616 fn short_query_weight_adjustment() {
617 let weights = SearchWeights::default_preset();
618 let adjusted = adjust_weights_for_short_query(&weights);
619 assert!(adjusted.text >= 0.55);
620 let sum = adjusted.text
621 + adjusted.pagerank
622 + adjusted.status
623 + adjusted.impact
624 + adjusted.priority
625 + adjusted.recency;
626 assert!((sum - 1.0).abs() < 0.001);
627 }
628
629 #[test]
630 fn deterministic_output() {
631 let (issues, metrics) = make_issues_and_metrics();
632 let weights = SearchWeights::default_preset();
633 let r1 = execute_search("fix", &issues, &metrics, SearchMode::Text, &weights, 10);
634 let r2 = execute_search("fix", &issues, &metrics, SearchMode::Text, &weights, 10);
635
636 assert_eq!(r1.len(), r2.len());
637 for (a, b) in r1.iter().zip(r2.iter()) {
638 assert_eq!(a.issue_id, b.issue_id);
639 assert!((a.score - b.score).abs() < f64::EPSILON);
640 }
641 }
642
643 #[test]
644 fn recency_for_current_timestamp_is_high() {
645 let score = normalize_recency(Some(chrono::Utc::now()));
646 assert!(
647 score > 0.9,
648 "expected recent timestamp score > 0.9, got {score}"
649 );
650 }
651}