intent_engine/search.rs
1//! Search utilities for intent-engine
2//!
3//! This module provides:
4//! 1. CJK (Chinese, Japanese, Korean) search utilities for detecting when to use
5//! LIKE fallback vs FTS5 trigram search
6//! 2. Unified search across tasks and events
7//!
8//! **Background**: SQLite FTS5 with trigram tokenizer requires at least 3 consecutive
9//! characters to match. This is problematic for CJK languages where single-character
10//! or two-character searches are common (e.g., "用户", "认证").
11//!
12//! **Solution**: For short CJK queries, we fallback to LIKE search which supports
13//! any length substring matching, albeit slower.
14
15/// Check if a character is a CJK character
16pub fn is_cjk_char(c: char) -> bool {
17 let code = c as u32;
18 matches!(code,
19 // CJK Unified Ideographs (most common Chinese characters)
20 0x4E00..=0x9FFF |
21 // CJK Extension A
22 0x3400..=0x4DBF |
23 // CJK Extension B-F (less common, but included for completeness)
24 0x20000..=0x2A6DF |
25 0x2A700..=0x2B73F |
26 0x2B740..=0x2B81F |
27 0x2B820..=0x2CEAF |
28 0x2CEB0..=0x2EBEF |
29 // Hiragana (Japanese)
30 0x3040..=0x309F |
31 // Katakana (Japanese)
32 0x30A0..=0x30FF |
33 // Hangul Syllables (Korean)
34 0xAC00..=0xD7AF
35 )
36}
37
38/// Determine if a query should use LIKE fallback instead of FTS5 trigram
39///
40/// Returns `true` if:
41/// - Query is a single CJK character, OR
42/// - Query is two CJK characters
43///
44/// Trigram tokenizer requires 3+ characters for matching, so we use LIKE
45/// for shorter CJK queries to ensure they work.
46pub fn needs_like_fallback(query: &str) -> bool {
47 let chars: Vec<char> = query.chars().collect();
48
49 // Single-character CJK
50 if chars.len() == 1 && is_cjk_char(chars[0]) {
51 return true;
52 }
53
54 // Two-character all-CJK
55 // This is optional - could also let trigram handle it, but trigram
56 // needs minimum 3 chars so two-char CJK won't work well
57 if chars.len() == 2 && chars.iter().all(|c| is_cjk_char(*c)) {
58 return true;
59 }
60
61 false
62}
63
64/// Escape FTS5 special characters in a query string
65///
66/// FTS5 queries support advanced syntax (AND, OR, NOT, *, "phrase search", etc.).
67/// This function only escapes double quotes, which is the most common case where
68/// user input needs escaping.
69///
70/// # Arguments
71/// * `query` - The query string to escape
72///
73/// # Returns
74/// The escaped query string with double quotes escaped as `""`
75///
76/// # Example
77/// ```ignore
78/// use crate::search::escape_fts5;
79///
80/// let escaped = escape_fts5("user \"admin\" role");
81/// assert_eq!(escaped, "user \"\"admin\"\" role");
82/// ```
83pub fn escape_fts5(query: &str) -> String {
84 query.replace('"', "\"\"")
85}
86
87// ============================================================================
88// Unified Search
89// ============================================================================
90
91use crate::db::models::{Event, PaginatedSearchResults, SearchResult, Task};
92use crate::error::Result;
93use crate::tasks::TaskManager;
94use sqlx::{Row, SqlitePool};
95
96pub struct SearchManager<'a> {
97 pool: &'a SqlitePool,
98}
99
100impl<'a> SearchManager<'a> {
101 pub fn new(pool: &'a SqlitePool) -> Self {
102 Self { pool }
103 }
104
105 /// Unified search across tasks and events with pagination support
106 ///
107 /// This is the new unified search method that replaces unified_search().
108 /// Key improvements:
109 /// - Pagination support (limit, offset)
110 /// - Flexible result inclusion (tasks, events)
111 /// - Optional priority-based secondary sorting
112 /// - Returns PaginatedSearchResults with metadata
113 ///
114 /// # Parameters
115 /// - `query`: FTS5 search query string
116 /// - `include_tasks`: Whether to search in tasks (default: true)
117 /// - `include_events`: Whether to search in events (default: true)
118 /// - `limit`: Maximum number of results per source (default: 20)
119 /// - `offset`: Number of results to skip (default: 0)
120 /// - `sort_by_priority`: Enable priority-based secondary sorting (default: false)
121 ///
122 /// # Returns
123 /// PaginatedSearchResults with mixed task and event results, ordered by relevance (FTS5 rank)
124 pub async fn search(
125 &self,
126 query: &str,
127 include_tasks: bool,
128 include_events: bool,
129 limit: Option<i64>,
130 offset: Option<i64>,
131 sort_by_priority: bool,
132 ) -> Result<PaginatedSearchResults> {
133 let limit = limit.unwrap_or(20);
134 let offset = offset.unwrap_or(0);
135
136 // Handle empty or whitespace-only queries
137 if query.trim().is_empty() {
138 return Ok(PaginatedSearchResults {
139 results: Vec::new(),
140 total_tasks: 0,
141 total_events: 0,
142 has_more: false,
143 limit,
144 offset,
145 });
146 }
147
148 // Handle queries with no searchable content (only special characters)
149 let has_searchable = query.chars().any(|c| c.is_alphanumeric() || is_cjk_char(c));
150 if !has_searchable {
151 return Ok(PaginatedSearchResults {
152 results: Vec::new(),
153 total_tasks: 0,
154 total_events: 0,
155 has_more: false,
156 limit,
157 offset,
158 });
159 }
160
161 // Escape FTS5 special characters
162 let escaped_query = escape_fts5(query);
163
164 let mut total_tasks: i64 = 0;
165 let mut total_events: i64 = 0;
166 let mut all_results: Vec<(SearchResult, f64)> = Vec::new();
167
168 // Check if we need LIKE fallback for short CJK queries
169 let use_like_fallback = needs_like_fallback(query);
170
171 if use_like_fallback {
172 // LIKE fallback path for short CJK queries (1-2 chars)
173 let like_pattern = format!("%{}%", query);
174
175 // Search tasks if enabled
176 if include_tasks {
177 // Get total count
178 let count_result = sqlx::query_scalar::<_, i64>(
179 "SELECT COUNT(*) FROM tasks WHERE name LIKE ? OR spec LIKE ?",
180 )
181 .bind(&like_pattern)
182 .bind(&like_pattern)
183 .fetch_one(self.pool)
184 .await?;
185 total_tasks = count_result;
186
187 // Build ORDER BY clause
188 let order_by = if sort_by_priority {
189 "ORDER BY COALESCE(priority, 0) ASC, id ASC"
190 } else {
191 "ORDER BY id ASC"
192 };
193
194 // Query tasks with pagination
195 let task_query = format!(
196 r#"
197 SELECT
198 id,
199 parent_id,
200 name,
201 spec,
202 status,
203 complexity,
204 priority,
205 first_todo_at,
206 first_doing_at,
207 first_done_at,
208 active_form
209 FROM tasks
210 WHERE name LIKE ? OR spec LIKE ?
211 {}
212 LIMIT ? OFFSET ?
213 "#,
214 order_by
215 );
216
217 let rows = sqlx::query(&task_query)
218 .bind(&like_pattern)
219 .bind(&like_pattern)
220 .bind(limit)
221 .bind(offset)
222 .fetch_all(self.pool)
223 .await?;
224
225 for row in rows {
226 let task = Task {
227 id: row.get("id"),
228 parent_id: row.get("parent_id"),
229 name: row.get("name"),
230 spec: row.get("spec"),
231 status: row.get("status"),
232 complexity: row.get("complexity"),
233 priority: row.get("priority"),
234 first_todo_at: row.get("first_todo_at"),
235 first_doing_at: row.get("first_doing_at"),
236 first_done_at: row.get("first_done_at"),
237 active_form: row.get("active_form"),
238 };
239
240 // Determine match field and create snippet
241 let (match_field, match_snippet) = if task.name.contains(query) {
242 ("name".to_string(), task.name.clone())
243 } else if let Some(ref spec) = task.spec {
244 if spec.contains(query) {
245 ("spec".to_string(), spec.clone())
246 } else {
247 ("name".to_string(), task.name.clone())
248 }
249 } else {
250 ("name".to_string(), task.name.clone())
251 };
252
253 all_results.push((
254 SearchResult::Task {
255 task,
256 match_snippet,
257 match_field,
258 },
259 1.0, // Constant rank for LIKE results
260 ));
261 }
262 }
263
264 // Search events if enabled
265 if include_events {
266 // Get total count
267 let count_result = sqlx::query_scalar::<_, i64>(
268 "SELECT COUNT(*) FROM events WHERE discussion_data LIKE ?",
269 )
270 .bind(&like_pattern)
271 .fetch_one(self.pool)
272 .await?;
273 total_events = count_result;
274
275 // Query events with pagination
276 let rows = sqlx::query(
277 r#"
278 SELECT
279 id,
280 task_id,
281 timestamp,
282 log_type,
283 discussion_data
284 FROM events
285 WHERE discussion_data LIKE ?
286 ORDER BY id ASC
287 LIMIT ? OFFSET ?
288 "#,
289 )
290 .bind(&like_pattern)
291 .bind(limit)
292 .bind(offset)
293 .fetch_all(self.pool)
294 .await?;
295
296 let task_mgr = TaskManager::new(self.pool);
297 for row in rows {
298 let event = Event {
299 id: row.get("id"),
300 task_id: row.get("task_id"),
301 timestamp: row.get("timestamp"),
302 log_type: row.get("log_type"),
303 discussion_data: row.get("discussion_data"),
304 };
305
306 // Create match snippet
307 let match_snippet = event.discussion_data.clone();
308
309 // Get task ancestry chain for this event
310 let task_chain = task_mgr.get_task_ancestry(event.task_id).await?;
311
312 all_results.push((
313 SearchResult::Event {
314 event,
315 task_chain,
316 match_snippet,
317 },
318 1.0, // Constant rank for LIKE results
319 ));
320 }
321 }
322 } else {
323 // FTS5 path for longer queries (3+ chars)
324 // Search tasks if enabled
325 if include_tasks {
326 // Get total count
327 let count_result = sqlx::query_scalar::<_, i64>(
328 "SELECT COUNT(*) FROM tasks_fts WHERE tasks_fts MATCH ?",
329 )
330 .bind(&escaped_query)
331 .fetch_one(self.pool)
332 .await?;
333 total_tasks = count_result;
334
335 // Build ORDER BY clause
336 let order_by = if sort_by_priority {
337 "ORDER BY rank ASC, COALESCE(t.priority, 0) ASC, t.id ASC"
338 } else {
339 "ORDER BY rank ASC, t.id ASC"
340 };
341
342 // Query tasks with pagination
343 let task_query = format!(
344 r#"
345 SELECT
346 t.id,
347 t.parent_id,
348 t.name,
349 t.spec,
350 t.status,
351 t.complexity,
352 t.priority,
353 t.first_todo_at,
354 t.first_doing_at,
355 t.first_done_at,
356 t.active_form,
357 COALESCE(
358 snippet(tasks_fts, 1, '**', '**', '...', 15),
359 snippet(tasks_fts, 0, '**', '**', '...', 15)
360 ) as match_snippet,
361 rank
362 FROM tasks_fts
363 INNER JOIN tasks t ON tasks_fts.rowid = t.id
364 WHERE tasks_fts MATCH ?
365 {}
366 LIMIT ? OFFSET ?
367 "#,
368 order_by
369 );
370
371 let rows = sqlx::query(&task_query)
372 .bind(&escaped_query)
373 .bind(limit)
374 .bind(offset)
375 .fetch_all(self.pool)
376 .await?;
377
378 for row in rows {
379 let task = Task {
380 id: row.get("id"),
381 parent_id: row.get("parent_id"),
382 name: row.get("name"),
383 spec: row.get("spec"),
384 status: row.get("status"),
385 complexity: row.get("complexity"),
386 priority: row.get("priority"),
387 first_todo_at: row.get("first_todo_at"),
388 first_doing_at: row.get("first_doing_at"),
389 first_done_at: row.get("first_done_at"),
390 active_form: row.get("active_form"),
391 };
392 let match_snippet: String = row.get("match_snippet");
393 let rank: f64 = row.get("rank");
394
395 // Determine match field based on snippet content
396 let match_field = if task
397 .spec
398 .as_ref()
399 .map(|s| match_snippet.to_lowercase().contains(&s.to_lowercase()))
400 .unwrap_or(false)
401 {
402 "spec".to_string()
403 } else {
404 "name".to_string()
405 };
406
407 all_results.push((
408 SearchResult::Task {
409 task,
410 match_snippet,
411 match_field,
412 },
413 rank,
414 ));
415 }
416 }
417
418 // Search events if enabled
419 if include_events {
420 // Get total count
421 let count_result = sqlx::query_scalar::<_, i64>(
422 "SELECT COUNT(*) FROM events_fts WHERE events_fts MATCH ?",
423 )
424 .bind(&escaped_query)
425 .fetch_one(self.pool)
426 .await?;
427 total_events = count_result;
428
429 // Query events with pagination
430 let rows = sqlx::query(
431 r#"
432 SELECT
433 e.id,
434 e.task_id,
435 e.timestamp,
436 e.log_type,
437 e.discussion_data,
438 snippet(events_fts, 0, '**', '**', '...', 15) as match_snippet,
439 rank
440 FROM events_fts
441 INNER JOIN events e ON events_fts.rowid = e.id
442 WHERE events_fts MATCH ?
443 ORDER BY rank ASC, e.id ASC
444 LIMIT ? OFFSET ?
445 "#,
446 )
447 .bind(&escaped_query)
448 .bind(limit)
449 .bind(offset)
450 .fetch_all(self.pool)
451 .await?;
452
453 let task_mgr = TaskManager::new(self.pool);
454 for row in rows {
455 let event = Event {
456 id: row.get("id"),
457 task_id: row.get("task_id"),
458 timestamp: row.get("timestamp"),
459 log_type: row.get("log_type"),
460 discussion_data: row.get("discussion_data"),
461 };
462 let match_snippet: String = row.get("match_snippet");
463 let rank: f64 = row.get("rank");
464
465 // Get task ancestry chain for this event
466 let task_chain = task_mgr.get_task_ancestry(event.task_id).await?;
467
468 all_results.push((
469 SearchResult::Event {
470 event,
471 task_chain,
472 match_snippet,
473 },
474 rank,
475 ));
476 }
477 }
478 } // End of else block (FTS5 path)
479
480 // Sort all results by rank (relevance)
481 all_results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
482
483 // Extract results without rank
484 let results: Vec<SearchResult> =
485 all_results.into_iter().map(|(result, _)| result).collect();
486
487 // Calculate has_more
488 let total_count = total_tasks + total_events;
489 let has_more = offset + (results.len() as i64) < total_count;
490
491 Ok(PaginatedSearchResults {
492 results,
493 total_tasks,
494 total_events,
495 has_more,
496 limit,
497 offset,
498 })
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505
506 #[test]
507 fn test_is_cjk_char() {
508 // Chinese characters
509 assert!(is_cjk_char('中'));
510 assert!(is_cjk_char('文'));
511 assert!(is_cjk_char('认'));
512 assert!(is_cjk_char('证'));
513
514 // Japanese Hiragana
515 assert!(is_cjk_char('あ'));
516 assert!(is_cjk_char('い'));
517
518 // Japanese Katakana
519 assert!(is_cjk_char('ア'));
520 assert!(is_cjk_char('イ'));
521
522 // Korean Hangul
523 assert!(is_cjk_char('가'));
524 assert!(is_cjk_char('나'));
525
526 // Non-CJK
527 assert!(!is_cjk_char('a'));
528 assert!(!is_cjk_char('A'));
529 assert!(!is_cjk_char('1'));
530 assert!(!is_cjk_char(' '));
531 assert!(!is_cjk_char('.'));
532 }
533
534 #[test]
535 fn test_needs_like_fallback() {
536 // Single CJK character - needs fallback
537 assert!(needs_like_fallback("中"));
538 assert!(needs_like_fallback("认"));
539 assert!(needs_like_fallback("あ"));
540 assert!(needs_like_fallback("가"));
541
542 // Two CJK characters - needs fallback
543 assert!(needs_like_fallback("中文"));
544 assert!(needs_like_fallback("认证"));
545 assert!(needs_like_fallback("用户"));
546
547 // Three+ CJK characters - can use FTS5
548 assert!(!needs_like_fallback("用户认"));
549 assert!(!needs_like_fallback("用户认证"));
550
551 // English - can use FTS5
552 assert!(!needs_like_fallback("JWT"));
553 assert!(!needs_like_fallback("auth"));
554 assert!(!needs_like_fallback("a")); // Single ASCII char, not CJK
555
556 // Mixed - can use FTS5
557 assert!(!needs_like_fallback("JWT认证"));
558 assert!(!needs_like_fallback("API接口"));
559 }
560
561 #[test]
562 fn test_needs_like_fallback_mixed_cjk_ascii() {
563 // Two characters: one CJK + one ASCII - should NOT need fallback
564 // because not all chars are CJK
565 assert!(!needs_like_fallback("中a"));
566 assert!(!needs_like_fallback("a中"));
567 assert!(!needs_like_fallback("認1"));
568
569 // Three+ characters with mixed CJK/ASCII - can use FTS5
570 assert!(!needs_like_fallback("中文API"));
571 assert!(!needs_like_fallback("JWT认证系统"));
572 assert!(!needs_like_fallback("API中文文档"));
573 }
574
575 #[test]
576 fn test_needs_like_fallback_edge_cases() {
577 // Empty string - no fallback needed
578 assert!(!needs_like_fallback(""));
579
580 // Whitespace only - no fallback
581 assert!(!needs_like_fallback(" "));
582 assert!(!needs_like_fallback(" "));
583
584 // Single non-CJK - no fallback
585 assert!(!needs_like_fallback("1"));
586 assert!(!needs_like_fallback("@"));
587 assert!(!needs_like_fallback(" "));
588
589 // Two non-CJK - no fallback
590 assert!(!needs_like_fallback("ab"));
591 assert!(!needs_like_fallback("12"));
592 }
593
594 #[test]
595 fn test_is_cjk_char_extension_ranges() {
596 // CJK Extension A (U+3400..U+4DBF)
597 assert!(is_cjk_char('\u{3400}')); // First char of Extension A
598 assert!(is_cjk_char('\u{4DBF}')); // Last char of Extension A
599
600 // CJK Unified Ideographs (U+4E00..U+9FFF) - common range
601 assert!(is_cjk_char('\u{4E00}')); // First common CJK
602 assert!(is_cjk_char('\u{9FFF}')); // Last common CJK
603
604 // Characters just outside ranges - should NOT be CJK
605 assert!(!is_cjk_char('\u{33FF}')); // Just before Extension A
606 assert!(!is_cjk_char('\u{4DC0}')); // Just after Extension A
607 assert!(!is_cjk_char('\u{4DFF}')); // Just before Unified Ideographs
608 assert!(!is_cjk_char('\u{A000}')); // Just after Unified Ideographs
609 }
610
611 #[test]
612 fn test_is_cjk_char_japanese() {
613 // Hiragana range (U+3040..U+309F)
614 assert!(is_cjk_char('\u{3040}')); // First Hiragana
615 assert!(is_cjk_char('ひ')); // Middle Hiragana
616 assert!(is_cjk_char('\u{309F}')); // Last Hiragana
617
618 // Katakana range (U+30A0..U+30FF)
619 assert!(is_cjk_char('\u{30A0}')); // First Katakana
620 assert!(is_cjk_char('カ')); // Middle Katakana
621 assert!(is_cjk_char('\u{30FF}')); // Last Katakana
622
623 // Just outside Japanese ranges
624 assert!(!is_cjk_char('\u{303F}')); // Before Hiragana
625 assert!(!is_cjk_char('\u{3100}')); // After Katakana (Bopomofo, not CJK by our definition)
626 }
627
628 #[test]
629 fn test_is_cjk_char_korean() {
630 // Hangul Syllables (U+AC00..U+D7AF)
631 assert!(is_cjk_char('\u{AC00}')); // First Hangul syllable (가)
632 assert!(is_cjk_char('한')); // Middle Hangul
633 assert!(is_cjk_char('\u{D7AF}')); // Last Hangul syllable
634
635 // Just outside Korean range
636 assert!(!is_cjk_char('\u{ABFF}')); // Before Hangul
637 assert!(!is_cjk_char('\u{D7B0}')); // After Hangul
638 }
639
640 #[test]
641 fn test_escape_fts5_basic() {
642 // No quotes - no escaping needed
643 assert_eq!(escape_fts5("hello world"), "hello world");
644 assert_eq!(escape_fts5("JWT authentication"), "JWT authentication");
645
646 // Single quote (not escaped by this function, only double quotes)
647 assert_eq!(escape_fts5("user's task"), "user's task");
648 }
649
650 #[test]
651 fn test_escape_fts5_double_quotes() {
652 // Single double quote
653 assert_eq!(escape_fts5("\"admin\""), "\"\"admin\"\"");
654
655 // Multiple double quotes
656 assert_eq!(
657 escape_fts5("\"user\" and \"admin\""),
658 "\"\"user\"\" and \"\"admin\"\""
659 );
660
661 // Double quotes at different positions
662 assert_eq!(
663 escape_fts5("start \"middle\" end"),
664 "start \"\"middle\"\" end"
665 );
666 assert_eq!(escape_fts5("\"start"), "\"\"start");
667 assert_eq!(escape_fts5("end\""), "end\"\"");
668 }
669
670 #[test]
671 fn test_escape_fts5_complex_queries() {
672 // Mixed quotes and special characters
673 assert_eq!(
674 escape_fts5("search for \"exact phrase\" here"),
675 "search for \"\"exact phrase\"\" here"
676 );
677
678 // Empty string
679 assert_eq!(escape_fts5(""), "");
680
681 // Only quotes
682 assert_eq!(escape_fts5("\""), "\"\"");
683 assert_eq!(escape_fts5("\"\""), "\"\"\"\"");
684 assert_eq!(escape_fts5("\"\"\""), "\"\"\"\"\"\"");
685 }
686
687 #[test]
688 fn test_escape_fts5_cjk_with_quotes() {
689 // CJK text with quotes
690 assert_eq!(escape_fts5("用户\"管理员\"权限"), "用户\"\"管理员\"\"权限");
691 assert_eq!(escape_fts5("\"認証\"システム"), "\"\"認証\"\"システム");
692
693 // Mixed CJK and English with quotes
694 assert_eq!(
695 escape_fts5("API\"接口\"documentation"),
696 "API\"\"接口\"\"documentation"
697 );
698 }
699
700 #[test]
701 fn test_needs_like_fallback_unicode_normalization() {
702 // Test with different Unicode representations
703 // Most CJK characters don't have composition, but test general behavior
704
705 // Standard CJK characters
706 assert!(needs_like_fallback("中"));
707 assert!(needs_like_fallback("日"));
708
709 // Two CJK characters
710 assert!(needs_like_fallback("中日"));
711 assert!(needs_like_fallback("認證"));
712 }
713}