intent_engine/search.rs
1//! Search utilities for intent-engine
2//!
3//! This module provides:
4//! 1. CJK (Chinese, Japanese, Korean) search utilities for detecting when to use
5//! LIKE fallback vs FTS5 trigram search
6//! 2. Unified search across tasks and events
7//!
8//! **Background**: SQLite FTS5 with trigram tokenizer requires at least 3 consecutive
9//! characters to match. This is problematic for CJK languages where single-character
10//! or two-character searches are common (e.g., "用户", "认证").
11//!
12//! **Solution**: For short CJK queries, we fallback to LIKE search which supports
13//! any length substring matching, albeit slower.
14
15/// Check if a character is a CJK character
16pub fn is_cjk_char(c: char) -> bool {
17 let code = c as u32;
18 matches!(code,
19 // CJK Unified Ideographs (most common Chinese characters)
20 0x4E00..=0x9FFF |
21 // CJK Extension A
22 0x3400..=0x4DBF |
23 // CJK Extension B-F (less common, but included for completeness)
24 0x20000..=0x2A6DF |
25 0x2A700..=0x2B73F |
26 0x2B740..=0x2B81F |
27 0x2B820..=0x2CEAF |
28 0x2CEB0..=0x2EBEF |
29 // Hiragana (Japanese)
30 0x3040..=0x309F |
31 // Katakana (Japanese)
32 0x30A0..=0x30FF |
33 // Hangul Syllables (Korean)
34 0xAC00..=0xD7AF
35 )
36}
37
38/// Determine if a query should use LIKE fallback instead of FTS5 trigram
39///
40/// Returns `true` if:
41/// - Query is a single CJK character, OR
42/// - Query is two CJK characters
43///
44/// Trigram tokenizer requires 3+ characters for matching, so we use LIKE
45/// for shorter CJK queries to ensure they work.
46pub fn needs_like_fallback(query: &str) -> bool {
47 let chars: Vec<char> = query.chars().collect();
48
49 // Single-character CJK
50 if chars.len() == 1 && is_cjk_char(chars[0]) {
51 return true;
52 }
53
54 // Two-character all-CJK
55 // This is optional - could also let trigram handle it, but trigram
56 // needs minimum 3 chars so two-char CJK won't work well
57 if chars.len() == 2 && chars.iter().all(|c| is_cjk_char(*c)) {
58 return true;
59 }
60
61 false
62}
63
64/// Escape FTS5 special characters in a query string
65///
66/// FTS5 queries support advanced syntax (AND, OR, NOT, *, "phrase search", etc.).
67/// This function only escapes double quotes, which is the most common case where
68/// user input needs escaping.
69///
70/// # Arguments
71/// * `query` - The query string to escape
72///
73/// # Returns
74/// The escaped query string with double quotes escaped as `""`
75///
76/// # Example
77/// ```ignore
78/// use crate::search::escape_fts5;
79///
80/// let escaped = escape_fts5("user \"admin\" role");
81/// assert_eq!(escaped, "user \"\"admin\"\" role");
82/// ```
83pub fn escape_fts5(query: &str) -> String {
84 query.replace('"', "\"\"")
85}
86
87// ============================================================================
88// Unified Search
89// ============================================================================
90
91use crate::db::models::{Event, PaginatedSearchResults, SearchResult, Task};
92use crate::error::Result;
93use crate::tasks::TaskManager;
94use sqlx::{Row, SqlitePool};
95
96pub struct SearchManager<'a> {
97 pool: &'a SqlitePool,
98}
99
100impl<'a> SearchManager<'a> {
101 pub fn new(pool: &'a SqlitePool) -> Self {
102 Self { pool }
103 }
104
105 /// Unified search across tasks and events with pagination support
106 ///
107 /// This is the new unified search method that replaces unified_search().
108 /// Key improvements:
109 /// - Pagination support (limit, offset)
110 /// - Flexible result inclusion (tasks, events)
111 /// - Optional priority-based secondary sorting
112 /// - Returns PaginatedSearchResults with metadata
113 ///
114 /// # Parameters
115 /// - `query`: FTS5 search query string
116 /// - `include_tasks`: Whether to search in tasks (default: true)
117 /// - `include_events`: Whether to search in events (default: true)
118 /// - `limit`: Maximum number of results per source (default: 20)
119 /// - `offset`: Number of results to skip (default: 0)
120 /// - `sort_by_priority`: Enable priority-based secondary sorting (default: false)
121 ///
122 /// # Returns
123 /// PaginatedSearchResults with mixed task and event results, ordered by relevance (FTS5 rank)
124 pub async fn search(
125 &self,
126 query: &str,
127 include_tasks: bool,
128 include_events: bool,
129 limit: Option<i64>,
130 offset: Option<i64>,
131 sort_by_priority: bool,
132 ) -> Result<PaginatedSearchResults> {
133 let limit = limit.unwrap_or(20);
134 let offset = offset.unwrap_or(0);
135
136 // Handle empty or whitespace-only queries
137 if query.trim().is_empty() {
138 return Ok(PaginatedSearchResults {
139 results: Vec::new(),
140 total_tasks: 0,
141 total_events: 0,
142 has_more: false,
143 limit,
144 offset,
145 });
146 }
147
148 // Handle queries with no searchable content (only special characters)
149 let has_searchable = query.chars().any(|c| c.is_alphanumeric() || is_cjk_char(c));
150 if !has_searchable {
151 return Ok(PaginatedSearchResults {
152 results: Vec::new(),
153 total_tasks: 0,
154 total_events: 0,
155 has_more: false,
156 limit,
157 offset,
158 });
159 }
160
161 // Escape FTS5 special characters
162 let escaped_query = escape_fts5(query);
163
164 let mut total_tasks: i64 = 0;
165 let mut total_events: i64 = 0;
166 let mut all_results: Vec<(SearchResult, f64)> = Vec::new();
167
168 // Check if we need LIKE fallback for short CJK queries
169 let use_like_fallback = needs_like_fallback(query);
170
171 if use_like_fallback {
172 // LIKE fallback path for short CJK queries (1-2 chars)
173 let like_pattern = format!("%{}%", query);
174
175 // Search tasks if enabled
176 if include_tasks {
177 // Get total count
178 let count_result = sqlx::query_scalar::<_, i64>(
179 "SELECT COUNT(*) FROM tasks WHERE name LIKE ? OR spec LIKE ?",
180 )
181 .bind(&like_pattern)
182 .bind(&like_pattern)
183 .fetch_one(self.pool)
184 .await?;
185 total_tasks = count_result;
186
187 // Build ORDER BY clause
188 let order_by = if sort_by_priority {
189 "ORDER BY COALESCE(priority, 0) ASC, id ASC"
190 } else {
191 "ORDER BY id ASC"
192 };
193
194 // Query tasks with pagination
195 let task_query = format!(
196 r#"
197 SELECT
198 id,
199 parent_id,
200 name,
201 spec,
202 status,
203 complexity,
204 priority,
205 first_todo_at,
206 first_doing_at,
207 first_done_at,
208 active_form,
209 owner
210 FROM tasks
211 WHERE name LIKE ? OR spec LIKE ?
212 {}
213 LIMIT ? OFFSET ?
214 "#,
215 order_by
216 );
217
218 let rows = sqlx::query(&task_query)
219 .bind(&like_pattern)
220 .bind(&like_pattern)
221 .bind(limit)
222 .bind(offset)
223 .fetch_all(self.pool)
224 .await?;
225
226 for row in rows {
227 let task = Task {
228 id: row.get("id"),
229 parent_id: row.get("parent_id"),
230 name: row.get("name"),
231 spec: row.get("spec"),
232 status: row.get("status"),
233 complexity: row.get("complexity"),
234 priority: row.get("priority"),
235 first_todo_at: row.get("first_todo_at"),
236 first_doing_at: row.get("first_doing_at"),
237 first_done_at: row.get("first_done_at"),
238 active_form: row.get("active_form"),
239 owner: row.get("owner"),
240 };
241
242 // Determine match field and create snippet
243 let (match_field, match_snippet) = if task.name.contains(query) {
244 ("name".to_string(), task.name.clone())
245 } else if let Some(ref spec) = task.spec {
246 if spec.contains(query) {
247 ("spec".to_string(), spec.clone())
248 } else {
249 ("name".to_string(), task.name.clone())
250 }
251 } else {
252 ("name".to_string(), task.name.clone())
253 };
254
255 all_results.push((
256 SearchResult::Task {
257 task,
258 match_snippet,
259 match_field,
260 },
261 1.0, // Constant rank for LIKE results
262 ));
263 }
264 }
265
266 // Search events if enabled
267 if include_events {
268 // Get total count
269 let count_result = sqlx::query_scalar::<_, i64>(
270 "SELECT COUNT(*) FROM events WHERE discussion_data LIKE ?",
271 )
272 .bind(&like_pattern)
273 .fetch_one(self.pool)
274 .await?;
275 total_events = count_result;
276
277 // Query events with pagination
278 let rows = sqlx::query(
279 r#"
280 SELECT
281 id,
282 task_id,
283 timestamp,
284 log_type,
285 discussion_data
286 FROM events
287 WHERE discussion_data LIKE ?
288 ORDER BY id ASC
289 LIMIT ? OFFSET ?
290 "#,
291 )
292 .bind(&like_pattern)
293 .bind(limit)
294 .bind(offset)
295 .fetch_all(self.pool)
296 .await?;
297
298 let task_mgr = TaskManager::new(self.pool);
299 for row in rows {
300 let event = Event {
301 id: row.get("id"),
302 task_id: row.get("task_id"),
303 timestamp: row.get("timestamp"),
304 log_type: row.get("log_type"),
305 discussion_data: row.get("discussion_data"),
306 };
307
308 // Create match snippet
309 let match_snippet = event.discussion_data.clone();
310
311 // Get task ancestry chain for this event
312 let task_chain = task_mgr.get_task_ancestry(event.task_id).await?;
313
314 all_results.push((
315 SearchResult::Event {
316 event,
317 task_chain,
318 match_snippet,
319 },
320 1.0, // Constant rank for LIKE results
321 ));
322 }
323 }
324 } else {
325 // FTS5 path for longer queries (3+ chars)
326 // Search tasks if enabled
327 if include_tasks {
328 // Get total count
329 let count_result = sqlx::query_scalar::<_, i64>(
330 "SELECT COUNT(*) FROM tasks_fts WHERE tasks_fts MATCH ?",
331 )
332 .bind(&escaped_query)
333 .fetch_one(self.pool)
334 .await?;
335 total_tasks = count_result;
336
337 // Build ORDER BY clause
338 let order_by = if sort_by_priority {
339 "ORDER BY rank ASC, COALESCE(t.priority, 0) ASC, t.id ASC"
340 } else {
341 "ORDER BY rank ASC, t.id ASC"
342 };
343
344 // Query tasks with pagination
345 let task_query = format!(
346 r#"
347 SELECT
348 t.id,
349 t.parent_id,
350 t.name,
351 t.spec,
352 t.status,
353 t.complexity,
354 t.priority,
355 t.first_todo_at,
356 t.first_doing_at,
357 t.first_done_at,
358 t.active_form,
359 t.owner,
360 COALESCE(
361 snippet(tasks_fts, 1, '**', '**', '...', 15),
362 snippet(tasks_fts, 0, '**', '**', '...', 15)
363 ) as match_snippet,
364 rank
365 FROM tasks_fts
366 INNER JOIN tasks t ON tasks_fts.rowid = t.id
367 WHERE tasks_fts MATCH ?
368 {}
369 LIMIT ? OFFSET ?
370 "#,
371 order_by
372 );
373
374 let rows = sqlx::query(&task_query)
375 .bind(&escaped_query)
376 .bind(limit)
377 .bind(offset)
378 .fetch_all(self.pool)
379 .await?;
380
381 for row in rows {
382 let task = Task {
383 id: row.get("id"),
384 parent_id: row.get("parent_id"),
385 name: row.get("name"),
386 spec: row.get("spec"),
387 status: row.get("status"),
388 complexity: row.get("complexity"),
389 priority: row.get("priority"),
390 first_todo_at: row.get("first_todo_at"),
391 first_doing_at: row.get("first_doing_at"),
392 first_done_at: row.get("first_done_at"),
393 active_form: row.get("active_form"),
394 owner: row.get("owner"),
395 };
396 let match_snippet: String = row.get("match_snippet");
397 let rank: f64 = row.get("rank");
398
399 // Determine match field based on snippet content
400 let match_field = if task
401 .spec
402 .as_ref()
403 .map(|s| match_snippet.to_lowercase().contains(&s.to_lowercase()))
404 .unwrap_or(false)
405 {
406 "spec".to_string()
407 } else {
408 "name".to_string()
409 };
410
411 all_results.push((
412 SearchResult::Task {
413 task,
414 match_snippet,
415 match_field,
416 },
417 rank,
418 ));
419 }
420 }
421
422 // Search events if enabled
423 if include_events {
424 // Get total count
425 let count_result = sqlx::query_scalar::<_, i64>(
426 "SELECT COUNT(*) FROM events_fts WHERE events_fts MATCH ?",
427 )
428 .bind(&escaped_query)
429 .fetch_one(self.pool)
430 .await?;
431 total_events = count_result;
432
433 // Query events with pagination
434 let rows = sqlx::query(
435 r#"
436 SELECT
437 e.id,
438 e.task_id,
439 e.timestamp,
440 e.log_type,
441 e.discussion_data,
442 snippet(events_fts, 0, '**', '**', '...', 15) as match_snippet,
443 rank
444 FROM events_fts
445 INNER JOIN events e ON events_fts.rowid = e.id
446 WHERE events_fts MATCH ?
447 ORDER BY rank ASC, e.id ASC
448 LIMIT ? OFFSET ?
449 "#,
450 )
451 .bind(&escaped_query)
452 .bind(limit)
453 .bind(offset)
454 .fetch_all(self.pool)
455 .await?;
456
457 let task_mgr = TaskManager::new(self.pool);
458 for row in rows {
459 let event = Event {
460 id: row.get("id"),
461 task_id: row.get("task_id"),
462 timestamp: row.get("timestamp"),
463 log_type: row.get("log_type"),
464 discussion_data: row.get("discussion_data"),
465 };
466 let match_snippet: String = row.get("match_snippet");
467 let rank: f64 = row.get("rank");
468
469 // Get task ancestry chain for this event
470 let task_chain = task_mgr.get_task_ancestry(event.task_id).await?;
471
472 all_results.push((
473 SearchResult::Event {
474 event,
475 task_chain,
476 match_snippet,
477 },
478 rank,
479 ));
480 }
481 }
482 } // End of else block (FTS5 path)
483
484 // Sort all results by rank (relevance)
485 all_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
486
487 // Extract results without rank
488 let results: Vec<SearchResult> =
489 all_results.into_iter().map(|(result, _)| result).collect();
490
491 // Calculate has_more
492 let total_count = total_tasks + total_events;
493 let has_more = offset + (results.len() as i64) < total_count;
494
495 Ok(PaginatedSearchResults {
496 results,
497 total_tasks,
498 total_events,
499 has_more,
500 limit,
501 offset,
502 })
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_is_cjk_char() {
512 // Chinese characters
513 assert!(is_cjk_char('中'));
514 assert!(is_cjk_char('文'));
515 assert!(is_cjk_char('认'));
516 assert!(is_cjk_char('证'));
517
518 // Japanese Hiragana
519 assert!(is_cjk_char('あ'));
520 assert!(is_cjk_char('い'));
521
522 // Japanese Katakana
523 assert!(is_cjk_char('ア'));
524 assert!(is_cjk_char('イ'));
525
526 // Korean Hangul
527 assert!(is_cjk_char('가'));
528 assert!(is_cjk_char('나'));
529
530 // Non-CJK
531 assert!(!is_cjk_char('a'));
532 assert!(!is_cjk_char('A'));
533 assert!(!is_cjk_char('1'));
534 assert!(!is_cjk_char(' '));
535 assert!(!is_cjk_char('.'));
536 }
537
538 #[test]
539 fn test_needs_like_fallback() {
540 // Single CJK character - needs fallback
541 assert!(needs_like_fallback("中"));
542 assert!(needs_like_fallback("认"));
543 assert!(needs_like_fallback("あ"));
544 assert!(needs_like_fallback("가"));
545
546 // Two CJK characters - needs fallback
547 assert!(needs_like_fallback("中文"));
548 assert!(needs_like_fallback("认证"));
549 assert!(needs_like_fallback("用户"));
550
551 // Three+ CJK characters - can use FTS5
552 assert!(!needs_like_fallback("用户认"));
553 assert!(!needs_like_fallback("用户认证"));
554
555 // English - can use FTS5
556 assert!(!needs_like_fallback("JWT"));
557 assert!(!needs_like_fallback("auth"));
558 assert!(!needs_like_fallback("a")); // Single ASCII char, not CJK
559
560 // Mixed - can use FTS5
561 assert!(!needs_like_fallback("JWT认证"));
562 assert!(!needs_like_fallback("API接口"));
563 }
564
565 #[test]
566 fn test_needs_like_fallback_mixed_cjk_ascii() {
567 // Two characters: one CJK + one ASCII - should NOT need fallback
568 // because not all chars are CJK
569 assert!(!needs_like_fallback("中a"));
570 assert!(!needs_like_fallback("a中"));
571 assert!(!needs_like_fallback("認1"));
572
573 // Three+ characters with mixed CJK/ASCII - can use FTS5
574 assert!(!needs_like_fallback("中文API"));
575 assert!(!needs_like_fallback("JWT认证系统"));
576 assert!(!needs_like_fallback("API中文文档"));
577 }
578
579 #[test]
580 fn test_needs_like_fallback_edge_cases() {
581 // Empty string - no fallback needed
582 assert!(!needs_like_fallback(""));
583
584 // Whitespace only - no fallback
585 assert!(!needs_like_fallback(" "));
586 assert!(!needs_like_fallback(" "));
587
588 // Single non-CJK - no fallback
589 assert!(!needs_like_fallback("1"));
590 assert!(!needs_like_fallback("@"));
591 assert!(!needs_like_fallback(" "));
592
593 // Two non-CJK - no fallback
594 assert!(!needs_like_fallback("ab"));
595 assert!(!needs_like_fallback("12"));
596 }
597
598 #[test]
599 fn test_is_cjk_char_extension_ranges() {
600 // CJK Extension A (U+3400..U+4DBF)
601 assert!(is_cjk_char('\u{3400}')); // First char of Extension A
602 assert!(is_cjk_char('\u{4DBF}')); // Last char of Extension A
603
604 // CJK Unified Ideographs (U+4E00..U+9FFF) - common range
605 assert!(is_cjk_char('\u{4E00}')); // First common CJK
606 assert!(is_cjk_char('\u{9FFF}')); // Last common CJK
607
608 // Characters just outside ranges - should NOT be CJK
609 assert!(!is_cjk_char('\u{33FF}')); // Just before Extension A
610 assert!(!is_cjk_char('\u{4DC0}')); // Just after Extension A
611 assert!(!is_cjk_char('\u{4DFF}')); // Just before Unified Ideographs
612 assert!(!is_cjk_char('\u{A000}')); // Just after Unified Ideographs
613 }
614
615 #[test]
616 fn test_is_cjk_char_japanese() {
617 // Hiragana range (U+3040..U+309F)
618 assert!(is_cjk_char('\u{3040}')); // First Hiragana
619 assert!(is_cjk_char('ひ')); // Middle Hiragana
620 assert!(is_cjk_char('\u{309F}')); // Last Hiragana
621
622 // Katakana range (U+30A0..U+30FF)
623 assert!(is_cjk_char('\u{30A0}')); // First Katakana
624 assert!(is_cjk_char('カ')); // Middle Katakana
625 assert!(is_cjk_char('\u{30FF}')); // Last Katakana
626
627 // Just outside Japanese ranges
628 assert!(!is_cjk_char('\u{303F}')); // Before Hiragana
629 assert!(!is_cjk_char('\u{3100}')); // After Katakana (Bopomofo, not CJK by our definition)
630 }
631
632 #[test]
633 fn test_is_cjk_char_korean() {
634 // Hangul Syllables (U+AC00..U+D7AF)
635 assert!(is_cjk_char('\u{AC00}')); // First Hangul syllable (가)
636 assert!(is_cjk_char('한')); // Middle Hangul
637 assert!(is_cjk_char('\u{D7AF}')); // Last Hangul syllable
638
639 // Just outside Korean range
640 assert!(!is_cjk_char('\u{ABFF}')); // Before Hangul
641 assert!(!is_cjk_char('\u{D7B0}')); // After Hangul
642 }
643
644 #[test]
645 fn test_escape_fts5_basic() {
646 // No quotes - no escaping needed
647 assert_eq!(escape_fts5("hello world"), "hello world");
648 assert_eq!(escape_fts5("JWT authentication"), "JWT authentication");
649
650 // Single quote (not escaped by this function, only double quotes)
651 assert_eq!(escape_fts5("user's task"), "user's task");
652 }
653
654 #[test]
655 fn test_escape_fts5_double_quotes() {
656 // Single double quote
657 assert_eq!(escape_fts5("\"admin\""), "\"\"admin\"\"");
658
659 // Multiple double quotes
660 assert_eq!(
661 escape_fts5("\"user\" and \"admin\""),
662 "\"\"user\"\" and \"\"admin\"\""
663 );
664
665 // Double quotes at different positions
666 assert_eq!(
667 escape_fts5("start \"middle\" end"),
668 "start \"\"middle\"\" end"
669 );
670 assert_eq!(escape_fts5("\"start"), "\"\"start");
671 assert_eq!(escape_fts5("end\""), "end\"\"");
672 }
673
674 #[test]
675 fn test_escape_fts5_complex_queries() {
676 // Mixed quotes and special characters
677 assert_eq!(
678 escape_fts5("search for \"exact phrase\" here"),
679 "search for \"\"exact phrase\"\" here"
680 );
681
682 // Empty string
683 assert_eq!(escape_fts5(""), "");
684
685 // Only quotes
686 assert_eq!(escape_fts5("\""), "\"\"");
687 assert_eq!(escape_fts5("\"\""), "\"\"\"\"");
688 assert_eq!(escape_fts5("\"\"\""), "\"\"\"\"\"\"");
689 }
690
691 #[test]
692 fn test_escape_fts5_cjk_with_quotes() {
693 // CJK text with quotes
694 assert_eq!(escape_fts5("用户\"管理员\"权限"), "用户\"\"管理员\"\"权限");
695 assert_eq!(escape_fts5("\"認証\"システム"), "\"\"認証\"\"システム");
696
697 // Mixed CJK and English with quotes
698 assert_eq!(
699 escape_fts5("API\"接口\"documentation"),
700 "API\"\"接口\"\"documentation"
701 );
702 }
703
704 #[test]
705 fn test_needs_like_fallback_unicode_normalization() {
706 // Test with different Unicode representations
707 // Most CJK characters don't have composition, but test general behavior
708
709 // Standard CJK characters
710 assert!(needs_like_fallback("中"));
711 assert!(needs_like_fallback("日"));
712
713 // Two CJK characters
714 assert!(needs_like_fallback("中日"));
715 assert!(needs_like_fallback("認證"));
716 }
717}